Comments (8)
in src/train_eval_ende_full.py, len(batch['input_pixels']) is 16, but len(outputs) is 1
from organ.
Hi,
The input format is fine according to your descriptions.
The generate function in the Transformer library receives the main input (e.g., BxS, with B as the batch size and S as the size of the input) and creates the decoder_input_ids
according to the main input. If the main input is set to input_pixels
, then the batch size of the created decoder_input_ids
should be B.
However, the later version of the Transformer library modified the generation function and used a different init method for creating the decoder_input_ids
. So there could be two guesses from my perspective:
-
If you use a later version of Transformer, then you will need to specify the
input_ids
in themodel_inputs
, such as:
"input_ids": torch.ones_like(batch["input_pixels"][:, :1, 0]).cuda() * model.config.bos_token_id
. -
If you use the same version of Transformer, may I ask if have you modified any lines of code in the evaluation function?
Best,
Ethan
from organ.
I using transformers==4.15.0(same version) and didn't edit evaluation function.
The outputs.shape
in eval_func
in train_eval_ende_full.py
is torch.Size([1, 64])
. is it correct?
And if I using transformers >= 4.16
TypeError: VLBartEncoder.forward() got an unexpected keyword argument 'pixel_values'
occured
from organ.
Hi,
The output size is not correct. It seems that the decoder_input_ids
is not initiated correctly.
Have you modified the main_input_name
at line 871 in the src/models/modeling_bart_custom.py
file? This should be the key for the generate()
function to obtain the batch size.
Best,
Ethan
from organ.
I didn't edit that and I tried to replace src folder to new but the same error occured
Is this below input file arguments are right?
./script/run_iu_xray.sh debug checkpoint_name plan_model_name_or_path plan_eval_file
checkpoint_name : trained observation planning pytorch_model.bin
plan_model_name_or_path : prediction.json
folder path
plan_eval_file : prediction_eval_step_132.json
from organ.
Hi,
The checkpoint_name is for the pre-trained ResNet model (similar to issue #5), instead of the planner. Stage 2 training does not require the trained planner, so you may conduct a similar operation to remove this argument.
In terms of the decoder_input_ids
, you may try to add an item: "input_ids": torch.ones_like(batch["input_pixels"][:, :1, 0]).cuda() * model.config.bos_token_id
into the model_inputs
. This used the previous settings of generate()
to create the decoder_input_ids
.
An alternative is to directly set the batch size to 1.
Codes of both stages have been tested locally and should work fine.
Best,
Ethan
from organ.
It works! Thank you!
from organ.
Reopen this issue if you have further questions.
from organ.
Related Issues (10)
- About ./data/mention/ HOT 1
- about chexbert_dir="./CheXbert/src/data/iu_xray/id2tag.csv" HOT 1
- about checkpoint_name: indicating the location for the pre-trained visual model, mainly for IU Xray dataset HOT 5
- CE scores HOT 3
- How can I train this model from scratch? HOT 4
- FileNotFoundError: [Errno 2] No such file or directory: './data/mention/' HOT 3
- What is the content of ./mention/? HOT 3
- Training guideline for MIMIC-CXR HOT 20
- "./data/iu_xray_id2tag.csv" & ./data/iu_xray/id2tag.csv HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from organ.