正弦位置编码


  • 参考苏剑林老师的bert4keras中的代码
class SinusoidalPositionEmbedding(Layer):
    """定义Sin-Cos位置Embedding
    """

    def __init__(
            self, output_dim, merge_mode='add', custom_position_ids=False, **kwargs
    ):
        super(SinusoidalPositionEmbedding, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.merge_mode = merge_mode
        self.custom_position_ids = custom_position_ids

    def call(self, inputs):
        """如果custom_position_ids,那么第二个输入为自定义的位置id
        """
        input_shape = K.shape(inputs)
        batch_size, seq_len = input_shape[0], input_shape[1]

        if self.custom_position_ids:
            inputs, position_ids = inputs
        else:
            # 得到位置编码id 加了[None]变成两维的 [[0,1,2,...,seq_len]]
            position_ids = K.arange(0, seq_len, dtype=K.floatx())[None]
        # 根据公式开始计算
        # 取一半的,方便2i的计算
        indices = K.arange(0, self.output_dim // 2, dtype=K.floatx())
        # 对前一个参数x,取后一个参数y的平方,x^y,即10000^(2i/dim)
        indices = K.pow(10000.0, -2 * indices / self.output_dim)

        # shape=(btz, seq_len, dim)
        pos_embeddings = tf.einsum('bn,d->bnd', position_ids, indices)
        pos_embeddings = K.concatenate([
            K.sin(pos_embeddings)[..., None],
            K.cos(pos_embeddings)[..., None]
        ])
        # [...,None]会在最后一维增加一维,把每个值用[]包起来
        # 比如a = K.arange(0, 10) 本来输出的是:[0 1 2 3 4 5 6 7 8 9];a = K.arange(0, 10)[..., None]变成了[[0] [1] [2] [3] [4] [5] [6] [7] [8] [9]]
        # 同K.expand_dim(pos_embeddings, -1)的效果

        # 重新reshape成shape=(btz, seq_len, dim)
        pos_embeddings = K.reshape(
            pos_embeddings, (-1, seq_len, self.output_dim)
        )

        if self.merge_mode == 'add':
            return inputs + pos_embeddings
        elif self.merge_mode == 'mul':
            return inputs * pos_embeddings
        else:
            if not self.custom_position_ids:
                pos_embeddings = K.tile(pos_embeddings, [batch_size, 1, 1])
            return K.concatenate([inputs, pos_embeddings])

    def compute_output_shape(self, input_shape):
        if self.custom_position_ids:
            input_shape = input_shape[0]

        if self.merge_mode in ['add', 'mul']:
            return input_shape
        else:
            return input_shape[:2] + (input_shape[2] + self.output_dim,)

    def get_config(self):
        config = {
            'output_dim': self.output_dim,
            'merge_mode': self.merge_mode,
            'custom_position_ids': self.custom_position_ids,
        }
        base_config = super(SinusoidalPositionEmbedding, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))