Giter Site home page Giter Site logo

PPO for SLT about trl HOT 6 CLOSED

huggingface avatar huggingface commented on July 30, 2024
PPO for SLT

from trl.

Comments (6)

lvwerra avatar lvwerra commented on July 30, 2024

Hi @jpanaro

Glad you are interested in the library. Let's see if I understand correctly: Your input is a continuous stream of features from a steam of images. Now you want to process the series of features to a series of text which corresponds to the signs in the video.

While the PPOTrainer is fairly general it is mainly tested for use in combination with GPT-2. This is a model that predicts the next word based on the previous words, usually referred to as autoregressive language modelling. Therefore, GPT-2 models the following probability distribution:

(The probability that word x_t appears after the words x_0, x_1 etc.)

I think for your use-case this architecture needs some modifications. One way I could imagine this could work is if you find a clever way to integrate you features into the context such that the model has the one follwoing objectives:

One feature for the hole series:

One feature for each word:

For t words there are n features:

Which one of them applies really depends on how your input features and output text are aligned. In any case one way to modify the GPT-2 architecture for the extra features would be to enrich the embeddings of the input tokens (words) with your features. This happens here in the transformers library. This is where the input tokens are also transformed to embeddings which you can regard as its input features.

Alternatively you could try to use something like this architecture and then modify that to work with the PPOTrainer. Probably you just need to add a value estimation head like I did with GPT-2 which is needed for PPO (see here).

I have never done something like this so these are just suggestions. Let me know how it goes!

from trl.

jpanaro avatar jpanaro commented on July 30, 2024

Thank you for the quick response. To answer a few of your questions and comments:

Yes, our input is the images or frames and the ground truths we are given come in the form of sentences for that sequence.

I think the idea of integrating the features into the context for GPT-2 is really interesting but unfortunately I am on somewhat of a deadline and it looks as if I will have to explore that option later. Still a very unique approach!

I do really like the idea of adapting some aspects of the video captioning model for use with PPOTrainer. You mentioned the addition of a value estimation head which appears to take a hidden state(s) and return a scalar for each one. I think this is well withing my ability and once I get the base transformer model up and running I will make best efforts to integrate it. Thank you for the idea!

I do have a few small questions about the architecture of the PPOTrainer:

  • So throughout the trainer, the model input, or the querys and responses are at most used used to generate the associated logits, values, and logprobs needed for use with various aspects of the PPOTrainer, the fact that they are both effectively strings of tokens is almost irrelevant since any model input that can generate valid logits, logprobs, and values should work?
    (for example if my 'query' and 'response' were simply the feature stream needed for model to generate those logits, logprobs and values?)

  • Also, the structure "train_stats" is present throughout the many functions. I am somewhat unfamiliar with W&B but is this structure there purely for a logging purpose or does it have a greater role in the actual functionality of the trainer?

from trl.

lvwerra avatar lvwerra commented on July 30, 2024

Thanks for clarifying what you are trying to achieve. Answering your first question takes a little bit of explanation as the devil is in the details. So there are a few things to note:

  1. The PPOTrainer is designed to fine-tune a model rather than training it from scratch. Therefore, it also requires a reference model and the KL-divergence between the two models is used as an additional reward signal. Are you also using a pretrained model and just want to fine-tune it? You could of course set the KL-divergence factor to zero and thus ignore it but I have never attempted it and am not sure how well it works.
  2. Since GPT-2 is an autoregressive model, it already generates an output for each query token plus the actual response tokens. I suspect this would be similar in your transformer architecture. The PPOTrainer uses the query length to determine which output logits and logprobs to ignore in the optimisation step. In your case you can probably use all of the decoder outputs and just need the features in the encoder step. Just keep that in mind.
  3. The PPOTrainer concatenates the query and response tensors (since both are just token ids) and uses them as model input for the forward pass. This step is needed to have differentiable outputs. Since you have multimodal query/tensors and a encoder/decoder architecture you might need to adapt this slightly. The relevant code is here and the following batched_forward_pass. I think it should not be too hard to adapt this for your architecture.
  4. That said, your statement is right: you should be able to use the PPOTrainer as long as the model generates valid logits, logprobs and values from your query/response pairs. The PPOTrainer expects the HuggingFace transformers format of the model outputs.

Finally, as for the train_stats object, you are right that this is a strictly passive component that gathers various statistics in a dictionary that can then be used to log them via W&B (which I strongly recommend to track the training progress). If you want to log or plot some information about the training progress yourself have a look at its entries. See a W&B report of a TRL training here. It is super easy to setup and helped me a lot debugging the library when I developed it.

I hope this helps. Let me know if you have any more questions.

from trl.

jpanaro avatar jpanaro commented on July 30, 2024
  1. Completely understand. In my first project I used REINFORCE to fine-tune a seq2seq model that had been pretrained on the same dataset using cross-entropy loss so the plan is to do the same thing here but with a Transformer instead of a seq2seq model and using PPOTrainer instead of the code I wrote for REINFORCE (it was heavily based on the work done in the paper. here if you are interested in taking a look). I am definitely going to integrate KL-divergence using the cross-entropy model as the reference model as it seems pretty critical to the success of the fine-tuning.
  2. When you say "determine which output logits and logprobs to ignore" are you referring to the modification of logprobs and vpred found here?
  3. I agree, this should just be a matter of insuring the dimensions all match up prior to making a pass on the model.
  4. I think I should be able to manually format my model output to the HuggingFace format seeing as I have all of the same information, but stored in a different way initially.

In the past I have "manually" stored and processed my data and statistics using various helper scripts which quickly turns into a massive pain and bloats a lot of my files with excess "tracking" code. W&B seems like a cool alternative and I am running through the tutorial now, thanks for the suggestion!

Your help is invaluable, thank you a ton for the assistance so far!

from trl.

lvwerra avatar lvwerra commented on July 30, 2024

When you say "determine which output logits and logprobs to ignore" are you referring to the modification of logprobs and vpred found here?

Yes, exactly. In my case the task is text continuation from the query. When calculating the logprobs etc. the model makes predictions for each query and response token. The predictions on the query part, however, are not relevant. I think in your case this is not a problem since all the generation is new.

Indeed, W&B is great for exactly that. If you add the appropriate lines to your code all the metrics are logged all the times along with the relevant parameters and even code.

Let me know if you have any further questions!

from trl.

lvwerra avatar lvwerra commented on July 30, 2024

I close this issue for now. If you have any more questions, just let me know. In any case if you publish your work I would very much like to read it!

from trl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.