Giter Site home page Giter Site logo

Comments (2)

awgu avatar awgu commented on May 9, 2024 2

@taekyounghan Do you need to pass sharded_model into train instead of model?

from pytorch.

taekyounghan avatar taekyounghan commented on May 9, 2024

@taekyounghan Do you need to pass sharded_model into train instead of model?

Thanks! @awgu , I missed that point.!

I've changed every model to sharded_model.

However 2-D weight error is gone but RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators! occurs.

Is it because I haven't written tensor parallelism for all layers?

Currently, only the line for the input of the first self-attention layer is written.

Current TP

custom_tp_model = parallelize_module(
    model,
    tp_mesh,
    {
        "encoder.block.0.layer.0.SelfAttention": PrepareModuleInput(
             input_layouts=(Replicate()),
             desired_input_layouts=(Shard(0)),
             use_local_output=False
        )   
        }
    )

Error log

Traceback (most recent call last):
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 508, in <module>
    fsdp_main(args)
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 420, in fsdp_main
    train_accuracy = train(args, sharded_model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 243, in train
    output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1711, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1115, in forward
    layer_outputs = layer_module(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 695, in forward
    self_attention_outputs = self.layer[0](
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 602, in forward
    attention_output = self.SelfAttention(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 521, in forward
    query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 280, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 104, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 309, in unwrap_to_op_info
    raise RuntimeError(
RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Try

So When I add Colwise parallelism of Embedding layer, RuntimeError: Cannot writeback when the parameter shape changes Expects torch.Size([16449536]) but got torch.Size([0]) occurs

    custom_tp_model = parallelize_module(
    model,
    tp_mesh,
    {
        "encoder.embed_tokens": ColwiseParallel(),   
        "encoder.block.0.layer.0.SelfAttention": PrepareModuleInput(
             input_layouts=(Replicate()),
             desired_input_layouts=(Shard(0)),
             use_local_output=False
        )   
        }
    )

Error log

Traceback (most recent call last):
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 509, in <module>
    fsdp_main(args)
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 421, in fsdp_main
    train_accuracy = train(args, sharded_model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 243, in train
    output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 835, in forward
    args, kwargs = _pre_forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 380, in _pre_forward
    unshard_fn(state, handle)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 415, in _pre_forward_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 288, in _unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1243, in pre_unshard
    ret = self._writeback_orig_params()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2256, in _writeback_orig_params
    self._writeback_tensor(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2359, in _writeback_tensor
    raise RuntimeError(
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([16449536]) but got torch.Size([0])

Try 2

    custom_tp_model = parallelize_module(
    model,
    tp_mesh,
    {
        #"encoder.embed_tokens": ColwiseParallel(input_layouts=Replicate()),
        "encoder.block.0.layer.0": PrepareModuleInput(
            input_layouts=(Replicate()),
            desired_input_layouts=(Shard(0)),
            use_local_output=True
        )
        }
    )

Error Log

Traceback (most recent call last):
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 519, in <module>
    fsdp_main(args)
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 431, in fsdp_main
    train_accuracy = train(args, sharded_model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
  File "/workspace/tkhan/pytorch_transformer_2d.py", line 243, in train
    output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1711, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1115, in forward
    layer_outputs = layer_module(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 695, in forward
    self_attention_outputs = self.layer[0](
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 602, in forward
    attention_output = self.SelfAttention(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 561, in forward
    scores += position_bias_masked
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0

Is there any way to resolve them?
Thanks for kind and quick reply!

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.