Giter Site home page Giter Site logo

Comments (7)

RohitMidha23 avatar RohitMidha23 commented on June 15, 2024 1

Understood, will keep it in mind next time.

Here is a detailed description of our data setup:

In Distribution Out of Distribution
Transcribe 70 hours 0
Translate 250 hours ~500 hours

Yes the language is a part of Whisper's pre-training dataset but extremely low resource.

from transformers.

RohitMidha23 avatar RohitMidha23 commented on June 15, 2024

@sanchit-gandhi any suggestions on this?

from transformers.

sanchit-gandhi avatar sanchit-gandhi commented on June 15, 2024

Hey @RohitMidha23 - super sorry for the late reply here! Just a quick preface: we typically reserve GitHub issues for bug reports and feature requests for the Transformers library (e.g. feature X is broken with model Y), and the Hugging Face forum for questions regarding examples scripts and use-case applications (e.g. how do I fine-tune model X for task Y). Considering this, your question would be more appropriate for the forum! Just something to bear in mind to ensure you get help as quickly as possible, and so that the answer is maximally visible to other members of the community. We can discuss your question here for now, but I'd appreciate if you could copy and paste the final thread over to the forum!

Thanks for your issue description! Could I ask a few follow-up questions:

  1. How much translate and transcribe data do you have? (per task)
  2. Is this language part of the Whisper pre-training dataset?

I imagine you'd actually get best performance doing one of the following:

  1. Two rounds of fine-tuning: train on transcribe, then on translate (i.e. as you've proposed above)
  2. Train on the translate and transcribe datasets jointly in a single round of fine-tuning
  3. Train on translate, but supplement your small translate dataset with speech-translation data in other languages

These are all valid options for boosting your effective dataset size. As to which one will work best: that depends on how in-distribution your different datasets are, and how much of each one you have. The best thing is to try training using each approach and seeing which works best!

Related: @eustlb has been trying 3 for Distil-Whisper experiments, and has gotten promising first results (see the link for details).

from transformers.

sanchit-gandhi avatar sanchit-gandhi commented on June 15, 2024

Great - thanks for the clarification! Based on the above, I would give 1 and 2 a go to see if they give you a good starting point. You can then supplement with additional data in linguistically-related languages if you need more data

from transformers.

RohitMidha23 avatar RohitMidha23 commented on June 15, 2024

@sanchit-gandhi tried out both methods 1 and 2. With 2, we see that even with the translate token being passed, the output contains transcribed words. Any known reasons for this?

With method 3, how much data would we need for a related language?
Does training schedule play a part here?

from transformers.

sanchit-gandhi avatar sanchit-gandhi commented on June 15, 2024

Hey @RohitMidha23 - awesome to hear! Could you share the code that you're using for 2 so I can take a look?

The most related experiment I know of for 3 is when @eustlb mixed French with Spanish for distillation efforts: https://github.com/huggingface/distil-whisper/tree/main/training#3-language-mixing

Here, we used 500 hours of Spanish data to supplement 400 hours of French, and got -7.5% WER improvement. I would imagine a few hundred hours of closely related data should help you here as well

from transformers.

RohitMidha23 avatar RohitMidha23 commented on June 15, 2024

Most of the code is picked up from your HuggingFace blog, @sanchit-gandhi.

Dataset Loading

from datasets import concatenate_datasets, load_dataset
translate_dataset = load_dataset("...")
transcribe_dataset = load_dataset("...")

dataset = concatenate_datasets(
        [
            translate_dataset["train"],
            transcribe_dataset["train"],
            transcribe_dataset["test"],
        ]
    )
    
dataset = dataset.sort("duration")

# No changes made here:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt"
        )

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

Note: Here I'm simply combining the mapped datasets (mapped using different task specific tokenizers).

Processor

processor = WhisperProcessor.from_pretrained(MODEL_ID, task="translate")
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

Model

model = WhisperForConditionalGeneration.from_pretrained(
        MODEL_ID,
        use_safetensors=True,
    )

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

Trainer

training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16, 
        learning_rate=1e-5,
        num_train_epochs=num_epochs,
        bf16=True,
        warmup_ratio=0.1,
        predict_with_generate=True,
        generation_max_length=444,
        save_strategy="steps",
        save_steps=500,
        logging_steps=15,
        report_to=["wandb"],
        push_to_hub=True,
        hub_strategy="all_checkpoints",
        save_total_limit=5,
    )

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset,
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

I guess there is one part that I'm unsure about which is:

if TASK == "transcribe":
        model.generation_config.language = "<|hi|>"  # from : https://github.com/huggingface/transformers/pull/28687#issuecomment-1970980372
else:
        model.generation_config.language = "<|en|>"

I'm not sure of how this can be specified during training cause we're mixing the dataset.

Pseudo Labeling

Another question I had, along the same lines, was if I use a model, say Gemini to label huge amounts of data - how would I use it with trainer. Do I need to write the training loop here myself to maybe change the loss for such samples or can that be specified somehow through trainer?
The idea I have here is that, maybe I can pre-train or train whisper in the above methods, to learn very basics of the language before going on to the harder language present in my in distribution dataset.
Thoughts?

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.