相对位置编码-NAZHA


  • 参考苏剑林老师的bert4keras中的代码
class RelativePositionEmbedding(Layer):
    """相对位置编码
    来自论文:https://arxiv.org/abs/1803.02155
    """

    def __init__(
            self, input_dim, output_dim, embeddings_initializer='zeros', **kwargs
    ):
        super(RelativePositionEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim  # 129
        self.output_dim = output_dim  # attention_head_size每一头的维度 768/12=64
        self.embeddings_initializer = initializers.get(embeddings_initializer)

    def build(self, input_shape):
        super(RelativePositionEmbedding, self).build(input_shape)
        # 初始化待训练的位置编码权重
        self.embeddings = self.add_weight(
            name='embeddings',
            shape=(self.input_dim, self.output_dim),
            initializer=self.embeddings_initializer,
        )

    def call(self, inputs):
        # 根据位置id获取位置编码  每一种的不一样,比如位置id=1的和位置id=-1的就不一样,读取进来,
        pos_ids = self.compute_position_ids(inputs)
        return K.gather(self.embeddings, pos_ids) # 输出的时候需要是(btz,seq_len,dim)


    def compute_position_ids(self, inputs):
        q, v = inputs  # [x, x]
        # 计算位置差
        # 一维[0,1,...,q_seq_len]
        q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
        # [[0] [1] [2] [3] ... [q_seq_len]]
        q_idxs = K.expand_dims(q_idxs, 1)
        v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
        # [[0,1,...,v_seq_len]]
        v_idxs = K.expand_dims(v_idxs, 0)
        pos_ids = v_idxs - q_idxs
        '''以q_seq_len=v_seq_len=9为例:
         [[ 0  1  2  3  4  5  6  7  8  9]
          [-1  0  1  2  3  4  5  6  7  8]
          [-2 -1  0  1  2  3  4  5  6  7]
          [-3 -2 -1  0  1  2  3  4  5  6]
          [-4 -3 -2 -1  0  1  2  3  4  5]
          [-5 -4 -3 -2 -1  0  1  2  3  4]
          [-6 -5 -4 -3 -2 -1  0  1  2  3]
          [-7 -6 -5 -4 -3 -2 -1  0  1  2]
          [-8 -7 -6 -5 -4 -3 -2 -1  0  1]
          [-9 -8 -7 -6 -5 -4 -3 -2 -1  0]]
          相对位置编码就比较简单的用这种差几位数来表示相对位置
        '''
        # 后处理操作
        max_position = (self.input_dim - 1) // 2
        '''
        K.clip:逐元素clip,将pos_ids中超出(-max_position, max_position)范围的数强制变为边界值
        1、作者假设精确的相对位置编码在超出了一定距离之后是没有必要的
        2、截断最大距离使得模型的泛化效果好,可以更好的generalize到没有在训练阶段出现过的序列长度上
        比如上面的例子中,截到(-4,4)之间为:
        [[ 0  1  2  3  4  4  4  4  4  4]
         [-1  0  1  2  3  4  4  4  4  4]
         [-2 -1  0  1  2  3  4  4  4  4]
         [-3 -2 -1  0  1  2  3  4  4  4]
         [-4 -3 -2 -1  0  1  2  3  4  4]
         [-4 -4 -3 -2 -1  0  1  2  3  4]
         [-4 -4 -4 -3 -2 -1  0  1  2  3]
         [-4 -4 -4 -4 -3 -2 -1  0  1  2]
         [-4 -4 -4 -4 -4 -3 -2 -1  0  1]
         [-4 -4 -4 -4 -4 -4 -3 -2 -1  0]]
        '''
        pos_ids = K.clip(pos_ids, -max_position,
                         max_position)
        pos_ids = pos_ids + max_position  # shape=(q_seq_lenv, v_seq_len)
        return pos_ids

    def compute_output_shape(self, input_shape):
        return (None, None, self.output_dim)

    def compute_mask(self, inputs, mask):
        return mask[0]

    def get_config(self):
        config = {
            'input_dim': self.input_dim,
            'output_dim': self.output_dim,
            'embeddings_initializer':
                initializers.serialize(self.embeddings_initializer),
        }
        base_config = super(RelativePositionEmbedding, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))