Comments (3)
If you want to use the model in inference mode only, you don't need the train transform. The test transform is just a normalization. You could probably overwrite this function with the image net normalization.
The training augmentations are implemented without torch support. But you could try to recreate it with torchvision transforms.
What did you execute to get this error? Maybe I could look into it a bit if you provide the code.
from nougat.
Hi, Thanks for responding!
I have modified your predict.py
to see if I can convert the model into torchscript. As you have mentioned, defining a wrapper like this below to prevent the preprocess scripts (prepare_input) being called for the Image does get me pass the first problem regarding train_transform
.
class NougatWrapper(torch.nn.Module):
def __init__(self, model):
super(NougatWrapper, self).__init__()
self.model = model
def forward(self, image):
output = self.model.inference(image_tensors=image)
return output
So the torchscript logic is something like:
model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.script(nougat_wrapper)
However, above fails as a lot of the modules related to transformers/models/mbart
are not convertable.
File "/home/chophilip21/nougat/predict.py", line 188, in <module>
main()
File "/home/chophilip21/nougat/predict.py", line 139, in main
script_model = torch.jit.script(nougat_wrapper)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
return torch.jit._recursive.create_script_module(
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 492, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 761, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 73, in make_stub_from_method
return make_stub(func, method_name)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 58, in make_stub
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 297, in get_jit_def
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 335, in build_def
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 359, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/transformers/models/mbart/modeling_mbart.py", line 1703
def forward(self, *args, **kwargs):
~~~~~~~ <--- HERE
return self.decoder(*args, **kwargs)
I did manage to get pass above **kwargs
restriction by manually laying out all the parameters, but unfortunately there are whole bunch of other errors related to converting mbart.
from nougat.
Okay, so instead of torch.script, tracing works without any issues.
a, b = next(iter(dataloader))
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.trace(nougat_wrapper, a)
script_model.save("nougat.pt")
But I can see that the output of the network is a a dictionary of lists:
output = {
"predictions": list(),
"sequences": list(),
"repeats": list(),
"repetitions": list(),
}
which causes RuntimeError, which kind of makes sense as you can see that it's a string output of the input pdf document.
Traceback (most recent call last):
File "/home/chophilip21/latex_training/deploy.py", line 215, in <module>
main()
File "/home/chophilip21/latex_training/deploy.py", line 146, in main
script_model = torch.jit.trace(nougat_wrapper, a)
File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 794, in trace
return trace_module(
File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 1056, in trace_module
module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of {'predictions': ["\n\n# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients.\n\n"], 'sequences': tensor([[ 0, 25, 15337, 7834, 14229, 221, 221, 31226, 300, 38411,
92, 1296, 221, 221, 3323, 638, 2922, 221, 221, 5302,
9928, 6320, 286, 8988, 321, 42898, 5497, 1684, 3521, 301,
13099, 1287, 312, 301, 16199, 1369, 3247, 3343, 345, 8091,
2459, 36, 732, 4190, 44919, 286, 4622, 29, 105, 25473,
434, 286, 5764, 18043, 299, 25389, 896, 4271, 243, 40,
38, 40, 40, 36, 1077, 10449, 2865, 10498, 301, 286,
740, 299, 2679, 1383, 2330, 301, 3700, 312, 5462, 34,
312, 491, 25960, 286, 1673, 299, 3802, 3521, 301, 286,
740, 299, 1684, 18923, 363, 1932, 3538, 36, 2]]), 'repeats': [None], 'repetitions': ["# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients."]}
:Dictionary inputs to traced functions must have consistent type. Found List[str] and Tensor
It's probably because postprocess
is being called for output['predictions'] and maybe I need to have this outside of the model definition. I need to dig deeper on this, but I'm not entirely sure if the way I am approaching this is correct though. If you have any better ideas, please let me know!
from nougat.
Related Issues (20)
- Have anybody trained a Chinese version?
- Why citations come out like this?
- Why can't I run nougat-ocr on pdfs? HOT 3
- What does this issue mean?
- TypeError: BARTDecoder.prepare_inputs_for_inference() got an unexpected keyword argument 'cache_position' HOT 11
- Big Bug: Can not detect some kinds of whitespace in arxiv computer science paper.
- UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3587.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] and TypeError: BARTDecoder.prepare_inputs_for_inference() got an unexpected keyword argument 'cache_position' -> Cannot close object, library is destroyed. This may cause a memory leak! HOT 5
- Questions Regarding the Nougat Model's Pre-training Process
- pydantic error HOT 17
- The command to generate mmd from pdf, the output result is empty
- TypeError: BARTDecoder HOT 3
- Low amount of recognised pages
- pydantic_core._pydantic_core.ValidationError: 1 validation error for InitSchema | Windows, Python 3.11.5 HOT 3
- Training set format problem
- nought项目可以下载直接部署到服务器对外暴露接口提供使用吗
- Why? Cannot close object, library is destroyed. This may cause a memory leak! HOT 3
- How to view .mmd format in better format ? HOT 1
- buy the license for the model
- Maintenance status; community fork HOT 8
- Input should be None [type=none_required, input_value=1.2, input_type=float] For further information visit https://errors.pydantic.dev/2.8/v/none_required HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nougat.