Giter Site home page Giter Site logo

Comments (8)

reflelia avatar reflelia commented on June 3, 2024

in src/train_eval_ende_full.py, len(batch['input_pixels']) is 16, but len(outputs) is 1

from organ.

wjhou avatar wjhou commented on June 3, 2024

Hi,

The input format is fine according to your descriptions.

The generate function in the Transformer library receives the main input (e.g., BxS, with B as the batch size and S as the size of the input) and creates the decoder_input_ids according to the main input. If the main input is set to input_pixels, then the batch size of the created decoder_input_ids should be B.

However, the later version of the Transformer library modified the generation function and used a different init method for creating the decoder_input_ids. So there could be two guesses from my perspective:

  1. If you use a later version of Transformer, then you will need to specify the input_ids in the model_inputs, such as:
    "input_ids": torch.ones_like(batch["input_pixels"][:, :1, 0]).cuda() * model.config.bos_token_id.

  2. If you use the same version of Transformer, may I ask if have you modified any lines of code in the evaluation function?

Best,
Ethan

from organ.

reflelia avatar reflelia commented on June 3, 2024

I using transformers==4.15.0(same version) and didn't edit evaluation function.
The outputs.shape in eval_func in train_eval_ende_full.py is torch.Size([1, 64]). is it correct?

And if I using transformers >= 4.16
TypeError: VLBartEncoder.forward() got an unexpected keyword argument 'pixel_values' occured

from organ.

wjhou avatar wjhou commented on June 3, 2024

Hi,

The output size is not correct. It seems that the decoder_input_ids is not initiated correctly.

Have you modified the main_input_name at line 871 in the src/models/modeling_bart_custom.py file? This should be the key for the generate() function to obtain the batch size.

Best,
Ethan

from organ.

reflelia avatar reflelia commented on June 3, 2024

I didn't edit that and I tried to replace src folder to new but the same error occured
Is this below input file arguments are right?
./script/run_iu_xray.sh debug checkpoint_name plan_model_name_or_path plan_eval_file
checkpoint_name : trained observation planning pytorch_model.bin
plan_model_name_or_path : prediction.json folder path
plan_eval_file : prediction_eval_step_132.json

from organ.

wjhou avatar wjhou commented on June 3, 2024

Hi,

The checkpoint_name is for the pre-trained ResNet model (similar to issue #5), instead of the planner. Stage 2 training does not require the trained planner, so you may conduct a similar operation to remove this argument.

In terms of the decoder_input_ids, you may try to add an item: "input_ids": torch.ones_like(batch["input_pixels"][:, :1, 0]).cuda() * model.config.bos_token_id into the model_inputs. This used the previous settings of generate() to create the decoder_input_ids.

An alternative is to directly set the batch size to 1.

Codes of both stages have been tested locally and should work fine.

Best,
Ethan

from organ.

reflelia avatar reflelia commented on June 3, 2024

It works! Thank you!

from organ.

wjhou avatar wjhou commented on June 3, 2024

Reopen this issue if you have further questions.

from organ.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.