Giter Site home page Giter Site logo

Comments (3)

lukas-blecher avatar lukas-blecher commented on September 27, 2024

If you want to use the model in inference mode only, you don't need the train transform. The test transform is just a normalization. You could probably overwrite this function with the image net normalization.

The training augmentations are implemented without torch support. But you could try to recreate it with torchvision transforms.

What did you execute to get this error? Maybe I could look into it a bit if you provide the code.

from nougat.

chophilip21 avatar chophilip21 commented on September 27, 2024

Hi, Thanks for responding!

I have modified your predict.py to see if I can convert the model into torchscript. As you have mentioned, defining a wrapper like this below to prevent the preprocess scripts (prepare_input) being called for the Image does get me pass the first problem regarding train_transform.

class NougatWrapper(torch.nn.Module):
    def __init__(self, model):
        super(NougatWrapper, self).__init__()
        self.model = model

    def forward(self, image):
        output = self.model.inference(image_tensors=image)
        return output

So the torchscript logic is something like:

model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.script(nougat_wrapper)

However, above fails as a lot of the modules related to transformers/models/mbart are not convertable.

File "/home/chophilip21/nougat/predict.py", line 188, in <module>
    main()
  File "/home/chophilip21/nougat/predict.py", line 139, in main
    script_model = torch.jit.script(nougat_wrapper)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 492, in create_script_module_impl
    method_stubs = stubs_fn(nn_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 761, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 73, in make_stub_from_method
    return make_stub(func, method_name)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 58, in make_stub
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 297, in get_jit_def
    return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 335, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 359, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/transformers/models/mbart/modeling_mbart.py", line 1703
    def forward(self, *args, **kwargs):
                              ~~~~~~~ <--- HERE
        return self.decoder(*args, **kwargs)

I did manage to get pass above **kwargs restriction by manually laying out all the parameters, but unfortunately there are whole bunch of other errors related to converting mbart.

from nougat.

chophilip21 avatar chophilip21 commented on September 27, 2024

Okay, so instead of torch.script, tracing works without any issues.

a, b = next(iter(dataloader))
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.trace(nougat_wrapper, a)
script_model.save("nougat.pt")

But I can see that the output of the network is a a dictionary of lists:

output = {
          "predictions": list(),
          "sequences": list(),
          "repeats": list(),
          "repetitions": list(),
      }

which causes RuntimeError, which kind of makes sense as you can see that it's a string output of the input pdf document.

Traceback (most recent call last):
  File "/home/chophilip21/latex_training/deploy.py", line 215, in <module>
    main()
  File "/home/chophilip21/latex_training/deploy.py", line 146, in main
    script_model = torch.jit.trace(nougat_wrapper, a)
  File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of {'predictions': ["\n\n# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients.\n\n"], 'sequences': tensor([[    0,    25, 15337,  7834, 14229,   221,   221, 31226,   300, 38411,
            92,  1296,   221,   221,  3323,   638,  2922,   221,   221,  5302,
          9928,  6320,   286,  8988,   321, 42898,  5497,  1684,  3521,   301,
         13099,  1287,   312,   301, 16199,  1369,  3247,  3343,   345,  8091,
          2459,    36,   732,  4190, 44919,   286,  4622,    29,   105, 25473,
           434,   286,  5764, 18043,   299, 25389,   896,  4271,   243,    40,
            38,    40,    40,    36,  1077, 10449,  2865, 10498,   301,   286,
           740,   299,  2679,  1383,  2330,   301,  3700,   312,  5462,    34,
           312,   491, 25960,   286,  1673,   299,  3802,  3521,   301,   286,
           740,   299,  1684, 18923,   363,  1932,  3538,    36,     2]]), 'repeats': [None], 'repetitions': ["# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients."]}
:Dictionary inputs to traced functions must have consistent type. Found List[str] and Tensor

It's probably because postprocess is being called for output['predictions'] and maybe I need to have this outside of the model definition. I need to dig deeper on this, but I'm not entirely sure if the way I am approaching this is correct though. If you have any better ideas, please let me know!

from nougat.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.