mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
When the autobucketing pass is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.
When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](a2e2e1d8c0/torch/distributed/distributed_c10d.py (L3303)), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.
It is needed to unblock the aot_eager+autobucketing pass in this [PR](https://github.com/pytorch/torchtitan/pull/1813).
Otherwise, I hit the error as follows:
```bash
traceback : Traceback (most recent call last):
File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
return f(*args, **kwargs)
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
self.train_step(data_iterator)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
loss = self.forward_backward_step(input_dict, labels)
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
return super().__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
raise BackendCompilerFailed(
self.compiler_fn, e, inspect.currentframe()
).with_traceback(e.__traceback__) from None
File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
compiled_fn, _ = aot_stage2_compile(
~~~~~~~~~~~~~~~~~~^
aot_state,
^^^^^^^^^^
...<4 lines>...
inference_compiler,
^^^^^^^^^^^^^^^^^^^
)
^
File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
return aot_stage2_autograd(aot_state, aot_graph_capture)
File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
schedule_overlap_bucketing(gm)
~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
).run()
~~~^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
dist.all_gather_object(
~~~~~~~~~~~~~~~~~~~~~~^
gathered_runtime_estimations, runtime_estimations, pg
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
return func(*args, **kwargs)
File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
input_tensor, local_size = _object_to_tensor(obj, current_device, group)
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
byte_tensor = torch.ByteTensor(byte_storage).to(device)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta". This is no longer allowed; the devices must match.
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165063
Approved by: https://github.com/eellison