Compare commits

...

32 Commits

Author SHA1 Message Date
5c30fb7061 update triton commit hash 2025-11-12 00:27:46 +00:00
e8d411e7f7 FSDPMemTracker fix with multihander hooks. (#165662)
Fixes #164663

## Issue
The torch model with multiple layers that is wrapped with fsdp2 registers pre and post forward hooks in a group using `_MultiHandler`. This becomes an issue during the context manager of the tracker where the hooks are reset and replaced. The hooks are all using the same fsdp state pointer so one reset will reset all.  So when the output layer was modified with a new pre and post forward hook it would delete the previous layer's initialization causing `KeyError` for the Norm layer as it is nonexistent.

## The Fix
Check to see if there are multiple `_MultiHandler` objects and `RemoveHandler` objects and only execute the remove hook once.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165662
Approved by: https://github.com/sanketpurandare
2025-11-11 23:49:36 +00:00
2e5233d7bd Revert "Support AC in default partitioner when functionalization is enabled (#166610)"
This reverts commit de773364be041ca7fd2dcaf35ca15c093fc9370b.

Reverted https://github.com/pytorch/pytorch/pull/166610 on behalf of https://github.com/soulitzer due to breaking internal tests ([comment](https://github.com/pytorch/pytorch/pull/166610#issuecomment-3519047226))
2025-11-11 23:01:09 +00:00
514dd96376 Remove --no-use-pep517 flag (#167096)
In pip 25.3 and newer, use of --no-use-pep517 has been removed (https://pip.pypa.io/en/stable/news/). In builds with pip 25.2, a warning message notes:

> DEPRECATION: Building 'torchvision' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'torchvision'. Discussion can be found at https://github.com/pypa/pip/issues/6334

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167096
Approved by: https://github.com/atalman
2025-11-11 23:00:35 +00:00
9ae62fcc18 [ROCm][CI] dynamo benchmarks update ci expected accuracy (#167574)
repvgg_a2 IMPROVED: accuracy=pass, expected=fail_accuracy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167574
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-11-11 22:54:55 +00:00
ae71b0e163 Fix typo in torch._refs (#167310)
Should be a typo here, but it doesn't raise an error because the inner function splits it into `a` and `,`, and the `,` case check is skipped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167310
Approved by: https://github.com/eellison
2025-11-11 22:31:09 +00:00
5b6ff8148d Revert "[ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)"
This reverts commit 402c46503002f98ccfc023a733081fb0719223a1.

Reverted https://github.com/pytorch/pytorch/pull/158250 on behalf of https://github.com/izaitsevfb due to Broke some torch.compile jobs ([comment](https://github.com/pytorch/pytorch/pull/158250#issuecomment-3518944863))
2025-11-11 22:27:51 +00:00
1f7e4343e7 [ROCm][CI] Add docker-cache-rocm.yml to test MI3xx CI docker caching (#167554)
* Trigger this workflow on every completed run of `docker-builds.yml`
* Uses `ubuntu-latest` for downloading artifacts from `docker-build` workflow run
* Uses `linux.rocm.gfx942.docker-cache` to cache docker images as tarballs for MI3xx CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167554
Approved by: https://github.com/jeffdaily
2025-11-11 21:32:22 +00:00
b21856f5fc Revert "[DebugMode] record triton kernels, run-to-run determinism checks (#167028)"
This reverts commit 259ba0ecabd809edd35d12b4f992777cb5923b68.

Reverted https://github.com/pytorch/pytorch/pull/167028 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167028#issuecomment-3518811298))
2025-11-11 21:31:12 +00:00
259ba0ecab [DebugMode] record triton kernels, run-to-run determinism checks (#167028)
Following up on https://github.com/pytorch/pytorch/pull/166348, extends DebugMode to capture inductor triton kernels at runtime, and adds an API for checking run-to-run determinism based on tensor hashes.

The workflow looks something like...
```python
# do 1st run with hashes, get logs
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
    compiled_model(*inputs)
logs1 = debug_mode.logs

# do 2nd run
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
    compiled_model(*inputs)
logs2 = debug_mode.logs

# returns list of calls w/ mismatched outputs
mismatches = DebugMode.check_hash_mismatches(logs1, logs2)
```

Example dump off a smaller version of @drisspg's FlexAttention fwd+bwd determinism tests [script](https://gist.github.com/pianpwk/f65cc63811d12853709dcc77d7eb69f1) (without forced reduction order):
```
cfg: TestConfig(name='Standard', B=2, Hq=32, Hkv=32, Q=2048, KV=2048, Dqk=128, Dv=128)
DETERMINISM: fwd: True, bwd_q: False, bwd_k: False, bwd_v: True

$$$ DEBUG MODE DUMP $$$  (this is what the logs look like)

    [triton] triton_tem_fused_0(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_MAX=t: f32[2, 32, 2048], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128])
    # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_MAX: 81775.3811062593, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, out_ptr0: 924917.7918248245}

    [triton] triton_per_fused_zeros_0(in_ptr0=t: bf16[2, 32, 2048, 128], in_ptr1=t: bf16[2, 32, 2048, 128], out_ptr1=t: f32[2, 32, 2048], xnumel=131072, r0_numel=128)
    # post-kernel hashes: {in_ptr0: 924917.7918248245, in_ptr1: 13389213.797377996, out_ptr1: 81775.38106592931}

    [triton] triton_tem_fused_zeros_1(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_DELTA=t: f32[2, 32, 2048], arg_DO=t: bf16[2, 32, 2048, 128], arg_DQ=t: bf16[2, 32, 2048, 128], arg_DV=t: bf16[2, 32, 2048, 128], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_Q_NUM_BLKS=t: i32[2, 32, 16], arg_Q_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_Q_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_Q_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128])
    # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_DELTA: 81775.38106592931, arg_DO: 13389213.797377996, arg_DQ: 874474.8084187683, arg_DV: 727742.3138379117, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_Q_NUM_BLKS: 1024.0, arg_Q_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, arg_FULL_Q_NUM_BLKS: 7680.0, arg_FULL_Q_IDX: 122880.0, out_ptr0: 700542.3431890717}

$$$ MISMATCHES $$$
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_0', 'arg_name': 'arg_MAX', 'pytree_path': None, 'hash1': 0.0, 'hash2': 81775.3811062593, 'rel_diff': 1.0, 'is_input_hash': False}  # I guess this one is misleading? not sure if I'm doing something wrong with waiting for kernel results
mismatch: {'call_type': 'triton kernel', 'call': 'triton_per_fused_zeros_0', 'arg_name': 'out_ptr1', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DELTA', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DQ', 'pytree_path': None, 'hash1': 874474.8097136207, 'hash2': 874474.8084187683, 'rel_diff': 1.480720012120795e-09, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'out_ptr0', 'pytree_path': None, 'hash1': 700542.3488049245, 'hash2': 700542.3431890717, 'rel_diff': 8.016435812581196e-09, 'is_input_hash': False}
```

note: current hash implementation is basically tensor norm, so tensor closeness -> hash closeness. This is likely to change soon, e.g. maybe to `torch.hash_tensor` (https://github.com/pytorch/pytorch/pull/154149) by default

Sample paste diff between log dumps from 2 runs:
<img width="1665" height="445" alt="Screenshot 2025-11-05 at 11 27 24 PM" src="https://github.com/user-attachments/assets/41402e37-f50b-4a9e-a17c-bb98b5917076" />

Another case where running this for FSDP2 on Llama3-8B, helped narrow down divergence b/w aot_eager <-> inductor, to inductor's FWD RMSNorm kernels: P2027003180

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167028
Approved by: https://github.com/v0i0
2025-11-11 20:37:53 +00:00
051f1fe8e3 Revert "[ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)"
This reverts commit ee387c43feada1cc2049b42a970ec4e2f12f210e.

Reverted https://github.com/pytorch/pytorch/pull/167554 on behalf of https://github.com/jithunnair-amd due to workflow had failure 'Unexpected input(s) 'run_id'' ([comment](https://github.com/pytorch/pytorch/pull/167554#issuecomment-3518642191))
2025-11-11 20:34:44 +00:00
ee387c43fe [ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)
Trigger this workflow on every completed run of `docker-builds.yml` and run on `ubuntu-latest` so it doesn't queue infinitely for `rocm-docker` label

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167554
Approved by: https://github.com/jeffdaily
2025-11-11 19:49:00 +00:00
3a944661d6 Cpython test_math.FMATests (#167217)
Resolves issues running the dynamo cpython math.fma tests.

Though math.fma is enabled to perform a multiply add in dynamo, torch.addcmul is currently used which doesn't guarantee the user request for fma. It was decided to not use inductor fma prim as it would break the contract of using aten/core ir in dynamo output - otherwise export=True may have issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167217
Approved by: https://github.com/guilhermeleobas
2025-11-11 19:26:18 +00:00
56034074ca Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit 6c5db82584bf71f5b1db3b598bbd00f44140c28d.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/mlazos due to Sorry, there's an internal slowdown due to the extra triton configs you added ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3518423369))
2025-11-11 19:23:40 +00:00
8def619bbe [user-streams] wait_stream op (#167512)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167512
Approved by: https://github.com/williamwen42
ghstack dependencies: #167510, #167511
2025-11-11 19:18:03 +00:00
61883a5787 [user-streams] Allow new streams to be created and registered during compilation (#167511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167511
Approved by: https://github.com/williamwen42
ghstack dependencies: #167510
2025-11-11 19:18:03 +00:00
d8ada1ee76 [user-streams] Allow new events to be created and registered during compilation (#167510)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167510
Approved by: https://github.com/williamwen42
2025-11-11 19:18:03 +00:00
fe841a1db4 [DeviceMesh] Log DeviceMesh.__init__ usage (#167375)
Adds (meta-internal-only) API usage logging for DeviceMesh creation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167375
Approved by: https://github.com/fduwjj
ghstack dependencies: #167374
2025-11-11 19:15:47 +00:00
b65829b84f [DTensor] Log API usage metrics for DTensor and DeviceMesh (#167374)
Logging propagate_op_sharding_non_cached is a compromise between
 - logging in DTensor.__init__ to catch ALL DTensor usage
 - sparing the overhead in a latency-senstitive region like
   DTensor.__init__
 - and 'real' DTensor usage should incur at least one call to sharding
   propagation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167374
Approved by: https://github.com/zpcore
2025-11-11 19:15:47 +00:00
b0e0ae97ba include thrust/distance.h explicitly in cuda sparse softmax (#167436)
`thrust::distance` is defined there
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167436
Approved by: https://github.com/Skylion007
2025-11-11 19:10:55 +00:00
f44a1ddcb2 Revert "[ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)"
This reverts commit 184e2cbc89570e1bf466b15d70fc36ed71be0eb9.

Reverted https://github.com/pytorch/pytorch/pull/167554 on behalf of https://github.com/jithunnair-amd due to Need to fix lint ([comment](https://github.com/pytorch/pytorch/pull/167554#issuecomment-3518382341))
2025-11-11 19:09:45 +00:00
184e2cbc89 [ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)
Trigger this workflow on every completed run of `docker-builds.yml` and run on `ubuntu-latest` so it doesn't queue infinitely for `rocm-docker` label

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167554
Approved by: https://github.com/jeffdaily
2025-11-11 19:07:19 +00:00
416421c7c4 fix failure of exporting compiled model with nested dynamic shapes (#166358)
## Problems
When exporting a compiled model with nested input like below
```python
import torch
from torch.export import export, Dim

def test_export_compiled_model_with_nested_dynamic_shapes():
   """Test exporting a compiled model with nested dict inputs and dynamic shapes."""
   print("Running test_export_compiled_model_with_nested_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, data_batch):
           return data_batch["a1"] + data_batch["a2"]

   m = M()
   compiled_m = torch.compile(m)
   example_args = ({
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   },)
   dynamic_shapes = ({
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   },)

   try:
       ep = export(compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(*example_args)
       result_compiled = compiled_m(*example_args)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_nested_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_nested_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

def test_export_compiled_model_with_kwargs_dynamic_shapes():
   """Test exporting a compiled model with kwargs and dynamic shapes."""
   print("\nRunning test_export_compiled_model_with_kwargs_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, a1, a2):
           return a1 + a2

   m = M()
   compiled_m = torch.compile(m)
   example_args = ()
   example_kwargs = {
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   }
   dynamic_shapes = {
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   }

   try:
       ep = export(compiled_m, example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(**example_kwargs)
       result_compiled = compiled_m(**example_kwargs)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_kwargs_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_kwargs_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

if __name__ == "__main__":
   print("Testing export of compiled models with dynamic shapes\n")
   print("=" * 70)

   results = []
   results.append(test_export_compiled_model_with_nested_dynamic_shapes())
   results.append(test_export_compiled_model_with_kwargs_dynamic_shapes())

   print("\n" + "=" * 70)
   print(f"\nResults: {sum(results)}/{len(results)} tests passed")

   if all(results):
       print("✓ All tests passed!")
   else:
       print("✗ Some tests failed")
       exit(1)
```

It will report
```
======================================================================
Running test_export_compiled_model_with_nested_dynamic_shapes...
✗ test_export_compiled_model_with_nested_dynamic_shapes FAILED
Error: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]` is a <class 'tuple'>, but `dynamic_shapes[0]` is a <class 'dict'>
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 614, in _tree_map_with_path
    return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2055, in tree_map_with_path
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2055, in <listcomp>
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1188, in flatten_up_to
    helper(self, tree, subtrees)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1185, in helper
    helper(subspec, subtree, subtrees)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1141, in helper
    raise ValueError(
ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'dict'>.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_exprot.py", line 25, in test_export_compiled_model_with_nested_dynamic_shapes
    ep = export(compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1415, in _strict_export
    gm_torch_level = _export_to_torch_ir(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 785, in _export_to_torch_ir
    _check_dynamic_shapes(combined_args, dynamic_shapes)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 1031, in _check_dynamic_shapes
    _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 686, in _tree_map_with_path
    _compare(tree_spec, other_tree_spec, [])
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 677, in _compare
    _compare(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 652, in _compare
    raise_mismatch_error(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 634, in raise_mismatch_error
    raise UserError(
torch._dynamo.exc.UserError: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]` is a <class 'tuple'>, but `dynamic_shapes[0]` is a <class 'dict'>
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

Running test_export_compiled_model_with_kwargs_dynamic_shapes...
✗ test_export_compiled_model_with_kwargs_dynamic_shapes FAILED
Error: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['kwargs'] of `inputs`, but here they are ['a1', 'a2']. Since here `inputs` is a list/tuple enclosing a single dict, maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_exprot.py", line 62, in test_export_compiled_model_with_kwargs_dynamic_shapes
    ep = export(compiled_m, example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=True)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1415, in _strict_export
    gm_torch_level = _export_to_torch_ir(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 785, in _export_to_torch_ir
    _check_dynamic_shapes(combined_args, dynamic_shapes)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 1007, in _check_dynamic_shapes
    raise UserError(
torch._dynamo.exc.UserError: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['kwargs'] of `inputs`, but here they are ['a1', 'a2']. Since here `inputs` is a list/tuple enclosing a single dict, maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

======================================================================
```
## Torch Version
(reproducible nightly version)

## Other Behavior
The model can export regularly when we test without compiling the model
```python
import torch
from torch.export import export, Dim

def test_export_compiled_model_with_nested_dynamic_shapes():
   """Test exporting a compiled model with nested dict inputs and dynamic shapes."""
   print("Running test_export_compiled_model_with_nested_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, data_batch):
           return data_batch["a1"] + data_batch["a2"]

   m = M()
   example_args = ({
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   },)
   dynamic_shapes = ({
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   },)

   try:
       ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(*example_args)
       result_compiled = m(*example_args)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_nested_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_nested_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

def test_export_compiled_model_with_kwargs_dynamic_shapes():
   """Test exporting a compiled model with kwargs and dynamic shapes."""
   print("\nRunning test_export_compiled_model_with_kwargs_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, a1, a2):
           return a1 + a2

   m = M()
   example_args = ()
   example_kwargs = {
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   }
   dynamic_shapes = {
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   }

   try:
       ep = export(m, example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(**example_kwargs)
       result_compiled = m(**example_kwargs)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_kwargs_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_kwargs_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

if __name__ == "__main__":
   print("Testing export of compiled models with dynamic shapes\n")
   print("=" * 70)

   results = []
   results.append(test_export_compiled_model_with_nested_dynamic_shapes())
   results.append(test_export_compiled_model_with_kwargs_dynamic_shapes())

   print("\n" + "=" * 70)
   print(f"\nResults: {sum(results)}/{len(results)} tests passed")

   if all(results):
       print("✓ All tests passed!")
   else:
       print("✗ Some tests failed")
       exit(1)

```
## Root Cause

This is because of a side effect of torch.compile(model). When the model is being compiled, the input signature will become (*args, **kwargs) automatically. In the above example, the `data_batch` will be added into `args` in combined_args [here](dc011d3203/torch/export/dynamic_shapes.py (L720)), and it will look like
```
{'args': ({'a1': tensor([[1., 1., 1.]... 1., 1.]]), 'a2': tensor([[1., 1., 1.]... 1., 1.]])},)}
```
Without the compiling, the combined args will look like
```
{'data_batch': {'a1': tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]), 'a2': tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])}}

```
Thus causing the mismatch when we use treemap to match the dynamic shape with the input argos

The error is also reproducible when we setup kwargs as example argos (see the 2nd test above)
## Fix
Proposed fix: In [_combine_args](dc011d3203/torch/export/dynamic_shapes.py (L720)) we explicitly flatten out the kwargs and args into combined args.
## Side Effects
There are 2 existing tests that assume this behavior and
1. add `args` explicitly to dynamic shapes
2. wrap args into nested format in dynamic_shape

I have modified those test to make args and dynamic_shapes to be in consistent format.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166358
Approved by: https://github.com/angelayi
2025-11-11 19:04:58 +00:00
bd99ae3315 [Docs] Add warning that torch.export.load uses pickle (#167557)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167557
Approved by: https://github.com/zhxchen17, https://github.com/angelayi
2025-11-11 18:47:14 +00:00
ce8672c24f Fix use of TORCH_CHECK in torch/csrc/stable (#167495)
Tested by above PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167495
Approved by: https://github.com/janeyx99
ghstack dependencies: #166579, #166694, #166695, #167362
2025-11-11 17:58:30 +00:00
402c465030 [ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)
Co-authored-by: Nikhil Gupta [nikhil.gupta2@arm.com](mailto:nikhil.gupta2@arm.com)

This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance

**This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%**

### Benchmark Setup
```
Model: meta-llama/Llama-3.1-8B
Test Platform: Neoverse V2
```
### Detailed Results

| Metric                           | With `--compile`         | Without `--compile`      |
|----------------------------------|---------------------------|---------------------------|
| Quantization Scheme              | INT4 symmetric channelwise | INT4 symmetric channelwise |
| Input Precision                  | BF16                      | BF16                      |
| Number of Layers Quantized       | 32                        | 32                        |
| Average Compression Ratio        | 87.49%                    | 87.49%                    |
| Total Quantization Time (s)      | 9.62                      | 10.32                     |
| Compile Time (First) (s)         | 134.48                    | 1.69                      |
| Compile Time (Second) (s)        | 80.44                     | 1.60                      |
| Compile Time (Subsequent) (s)    | 0.19                      | 0.22                      |
| Prefill Tokens                   | 54                        | 54                        |
| Decoded Tokens                   | 33                        | 33                        |
| Prefill Time (s)                 | 0.19                      | 0.22                      |
| Decode Time (s)                  | 0.76                      | 1.38                      |
| E2E Generation Time (s)          | 0.95                      | 1.60                      |
| Prefill Throughput (tokens/s)    | 288.13                    | 249.91                    |
| Decode Throughput (tokens/s)     | 43.42                     | 23.83                     |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158250
Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01

Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-11 17:50:22 +00:00
573a79fffa [OpenReg] Initialize device stream states for all devices in initOpenRegStreamsOnce (#167528)
Fixes #167527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167528
Approved by: https://github.com/fffrog
2025-11-11 16:53:22 +00:00
4945180468 Add empty tensor check for _pad_packed_sequence (#167521)
That prevents null pointer dereference

Fixes https://github.com/pytorch/pytorch/issues/149622
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167521
Approved by: https://github.com/albanD
2025-11-11 16:46:13 +00:00
1df723e6f5 [inductor] Fix constant creation (#167398)
We ran into this issue when debugging inductor-lite. Calling `torch.tensor` within a fake mode (which is the case inside of inductor) will create a FakeTensor, which causes this FakeTensor to be used as a constant within inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167398
Approved by: https://github.com/eellison, https://github.com/BoyuanFeng
2025-11-11 16:30:46 +00:00
f9b81e23e4 [ROCm] Disable group gemm CK path when composable kernel (CK) is not enabled (#167403)
For ROCm builds without CK support, ensure use_fast_path is false so that the CK path is not triggered, since CK is currently not available in this configuration.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167403
Approved by: https://github.com/Skylion007, https://github.com/ScottTodd, https://github.com/jeffdaily
2025-11-11 16:15:51 +00:00
ffe6cc39c7 [inductor] Optimize cold compile time when cudagraphs-partition is enabled (#167132)
Summary: When cudagraphs-parittion is enabled, we have seen an increase of cold compile time in the vllm benchmark (see https://github.com/vllm-project/vllm/issues/27080). After some profiling, we found Triton compilation time increased the most. Further investigation reveals it was caused by duplicated Triton kernels not being shared among different partitions. This PR fixes the issue by reusing the Trition kernel source code cache at the top-level PythonWrapperCodegen.

In theory we could further reduce the compilation time by completely skipping compiling duplicated partitions. That can come as a furture improvement.

Some vllm benchmarking data,

```
VLLM_USE_STANDALONE_COMPILE=0 VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=True --model meta-llama/Meta-Llama-3.1-8
```
Before:
```
torch.compile takes 69.18 s in total
```
After:
```
torch.compile takes 26.81 s in total
```

As a refrence, this is the compile time when turning off inductor graph partition. Looks like we still have some gap to close.
```
VLLM_USE_STANDALONE_COMPILE=0 VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=False --model meta-llama/Meta-Llama-3.1-8B

torch.compile takes 19.41 s in total
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167132
Approved by: https://github.com/eellison
ghstack dependencies: #167131
2025-11-11 15:54:31 +00:00
db1f3f6901 [inductor] Only generate compile-time auto-tuning block in the main graph (#167131)
Summary: When cudagraphs partition and autotune_at_compile_time are enabled, currently each subgraph will generate its own auto-tuning code block and run them once by one. This PR improves it by only generating one auto-tuning code block at the main graph level and execute it once time to auto-tune all the kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167131
Approved by: https://github.com/eellison
2025-11-11 15:54:31 +00:00
44 changed files with 799 additions and 509 deletions

View File

@ -1 +1 @@
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
1070cd530573098dc8375422a1f1918d3815af3f

View File

@ -96,7 +96,6 @@ function pip_build_and_install() {
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
--no-use-pep517 \
-w "${wheel_dir}" \
"${build_target}"
fi

View File

@ -63,7 +63,7 @@ self-hosted-runner:
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4
- rocm-docker
- linux.rocm.gfx942.docker-cache
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-14

View File

@ -1,55 +0,0 @@
name: docker-cache-mi300
on:
# run every 6 hours
schedule:
- cron: 0 0,6,12,18 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
docker-cache:
if: github.repository_owner == 'pytorch'
runs-on: rocm-docker
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
no-sudo: true
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
push: false
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Tar and upload to S3 bucket
run: |
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress

108
.github/workflows/docker-cache-rocm.yml vendored Normal file
View File

@ -0,0 +1,108 @@
name: docker-cache-rocm
on:
workflow_run:
workflows: [docker-builds]
# TODO: Uncomment before merging
#branches: [main, release]
types:
- completed
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
actions: read
jobs:
download-docker-builds-artifacts:
if: github.repository_owner == 'pytorch'
name: download-docker-builds-artifacts
runs-on: ubuntu-latest
outputs:
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
steps:
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
with:
run-id: ${{ github.event.workflow_run.id }}
path: ./docker-builds-artifacts
merge-multiple: true
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Process artifacts
id: process-artifacts
run: |
ls -R ./docker-builds-artifacts
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
cat "${GITHUB_OUTPUT}"
docker-cache:
if: github.repository_owner == 'pytorch'
needs: download-docker-builds-artifacts
strategy:
fail-fast: false
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:
- name: debug
run: |
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Generate ghrc.io tag
id: ghcr-io-tag
run: |
ecr_image="${{ matrix.docker-image }}"
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
- name: Save as tarball
run: |
docker_image_tag=${{ matrix.docker-image }}
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
ref_name=${{ github.event.workflow_run.head_branch }}
if [[ $ref_name =~ "release/" ]]; then
ref_suffix="release"
elif [[ $ref_name == "main" ]]; then
ref_suffix="main"
else
# TODO: Remove below
ref_suffix="main"
# echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
fi
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar

View File

@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
auto batch_sizes_t = _batch_sizes.contiguous();
checkLongTensor(batch_sizes_t);
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
int64_t max_batch_size = batch_sizes[0];

View File

@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);

View File

@ -47,6 +47,7 @@
#include <c10/macros/Macros.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/gather.h>

View File

@ -50,7 +50,7 @@ nfnet_l0,pass,7
repvgg_a2,fail_accuracy,7
repvgg_a2,pass,7

1 name accuracy graph_breaks
50
51
52
53
54
55
56

View File

@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
pip install ninja
# Install onnx
pip install --no-use-pep517 -e "$tp2_dir/onnx"
pip install -e "$tp2_dir/onnx"
# Install caffe2 and pytorch
pip install -r "$top_dir/caffe2/requirements.txt"

View File

@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
static void initOpenRegStreamsOnce() {
c10::call_once(init_flag, initGlobalStreamState);
for (const auto i : c10::irange(num_devices)) {
c10::call_once(
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
}
if (current_streams) {
return;
}
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
if (device_index == -1) {
device_index = current_device();
}
c10::call_once(
device_flags[device_index], initDeviceStreamState, device_index);
auto pri_idx =
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
const auto idx = get_idx(priority_counters[device_index][pri_idx]);

View File

@ -180,6 +180,47 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
del model
del optim
def _test_tracker_multihandler_hook(self):
"""Should run without KeyError."""
class TestModule(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.norm1 = nn.RMSNorm(dim)
self.output1 = nn.Linear(dim, dim)
self.norm2 = nn.RMSNorm(dim)
self.output2 = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm1(x)
x = self.output1(x)
x = self.norm2(x)
x = self.output2(x)
return x
gc.collect()
torch.manual_seed(42)
dev = torch.device(torch.accelerator.current_device_index())
with torch.device(dev):
model = TestModule(128)
mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard([model.norm1, model.output1], mesh=mesh)
fully_shard([model.norm2, model.output2], mesh=mesh)
fully_shard(model, mesh=mesh)
fmt = FSDPMemTracker(model)
with fmt:
inp = torch.randn(16, 128, device=dev)
y = model(inp)
loss = y.sum()
loss.backward()
del inp
del model
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
@ -371,6 +372,7 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import default_partition, min_cut_rematerialization_partition
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
)
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -281,14 +281,7 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler)
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def test_tags_function(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -304,22 +297,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def test_tags_function_via_global_checkpoint(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -334,28 +316,17 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def test_tags_function_with_kwargs(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=False
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
@ -365,22 +336,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def test_tags_sequential_layers(self, device):
def gn(x):
x = x.cos()
for _ in range(3):
@ -401,22 +361,11 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def test_tags_multiple_checkpoints(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -434,22 +383,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
def test_tags_module(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -473,22 +411,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
def test_tags_decomps(self, device):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
@ -516,7 +443,6 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
@ -776,14 +702,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def test_compile_selective_checkpoint_must_recompute(self, device):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
@ -804,9 +723,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def _test(context_fn, bw_compiler, partition_fn):
def _test(context_fn, bw_compiler):
def gn(x):
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
return torch.sigmoid(torch.matmul(x, x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
@ -820,14 +739,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial(
count_ops,
freq=2,
freq=1,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x)
@ -835,19 +754,17 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=4, # 2 bwd mm ops per fwd matmul
freq=2, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
def test_sac_with_partial_context_fn(self):
@ -884,16 +801,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -933,22 +841,15 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device, partition_fn
self, device
):
def selective_checkpointing_context_fn():
no_recompute_list = [
@ -988,7 +889,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
@ -996,14 +897,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
def test_compile_selective_checkpoint_triton_kernel(self, device):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
@ -1063,21 +957,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def test_compile_selective_checkpoint_tensor_subclass(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1120,21 +1007,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def test_compile_selective_checkpoint_custom_rule(self, device):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1192,21 +1072,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
@ -1245,21 +1118,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def test_compile_selective_checkpoint_outplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1297,21 +1163,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def test_compile_selective_checkpoint_list_ops(self, device):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
@ -1347,7 +1206,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1358,14 +1217,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def test_compile_selective_checkpoint_inplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1405,7 +1257,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1413,14 +1265,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
def test_compile_selective_checkpoint_random_op(self, device):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
@ -1467,7 +1312,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1479,14 +1324,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def test_compile_selective_checkpoint_invalid_context(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
@ -1515,7 +1353,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1524,14 +1362,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def test_compile_selective_checkpoint_parametrization(self):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
@ -1594,9 +1425,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial(
count_ops,
freqs=[
# 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
2, # 1 from mul recompute, 1 from mul backward
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1605,7 +1434,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
model = MLPModule()

View File

@ -2363,6 +2363,34 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(output, expected))
assert cnt.frame_count == 1
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
def test_math_fma(self):
def fma_func(a, b, c):
return math.fma(a, b, c)
# Test with scalar constants (constant folding path)
cnt = torch._dynamo.testing.CompileCounter()
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
assert cnt.frame_count == 0
expected = fma_func(2.0, 3.0, 4.0)
output = cfma_scalars(2.0, 3.0, 4.0)
self.assertEqual(output, expected)
assert cnt.frame_count == 0
# Test with tensors (Inductor path)
cnt2 = torch._dynamo.testing.CompileCounter()
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
assert cnt2.frame_count == 0
x = torch.tensor(2.0)
y = torch.tensor(3.0)
z = torch.tensor(4.0)
expected_tensors = x * y + z
output_tensors = cfma_tensors(x, y, z)
torch.testing.assert_close(output_tensors, expected_tensors)
assert cnt2.frame_count == 1
@make_test
def test_numpy_meshgrid(x, y):
r1, r2 = np.meshgrid(x.numpy(), y.numpy())

View File

@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
@requires_multigpu()
def test_new_event_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_event
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
e0_ind = new_event()
with torch.Stream(device="cuda:1"):
get_external_object_by_index(e0_ind).record()
e1_ind = new_event()
self.assertNotEqual(e0_ind, e1_ind)
self.assertNotEqual(
get_external_object_by_index(e0_ind),
get_external_object_by_index(e1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=event_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_new_stream_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_stream
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
s0_ind = new_stream()
s1_ind = new_stream()
self.assertNotEqual(s0_ind, s1_ind)
self.assertNotEqual(
get_external_object_by_index(s0_ind),
get_external_object_by_index(s1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=stream_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
@ -523,6 +576,23 @@ class <lambda>(torch.nn.Module):
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_run_opcheck_wait_record_stream(self):
from torch._dynamo.variables.streams import wait_stream
from torch.library import opcheck
s0 = torch.Stream()
s1 = torch.Stream()
s2 = torch.Stream()
store_user_object_weakrefs(s0, s1, s2)
sample_inputs = [
(0, 1),
(2, 0),
]
for args in sample_inputs:
opcheck(wait_stream, args)
@requires_cuda
def test_inductor_lowering(self):
with patch("torch._inductor.config.implicit_fallbacks", False):

View File

@ -331,7 +331,12 @@ class TestDynamismExpression(TestCase):
return torch.ops.aten.slice.Tensor(*args)
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
dynamic_shapes = (
{0: Dim("dim")},
None,
None,
None,
)
torch.export.export(
Slice(),
inp,
@ -5533,21 +5538,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
w = Wrapped()
if is_retracebility_test(self._testMethodName):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Detected mismatch between the structure of `inputs` and `dynamic_shapes`"
": `inputs` has 2 elements, but `dynamic_shapes` has 1 elements",
):
export(w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})})
else:
compiled = export(
w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})}
)
expected = w(*args)
mod = compiled.module()
got = mod(*args)
self.assertTrue(torch.allclose(expected, got))
compiled = export(w, args, dynamic_shapes=({0: batch}, {0: batch}))
expected = w(*args)
mod = compiled.module()
got = mod(*args)
self.assertTrue(torch.allclose(expected, got))
def test_dynamic_shapes_builder_basic(self):
class M(torch.nn.Module):
@ -17504,6 +17499,105 @@ def forward(self, x):
exported_param_names = [name for name, _ in gm.named_parameters()]
self.assertEqual(original_param_names, exported_param_names)
def test_export_compiled_model_with_nested_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, data_batch):
return data_batch["a1"] + data_batch["a2"]
m = M()
compiled_m = torch.compile(m)
example_args = (
{
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
},
)
dynamic_shapes = (
{
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
},
)
ep = export(
compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True
)
gm = ep.module()
self.assertEqual(gm(*example_args), compiled_m(*example_args))
def test_export_model_with_nested_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, data_batch):
return data_batch["a1"] + data_batch["a2"]
m = M()
example_args = (
{
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
},
)
B = torch.export.Dim("batch", min=1, max=65536)
dynamic_shapes = (
{
"a1": {0: B},
"a2": {0: B},
},
)
ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
gm = ep.module()
self.assertEqual(gm(*example_args), m(*example_args))
def test_export_compiled_model_with_kwargs_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, a1, a2):
return a1 + a2
m = M()
compiled_m = torch.compile(m)
example_args = ()
example_kwargs = {
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
}
dynamic_shapes = {
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
}
ep = export(
compiled_m,
example_args,
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
gm = ep.module()
self.assertEqual(gm(**example_kwargs), compiled_m(**example_kwargs))
def test_export_model_with_kwargs_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, a1, a2):
return a1 + a2
m = M()
example_args = ()
example_kwargs = {
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
}
dynamic_shapes = {
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
}
ep = export(
m,
example_args,
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
gm = ep.module()
self.assertEqual(gm(**example_kwargs), m(**example_kwargs))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestExportCustomClass(TorchTestCase):

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b).sin_().clone()
return FwBwMutation.apply(a, b)
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2689,22 +2689,17 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
return (mul, add)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
def forward(self, add, tangents_1):
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -4101,6 +4101,53 @@ if HAS_CUDA_AND_TRITON:
compiled_out = compiled_foo(x)
self.assertEqual(eager_out, compiled_out)
# Use autotune_at_compile_time=True to test standalone_compile
@parametrize("autotune_at_compile_time", [True, False])
@config.patch("graph_partition", True)
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
def foo(x):
# partition 1
x1 = x @ x
y1 = x1 + 1
z_cpu = y1.cpu() + 1
# partition 2
# partition 2 should reuse the fused triton kernel generated
# in partition 1
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
y2 = x2 + 1
return y1, y2
with config.patch(
"triton.autotune_at_compile_time", autotune_at_compile_time
):
compiled_foo = torch.compile(foo)
x = torch.randn((20, 20), device="cuda")
eager_out = foo(x)
compiled_out, code = run_and_get_code(compiled_foo, x)
self.assertEqual(eager_out, compiled_out)
if autotune_at_compile_time:
# auto-tuning block should only appear once. We generate auto-tuning code
# for all the kernels no matter if they are defined in the main graph or
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
FileCheck().check_count(
"Compile-time auto-tuning block", 1, exactly=True
).run(code[0])
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
# and then in the main code block
FileCheck().check_count(
"def triton_poi_fused_add_", 2, exactly=True
).run(code[0])
# cpu kernel definition should only appence once, not in the auto-tuning block
FileCheck().check_count(
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
).run(code[0])
else:
# triton_poi_fused_add_ should appear once, because of kernel reuse
FileCheck().check_count(
"def triton_poi_fused_add_", 1, exactly=True
).run(code[0])
def test_meta_tensor(self):
def foobar(x, y):
return x * 2, y * 3

View File

@ -4,8 +4,9 @@ from functools import partial
from unittest import skipIf
import torch
from torch._inductor import config
from torch._inductor.ir import Pointwise
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.virtualized import ops
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
@ -237,6 +238,17 @@ class TestCustomLowering(InductorTestCase):
out2 = fn_opt(a, b)
self.assertEqual(out1, out2)
@config.patch(joint_graph_constant_folding=False)
def test_constant_creation(self):
class M(torch.nn.Module):
def forward(self, x):
return x + torch.tensor(1)
make_fallback(torch.ops.aten.lift_fresh_copy.default)
self.assertTrue(
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -492,6 +492,36 @@ class PackedSequenceTest(TestCase):
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
)
def test_empty_packed_sequence(self):
"""
Regression test for https://github.com/pytorch/pytorch/issues/149622
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
"""
# Test case 1: pad_packed_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.randn(0, 5)
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
empty_packed = rnn_utils.PackedSequence(
empty_data, empty_batch_sizes, None, None
)
# Should not crash - either return empty result or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
# Test case 2: unpack_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.tensor([])
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
packed = rnn_utils.PackedSequence(
data=empty_data, batch_sizes=empty_batch_sizes
)
# Should not crash - either return empty list or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.unpack_sequence(packed)
if __name__ == "__main__":
run_tests()

View File

@ -2320,6 +2320,8 @@ if sys.version_info >= (3, 11):
torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable
torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable
if sys.version_info >= (3, 13):
torch_c_binding_in_graph_functions["math.fma"] = TorchInGraphFunctionVariable
# In graph functions (including constant folding) that are not C bindings
# NOTE: [Cacheability of in-graph torch functions]

View File

@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented
from ..graph_bytecode_inputs import get_external_object_by_index
from ..graph_bytecode_inputs import (
get_external_object_by_index,
register_graph_created_object,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import FxTracebackAnnotateVariable
@ -28,6 +31,26 @@ from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
def new_event(*args: Any, **kwargs: Any) -> int:
event = torch.Event(*args, **kwargs)
return register_graph_created_object(
event,
EventVariable.make_construct_in_graph_event_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
return register_graph_created_object(
stream,
StreamVariable.make_construct_in_graph_stream_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def _get_stream_by_index(index: int) -> torch.Stream:
stream = get_external_object_by_index(index)
assert isinstance(stream, torch.Stream), (
@ -115,6 +138,24 @@ def _(
has_side_effect(torch.ops.streams.wait_event.default)
@custom_op("streams::wait_stream", mutates_args=())
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
waiting = _get_stream_by_index(waiting_stream_index)
waited_on = _get_stream_by_index(waited_on_stream_index)
waiting.wait_stream(waited_on)
@wait_stream.register_fake
def _(
event_index: int,
stream_index: int,
) -> None:
pass
has_side_effect(torch.ops.streams.wait_stream.default)
class SymbolicStreamState:
"""Track the currently entered stream if any"""

View File

@ -603,6 +603,21 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
VariableTracker.build(tx, polyfills.radians), args, kwargs
)
if hasattr(math, "fma"): # Python 3.13+
@register(math.fma)
def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) != 3 or kwargs:
return None
if all(isinstance(arg, variables.TensorVariable) for arg in args):
x, y, z = args
addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul)
return addcmul_fn.call_function(tx, [z, x, y], {})
# Use math.fma if constants
return None
@register(torch.is_inference_mode_enabled)
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
unimplemented(

View File

@ -27,7 +27,6 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -296,10 +295,6 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
)
mode = get_proxy_mode()
assert mode is not None
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -51,7 +51,6 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import assert_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -298,10 +297,6 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1026,95 +1021,105 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
forward_node_names = OrderedSet(
node.name for node in forward_nodes if node.op != "output"
node.name for node in forward_only_graph.nodes if node.op != "output"
)
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
assert_functional_graph(joint_module.graph)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if node.is_impure(impure_random=False) and node.op not in (
"placeholder",
"output",
):
# See is_impure in torch/fx/node.py
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op",
node,
node.target,
)
saved_values.append(node)
continue
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if (
elif (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
assert all(user.target == operator.getitem for user in node.users)
continue
if not must_recompute(node):
saved_values.append(node)
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target is operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
return _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1122,24 +1127,6 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1634,9 +1621,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
@ -1673,16 +1658,6 @@ def cleanup_recompute_tags(
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module
@ -2790,59 +2765,6 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule,
_joint_inputs,
@ -2891,16 +2813,68 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
joint_module = cleanup_recompute_tags(joint_module)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
# networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
gpu: bool = True,
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time:
if config.triton.autotune_at_compile_time and gpu:
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
@ -3745,6 +3745,13 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
super().__init__()
root = self.get_root_graph()
# Only generate auto-tuning block in the main graph
self.kernel_autotune_defs = root.kernel_autotune_defs
self.kernel_autotune_calls = root.kernel_autotune_calls
# Only store kernel src to name mapping in the main graph
self.src_to_kernel = root.src_to_kernel
def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of
# the subgraph.
@ -3837,3 +3844,16 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
# )
self.parent_wrapper.write_get_raw_stream_header_once()
@cache_on_self
def get_root_graph(self) -> PythonWrapperCodegen:
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
while isinstance(root, SubgraphPythonWrapperCodegen):
root = root.parent_wrapper
assert isinstance(root, PythonWrapperCodegen)
return root
def generate_and_run_autotune_block(self):
# Only execute auto-tuning block in the main graph
pass

View File

@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.node import Node
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import _disable_current_modes
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
from torch.utils._sympy.symbol import SymT
@ -6135,9 +6136,12 @@ class ExternKernel(InputsKernel):
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
return ShapeAsConstantBuffer(expr=x)
if isinstance(x, Constant):
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
# We need to unset fake mode, or else the torch.tensor() call will
# turn into a FakeTensor
with _disable_current_modes():
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
if isinstance(x, ConstantBuffer):
return x
if isinstance(x, TensorBox):

View File

@ -3607,24 +3607,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -702,7 +702,7 @@ def exp2(a):
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a,"),
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/device_struct.h>
@ -120,7 +119,7 @@ struct FromImpl<ScalarType> {
case ScalarType::UInt64:
return from(aoti_torch_dtype_uint64());
default:
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported ScalarType, please file an issue describing your use case.");
}
@ -151,7 +150,7 @@ struct FromImpl<DeviceType> {
case DeviceType::PrivateUse1:
return from(aoti_torch_device_type_privateuse1());
default:
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported DeviceType, please file an issue describing your use case.");
}
@ -379,7 +378,7 @@ struct ToImpl<ScalarType> {
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
return ScalarType::UInt64;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported ScalarType ",
std::to_string(shim_scalartype),
@ -409,7 +408,7 @@ struct ToImpl<DeviceType> {
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
return DeviceType::PrivateUse1;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported DeviceType ",
std::to_string(shim_devicetype),

View File

@ -2,7 +2,7 @@ from collections.abc import Callable
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, NamedTuple, Optional, TypeVar, Union
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
@ -17,6 +17,9 @@ from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary, weakref
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
_TOTAL_KEY = "Total"
__all__ = ["FSDPMemTracker"]
@ -365,14 +368,28 @@ class FSDPMemTracker(MemTracker):
# `FSDPParamGroup.post_forward` because during AC these won't be called.
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
# pyrefly: ignore [missing-attribute]
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
# the _MultiHandlers object will only need to be grabbed once.
unique_handlers: dict[RemovableHandle, bool] = {}
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
unique_handlers[fsdp_state._post_forward_hook_handle] = True
# call remove on the handles once
for f_hook_handle in unique_handlers.keys():
f_hook_handle.remove()
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
fsdp_state._pre_forward_hook_handle.remove()
fsdp_state._post_forward_hook_handle.remove()
fsdp_state._pre_forward_hook_handle = (
# pyrefly: ignore [missing-attribute]
module.register_forward_pre_hook(

View File

@ -194,6 +194,10 @@ else:
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
# no-op in OSS, logs API usage metrics in meta-internal runs
torch._C._log_api_usage_once(
"torch.distributed.device_mesh.DeviceMesh.__init__"
)
if mesh is not None:
if _layout is not None or _rank_map is not None:
raise TypeError(

View File

@ -359,6 +359,10 @@ class ShardingPropagator:
"""
Propagate the sharding for an operator given the op_schema.
"""
# no-op in OSS, logs API usage metrics in meta-internal runs
torch._C._log_api_usage_once(
"torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
)
# special case op, we don't need to propagate for local
# scalar. TODO: figure out a better way to handle this
if op_schema.op is aten._local_scalar_dense.default:

View File

@ -398,6 +398,9 @@ def load(
Under active development, saved files may not be usable in newer versions
of PyTorch.
.. warning::
:func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.**
Loads an :class:`ExportedProgram` previously saved with
:func:`torch.export.save <torch.export.save>`.

View File

@ -3,7 +3,7 @@ import dataclasses
import inspect
import logging
import sys
from collections import defaultdict
from collections import defaultdict, OrderedDict
from collections.abc import Callable
from enum import auto, Enum
from typing import Any, Optional, TYPE_CHECKING, Union
@ -721,7 +721,18 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]:
else inspect.signature(f)
)
kwargs = kwargs if kwargs is not None else {}
return signature.bind(*args, **kwargs).arguments
combined_args = signature.bind(*args, **kwargs).arguments
# if `args` is in the key, flatten it into args_0, args_1, ...
if "args" in combined_args:
flattened_args = {f"args_{i}": v for i, v in enumerate(combined_args["args"])}
combined_args = OrderedDict({**combined_args, **flattened_args})
del combined_args["args"]
# flatten kwargs into combined_args
if "kwargs" in combined_args:
for k, v in combined_args["kwargs"].items():
combined_args[k] = v
del combined_args["kwargs"]
return combined_args
class ShapesCollection: