Giter Site home page Giter Site logo

Comments (16)

zorrofox avatar zorrofox commented on May 21, 2024 1

@PawKanarek Thanks a lot for your advice, I also have the same issue as you. I think you have the root causes that why the trained model not changed.

from transformers.

shub-kris avatar shub-kris commented on May 21, 2024 1

@PawKanarek can you also provide the training logs please and run with logging_steps=1?
Also use save_strategy=epoch

from transformers.

PawKanarek avatar PawKanarek commented on May 21, 2024

I modified the code a little bit to make some sanity checks.

def train():
    gemma2it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") # sanity check model
    
    tokenizer =  AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    dataset = load_dataset("pawkanarek/poke_test", split="train")
    lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=64,
            num_train_epochs=4,
            output_dir="output/trained_model",
            optim="adafactor",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=2048,
    )
    # 1
    trainer.train()
    print("comparing gemma2it with trainer.model")
    compare_weights(gemma2it, trainer.model) # different GemmaForCausalLM:2506172416 params vs SpmdFullyShardedDataParallel:3031123968 params
    
    # 2
    merged_model = trainer.model.merge_and_unload()
    print("comparing gemma2it with merged_model")
    compare_weights(gemma2it, merged_model) # different GemmaForCausalLM:2506172416 params vs GemmaForCausalLM:3030460416 params
    
    # 3
    print("saving merged_model")
    merged_model.to("cpu")
    merged_model.save_pretrained("output/merged_model")
    compare_weights(gemma2it, merged_model) # different GemmaForCausalLM:2506172416 params vs GemmaForCausalLM:3030460416 params

    # 4
    print("comparing loaded merged_model from disk with in-memory merged_model")
    loaded_merged_model = AutoModelForCausalLM.from_pretrained("output/merged_model")
    compare_weights(merged_model, loaded_merged_model) # different GemmaForCausalLM:3030460416 params vs GemmaForCausalLM:2506172416 params

    # 5
    print("comparing gemma2it with loaded merged_model from disk")
    compare_weights(gemma2it, loaded_merged_model) # models  GemmaForCausalLM and GemmaForCausalLM are the same

I added some sanity checks with base, untouched gemma2it model, and some mid-step comparison:

  1. Check if model after training trainer.model, differs from the base gemma2it: yes, the are different in number of parameters - that implies that training was succesfull
  2. Check if trained model after merge merged_model, differs from the base gemma2it : yes, the are different in numer of parameters - that implies that merging was succesfull
  3. Saving merged model and check if model after save merged_model difffers from the base gemma2it : yes, the are different in number of parameters - that implies that saving does nothing to parameters
  4. Loading merged model from the disk loaded_merged_model and check if it differs from the merged_model before saving - YES THEY ARE DIFFERENT :( - that implies that there is something wrong with loading the model (or saving)
    4.1. This warning popped when loading model from the disk:
Some weights of the model checkpoint at output/merged_model were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', (...) 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at output/merged_model and are newly initialized: ['model.layers.0.input_layernorm.weight', (...) 'model.layers.9.self_attn.v_proj.weight']
(...)You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1. Check if merged model from disk loaded_merged_model differs from the base gemma2it: no, they are the same... - that implies that all my training was worthless...

Looks like there is something fishy with my code when saving / loading model from the disk... I'll update if i notice what's wrong. I will check why my weights are saved to something called _orig_module.

from transformers.

zorrofox avatar zorrofox commented on May 21, 2024

Hi @PawKanarek

Please reference #29388 , by the way do you have testing the LoRA fine tune performance on TPU XLA? I have some explore for some LoRA but it has no any effective for the base model and the generate message just very same as base model.

from transformers.

PawKanarek avatar PawKanarek commented on May 21, 2024

Hi @zorrofox, and thanks for insight! Looks like my transformers fork didn't included change from that PR.
What kind of fine-tune performance are you talking about? You want to know how long does it take to train model with LoRA, or how well model is behaving after fine-tuning?

from transformers.

PawKanarek avatar PawKanarek commented on May 21, 2024

I used the trainer.save_pretrained function mentioned in PR #29388 but it didn't change anything - trained model after saving is still excactly the same as before training.

from transformers.

PawKanarek avatar PawKanarek commented on May 21, 2024

I think that i fixed it, but i won't recommend this fix to anyone, so I'm not even thinking about making PR.

It's a patch rather than fix, but i think it works - To check if it really works I will train gemma-2-it until it overfit on training dataset and then i will take a look on interference output.

To apply my patch you would have to add new parameter to save_pretrained
https://github.com/huggingface/transformers/blob/f02aea27378dd57c2ced4b28ff9e58ec3876340a/src/transformers/modeling_utils.py#L2190C1-L2203C7

formatting_weights_func = None,

Also add this code before sharding https://github.com/huggingface/transformers/blob/03847ef45189d328a51f428b0a61a6b891e69f88/src/transformers/modeling_utils.py#L2429C1-L2437C111

# apply formatting to the weights before saving 
if formatting_weights_func is not None: 
    for old_key in list(state_dict.keys()):
        new_key = formatting_weights_func(old_key)
        logger.debug(f"changed {old_key=} to {new_key=}")
        state_dict[new_key] = state_dict.pop(old_key)

With this changes I can finally spot a difference between a trained model loaded from disk and a base model that was trained on, and the warning also is gone

Some weights of the model checkpoint at output/merged_model were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', (...) 'model.layers.9._orig_module.self_attn.v_proj.weight']
def compare_weights(model1, model2):
    name1, name2 = model1.__class__.__name__, model2.__class__.__name__
    params1, params2 = model1.parameters(), model2.parameters() 
    sum1, sum2 = sum(p.numel() for p in params1), sum(p.numel() for p in params2)
    
    if (sum1 != sum2):
        print(f"!!! different in {name1}:{sum1} params vs {name2}:{sum2} params")
    
    for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
        if n1 != n2:
            print(f"!!! Parameter names differ: {n1} != {n2}")
            return False
        if not torch.equal(p1.data, p2.data):
            print(f"!!! Parameter values differ: {n1}, {p1.data}, {p2.data}")
            return False
        
def formmating_func(old_key):
    return old_key.replace('._orig_module', '')

def train():
    # the same training config as before
    trainer.train()
    trainer_model = trainer.model.to('cpu')
    merged_model = trainer_model.merge_and_unload()
    merged_model.save_pretrained("output/merged_model", formatting_weights_func = formmating_func)
    
    loaded_merged_model = AutoModelForCausalLM.from_pretrained("output/merged_model")
    gemma2it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
    print("!!! comparing gemma2it with loaded merged_model from disk")
    compare_weights(gemma2it, loaded_merged_model) # !!! FINALLY !!! Parameter values differ: model.layers.0.self_attn.k_proj.weight, tensor([[-3.2043e-04,  8.1177e-03,  3.0365e-03,  ..., -5.3101e-03,

I'm not closing this issue, because I didn't fixed it, and true issue is still hidden somewhere. That's only workaround

from transformers.

amyeroberts avatar amyeroberts commented on May 21, 2024

cc @pacman100 @muellerzr @shub-kris

from transformers.

shub-kris avatar shub-kris commented on May 21, 2024

@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?

from transformers.

shub-kris avatar shub-kris commented on May 21, 2024

@PawKanarek also after training can you try saving with trainer.save_model('output_dir')

from transformers.

shub-kris avatar shub-kris commented on May 21, 2024

@PawKanarek also with your patch did it work?

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.