Hi Phil,
Hope you are doing great and apologies for piggy-backing on this github repository. but I thought you might be able to provide me with some pointers as you have worked extensively with the Donut model.
I've been trying to convert this Donut model to Inferentia2. I've been basing myself on the excellent script provided , inference_transformers_vision.py but getting some exceptions when running the trace model python script.
Here is the code used (trace-model.py) :
`import torch
import os
import importlib
import requests
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
chip_type = os.environ.get("CHIP_TYPE", "inf2")
print(f"Selecting chip type: {chip_type}")
if chip_type == "inf1":
import torch_neuron as neuron_lib
elif chip_type == "inf2":
import torch_neuronx as neuron_lib
batch_size = 1
sequence_length = 128
model_name = 'naver-clova-ix/donut-base-finetuned-cord-v2'
#2. LOAD PRE-TRAINED MODEL
print(f'\nLoading pre-trained model: {model_name}')
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
#3. TOKENIZE THE INPUT
#note: if you don't include return_tensors='pt' you'll get a list of lists which is easier for exploration but you cannot feed that into a model.
#Move model to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
#let's perform inference on an image
url = "https://media.snopes.com/2017/07/walmart-jajket.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image = image.convert("RGB")
#prepare decoder inputs
task_prompt = ""
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
print('\nTracing model ...')
pipeline_cores = 1
model_traced = neuron_lib.trace(model, pixel_values, compiler_workdir=f'{chip_type}-compiler-workdir')
print(' tracing completed.')
model_traced.save('./compiled-model-bs-'+str(batch_size)+'.pt')
print('\n Model Traced and Saved')
`
And the exception I've been getting is shown below:
Tracing model ...
Traceback (most recent call last):
File "/trace-model/trace-model.py", line 47, in
model_traced = neuron_lib.trace(model, tuple(pixel_values), tuple(decoder_input_ids), compiler_workdir=f'{chip_type}-compiler-workdir')
File "/opt/conda/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py", line 289, in trace
neff_filename, metaneff, flattener, packer = _trace(
File "/opt/conda/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py", line 326, in _trace
hlo, input_parameter_names, constant_parameter_tensors, flattener, packer = xla_trace(
File "/opt/conda/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py", line 94, in xla_trace
outputs = func(*example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py", line 581, in forward
encoder_outputs = self.encoder(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/donut/modeling_donut_swin.py", line 934, in forward
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/donut/modeling_donut_swin.py", line 177, in forward
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/donut/modeling_donut_swin.py", line 228, in forward
_, num_channels, height, width = pixel_values.shape
ValueError: not enough values to unpack (expected 4, got 3)
Please let me know what I might be missing here. thanks.