SE-Inception v3架构的模型搭建(keras代码实现)

Mark 2018-09-28 16:47
关注文章

首先,先上SENet架构的原理图:

description

图是将SE模块嵌入到Inception结构的一个示例。方框旁边的维度信息代表该层的输出。这里我们使用global average pooling作为Squeeze操作。紧接着两个Fully Connected 层组成一个Bottleneck结构去建模通道间的相关性,并输出和输入特征同样数目的权重。我们首先将特征维度降低到输入的1/16,然后经过ReLu激活后再通过一个Fully Connected 层升回到原来的维度。 这样做比直接用一个Fully Connected层的好处在于:

  1. 具有更多的非线性,可以更好地拟合通道间复杂的相关性;
  2. 极大地减少了参数量和计算量。然后通过一个Sigmoid的门获得0~1之间归一化的权重,最后通过一个Scale的操作来将归一化后的权重加权到每个通道的特征上。

SENet架构(Squeeze And Excitation),无非就是Squeeze操作和Excitation操作

  • 首先是Squeeze操作,我们顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用的。
  • 其次是Excitation操作,它是一个类似于循环神经网络中门的机制。通过参数 来为每个特征通道生成权重,其中参数 被学习用来显式地建模特征通道间的相关性。
  • 最后是一个Reweight的操作,我们将Excitation的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定

接下来是代码实现:

def build_model(out_dims, input_shape=(224, 224, 3)):
    inputs_dim = Input(input_shape)
    x = Lambda(lambda x: x / 255.0)(inputs_dim) #在模型里进行归一化预处理

    x = InceptionV3(include_top=False,
                weights='imagenet',
                input_tensor=None,
                input_shape=(224, 224, 3),
                pooling=max)(x)

    squeeze = GlobalAveragePooling2D()(x)

    excitation = Dense(units=2048 // 11)(squeeze)
    excitation = Activation('relu')(excitation)
    excitation = Dense(units=2048)(excitation)
    excitation = Activation('sigmoid')(excitation)
    excitation = Reshape((1, 1, 2048))(excitation)

    scale = multiply([x, excitation])

    x = GlobalAveragePooling2D()(scale)
    dp_1 = Dropout(0.6)(x)
    fc2 = Dense(out_dims)(dp_1)
    fc2 = Activation('sigmoid')(fc2) #此处注意,为sigmoid函数
    model = Model(inputs=inputs_dim, outputs=fc2)
    return model

在此声明 文中部分是非原创,本人只是为想实现SENet架构的小伙伴们搞个福利,贴了代码实现出来以便大家参考使用

{{panelTitle}}(1)
支持Markdown和数学公式,公式格式:\\(...\\)或\\[...\\]
Mark 2018-09-28 16:48

大佬些 有什么意见或者是见解可以提出来,我也好学习学习

2019-04-17 23:42

大佬你好,请问为什么在scale = multiply([x, excitation])之后还要x = GlobalAveragePooling2D()(scale)一下啊,这个操作的目的是什么呢?

关注微信公众号