卷积神经网络(CNN)的参数计算-keras


import numpy as np
import keras
from keras import layers,models
from keras.utils import plot_model

truncated_length = 80
init_embedding = np.random.random(size=(10000,64))
embedding_trainable = True # 词向量层可以训练
filter_sizes = [3,4,5]
kernel_size = 128
num_classes = 2

inputs = layers.Input(shape=(truncated_length,))
x = layers.Embedding(init_embedding.shape[0],
                     output_dim=init_embedding.shape[1],
                     weights=[init_embedding],
                     input_length=truncated_length,
                     trainable=embedding_trainable)(inputs) 
cnn_list = []
for i in range(len(filter_sizes)):
    _cnn = layers.Conv1D(256, filter_sizes[i], activation='relu', padding='same')(x)
    _cnn = layers.GlobalMaxPooling1D()(_cnn)
    cnn_list.append(_cnn)
cnn = layers.concatenate(cnn_list, axis=1)
dropout = layers.Dropout(0.1)(cnn)
dense = layers.Dense(num_classes, activation='softmax')(dropout)
model = models.Model(inputs=[inputs], outputs=[dense])
model.compile(loss=keras.losses.CategoricalCrossentropy(), optimizer=keras.optimizers.Nadam(), metrics=['accuracy'])
model.summary()

print(1000*64) # embedding
print(3*64*256+256) # filter_sizes = 3
print(4*64*256+256) # filter_sizes = 4
print(5*64*256+256) # filter_sizes = 5
print(768*2+2) # dense

# plot_model(model, to_file='model.png', show_shapes=True,show_dtype=False,)
Using TensorFlow backend.
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 80)           0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 80, 64)       640000      input_1[0][0]                    
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 80, 256)      49408       embedding_1[0][0]                
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 80, 256)      65792       embedding_1[0][0]                
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 80, 256)      82176       embedding_1[0][0]                
__________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalM (None, 256)          0           conv1d_1[0][0]                   
__________________________________________________________________________________________________
global_max_pooling1d_2 (GlobalM (None, 256)          0           conv1d_2[0][0]                   
__________________________________________________________________________________________________
global_max_pooling1d_3 (GlobalM (None, 256)          0           conv1d_3[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 768)          0           global_max_pooling1d_1[0][0]     
                                                                 global_max_pooling1d_2[0][0]     
                                                                 global_max_pooling1d_3[0][0]     
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 768)          0           concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            1538        dropout_1[0][0]                  
==================================================================================================
Total params: 838,914
Trainable params: 838,914
Non-trainable params: 0
__________________________________________________________________________________________________
64000
49408
65792
82176
1538