语义分割网络RefineNet原理与Pytorch实现(附代码地址)

GCME 2019-01-01 08:48
关注文章

![description](article-content/f103f876-6fb9-403c-b962-d3a1b8e36fa7)


**【1】RefineNet介绍**

![description](article-content/6c06894b-4013-4e6c-bb21-63cf57c4adc0)

RefineNet结构图

RefineNet使用long-range残差连接,能够有效的将下采样中缺失的信息融合进来,从而产生高分辨率的预测图像。它将粗糙的高层语义特征和细粒度的底层特征进行融合。利用残差连接和identity mapping的思想,实现了端到端的训练。通过链式残差池化(Chained residual pooling)能够融合丰富的上下文信息。

RefineNet在多个数据集上达到state-of-the-art的效果,作者利用Pytorch框架实现了了RefineNet网络,并在NYUD-v2数据集上进行了训练。

直接在CNN 网络中加上VGG,ResNet 用于语义分割使得CNN卷积池化得到的特征图是降采样的32 倍,丢失很多信息,结果比较粗糙。可以使用反卷积进行上采样,不能恢复低层的特征。也可以利用空洞卷积生成高分辨率的特征图,但要耗费大量计算资源。

网络的各层特征对于分割来说都是有用的,高层特征有助于类别识别,低层特征有助于生成精细的边界。将这些特征全部融合对于提升分割的准确率有帮助,RefineNet就是基于此思想提出来的。


【2】RefineNet网络结构

RefineNet是利用ResNet作为基础网络,将ResNet的中间层按照分辨率分为四个blocks,每个block通过一个叫做RefineNet block的模块,然后逐个对其进行融合,最后得到一个refined feature map。除了RefineNet-4,所有的RefineNet block 都是二输入的,用于融合不同分辨率的feature map做refine。

![description](article-content/06c37707-8858-4fb4-9d30-6d611b362d0e)

RefineNet Block

RefineNet Block主要由Residual Convolution Unit(RCU), Multi-Resolution Fusion(MRF),Chained Residual Pooling(CRP)组成。下面我们分别来介绍。RefineNet block代码如下: ``` class BaseRefineNetBlock(nn.Module): def __init__(self, features, residual_conv_unit, multi_resolution_fusion, chained_residual_pool, *shapes): super().__init__()
    for i, shape in enumerate(shapes):
        feats = shape[0]
        self.add_module(
            "rcu{}".format(i),
            nn.Sequential(
                residual_conv_unit(feats), residual_conv_unit(feats)))

    if len(shapes) != 1:
        self.mrf = multi_resolution_fusion(features, *shapes)
    else:
        self.mrf = None

    self.crp = chained_residual_pool(features)
    self.output_conv = residual_conv_unit(features)

def forward(self, *xs):
    rcu_xs = []

    for i, x in enumerate(xs):
        rcu_xs.append(self.__getattr__("rcu{}".format(i))(x))

    if self.mrf is not None:
        out = self.mrf(*rcu_xs)
    else:
        out = rcu_xs[0]

    out = self.crp(out)
    return self.output_conv(out)
<br/><br/>
RCU
<br/><br/>
RCU是从残差网络中提取出来的单元结构,由两组ReLU和3x3卷积构成的块组成。RCU代码如下:

class ResidualConvUnit(nn.Module): def init(self, features): super().init()

    self.conv1 = nn.Conv2d(
        features, features, kernel_size=3, stride=1, padding=1, bias=True)
    self.conv2 = nn.Conv2d(
        features, features, kernel_size=3, stride=1, padding=1, bias=False)
    self.relu = nn.ReLU(inplace=False)

def forward(self, x):

    out = self.relu(x)
    out = self.conv1(out)
    out = self.relu(out)
    out = self.conv2(out)

    return out + x
<br/><br/>
MRF
<br/><br/>
MRF先对多输入的特征图都用一个卷积层进行自适应再上采样,最后进行element-wise相加。MRF代码如下:

class MultiResolutionFusion(nn.Module): def init(self, out_feats, *shapes): super().init()

    _, max_h, max_w = max(shapes, key=lambda x: x[1])

    self.scale_factors = []
    for i, shape in enumerate(shapes):
        feat, h, w = shape
        if max_h % h != 0:
            raise ValueError("max_size not divisble by shape {}".format(i))

        self.scale_factors.append(max_h // h)
        self.add_module(
            "resolve{}".format(i),
            nn.Conv2d(
                feat,
                out_feats,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False))

def forward(self, *xs):

    output = self.resolve0(xs[0])
    if self.scale_factors[0] != 1:
        output = un_pool(output, self.scale_factors[0])

    for i, x in enumerate(xs[1:], 1):
        tmp_out = self.__getattr__("resolve{}".format(i))(x)
        if self.scale_factors[i] != 1:
            tmp_out = un_pool(tmp_out, self.scale_factors[i])
        output = output + tmp_out

    return output
<br/><br/>
CRP
<br/><br/>
CRP先进行一次ReLU激活,然后经过多次残差连接,每个残差都由一个5x5卷积和3x3卷积块组成。CRP代码如下:

class ChainedResidualPool(nn.Module): def init(self, feats, block_count=4): super().init()

    self.block_count = block_count
    self.relu = nn.ReLU(inplace=False)
    for i in range(0, block_count):
        self.add_module(
            "block{}".format(i),
            nn.Sequential(
                nn.MaxPool2d(kernel_size=5, stride=1, padding=2),
                nn.Conv2d(
                    feats,
                    feats,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False)))

def forward(self, x):
    x = self.relu(x)
    path = x
<br/><br/>
请访问文章中github地址连接获取RefineNet网络整体结构和训练、测试、预测等完整代码。
模型
{{panelTitle}}
支持Markdown和数学公式,公式格式:\\(...\\)或\\[...\\]

还没有内容

关注微信公众号