Commit Graph

24 Commits

Author SHA1 Message Date
3255e7872b Enable all flake8-logging-format rules (#164655)
These rules are enabled by removing existing suppressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164655
Approved by: https://github.com/janeyx99, https://github.com/mlazos
2025-10-19 00:59:28 +00:00
b3f6d49b69 Overlap scheduler improvements (#165318)
Bucketing a number of smallish improvements:

- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
-  Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.

Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.

TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)

On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB

on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB

WIP: adding tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
2025-10-15 21:58:47 +00:00
0d7994ca97 [inductor] do comm compute overlap at aten fx level (#163215)
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
2025-09-30 04:53:58 +00:00
0f619c1f89 Revert "[inductor] do comm compute overlap at aten fx level (#163215)"
This reverts commit c9b5af9a384e7ef5f95613abe1622f5f55133c3a.

Reverted https://github.com/pytorch/pytorch/pull/163215 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see c9b5af9a38 (51526707590-box) ([comment](https://github.com/pytorch/pytorch/pull/163215#issuecomment-3349177940))
2025-09-29 21:53:42 +00:00
c9b5af9a38 [inductor] do comm compute overlap at aten fx level (#163215)
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
2025-09-29 18:18:03 +00:00
25c170b72e [inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)
During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405
Approved by: https://github.com/eellison
2025-09-08 14:33:19 +00:00
02c7ab2f9b [cpp wrapper] add AOTI shim for collective ops (#154492)
Implementations:
1. Move collective ops to c10d namespace, so that we can call them externally.
2. Add AOTI shims for collective ops.

Testing
1. Add c10d functional UT for cpu.
2. Include the above one in cpp wrapper UT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154492
Approved by: https://github.com/desertfire
2025-06-25 01:20:05 +00:00
b6d477fd56 [BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768
Approved by: https://github.com/jansel
2024-07-20 16:20:58 +00:00
8a09940a54 [inductor] fix compile time regression by caching get_gpu_type (#128363)
We observed signficant compile time regression in torchtitan when turning
on 2D parallel + torch.compile recently. So I decided to get a deeper
understanding why.

It turns out this is affecting **all the trainings** that have functional collectives
captured in the graph, not only 2D parallel (2D parallel was just the
job that happen to have collectives captured in the TP region).

The root cause is because when doing inductor lowering, we are calling
the comm analysis pass to get a estimated collective time for each
collective node in the graph, for each call to check the collective
node, we are calling `get_gpu_type()`, which under the hood calls a
`torch.utils.collect_env.run` to get the GPU info. However, this call is
super expensive! The reason is that this call effectively spawns a new
process and call `nvidia-smi` to get the GPU info, so the cost is **linear**
to the number of collective nodes in the graph.

see https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py#L75

The fix is to add a lru cache to the function, so that we only call this
once and reuse the cached results afterwards

torchtitan benchmark shows:
* before this fix: 2D parallel + fp8 compile time: 6min +
* after this fix: 2D parallel + fp8 compile time: 2min 48s (more than 100% improvement)

There're more room to improve the compile time, but this PR is trying to fix the biggest regression I found so far.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128363
Approved by: https://github.com/yf225
2024-06-11 18:02:13 +00:00
df0c69f32d [inductor] Add fallback for collectives size estimation for unbacked (#127562)
Differential Revision: [D57982928](https://our.internmc.facebook.com/intern/diff/D57982928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127562
Approved by: https://github.com/yifuwang
2024-05-31 11:15:46 +00:00
58d8388ed3 Remove Inductor IRs for legacy functional collectives (#124992)
This PR completely removes the Inductor IR for legacy functional collectives:
- Removed the `CollectiveKernel` hiearchy and `Wait`, as well as the corresponding lowerings. These IRs are target (i.e. Python) specific and don't model node dependencies propoerly (e.g. they rely on `never_reuse_buffers` for correct behavior). They've been superceded by `ir._CollectiveKernel`.
- Removed `InPlaceHint` and the scheduler logic for handling it. `InPlaceHint` is a codegen-time buffer reuse mechanism controlled by the IR's codegen. It's a bit hacky and overlaps with the default buffer reuse mechanism. Removing it since it is only used by legacy functional collectives.
- Removed `OutputBuffer` and `MultiOutputNoSizeAssert` which are designed for and only used by legacy functional collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124992
Approved by: https://github.com/Chillee, https://github.com/wanchaol
2024-05-05 19:49:58 +00:00
c5fafe9f48 [BE]: TRY002 - Ban raising vanilla exceptions (#124570)
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.

I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
2024-04-21 22:26:40 +00:00
71b8363f40 [inductor] Remove unused local variable. (#120227)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120227
Approved by: https://github.com/Skylion007
2024-04-13 03:19:13 +00:00
27ffede878 [reland] Fix estimate_nccl_collective_runtime (#118986)
`estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR:
- Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497.
- Adds white-box testing so future issues can be surfaced in tests.
- Add support for native funcol IRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986
Approved by: https://github.com/yf225
ghstack dependencies: #119102
2024-02-12 18:48:06 +00:00
7315ec7505 Revert "Fix estimate_nccl_collective_runtime (#118986)"
This reverts commit 0dab6fb35284ed47d1c6339e9d71e4ca3b50dc51.

Reverted https://github.com/pytorch/pytorch/pull/118986 on behalf of https://github.com/atalman due to Breaks internal tests ([comment](https://github.com/pytorch/pytorch/pull/118986#issuecomment-1934680463))
2024-02-08 18:11:53 +00:00
0dab6fb352 Fix estimate_nccl_collective_runtime (#118986)
`estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR:
- Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497.
- Adds white-box testing so future issues can be surfaced in tests.
- Add support for native funcol IRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986
Approved by: https://github.com/yf225
ghstack dependencies: #118910, #118911, #118437
2024-02-07 18:02:51 +00:00
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45ef53747e2eefffd65d91ce840b431b.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
4667e20b3f Delete a bunch of type-ignores (#113990)
* Replaced `ignore[import]` by mypy config file entries
* Removed a bunch of ignores around previously-fixed attr-defined /
  call-arg issues
* Fixed some invalid / undefined types; added a few more type-ignores to
  squelch the downstream errors this exposed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113990
Approved by: https://github.com/eellison, https://github.com/Skylion007
ghstack dependencies: #113979
2023-11-18 02:48:38 +00:00
498a760802 Update comm_analysis.py license (#113184)
Consulted with legal, this is the right way to do it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113184
Approved by: https://github.com/Chillee, https://github.com/malfet
2023-11-07 22:58:56 +00:00
e9804aaacc Fix unit tests and add logging for Inductor intra-graph reordering (#111981)
1. Fix code to make unit tests pass (incl. collect_env issue called out by @int3  in https://github.com/pytorch/pytorch/pull/108091#discussion_r1362901686).
2. Add logging for Inductor intra-graph reordering passes (`TORCH_LOGS="overlap"`), for easier debugging. Example log:
```
[rank0]:[2023-10-24 16:28:26,446] [0/0] torch._inductor.comms.__overlap: [DEBUG] ==== Visualize overlap before reordering pass <function reorder_compute_for_overlap at 0x7fa68c5568e0> ====
[rank0]:[2023-10-24 16:28:26,446] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf0)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf1)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf2)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf3)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf4)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf5)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf6)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf7)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf8)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf9)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf10)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf11)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] Est. runtime (ms): 0.000228

[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ==== Visualize overlap after reordering pass <function reorder_compute_for_overlap at 0x7fa68c5568e0> ====
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf2)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf3)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] | ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf0)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] | ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf1)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] | ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf9)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf4)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf5)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf6)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf7)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf8)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf10)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf11)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] Est. runtime (ms): 0.000217
```
The `| SomeComputeOp` means the compute op is overlapped with the comm op above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111981
Approved by: https://github.com/wanchaol
2023-10-25 18:19:43 +00:00
4c6e85365f Add NVIDIA license to comm_analysis.py (#111670)
We adapted the cost model from NCCL code, we should apply their license here as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111670
Approved by: https://github.com/Chillee, https://github.com/wanchaol
2023-10-20 21:34:35 +00:00
b28cb43f5c Intra-graph reordering pass on Inductor scheduler IR (based on #100762) (#108091)
This PR implements intra-graph communication reordering pass on Inductor scheduler IR, based on Horace's previous PR #100762.

Main algorithm:
1. Greedily moves waits as late as possible (i.e. until we reach a use)
2. Greedily moves comms as early as possible (i.e. until we reach an input)
3. Move computes following simple heuristics to improve overlap.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108091
Approved by: https://github.com/Chillee, https://github.com/wanchaol
2023-10-14 14:51:24 +00:00