class SAFM(Module):
def __init__(self, dim, n_levels=4):
super().__init__()
self.n_levels = n_levels
chunk_dim = dim // n_levels
# Spatial Weighting
self.mfr = ModuleList([Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
# # Feature Aggregation
self.aggr = Conv2d(dim, dim, 1, 1, 0)
# Activation
self.act = GELU()
def forward(self, x):
h, w = x.size()[-2:]
xc = x.chunk(self.n_levels, dim=1)
out = []
for i in range(self.n_levels):
if i > 0:
p_size = (h//2**i, w//2**i)
s = torch_nn_adaptive_max_pool2d(xc[i], p_size)
s = self.mfr[i](s)
s = torch_nn_interpolate(s, size=(h, w), mode='nearest')
else:
s = self.mfr[i](xc[i])
out.append(s)
out = self.aggr(torch_cat(out, dim=1))
out = self.act(out) * x
return out