Bucketing of multiple dtypes to be processed in one bucketed collective.
First target is to bucket bf16 and f32, but already can be used with other dtypes.
For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470
Approved by: https://github.com/eellison
Bucketing a number of smallish improvements:
- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
- Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.
Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.
TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)
On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB
on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB
WIP: adding tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
When the autobucketing pass is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.
When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](a2e2e1d8c0/torch/distributed/distributed_c10d.py (L3303)), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.
It is needed to unblock the aot_eager+autobucketing pass in this [PR](https://github.com/pytorch/torchtitan/pull/1813).
Otherwise, I hit the error as follows:
```bash
traceback : Traceback (most recent call last):
File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
return f(*args, **kwargs)
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
self.train_step(data_iterator)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
loss = self.forward_backward_step(input_dict, labels)
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
return super().__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
raise BackendCompilerFailed(
self.compiler_fn, e, inspect.currentframe()
).with_traceback(e.__traceback__) from None
File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
compiled_fn, _ = aot_stage2_compile(
~~~~~~~~~~~~~~~~~~^
aot_state,
^^^^^^^^^^
...<4 lines>...
inference_compiler,
^^^^^^^^^^^^^^^^^^^
)
^
File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
return aot_stage2_autograd(aot_state, aot_graph_capture)
File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
schedule_overlap_bucketing(gm)
~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
).run()
~~~^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
dist.all_gather_object(
~~~~~~~~~~~~~~~~~~~~~~^
gathered_runtime_estimations, runtime_estimations, pg
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
return func(*args, **kwargs)
File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
input_tensor, local_size = _object_to_tensor(obj, current_device, group)
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
byte_tensor = torch.ByteTensor(byte_storage).to(device)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta". This is no longer allowed; the devices must match.
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165063
Approved by: https://github.com/eellison
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
- Update the Memory Estimator to use node storages for analysis, which simplifies book keeping, as opposed to manually looking at operator schema. This will also allow me to reuse this component elsewhere.
- Factor out into separate class, so that this same logic can be used in scheduling (node allocations / aliasing / uses)
- Adds Tests for correctness - right now only on fwd/bwd by itself, not with both.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164783
Approved by: https://github.com/ruisizhang123
ghstack dependencies: #164738
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.
This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.
There's definitely some similarity with the `with_effects` hop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate.
The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)). By maintaining the node with its original calling structure in subgraph this all works.
Example transformation:
Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```
The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
Add a deterministic mode to skip the on device benchmarking that we know should affect numeric. This include
- pad-mm
- dynamic rblock scaling
- template autotuning
- coordinate descent tuning for reduction
- reduction config autotuning in CachingAutotuner. For reduction both RBLOCK, num_warps should affect numeric. XBLOCK does not. We can still autotune XBLOCK for reductions.
- benchmarking for computation communication reordering pass
The mode definitely has perf hit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163589
Approved by: https://github.com/v0i0
tl;dr performs bucketing while preserving comm-compute overlap.
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.
TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.
Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results
Other mis:
- Validate accuracy of nccl estimations ( use ruisi's profiling instead ?)
For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.
fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3
bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.
Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results
Other mis:
- Validate accuracy of nccl estimations ( use ruisi's profiling instead ?)
For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.
fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3
bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
Summary:
Running a toy example through `torch.compile(fullgraph=True, backend="inductor")` with default inductor config, I tried to see what passes are run in each of pre-grad, joint-graph, and post-grad phases by printing out the subsystem in `GraphTransformObserver`. However the subsystem showed up as None in a bunch of transforms that were run in each of those phases, so this PR adds some additional annotations.
Note that these annotations are probably not a complete set, since other transforms may run based on changes to the config that are not covered here.
Hopefully this doesn't change behavior. However, I did notice that bisecting relies on disabling various phases, which means that while before some passes would *not* be disabled (because their subsystem was `None`), now they would.
Test Plan: existing tests + manual test described in summary
Differential Revision: D83306676
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163922
Approved by: https://github.com/jansel
Current state: Shape mismatch failure when mm+rs on the last mm scatter dim.
Adding separate path to handle lastdim for aten.mm, scaled_mm should be handled similarly, but needs additional PR.
So disabling scaled_mm case with filter matmul function.
Adding inductor.config for this change that is True by default for fast debuggability of new path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162794
Approved by: https://github.com/fegin
Summary:
What: Enables CUDA support for concat linear int8_mm woq optimization pattern by:
- Updating pattern validation to accept CUDA devices
- Adding test coverage for CUDA
Why: Extend WOQ to more device types
Test Plan:
```
buck2 run 'fbcode//mode/opt' //caffe2/test/inductor:cuda_select_algorithm
```
Rollback Plan:
Differential Revision: D80884518
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161848
Approved by: https://github.com/jerryzh168
Summary:
What: Enables CUDA support for int8_mm woq optimization pattern by:
- Fixing dtype conversion in weight_int8pack_mm_kernel to match CPU
- Updating pattern validation to accept CUDA devices
- Adding test coverage for CUDA
Why: Extend WOQ to more device types
Test Plan:
```
buck2 run 'fbcode//mode/opt' //caffe2/test/inductor:cuda_select_algorithm
```
Rollback Plan:
Reviewed By: jerryzh168
Differential Revision: D80882442
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161680
Approved by: https://github.com/jerryzh168