Giter Site home page Giter Site logo

xlsr's Introduction

XLSR

PyTorch implementation of paper "Extremely Lightweight Quantization Robust Real-Time Single-Image Super Resolution for Mobile Devices"

xlsr's People

Contributors

cxzhou95 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

xlsr's Issues

Error in test.py: Unexpected key(s) in state_dict

When I try to run test.py using defaults values (with the correct model path), it gives the following error. What is the problem?

Traceback (most recent call last):
  File "C:\Users\Downloads\XLSR\test.py", line 110, in <module>
    model.load_state_dict(torch.load(os.path.join(opt.save_dir, 'best.pt'), map_location=device))
  File "C:\Users\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for XLSR:
        Missing key(s) in state_dict: "Gblocks.0.conv0.conv2d_block.0.weight", "Gblocks.0.conv0.conv2d_block.0.bias", "Gblocks.0.conv0.conv2d_block.1.weight", "Gblocks.0.conv0.conv2d_block.1.bias", "Gblocks.0.conv0.conv2d_block.2.weight", "Gblocks.0.conv0.conv2d_block.2.bias", "Gblocks.0.conv0.conv2d_block.3.weight", "Gblocks.0.conv0.conv2d_block.3.bias", "Gblocks.1.conv0.conv2d_block.0.weight", "Gblocks.1.conv0.conv2d_block.0.bias", "Gblocks.1.conv0.conv2d_block.1.weight", "Gblocks.1.conv0.conv2d_block.1.bias", "Gblocks.1.conv0.conv2d_block.2.weight", "Gblocks.1.conv0.conv2d_block.2.bias", "Gblocks.1.conv0.conv2d_block.3.weight", "Gblocks.1.conv0.conv2d_block.3.bias", "Gblocks.2.conv0.conv2d_block.0.weight", "Gblocks.2.conv0.conv2d_block.0.bias", "Gblocks.2.conv0.conv2d_block.1.weight", "Gblocks.2.conv0.conv2d_block.1.bias", "Gblocks.2.conv0.conv2d_block.2.weight", "Gblocks.2.conv0.conv2d_block.2.bias", "Gblocks.2.conv0.conv2d_block.3.weight", "Gblocks.2.conv0.conv2d_block.3.bias".
        Unexpected key(s) in state_dict: "Gblocks.0.conv0.weight", "Gblocks.0.conv0.bias", "Gblocks.1.conv0.weight", "Gblocks.1.conv0.bias", "Gblocks.2.conv0.weight", "Gblocks.2.conv0.bias".

关于将quantized_model.pt转换为onnx

您好,我尝试将quantized_model.pt转换为onnx模型,但由于quantized_model.pt在保存的时候用的是torch.jit.save,似乎无法再利用torch.onnx.export进行onnx模型的转换。于是我将quantized_model.pt的保存改为了torch.save(quantized_model.state_dict(), os.path.join(opt.save_dir, 'quantized_model.pt')),但不清楚怎样去读取这个模型,我的读取代码如下,目前是错误的(是否是缺少了反量化的操作?)

quantized_model = XLSR_quantization(3)
quantized_model.load_state_dict(torch.load(os.path.join(opt.save_dir, 'quantized_model.pt'), map_location=device))

量化后模型输出问题

您好!我在将您给的best模型PTQ量化以后,测试了自己的图片。结果输出的图像数据数值不是0就是255,请问这是怎么回事呢?感谢您的解答!

Typo in Gblock

There is an issue in the

class Gblock(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(Gblock, self).__init__()
        self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        
    def forward(self, x):
        x = self.conv0(x)
        x = self.relu(x)
        x = self.conv1(x)
        return x

the line
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

should be replaced with

self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0)

for a general purpose Gblock since in_channels and out_channels does not need to be the same

test problem

Hi!Thanks for ur work.
I wanna know that how u guys to test the benchmark dataset since the pics' sizes(which are constructed by x3 downsapmled pics) are not equal to the corresponding HR pics?

Error in loading the quantized models

While loading the regular best.pt model works, but loading the quantized models (like QAT_quantized_model.pt) throws an error. What is the problem?

python test.py --model exp/OneCyclicLR/QAT_quantized_model.pt --device cpu

Error:

C:\Users\AppData\Local\anaconda3\Lib\site-packages\torch\serialization.py:995: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
  warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
Traceback (most recent call last):
  File "C:\Users\Downloads\XLSR\test.py", line 109, in <module>
    model.load_state_dict(torch.load(opt.model, map_location=device))
  File "C:\Users\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 2103, in load_state_dict
    raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
TypeError: Expected state_dict to be dict-like, got <class 'torch.jit._script.RecursiveScriptModule'>.

量化训练步骤

你好,请问训练流程是 1) 先训练浮点模型,2) 定义量化模型,加载浮点模型权重,再继续进行qat训练吗

Unoptimized Pytorch group conv2d block

For those who experience slow inference and hence slow training of XLSR module, you can try the following custom module instead of built-in pytorch block for group conv

change line in GBlock class

self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)

with

self.conv0 = GConv2d(in_channels,out_channels,kernel_size=3,groups=groups)

where

GConv2d is

class GConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,groups=4):
        super(GConv2d,self).__init__()
        self.conv2d_block=nn.ModuleList()
        self.groups=groups
        for _ in range(groups):
            self.conv2d_block.append(nn.Conv2d(in_channels=in_channels//groups,out_channels=out_channels//groups,kernel_size=kernel_size,padding=kernel_size//2))

    def forward(self,x):
        return torch.cat([filterg(xg) for filterg,xg in zip(self.conv2d_block,torch.chunk(x,self.groups,1))],dim=1)

Personally I can experience almost x2 speedup with this approach during training

*groups parameter is known to be resulting in a slow code accoding to
[https://github.com/pytorch/pytorch/issues/18631]

Thanks for porting this to pytorch

Hello This is Mustafa;

The author of the XLSR paper. I would like to thank you for your clear implementation of my paper and sharing it with the community. In case of any further academic work you may contribute, please share it here I would like to read it and see how my work helped others in their own work.

Good luck

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.