Compare commits

..

99 Commits

Author SHA1 Message Date
0f49e915a9 rebase 2025-05-30 14:30:12 -07:00
2f1217f944 benchmarking 2025-05-30 14:27:37 -07:00
e0bf01e87b Script for consolidation of sharded safetensor files
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154743

Script to consolidate sharded safetensors files with DCP into full tensors. This relies on file system operations to read and copy bytes directly instead of the traditional approach of loading and re-sharding and then saving again, because users will have models that are larger than allotted memory.

Differential Revision: [D75536985](https://our.internmc.facebook.com/intern/diff/D75536985/)
ghstack-source-id: 287291639
2025-05-30 14:18:51 -07:00
3b5ae0e9fc Support re-sharding for safetensors checkpoints
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154519

This change will add the ability to support re-sharding for hf safetensors checkpoints.
This is done by adding more metadata when saving each file. This metadata captures the size and offset of the saved shard. This can be used to re-shard on load by using this information to create the chunks belonging to TensorStorageMetadata class.

Differential Revision: [D75226344](https://our.internmc.facebook.com/intern/diff/D75226344/)
ghstack-source-id: 286572125
2025-05-30 10:40:32 -07:00
5f5f654a3e Updates to HFStorageReader to use TensorStorageMetadata instead of BytesStorageMetadata
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154518

As we prepare to support re-sharding, the current approach of using BytesStorageMetadata to read safetenstors won't work anymore. Before, we didn't need to read the metadata of the safetensors file from its header because we were just loading the contents of the file directly into tensors with safetensor.load() that would handle the metadata and deserialization. But now, in preparation of handling re-sharding, we need to read the metadata directly from the header of the safetensors file and store it directly in TensorStorageMetadata objects so that we can perform re-sharding. Re-sharding won't currently work, as we need extra metadata to be stored on each save, so that will be added in a subsequent PR.
In addition this PR adds an integration test in addition to the unit tests.
It also removes the HfFileSystem import because that's only needed if users are using HfFileSystem, but we want to support any backend.
ghstack-source-id: 286649070
@exported-using-ghexport

Differential Revision: [D74891998](https://our.internmc.facebook.com/intern/diff/D74891998/)
2025-05-30 10:40:30 -07:00
21931cbbc6 Changes to HFStorageWriter to support saving shards of tensors
As we move towards supporting saving partial tensors natively with HFStorageWriter, there are some simple changes that need to be made to make this happen.
- The current approach for distributed writes is that every rank has full tensors, but we split up the writing of these full tensors across all available ranks. We're removing this logic that was in the HFSavePlanner and instead assuming that every rank has a shard and saving every rank's local state
    -  as a result we can probably remove the HFSavePlanner, but keeping it as a placeholder for now

- the current naming of files doesn't support shards as its in the format "model-00001-of-00004.safetensors", but if every rank is writing the same file names they will overwrite eachother, so this adds a shard-00001 prefix, so that the rank files don't overwrite eachother
- don't save the metadata file models.safetensors.index.json if sharding is enabled. This file expects a 1 to 1 ratio between tensor and filename, but this doesn't make sense in the sharded saving approach, so we can just get rid of this file
- make the "fqn_to_file_index" map optional. This is to describe which files to save which tensors in, but if users don't want to provide this, we can just save all the tensors to one file. If they run into issues, they can choose how to split up their tensors to be more friendly with 5GB HF remote storage file size soft limit.

Differential Revision: [D75099862](https://our.internmc.facebook.com/intern/diff/D75099862/)

ghstack-source-id: 286648122
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154742
2025-05-30 10:40:28 -07:00
ef4d57329b [CAG] Support for call_module at copy paste aot bwd graph (#153827)
Support for `call_module` in `copy_paste_aot_backward_graph` added recently with PT2.7

Problem is being observed with HPU backend in example repro due to creating fused modules.

```
import torch

device = 'cpu' #'hpu'
backend = 'inductor' #'hpu_backend'

def fn(t1):
    t1 = t1 * 1
    t1_grad = torch.ones_like(t1, device=device)
    t1.backward(t1_grad, retain_graph=True)
    return t1

t1 = torch.ones(1, requires_grad=True, device=device) #.squeeze()
compiled_fn = torch.compile(fn, backend=backend)
result = compiled_fn(t1)

with torch._dynamo.compiled_autograd._enable(torch.compile(backend=backend)):
    result_grad = torch.ones_like(result, device=device)
    result.backward(result_grad)

print(f'{result_grad=}')
print(f'{t1.grad=}')
```

With this change I'm getting same results like on CPU, however I'm facing below problem when running with scalar (t1 tensor after squeeze):
`torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(FakeTensor(..., device='hpu:0', size=()), 0), **{}): got IndexError('invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number')`

While on CPU there's following warning and None returned:
`repro.py:23: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at pytorch/build/aten/src/ATen/core/TensorBody.h:489.)
  print(f'{t1.grad=}')
t1.grad=None`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153827
Approved by: https://github.com/xmfan
2025-05-28 22:52:40 +00:00
d62a33c002 [ez] add docblock for _expandsums (#154397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154397
Approved by: https://github.com/laithsakka
ghstack dependencies: #154400, #154398, #154396, #154399
2025-05-28 22:43:26 +00:00
0c00e32632 [ez] add docblock for _eval_is_non_overlapping_and_dense (#154399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154399
Approved by: https://github.com/laithsakka
ghstack dependencies: #154400, #154398, #154396
2025-05-28 22:40:03 +00:00
0f56318152 [precompile] Add Exception type PackageError for unsupported precompile features. (#154430)
Summary:
Today when guard serialization fails, dynamo will raise an internal error like:

```
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: CLOSURE_MATCH guard cannot be serialized.
```

Adding a dedicated PackageError type to surface the error more clearly.

Test Plan: CI

Differential Revision: D75452124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154430
Approved by: https://github.com/jamesjwu, https://github.com/jansel
2025-05-28 22:34:51 +00:00
11129d9317 Add new ops in fallback ops (#154251)
Fixes #ISSUE_NUMBER

## Background

Task: [T222738229](https://www.internalfb.com/intern/tasks/?t=222738229)

It's the first starter task on the project **_Enabling TorchNative Standalone on Whisper_**.  We are using cshim to create a layer of abstraction between _**libtorch**_ and **_AOTInductor generated artifacts_**.

So we needed to add an entry in the cshim for every API surface in libtorch. And we only care about operators that AOTInductor does not handle. And for this task, we only wanted to add it for the following ops.

## What I've done?

4 new fallback ops are added that show up in the Whisper model. (torchgen/aoti/fallback_ops.py)

- aten.permute (default)
- aten.squueze (dim)
- aten.abs (default)
- aten.hann_window (default)

Then I ran the below command to generate new header C shim header files. As it says [here](7e86a7c015/torchgen/gen.py (L2424-L2436%20for%20details))
`python torchgen/gen.py --update-aoti-c-shim`

Then, `python setup.py develop` to rebuild PyTorch

## Testing

Also 4 new tests have been added on test/inductor/test_aot_inductor.py

- test_proxy_executor_permute
- test_proxy_executor_abs
- test_proxy_executor_squeeze
- test_proxy_executor_hann

I ran these commands to test it (inside local pytorch root folder):

`python test/inductor/test_aot_inductor.py -k test_proxy_executor_permute`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_abs`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_squeeze`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_hann`

## NOTE:
I didn't see any order between the tests inside _test/inductor/test_aot_inductor.py_. That's why, I added new tests just after the test given in the example.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154251
Approved by: https://github.com/angelayi
2025-05-28 22:11:07 +00:00
d2f506cae8 [ca] disable ca for functorch grad and run all HOO tests (#154147)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154147
Approved by: https://github.com/zou3519
ghstack dependencies: #154133
2025-05-28 22:06:13 +00:00
857f21631d [ca] fix hop_db tests (#154133)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154133
Approved by: https://github.com/zou3519
2025-05-28 22:06:13 +00:00
ed348e7026 Add docblock for TrackedFake (#154396)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154396
Approved by: https://github.com/laithsakka
ghstack dependencies: #154400, #154398
2025-05-28 21:19:49 +00:00
d311b79c12 add docblock for _fast_expand (#154398)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154398
Approved by: https://github.com/laithsakka
ghstack dependencies: #154400
2025-05-28 21:16:47 +00:00
e7318b863d [ez] add docblock to cast_symbool_to_symint_guardless (#154400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154400
Approved by: https://github.com/laithsakka
2025-05-28 21:11:53 +00:00
f6dcc45c44 [Kineto x Insight] Add device to activity type map in pytorch (#154253)
Summary: Update the device to ActivityType Map in pytorch. Need to be exported to github

Test Plan:
Run the ondemand e2e test and insight profiler is triggered during profiling
P1819539581: https://www.internalfb.com/intern/paste/P1819539581/
{F1978519960}

Insight profiler is not enabled when mtia_insight not specifying in config
{F1978527200}

Reviewed By: fenypatel99

Differential Revision: D75246621

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154253
Approved by: https://github.com/Skylion007
2025-05-28 20:36:19 +00:00
e25074d462 [c10d][CI] Change expected return code in Sandcastle for Nan tests (#154441)
Fixing internal error caused by #153167.

`skip_but_pass_in_sandcastle_if` returns exit code 0. But `test_nan_assert` expects exit code -6.
So we'd need to set expected return code conditional on `IS_SANDCASTLE`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154441
Approved by: https://github.com/fduwjj, https://github.com/nWEIdia
ghstack dependencies: #153167
2025-05-28 20:35:52 +00:00
c381103fd7 Fix the logic of set_cpu_affinity (#154503)
While investigating https://github.com/pytorch/pytorch/issues/152566, I found two issues with how the cpu affinity is set in benchmark job:

* The current logic doesn't work with cgroups slice, the mechanism behind multi-tenant runner:
    * Using `lscpu` returns all CPUs and not the available ones from cgroups.  On the other hand, `nproc` works correctly.  For example, on H100, `lscpu` returns 192 CPUs while `nproc` returns 24 (192 / 8)
    * Setting `taskset -c 0-N` blindly is wrong because CPU 0 is only available to the the first tenant, aka alice.  For example, running `taskset -c 0 ls` on any other tenants will fail. To fix this, the ID of available CPUs can be fetched by calling `os.sched_getaffinity(0)`.
* The last bug is `taskset` works with logical CPUs https://www.man7.org/linux/man-pages/man1/taskset.1.html, so using the result from `test_inductor_get_core_number` is also wrong because that function returns the number of physical CPUs.

### Testing

CPU benchmark jobs look ok

* [aarch64 torch.compile benchmark](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2021%20May%202025%2016%3A40%3A28%20GMT&stopTime=Wed%2C%2028%20May%202025%2016%3A40%3A28%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cpu%20(aarch64)&lBranch=fix-cpu-affinity-cgroups&lCommit=9a6288e083d650c470623f5fe136b1060824021c&rBranch=main&rCommit=dec5ab8d984b8a608140911351d877b9ddb141c2)
* [x86 micro benchmark](https://hud.pytorch.org/benchmark/llms?startTime=Wed%2C%2021%20May%202025%2016%3A41%3A26%20GMT&stopTime=Wed%2C%2028%20May%202025%2016%3A41%3A26%20GMT&granularity=day&lBranch=main&lCommit=c1b7dbc52aaa49f4cd147bbe5935110a4a10e3e3&rBranch=refs/tags/ciflow/inductor-micro-benchmark-cpu-x86/154503&rCommit=9a6288e083d650c470623f5fe136b1060824021c&repoName=pytorch%2Fpytorch&benchmarkName=&modelName=All%20Models&backendName=All%20Backends&modeName=All%20Modes&dtypeName=All%20DType&deviceName=cpu%20(x86_64)&archName=All%20Platforms)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154503
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-05-28 19:38:20 +00:00
66f53889d5 [nativert] port semaphore to c10 util (#153504)
Summary:
nativert RFC: https://github.com/zhxchen17/rfcs/blob/master/RFC-0043-torch-native-runtime.md

To land the runtime into PyTorch core, we will gradually land logical parts of the code into the Github issue and get each piece properly reviewed.

This diff adds a simple semaphore interface into c10 until c++20 where we get counting_semaphore

gonna need a oss build export to take a look at this...

Test Plan: CI

Differential Revision: D73882656

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153504
Approved by: https://github.com/zhxchen17
2025-05-28 19:17:30 +00:00
24980d2641 [ROCm][CI] Update build-environment for mi300 workflows (#153134)
so their test times are tracked separately in https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json. Currently, both MI200 and MI300 test times get combined into the same key `linux-focal-rocm-py3.10`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153134
Approved by: https://github.com/huydhn
2025-05-28 19:04:53 +00:00
d4ab8e74f3 Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)"
This reverts commit c6fc11af760d4ad1f01cc699a3c6488ab5f41770.

Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/izaitsevfb due to still fails to link internally at meta ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2917221575))
2025-05-28 18:22:39 +00:00
1c7a70b483 [AOTI][cutlass backend] Do not remove the cutlass kernel .o file after packaging (#154155)
Differential Revision: [D75253009](https://our.internmc.facebook.com/intern/diff/D75253009/)

In general, we want to cache the cutlass kernels.

Also saw an error saying .o not found.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154155
Approved by: https://github.com/chenyang78
2025-05-28 17:35:19 +00:00
66ac724b56 pyfmt lint torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py (#154488)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154488
Approved by: https://github.com/Skylion007
ghstack dependencies: #154483, #154484, #154485, #154487
2025-05-28 17:07:15 +00:00
dfe0f48123 pyfmt lint torch/_export/serde/schema.py (#154487)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154487
Approved by: https://github.com/Skylion007
ghstack dependencies: #154483, #154484, #154485
2025-05-28 17:07:15 +00:00
92cebed1bd pyfmt lint torch/_export/serde/serialize.py (#154485)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154485
Approved by: https://github.com/Skylion007
ghstack dependencies: #154483, #154484
2025-05-28 17:07:07 +00:00
b4fe5ca58a pymft lint torch/utils/weak.py (#154484)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154484
Approved by: https://github.com/Skylion007
ghstack dependencies: #154483
2025-05-28 17:06:58 +00:00
4de1b25df7 Remove empty files from execlude lint rule (#154483)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154483
Approved by: https://github.com/Skylion007
2025-05-28 17:06:50 +00:00
70539308ac [dynamo] updating gb_type names for uniqueness (#154452)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154452
Approved by: https://github.com/williamwen42
2025-05-28 16:54:10 +00:00
e313152a33 SDPA fix memory efficient attention for large batch dim (#154029)
Fixes #146704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154029
Approved by: https://github.com/ngimel
2025-05-28 16:53:53 +00:00
3b38989b5f Remove MemPoolContext (#154042)
Removes MemPoolContext from custom user mempools. The ground truth for which pool should be used is in graph_pools active pool, and MemPoolContext just introduced an opportunity for the pool pointed to by MemPoolContext and active pool in graph_pools to go out of sync (see all the asserts in the code to make sure that happens, and yet it still could happen in a multithread scenario, see my recent PRs (#153990).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154042
Approved by: https://github.com/albanD, https://github.com/syed-ahmed
2025-05-28 16:35:48 +00:00
d23aa7e182 Add deprecation warning for torch.ao.quantization (#153892)
Summary:
att

Test Plan:
(ao) $ PYTHONWARNINGS='default' python
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer
printing warning
*/anaconda3/envs/ao/lib/python3.10/site-packages/torch/ao/quantization/__init__.py:36: DeprecationWarning: torch.ao.quantization is deprecated. Plan is to
1. Remove eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. Remove fx graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx, torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://dev-discuss.pytorch.org/t/torch-ao-quantization-migration-plan/2810 for more details
  warnings.warn(
>>> a = XNNPACKQuantizer()
*/anaconda3/envs/ao/lib/python3.10/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py:281: DeprecationWarning: XNNPACKQuantizer is deprecated! Please use xnnpack quantizer in ExecuTorch (https://github.com/pytorch/executorch/tree/main/backends/xnnpack/quantizer) instead
  warnings.warn(f"{self.__class__.__name__} is deprecated! Please use xnnpack quantizer in ExecuTorch (https://github.com/pytorch/executorch/tree/main/backends/xnnpack/quantizer) instead", DeprecationWarning)
>>>

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153892
Approved by: https://github.com/Skylion007
2025-05-28 16:25:30 +00:00
5bf74753f6 [precompile] Prune local scope variables for guard serialization. (#154431)
Summary: Prune unused local objects from serialized local scope if they are not used in guard reconstruction. This is helpful when a user program takes things like local callable functions or the function call is recursive.

Test Plan:
test/dynamo/test_guard_serialization.py -k test_function_locals

Before pruning locals:
```
state = GuardsState(output_graph=OutputGraphGuardsState(local_scope={'x': tensor([ 0.0461,  0.4024, -1.0115]), 'g': <function ...aints=None, _guards=<torch._guards.GuardsSet object at 0x7fbccc7e9fc0>, _aotautograd_guards=[]), shape_code_parts=None)

    def pickle_guards_state(state: GuardsState) -> bytes:
        buf = io.BytesIO()
        pickler = GuardsStatePickler(buf)
        try:
            pickler.dump(state)
        except AttributeError as e:
>           raise torch._dynamo.exc.PackageError(str(e)) from e
E           torch._dynamo.exc.PackageError: Can't pickle local object 'TestGuardSerialization.test_function_locals.<locals>.foo'
```
After the diff
```
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D75452123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154431
Approved by: https://github.com/jansel
2025-05-28 16:03:02 +00:00
9db7bcb3fe [Dynamo] Introduce hook receiving list of traced code objects (#153622)
This PR:
* Expands `Hooks` with a new, optional `frame_traced_fn` field. It should be a callable receiving the list of traced code objects
* Maintains a list of `traced_code` objects in the `TracingContext` of an `OutputGraph`
    *  Whenever an `inline_call()` is encountered, the corresponding code object is added to this set
    * `OutputGraph`'s associated `f_code` is added to the list just before the hook is called

I believe use of this hook should enable the source code hashing that vLLM does in a better way than monkey-patching `inline_call()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153622
Approved by: https://github.com/jansel
2025-05-28 15:40:09 +00:00
476e0a643a [ez] add docblock for ShapeGuardPythonPrinter (#154403)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154403
Approved by: https://github.com/jingsh
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379, #154380, #154381, #154383, #154384, #154385, #154402
2025-05-28 14:17:17 +00:00
473a93eb58 [ez] add docblock for _ShapeGuardPrinter (#154402)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154402
Approved by: https://github.com/jingsh
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379, #154380, #154381, #154383, #154384, #154385
2025-05-28 14:13:22 +00:00
35a473e364 [ez] add docblock for guard_scalar (#154385)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154385
Approved by: https://github.com/jingsh
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379, #154380, #154381, #154383, #154384
2025-05-28 14:10:07 +00:00
ee4f433963 [ez] add docblock for _guard_or (#154384)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154384
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379, #154380, #154381, #154383
2025-05-28 14:06:29 +00:00
e9b97d19b1 [ez] Make SymNodeImpl comments less misleading (#154480)
As discussed in DS workchat, it's easy for users to get confused by
guarding for these supposedly non-guarding methods. The TL;DR is in the
case of non pythonic compilers like XLA, we actually do guard. I've
updated the comments accordingly to reduce confusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154480
Approved by: https://github.com/pianpwk, https://github.com/Skylion007
2025-05-28 14:04:32 +00:00
a75e3a02be Revert "[dynamo, nested graph breaks] small fixes to resume function generation (#151056)"
This reverts commit 28e7aa21c522e92ea01a62dfdc5e3b74e398d8f0.

Reverted https://github.com/pytorch/pytorch/pull/151056 on behalf of https://github.com/malfet due to Not sure which one, but it broke test_error_messages, see 203b0efd63/1 ([comment](https://github.com/pytorch/pytorch/pull/151056#issuecomment-2916437433))
2025-05-28 13:53:50 +00:00
9603d6382d Revert "[dynamo, nested graph breaks] refactor codegen to minimize NULL codegen'ing (#153510)"
This reverts commit 1fe98429222a8ba5e16dd9381f50a8fb90edcf0e.

Reverted https://github.com/pytorch/pytorch/pull/153510 on behalf of https://github.com/malfet due to Not sure which one, but it broke test_error_messages, see 203b0efd63/1 ([comment](https://github.com/pytorch/pytorch/pull/151056#issuecomment-2916437433))
2025-05-28 13:53:50 +00:00
5fd7004dc9 Revert "[dynamo, nested graph breaks] remove block stack graph break in output_graph (#153772)"
This reverts commit 9a66c30bdc563c62375e5030c4103b67515b8dac.

Reverted https://github.com/pytorch/pytorch/pull/153772 on behalf of https://github.com/malfet due to Not sure which one, but it broke test_error_messages, see 203b0efd63/1 ([comment](https://github.com/pytorch/pytorch/pull/151056#issuecomment-2916437433))
2025-05-28 13:53:50 +00:00
e86439ed5b Revert "[dynamo, nested graph breaks] add skip_frame debugging function (#153773)"
This reverts commit aadf9eae63c4793e1107a3b21ede30e5289eeaca.

Reverted https://github.com/pytorch/pytorch/pull/153773 on behalf of https://github.com/malfet due to Not sure which one, but it broke test_error_messages, see 203b0efd63/1 ([comment](https://github.com/pytorch/pytorch/pull/151056#issuecomment-2916437433))
2025-05-28 13:53:50 +00:00
203b0efd63 [PP] Allow unused kwargs in ZB path (#153498)
This is a fix when an unused kwarg is in the PP stage forward, we try to call `torch.autograd.grad()` and update its gradients when it shouldn't have gradients. Leading to this error:

```
[rank3]:[rank3]: File "/data/users/howardhuang/pytorch/torch/distributed/pipelining/stage.py", line 613, in
[rank3]:[rank3]: return lambda: stage_backward_input(
[rank3]:[rank3]: File "/data/users/howardhuang/pytorch/torch/distributed/pipelining/_backward.py", line 199, in stage_backward_input
[rank3]:[rank3]: dinputs = torch.autograd.grad(
[rank3]:[rank3]: File "/data/users/howardhuang/pytorch/torch/autograd/init.py", line 503, in grad
[rank3]:[rank3]: result = _engine_run_backward(
[rank3]:[rank3]: File "/data/users/howardhuang/pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank3]:[rank3]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank3]:[rank3]: RuntimeError: One of the differentiated Tensors does not require grad
```

related issues: https://github.com/pytorch/torchtitan/issues/1188

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153498
Approved by: https://github.com/kwen2501
2025-05-28 13:34:04 +00:00
cf7451f279 Fix signature of torch.sparse_coo_tensor() (#152681)
Fixes #145371

@pearu Searched all and find these codes, wondering whether is the root cause of the issue, could you have a review? Thanks a lot!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152681
Approved by: https://github.com/Skylion007, https://github.com/pearu, https://github.com/nikitaved
2025-05-28 13:16:41 +00:00
f58143b945 [Typing] Refactor torch.types.Device in torch/cuda/__init__.py (#153447)
Part of: #152952
Follow up: #153027

Here is the definition of `torch.types.Device`:

ab997d9ff5/torch/types.py (L74)

So `Optional[Union[Device, int]]` is equivalent to `torch.types.Device`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153447
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-05-28 10:09:31 +00:00
fdc339003b Revert "[AOTI] Support multi-arch when using package_cpp_only (#154414)"
This reverts commit a84d8c4a1cc515db274366537afd0b1492800c2d.

Reverted https://github.com/pytorch/pytorch/pull/154414 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing ROCm trunk job ([comment](https://github.com/pytorch/pytorch/pull/154414#issuecomment-2915597821))
2025-05-28 09:23:31 +00:00
853958f82c Fix: Replacements can cause runtime assertions to disappear and can cause invalid inductor code. (#153661)
Lets explore firs a couple of problem related to replacements and runtime assertions.

#### example problem 1
if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert  not being inserted during insert_deferred_runtime_asserts.
The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen,  but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc.

#### example problem 2
Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure.
internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/

## The solution :
Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item().

The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion.

This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions.
I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue.

Other approaches that did not work :
#### ban replacements on unbacked.
1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements
but problem 2 becomes more problematic.
2. Problem two, it affects the quality of reasoning ! in a bad way.

#### Apply specialization on runtime assertions before codegen .
1. Can fix some issues, but may lead also to runtime assertions becoming NOPs.
2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661
Approved by: https://github.com/jansel
2025-05-28 09:08:05 +00:00
aadf9eae63 [dynamo, nested graph breaks] add skip_frame debugging function (#153773)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153773
Approved by: https://github.com/jansel
ghstack dependencies: #151056, #153510, #153772
2025-05-28 08:54:09 +00:00
9a66c30bdc [dynamo, nested graph breaks] remove block stack graph break in output_graph (#153772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153772
Approved by: https://github.com/jansel
ghstack dependencies: #151056, #153510
2025-05-28 08:54:09 +00:00
1fe9842922 [dynamo, nested graph breaks] refactor codegen to minimize NULL codegen'ing (#153510)
Stop codegening NULLs that we need to pop later. Some output_graph.py changes to prepare for nested graph break support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153510
Approved by: https://github.com/jansel
ghstack dependencies: #151056
2025-05-28 08:54:09 +00:00
28e7aa21c5 [dynamo, nested graph breaks] small fixes to resume function generation (#151056)
Old: ~pack resume function stack + locals into a list: we need to be able to pass frame stack+locals in lists to hand off to nested functions in the future, so we implement this part first.~

We are no longer doing this right now since GraphModule/guard variable naming gets messed up. Going forward, our approach will be to keep the top frame unpacked, but pack the rest of the contents of other frames in a list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151056
Approved by: https://github.com/jansel
2025-05-28 08:54:09 +00:00
cyy
9d04c0f352 Remove outdated CUDA 11 conditions (#154313)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154313
Approved by: https://github.com/eqy
2025-05-28 08:44:58 +00:00
1d9b7dd2d1 [PGO] suggest dynamic whitelist for recompilations (#154189)
suggests `TORCH_COMPILE_DYNAMIC_SOURCES` based off tensor size changes in PGO code state, including parameters.

Closing #153442 which took the dynamo guards approach.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154189
Approved by: https://github.com/bobrenjc93
2025-05-28 07:11:43 +00:00
fe760b6636 [ez] add docblock for _free_unbacked_symbols_with_path (#154383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154383
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379, #154380, #154381
2025-05-28 05:53:50 +00:00
8e25ba6963 [ez] add docblock for find_symbol_binding_fx_nodes (#154381)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154381
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379, #154380
2025-05-28 05:44:26 +00:00
08c29deb5f [ez] add docblock to is_symbol_binding_fx_node (#154380)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154380
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378, #154379
2025-05-28 05:41:19 +00:00
07405a6cff [ez] add docblock for free_unbacked_symbols (#154379)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154379
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377, #154378
2025-05-28 05:37:25 +00:00
dcdaef5206 [ez] add docblock for free_symbols (#154378)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154378
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405, #154377
2025-05-28 05:34:25 +00:00
abc3fdc7ac [ez] add docblock for _iterate_exprs (#154377)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154377
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405
2025-05-28 05:28:58 +00:00
ab6cb85cb0 [ez] add docblock for _remove_effect_token_unbacked_bindings (#154405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154405
Approved by: https://github.com/Skylion007, https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404
2025-05-28 05:16:14 +00:00
fde8f6a8b8 [ez] add docblock for _suggest_torch_checks (#154404)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154404
Approved by: https://github.com/Skylion007
ghstack dependencies: #154374, #154375, #154376, #154386, #154401
2025-05-28 04:45:55 +00:00
b82fb57b67 [ez] add docblock for RuntimeAssert (#154401)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154401
Approved by: https://github.com/Skylion007
ghstack dependencies: #154374, #154375, #154376, #154386
2025-05-28 04:43:22 +00:00
d64b4a91dd [ez] remove unused function _constrain_symbol_range (#154386)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154386
Approved by: https://github.com/Skylion007
ghstack dependencies: #154374, #154375, #154376
2025-05-28 04:41:00 +00:00
ef90cc18d7 use definitely_contiguous for _prim_elementwise_meta short circuit (#153441)
*
This verifies that the check short circuit is not material. https://github.com/pytorch/pytorch/pull/153431
```
import torch
from torch.export import Dim, export
class MyModel(torch.nn.Module):
    def forward(self, x, ranks):
        first_k = ranks.max().item()
        torch._check_is_size(first_k)
        narrow = x.narrow(dim = 1, start = 0, length = first_k)
        lt = narrow < narrow.size(1)
        return lt
inps = (
    torch.randn((8, 16), device="cuda"),
    torch.arange(8, device="cuda", dtype=torch.int8)
)
spec = {
    "x": (Dim.AUTO, Dim.AUTO),
    "ranks": (Dim.AUTO,),
}
traced = export(MyModel(), inps, dynamic_shapes=spec, strict=True).run_decompositions({})

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153441
Approved by: https://github.com/jansel
ghstack dependencies: #153432
2025-05-28 03:41:26 +00:00
39df901b2a introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors.
in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want
to use definitely _contiguous API.

This is appleid for reshape in this PR and also to  tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true  now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432
Approved by: https://github.com/bobrenjc93
2025-05-28 03:41:26 +00:00
54f1f29fed [dynamo] dynamic gb_type -> static gb_type (#154435)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154435
Approved by: https://github.com/williamwen42
2025-05-28 03:14:26 +00:00
f12ce4e36b [Intel GPU] convolution fusion at XPU backend (#154202)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154202
Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/etaf
ghstack dependencies: #140365
2025-05-28 03:14:18 +00:00
c6fc11af76 Fix the Problems About Defining Static Variable in Inline Function (#147095)
Refer to https://github.com/pytorch/pytorch/issues/125465 for more informations

- Remove unused header files
- Move the inline function that defines the static variable to .cc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147095
Approved by: https://github.com/cyyever, https://github.com/albanD
2025-05-28 02:47:16 +00:00
855eff8e8e Don't CSE unbacked nodes (#154387)
* #154440
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154387
Approved by: https://github.com/TroyGarden
ghstack dependencies: #154440
2025-05-28 02:21:56 +00:00
919a1a17e3 [ez] Replace misleading implementations with NYI (#154440)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154440
Approved by: https://github.com/Skylion007, https://github.com/pianpwk
2025-05-28 02:21:56 +00:00
a84d8c4a1c [AOTI] Support multi-arch when using package_cpp_only (#154414)
Summary: Add support of multi_arch_kernel_binary in the package_cpp_only mode. More specifically, generate specific cmake targets to compile .ptx to .fatbin and embed them in the final shared library or binary.

Differential Revision: [D75452096](https://our.internmc.facebook.com/intern/diff/D75452096)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154414
Approved by: https://github.com/angelayi
ghstack dependencies: #154412, #154413
2025-05-28 01:20:38 +00:00
cde82d25b7 [AOTI] Add a multi_arch_kernel_binary option (#154413)
Summary: CUDA can support multi-arch with the fatbin format. Add this multi_arch_kernel_binary option, so the compiled model binary can run across different GPU archs.

Differential Revision: [D75452094](https://our.internmc.facebook.com/intern/diff/D75452094)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154413
Approved by: https://github.com/angelayi
ghstack dependencies: #154412
2025-05-28 01:20:38 +00:00
4d8f3d537a [AOTI][refactor] Rename embed_cubin to embed_kernel_binary (#154412)
Summary: Rename as it is not CUDA specific.

Differential Revision: [D75452095](https://our.internmc.facebook.com/intern/diff/D75452095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154412
Approved by: https://github.com/angelayi
2025-05-28 01:20:28 +00:00
e79790e14b [ez] add docblock for _sympy_from_args (#154376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154376
Approved by: https://github.com/Skylion007
ghstack dependencies: #154374, #154375
2025-05-27 23:43:13 +00:00
fe082c5ffe Move inductor workflows focal (ubuntu 20.04) -> jammy (ubuntu 22.04) (#154153)
Trying to fix: https://github.com/pytorch/pytorch/issues/154157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154153
Approved by: https://github.com/Skylion007, https://github.com/huydhn, https://github.com/nike4949, https://github.com/cyyever
2025-05-27 23:16:21 +00:00
3f10c9d8af Fixed an issue with XPU skip so the test_decompose_mem_bound_mm.py suite can be ran correctly (#153245)
Fixes #153239

Replaced custom decorator with the common one. Although the better way to skip the whole suite would be to add it to skip list in run_test.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153245
Approved by: https://github.com/jeffdaily
2025-05-27 23:10:25 +00:00
4b39832412 [CI] Update torchbench pin (#154453)
Related to https://github.com/pytorch/pytorch/issues/154446
Pins torchbench repo to a https://github.com/pytorch/benchmark/pull/2620 which pins opacus to ``1.5.3`` version

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154453
Approved by: https://github.com/wdvr, https://github.com/malfet
2025-05-27 23:08:42 +00:00
247ea229ba Create issue template: Release highlight for proposed Feature (#154125)
Authors: @anitakat @atalman

This is related to: https://github.com/pytorch/pytorch/issues/152134 . Adding RFC template for feature submissions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154125
Approved by: https://github.com/anitakat, https://github.com/ZainRizvi, https://github.com/albanD
2025-05-27 22:45:21 +00:00
53affa273b [MTIA Aten Backend][1.3/n] Migrate remaining view ops, which all need explicit register in native_functions.yaml (#154337)
See context in D75266206.

This diff/PR migrates all the remaining view ops, which all need changes in `native_functions.yaml` and thus need to be exported to PR.

Ops covered by this diff:
- _reshape_alias
- unfold

internal: Also delete the entire aten_mtia_view_ops.cpp file, and update corresponding build config.

Differential Revision: [D75385411](https://our.internmc.facebook.com/intern/diff/D75385411/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154337
Approved by: https://github.com/nautsimon
ghstack dependencies: #154336
2025-05-27 22:18:12 +00:00
eaf355cb11 [BE] Clean up unused parameter input in AOTIModel (#154276)
Summary: As title

Test Plan: CI

Differential Revision: D74691763

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154276
Approved by: https://github.com/Skylion007
2025-05-27 22:17:32 +00:00
241f8dc84d Revert "Remove outdated CUDA 11 conditions (#154313)"
This reverts commit 3936e6141c09dab94f21e4fdab7bea4bddf62ac2.

Reverted https://github.com/pytorch/pytorch/pull/154313 on behalf of https://github.com/izaitsevfb due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/154313#issuecomment-2914230005))
2025-05-27 21:54:41 +00:00
6be829535f [ROCm] Improve vectorized elementwise kernel performance in MI300X (#153634)
* Use non-temporal loads to improve the vectorized elementwise kernel performance on MI300
* Use thread_work_size of 8 or 16 for vectorized elementwise kernel

Co-author: @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153634
Approved by: https://github.com/jeffdaily
2025-05-27 20:49:32 +00:00
555fc05868 Revert "[Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)"
This reverts commit 6169ca0b65bcb382faa1a2287278b3717c18f127.

Reverted https://github.com/pytorch/pytorch/pull/154371 on behalf of https://github.com/benjaminglass1 due to Appears to have broken main ([comment](https://github.com/pytorch/pytorch/pull/154371#issuecomment-2913975736))
2025-05-27 20:39:09 +00:00
7359705232 Add CPython tests for unittest (#150788)
Tests:
* test_assertions.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150788
Approved by: https://github.com/williamwen42
2025-05-27 20:26:17 +00:00
12fc06d267 Add CPython complex tests (#152015)
Tests:
* test_complex.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152015
Approved by: https://github.com/williamwen42
2025-05-27 20:24:28 +00:00
3b218e56dc Add CPython tests for iter/sort (#150797)
Tests:
* test_iter.py
* test_sort.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150797
Approved by: https://github.com/williamwen42
2025-05-27 20:22:34 +00:00
4fd8a54a41 [ez] add docblock for is_accessor_node (#154375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154375
Approved by: https://github.com/Skylion007, https://github.com/pianpwk
ghstack dependencies: #154374
2025-05-27 19:47:32 +00:00
b367e5f6a6 [ROCm][Windows] Fix building torch 2.8 wheel with ROCm (added hipblasLt and rocblas directories) (#153144)
Since rocblas.dll and hipblaslt.dll are copied to torch/lib, rocblas and hipblaslt directories are needed to be stored there too (otherwise we have an error after wheel installation while searching for files in rocblas/library and hipblaslt/library which doesn't exist). This PR fixes this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153144
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-05-27 19:40:28 +00:00
fa6ca59079 Revert "Move inductor workflows focal (ubuntu 20.04) -> jammy (ubuntu 22.04) (#154153)"
This reverts commit 2bd95f3a1f07132aa00f5c438c5228866d7dd1f8.

Reverted https://github.com/pytorch/pytorch/pull/154153 on behalf of https://github.com/malfet due to Broke inductor tests, see b8452e55bc/1 ([comment](https://github.com/pytorch/pytorch/pull/154153#issuecomment-2913738047))
2025-05-27 19:23:28 +00:00
6169ca0b65 [Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)
Prepares for the next PR in the stack by tightening up typing on a `cpp_wrapper` interface that's only used in one (well-typed) place, as well as downstream effects of that change. In particular, this enabled:

1. removing a number of now clearly unnecessary asserts
2. adding a few more targeted asserts to validate the code's current assumptions
3. removing some unneeded control flow in several functions

As far as I can tell, this PR should be functionally neutral. One argument was removed from a `cpp_wrapper` public API, but that argument was unused, and only had a single callsite.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154371
Approved by: https://github.com/desertfire
2025-05-27 19:17:41 +00:00
75bbd4989c [dynamo] Support using symint from dispatcher-style tensor subclass (#154130)
Fixes #146932.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154130
Approved by: https://github.com/laithsakka
2025-05-27 19:05:46 +00:00
8c0f07f944 Revert "[ROCm] Improve vectorized elementwise kernel performance in MI300X (#153634)"
This reverts commit 0d4de7872ac019abbd6e87b3391b2276d9d05bd4.

Reverted https://github.com/pytorch/pytorch/pull/153634 on behalf of https://github.com/malfet due to Broke inductor jobs, see b8452e55bc/1 ([comment](https://github.com/pytorch/pytorch/pull/153634#issuecomment-2913619071))
2025-05-27 19:02:59 +00:00
b8452e55bc [Kineto x Insight] Update Kineto submodule (#154426)
Summary: We add a new ActivityType::MTIA_INSIGHT in 20f652846f

Test Plan: CI

Differential Revision: D75454945

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154426
Approved by: https://github.com/Skylion007
2025-05-27 18:29:29 +00:00
5075df6fee Make torch importable if compiled without TensorPipe (#154382)
By delaying the import/hiding it behind `torch.distributed.rpc.is_tensorpipe_avaiable()` check
Fixes https://github.com/pytorch/pytorch/issues/154300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154382
Approved by: https://github.com/Skylion007
ghstack dependencies: #154325
2025-05-27 18:13:38 +00:00
f472ea63bb [BE] Fix typos in SyntaxError description (#154436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154436
Approved by: https://github.com/seemethere, https://github.com/wdvr, https://github.com/ZainRizvi
2025-05-27 18:08:58 +00:00
cfbd99fdfd [Pytorch] Add option to CPU Blas GEMM to avoid output downcast (#154012)
Summary:
Dot product for a single output element consists of 3 steps (both input vectors have elements of type scalar_t):
1. elementwise vector multiply (scalar_t x scalar_t -> opmath_t)
2. vector reduction to a scalar value (opmath_t -> opmath_t)
3. optional downcast if opmath_t != out_t

The current blas kernel performs steps 1 and 2 correctly, but for step 3, it will always downcast to scalar_t even when opmath_t == output_t (and then do an upcast back to output_t), which results in precision loss. This diff fixes the precision loss in the BlasKernel

Test Plan: Attention CI passes

Differential Revision: D75023858

topic: not user facing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154012
Approved by: https://github.com/Valentine233, https://github.com/aditew01, https://github.com/CaoE, https://github.com/drisspg
2025-05-27 17:43:21 +00:00
1ca082d9a1 [ez] Rewrite comment to be more friendly to non haskellers (#151421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151421
Approved by: https://github.com/aorenste
2025-05-27 17:32:34 +00:00
70fbd5e08c [ez] Add docblock for resolve_unbacked_bindings (#154374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154374
Approved by: https://github.com/Skylion007, https://github.com/pianpwk
2025-05-27 17:05:49 +00:00
286 changed files with 8288 additions and 1284 deletions

View File

@ -820,16 +820,7 @@ test_inductor_torchbench_smoketest_perf() {
done
}
test_inductor_get_core_number() {
if [[ "${TEST_CONFIG}" == *aarch64* ]]; then
echo "$(($(lscpu | grep 'Cluster(s):' | awk '{print $2}') * $(lscpu | grep 'Core(s) per cluster:' | awk '{print $4}')))"
else
echo "$(($(lscpu | grep 'Socket(s):' | awk '{print $2}') * $(lscpu | grep 'Core(s) per socket:' | awk '{print $4}')))"
fi
}
test_inductor_set_cpu_affinity(){
#set jemalloc
JEMALLOC_LIB="$(find /usr/lib -name libjemalloc.so.2)"
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
@ -841,14 +832,23 @@ test_inductor_set_cpu_affinity(){
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_BLOCKTIME=1
fi
cores=$(test_inductor_get_core_number)
# Set number of cores to 16 on Aarch64 for performance runs.
# Use nproc here instead of lscpu because it takes into account cgroups slice
cpus=$(nproc)
thread_per_core=$(lscpu | grep 'Thread(s) per core:' | awk '{print $4}')
cores=$((cpus / thread_per_core))
# Set number of cores to 16 on aarch64 for performance runs
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
cores=16
fi
export OMP_NUM_THREADS=$cores
end_core=$((cores-1))
export TASKSET="taskset -c 0-$end_core"
# Handle cgroups slice start and end CPU
start_cpu=$(python -c 'import os; print(min(os.sched_getaffinity(0)))')
# Leaving one physical CPU for other tasks
end_cpu=$(($(python -c 'import os; print(max(os.sched_getaffinity(0)))') - thread_per_core))
export TASKSET="taskset -c $start_cpu-$end_cpu"
}
test_inductor_torchbench_cpu_smoketest_perf(){

View File

@ -0,0 +1,111 @@
name: 🚀 Release highlight for proposed Feature
description: Submit a Release highlight for proposed Feature
labels: ["release-feature-request"]
body:
- type: textarea
attributes:
label: Release highlight for proposed Feature
description: >
Example: “A torch.special module, analogous to SciPy's special module.”
- type: input
id: contact
attributes:
label: Point(s) of contact
description: How can we get in touch with you if we need more info?
placeholder: ex. github username
validations:
required: false
- type: dropdown
attributes:
label: Release Mode (pytorch/pytorch features only)
description: |
If "out-of-tree", please include the GH repo name
options:
- In-tree
- Out-of-tree
validations:
required: true
- type: textarea
attributes:
label: Out-Of-Tree Repo
description: >
please include the GH repo name
validations:
required: false
- type: textarea
attributes:
label: Description and value to the user
description: >
Please provide a brief description of the feature and how it will benefit the user.
validations:
required: false
- type: textarea
attributes:
label: Link to design doc, GitHub issues, past submissions, etc
validations:
required: false
- type: textarea
attributes:
label: What feedback adopters have provided
description: >
Please list users/teams that have tried the feature and provided feedback. If that feedback motivated material changes (API, doc, etc..), a quick overview of the changes and the status (planned, in progress, implemented) would be helpful as well.
validations:
required: false
- type: dropdown
attributes:
label: Plan for documentations / tutorials
description: |
Select One of the following options
options:
- Tutorial exists
- Will submit a PR to pytorch/tutorials
- Will submit a PR to a repo
- Tutorial is not needed
validations:
required: true
- type: textarea
attributes:
label: Additional context for tutorials
description: >
Please provide a link for existing tutorial or link to a repo or context for why tutorial is not needed.
validations:
required: false
- type: dropdown
attributes:
label: Marketing/Blog Coverage
description: |
Are you requesting feature Inclusion in the release blogs?
options:
- "Yes"
- "No"
validations:
required: true
- type: textarea
attributes:
label: Are you requesting other marketing assistance with this feature?
description: >
E.g. supplementary blogs, social media amplification, etc.
validations:
required: false
- type: textarea
attributes:
label: Release Version
description: >
Please include release version for marketing coverage.
validations:
required: false
- type: textarea
attributes:
label: OS / Platform / Compute Coverage
description: >
Please list the platforms supported by the proposed feature. If the feature supports all the platforms, write "all". Goal of this section is to clearly share if this feature works in all PyTorch configurations or is it limited to only certain platforms/configurations (e.g. CPU only, GPU only, Linux only, etc...)
validations:
required: false
- type: textarea
attributes:
label: Testing Support (CI, test cases, etc..)
description: >
Please provide an overview of test coverage. This includes unit testing and integration testing, but if E2E validation testing has been done to show that the feature works for a certain set of use cases or models please mention that as well.
validations:
required: false

View File

@ -1 +1 @@
6693f5845f212d8af3513f8b8d275d5b65db9caf
e03a63be43e33596f7f0a43b0f530353785e4a59

View File

@ -38,12 +38,12 @@ jobs:
opt_out_experiments: lf
linux-jammy-rocm-py3_10-inductor-build:
name: rocm-py3.10-inductor
name: rocm-py3.10-inductor-mi300
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
build-environment: linux-jammy-rocm-py3.10-mi300
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
@ -56,11 +56,11 @@ jobs:
permissions:
id-token: write
contents: read
name: rocm-py3.10-inductor
name: rocm-py3.10-inductor-mi300
uses: ./.github/workflows/_rocm-test.yml
needs: linux-jammy-rocm-py3_10-inductor-build
with:
build-environment: linux-jammy-rocm-py3.10
build-environment: linux-jammy-rocm-py3.10-mi300
docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.test-matrix }}
secrets: inherit

View File

@ -50,12 +50,12 @@ jobs:
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
name: linux-jammy-rocm-py3.10-mi300
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
build-environment: linux-jammy-rocm-py3.10-mi300
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
@ -69,13 +69,13 @@ jobs:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
name: linux-jammy-rocm-py3.10-mi300
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
build-environment: linux-jammy-rocm-py3.10-mi300
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -38,12 +38,12 @@ jobs:
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
name: linux-jammy-rocm-py3.10-mi300
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
build-environment: linux-jammy-rocm-py3.10-mi300
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
@ -61,13 +61,13 @@ jobs:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
name: linux-jammy-rocm-py3.10-mi300
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
build-environment: linux-jammy-rocm-py3.10-mi300
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -1160,12 +1160,6 @@ exclude_patterns = [
'torch/_inductor/autoheuristic/artifacts/**',
# These files are all grandfathered in, feel free to remove from this list
# as necessary
'test/_nvfuser/__init__.py',
'test/_nvfuser/test_dynamo.py',
'test/_nvfuser/test_python_frontend.py',
'test/_nvfuser/test_torchscript.py',
'test/delete.py',
'test/expect/__init__.py',
'test/quantization/__init__.py',
'test/quantization/core/__init__.py',
'test/quantization/core/experimental/apot_fx_graph_mode_ptq.py',
@ -1322,12 +1316,6 @@ exclude_patterns = [
'torch/_export/passes/const_prop_pass.py',
'torch/_export/passes/functionalize_side_effectful_ops_pass.py',
'torch/_export/passes/replace_sym_size_ops_pass.py',
'torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py',
'torch/_export/serde/__init__.py',
'torch/_export/serde/schema.py',
'torch/_export/serde/serialize.py',
'torch/_export/serde/upgrade.py',
'torch/_export/trace.py',
'torch/testing/_internal/__init__.py',
'torch/testing/_internal/autocast_test_lists.py',
'torch/testing/_internal/autograd_function_db.py',
@ -1444,7 +1432,6 @@ exclude_patterns = [
'torch/utils/throughput_benchmark.py',
'torch/utils/viz/__init__.py',
'torch/utils/viz/_cycles.py',
'torch/utils/weak.py',
]
init_command = [
'python3',

View File

@ -184,6 +184,12 @@ new_local_repository(
path = "third_party/nlohmann",
)
new_local_repository(
name = "moodycamel",
build_file = "//third_party:moodycamel.BUILD",
path = "third_party/concurrentqueue",
)
new_local_repository(
name = "tensorpipe",
build_file = "//third_party:tensorpipe.BUILD",

View File

@ -78,7 +78,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
return CUDA_R_64I;
case c10::ScalarType::BFloat16:
return CUDA_R_16BF;
#if defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300)
#if !defined(USE_ROCM) || ROCM_VERSION >= 60300
case c10::ScalarType::Float8_e4m3fn:
return CUDA_R_8F_E4M3;
case c10::ScalarType::Float8_e5m2:

View File

@ -139,7 +139,7 @@ void CUDAGraph::capture_end() {
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
// cudaGraphInstantiateWithFlags
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
#if (defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60200))
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
int version = 0;
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
if (version < 11040) {
@ -154,7 +154,7 @@ void CUDAGraph::capture_end() {
#endif
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
#if (defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60200))
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
} else {
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,

View File

@ -135,6 +135,7 @@ CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) {
} // namespace (anonymous)
DEFINE_DISPATCH(gemm_stub);
DEFINE_DISPATCH(gemm_no_downcast_stub);
void gemm(
TransposeType transa, TransposeType transb,
@ -452,18 +453,18 @@ void gemm(
// for the fallback path, first compute gemm with beta = 0,
// and then add c in full precision.
int64_t c_size = n * m;
std::vector<at::BFloat16> bfloat_c(c_size, 0.f);
gemm_stub(
std::vector<float> float_c(c_size, 0.f);
gemm_no_downcast_stub(
at::kCPU, at::kBFloat16,
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m);
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
auto offset = j * ldc + i;
// beta == 0 won't propagate NaN from C
if (beta == 0.f) {
c[offset] = c10::convert<float>(bfloat_c[j * m + i]);
c[offset] = float_c[j * m + i];
} else {
c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]);
c[offset] = beta * c[offset] + float_c[j * m + i];
}
}
}

View File

@ -29,6 +29,18 @@ using gemm_fn = void(*)(
DECLARE_DISPATCH(gemm_fn, gemm_stub)
using gemm_no_downcast_fn = void(*)(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc);
DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
template <typename scalar_t>
void gemm(
TransposeType transa, TransposeType transb,

View File

@ -24,6 +24,7 @@
#include <ATen/native/cpu/SerialStackImpl.h>
#include <ATen/native/cpu/StackKernel.h>
#include <ATen/quantized/QTensorImpl.h>
#include <c10/core/Contiguity.h>
#include <c10/core/GradMode.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
@ -1993,11 +1994,15 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
}
if (self.is_contiguous() && !self.is_mkldnn()) {
auto sym_sizes = self.sym_sizes();
auto sym_strides = self.sym_strides();
auto sym_numel = self.sym_numel();
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
!self.is_mkldnn()) {
return self.view_symint(proposed_shape);
}
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
if (self.is_mkldnn()) {
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
@ -2005,8 +2010,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
// `computeStride` returns the proper strides to use if this
// `reshape` can be just a view.
auto stride =
at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape);
// NB: Even though we have viewable geometry and the target strides here,
// we do not just call `as_strided` on `self` because the backward

View File

@ -99,7 +99,7 @@ auto sum(int64_t N, Func f) {
return partial_sums[0];
}
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
gemm_notrans_(
int64_t m,
@ -111,7 +111,7 @@ gemm_notrans_(
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
out_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
@ -135,7 +135,7 @@ gemm_notrans_(
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
gemm_notrans_(
int64_t m,
@ -147,7 +147,7 @@ gemm_notrans_(
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
out_t* c,
int64_t ldc) {
// c += alpha * (a @ b)
for (const auto i : c10::irange(m)) {
@ -165,7 +165,7 @@ gemm_notrans_(
}
}
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
@ -173,7 +173,7 @@ void gemm_transa_(
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
out_t *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
const scalar_t *a_ = a;
for (const auto i : c10::irange(m)) {
@ -225,6 +225,7 @@ void gemm_transb_impl(
}
}
// in this case, scalar_t == opmath_t == out_t so out_t template param is not needed
template <typename scalar_t, typename opmath_t>
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
gemm_transb_(
@ -247,7 +248,7 @@ gemm_transb_(
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
gemm_transb_(
TransposeType transb,
@ -260,7 +261,7 @@ gemm_transb_(
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
out_t* c,
int64_t ldc) {
// We need to calculate full-precision dot products for correctness;
// users notice error accumulation with reduced-width types (e.g.,
@ -304,7 +305,7 @@ gemm_transb_(
}
}
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
void gemm_transab_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
@ -312,7 +313,7 @@ void gemm_transab_(
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
out_t *c, int64_t ldc) {
// c = beta * c + alpha * (a.T @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
@ -436,7 +437,7 @@ void gemm_transa_(
}
#endif // !defined(C10_MOBILE)
template <typename scalar_t, typename opmath_t>
template <typename scalar_t, typename opmath_t, typename out_t>
void gemm_core_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
@ -444,7 +445,7 @@ void gemm_core_(
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
out_t *c, int64_t ldc) {
if (transa == TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
@ -493,6 +494,27 @@ void cpublas_gemm_impl(
});
}
void cpublas_gemm_no_downcast_impl(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc) {
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_no_downcast_impl", [&]{
using opmath_t = at::opmath_type<scalar_t>;
gemm_core_(
transa, transb, m, n, k,
alpha.to<opmath_t>(),
static_cast<const scalar_t *>(a), lda,
static_cast<const scalar_t *>(b), ldb,
beta.to<opmath_t>(),
static_cast<opmath_t *>(c), ldc);
});
}
void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
if (type == at::kBool) {
auto a = _a.to<bool>();
@ -530,6 +552,7 @@ void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t i
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl)
REGISTER_DISPATCH(cpublas::gemm_no_downcast_stub, &cpublas::cpublas_gemm_no_downcast_impl)
REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl)
REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl)

View File

@ -3,6 +3,7 @@
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/interned_strings.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/mkldnn/xpu/FusionUtils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/ops/full.h>
@ -309,81 +310,6 @@ static at::Tensor view3d(const at::Tensor& tensor) {
return tensor.squeeze(2);
}
Attr get_onednn_conv_sum_attr(
const Tensor& input_r,
const Tensor& weight_r,
IntArrayRef stride_,
IntArrayRef padding_,
IntArrayRef dilation_,
Tensor& accumu,
double scale,
Tensor& output,
bool& is_fused,
Attr attr = Attr(),
bool force_inplace = false) {
is_fused = true;
if (scale == 0.f)
return attr;
auto ndim = input_r.ndimension();
auto output_size = conv_dst_size(
ndim,
input_r.sizes(),
weight_r.sizes(),
padding_,
padding_,
stride_,
dilation_);
MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
auto input_fmt = input_r.suggest_memory_format();
auto input_is_cl =
(input_fmt == at::MemoryFormat::ChannelsLast ||
input_fmt == at::MemoryFormat::ChannelsLast3d);
auto weight_fmt = weight_r.suggest_memory_format();
auto weight_is_cl =
(weight_fmt == at::MemoryFormat::ChannelsLast ||
weight_fmt == at::MemoryFormat::ChannelsLast3d);
bool propagate_channels_last = input_is_cl || weight_is_cl;
if (propagate_channels_last)
mem_fmt = get_cl_tag_by_ndim(ndim);
Tensor out = at::empty(output_size, input_r.options().memory_format(mem_fmt));
if (!onednn::binary_valid(out, accumu)) {
is_fused = false;
return attr;
}
// For post-sum and post-binary-add, onednn needs sum/binary scale=1.f
// Thus we need the following transformation
// conv(src, wei) + scale * accumu
// scale * (1/scale * conv(src, wei) + sum (or binary))
if (scale != 1.f)
attr.append_post_eltwise(
/* scale */ 1.f,
/* alpha */ 1.f / scale,
/* beta */ 0.f,
attr.kind_with_linear);
if (force_inplace) {
// If sizes are the same, post sum is used.
output = accumu;
attr.append_post_sum(/* sum_scale */ 1.f);
} else {
// If sizes are different, post binary is used.
attr.append_post_binary(attr.kind_with_binary_add, accumu);
}
if (scale != 1.f)
attr.append_post_eltwise(
/* scale */ 1.f,
/* alpha */ scale,
/* beta */ 0.f,
attr.kind_with_linear);
return attr;
}
} // namespace impl
using namespace impl;
@ -476,6 +402,8 @@ Tensor _convolution_out(
params.output_padding,
params.groups);
output = at::empty(dst_tz, input.options(), mfmt);
} else {
output = output_r;
}
onednn::deconvolution(
@ -518,6 +446,8 @@ Tensor _convolution_out(
params.stride,
params.dilation);
output = at::empty(dst_tz, input.options(), mfmt);
} else {
output = output_r;
}
onednn::convolution(
output,
@ -751,6 +681,119 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
Tensor convolution_pointwise(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::string_view attr,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
c10::DeviceGuard device_guard(input_t.device());
Attr att;
att = construct_unary_attr(att, attr, scalars, algorithm);
const Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
return _convolution(
input_t,
weight_t,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
/*output_padding*/ {0},
groups,
att);
}
Tensor convolution_pointwise_binary(
const Tensor& input_t,
const Tensor& other_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::string_view binary_attr,
std::optional<at::Scalar> alpha,
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
c10::DeviceGuard device_guard(input_t.device());
Tensor output;
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
// Step1: Construct binary attr
Attr attr;
attr = construct_binary_attr(attr, binary_attr, other_t);
// Step2: Append unary attr
if (unary_attr.has_value())
attr = construct_unary_attr(
attr, unary_attr.value(), unary_scalars, unary_algorithm);
Tensor res = _convolution_out(
output,
input_t,
weight_t,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
/*output_padding*/ {0},
groups,
attr);
// Step3: Run conv
return res;
}
Tensor& convolution_pointwise_binary_(
Tensor& other_t,
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::string_view binary_attr,
std::optional<at::Scalar> alpha,
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
c10::DeviceGuard device_guard(input_t.device());
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
// Step1: Construct binary attr
Attr attr;
attr = construct_binary_attr(attr, binary_attr, other_t);
// Step2: Append unary attr
if (unary_attr.has_value())
attr = construct_unary_attr(
attr, unary_attr.value(), unary_scalars, unary_algorithm);
_convolution_out(
other_t,
input_t,
weight_t,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
/*output_padding*/ {0},
groups,
attr);
// Step3: Run conv
return other_t;
}
TORCH_LIBRARY_IMPL(aten, XPU, m) {
m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
m.impl(
@ -758,4 +801,16 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
TORCH_FN(convolution_backward_overrideable));
}
TORCH_LIBRARY_IMPL(mkldnn, XPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"),
TORCH_FN(convolution_pointwise));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"),
TORCH_FN(convolution_pointwise_binary));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
TORCH_FN(convolution_pointwise_binary_));
}
} // namespace at::native::xpu

View File

@ -4981,7 +4981,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS, MTIA: _reshape_alias
# We don't need to support mkldnn since this is handled explicitly by the reshape operator.
- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
@ -10236,7 +10236,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CPU, CUDA, Meta, MPS: unfold
CPU, CUDA, Meta, MPS, MTIA: unfold
QuantizedCPU, QuantizedCUDA: unfold
- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor

View File

@ -356,13 +356,14 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d + 1);
}
return at::_sparse_coo_tensor_with_dims_and_tensors(
sparse_dim,
dense_dim,
computed_sizes,
return at::native::_sparse_coo_tensor_unsafe(
indices,
values,
values.options().layout(kSparse),
computed_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
is_coalesced);
}

View File

@ -46,6 +46,7 @@
#include <ATen/ops/_triton_multi_head_attention_native.h>
#include <ATen/ops/_triton_scaled_dot_attention.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/narrow_native.h>
@ -963,33 +964,98 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
std::optional<double> scale) {
// Used for tracking usage statistics
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
Tensor q_t = query.transpose(1, 2);
Tensor k_t = key.transpose(1, 2);
Tensor v_t = value.transpose(1, 2);
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
int64_t batch_size = query.size(0);
sdp::CustomMaskType custom_mask_type = is_causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
if (batch_size > MAX_BATCH_SIZE) {
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
}
auto process_chunk = [&](const Tensor& q_chunk,
const Tensor& k_chunk,
const Tensor& v_chunk,
const std::optional<Tensor>& bias_chunk)
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
Tensor q_t = q_chunk.transpose(1, 2);
Tensor k_t = k_chunk.transpose(1, 2);
Tensor v_t = v_chunk.transpose(1, 2);
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
q_t,
k_t,
v_t,
attn_bias,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
dropout_p,
static_cast<int64_t>(custom_mask_type),
compute_log_sumexp,
scale);
sdp::CustomMaskType custom_mask_type = is_causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
attention = attention.transpose(1, 2);
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] =
at::_efficient_attention_forward(
q_t,
k_t,
v_t,
bias_chunk,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
dropout_p,
static_cast<int64_t>(custom_mask_type),
compute_log_sumexp,
scale);
attention = attention.transpose(1, 2);
return std::make_tuple(std::move(attention),
std::move(log_sumexp),
std::move(seed),
std::move(offset));
};
// when bs is larger than allowed maximum, process in chunks
if (batch_size > MAX_BATCH_SIZE) {
int64_t start = 0;
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
Tensor query_chunk = query.slice(0, start, end);
Tensor key_chunk = key.slice(0, start, end);
Tensor value_chunk = value.slice(0, start, end);
std::optional<Tensor> bias_chunk;
if (attn_bias.has_value()) {
bias_chunk = attn_bias.value().slice(0, start, end);
}
auto [attn, log_sumexp, seed, offset] =
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
int dim = attn.dim();
std::vector<int64_t> sizes;
sizes.reserve(dim);
sizes.push_back(batch_size);
for (int i = 1; i < dim; i++) {
sizes.push_back(attn.size(i));
}
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
final_attention.slice(0, start, end).copy_(attn);
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
end = std::min(start + MAX_BATCH_SIZE, batch_size);
query_chunk = query.slice(0, start, end);
key_chunk = key.slice(0, start, end);
value_chunk = value.slice(0, start, end);
if (attn_bias.has_value()) {
bias_chunk = attn_bias.value().slice(0, start, end);
} else {
bias_chunk.reset();
}
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
final_attention.slice(0, start, end).copy_(chunk_attn);
}
return std::make_tuple(std::move(final_attention),
std::move(log_sumexp),
std::move(seed),
std::move(offset));
}
// when bs is within the allowed size, no need to chunk it
else {
return process_chunk(query, key, value, attn_bias);
}
}
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,

View File

@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18140000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
update_hint_regression,compile_time_instruction_count,1681000000,0.02
update_hint_regression,compile_time_instruction_count,1700000000,0.02
float_args,compile_time_instruction_count,449800000,0.015
float_args,compile_time_instruction_count,452500000,0.015
@ -54,24 +54,24 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2112000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015

1 add_loop_eager compile_time_instruction_count 2953000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5808000000 5738000000 0.025
3 add_loop_inductor compile_time_instruction_count 29370000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44010000000 44490000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 25900000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 939900000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18140000000 18270000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16220000000 16310000000 0.015
10 update_hint_regression compile_time_instruction_count 1681000000 1700000000 0.02
11 float_args compile_time_instruction_count 449800000 452500000 0.015
12 sum_floordiv_regression compile_time_instruction_count 998600000 0.015
13 symint_sum compile_time_instruction_count 3252000000 0.015
14 symint_sum_loop compile_time_instruction_count 4262000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2091000000 2112000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5981000000 6022000000 0.015
22
23
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

View File

@ -178,6 +178,7 @@ THIRD_PARTY_LIBS = {
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
"ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"],

View File

@ -15,6 +15,7 @@ cxx_library(
"//third_party:cpuinfo",
"//third_party:fmt",
"//third_party:glog",
"//third_party:moodycamel",
],
exported_deps = [],
compiler_flags = [

View File

@ -96,6 +96,7 @@ if(NOT BUILD_LIBTORCHLESS)
endif()
target_link_libraries(c10 PRIVATE fmt::fmt-header-only)
target_link_libraries(c10 PRIVATE nlohmann)
target_link_libraries(c10 PRIVATE moodycamel)
if(C10_USE_NUMA)
message(STATUS "NUMA paths:")

View File

@ -12,24 +12,49 @@ namespace c10 {
template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
bool is_contiguous = true;
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
return is_contiguous;
return true;
}
T z = 1;
T expected_stride = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
z *= size_d;
} else {
is_contiguous = false;
break;
}
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
continue;
}
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return is_contiguous;
return true;
}
// This function will return True if the tensor is contiguous, and False if the
// its not or if we can't determine if it is contiguous due to unbacked symbols
// (it could be either in that case based on the actual runtime data).
template <typename T>
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
return true;
}
T expected_stride = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
continue;
}
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return true;
}
template <typename T>

View File

@ -188,18 +188,21 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
return guard_bool(file, line);
}
virtual bool guard_or_false(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
// Note: PT2 primarily uses PythonSymNodeImpl for this functionality.
// XLA is currently the main consumer of this fallback path since it uses
// ahead-of-time compilation and cannot depend on Python runtime.
return guard_bool(file, line);
}
virtual bool statically_known_true(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
// Note: PT2 primarily uses PythonSymNodeImpl for this functionality.
// XLA is currently the main consumer of this fallback path since it uses
// ahead-of-time compilation and cannot depend on Python runtime.
return guard_bool(file, line);
}
virtual bool guard_or_true(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
// Note: PT2 primarily uses PythonSymNodeImpl for this functionality.
// XLA is currently the main consumer of this fallback path since it uses
// ahead-of-time compilation and cannot depend on Python runtime.
return guard_bool(file, line);
}
virtual bool expect_true(const char* file, int64_t line) {

View File

@ -833,8 +833,9 @@ class EventPool {
// CUDA graphs helper
struct PrivatePool {
PrivatePool(MempoolId_t id)
PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
@ -855,8 +856,14 @@ struct PrivatePool {
// distinguish private blocks by adding a "pool id" check above the stream
// check in BlockComparator. BlockComparator is performance- critical though,
// I'd rather not add more logic to it.
CUDAAllocator* allocator_;
BlockPool large_blocks;
BlockPool small_blocks;
public:
CUDAAllocator* allocator() {
return allocator_;
}
};
MempoolId_t BlockPool::owner_MempoolId() const {
@ -905,9 +912,8 @@ struct MempoolIdHash {
};
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) {
*ptr = active_pool->allocator()->raw_alloc(size);
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
} else {
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
@ -1277,14 +1283,14 @@ class DeviceCachingAllocator {
alloc_block(params, false, context, lock))
// Free all non-split cached blocks and retry alloc.
|| (C10_LIKELY(captures_underway.empty()) &&
release_cached_blocks(context) &&
release_cached_blocks(context, {0, 0}) &&
alloc_block(params, true, context, lock));
}
// we are about to oom, try to use existing mempools as a last resort
if (!block_found && params.err == cudaErrorMemoryAllocation) {
// if already trying to use a mempool, then just oom
auto active_pool = MemPoolContext::getActiveMemPool();
bool active_pool = params.pool->owner_PrivatePool;
if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) {
auto tid = std::this_thread::get_id();
@ -1671,10 +1677,10 @@ class DeviceCachingAllocator {
}
/** returns cached blocks to the system allocator **/
void emptyCache() {
void emptyCache(MempoolId_t mempool_id) {
auto context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
release_cached_blocks(context);
release_cached_blocks(context, mempool_id);
}
/** Retrieves size of largest unused block held by the memory cache **/
@ -1992,16 +1998,10 @@ class DeviceCachingAllocator {
/** Dump a complete snapshot of the memory held by the allocator. Potentially
* VERY expensive. **/
std::vector<SegmentInfo> snapshot() {
std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<Block*> all_blocks;
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first != 0 || mempool_id.second != 0) {
// If there is an active mempool, we find the corresponding PrivatePool
@ -2011,7 +2011,7 @@ class DeviceCachingAllocator {
all_blocks = get_private_pool_head_blocks(pool->second.get());
}
} else {
// When snapshot is called outside a MemPoolContext, we return
// When snapshot is called with non-default mempool_id, we return
// all the blocks in the CUDACachingAllocator (as returned by
// get_all_blocks).
all_blocks = get_all_blocks();
@ -2130,11 +2130,11 @@ class DeviceCachingAllocator {
}
}
void ensureExistsAndIncrefPool(MempoolId_t mempool_id) {
void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) {
// Create a PrivatePool object if it does not exist yet
// and increment its use_count
std::lock_guard<std::recursive_mutex> lock(mutex);
ensure_exists_and_incref_pool(mempool_id);
create_or_incref_pool(mempool_id, allocator);
}
void setUseOnOOM(MempoolId_t mempool_id) {
@ -2150,7 +2150,7 @@ class DeviceCachingAllocator {
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) {
std::lock_guard<std::recursive_mutex> lock(mutex);
ensure_exists_and_incref_pool(mempool_id);
create_or_incref_pool(mempool_id);
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
++it2) {
TORCH_CHECK(
@ -2272,21 +2272,24 @@ class DeviceCachingAllocator {
return blocks;
}
void ensure_exists_and_incref_pool(MempoolId_t mempool_id) {
void create_or_incref_pool(
MempoolId_t mempool_id,
CUDAAllocator* allocator = nullptr) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
// usage. use_count is initially 1, which means the pool is
// being used since somebody called ensureExistsAndIncrefPool.
// being used since somebody called createOrIncrefPool.
graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id));
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
} else {
// mempool_id references an existing pool, which the current CUDAGraph
// capture or torch.cuda.use_mem_pool will
// share. Check this pool is live (at least one other capture already
// references it). Increment it to establish the usage.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++;
}
}
@ -2776,7 +2779,8 @@ class DeviceCachingAllocator {
bool in_fbcode = false;
#endif
auto active_pool = MemPoolContext::getActiveMemPool();
bool active_pool =
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
if (set_fraction &&
total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation;
@ -2801,12 +2805,6 @@ class DeviceCachingAllocator {
}
return bool(p.block);
} else {
if (active_pool && active_pool->allocator() &&
p.pool->owner_PrivatePool) {
// Ensure that active_pool and p.pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool);
}
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
// At scope exit, acquire the lock again. This provides safety against
// any potential exceptions in the cudaMallocMaybeCapturing function.
@ -2926,13 +2924,9 @@ class DeviceCachingAllocator {
return true;
}
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
bool release_cached_blocks(
const std::shared_ptr<GatheredContext>& context,
MempoolId_t mempool_id) {
if (mempool_id.first == 0 && mempool_id.second == 0) {
// If there is no active mempool, we work on releasing *all* blocks.
@ -3005,15 +2999,10 @@ class DeviceCachingAllocator {
context ? context : block->context_when_segment_allocated);
auto* pool = block->pool;
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
// Ensure that active_pool and pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) {
// If there is an active mempool with a given allocator,
// we use the given allocator's delete function.
active_pool->allocator()->raw_delete((void*)block->ptr);
pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr);
} else {
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
}
@ -3589,9 +3578,9 @@ class NativeCachingAllocator : public CUDAAllocator {
}
}
void emptyCache() override {
void emptyCache(MempoolId_t mempool_id) override {
for (auto& da : device_allocator)
da->emptyCache();
da->emptyCache(mempool_id);
}
void enable(bool value) override {
@ -3639,7 +3628,7 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[block->device]->recordStream(block, stream);
}
SnapshotInfo snapshot() override {
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
// Set-up converter to convert timestamps from tsc to microseconds.
auto tsc_to_ns = clock_converter.makeConverter();
auto tsc_to_us = [=](approx_time_t t_approx) {
@ -3657,7 +3646,7 @@ class NativeCachingAllocator : public CUDAAllocator {
// Get the device_traces' TraceEntry lists.
for (auto& da : device_allocator) {
result.device_traces.emplace_back(da->trace(tsc_to_us));
auto snap = da->snapshot();
auto snap = da->snapshot(mempool_id);
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
}
@ -3785,11 +3774,13 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->resetPeakStats();
}
void ensureExistsAndIncrefPool(
void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) override {
MempoolId_t mempool_id,
CUDAAllocator* allocator) override {
assertValidDevice(device);
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
device_allocator[device]->createOrIncrefPool(
std::move(mempool_id), allocator);
}
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
@ -4134,7 +4125,7 @@ MemPool::MemPool(
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
@ -4143,8 +4134,7 @@ MemPool::MemPool(
MemPool::~MemPool() {
TORCH_INTERNAL_ASSERT(use_count() == 1);
CUDACachingAllocator::releasePool(device_, id_);
auto ctx = MemPoolContext(this);
c10::cuda::CUDACachingAllocator::emptyCache();
c10::cuda::CUDACachingAllocator::emptyCache(id_);
}
MempoolId_t MemPool::id() {
@ -4170,23 +4160,4 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
return {uuid_++, 0};
}
// Note that active_mempool_ is a global variable here
// and not inside MemPoolContext class, because in windows we
// can't use __declspec(dllexport) and __declspec(thread)
// together: https://stackoverflow.com/a/50967977
static thread_local MemPool* active_mempool_ = nullptr;
MemPoolContext::MemPoolContext(MemPool* mempool)
: prev_mempool_(active_mempool_) {
active_mempool_ = mempool;
}
MemPoolContext::~MemPoolContext() {
active_mempool_ = prev_mempool_;
}
MemPool* MemPoolContext::getActiveMemPool() {
return active_mempool_;
}
} // namespace c10::cuda

View File

@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator {
virtual bool initialized() = 0;
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
virtual void emptyCache() = 0;
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void enable(bool value) = 0;
virtual bool isEnabled() const = 0;
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator {
c10::DeviceIndex device) = 0;
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
virtual SnapshotInfo snapshot() = 0;
virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
@ -239,13 +239,14 @@ class CUDAAllocator : public Allocator {
" does not yet support getPoolUseCount. "
"If you need it, please file an issue describing your use case.");
}
virtual void ensureExistsAndIncrefPool(
virtual void createOrIncrefPool(
c10::DeviceIndex /*device*/,
MempoolId_t /*mempool_id*/) {
MempoolId_t /*mempool_id*/,
CUDAAllocator* allocator = nullptr) {
TORCH_CHECK(
false,
name(),
" does not yet support ensureExistsAndIncrefPool. "
" does not yet support createOrIncrefPool. "
"If you need it, please file an issue describing your use case.");
}
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
@ -364,7 +365,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device);
}
inline void emptyCache() {
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
return get()->emptyCache();
}
@ -401,8 +402,8 @@ inline void resetPeakStats(c10::DeviceIndex device) {
return get()->resetPeakStats(device);
}
inline SnapshotInfo snapshot() {
return get()->snapshot();
inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
return get()->snapshot(mempool_id);
}
inline std::shared_ptr<AllocatorState> getCheckpointState(
@ -475,10 +476,11 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->releasePool(device, mempool_id);
}
inline void ensureExistsAndIncrefPool(
inline void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) {
get()->ensureExistsAndIncrefPool(device, mempool_id);
MempoolId_t mempool_id,
CUDAAllocator* allocator_ptr = nullptr) {
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
}
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
@ -555,26 +557,4 @@ struct C10_CUDA_API MemPool {
c10::DeviceIndex device_;
};
// MemPoolContext holds the currently active pool and stashes the previous
// pool. On deletion it makes the previous pool active.
struct C10_CUDA_API MemPoolContext {
MemPoolContext(MemPool* mempool);
~MemPoolContext();
// getActiveMemPool() can be used to get the currently active pool.
// For instance: in CUDACachingAllocator, we can route allocations
// to a user provided allocator, by doing:
//
// auto active_pool = MemPoolContext::getActiveMemPool();
// if (active_pool && active_pool->allocator()) {
// ptr = active_pool->allocator()->raw_alloc(size);
// }
//
static MemPool* getActiveMemPool();
private:
MemPool* prev_mempool_;
};
} // namespace c10::cuda

View File

@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// introduces performance nondeterminism.
}
void emptyCache() override {
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
std::lock_guard<std::mutex> lk(general_mutex);
for (int dev = 0; dev < device_count; dev++) {
@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
}
SnapshotInfo snapshot() override {
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
TORCH_CHECK(
false,
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "

View File

@ -0,0 +1,35 @@
#include <c10/util/Semaphore.h>
#include <c10/util/irange.h>
#include <gtest/gtest.h>
#include <thread>
using namespace ::testing;
TEST(SemaphoreTest, TestConcurrency) {
auto num_threads = std::thread::hardware_concurrency();
auto num_incr = 10000;
c10::Semaphore sem;
std::vector<std::thread> threads;
for ([[maybe_unused]] const auto _ : c10::irange(num_threads)) {
threads.emplace_back([num_incr = num_incr, &sem]() {
for ([[maybe_unused]] const auto _ : c10::irange(num_incr)) {
sem.release();
}
for ([[maybe_unused]] const auto _ : c10::irange(num_incr)) {
sem.acquire();
}
sem.release(num_incr);
for ([[maybe_unused]] const auto _ : c10::irange(num_incr)) {
sem.acquire();
}
});
}
std::for_each(
threads.begin(), threads.end(), [](std::thread& t) { t.join(); });
EXPECT_FALSE(sem.tryAcquire());
}

View File

@ -289,8 +289,8 @@ class C10_API OutOfMemoryError : public Error {
using Error::Error;
};
// Used for handling syntacitc erros in input arguments.
// They shuld turn into SytnaxError when the cross into Python
// Used for handling syntactic errors in input arguments.
// These turn into SyntaxError when the cross into Python.
class C10_API SyntaxError : public Error {
using Error::Error;
};

71
c10/util/Semaphore.h Normal file
View File

@ -0,0 +1,71 @@
#pragma once
#include <version>
/*
a simple semaphore interface.
*/
// note: __cpp_lib_semaphore will not be defined in some apple platforms
// even if >= C++20.
#if __has_include(<semaphore>) && defined(__cpp_lib_semaphore) && __cpp_lib_semaphore >= 201907L
#define C10_SEMAPHORE_USE_STL
#endif
#ifdef C10_SEMAPHORE_USE_STL
#include <semaphore>
#else
// To use moodycamel semaphore, we need to include the header file
// for concurrentqueue first. Hiding implementation detail here.
#ifdef BLOCK_SIZE
#pragma push_macro("BLOCK_SIZE")
#undef BLOCK_SIZE
#include <moodycamel/concurrentqueue.h> // @manual
#pragma pop_macro("BLOCK_SIZE")
#else
#include <moodycamel/concurrentqueue.h> // @manual
#endif
#include <moodycamel/lightweightsemaphore.h> // @manual
#endif
namespace c10 {
class Semaphore {
public:
Semaphore(int32_t initial_count = 0) : impl_(initial_count) {}
void release(int32_t n = 1) {
#ifdef C10_SEMAPHORE_USE_STL
impl_.release(n);
#else
impl_.signal(n);
#endif
}
void acquire() {
#ifdef C10_SEMAPHORE_USE_STL
impl_.acquire();
#else
impl_.wait();
#endif
}
bool tryAcquire() {
#ifdef C10_SEMAPHORE_USE_STL
return impl_.try_acquire();
#else
return impl_.tryWait();
#endif
}
private:
#ifdef C10_SEMAPHORE_USE_STL
std::counting_semaphore<> impl_;
#else
moodycamel::LightweightSemaphore impl_;
#endif
};
} // namespace c10
#undef C10_SEMAPHORE_USE_STL

View File

@ -36,6 +36,7 @@ def define_targets(rules):
":bit_cast",
"//c10/macros",
"@fmt",
"@moodycamel//:moodycamel",
] + rules.select({
"//c10:using_gflags": ["@com_github_gflags_gflags//:gflags"],
"//conditions:default": [],

View File

@ -1154,6 +1154,7 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE)
list(APPEND Caffe2_DEPENDENCY_LIBS tensorpipe)
list(APPEND Caffe2_DEPENDENCY_LIBS nlohmann)
list(APPEND Caffe2_DEPENDENCY_LIBS moodycamel)
if(USE_CUDA)
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS tensorpipe_cuda)
elseif(USE_ROCM)
@ -1713,3 +1714,7 @@ target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_
# Include nlohmann-json
add_library(nlohmann INTERFACE IMPORTED)
include_directories(nlohmann SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/nlohmann/include)
# Include moodycamel
add_library(moodycamel INTERFACE IMPORTED)
include_directories(moodycamel SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/concurrentqueue)

View File

@ -2282,7 +2282,6 @@ coverage_ignore_classes = [
"UnsynchronizedAccessError",
# torch.cuda.memory
"MemPool",
"MemPoolContext",
# torch.distributed.elastic.multiprocessing.errors
"ChildFailedError",
"ProcessFailure",

View File

@ -128,7 +128,6 @@ Memory management
CUDAPluggableAllocator
change_current_allocator
MemPool
MemPoolContext
.. currentmodule:: torch.cuda.memory

View File

@ -748,6 +748,25 @@ class build_ext(setuptools.command.build_ext.build_ext):
self.copy_file(export_lib, target_lib)
# In ROCm on Windows case copy rocblas and hipblaslt files into
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
use_rocm = os.environ.get("USE_ROCM")
if use_rocm:
rocm_dir_path = os.environ.get("ROCM_DIR")
rocm_bin_path = os.path.join(rocm_dir_path, "bin")
rocblas_dir = os.path.join(rocm_bin_path, "rocblas")
target_rocblas_dir = os.path.join(target_dir, "rocblas")
os.makedirs(target_rocblas_dir, exist_ok=True)
self.copy_tree(rocblas_dir, target_rocblas_dir)
hipblaslt_dir = os.path.join(rocm_bin_path, "hipblaslt")
target_hipblaslt_dir = os.path.join(target_dir, "hipblaslt")
os.makedirs(target_hipblaslt_dir, exist_ok=True)
self.copy_tree(hipblaslt_dir, target_hipblaslt_dir)
else:
report("The specified environment variable does not exist.")
def build_extensions(self):
self.create_compile_commands()

View File

@ -0,0 +1,135 @@
# Owner(s): ["oncall: distributed_checkpointing"]
import os
import sys
import torch
import torch.distributed.checkpoint as dist_cp
from torch import distributed as dist
from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import (
consolidate_safetensors_files,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestConsolidateHFSafeTensors(DTensorTestBase):
def _create_d_tensors(self) -> None:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
# Create local tensor with row-wise sharding
rows_per_rank = global_tensor.shape[0] // self.world_size
start_row = self.rank * rows_per_rank
end_row = start_row + rows_per_rank
local_tensor = global_tensor[start_row:end_row].clone()
# Create DTensor with row-wise sharding
dtensor = DTensor.from_local(
local_tensor,
device_mesh=mesh_1d,
placements=[Shard(0)],
shape=global_tensor.shape,
stride=(4, 1),
)
# Create local tensor with column-wise sharding
cols_per_rank = global_tensor.shape[1] // self.world_size
start_col = self.rank * cols_per_rank
end_col = start_col + cols_per_rank
local_tensor_col = global_tensor[:, start_col:end_col].clone()
# Create DTensor with column-wise sharding
dtensor_col = DTensor.from_local(
local_tensor_col,
device_mesh=mesh_1d,
placements=[Shard(1)], # Column-wise sharding
shape=global_tensor.shape,
stride=(4, 1),
)
state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=self.temp_dir, save_sharded=True
),
)
dist.barrier()
os.sync()
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_consolidate_to_one_file(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
checkpoint_dir = self.temp_dir
output_dir = os.path.join(checkpoint_dir, "consolidated")
os.makedirs(output_dir, exist_ok=True)
self._create_d_tensors()
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
if self.rank == 0:
consolidate_safetensors_files(checkpoint_dir, output_dir)
file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors")
loaded_dict = safetensors.torch.load_file(file_path)
self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"})
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor))
dist.barrier()
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_consolidate_to_two_files(self):
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
checkpoint_dir = self.temp_dir
output_dir = os.path.join(checkpoint_dir, "consolidated")
os.makedirs(output_dir, exist_ok=True)
self._create_d_tensors()
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
if self.rank == 0:
fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2}
consolidate_safetensors_files(
checkpoint_dir, output_dir, fqn_to_index_mapping
)
file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors")
file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors")
loaded_dict = safetensors.torch.load_file(file1_path)
self.assertEqual(loaded_dict.keys(), {"dtensor"})
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
loaded_dict_col = safetensors.torch.load_file(file2_path)
self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"})
self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor))
dist.barrier()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,420 @@
# Owner(s): ["oncall: distributed_checkpointing"]
import sys
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint import _HuggingFaceLoadPlanner
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
CHECKPOINT_DIR = "checkpoint"
class MyTestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(5, 5)
self.linear_2 = torch.nn.Linear(5, 1)
self.emb = torch.nn.EmbeddingBag(5, 10)
class TestSingleRankSaveLoad(TestCase):
@with_temp_dir
def test_save(self) -> None:
try:
from safetensors.torch import load_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=CHECKPOINT_DIR
),
)
state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
@with_temp_dir
def test_load(self) -> None:
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_load = MyTestModule().state_dict()
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(
path=CHECKPOINT_DIR
),
)
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
@with_temp_dir
def test_load_into_empty_dict(self) -> None:
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
state_dict_loaded = _load_state_dict_from_keys(
storage_reader=dist_cp._HuggingFaceStorageReader(
path=CHECKPOINT_DIR
),
)
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
@with_temp_dir
def test_load_allowing_resize(self) -> None:
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
state_dict_to_load= {}
for key in state_dict_to_save.keys():
state_dict_to_load[key] = torch.zeros(1)
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(
path=CHECKPOINT_DIR
),
planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True),
)
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
ONE_D_PLACEMENTS = [
[Shard(0)],
[Replicate()],
]
ONE_D_TO_ONE_D_PLACEMENTS = [
([Replicate()], [Shard(0)]),
([Shard(0)], [Replicate()]),
]
TWO_D_PLACEMENTS = [
[Replicate(), Replicate()],
[Replicate(), Shard(0)],
[Shard(0), Replicate()],
[Shard(0), Shard(0)],
]
TWO_D_TO_TWO_D_PLACEMENTS = []
for p1 in TWO_D_PLACEMENTS:
for p2 in TWO_D_PLACEMENTS:
if p1 != p2:
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
@instantiate_parametrized_tests
class TestDTensorReshardPlacementChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
"""
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_1d_to_1d_reshard_placement_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
original_placement, new_placement = one_d_to_one_d_placements
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, device_mesh, placements=original_placement
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=CHECKPOINT_DIR,
save_sharded=True,
),
)
zero_dtensor = zeros(
[4, 4], device_mesh=device_mesh, placements=new_placement
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(
CHECKPOINT_DIR,
),
)
# materialize the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
device_mesh,
placements=[Replicate()],
)
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
# redistribute the tensor back to its original placement for comparison.
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
device_mesh,
placements=original_placement,
)
self.assertEqual(
state_dict_to_save["dtensor"].to_local(),
state_dict_to_load["dtensor"].to_local(),
)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_2d_to_2d_reshard_placement_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS:
original_placement, new_placement = two_d_to_two_d_placements
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor,
mesh_2d,
placements=original_placement,
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
planner=dist_cp.DefaultSavePlanner(),
)
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
)
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
mesh_2d,
placements=[Replicate(), Replicate()],
)
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
mesh_2d,
placements=original_placement,
)
self.assertEqual(
state_dict_to_save["dtensor"].to_local(),
state_dict_to_load["dtensor"].to_local(),
)
class TestDTensorReshardMeshChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
"""
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_1d_to_2d_reshard_mesh_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for placements_1d in ONE_D_PLACEMENTS:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, mesh_1d, placements=placements_1d
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
)
for placements_2d in TWO_D_PLACEMENTS:
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
zero_dtensor = zeros(
[4, 4], device_mesh=mesh_2d, placements=placements_2d
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load[
"dtensor"
].redistribute(
mesh_2d,
placements=[Replicate(), Replicate()],
)
self.assertEqual(
global_tensor, state_dict_to_load["dtensor"].to_local()
)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_2d_to_1d_reshard_mesh_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for placements_2d in TWO_D_PLACEMENTS:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, mesh_2d, placements=placements_2d
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
planner=dist_cp.DefaultSavePlanner(),
)
for placements_1d in ONE_D_PLACEMENTS:
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
zero_dtensor = zeros(
[4, 4], device_mesh=mesh_1d, placements=placements_1d
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load[
"dtensor"
].redistribute(
mesh_1d,
placements=[Replicate()],
)
self.assertEqual(
global_tensor, state_dict_to_load["dtensor"].to_local()
)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_dtensor_checkpoint_resharding_with_empty_shard(self):
"""
Test dtensor checkpoint resharding with dtensor containing empty shards.
"""
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
tensor = torch.rand(1).cuda()
mesh = init_device_mesh(self.device_type, (self.world_size,))
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
ref_state_dict = {"dtensor": dtensor}
dist_cp.save(
state_dict=ref_state_dict,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True),
)
tensor = torch.rand(1).cuda()
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
state_dict = {"dtensor": dtensor}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir),
)
if __name__ == "__main__":
run_tests()

View File

@ -8,10 +8,7 @@ import tempfile
from unittest.mock import MagicMock
import torch
from torch.distributed.checkpoint._hf_planner import (
_FqnToFileMapping,
_HuggingFaceLoadPlanner,
)
from torch.distributed.checkpoint import DefaultLoadPlanner
from torch.distributed.checkpoint._hf_storage import (
_HuggingFaceStorageReader,
_HuggingFaceStorageWriter,
@ -21,24 +18,25 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadPlan, SavePlan
from torch.distributed.checkpoint.planner_helpers import (
_create_read_items,
_create_write_item_for_tensor,
from torch.distributed.checkpoint.planner import (
LoadItemType,
LoadPlan,
ReadItem,
SavePlan,
)
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
from torch.distributed.checkpoint.storage import WriteResult
from torch.testing._internal.common_utils import run_tests, TestCase
class TestHfStorage(TestCase):
def test_write_data_hf(self) -> None:
mock_module = MagicMock()
sys.modules["safetensors"] = mock_module
sys.modules["huggingface_hub"] = mock_module
mock_module = MagicMock()
mock_module.save.return_value = b""
sys.modules["safetensors.torch"] = mock_module
@ -46,7 +44,7 @@ class TestHfStorage(TestCase):
with tempfile.TemporaryDirectory() as path:
writer = _HuggingFaceStorageWriter(
path=path,
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1},
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2},
)
writer.fs = FileSystem()
@ -59,7 +57,7 @@ class TestHfStorage(TestCase):
save_plan = SavePlan(
[write_item_1, write_item_2],
storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}),
storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}},
)
save_planner = DefaultSavePlanner()
save_planner.set_up_planner(state_dict=state_dict)
@ -76,7 +74,7 @@ class TestHfStorage(TestCase):
),
size_in_bytes=tensor0.numel() * tensor0.element_size(),
storage_data=_StorageInfo(
relative_path="model-00001-of-00001.safetensors",
relative_path="model-00001-of-00002.safetensors",
offset=0,
length=tensor0.numel() * tensor0.element_size(),
),
@ -87,7 +85,68 @@ class TestHfStorage(TestCase):
),
size_in_bytes=tensor1.numel() * tensor1.element_size(),
storage_data=_StorageInfo(
relative_path="model-00001-of-00001.safetensors",
relative_path="model-00002-of-00002.safetensors",
offset=0,
length=tensor1.numel() * tensor1.element_size(),
),
),
]
self.assertEqual(
actual_write_results,
expected_write_results,
)
def test_write_data_with_sharding(self) -> None:
mock_module = MagicMock()
mock_module.save.return_value = b""
sys.modules["safetensors.torch"] = mock_module
with tempfile.TemporaryDirectory() as path:
writer = _HuggingFaceStorageWriter(
path=path,
save_sharded=True,
)
writer.fs = FileSystem()
tensor0 = torch.rand(4)
tensor1 = torch.rand(10)
write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0)
write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1)
state_dict = {"tensor_0": tensor0, "tensor_1": tensor1}
save_plan = SavePlan(
[write_item_1, write_item_2],
storage_data={"shard_index": 1},
)
save_planner = DefaultSavePlanner()
save_planner.set_up_planner(state_dict=state_dict)
write_results = writer.write_data(save_plan, save_planner)
write_results.wait()
actual_write_results = write_results.value()
expected_write_results = [
WriteResult(
index=MetadataIndex(
fqn="tensor_0", offset=torch.Size([0]), index=None
),
size_in_bytes=tensor0.numel() * tensor0.element_size(),
storage_data=_StorageInfo(
relative_path="shard-00001-model-00001-of-00001.safetensors",
offset=0,
length=tensor0.numel() * tensor0.element_size(),
),
),
WriteResult(
index=MetadataIndex(
fqn="tensor_1", offset=torch.Size([0]), index=None
),
size_in_bytes=tensor1.numel() * tensor1.element_size(),
storage_data=_StorageInfo(
relative_path="shard-00001-model-00001-of-00001.safetensors",
offset=0,
length=tensor1.numel() * tensor1.element_size(),
),
@ -100,43 +159,84 @@ class TestHfStorage(TestCase):
)
def test_read_data_hf(self) -> None:
mock_module = MagicMock()
sys.modules["safetensors"] = mock_module
sys.modules["huggingface_hub"] = mock_module
mock_safetensors = MagicMock()
sys.modules["safetensors"] = mock_safetensors
name = "tensor_0"
tensor_0 = torch.rand(4)
mock_module = MagicMock()
mock_module.load.return_value = {name: tensor_0}
sys.modules["safetensors.torch"] = mock_module
# Create test tensors
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Mock the deserialize function to return our test tensors
# The format matches what's expected in the read_data method
mock_safetensors.deserialize.return_value = [
("tensor_0", {
"data": tensor_0.numpy().tobytes(),
"dtype": "F32",
"shape": [4]
}),
]
with tempfile.TemporaryDirectory() as path:
# Create the reader
reader = _HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
file_name = "model-00001-of-00001"
pathlib.Path(os.path.join(path, file_name)).touch()
# Create test file
file_name = "model-00001-of-00001.safetensors"
file_path = os.path.join(path, file_name)
pathlib.Path(file_path).touch()
reader.set_up_storage_reader(
Metadata(
state_dict_metadata={name: BytesStorageMetadata()},
storage_data={name: file_name},
),
is_coordinator=True,
)
# Set up storage data with _StorageInfo objects
storage_data = {
"tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()),
}
read_items = _create_read_items(name, BytesStorageMetadata(), file_name)
reader.storage_data = storage_data
# Create target tensors that will be updated by read_data
target_tensor_0 = torch.zeros(4)
state_dict = {
"tensor_0": target_tensor_0,
}
# Create read items for the load plan
read_items = []
for name, tensor in state_dict.items():
storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
read_items.append(
ReadItem(
type=LoadItemType.TENSOR,
storage_index=storage_index,
dest_index=dest_index,
storage_offsets=[0, 0],
dest_offsets=[0, 0],
lengths=tensor.size(),
)
)
# Create load plan and planner
load_plan = LoadPlan(read_items)
load_planner = _HuggingFaceLoadPlanner()
load_planner.set_up_planner(state_dict={name: torch.rand(4)})
load_planner = DefaultLoadPlanner()
load_planner.set_up_planner(
state_dict=state_dict,
metadata=Metadata(
state_dict_metadata={
"tensor_0": TensorStorageMetadata(
properties=TensorProperties(dtype=torch.float32),
size=torch.Size([4]),
chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])},
storage_data=storage_data)
)
read_data = reader.read_data(load_plan, load_planner)
read_data.wait()
# Call read_data
future = reader.read_data(load_plan, load_planner)
future.wait()
loaded_tensor = load_planner.original_state_dict[name]
self.assertEqual(loaded_tensor, tensor_0)
# Verify results - the target tensors should now contain the values from our test tensor
self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0))
def test_metadata_hf(self) -> None:
def test_write_metadata_hf(self) -> None:
mock_module = MagicMock()
sys.modules["huggingface_hub"] = mock_module
with tempfile.TemporaryDirectory() as path:
@ -160,7 +260,6 @@ class TestHfStorage(TestCase):
writer = _HuggingFaceStorageWriter(
path=path,
fqn_to_index_mapping=_FqnToFileMapping({}),
)
writer.fs = FileSystem()
writer.finish(
@ -185,26 +284,16 @@ class TestHfStorage(TestCase):
metadata = json.load(f)
self.assertEqual(metadata, expected_metadata)
reader = _HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
metadata = reader.read_metadata()
self.assertEqual(metadata.storage_data, expected_metadata["weight_map"])
def test_read_metadata_when_metadata_file_does_not_exist(self) -> None:
mock_module = MagicMock()
sys.modules["huggingface_hub"] = mock_module
def test_read_metadata_hf(self):
with tempfile.TemporaryDirectory() as path:
reader = _HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
# there is one safetensor file, but no metadata file,
# so we create metadata from the safetensor file
keys = ["tensor_0", "tensor_1"]
key = "tensor_0"
file_name = "test.safetensors"
with open(os.path.join(path, file_name), "wb") as f:
# write metadata the same way it would be in safetensors file
metadata_contents = json.dumps(
{"tensor_0": "value_0", "tensor_1": "value_1"}
{'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}}
)
metadata_bytes = metadata_contents.encode("utf-8")
@ -216,13 +305,16 @@ class TestHfStorage(TestCase):
self.assertEqual(
metadata.state_dict_metadata,
{
keys[0]: BytesStorageMetadata(),
keys[1]: BytesStorageMetadata(),
key: TensorStorageMetadata(
properties=TensorProperties(dtype=torch.float32),
size=torch.Size([5, 10]),
chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))],
),
},
)
self.assertEqual(
metadata.storage_data,
{keys[0]: file_name, keys[1]: file_name},
{key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)},
)

View File

@ -109,6 +109,27 @@ class MLPModule(torch.nn.Module):
return x
class MLPKWargModule(torch.nn.Module):
def __init__(self, d_hid: int, layer_num):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
self.layer_num = layer_num
def forward(self, x, unused_kwarg: torch.Tensor = torch.zeros(1)):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
# Test when only 1 module has extra outputs
# TODO: handle this case later
# if self.layer_num == 0:
# return x, unused_kwarg
# else:
# return x
return x
# Multi-MLP model
class MultiMLP(torch.nn.Module):
def __init__(self, d_hid: int, n_layers: int = 2):
@ -125,6 +146,29 @@ class MultiMLP(torch.nn.Module):
return x
# Multi-MLP with kwargs model
class MultiMLPKwargs(torch.nn.Module):
def __init__(self, d_hid: int, n_layers: int = 2):
super().__init__()
self.layers = torch.nn.ModuleList(
[MLPKWargModule(d_hid, i) for i in range(n_layers)]
)
# For testing purpose only, this should be defined by user
self.split_spec = {
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
}
def forward(self, x, unused_kwarg: torch.Tensor = torch.zeros(1)):
for layer in self.layers:
# TODO: handle this case later
# if layer.layer_num == 0:
# x, _ = layer(x, unused_kwarg)
# else:
# x = layer(x)
x = layer(x)
return x
class CustomLinearDx(Function):
@staticmethod
def forward(ctx, input_val, weight, bias, module, layer_idx):

View File

@ -4,7 +4,7 @@ import copy
import logging
import tempfile
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPKwargs, MultiMLPWithDw
from schedule_registry import (
ScheduleUnbalanced,
ScheduleVShaped,
@ -946,6 +946,113 @@ class ScheduleTest(MultiProcContinousTest):
ref_p = ref_submod.get_parameter(name)
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize(
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
)
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
full_mod = MultiMLPKwargs(d_hid, n_layers=n_stages)
full_mod.to(self.device)
ref_mod = copy.deepcopy(full_mod)
x = torch.randn(batch_size, d_hid, device=self.device)
unused_kwarg = torch.tensor([1.0], device=self.device)
with torch.no_grad():
y = ref_mod(x)
# Add a small perturbation
target = y + torch.randn(batch_size, d_hid, device=self.device)
loss_fn = torch.nn.MSELoss(reduction="sum")
# Get a submodule, e.g. `layers.0` or `layers.1`
stage_indices = [
self.rank + i * self.world_size for i in range(stages_per_rank)
]
submod_names = [f"layers.{i}" for i in stage_indices]
stage_modules = [
full_mod.get_submodule(submod_name) for submod_name in submod_names
]
# Run reference
for _ in range(2):
ref_stage_modules = [
ref_mod.get_submodule(submod_name) for submod_name in submod_names
]
for stage_module in ref_stage_modules:
stage_module.zero_grad()
ref_mod.zero_grad()
ref_out = ref_mod(x, unused_kwarg=unused_kwarg)
ref_loss = loss_fn(ref_out, target)
ref_loss.backward()
# Create a pipeline stage to wrap that submodule
stages = [
PipelineStage(
stage_module,
stage_idx,
n_stages,
self.device,
)
for stage_module, stage_idx in zip(stage_modules, stage_indices)
]
# Attach to a schedule
num_microbatches = (
ScheduleClass.num_microbatches
if hasattr(ScheduleClass, "num_microbatches")
else 2 * self.world_size
)
schedule = ScheduleClass(
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
for _ in range(2):
# Zero gradients
for stage_module in stage_modules:
stage_module.zero_grad()
if self.rank == 0:
schedule.step(
x,
unused_kwarg=unused_kwarg.clone()
.unsqueeze(0)
.expand(num_microbatches, -1),
)
elif self.rank == self.world_size - 1:
losses = []
out = schedule.step(target=target, losses=losses)
else:
schedule.step()
dist.barrier()
# Last rank checks result
if self.rank == self.world_size - 1:
# Check output
torch.testing.assert_close(out, ref_out)
# Check loss
pipe_loss = sum(losses)
torch.testing.assert_close(pipe_loss, ref_loss)
# Every rank checks gradients
for stage_module, submod_name in zip(stage_modules, submod_names):
# Get corresponding submodule from reference model
ref_submod = ref_mod.get_submodule(submod_name)
# Check gradients per parameter
for name, p in stage_module.named_parameters():
ref_p = ref_submod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=5e-3)
except AssertionError:
print(
f"Gradient test failed for {name}: {p.grad=} vs {ref_p.grad=}"
)
raise
instantiate_parametrized_tests(ScheduleTest)

View File

@ -56,6 +56,7 @@ from torch.testing._internal.common_distributed import (
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_SANDCASTLE,
MI300_ARCH,
parametrize,
retry_on_connect_failures,
@ -286,13 +287,15 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
# These tests are expected to throw SIGABRT(6); adding the negative sign
# bc the test return code is actually -6
# But if we are in Sandcastle, `skip_but_pass_in_sandcastle` would return 0.
TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else -signal.SIGABRT
self.special_return_code_checks = {
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN,
self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN,
self.test_nan_assert_float64.__wrapped__: TEST_NAN_ASSERT_RETURN,
self.test_nan_assert_bfloat16.__wrapped__: TEST_NAN_ASSERT_RETURN,
self.test_nan_assert_float8_e4m3fn.__wrapped__: TEST_NAN_ASSERT_RETURN,
self.test_nan_assert_float8_e5m2.__wrapped__: TEST_NAN_ASSERT_RETURN,
}
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests

View File

@ -0,0 +1,231 @@
diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
index 6ff1a8ab29d..ab5bd3dab62 100644
--- a/test/dynamo/cpython/3_13/test_complex.py
+++ b/test/dynamo/cpython/3_13/test_complex.py
@@ -1,16 +1,143 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+import sys
+import torch
+import torch._dynamo.test_case
import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ run_tests,
+ xfailIfTorchDynamo,
+)
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
import sys
-from test import support
-from test.support.testcase import ComplexesAreIdenticalMixin
-from test.support.numbers import (
- VALID_UNDERSCORE_LITERALS,
- INVALID_UNDERSCORE_LITERALS,
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
)
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
+import unittest
+import sys
+from test import support
+from test.support.testcase import ComplexesAreIdenticalMixin
from random import random
from math import isnan, copysign
+import math
import operator
+VALID_UNDERSCORE_LITERALS = [
+ '0_0_0',
+ '4_2',
+ '1_0000_0000',
+ '0b1001_0100',
+ '0xffff_ffff',
+ '0o5_7_7',
+ '1_00_00.5',
+ '1_00_00.5e5',
+ '1_00_00e5_1',
+ '1e1_0',
+ '.1_4',
+ '.1_4e1',
+ '0b_0',
+ '0x_f',
+ '0o_5',
+ '1_00_00j',
+ '1_00_00.5j',
+ '1_00_00e5_1j',
+ '.1_4j',
+ '(1_2.5+3_3j)',
+ '(.5_6j)',
+]
+INVALID_UNDERSCORE_LITERALS = [
+ # Trailing underscores:
+ '0_',
+ '42_',
+ '1.4j_',
+ '0x_',
+ '0b1_',
+ '0xf_',
+ '0o5_',
+ '0 if 1_Else 1',
+ # Underscores in the base selector:
+ '0_b0',
+ '0_xf',
+ '0_o5',
+ # Old-style octal, still disallowed:
+ '0_7',
+ '09_99',
+ # Multiple consecutive underscores:
+ '4_______2',
+ '0.1__4',
+ '0.1__4j',
+ '0b1001__0100',
+ '0xffff__ffff',
+ '0x___',
+ '0o5__77',
+ '1e1__0',
+ '1e1__0j',
+ # Underscore right before a dot:
+ '1_.4',
+ '1_.4j',
+ # Underscore right after a dot:
+ '1._4',
+ '1._4j',
+ '._5',
+ '._5j',
+ # Underscore right after a sign:
+ '1.0e+_1',
+ '1.0e+_1j',
+ # Underscore right before j:
+ '1.4_j',
+ '1.4e5_j',
+ # Underscore right before e:
+ '1_e1',
+ '1.4_e1',
+ '1.4_e1j',
+ # Underscore right after e:
+ '1e_1',
+ '1.4e_1',
+ '1.4e_1j',
+ # Complex cases with parens:
+ '(1+1.5_j_)',
+ '(1+1.5_j)',
+]
+
INF = float("inf")
NAN = float("nan")
DBL_MAX = sys.float_info.max
@@ -45,7 +172,40 @@ class WithComplex:
def __complex__(self):
return self.value
-class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
+class ComplexTest(__TestCase):
+
+ def assertFloatIdentical(self, x, y):
+ """Fail unless floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
+ (2) both x and y are infinities, with the same sign, or
+ (3) both x and y are zeros, with the same sign, or
+ (4) x and y are both finite and nonzero, and x == y
+
+ """
+ msg = 'floats {!r} and {!r} are not identical'
+
+ if math.isnan(x) or math.isnan(y):
+ if math.isnan(x) and math.isnan(y):
+ return
+ elif x == y:
+ if x != 0.0:
+ return
+ # both zero; check that signs match
+ elif math.copysign(1.0, x) == math.copysign(1.0, y):
+ return
+ else:
+ msg += ': zeros have different signs'
+ self.fail(msg.format(x, y))
+
+ def assertComplexesAreIdentical(self, x, y):
+ """Fail unless complex numbers x and y have equal values and signs.
+
+ In particular, if x and y both have real (or imaginary) part
+ zero, but the zeros have different signs, this test will fail.
+
+ """
+ self.assertFloatIdentical(x.real, y.real)
+ self.assertFloatIdentical(x.imag, y.imag)
def assertAlmostEqual(self, a, b):
if isinstance(a, complex):
@@ -74,6 +234,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# check that relative difference < eps
self.assertTrue(abs((x-y)/y) < eps)
+ def assertFloatsAreIdentical(self, x, y):
+ """assert that floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
+ (2) both x and y are infinities, with the same sign, or
+ (3) both x and y are zeros, with the same sign, or
+ (4) x and y are both finite and nonzero, and x == y
+
+ """
+ msg = 'floats {!r} and {!r} are not identical'
+
+ if isnan(x) or isnan(y):
+ if isnan(x) and isnan(y):
+ return
+ elif x == y:
+ if x != 0.0:
+ return
+ # both zero; check that signs match
+ elif copysign(1.0, x) == copysign(1.0, y):
+ return
+ else:
+ msg += ': zeros have different signs'
+ self.fail(msg.format(x, y))
+
def assertClose(self, x, y, eps=1e-9):
"""Return true iff complexes x and y "are close"."""
self.assertCloseAbs(x.real, y.real, eps)
@@ -855,4 +1038,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,85 @@
diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py
index 1b9f3cf7624..d0c68f4314c 100644
--- a/test/dynamo/cpython/3_13/test_iter.py
+++ b/test/dynamo/cpython/3_13/test_iter.py
@@ -1,3 +1,57 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ skipIfTorchDynamo,
+ run_tests,
+)
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
# Test iterators.
import sys
@@ -104,7 +158,7 @@ class EmptyIterClass:
# Main test suite
-class TestCase(unittest.TestCase):
+class TestCase(__TestCase):
# Helper to check that an iterator returns a given sequence
def check_iterator(self, it, seq, pickle=True):
@@ -635,6 +689,7 @@ class TestCase(unittest.TestCase):
pass
# Test zip()'s use of iterators.
+ @skipIfTorchDynamo("infinite loop")
def test_builtin_zip(self):
self.assertEqual(list(zip()), [])
self.assertEqual(list(zip(*[])), [])
@@ -1187,4 +1242,4 @@ class TestCase(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,101 @@
diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py
index 2a7cfb7affa..d661ae544b9 100644
--- a/test/dynamo/cpython/3_13/test_sort.py
+++ b/test/dynamo/cpython/3_13/test_sort.py
@@ -1,3 +1,54 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
from test import support
import random
import unittest
@@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None):
nerrors += 1
return
-class TestBase(unittest.TestCase):
+class TestBase(__TestCase):
def testStressfully(self):
# Try a variety of sizes at and around powers of 2, and at powers of 10.
sizes = [0]
@@ -151,7 +202,7 @@ class TestBase(unittest.TestCase):
self.assertEqual(forced, native)
#==============================================================================
-class TestBugs(unittest.TestCase):
+class TestBugs(__TestCase):
def test_bug453523(self):
# bug 453523 -- list.sort() crasher.
@@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase):
#==============================================================================
-class TestDecorateSortUndecorate(unittest.TestCase):
+class TestDecorateSortUndecorate(__TestCase):
def test_decorated(self):
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
@@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L):
self.assertIs(opt, ref)
#note: not assertEqual! We want to ensure *identical* behavior.
-class TestOptimizedCompares(unittest.TestCase):
+class TestOptimizedCompares(__TestCase):
def test_safe_object_compare(self):
heterogeneous_lists = [[0, 'foo'],
[0.0, 'foo'],
@@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase):
#==============================================================================
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -0,0 +1,462 @@
# ======= BEGIN Dynamo patch =======
# Owner(s): ["module: dynamo"]
# ruff: noqa
# flake8: noqa
import sys
import torch
import torch._dynamo.test_case
import unittest
from torch._dynamo.test_case import CPythonTestCase
from torch.testing._internal.common_utils import run_tests
__TestCase = CPythonTestCase
# redirect import statements
import sys
import importlib.abc
redirect_imports = (
"test.mapping_tests",
"test.typinganndata",
"test.test_grammar",
"test.test_math",
"test.test_iter",
"test.typinganndata.ann_module",
)
class RedirectImportFinder(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
# Check if the import is the problematic one
if fullname in redirect_imports:
try:
# Attempt to import the standalone module
name = fullname.removeprefix("test.")
r = importlib.import_module(name)
# Redirect the module in sys.modules
sys.modules[fullname] = r
# Return a module spec from the found module
return importlib.util.find_spec(name)
except ImportError:
return None
return None
# Add the custom finder to sys.meta_path
sys.meta_path.insert(0, RedirectImportFinder())
# ======= END DYNAMO PATCH =======
from test import support
import random
import unittest
from functools import cmp_to_key
verbose = support.verbose
nerrors = 0
def check(tag, expected, raw, compare=None):
global nerrors
if verbose:
print(" checking", tag)
orig = raw[:] # save input in case of error
if compare:
raw.sort(key=cmp_to_key(compare))
else:
raw.sort()
if len(expected) != len(raw):
print("error in", tag)
print("length mismatch;", len(expected), len(raw))
print(expected)
print(orig)
print(raw)
nerrors += 1
return
for i, good in enumerate(expected):
maybe = raw[i]
if good is not maybe:
print("error in", tag)
print("out of order at index", i, good, maybe)
print(expected)
print(orig)
print(raw)
nerrors += 1
return
class TestBase(__TestCase):
def testStressfully(self):
# Try a variety of sizes at and around powers of 2, and at powers of 10.
sizes = [0]
for power in range(1, 10):
n = 2 ** power
sizes.extend(range(n-1, n+2))
sizes.extend([10, 100, 1000])
class Complains(object):
maybe_complain = True
def __init__(self, i):
self.i = i
def __lt__(self, other):
if Complains.maybe_complain and random.random() < 0.001:
if verbose:
print(" complaining at", self, other)
raise RuntimeError
return self.i < other.i
def __repr__(self):
return "Complains(%d)" % self.i
class Stable(object):
def __init__(self, key, i):
self.key = key
self.index = i
def __lt__(self, other):
return self.key < other.key
def __repr__(self):
return "Stable(%d, %d)" % (self.key, self.index)
for n in sizes:
x = list(range(n))
if verbose:
print("Testing size", n)
s = x[:]
check("identity", x, s)
s = x[:]
s.reverse()
check("reversed", x, s)
s = x[:]
random.shuffle(s)
check("random permutation", x, s)
y = x[:]
y.reverse()
s = x[:]
check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
if verbose:
print(" Checking against an insane comparison function.")
print(" If the implementation isn't careful, this may segfault.")
s = x[:]
s.sort(key=cmp_to_key(lambda a, b: int(random.random() * 3) - 1))
check("an insane function left some permutation", x, s)
if len(x) >= 2:
def bad_key(x):
raise RuntimeError
s = x[:]
self.assertRaises(RuntimeError, s.sort, key=bad_key)
x = [Complains(i) for i in x]
s = x[:]
random.shuffle(s)
Complains.maybe_complain = True
it_complained = False
try:
s.sort()
except RuntimeError:
it_complained = True
if it_complained:
Complains.maybe_complain = False
check("exception during sort left some permutation", x, s)
s = [Stable(random.randrange(10), i) for i in range(n)]
augmented = [(e, e.index) for e in s]
augmented.sort() # forced stable because ties broken by index
x = [e for e, i in augmented] # a stable sort of s
check("stability", x, s)
def test_small_stability(self):
from itertools import product
from operator import itemgetter
# Exhaustively test stability across all lists of small lengths
# and only a few distinct elements.
# This can provoke edge cases that randomization is unlikely to find.
# But it can grow very expensive quickly, so don't overdo it.
NELTS = 3
MAXSIZE = 9
pick0 = itemgetter(0)
for length in range(MAXSIZE + 1):
# There are NELTS ** length distinct lists.
for t in product(range(NELTS), repeat=length):
xs = list(zip(t, range(length)))
# Stability forced by index in each element.
forced = sorted(xs)
# Use key= to hide the index from compares.
native = sorted(xs, key=pick0)
self.assertEqual(forced, native)
#==============================================================================
class TestBugs(__TestCase):
def test_bug453523(self):
# bug 453523 -- list.sort() crasher.
# If this fails, the most likely outcome is a core dump.
# Mutations during a list sort should raise a ValueError.
class C:
def __lt__(self, other):
if L and random.random() < 0.75:
L.pop()
else:
L.append(3)
return random.random() < 0.5
L = [C() for i in range(50)]
self.assertRaises(ValueError, L.sort)
def test_undetected_mutation(self):
# Python 2.4a1 did not always detect mutation
memorywaster = []
for i in range(20):
def mutating_cmp(x, y):
L.append(3)
L.pop()
return (x > y) - (x < y)
L = [1,2]
self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
def mutating_cmp(x, y):
L.append(3)
del L[:]
return (x > y) - (x < y)
self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
memorywaster = [memorywaster]
#==============================================================================
class TestDecorateSortUndecorate(__TestCase):
def test_decorated(self):
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
copy = data[:]
random.shuffle(data)
data.sort(key=str.lower)
def my_cmp(x, y):
xlower, ylower = x.lower(), y.lower()
return (xlower > ylower) - (xlower < ylower)
copy.sort(key=cmp_to_key(my_cmp))
def test_baddecorator(self):
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
def test_stability(self):
data = [(random.randrange(100), i) for i in range(200)]
copy = data[:]
data.sort(key=lambda t: t[0]) # sort on the random first field
copy.sort() # sort using both fields
self.assertEqual(data, copy) # should get the same result
def test_key_with_exception(self):
# Verify that the wrapper has been removed
data = list(range(-2, 2))
dup = data[:]
self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
self.assertEqual(data, dup)
def test_key_with_mutation(self):
data = list(range(10))
def k(x):
del data[:]
data[:] = range(20)
return x
self.assertRaises(ValueError, data.sort, key=k)
def test_key_with_mutating_del(self):
data = list(range(10))
class SortKiller(object):
def __init__(self, x):
pass
def __del__(self):
del data[:]
data[:] = range(20)
def __lt__(self, other):
return id(self) < id(other)
self.assertRaises(ValueError, data.sort, key=SortKiller)
def test_key_with_mutating_del_and_exception(self):
data = list(range(10))
## dup = data[:]
class SortKiller(object):
def __init__(self, x):
if x > 2:
raise RuntimeError
def __del__(self):
del data[:]
data[:] = list(range(20))
self.assertRaises(RuntimeError, data.sort, key=SortKiller)
## major honking subtlety: we *can't* do:
##
## self.assertEqual(data, dup)
##
## because there is a reference to a SortKiller in the
## traceback and by the time it dies we're outside the call to
## .sort() and so the list protection gimmicks are out of
## date (this cost some brain cells to figure out...).
def test_reverse(self):
data = list(range(100))
random.shuffle(data)
data.sort(reverse=True)
self.assertEqual(data, list(range(99,-1,-1)))
def test_reverse_stability(self):
data = [(random.randrange(100), i) for i in range(200)]
copy1 = data[:]
copy2 = data[:]
def my_cmp(x, y):
x0, y0 = x[0], y[0]
return (x0 > y0) - (x0 < y0)
def my_cmp_reversed(x, y):
x0, y0 = x[0], y[0]
return (y0 > x0) - (y0 < x0)
data.sort(key=cmp_to_key(my_cmp), reverse=True)
copy1.sort(key=cmp_to_key(my_cmp_reversed))
self.assertEqual(data, copy1)
copy2.sort(key=lambda x: x[0], reverse=True)
self.assertEqual(data, copy2)
#==============================================================================
def check_against_PyObject_RichCompareBool(self, L):
## The idea here is to exploit the fact that unsafe_tuple_compare uses
## PyObject_RichCompareBool for the second elements of tuples. So we have,
## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])]
## This will work as long as __eq__ => not __lt__ for all the objects in L,
## which holds for all the types used below.
##
## Testing this way ensures that the optimized implementation remains consistent
## with the naive implementation, even if changes are made to any of the
## richcompares.
##
## This function tests sorting for three lists (it randomly shuffles each one):
## 1. L
## 2. [(x,) for x in L]
## 3. [((x,),) for x in L]
random.seed(0)
random.shuffle(L)
L_1 = L[:]
L_2 = [(x,) for x in L]
L_3 = [((x,),) for x in L]
for L in [L_1, L_2, L_3]:
optimized = sorted(L)
reference = [y[1] for y in sorted([(0,x) for x in L])]
for (opt, ref) in zip(optimized, reference):
self.assertIs(opt, ref)
#note: not assertEqual! We want to ensure *identical* behavior.
class TestOptimizedCompares(__TestCase):
def test_safe_object_compare(self):
heterogeneous_lists = [[0, 'foo'],
[0.0, 'foo'],
[('foo',), 'foo']]
for L in heterogeneous_lists:
self.assertRaises(TypeError, L.sort)
self.assertRaises(TypeError, [(x,) for x in L].sort)
self.assertRaises(TypeError, [((x,),) for x in L].sort)
float_int_lists = [[1,1.1],
[1<<70,1.1],
[1.1,1],
[1.1,1<<70]]
for L in float_int_lists:
check_against_PyObject_RichCompareBool(self, L)
def test_unsafe_object_compare(self):
# This test is by ppperry. It ensures that unsafe_object_compare is
# verifying ms->key_richcompare == tp->richcompare before comparing.
class WackyComparator(int):
def __lt__(self, other):
elem.__class__ = WackyList2
return int.__lt__(self, other)
class WackyList1(list):
pass
class WackyList2(list):
def __lt__(self, other):
raise ValueError
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
elem = L[-1]
with self.assertRaises(ValueError):
L.sort()
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
elem = L[-1]
with self.assertRaises(ValueError):
[(x,) for x in L].sort()
# The following test is also by ppperry. It ensures that
# unsafe_object_compare handles Py_NotImplemented appropriately.
class PointlessComparator:
def __lt__(self, other):
return NotImplemented
L = [PointlessComparator(), PointlessComparator()]
self.assertRaises(TypeError, L.sort)
self.assertRaises(TypeError, [(x,) for x in L].sort)
# The following tests go through various types that would trigger
# ms->key_compare = unsafe_object_compare
lists = [list(range(100)) + [(1<<70)],
[str(x) for x in range(100)] + ['\uffff'],
[bytes(x) for x in range(100)],
[cmp_to_key(lambda x,y: x<y)(x) for x in range(100)]]
for L in lists:
check_against_PyObject_RichCompareBool(self, L)
def test_unsafe_latin_compare(self):
check_against_PyObject_RichCompareBool(self, [str(x) for
x in range(100)])
def test_unsafe_long_compare(self):
check_against_PyObject_RichCompareBool(self, [x for
x in range(100)])
def test_unsafe_float_compare(self):
check_against_PyObject_RichCompareBool(self, [float(x) for
x in range(100)])
def test_unsafe_tuple_compare(self):
# This test was suggested by Tim Peters. It verifies that the tuple
# comparison respects the current tuple compare semantics, which do not
# guarantee that x < x <=> (x,) < (x,)
#
# Note that we don't have to put anything in tuples here, because
# the check function does a tuple test automatically.
check_against_PyObject_RichCompareBool(self, [float('nan')]*100)
check_against_PyObject_RichCompareBool(self, [float('nan') for
_ in range(100)])
def test_not_all_tuples(self):
self.assertRaises(TypeError, [(1.0, 1.0), (False, "A"), 6].sort)
self.assertRaises(TypeError, [('a', 1), (1, 'a')].sort)
self.assertRaises(TypeError, [(1, 'a'), ('a', 1)].sort)
def test_none_in_tuples(self):
expected = [(None, 1), (None, 2)]
actual = sorted([(None, 2), (None, 1)])
self.assertEqual(actual, expected)
#==============================================================================
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,89 @@
diff --git a/test/dynamo/cpython/3_13/test_unittest/test_assertions.py b/test/dynamo/cpython/3_13/test_unittest/test_assertions.py
index 1dec947ea76..5a8c2a9d3af 100644
--- a/test/dynamo/cpython/3_13/test_unittest/test_assertions.py
+++ b/test/dynamo/cpython/3_13/test_unittest/test_assertions.py
@@ -1,3 +1,54 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch.testing._internal.common_utils import run_tests
+
+
+__TestCase = torch._dynamo.test_case.CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
import datetime
import warnings
import weakref
@@ -6,7 +57,7 @@ from test.support import gc_collect
from itertools import product
-class Test_Assertions(unittest.TestCase):
+class Test_Assertions(__TestCase):
def test_AlmostEqual(self):
self.assertAlmostEqual(1.00000001, 1.0)
self.assertNotAlmostEqual(1.0000001, 1.0)
@@ -141,12 +192,13 @@ class Test_Assertions(unittest.TestCase):
self.fail('assertNotRegex should have failed.')
-class TestLongMessage(unittest.TestCase):
+class TestLongMessage(__TestCase):
"""Test that the individual asserts honour longMessage.
This actually tests all the message behaviour for
asserts that use longMessage."""
def setUp(self):
+ super().setUp()
class TestableTestFalse(unittest.TestCase):
longMessage = False
failureException = self.failureException
@@ -414,4 +466,4 @@ class TestLongMessage(unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()

View File

@ -0,0 +1,469 @@
# ======= BEGIN Dynamo patch =======
# Owner(s): ["module: dynamo"]
# ruff: noqa
# flake8: noqa
import sys
import torch
import torch._dynamo.test_case
import unittest
from torch.testing._internal.common_utils import run_tests
__TestCase = torch._dynamo.test_case.CPythonTestCase
# redirect import statements
import sys
import importlib.abc
redirect_imports = (
"test.mapping_tests",
"test.typinganndata",
"test.test_grammar",
"test.test_math",
"test.test_iter",
"test.typinganndata.ann_module",
)
class RedirectImportFinder(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
# Check if the import is the problematic one
if fullname in redirect_imports:
try:
# Attempt to import the standalone module
name = fullname.removeprefix("test.")
r = importlib.import_module(name)
# Redirect the module in sys.modules
sys.modules[fullname] = r
# Return a module spec from the found module
return importlib.util.find_spec(name)
except ImportError:
return None
return None
# Add the custom finder to sys.meta_path
sys.meta_path.insert(0, RedirectImportFinder())
# ======= END DYNAMO PATCH =======
import datetime
import warnings
import weakref
import unittest
from test.support import gc_collect
from itertools import product
class Test_Assertions(__TestCase):
def test_AlmostEqual(self):
self.assertAlmostEqual(1.00000001, 1.0)
self.assertNotAlmostEqual(1.0000001, 1.0)
self.assertRaises(self.failureException,
self.assertAlmostEqual, 1.0000001, 1.0)
self.assertRaises(self.failureException,
self.assertNotAlmostEqual, 1.00000001, 1.0)
self.assertAlmostEqual(1.1, 1.0, places=0)
self.assertRaises(self.failureException,
self.assertAlmostEqual, 1.1, 1.0, places=1)
self.assertAlmostEqual(0, .1+.1j, places=0)
self.assertNotAlmostEqual(0, .1+.1j, places=1)
self.assertRaises(self.failureException,
self.assertAlmostEqual, 0, .1+.1j, places=1)
self.assertRaises(self.failureException,
self.assertNotAlmostEqual, 0, .1+.1j, places=0)
self.assertAlmostEqual(float('inf'), float('inf'))
self.assertRaises(self.failureException, self.assertNotAlmostEqual,
float('inf'), float('inf'))
def test_AmostEqualWithDelta(self):
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
self.assertRaises(self.failureException, self.assertNotAlmostEqual,
1.0, 1.0, delta=0.5)
self.assertRaises(self.failureException, self.assertAlmostEqual,
1.1, 1.0, delta=0.05)
self.assertRaises(self.failureException, self.assertNotAlmostEqual,
1.1, 1.0, delta=0.5)
self.assertRaises(TypeError, self.assertAlmostEqual,
1.1, 1.0, places=2, delta=2)
self.assertRaises(TypeError, self.assertNotAlmostEqual,
1.1, 1.0, places=2, delta=2)
first = datetime.datetime.now()
second = first + datetime.timedelta(seconds=10)
self.assertAlmostEqual(first, second,
delta=datetime.timedelta(seconds=20))
self.assertNotAlmostEqual(first, second,
delta=datetime.timedelta(seconds=5))
def test_assertRaises(self):
def _raise(e):
raise e
self.assertRaises(KeyError, _raise, KeyError)
self.assertRaises(KeyError, _raise, KeyError("key"))
try:
self.assertRaises(KeyError, lambda: None)
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
self.assertRaises(KeyError, _raise, ValueError)
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
with self.assertRaises(KeyError) as cm:
try:
raise KeyError
except Exception as e:
exc = e
raise
self.assertIs(cm.exception, exc)
with self.assertRaises(KeyError):
raise KeyError("key")
try:
with self.assertRaises(KeyError):
pass
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
with self.assertRaises(KeyError):
raise ValueError
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
def test_assertRaises_frames_survival(self):
# Issue #9815: assertRaises should avoid keeping local variables
# in a traceback alive.
class A:
pass
wr = None
class Foo(unittest.TestCase):
def foo(self):
nonlocal wr
a = A()
wr = weakref.ref(a)
try:
raise OSError
except OSError:
raise ValueError
def test_functional(self):
self.assertRaises(ValueError, self.foo)
def test_with(self):
with self.assertRaises(ValueError):
self.foo()
Foo("test_functional").run()
gc_collect() # For PyPy or other GCs.
self.assertIsNone(wr())
Foo("test_with").run()
gc_collect() # For PyPy or other GCs.
self.assertIsNone(wr())
def testAssertNotRegex(self):
self.assertNotRegex('Ala ma kota', r'r+')
try:
self.assertNotRegex('Ala ma kota', r'k.t', 'Message')
except self.failureException as e:
self.assertIn('Message', e.args[0])
else:
self.fail('assertNotRegex should have failed.')
class TestLongMessage(__TestCase):
"""Test that the individual asserts honour longMessage.
This actually tests all the message behaviour for
asserts that use longMessage."""
def setUp(self):
super().setUp()
class TestableTestFalse(unittest.TestCase):
longMessage = False
failureException = self.failureException
def testTest(self):
pass
class TestableTestTrue(unittest.TestCase):
longMessage = True
failureException = self.failureException
def testTest(self):
pass
self.testableTrue = TestableTestTrue('testTest')
self.testableFalse = TestableTestFalse('testTest')
def testDefault(self):
self.assertTrue(unittest.TestCase.longMessage)
def test_formatMsg(self):
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
# This blows up if _formatMessage uses string concatenation
self.testableTrue._formatMessage(object(), 'foo')
def test_formatMessage_unicode_error(self):
one = ''.join(chr(i) for i in range(255))
# this used to cause a UnicodeDecodeError constructing msg
self.testableTrue._formatMessage(one, '\uFFFD')
def assertMessages(self, methodName, args, errors):
"""
Check that methodName(*args) raises the correct error messages.
errors should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
def getMethod(i):
useTestableFalse = i < 2
if useTestableFalse:
test = self.testableFalse
else:
test = self.testableTrue
return getattr(test, methodName)
for i, expected_regex in enumerate(errors):
testMethod = getMethod(i)
kwargs = {}
withMsg = i % 2
if withMsg:
kwargs = {"msg": "oops"}
with self.assertRaisesRegex(self.failureException,
expected_regex=expected_regex):
testMethod(*args, **kwargs)
def testAssertTrue(self):
self.assertMessages('assertTrue', (False,),
["^False is not true$", "^oops$", "^False is not true$",
"^False is not true : oops$"])
def testAssertFalse(self):
self.assertMessages('assertFalse', (True,),
["^True is not false$", "^oops$", "^True is not false$",
"^True is not false : oops$"])
def testNotEqual(self):
self.assertMessages('assertNotEqual', (1, 1),
["^1 == 1$", "^oops$", "^1 == 1$",
"^1 == 1 : oops$"])
def testAlmostEqual(self):
self.assertMessages(
'assertAlmostEqual', (1, 2),
[r"^1 != 2 within 7 places \(1 difference\)$", "^oops$",
r"^1 != 2 within 7 places \(1 difference\)$",
r"^1 != 2 within 7 places \(1 difference\) : oops$"])
def testNotAlmostEqual(self):
self.assertMessages('assertNotAlmostEqual', (1, 1),
["^1 == 1 within 7 places$", "^oops$",
"^1 == 1 within 7 places$", "^1 == 1 within 7 places : oops$"])
def test_baseAssertEqual(self):
self.assertMessages('_baseAssertEqual', (1, 2),
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"])
def testAssertSequenceEqual(self):
# Error messages are multiline so not testing on full message
# assertTupleEqual and assertListEqual delegate to this method
self.assertMessages('assertSequenceEqual', ([], [None]),
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$",
r"\+ \[None\] : oops$"])
def testAssertSetEqual(self):
self.assertMessages('assertSetEqual', (set(), set([None])),
["None$", "^oops$", "None$",
"None : oops$"])
def testAssertIn(self):
self.assertMessages('assertIn', (None, []),
[r'^None not found in \[\]$', "^oops$",
r'^None not found in \[\]$',
r'^None not found in \[\] : oops$'])
def testAssertNotIn(self):
self.assertMessages('assertNotIn', (None, [None]),
[r'^None unexpectedly found in \[None\]$', "^oops$",
r'^None unexpectedly found in \[None\]$',
r'^None unexpectedly found in \[None\] : oops$'])
def testAssertDictEqual(self):
self.assertMessages('assertDictEqual', ({}, {'key': 'value'}),
[r"\+ \{'key': 'value'\}$", "^oops$",
r"\+ \{'key': 'value'\}$",
r"\+ \{'key': 'value'\} : oops$"])
def testAssertMultiLineEqual(self):
self.assertMessages('assertMultiLineEqual', ("", "foo"),
[r"\+ foo\n$", "^oops$",
r"\+ foo\n$",
r"\+ foo\n : oops$"])
def testAssertLess(self):
self.assertMessages('assertLess', (2, 1),
["^2 not less than 1$", "^oops$",
"^2 not less than 1$", "^2 not less than 1 : oops$"])
def testAssertLessEqual(self):
self.assertMessages('assertLessEqual', (2, 1),
["^2 not less than or equal to 1$", "^oops$",
"^2 not less than or equal to 1$",
"^2 not less than or equal to 1 : oops$"])
def testAssertGreater(self):
self.assertMessages('assertGreater', (1, 2),
["^1 not greater than 2$", "^oops$",
"^1 not greater than 2$",
"^1 not greater than 2 : oops$"])
def testAssertGreaterEqual(self):
self.assertMessages('assertGreaterEqual', (1, 2),
["^1 not greater than or equal to 2$", "^oops$",
"^1 not greater than or equal to 2$",
"^1 not greater than or equal to 2 : oops$"])
def testAssertIsNone(self):
self.assertMessages('assertIsNone', ('not None',),
["^'not None' is not None$", "^oops$",
"^'not None' is not None$",
"^'not None' is not None : oops$"])
def testAssertIsNotNone(self):
self.assertMessages('assertIsNotNone', (None,),
["^unexpectedly None$", "^oops$",
"^unexpectedly None$",
"^unexpectedly None : oops$"])
def testAssertIs(self):
self.assertMessages('assertIs', (None, 'foo'),
["^None is not 'foo'$", "^oops$",
"^None is not 'foo'$",
"^None is not 'foo' : oops$"])
def testAssertIsNot(self):
self.assertMessages('assertIsNot', (None, None),
["^unexpectedly identical: None$", "^oops$",
"^unexpectedly identical: None$",
"^unexpectedly identical: None : oops$"])
def testAssertRegex(self):
self.assertMessages('assertRegex', ('foo', 'bar'),
["^Regex didn't match:",
"^oops$",
"^Regex didn't match:",
"^Regex didn't match: (.*) : oops$"])
def testAssertNotRegex(self):
self.assertMessages('assertNotRegex', ('foo', 'foo'),
["^Regex matched:",
"^oops$",
"^Regex matched:",
"^Regex matched: (.*) : oops$"])
def assertMessagesCM(self, methodName, args, func, errors):
"""
Check that the correct error messages are raised while executing:
with method(*args):
func()
*errors* should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
p = product((self.testableFalse, self.testableTrue),
({}, {"msg": "oops"}))
for (cls, kwargs), err in zip(p, errors):
method = getattr(cls, methodName)
with self.assertRaisesRegex(cls.failureException, err):
with method(*args, **kwargs) as cm:
func()
def testAssertRaises(self):
self.assertMessagesCM('assertRaises', (TypeError,), lambda: None,
['^TypeError not raised$', '^oops$',
'^TypeError not raised$',
'^TypeError not raised : oops$'])
def testAssertRaisesRegex(self):
# test error not raised
self.assertMessagesCM('assertRaisesRegex', (TypeError, 'unused regex'),
lambda: None,
['^TypeError not raised$', '^oops$',
'^TypeError not raised$',
'^TypeError not raised : oops$'])
# test error raised but with wrong message
def raise_wrong_message():
raise TypeError('foo')
self.assertMessagesCM('assertRaisesRegex', (TypeError, 'regex'),
raise_wrong_message,
['^"regex" does not match "foo"$', '^oops$',
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$'])
def testAssertWarns(self):
self.assertMessagesCM('assertWarns', (UserWarning,), lambda: None,
['^UserWarning not triggered$', '^oops$',
'^UserWarning not triggered$',
'^UserWarning not triggered : oops$'])
def test_assertNotWarns(self):
def warn_future():
warnings.warn('xyz', FutureWarning, stacklevel=2)
self.assertMessagesCM('_assertNotWarns', (FutureWarning,),
warn_future,
['^FutureWarning triggered$',
'^oops$',
'^FutureWarning triggered$',
'^FutureWarning triggered : oops$'])
def testAssertWarnsRegex(self):
# test error not raised
self.assertMessagesCM('assertWarnsRegex', (UserWarning, 'unused regex'),
lambda: None,
['^UserWarning not triggered$', '^oops$',
'^UserWarning not triggered$',
'^UserWarning not triggered : oops$'])
# test warning raised but with wrong message
def raise_wrong_message():
warnings.warn('foo')
self.assertMessagesCM('assertWarnsRegex', (UserWarning, 'regex'),
raise_wrong_message,
['^"regex" does not match "foo"$', '^oops$',
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$'])
if __name__ == "__main__":
run_tests()

View File

@ -130,14 +130,17 @@ class _multiply_invoke(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
getitem: "f32[s21]" = l_inputs_[0]
getitem_1: "f32[s21]" = l_inputs_[1]
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
@ -160,14 +163,17 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
getitem: "f32[s21]" = l_inputs_[0]
getitem_1: "f32[s21]" = l_inputs_[1]
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
@ -242,15 +248,18 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
getitem: "f32[s21]" = l_inputs_[0]
getitem_1: "f32[s21]" = l_inputs_[1]
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None

View File

@ -474,13 +474,13 @@ class GraphModule(torch.nn.Module):
return invoke_quant_test(inner, x, y, scheme="nf4")
with self.assertRaisesRegex(
RuntimeError, "Encountered aliasing during higher order op tracing for HOP"
RuntimeError, "Encountered aliasing during higher order op tracing"
):
f(inner, x, y)
with self.assertRaisesRegex(
RuntimeError,
"Encountered input mutation during higher order op tracing for HOP",
"Encountered input mutation during higher order op tracing",
):
f(inner2, x, y)

View File

@ -16,6 +16,7 @@ import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.bytecode_transformation import transform_code_object
from torch._dynamo.exc import PackageError
from torch._dynamo.guards import CheckFunctionManager, CompileId
from torch._dynamo.symbolic_convert import (
ExceptionStack,
@ -235,6 +236,15 @@ pytree.register_constant(CustomConstantType)
class TestGuardSerialization(torch._inductor.test_case.TestCase):
def test_function_locals(self):
def foo(x):
return x + 1
def fn(x, g):
return g(x) + 1
self._test_serialization("TENSOR_MATCH", fn, torch.randn(3), foo)
def _tracefunc(self, frame, event, arg):
if event != "call":
return
@ -481,7 +491,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
# === example subclass defined locally (error) ===
local_sub = LocalSubclass(torch.randn(3))
with self.assertRaisesRegex(
RuntimeError, "Please define the class at global scope"
PackageError, "Please define the class at global scope"
):
self._test_serialization("TENSOR_SUBCLASS_METADATA_MATCH", fn, local_sub)
@ -646,7 +656,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
# we don't support NN_MODULE because it adds an ID_MATCH guard, and we don't
# support that in serialization
with self.assertRaisesRegex(
RuntimeError, "NN_MODULE guard cannot be serialized."
PackageError, "NN_MODULE guard cannot be serialized."
):
self._test_serialization("NN_MODULE", fn, m, x)
@ -662,7 +672,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
# we don't support FUNCTION_MATCH because it adds an ID_MATCH guard, and we don't
# support that in serialization
with self.assertRaisesRegex(
RuntimeError, "FUNCTION_MATCH guard cannot be serialized."
PackageError, "FUNCTION_MATCH guard cannot be serialized."
):
self._test_serialization("FUNCTION_MATCH", fn, x)
@ -676,7 +686,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
# we don't support CLOSURE_MATCH because it adds a FUNCTION_MATCH guard, and we don't
# support that in serialization
with self.assertRaisesRegex(
RuntimeError, "CLOSURE_MATCH guard cannot be serialized."
PackageError, "CLOSURE_MATCH guard cannot be serialized."
):
self._test_serialization("CLOSURE_MATCH", fn, x)
@ -795,7 +805,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
return pytree.tree_leaves(x)[0] + 1
with self.assertRaisesRegex(
RuntimeError, "DICT_VERSION guard cannot be serialized."
PackageError, "DICT_VERSION guard cannot be serialized."
):
self._test_serialization("DICT_VERSION", fn, {"t": torch.randn(3)})
@ -847,7 +857,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
return x + id(x)
with self.assertRaisesRegex(
RuntimeError, "ID_MATCH guard cannot be serialized."
PackageError, "ID_MATCH guard cannot be serialized."
):
self._test_serialization("ID_MATCH", fn, torch.randn(3))
@ -1023,7 +1033,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
x = torch.randn(3, 2)
with self.assertRaisesRegex(
RuntimeError, "DUPLICATE_INPUT guard cannot be serialized"
PackageError, "DUPLICATE_INPUT guard cannot be serialized"
):
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
@ -1040,7 +1050,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
return params[0].sum()
with self.assertRaisesRegex(
RuntimeError, "WEAKREF_ALIVE guard cannot be serialized"
PackageError, "WEAKREF_ALIVE guard cannot be serialized"
):
with torch.set_grad_enabled(False):
self._test_serialization("WEAKREF_ALIVE", fn)
@ -1159,7 +1169,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
with torch._C.DisableTorchFunction():
self._test_check_fn(ref, loaded, {"x": x}, False)
with self.assertRaisesRegex(
RuntimeError,
PackageError,
"defined in local scope. Please define the class at global scope",
):
with LocalTorchFunctionMode():

View File

@ -3,6 +3,7 @@
import contextlib
import importlib.util
import os
import re
import tempfile
import torch._dynamo.config
@ -54,6 +55,104 @@ class PgoTest(torch._dynamo.test_case.TestCase):
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 1)
def test_whitelist_suggestion(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(4, 4)
self.attr = torch.randn(4)
def forward(self, x, y):
return self.lin(x) + self.attr + y
sources = [
"L['x']",
"L['self']._modules['lin']._parameters['weight']",
"L['self']._modules['lin']._parameters['bias']",
"L['self'].attr",
"L['y']",
]
def check_whitelist(sources_):
state = torch._dynamo.pgo.render_code_state(
torch._dynamo.pgo.get_code_state()
)
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(
1
)
for src in sources_:
self.assertTrue(src in whitelist)
# check growing whitelist
f = Foo()
f(torch.randn(2, 4), torch.randn(4))
# only x
f(torch.randn(4, 4), torch.randn(4))
check_whitelist(sources[:1])
# x, lin.weight
f.lin = torch.nn.Linear(8, 4)
f(torch.randn(8, 8), torch.randn(4))
check_whitelist(sources[:2])
# x, y, lin.weight, lin.bias, attr
f.lin = torch.nn.Linear(8, 8)
f.attr = torch.randn(8)
f(torch.randn(8, 8), torch.randn(8))
check_whitelist(sources)
# now use suggested whitelist
self.reset()
cnts.clear()
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1)
with torch.compiler.config.patch(dynamic_sources=whitelist):
f = Foo()
f(torch.randn(2, 4), torch.randn(4))
f(torch.randn(4, 4), torch.randn(4))
f.lin = torch.nn.Linear(8, 8)
f.attr = torch.randn(8)
f(torch.randn(8, 8), torch.randn(8))
self.assertEqual(cnts.frame_count, 1)
def test_pgo_dynamic_params(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = None
def forward(self, x):
return self.lin(x)
f = Foo()
def run():
self.reset()
cnts.clear()
f.lin = torch.nn.Linear(4, 4)
f(torch.randn(2, 4))
f(torch.randn(4, 4))
f.lin = torch.nn.Linear(8, 8)
f(torch.randn(8, 8))
# recompile each run
run()
self.assertEqual(cnts.frame_count, 3)
# parameter static shapes are forced static, so we recompile once
run()
self.assertEqual(cnts.frame_count, 2)
# flags are flipped, PGO records dynamism, so params are dynamically compiled to start
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.force_nn_module_property_static_shapes = False
run()
self.assertEqual(cnts.frame_count, 1)
def test_njt(self):
cnts = CompileCounter()

View File

@ -3226,6 +3226,25 @@ class GraphModule(torch.nn.Module):
lengths = torch.tensor([2, 4, 3])
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
def test_in_graph_construction_from_input_6(self):
# Construct with symbolic int.
def fn(values, offsets, max_seqlen):
t = torch.nested.nested_tensor_from_jagged(
values, offsets, max_seqlen=max_seqlen
)
return torch.nested.nested_tensor_from_jagged(
values, t.offsets(), max_seqlen=t._maybe_max_seqlen
)
opt_fn = torch.compile(fn, fullgraph=True, dynamic=True)
values = torch.randn(10, 5)
offsets = torch.tensor([0, 2, 4, 7, 10])
max_seqlen = 5
ref = fn(values, offsets, max_seqlen)
res = opt_fn(values, offsets, max_seqlen)
self.assertEqualIgnoringNestedInts(ref, res)
#
# Case 2: in-graph construction where offsets are graph intermediates
#

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import dataclasses
import os
import pprint
import sys
from unittest import mock
@ -141,6 +142,69 @@ class TestUtils(TestCase):
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
def test_frame_traced_hook(self):
from utils import add, break_it
traced_code_lists = []
def get_traced_code(s):
nonlocal traced_code_lists
traced_code_lists.append(s)
def get_filenames(traced_code_lists):
return [
[code.co_filename for code in code_list]
for code_list in traced_code_lists
]
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# === no inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
return x * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
# expect hook to be called once with this file
self.assertEqual(get_filenames(traced_code_lists), [[__file__]])
# === successful inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
return add(x) * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# expect hook to be called once with both this file and file of inlined func
self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]])
# === graph break occurs during inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
y = break_it(x)
return y * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
# expect hook to be called twice; once for this file one for file of inlined func
self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]])
# === empty graph ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
return x
x = torch.randn(3)
traced_code_lists = []
fn(x)
# hook is not expected to be called at all for an empty graph
self.assertEqual(traced_code_lists, [])
class TestModel(torch.nn.Module):
def __init__(self):

View File

@ -39,6 +39,10 @@ def add(x):
return x + 1
def break_it(x):
return x.sum().item()
def create_dummy_module_and_function():
module = types.ModuleType("dummy_module")
module.__spec__ = importlib.machinery.ModuleSpec(

Some files were not shown because too many files have changed in this diff Show More