Bucketing a number of smallish improvements:
- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
- Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.
Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.
TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)
On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB
on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB
WIP: adding tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.
Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results
Other mis:
- Validate accuracy of nccl estimations ( use ruisi's profiling instead ?)
For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.
fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3
bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.
Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results
Other mis:
- Validate accuracy of nccl estimations ( use ruisi's profiling instead ?)
For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.
fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3
bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
Implementations:
1. Move collective ops to c10d namespace, so that we can call them externally.
2. Add AOTI shims for collective ops.
Testing
1. Add c10d functional UT for cpu.
2. Include the above one in cpp wrapper UT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154492
Approved by: https://github.com/desertfire
We observed signficant compile time regression in torchtitan when turning
on 2D parallel + torch.compile recently. So I decided to get a deeper
understanding why.
It turns out this is affecting **all the trainings** that have functional collectives
captured in the graph, not only 2D parallel (2D parallel was just the
job that happen to have collectives captured in the TP region).
The root cause is because when doing inductor lowering, we are calling
the comm analysis pass to get a estimated collective time for each
collective node in the graph, for each call to check the collective
node, we are calling `get_gpu_type()`, which under the hood calls a
`torch.utils.collect_env.run` to get the GPU info. However, this call is
super expensive! The reason is that this call effectively spawns a new
process and call `nvidia-smi` to get the GPU info, so the cost is **linear**
to the number of collective nodes in the graph.
see https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py#L75
The fix is to add a lru cache to the function, so that we only call this
once and reuse the cached results afterwards
torchtitan benchmark shows:
* before this fix: 2D parallel + fp8 compile time: 6min +
* after this fix: 2D parallel + fp8 compile time: 2min 48s (more than 100% improvement)
There're more room to improve the compile time, but this PR is trying to fix the biggest regression I found so far.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128363
Approved by: https://github.com/yf225
This PR completely removes the Inductor IR for legacy functional collectives:
- Removed the `CollectiveKernel` hiearchy and `Wait`, as well as the corresponding lowerings. These IRs are target (i.e. Python) specific and don't model node dependencies propoerly (e.g. they rely on `never_reuse_buffers` for correct behavior). They've been superceded by `ir._CollectiveKernel`.
- Removed `InPlaceHint` and the scheduler logic for handling it. `InPlaceHint` is a codegen-time buffer reuse mechanism controlled by the IR's codegen. It's a bit hacky and overlaps with the default buffer reuse mechanism. Removing it since it is only used by legacy functional collectives.
- Removed `OutputBuffer` and `MultiOutputNoSizeAssert` which are designed for and only used by legacy functional collectives.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124992
Approved by: https://github.com/Chillee, https://github.com/wanchaol
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.
I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
This PR implements intra-graph communication reordering pass on Inductor scheduler IR, based on Horace's previous PR #100762.
Main algorithm:
1. Greedily moves waits as late as possible (i.e. until we reach a use)
2. Greedily moves comms as early as possible (i.e. until we reach an input)
3. Move computes following simple heuristics to improve overlap.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108091
Approved by: https://github.com/Chillee, https://github.com/wanchaol