class RelativePositionEmbedding(Layer):
"""相对位置编码
来自论文:https://arxiv.org/abs/1803.02155
"""
def __init__(
self, input_dim, output_dim, embeddings_initializer='zeros', **kwargs
):
super(RelativePositionEmbedding, self).__init__(**kwargs)
self.input_dim = input_dim # 129
self.output_dim = output_dim # attention_head_size每一头的维度 768/12=64
self.embeddings_initializer = initializers.get(embeddings_initializer)
def build(self, input_shape):
super(RelativePositionEmbedding, self).build(input_shape)
# 初始化待训练的位置编码权重
self.embeddings = self.add_weight(
name='embeddings',
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
)
def call(self, inputs):
# 根据位置id获取位置编码 每一种的不一样,比如位置id=1的和位置id=-1的就不一样,读取进来,
pos_ids = self.compute_position_ids(inputs)
return K.gather(self.embeddings, pos_ids) # 输出的时候需要是(btz,seq_len,dim)
def compute_position_ids(self, inputs):
q, v = inputs # [x, x]
# 计算位置差
# 一维[0,1,...,q_seq_len]
q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
# [[0] [1] [2] [3] ... [q_seq_len]]
q_idxs = K.expand_dims(q_idxs, 1)
v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
# [[0,1,...,v_seq_len]]
v_idxs = K.expand_dims(v_idxs, 0)
pos_ids = v_idxs - q_idxs
'''以q_seq_len=v_seq_len=9为例:
[[ 0 1 2 3 4 5 6 7 8 9]
[-1 0 1 2 3 4 5 6 7 8]
[-2 -1 0 1 2 3 4 5 6 7]
[-3 -2 -1 0 1 2 3 4 5 6]
[-4 -3 -2 -1 0 1 2 3 4 5]
[-5 -4 -3 -2 -1 0 1 2 3 4]
[-6 -5 -4 -3 -2 -1 0 1 2 3]
[-7 -6 -5 -4 -3 -2 -1 0 1 2]
[-8 -7 -6 -5 -4 -3 -2 -1 0 1]
[-9 -8 -7 -6 -5 -4 -3 -2 -1 0]]
相对位置编码就比较简单的用这种差几位数来表示相对位置
'''
# 后处理操作
max_position = (self.input_dim - 1) // 2
'''
K.clip:逐元素clip,将pos_ids中超出(-max_position, max_position)范围的数强制变为边界值
1、作者假设精确的相对位置编码在超出了一定距离之后是没有必要的
2、截断最大距离使得模型的泛化效果好,可以更好的generalize到没有在训练阶段出现过的序列长度上
比如上面的例子中,截到(-4,4)之间为:
[[ 0 1 2 3 4 4 4 4 4 4]
[-1 0 1 2 3 4 4 4 4 4]
[-2 -1 0 1 2 3 4 4 4 4]
[-3 -2 -1 0 1 2 3 4 4 4]
[-4 -3 -2 -1 0 1 2 3 4 4]
[-4 -4 -3 -2 -1 0 1 2 3 4]
[-4 -4 -4 -3 -2 -1 0 1 2 3]
[-4 -4 -4 -4 -3 -2 -1 0 1 2]
[-4 -4 -4 -4 -4 -3 -2 -1 0 1]
[-4 -4 -4 -4 -4 -4 -3 -2 -1 0]]
'''
pos_ids = K.clip(pos_ids, -max_position,
max_position)
pos_ids = pos_ids + max_position # shape=(q_seq_lenv, v_seq_len)
return pos_ids
def compute_output_shape(self, input_shape):
return (None, None, self.output_dim)
def compute_mask(self, inputs, mask):
return mask[0]
def get_config(self):
config = {
'input_dim': self.input_dim,
'output_dim': self.output_dim,
'embeddings_initializer':
initializers.serialize(self.embeddings_initializer),
}
base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))