Comments (2)
@taekyounghan Do you need to pass sharded_model
into train
instead of model
?
from pytorch.
@taekyounghan Do you need to pass
sharded_model
intotrain
instead ofmodel
?
Thanks! @awgu , I missed that point.!
I've changed every model
to sharded_model
.
However 2-D weight error is gone but RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
occurs.
Is it because I haven't written tensor parallelism for all layers?
Currently, only the line for the input of the first self-attention layer is written.
Current TP
custom_tp_model = parallelize_module(
model,
tp_mesh,
{
"encoder.block.0.layer.0.SelfAttention": PrepareModuleInput(
input_layouts=(Replicate()),
desired_input_layouts=(Shard(0)),
use_local_output=False
)
}
)
Error log
Traceback (most recent call last):
File "/workspace/tkhan/pytorch_transformer_2d.py", line 508, in <module>
fsdp_main(args)
File "/workspace/tkhan/pytorch_transformer_2d.py", line 420, in fsdp_main
train_accuracy = train(args, sharded_model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
File "/workspace/tkhan/pytorch_transformer_2d.py", line 243, in train
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1711, in forward
encoder_outputs = self.encoder(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1115, in forward
layer_outputs = layer_module(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 695, in forward
self_attention_outputs = self.layer[0](
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 602, in forward
attention_output = self.SelfAttention(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 521, in forward
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 280, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 104, in dispatch
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_tensor/dispatch.py", line 309, in unwrap_to_op_info
raise RuntimeError(
RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
Try
So When I add Colwise parallelism of Embedding layer, RuntimeError: Cannot writeback when the parameter shape changes Expects torch.Size([16449536]) but got torch.Size([0])
occurs
custom_tp_model = parallelize_module(
model,
tp_mesh,
{
"encoder.embed_tokens": ColwiseParallel(),
"encoder.block.0.layer.0.SelfAttention": PrepareModuleInput(
input_layouts=(Replicate()),
desired_input_layouts=(Shard(0)),
use_local_output=False
)
}
)
Error log
Traceback (most recent call last):
File "/workspace/tkhan/pytorch_transformer_2d.py", line 509, in <module>
fsdp_main(args)
File "/workspace/tkhan/pytorch_transformer_2d.py", line 421, in fsdp_main
train_accuracy = train(args, sharded_model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
File "/workspace/tkhan/pytorch_transformer_2d.py", line 243, in train
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 835, in forward
args, kwargs = _pre_forward(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 380, in _pre_forward
unshard_fn(state, handle)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 415, in _pre_forward_unshard
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 288, in _unshard
ran_pre_unshard = handle.pre_unshard()
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1243, in pre_unshard
ret = self._writeback_orig_params()
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2256, in _writeback_orig_params
self._writeback_tensor(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2359, in _writeback_tensor
raise RuntimeError(
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([16449536]) but got torch.Size([0])
Try 2
custom_tp_model = parallelize_module(
model,
tp_mesh,
{
#"encoder.embed_tokens": ColwiseParallel(input_layouts=Replicate()),
"encoder.block.0.layer.0": PrepareModuleInput(
input_layouts=(Replicate()),
desired_input_layouts=(Shard(0)),
use_local_output=True
)
}
)
Error Log
Traceback (most recent call last):
File "/workspace/tkhan/pytorch_transformer_2d.py", line 519, in <module>
fsdp_main(args)
File "/workspace/tkhan/pytorch_transformer_2d.py", line 431, in fsdp_main
train_accuracy = train(args, sharded_model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
File "/workspace/tkhan/pytorch_transformer_2d.py", line 243, in train
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1711, in forward
encoder_outputs = self.encoder(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1115, in forward
layer_outputs = layer_module(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 695, in forward
self_attention_outputs = self.layer[0](
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 602, in forward
attention_output = self.SelfAttention(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 561, in forward
scores += position_bias_masked
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0
Is there any way to resolve them?
Thanks for kind and quick reply!
from pytorch.
Related Issues (20)
- `_allow_non_fake_inputs` parameter of `make_fx` has no effect HOT 1
- DISABLED test_binary_op_list_slow_path__foreach_div_cuda_bool (__main__.TestForeachCUDA) HOT 1
- [DeviceMesh] Add support for `group: Tuple[ProcessGroup, ...]` in `from_group()` HOT 2
- DISABLED test_n_threads (__main__.TestOpenMP_ParallelFor) HOT 1
- DISABLED test_profiler_rpc_key_names (__main__.TensorPipeRpcTest) HOT 1
- DISABLED test_full_tensor_sync (__main__.DTensorTest) HOT 1
- DISABLED test_noncontiguous_samples_special_bessel_y1_cuda_int64 (__main__.TestCommonCUDA) HOT 1
- torch.tensor call with list of tensors fails with AssertionError: pending {u0} not in FakeTensor
- Weird AST constructor issue with mode="max-autotune" with python 3.11 HOT 8
- torch.onnx.dynamo_export fails to convert torchaudio.transforms.MFCC to onnx
- insert_deferred_runtime_asserts does not work with modulus HOT 2
- Run mkldnn matmul in SPR for bf32
- `torch.compile` gives correct index values (if those are returned), but not the indexed values. HOT 5
- torch.multinomial raises no error when sampling from zero weight and replacement=False. HOT 1
- Support for Kolmogorov-Arnold Networks (KANs) HOT 1
- [typing] Make arguments to `__getitem__`/`__setitem__`/ etc. positional only. HOT 1
- Torch shared library: undefined symbol to CBLAS
- Investigate Error "f capturable=True, params and state_steps must be CUDA or XLA tensors" when nn module inlining enabled. HOT 3
- DISABLED test_some_output_requires_grad_input_doesnt (__main__.TestAOTAutograd) HOT 1
- Could not jit compile custom extension in dataparallel mode
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.