绝对位置编码-BERT


  • 参考苏剑林老师的bert4keras中的代码
from keras.layers import Layer
import keras.backend as K
from keras import initializers
import tensorflow as tf
class PositionEmbedding(Layer):
    """定义可训练的位置Embedding
    """

    def __init__(
            self,
            input_dim,
            output_dim,
            merge_mode='add',
            hierarchical=None,
            embeddings_initializer='zeros',
            custom_position_ids=False,
            **kwargs
    ):
        super(PositionEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim  # 输入维度max_position
        self.output_dim = output_dim  # 输出维度embedding_size,bert中用的是768
        self.merge_mode = merge_mode  # add模式或者mul模式
        self.hierarchical = hierarchical
        self.embeddings_initializer = initializers.get(embeddings_initializer)
        self.custom_position_ids = custom_position_ids

    def build(self, input_shape):
        super(PositionEmbedding, self).build(input_shape)
        self.embeddings = self.add_weight(
            name='embeddings',
            shape=(self.input_dim, self.output_dim),
            initializer=self.embeddings_initializer
        )  # 初始化待训练的位置编码权重

    def call(self, inputs):
        """如果custom_position_ids,那么第二个输入为自定义的位置id
        """
        input_shape = K.shape(inputs)
        batch_size, seq_len = input_shape[0], input_shape[1]
        # 自己输入位置编码及其位置id
        if self.custom_position_ids:  
            inputs, position_ids = inputs
            if K.dtype(position_ids) != 'int32':
                position_ids = K.cast(position_ids, 'int32')
        else:
            # 得到位置编码id 加了[None]变成两维的 [[0,1,2,...,seq_len]]
            position_ids = K.arange(0, seq_len, dtype='int32')[None]  

        if self.hierarchical:
            alpha = 0.4 if self.hierarchical is True else self.hierarchical
            embeddings = self.embeddings - alpha * self.embeddings[:1]
            embeddings = embeddings / (1 - alpha)
            embeddings_x = K.gather(embeddings, position_ids // self.input_dim)
            embeddings_y = K.gather(embeddings, position_ids % self.input_dim)
            pos_embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y
        else:
            # 如果是自己输入位置编码,就用位置id读取相应的位置编码
            if self.custom_position_ids:  
                pos_embeddings = K.gather(self.embeddings, position_ids)
            else:
                # 直接拿初始化的位置编码权重
                pos_embeddings = self.embeddings[None, :seq_len]
        # add模式直接把原有特征和位置编码相加即可
        if self.merge_mode == 'add':  
            return inputs + pos_embeddings
        # mul模式是把原有特征和位置编码对应相乘
        elif self.merge_mode == 'mul':  
            return inputs * pos_embeddings
        else:
            if not self.custom_position_ids:
                pos_embeddings = K.tile(pos_embeddings, [batch_size, 1, 1])
            # 如果不属于上述两种模式,则用concat的形式
            return K.concatenate([inputs, pos_embeddings])  

    def compute_output_shape(self, input_shape):
        if self.custom_position_ids:
            input_shape = input_shape[0]

        if self.merge_mode in ['add', 'mul']:
            return input_shape
        else:
            return input_shape[:2] + (input_shape[2] + self.output_dim,)

    def get_config(self):
        config = {
            'input_dim': self.input_dim,
            'output_dim': self.output_dim,
            'merge_mode': self.merge_mode,
            'hierarchical': self.hierarchical,
            'embeddings_initializer':
                initializers.serialize(self.embeddings_initializer),
            'custom_position_ids': self.custom_position_ids,
        }
        base_config = super(PositionEmbedding, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))