class SinusoidalPositionEmbedding(Layer):
"""定义Sin-Cos位置Embedding
"""
def __init__(
self, output_dim, merge_mode='add', custom_position_ids=False, **kwargs
):
super(SinusoidalPositionEmbedding, self).__init__(**kwargs)
self.output_dim = output_dim
self.merge_mode = merge_mode
self.custom_position_ids = custom_position_ids
def call(self, inputs):
"""如果custom_position_ids,那么第二个输入为自定义的位置id
"""
input_shape = K.shape(inputs)
batch_size, seq_len = input_shape[0], input_shape[1]
if self.custom_position_ids:
inputs, position_ids = inputs
else:
# 得到位置编码id 加了[None]变成两维的 [[0,1,2,...,seq_len]]
position_ids = K.arange(0, seq_len, dtype=K.floatx())[None]
# 根据公式开始计算
# 取一半的,方便2i的计算
indices = K.arange(0, self.output_dim // 2, dtype=K.floatx())
# 对前一个参数x,取后一个参数y的平方,x^y,即10000^(2i/dim)
indices = K.pow(10000.0, -2 * indices / self.output_dim)
# shape=(btz, seq_len, dim)
pos_embeddings = tf.einsum('bn,d->bnd', position_ids, indices)
pos_embeddings = K.concatenate([
K.sin(pos_embeddings)[..., None],
K.cos(pos_embeddings)[..., None]
])
# [...,None]会在最后一维增加一维,把每个值用[]包起来
# 比如a = K.arange(0, 10) 本来输出的是:[0 1 2 3 4 5 6 7 8 9];a = K.arange(0, 10)[..., None]变成了[[0] [1] [2] [3] [4] [5] [6] [7] [8] [9]]
# 同K.expand_dim(pos_embeddings, -1)的效果
# 重新reshape成shape=(btz, seq_len, dim)
pos_embeddings = K.reshape(
pos_embeddings, (-1, seq_len, self.output_dim)
)
if self.merge_mode == 'add':
return inputs + pos_embeddings
elif self.merge_mode == 'mul':
return inputs * pos_embeddings
else:
if not self.custom_position_ids:
pos_embeddings = K.tile(pos_embeddings, [batch_size, 1, 1])
return K.concatenate([inputs, pos_embeddings])
def compute_output_shape(self, input_shape):
if self.custom_position_ids:
input_shape = input_shape[0]
if self.merge_mode in ['add', 'mul']:
return input_shape
else:
return input_shape[:2] + (input_shape[2] + self.output_dim,)
def get_config(self):
config = {
'output_dim': self.output_dim,
'merge_mode': self.merge_mode,
'custom_position_ids': self.custom_position_ids,
}
base_config = super(SinusoidalPositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))