Giter Site home page Giter Site logo

ncnnqat's Introduction

ncnnqat

ncnnqat is a quantize aware training package for NCNN on pytorch.

Table of Contents

Installation

  • Supported Platforms: Linux

  • Accelerators and GPUs: NVIDIA GPUs via CUDA driver 10.1.

  • Dependencies:

    • python >= 3.5, < 4
    • pytorch >= 1.6
    • numpy >= 1.18.1
    • onnx >= 1.7.0
    • onnx-simplifier >= 0.3.5
  • Install ncnnqat via pypi:

    $ pip install ncnnqat (to do....)

    It is recommended to install from the source code

  • or Install ncnnqat via repo:

    $ git clone https://github.com/ChenShisen/ncnnqat
    $ cd ncnnqat
    $ make install

Usage

  • merge bn weight into conv and freeze bn

    suggest finetuning from a well-trained model, register_quantization_hook and merge_freeze_bn at beginning. do it after a few epochs of training otherwise.

    from ncnnqat import unquant_weight, merge_freeze_bn, register_quantization_hook
    ...
    ...
        for epoch in range(epoch_train):
      	  model.train()
      	  if epoch==well_epoch:
      		  register_quantization_hook(model)
      	  if epoch>=well_epoch:
      		  model = merge_freeze_bn(model)  #it will change bn to eval() mode during training
    ...
  • Unquantize weight before update it

    ...
    ...
        model.apply(unquant_weight)  # using original weight while updating
        optimizer.step()
    ...
  • Save weight and save ncnn quantize table after train

    ...
    ...
        onnx_path = "./xxx/model.onnx"
        table_path="./xxx/model.table"
        dummy_input = torch.randn(1, 3, img_size, img_size, device='cuda')
        input_names = [ "input" ]
        output_names = [ "fc" ]
        torch.onnx.export(model, dummy_input, onnx_path, verbose=False, input_names=input_names, output_names=output_names)
        save_table(model,onnx_path=onnx_path,table=table_path)
    
    ...

    if use "model = nn.DataParallel(model)",pytorch unsupport torch.onnx.export,you should save state_dict first and prepare a new model with one gpu,then you will export onnx model.

    ...
    ...
        model_s = new_net() #
        model_s.cuda()
        register_quantization_hook(model_s)
        #model_s = merge_freeze_bn(model_s)
        onnx_path = "./xxx/model.onnx"
        table_path="./xxx/model.table"
        dummy_input = torch.randn(1, 3, img_size, img_size, device='cuda')
        input_names = [ "input" ]
        output_names = [ "fc" ]
        model_s.load_state_dict({k.replace('module.',''):v for k,v in model.state_dict().items()}) #model_s = model     model = nn.DataParallel(model)
              
        torch.onnx.export(model_s, dummy_input, onnx_path, verbose=False, input_names=input_names, output_names=output_names)
        save_table(model_s,onnx_path=onnx_path,table=table_path)
        
    
    ...

Code Examples

Cifar10 quantization aware training example.

python test/test_cifar10.py

Results

  • Cifar10

    result:

    net fp32(onnx) ncnnqat ncnn aciq ncnn kl
    mobilenet_v2 0.91 0.9066 0.9033 0.9066
    resnet18 0.94 0.93333 0.9367 0.937
  • coco

    ....

Todo

....

ncnnqat's People

Contributors

chenshisen avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.