from __future__ import division
import os
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon.nn import BatchNorm
from gluoncv.model_zoo.fcn import _FCNHead
from mxnet import nd
from .askc import LCNASKCFuse
from model.atac.backbone import ATACBlockV1, conv1ATAC, DynamicCell
from model.atac.convolution import LearnedCell, ChaDyReFCell, SeqDyReFCell, SK_ChaDyReFCell, \
SK_1x1DepthDyReFCell, SK_MSSpaDyReFCell, SK_SpaDyReFCell, Direct_AddCell, SKCell, \
SK_SeqDyReFCell, Sub_MSSpaDyReFCell, SK_MSSeqDyReFCell, iAAMSSpaDyReFCell
from model.atac.convolution import \
LearnedConv, ChaDyReFConv, SeqDyReFConv, SK_ChaDyReFConv, \
SK_1x1DepthDyReFConv, SK_MSSpaDyReFConv, SK_SpaDyReFConv, Direct_AddConv, SKConv, \
SK_SeqDyReFConv
# , SK_MSSeqDyReFConv
from .activation import xUnit, SpaATAC, ChaATAC, SeqATAC, MSSeqATAC, MSSeqATACAdd, \
MSSeqATACConcat, MSSeqAttentionMap, xUnitAttentionMap
from model.atac.fusion import Direct_AddFuse_Reduce, SK_MSSpaFuse, SKFuse_Reduce, LocalChaFuse, \
GlobalChaFuse, \
LocalGlobalChaFuse_Reduce, LocalLocalChaFuse_Reduce, GlobalGlobalChaFuse_Reduce, \
AYforXplusYChaFuse_Reduce, XplusAYforYChaFuse_Reduce, IASKCChaFuse_Reduce,\
GAUChaFuse_Reduce, SpaFuse_Reduce, ConcatFuse_Reduce, AXYforXplusYChaFuse_Reduce,\
BiLocalChaFuse_Reduce, BiGlobalChaFuse_Reduce, LocalGAUChaFuse_Reduce, GlobalSpaFuse,\
AsymBiLocalChaFuse_Reduce, BiSpaChaFuse_Reduce, AsymBiSpaChaFuse_Reduce, LocalSpaFuse, \
BiGlobalLocalChaFuse_Reduce
# from gluoncv.model_zoo.resnetv1b import BasicBlockV1b
from gluoncv.model_zoo.cifarresnet import CIFARBasicBlockV1
class ASKCResNetFPN(HybridBlock):
def __init__(self, layers, channels, fuse_mode, act_dilation, classes=1, tinyFlag=False,
norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
super(ASKCResNetFPN, self).__init__(**kwargs)
self.layer_num = len(layers)
self.tinyFlag = tinyFlag
with self.name_scope():
stem_width = int(channels[0])
self.stem = nn.HybridSequential(prefix='stem')
self.stem.add(norm_layer(scale=False, center=False,
**({} if norm_kwargs is None else norm_kwargs)))
if tinyFlag:
self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1,
padding=1, use_bias=False))
self.stem.add(norm_layer(in_channels=stem_width*2))
self.stem.add(nn.Activation('relu'))
else:
self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2,
padding=1, use_bias=False))
self.stem.add(norm_layer(in_channels=stem_width))
self.stem.add(nn.Activation('relu'))
self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1,
padding=1, use_bias=False))
self.stem.add(norm_layer(in_channels=stem_width))
self.stem.add(nn.Activation('relu'))
self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1,
padding=1, use_bias=False))
self.stem.add(norm_layer(in_channels=stem_width*2))
self.stem.add(nn.Activation('relu'))
self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1))
# self.head1 = _FCNHead(in_channels=channels[1], channels=classes)
# self.head2 = _FCNHead(in_channels=channels[2], channels=classes)
# self.head3 = _FCNHead(in_channels=channels[3], channels=classes)
# self.head4 = _FCNHead(in_channels=channels[4], channels=classes)
self.head = _FCNHead(in_channels=channels[1], channels=classes)
self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0],
channels=channels[1], stride=1, stage_index=1,
in_channels=channels[1])
self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1],
channels=channels[2], stride=2, stage_index=2,
in_channels=channels[1])
self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2],
channels=channels[3], stride=2, stage_index=3,
in_channels=channels[2])
if self.layer_num == 4:
self.layer4 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[3],
channels=channels[4], stride=2, stage_index=4,
in_channels=channels[3])
if self.layer_num == 4:
self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3],
act_dilation=act_dilation) # channels[4]
self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2],
act_dilation=act_dilation) # 64
self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1],
act_dilation=act_dilation) # 32
# if fuse_order == 'reverse':
# self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2]) # channels[2]
# self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[3]
# self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
# elif fuse_order == 'normal':
# self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
# self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
# self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4]
def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0,
norm_layer=BatchNorm, norm_kwargs=None):
layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
with layer.name_scope():
downsample = (channels != in_channels) or (stride != 1)
layer.add(block(channels, stride, downsample, in_channels=in_channels,
prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs))
for _ in range(layers-1):
layer.add(block(channels, 1, False, in_channels=channels, prefix='',
norm_layer=norm_layer, norm_kwargs=norm_kwargs))
return layer
def _fuse_layer(self, fuse_mode, channels, act_dilation):
if fuse_mode == 'Direct_Add':
fuse_layer = Direct_AddFuse_Reduce(channels=channels)
elif fuse_mode == 'Concat':
fuse_layer = ConcatFuse_Reduce(channels=channels)
elif fuse_mode == 'SK':
fuse_layer = SKFuse_Reduce(channels=channels)
# elif fuse_mode == 'LocalCha':
# fuse_layer = LocalChaFuse(channels=channels)
# elif fuse_mode == 'GlobalCha':
# fuse_layer = GlobalChaFuse(channels=channels)
elif fuse_mode == 'LocalGlobalCha':
fuse_layer = LocalGlobalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'LocalLocalCha':
fuse_layer = LocalLocalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'GlobalGlobalCha':
fuse_layer = GlobalGlobalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'IASKCChaFuse':
fuse_layer = IASKCChaFuse_Reduce(channels=channels)
elif fuse_mode == 'AYforXplusY':
fuse_layer = AYforXplusYChaFuse_Reduce(channels=channels)
elif fuse_mode == 'AXYforXplusY':
fuse_layer = AXYforXplusYChaFuse_Reduce(channels=channels)
elif fuse_mode == 'XplusAYforY':
fuse_layer = XplusAYforYChaFuse_Reduce(channels=channels)
elif fuse_mode == 'GAU':
fuse_layer = GAUChaFuse_Reduce(channels=channels)
elif fuse_mode == 'LocalGAU':
fuse_layer = LocalGAUChaFuse_Reduce(channels=channels)
elif fuse_mode == 'SpaFuse':
fuse_layer = SpaFuse_Reduce(channels=channels, act_dialtion=act_dilation)
elif fuse_mode == 'BiLocalCha':
fuse_layer = BiLocalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'BiGlobalLocalCha':
fuse_layer = BiGlobalLocalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'AsymBiLocalCha':
fuse_layer = AsymBiLocalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'BiGlobalCha':
fuse_layer = BiGlobalChaFuse_Reduce(channels=channels)
elif fuse_mode == 'BiSpaCha':
fuse_layer = BiSpaChaFuse_Reduce(channels=channels)
elif fuse_mode == 'AsymBiSpaCha':
fuse_layer = AsymBiSpaChaFuse_Reduce(channels=channels)
# elif fuse_mode == 'LocalSpa':
# fuse_layer = LocalSpaFuse(channels=channels, act_dilation=act_dilation)
# elif fuse_mode == 'GlobalSpa':
# fuse_layer = GlobalSpaFuse(channels=channels, act_dilation=act_dilation)
# elif fuse_mode == 'SK_MSSpa':
# # fuse_layer.add(SK_MSSpaFuse(channels=channels, act_dilation=act_dilation))
# fuse_layer = SK_MSSpaFuse(channels=channels, act_dilation=act_dilation)
else:
raise ValueError('Unknown fuse_mode')
return fuse_layer
def hybrid_forward(self, F, x):
_, _, hei, wid = x.shape
x = self.stem(x) # down 4, 32
c1 = self.layer1(x) # down 4, 32
c2 = self.layer2(c1) # down 8, 64
out = self.layer3(c2) # down 16, 128
if self.layer_num == 4:
c4 = self.layer4(out) # down 32
if self.tinyFlag:
c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4
else:
c4 = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16) # down 16
out = self.fuse34(c4, out)
if self.tinyFlag:
out = F.contrib.BilinearResize2D(out, height=hei//2, width=wid//2) # down 2, 128
else:
out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8) # down 8, 128
out = self.fuse23(out, c2)
if self.tinyFlag:
out = F.contrib.BilinearResize2D(out, height=hei, width=wid) # down 1
else:
out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4) # down 8
out = self.fuse12(out, c1)
pred = self.head(out)
if self.tinyFlag:
out = pred
else:
out = F.contrib.BilinearResize2D(pred, height=hei, width=wid) # down 4
######### reverse order ##########
# up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4) # down 4
# fuse2 = self.fuse12(up_c2, c1) # down 4, channels[2]
#
# up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4) # down 4
# fuse3 = self.fuse23(up_c3, fuse2) # down 4, channels[3]
#
# up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4
# fuse4 = self.fuse34(up_c4, fuse3) # down 4, channels[4]
#
######### normal order ##########
# out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16)
# out = self.fuse34(out, c3)
# out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8)
# out = self.fuse23(out, c2)
# out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4)
# out = self.fuse12(out, c1)
# out = self.head(out)
# out = F.contrib.BilinearResize2D(out, height=hei, width=wid)
return out
def evaluate(self, x):
"""evaluating network with inputs and targets"""
return self.forward(x)