Compare commits

..

166 Commits

Author SHA1 Message Date
a38b1ee364 Merge branch 'main' into adi/onednn_aarch64 2025-11-06 10:31:51 +00:00
a51208c656 Check cluster_dims attribute exists before access (#167187)
Error in Helion CI's AMD job: https://github.com/pytorch/helion/actions/runs/19118581048/job/54633730633
```
>                   (binary.metadata.num_ctas, *binary.metadata.cluster_dims)
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    if hasattr(binary, "metadata")
                    else ()
                )
            ),
            "function": get_first_attr(binary, "function", "cu_function"),
            "runner": get_first_attr(binary, "run", "c_wrapper"),
            "math": math_lib,
            "torch": torch_lib,
            "triton": triton_lib,
        }
E       torch._inductor.exc.InductorError: AttributeError: 'KernelMetadata' object has no attribute 'cluster_dims'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167187
Approved by: https://github.com/oulgen
2025-11-06 08:02:57 +00:00
ed4aa449b6 CustomOp Inline Fusion (#165952)
Add Inline Fusion Support for Custom Op Autotuning
--------------------------------------------------

This PR extends PyTorch Inductor's custom op autotuning with inline fusion capabilities, enabling the winning decomposition to be inlined directly into the computation graph for fusion with surrounding operations.

### Usage

```python

def decompose_k_implementation(
    a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
    """Matrix multiply with k-way decomposition."""
    ...

@torch.library.custom_op("my_lib::matmul_relu", mutates_args={})
def custom_matmul_relu_dk(
    a: torch.Tensor, b: torch.Tensor, k_splits: int
) -> torch.Tensor:
    return torch.relu(decompose_k_implementation(a, b, k_splits))

register_custom_op_autotuning(
    custom_op=custom_matmul_relu_dk,
    configs=[
        CustomOpConfig(k_splits=2),
        CustomOpConfig(k_splits=4),
        CustomOpConfig(k_splits=8),
        CustomOpConfig(k_splits=32),
        CustomOpConfig(k_splits=64),
    ],
    name="decompose_k_autotuned",
    input_gen_fns={
        "a": lambda fake: torch.randn_like(fake, device='cuda'),
        "b": lambda fake: torch.randn_like(fake, device='cuda'),
    }
)
```

### How It Works
Enable optimizations from Inductor by inlining the best decomposition, allowing fusion with surrounding elementwise operations and other graph-level optimizations. This provide potentially better performance and memory efficiency.
During customop autotuning phase, we still benchmarks all CustomOpConfigs to find the fastest implementation. Then during inline fusion, inductor inline the decompositions into the main graph, converting the winning choice to individual ComputedBuffer IR nodes (fusable). At the end, Inductor automatically fuses inlined operations with surrounding elementwise ops (e.g., bias add, ReLU, scaling). Note that the winning choice must be a SubgraphChoiceCaller (decomposition-based) rather than an ExternKernelChoice for inlining to work. If the ExternKernelChoice is returned, no inline happens.

Performance Results
Benchmarked on matmul+relu workload with decompose-k fusion (H100 GPU, 15 test shapes):
<img width="782" height="377" alt="Screenshot 2025-11-04 at 12 43 11 AM" src="https://github.com/user-attachments/assets/22131d4c-a8ce-4f55-bdcd-ac758ddad8cd" />

Metric | Result
-- | --
Average Speedup vs ATen | 1.28x
Max Speedup vs ATen | 1.41x

<br class="Apple-interchange-newline">

The performance comparison are detailed in the below plots. We spot that on most use cases, the inline fusion gains better performance compared to aten baseline and the current torch.compile.
<img width="4874" height="3545" alt="image" src="https://github.com/user-attachments/assets/190a1233-412f-4f34-84cd-9b7cb582f504" />

**Test**: `test_decompose_k_with_fusion` demonstrates decompose-k with inline fusion enabled.

--------------

### Integration to mm.py decomposeK with a flag enable_inline_subgraph_fusion=True in config (deprecated to avoid breaking async compilation. removed from the PR already)
FP32:
<img width="738" height="357" alt="Screenshot 2025-11-04 at 12 05 08 AM" src="https://github.com/user-attachments/assets/ee421d22-c426-42f2-8dcd-4dcc547d6219" />
FP16:
<img width="769" height="403" alt="Screenshot 2025-11-04 at 12 13 49 AM" src="https://github.com/user-attachments/assets/346d1ffc-15af-40b0-9378-cf9b297711c2" />

The TCF column represents torch compile fusion, which is close to custom_op decomposek. The difference might due to different candidate k values.

#### Usage:
Note: this only happens when we don't benchmark_epilogue_fusion, i.e., not using multi_template_buffer.

```python
# Define the matmul+relu function
    def matmul_relu(x, y):
        return torch.nn.functional.relu(torch.matmul(x, y))

    # Compile with inline subgraph fusion enabled
    @torch.compile
    def compiled_matmul_relu(x, y):
        return matmul_relu(x, y)

    # Reset dynamo to ensure clean compilation
    torch._dynamo.reset()

    with config.patch(
        {
            "max_autotune": True,
            # CRITICAL: These two flags enable inline subgraph fusion
            "benchmark_epilogue_fusion": False,  # Must be False for inline fusion!
            "enable_inline_subgraph_fusion": True,  # Enable inline fusion
        }
    ):
        # Compile and run
        result = compiled_matmul_relu(a, b)
        torch.cuda.synchronize()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165952
Approved by: https://github.com/PaulZhang12, https://github.com/eellison
2025-11-06 06:59:10 +00:00
9eebda944d make narrow_tensor_symint DDE-free (#166379)
https://github.com/pytorch/pytorch/issues/158081

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166361
2025-11-06 06:09:22 +00:00
09d8953fb4 Update tensorpipe submodule (#167108)
To pick a single change 2b4cd91092 that should fix compilation errors with clang-21
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167108
Approved by: https://github.com/Skylion007
2025-11-06 06:08:13 +00:00
8b2365094d Expose torch.compiler.config.force_disable_caches as a public API (#166699)
Exposing this flag as some upstream frameworks (like vLLM) could benefit from knowing whether torch.compile caches are enabled or not to adjust their own caching behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166699
Approved by: https://github.com/oulgen, https://github.com/mlazos
2025-11-06 05:59:05 +00:00
7b423c2d21 [user-streams] Mark stream ops as side effectful (#167152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167152
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167141, #167151
2025-11-06 05:03:18 +00:00
46b3f913b3 [user-streams] Add record/wait ops (#167151)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167151
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167141
2025-11-06 05:03:18 +00:00
f7b7f40a6f [user-streams] Enable stream ops to work in eager (#167141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167141
Approved by: https://github.com/Lucaskabela
2025-11-06 05:03:18 +00:00
91337ae3ff [audio hash update] update the pinned audio hash (#167031)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167031
Approved by: https://github.com/pytorchbot
2025-11-06 04:57:05 +00:00
eea951758f [dynamo, 3.14] disable dynamo cpython tests in 3.14 (again) (#167000)
The previous PR was not enough to prevent errors caused by cpython dynamo tests in 3.14
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167000
Approved by: https://github.com/mlazos, https://github.com/guilhermeleobas
2025-11-06 04:34:33 +00:00
3feea296a5 torch.fx: add debug-level logging to Interpreter.run_node (#117351) (#166622)
### Summary
Adds a debug-level logging statement to torch.fx.Interpreter.run_node, as proposed in [#117351](https://github.com/pytorch/pytorch/issues/117351), to make FX graph execution traceable when debugging or instrumenting model transformations.

When debug logging is enabled, each executed node emits a single structured log line formatted via `LazyString(lambda: n.format_node())`, deferring string construction unless logging is active.

### Example Output
With `logging.DEBUG` enabled:

```
run_node x = x()
run_node add = _operator.add(x, 1)
run_node clamp = torch.clamp(add, min=0.0, max=5.0)
run_node output = output(clamp)
```

With `logging.DEBUG` disabled no additional output is produced (unchanged default behavior).

### Test Plan

Verified locally with Python 3.11 on macOS using a PyTorch build from source.

- With `logging.DEBUG` enabled: each node emits a debug log via LazyString.
- With `logging.DEBUG` disabled: no additional output.
- Confirmed all `Interpreter` tests pass locally:
`pytest test/test_fx.py -k "Interpreter"`

Updated the example output to reflect the new `_format_fx_node` helper and inclusion of `kwargs`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166622
Approved by: https://github.com/aorenste
2025-11-06 04:33:09 +00:00
c3c3653418 [1/N] Add return types of Python functions (#167162)
This PR adds return types of some Python functions. Most of them return `None`. The types were added automatically by ruff `ANN` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167162
Approved by: https://github.com/Lucaskabela
2025-11-06 04:32:14 +00:00
f72772b184 [PP] make runtime dbg log print custom actions (#167113)
Previously the log only printed if the default implementation for an
action was used, now it prints before dispatching to custom registered
actions.

Tested by running on autoparallel graph runner and observing forward
pass action logged

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167113
Approved by: https://github.com/sanketpurandare, https://github.com/Skylion007
2025-11-06 04:20:50 +00:00
981dd71893 Refactor: extract OperatorArgsKwargsView from parseIValuesToPyArgsKwargs (#166368)
Intended to make it easier to reuse this logic for processing operator arguments as IValues in following PR(s).

Testing: python test/test_python_dispatch.py (broke during development, seems to work now)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166368
Approved by: https://github.com/albanD
2025-11-06 04:18:54 +00:00
d31599f40b [7/N] Fix unused loop variables in tests (#167043)
This PR continues to fix or remove unused loop variables in tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167043
Approved by: https://github.com/Lucaskabela
2025-11-06 03:36:59 +00:00
85fab6c9b0 Fix duplicate benchmarking entries for addmm (#166652)
There have been duplicate entries for addmm in dashboard. This PR fixes the duplicate entries issues
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166652
Approved by: https://github.com/yangw-dev
2025-11-06 03:25:03 +00:00
c08ce30d18 [ci][cpu] Update compiler to GCC-13 in jammy-aarch64 (#166849)
This is needed because manylinux uses GCC-13 since #152825
As a result of the current compiler version mismatches, we've seen tests passing jammy-aarch64 pre-commit CI, but failing for wheels built in manylinux
Related to: #166736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166849
Approved by: https://github.com/robert-hardwick, https://github.com/malfet, https://github.com/Skylion007, https://github.com/atalman
2025-11-06 03:14:16 +00:00
e1a1aeaf5b [1/N] Use key in dict for existence checks (#167035)
This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167035
Approved by: https://github.com/janeyx99
2025-11-06 02:25:10 +00:00
943227f57b [c10d] Fix split_group bug by having the parent pg option deep copied (#167125)
Summary: Inside group_split api, we share the reference of PG option with parent PG if a PG option is not explicitly specified. This is bad because if we split parent pg multiple times, we will run into errors.

Test Plan: UT + internal test.

Differential Revision: D86225394

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167125
Approved by: https://github.com/Skylion007
2025-11-06 02:08:05 +00:00
3a2d75a086 Change template 'Release highlight for proposed Feature'->'New Feature for Release' (#167145)
Makes it simpler and more clear

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167145
Approved by: https://github.com/huydhn
2025-11-06 02:01:57 +00:00
69af74972b Bugfix to forward autodiff causing different datatype 2 (#165784)
Fixes #160513

## The Problem Summary
The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic.  That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence.  The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/__init__.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number.

## The Fix
The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the `is_wrapped_number` method but this came with a big issue.  During the `forward_ad` the derivative of those scalars turned out to be `ZeroTensor`s.  The `ZeroTensor` internally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially the `is_number_wrapped_` property.  I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests for `is_wrapped_number_` requires `dim > 0` and a scalar `ZeroTensor` is a meta dtype tensor which complicates things.

So I chose the route of creating a new property called `was_wrapped_number` and exposed this property to the python tensor API. I had to modify the autograd code generation to set `was_wrapped_number` in the mul, add, and div operations in  `VariableType.cpp`.  Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed.

I wrote a new ops testing module `TestForwardADWithScalars`.  I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations.

[edit]: Just used `efficientzerotensor` meta and converted that to a python number. Since wrapped number is converted back to a python number, dtype promotion is preserved.  The constraint to achieve this happened by setting the forward grad zero tensor of a wrapped number with a wrapped number flag since the tangent of the wrapped number should still be a wrapped number. After that this specific zerotensor was then sent through as a meta type in the `BinaryOps.cpp` to get appropriate dtype for resulting arithmetic.

@ezyang @OihanJoyot

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165784
Approved by: https://github.com/ezyang
2025-11-06 01:59:53 +00:00
7432676187 [MPS] Fix crash in BCELoss backwards with reduction="none" and inputs with trailing 1s in shape (#166786)
Fixes #166746 by removing squeezes that caused shape mismatches when calling backwards through `BCELoss(reduction='none')`.

Based on running these tests, it seems MPSGraph can handle inputs without squeezing.
```
python test/test_mps.py TestMPS -k test_bce
python test/test_mps.py TestConsistency -k binary_cross
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166786
Approved by: https://github.com/malfet
2025-11-06 01:55:38 +00:00
fd5edda1ed Reland "Add model code stack trace to torch.profile (#166677)" (#167110)
```python
python test/test_fx.py -k profiler
```

Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.

We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.

`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.

One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.

`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.

<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a" />

Example code gen'd.
```
def forward(self, args_list):
    args_iter = iter(args_list)
    arg0_1 = next(args_iter)
    arg1_1 = next(args_iter)
    args_list.clear()
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
    repeated_subgraph0 = self.repeated_subgraph0
    _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
    invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1);  repeated_subgraph0 = arg0_1 = arg1_1 = None
    _rf_invoke_subgraph.__exit__(None, None, None)
    _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
    getitem = invoke_subgraph[0];  invoke_subgraph = None
    _rf_getitem.__exit__(None, None, None)
    return (getitem,)
    _rf.__exit__(None, None, None)

def forward(self, arg0_1, arg1_1):
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
    _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    _rf_mul.__exit__(None, None, None)
    _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
    sin = torch.ops.aten.sin.default(mul);  mul = None
    _rf_sin.__exit__(None, None, None)
    _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
    _rf_add.__exit__(None, None, None)
    return (add,)
    _rf.__exit__(None, None, None)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167110
Approved by: https://github.com/pianpwk
2025-11-06 01:14:27 +00:00
872d1daec2 Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-06 01:04:19 +00:00
eqy
6cd57e6fc2 [cuBLAS] Force tensor-core-no-reduction algo in cuBLASLt for n=1 cases (#166735)
Ostensibly useful for batch-invariance purposes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166735
Approved by: https://github.com/ngimel
2025-11-06 00:50:42 +00:00
d29efba8fa Move almalinux docker image to DEVTOOLSET 13 (#167018)
1. Update general Almalinux image to Devtoolset 13.
2. Fix ROCm images, missing devtoolset-13
This image used by Linux Job in test-infra
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167018
Approved by: https://github.com/sudharssun, https://github.com/d4l3k
2025-11-06 00:34:40 +00:00
a344069f2a Add missing skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) to test/test_transformers.py (#166969)
This PR adds missing skips for efficient attention tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166969
Approved by: https://github.com/jeffdaily
2025-11-05 23:16:51 +00:00
af829c0dad [ROCm] Skip nvfp4 tests on ROCm (#167066)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167066
Approved by: https://github.com/jeffdaily, https://github.com/slayton58
2025-11-05 23:15:17 +00:00
3869aa115b fix fr reset api (#166970)
Summary:
- there are various places that access fr's `entries_` field
- if we empty the entries_ on reset, the accesses can result in an error
- so we only perform a soft delete instead of clearing out the entries copletely
  - only reset id_ on the reset
  - keep track of a reset_epoch which increments everytime reset is called
  - dump_entries only returns entries from the latest epoch
  - api's that access entries also check if the reset epoch matches
- make the `next_` always track the index in the circular buffer - this change was needed to make the soft delete's implementation easier

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/166970).
* #166972
* #166971
* __->__ #166970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166970
Approved by: https://github.com/fduwjj
2025-11-05 23:06:00 +00:00
47eb34b7ac [ATEN][CUDA] Reduce register pressure in radix_sort_pairs to improve torch.sort performance (#167094)
# Summary
This PR improves `torch.sort` and `torch.unique` performance by **15% to 50%** on NVIDIA GPUs by optimizing CUDA register allocation in radix sort operations.

The key change: specialize `OpaqueType<N>` to use native integer types (uint8_t, uint16_t, uint32_t, uint64_t) for common sizes (1, 2, 4, 8 bytes) instead of `char data[N]`. This enables more efficient register allocation while preserving the template deduplication strategy.

The following table shows the speedup on various input shapes and GPUs. Sorting is performed on the last dimension, and baseline torch version is 2.9.0.

| GPU  | input shape | input dtype | **Before** **(ms)** | After (ms) | Speedup |
| ---- | ----------- | ----------- | ------------------- | ---------- | ------- |
| H100 | (16, 1e6)   | int32       | 1.61                | 1.37       | 1.18×   |
| H100 | (1, 1e8)    | int32       | 6.6                 | 5.0        | 1.3×    |
| H20  | (16, 1e6)   | int64       | 3.57                | 3.03       | 1.18×   |
| H20  | (1, 1e8)    | int64       | 19.3                | 13.0       | 1.48×   |

# Analysis

`torch.sort` and `torch.unique` use `radix_sort_pairs`, which internally calls `cub::DeviceRadixSort::SortPairs`. Since values are only copied (never compared), we cast them to `OpaqueType<sizeof(value_t)>` to minimize template instantiations. For example, both `int32` and `float32` values map to the same `OpaqueType<4>.`

## The Problem

The previous `char data[N]` implementation causes inefficient register allocation. Here is one reason I find from SASS code. For 8-byte types:

- `char data[8]:` Compiler may allocate 8 registers (one per byte)

- `uint64_t data`: Compiler allocates 2 registers (standard 64-bit handling)

This happens because the compiler doesn't recognize char[8] as a cohesive 64-bit value, treating each byte independently, which increases register pressure and reduces GPU occupancy.

From Nsight Compute, when using `char data[8]`, the registers per thread is 166, and corresponding theoretical occupancy is 18.75%. When using native `uint64_t`, the registers per thread is 80, and corresponding theoretical occupancy is 37.5%.

## The Solution

Specialize `OpaqueType<N>` for common sizes using native integer types:

```
// Before
template <int N> struct alignas(N) OpaqueType { char data[N]; };

// After
template <int N> struct alignas(N) OpaqueType { char data[N]; }; // fallback
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
```

This preserves the template deduplication strategy (all 8-byte types still use the same `OpaqueType<8>` instantiation) while enabling better register allocation.

# Testing & Compatibility
## Testing:
 Correctness tests pass for various input types (bfloat16, int32, float32, int64), shapes, and dimensions (1, 2, 3)
 Register usage reduction verified with NSight Compute
 Linter passes
## Compatibility:
 No API/ABI changes
 Template instantiation count unchanged

# Reference
For detailed analysis, please refere to my previous blog: [Performance Optimization of torch.sort on GPU](https://yywangcs.notion.site/Performance-Optimization-of-torch-sort-on-GPU-192fc9f5d8058018a1bec1efa35da3f9)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167094
Approved by: https://github.com/ngimel, https://github.com/Skylion007
2025-11-05 22:34:19 +00:00
08200280ce [CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub (#166991)
While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages.
Add it to __all__ so that people can still use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991
Approved by: https://github.com/XilunWu
ghstack dependencies: #166456, #166501
2025-11-05 22:22:55 +00:00
ad7a57262c [12/N] Apply ruff UP035 rule (#166929)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 22:06:19 +00:00
711a775878 fix nccl estimations (#167093)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093
Approved by: https://github.com/kwen2501, https://github.com/eellison
2025-11-05 22:01:49 +00:00
e9a688f02e [DebugMode] output, tensor id annotations for DebugMode (#165076)
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`

Example output for `test_debug_mode_mm`, with both enabled:
```
  torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))  ->  dt$12: f32[8, 32]| S(0)
    aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
      redistribute_input(1, S(0) -> R)
        redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
          _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0)  ->  t$6: f32[8, 32]
          _c10d_functional::wait_tensor(t$7: f32[8, 32])  ->  t$8: f32[8, 32]
      aten::mm(t$9: f32[1, 8], t$10: f32[8, 32])  ->  t$11: f32[1, 32]
  <method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0))  ->  dt$17: f32[]| P
    aten::sum(dt$14: f32[8, 32]| S(0))
      aten::sum(t$15: f32[1, 32])  ->  t$16: f32[]"""
```

Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-05 22:00:11 +00:00
e69aaaf45a [user-streams] Add backward test (#167021)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167019
2025-11-05 21:24:44 +00:00
fd8f368d31 [user-streams] Add graph annotation checks (#167019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019
Approved by: https://github.com/Lucaskabela
2025-11-05 21:24:44 +00:00
13d2cc7bd2 Remove python workaround for ContextDecorator (#167049)
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-05 20:56:04 +00:00
c6c913d18e Add torch::stable::Tensor sizes and strides (#165153)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991, #165152
2025-11-05 20:55:34 +00:00
ef3f953966 Revert "[DebugMode] output, tensor id annotations for DebugMode (#165076)"
This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f.

Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159))
2025-11-05 20:52:43 +00:00
ea44f12bce [13/N] Apply ruff UP035 rule (#167048)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-05 20:51:53 +00:00
a74fe75c45 Don't hardcode double argument for reduction base (#166951)
Fixes https://github.com/pytorch/pytorch/issues/43254

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951
Approved by: https://github.com/ngimel, https://github.com/Skylion007
ghstack dependencies: #166813
2025-11-05 20:34:15 +00:00
6d30666bc1 Revert "[12/N] Apply ruff UP035 rule (#166929)"
This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52.

Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076. @cyyever Please re-merge after revert of #165076. ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596))
2025-11-05 20:02:47 +00:00
8e8cbb85ee Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)"
This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23.

Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see fbd70fb84e/1 and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038))
2025-11-05 19:42:39 +00:00
fbd70fb84e Update typing docs to reference pyrefly (#166883)
Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883
Approved by: https://github.com/malfet
2025-11-05 19:35:38 +00:00
6c5db82584 [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-11-05 19:27:23 +00:00
6052a01b71 [BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py (#167022)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022
Approved by: https://github.com/Skylion007
2025-11-05 19:18:35 +00:00
14b153bcf2 include DTensor metadata when pretty-printing fx.Graphs (#166750)
Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />

```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()

def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750
Approved by: https://github.com/ezyang, https://github.com/wconstab, https://github.com/Skylion007
2025-11-05 18:58:54 +00:00
641de23c96 ci: Add aarch64 docker builds for modern clang (#166416)
Should enable us to build using some arm optimizations that are only
available on the newest versions of clang.

Signed-off-by: Eli Uriegas <eliuriegas@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416
Approved by: https://github.com/malfet
2025-11-05 18:55:56 +00:00
89165c0a2b Update triton to 3.5.1 release (#166968)
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela, https://github.com/njriasan
2025-11-05 18:26:34 +00:00
dcc2ba4ca4 Add some code for exploring the space of accessible size/stride configs via plain views (#167076)
We are working on a translation from as_strided to view operations, but
only when the as_strided is representable as a plain view.  A useful
testing utility in this situation is the ability to enumerate all valid
views on an original tensor.  So we have a small test here that shows
it is possible.

To avoid an explosion of states, we don't handle permutes and size=1,
which are degenerate cases (you can always do a single permute and
a series of unsqueezes to get to the final desired state.)

Authored with claude code assistance.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076
Approved by: https://github.com/albanD
ghstack dependencies: #166868, #166867
2025-11-05 18:25:19 +00:00
ad5c7c20e0 Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)"
This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578.

Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312))
2025-11-05 18:13:57 +00:00
c86540f120 Revert "Add model code stack trace to torch.profile (#166677)"
This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8.

Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160))
2025-11-05 18:11:11 +00:00
c17aa0f113 [ROCm] Enable group gemm through CK (#166334)
Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334
Approved by: https://github.com/atalman
2025-11-05 18:03:59 +00:00
4ff068c33a [Code Clean] Replace assert with if statement and raise AssertionError (#166935)
Including:
- `torch/profiler/profiler.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935
Approved by: https://github.com/fffrog, https://github.com/albanD
2025-11-05 17:59:16 +00:00
0c7a4a6b48 [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890
Approved by: https://github.com/eellison
2025-11-05 17:50:08 +00:00
f93ee16fb6 [CI] Parse xml and upload json while running (#166988)
Then we can point an ClickHouse ingestor at this s3 path and get them into ClickHouse while the job is running.

use filelock to make sure each json is uploaded once so we don't end up with dups in ClickHouse
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166988
Approved by: https://github.com/izaitsevfb
2025-11-05 17:19:24 +00:00
9c2c3dbc15 Revert "Update triton to 3.5.1 release (#166968)"
This reverts commit b4e4ee81d386db922d8f63359f9870eff1f44052.

Reverted https://github.com/pytorch/pytorch/pull/166968 on behalf of https://github.com/malfet due to It might have caused deadlock/test timeouts, see d4dcd0354c/1 ([comment](https://github.com/pytorch/pytorch/pull/166968#issuecomment-3492399396))
2025-11-05 17:12:30 +00:00
d4dcd0354c [pytree][dynamo] add test to ensure tree_map preserves dict order (#166236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166236
Approved by: https://github.com/mlazos
2025-11-05 17:04:40 +00:00
aba2fa3259 Fix clang-21 warnings (#166859)
Fixes compiler warnings thrown by Clang-21

Fixes #166755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166859
Approved by: https://github.com/aditew01, https://github.com/fadara01, https://github.com/malfet
2025-11-05 16:55:51 +00:00
d2d13bf62d Invert unary read and write for fusion (#161404)
For [this repro](https://gist.github.com/eellison/75a99616a0fcca0436316bbfd8987fae) enables fusion of `to_blocked` with the prior `to_mx` calculation, so that there is only a single kernel per tensor, resulting in a 10% speedup of the non conversion code (need to update my local devserver to 12.9 to time the matmul as well).

The `to_mx` kernel has a contiguous write:

```Py
op6_op7: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op6_op7.writes = [MemoryDep('buf6', c0, {c0: 2097152}), MemoryDep('buf7', c0, {c0: 67108864})]
op6_op7.unmet_dependencies = []
op6_op7.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 67108864})]
op6_op7.outputs = [
    buf6: ComputedBuffer
    buf6.layout = FixedLayout('cuda:0', torch.float32, size=[8192, 256], stride=[256, 1])
    buf6.users = [
        NodeUser(node=SchedulerNode(name='op7'), can_inplace=False, is_weak=False),
        NodeUser(node=SchedulerNode(name='op9'), can_inplace=False, is_weak=False),
    ]
    buf7: ComputedBuffer
    buf7.layout = FixedLayout('cuda:0', torch.float8_e4m3fn, size=[8192, 256, 32], stride=[8192, 32, 1])
    buf7.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)]
]
```

While the `to_blocked` has a single discontiguous read and a single contiguous write.

```Py
op9: SchedulerNode(ComputedBuffer)
op9.writes = [MemoryDep('buf9', c0, {c0: 2097152})]
op9.unmet_dependencies = [   MemoryDep('buf6', 32768*((c0//32768)) + 8192*(((ModularIndexing(c0, 1, 16))//4)) + 256*(ModularIndexing(c0, 16, 32)) + 4*(ModularIndexing(c0, 512, 64)) + (ModularIndexing(ModularIndexing(c0, 1, 16), 1, 4)), {c0: 2097152})]
op9.met_dependencies = []
op9.outputs = [
    buf9: ComputedBuffer
    buf9.layout = FixedLayout('cuda:0', torch.float8_e8m0fnu, size=[2097152], stride=[1])
    buf9.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)]
]
```

To enable fusion, we invert the read, giving op9 and contiguous read and discontiguous write. More explanation here: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9

[Tlparse with this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000).

[Tlparse without this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161404
Approved by: https://github.com/shunting314
2025-11-05 16:10:52 +00:00
7a6ff88196 Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
2025-11-05 16:00:24 +00:00
59563dfe56 Refactor out headeronly ArrayRef (#164991)
Differential Revision: [D85091961](https://our.internmc.facebook.com/intern/diff/D85091961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991
Approved by: https://github.com/swolchok
2025-11-05 16:00:24 +00:00
5c639466f7 Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003)"
This reverts commit 658c5f879c37142b1df51c7eb6c5a5bb06318597.

Reverted https://github.com/pytorch/pytorch/pull/167003 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19093785744/job/54553796743) [HUD commit link](658c5f879c) ([comment](https://github.com/pytorch/pytorch/pull/167003#issuecomment-3491527704))
2025-11-05 14:30:15 +00:00
0b4dd08e04 [dynamo] Introduce _set_lru_cache (#167038)
Addresses the short-term plan for https://github.com/pytorch/pytorch/issues/166926. This PR can't be defaulted on, that would be terrible for cache look up times.

There's a proper fix in the works by @williamwen42.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167038
Approved by: https://github.com/williamwen42
2025-11-05 09:05:11 +00:00
edd8d356b6 fixes keyerror when loading parameter with unsaved optimizer state (#165228)
Fixes #164257

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165228
Approved by: https://github.com/fegin
2025-11-05 08:07:46 +00:00
658c5f879c [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003)
Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036?fbclid=IwY2xjawN3RL1leHRuA2FlbQIxMQBicmlkETExOEcxcnVhNVA1TzRSVmhiAR63GOEpJbZA-JhQ0CSj9ji8H_RHBUhDwYNDtxjOYfDol56OGqmC4r7jPP96Fw_aem_bWvtMfVifLQrnpv1YB_fJA, which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs.

Test Plan:
Inductor test (fbcode):
`INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"`

Tritonbench (fbcode):
`clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16  --num-inputs 1 --metrics tflops,accuracy`

Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16  --num-inputs 1 --metrics tflops,accuracy`

Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`

Differential Revision: D86231180

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167003
Approved by: https://github.com/jananisriram
2025-11-05 06:51:30 +00:00
59a6c83dfe [fx] Add strict argument validation to Interpreter.boxed_run (#166784)
# Summary

This PR fixes an issue where `torch.fx.Interpreter.boxed_run` would silently ignore extra input arguments instead of validating the argument count.

Previously, `boxed_run` would only consume as many inputs as there were placeholder nodes and then clear the entire `args_list`, hiding potential bugs. This change introduces a strict check to ensure `len(args_list)` matches the number of placeholder nodes, raising a `RuntimeError` on a mismatch.

Fixes #166583.

# Changes

* Validate `len(args_list)` against the number of placeholder nodes at the beginning of `boxed_run`.
* Raise a `RuntimeError` with a clear message ("extra arguments" or "missing arguments") if the counts do not match.
* Move `args_list.clear()` to only execute after successful validation and environment setup. If an error is raised, `args_list` is preserved for debugging.

# Testing

* Added `test_interpreter_boxed_run_argument_validation` to `test/test_fx.py`.
* This test covers three scenarios:
    1.  Correct number of arguments (succeeds, `args_list` is cleared).
    2.  Extra arguments (raises `RuntimeError`, `args_list` is preserved).
    3.  Missing arguments (raises `RuntimeError`, `args_list` is preserved).

# User-facing impact / BC notes

This is a bug fix. Code that was incorrectly passing the wrong number of arguments to `boxed_run` will now fail fast with a `RuntimeError` instead of executing silently with unintended inputs. Correctly written code is unaffected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166784
Approved by: https://github.com/ezyang, https://github.com/xmfan
2025-11-05 06:39:32 +00:00
431dfe8692 [dynamo] extend collections.defaultdict support with *args, **kwargs and custom default_factory (#166793)
Fixes #166238

Extend `collections.defaultdict` to accept `*args` and `**kwargs` in the constructor. And also support custom `default_factory`, such as `dd.default_factory` (a `GetAttrVariable`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166793
Approved by: https://github.com/guilhermeleobas
2025-11-05 06:09:39 +00:00
c00696144d Add model code stack trace to torch.profile (#166677)
```python
python test/test_fx.py -k profiler
```

Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.

We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.

`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.

One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.

`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.

<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a" />

Example code gen'd.
```
def forward(self, args_list):
    args_iter = iter(args_list)
    arg0_1 = next(args_iter)
    arg1_1 = next(args_iter)
    args_list.clear()
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
    repeated_subgraph0 = self.repeated_subgraph0
    _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
    invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1);  repeated_subgraph0 = arg0_1 = arg1_1 = None
    _rf_invoke_subgraph.__exit__(None, None, None)
    _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
    getitem = invoke_subgraph[0];  invoke_subgraph = None
    _rf_getitem.__exit__(None, None, None)
    return (getitem,)
    _rf.__exit__(None, None, None)

def forward(self, arg0_1, arg1_1):
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
    _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    _rf_mul.__exit__(None, None, None)
    _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
    sin = torch.ops.aten.sin.default(mul);  mul = None
    _rf_sin.__exit__(None, None, None)
    _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
    _rf_add.__exit__(None, None, None)
    return (add,)
    _rf.__exit__(None, None, None)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677
Approved by: https://github.com/ezyang
2025-11-05 06:08:34 +00:00
9ffc480c5a Add min/max support for barebones uint types (#166813)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166813
Approved by: https://github.com/Skylion007
2025-11-05 04:44:21 +00:00
14956eaef4 [ROCm][CI] revert ROCm magma commit hash to last known good (#167044)
PR https://github.com/pytorch/pytorch/pull/166693 updated the magma commit hash but this has been linked to ROCm 7.1 CI failures.  Go back to last known working magma version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167044
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-11-05 04:18:04 +00:00
066c5c57a9 Fix typo in gloo_hip library name (#166502)
The typo was never noticed; conditions to enable it require system gloo: `-DUSE_SYSTEM_GLOO=ON -DUSE_GLOO=ON -DUSE_DISTRIBUTED=ON -DUSE_ROCM=ON`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166502
Approved by: https://github.com/jerryzh168, https://github.com/cyyever
2025-11-05 04:14:01 +00:00
08ef852a4b [unified v2][apple] Clean up APPLETVOS from caffe2 (#166953)
Summary: This is not used, so delete it

Test Plan:
```
$ buck targets xplat/... > /dev/null
```

Reviewed By: dtolnay

Differential Revision: D86125712

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166953
Approved by: https://github.com/seemethere
2025-11-05 03:09:56 +00:00
56fc99915b Fix typos in complex numbers docs (#166671)
This PR fixes two small typos in the complex numbers docs:

1. "numbercial" -> "numerical"
2. "easily to switch" -> "easily switch to"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166671
Approved by: https://github.com/jcaip, https://github.com/Arpitha781, https://github.com/mlazos, https://github.com/cyyever
2025-11-05 03:05:06 +00:00
5863ba1b2e [12/N] Apply ruff UP035 rule (#166929)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 03:03:41 +00:00
a743f9eeb5 Revert "Avoid DDE in narrow with unbacked start (#166361)"
This reverts commit ed45c5f38df6aa419c67d139d932c2c94404223a.

Reverted https://github.com/pytorch/pytorch/pull/166361 on behalf of https://github.com/malfet due to Looks like it broke test_torchfuzz subtests, see 01e6e35c7f/1 ([comment](https://github.com/pytorch/pytorch/pull/166361#issuecomment-3488916766))
2025-11-05 02:39:55 +00:00
53b03f1a2b Revert "make narrow_tensor_symint DDE-free (#166379)"
This reverts commit d7e2d0ad301b5d0db049bf5d2a2fc7ff9c89c58c.

Reverted https://github.com/pytorch/pytorch/pull/166379 on behalf of https://github.com/malfet due to Need to revert previous PR in the stack ([comment](https://github.com/pytorch/pytorch/pull/166379#issuecomment-3488910172))
2025-11-05 02:36:46 +00:00
cd5d810c3a Annotation should be deepcopied (#167017)
The annotation should be deepcopied. Otherwise all nodes with the same `seq_nr` share the same underlying dict

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167017
Approved by: https://github.com/yiming0416
2025-11-05 02:22:33 +00:00
01e6e35c7f Send / recv support in local tensor (#166595)
This change introduces LocalRunnerMode that allows you to run multiple
    SPMD functions concurrently. SMPD functions are executing one at a time,
    yielding execution capability while waiting for send or receive operations
    to complete. Send and receive peer operations only supported while running
    under LocalRunnerMode.

    The example test in this change demonstrates how ranks are sending data
    to the next peer and receiving data from the previous peer (ring).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166595
Approved by: https://github.com/wconstab, https://github.com/ezyang
2025-11-05 01:36:44 +00:00
bcd159bcdd Fix the vmap op fallback bug (#166032)
## The bug

In some environments, if run:

```py
    def inner_func(x):
      return x.to(torch.float32, memory_format=torch.channels_last)

    x = torch.randn(2, 2, 3, 4, device="cpu", dtype=torch.float64)
    torch.vmap(inner_func)(x)
```

we get:

```
E       RuntimeError: Batching rule not implemented for aten::to.dtype_layout; the fallback path doesn't work on out= or view ops.
```

Otherwise, it would always fallback and result in an error for ops like `to.dtype` and `to.dtype_layout` even the kernels are registered.

## The cause

The alias key of `FuncTorchBatchedDecomposition` is not properly translated to runtime dispatch keys when updating the dispatch table of `OperatorEntry::dispatchTable_`. [[link](984b096d10/aten/src/ATen/core/dispatch/OperatorEntry.cpp (L500-L501))]
The [`getRuntimeDispatchKeySet`](f3fa560dec/c10/core/DispatchKeySet.cpp (L62)) use if-else to translate all other alias keys but `FuncTorchBatchedDecomposition`.

This would result in not finding the kernel in many cases.

## The fix

This PR adds one more `if` statement to `getRuntimeDispatchKeySet` to map `FuncTorchBatchedDecomposition` to the corresponding runtime dispatch key, `FuncTorchBatched`.
So, that the dispatch table can be properly updated.

This fix allows people to use ops inside vmaps in more environments and across more compilers.

## Why does it work without the PR

As long as the `FuncTorchBatchedDecomposition` [[link](51319ca090/aten/src/ATen/functorch/BatchRulesDecompositions.cpp (L35))]
is registered before the fallback method of `FuncTorchBatched` [[link](d311a3d1dc/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp (L759))], everything runs fine.

In this case, it relies on the registration of the fallback method to update the dispatch table, which flushes all the kernels in `OperatorEntry::kernels_` into `dispatchTable_`, among which there are kernels registered with `FuncTorchBatchedDecomposition`.

## When does it fail

However, the order of the op registration and the fallback registration is not garanteed at all.
It relies on the C++ static initialization order, which varies from environment to environment.
On our compiler, it the fallback registration goes first and the alias key kernels under `FuncTorchBatchedDecomposition` comes later and not get flushed into the dispatch table by the fallback registration.
Therefore, it cannot find the kernel for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166032
Approved by: https://github.com/albanD
2025-11-05 01:16:58 +00:00
64ae31c5d3 [HOP][print] Add HOP subclass for printing (#166660)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166660
Approved by: https://github.com/angelayi, https://github.com/anijain2305

Co-authored-by: Angela Yi <yiangela7@gmail.com>
2025-11-05 01:16:49 +00:00
45da6e1fe1 [CD] Upload XPU inductor benchmark test reports to s3 (#166954)
As the title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166954
Approved by: https://github.com/atalman
2025-11-05 01:02:57 +00:00
39160dba0c shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-11-05 00:54:40 +00:00
f2fbc81c50 [RFC] Add experimental Pallas TorchInductor backend (#166822)
Very simple Pallas TorchInductor backend
Given
```
import torch

def f(x, y):
    return x.sin() + y

torch._inductor.config.cuda_backend="pallas"

x = torch.randn(4).cuda()
y = torch.randn(4).cuda()

compiled = torch.compile(f, backend="inductor", fullgraph=True)
torch.testing.assert_close(compiled(x, y), f(x, y))
```
it outputs
```
import torch
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils import dlpack as torch_dlpack
def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0):
    tmp0 = in_ptr0[...]
    tmp1 = jnp.sin(tmp0)
    tmp2 = in_ptr1[...]
    tmp3 = tmp1 + tmp2
    out_ptr0[...] = tmp3
def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None):
    # Convert Torch -> JAX for inputs
    in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0))
    in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1))
    # Prepare output spec from PyTorch tensor
    # Map PyTorch dtype to JAX dtype string
    _torch_dtype_to_jax = {
        torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,
        torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,
        torch.uint8: jnp.uint8, torch.bool: jnp.bool_,
    }
    out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype])
    compiled = pl.pallas_call(
        lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs),
        out_shape=out_spec,
        grid=(1,),
    )
    res = compiled(in_ptr0_jax, in_ptr1_jax)
    # Copy result back into the provided torch output tensor
    res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))
    out_ptr0.copy_(res_t)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166822
Approved by: https://github.com/jansel
ghstack dependencies: #166976, #166982
2025-11-05 00:52:41 +00:00
4271ffe918 don't produce invalid grid configs (#166974)
Proper fix for #164048, fixes gather too, reverts #164049
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166974
Approved by: https://github.com/eqy
2025-11-05 00:20:27 +00:00
7eefcfb1db [BE][Typing][Dynamo] Type torch/_dynamo/variables/ctx_manager.py (#166878)
Provides type coverage to torch/_dynamo/variables/ctx_manager.py

Coverage report:
`mypy torch/_dynamo/variables/ctx_manager.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1541 lines and 144 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166878
Approved by: https://github.com/Skylion007
2025-11-04 23:54:18 +00:00
4b12c0344d Add default .github/copilot-instructions.md and item in .gitignore for allowing local changes (#166864)
Fixes [#166850](https://github.com/pytorch/pytorch/issues/166850)

- Create a default `.github/copilot-instructions.md` file (used Claude Sonnet 4.5 in Copilot).
- Add `.github/copilot-instructions.md` to the `.gitignore` file.

The prompt used is below, which is preset by Copilot:
```
Analyze this codebase to generate or update `.github/copilot-instructions.md` for guiding AI coding agents.

Focus on discovering the essential knowledge that would help an AI agents be immediately productive in this codebase. Consider aspects like:
- The "big picture" architecture that requires reading multiple files to understand - major components, service boundaries, data flows, and the "why" behind structural decisions
- Critical developer workflows (builds, tests, debugging) especially commands that aren't obvious from file inspection alone
- Project-specific conventions and patterns that differ from common practices
- Integration points, external dependencies, and cross-component communication patterns

Source existing AI conventions from `**/{.github/copilot-instructions.md,AGENT.md,AGENTS.md,CLAUDE.md,.cursorrules,.windsurfrules,.clinerules,.cursor/rules/**,.windsurf/rules/**,.clinerules/**,README.md}` (do one glob search).

Guidelines (read more at https://aka.ms/vscode-instructions-docs):
- If `.github/copilot-instructions.md` exists, merge intelligently - preserve valuable content while updating outdated sections
- Write concise, actionable instructions (~20-50 lines) using markdown structure
- Include specific examples from the codebase when describing patterns
- Avoid generic advice ("write tests", "handle errors") - focus on THIS project's specific approaches
- Document only discoverable patterns, not aspirational practices
- Reference key files/directories that exemplify important patterns

Update `.github/copilot-instructions.md` for the user, then ask for feedback on any unclear or incomplete sections to iterate.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166864
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-04 23:53:56 +00:00
661b639663 use_cpp_bmm_template supports more use cases (#165469)
Summary: In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous, but the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. This diff specifically checks for contiguity within the 2D matrix of each batch, and enables more uses for cpp bmm template.

Differential Revision: D84561331

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165469
Approved by: https://github.com/desertfire
2025-11-04 23:47:17 +00:00
0cd809f60c [inductor][AMD] Filter out invalid Triton Configs for MI350X _scaled_mm (#166442)
Summary: Mirrors change done in D81180838 but for inductor. Without this change, running _scaled_mm on MI350X accelerator would crash.

Test Plan: HIP_VISIBLE_DEVICES=7 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run mode/opt-amd-gpu   -m rocm70   -c fbcode.rocm_arch=mi350   scripts/jchunx/gemm:scaled_mm_microbench --   --csv_file /home/jchunx/scripts/fp8_shapes.csv   --backend triton,aten   --fast_accum=true   2>&1 | tee ~/logs/scaled_mm.log

Reviewed By: bilal

Differential Revision: D85694383

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166442
Approved by: https://github.com/bilal
2025-11-04 23:47:11 +00:00
a96728d188 Clarify safety of CUDA graph memory pool sharing across graphs that are replayed in arbtirary order. (#166975)
Some users at pytorch conference were asking me about whether it is safe to share a memory pool among cuda graphs that never run concurrently, but may run in arbitrary order, if they don't depend upon each other's output. Even though your capture order doesn't match replay order in this situation, this is safe. However, our documents confusingly said this wasn't allowed. This update is intended to help with that. Since vLLM essentially depends upon this behavior, I call it out specifically.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166975
Approved by: https://github.com/eellison, https://github.com/BoyuanFeng
2025-11-04 23:36:03 +00:00
c1e91bd4c3 [export] Codemod unittests to use new graph capture API (#166957)
Summary:
as title.

Test Plan:
pytest test/functorch/test_aot_joint_with_descriptors.py
pytest test/higher_order_ops/test_local_map.py

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166957
Approved by: https://github.com/angelayi, https://github.com/yushangdi
2025-11-04 22:55:30 +00:00
d7e2d0ad30 make narrow_tensor_symint DDE-free (#166379)
https://github.com/pytorch/pytorch/issues/158081

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166361
2025-11-04 22:43:15 +00:00
81038fd326 Revert "Add model code stack trace to torch.profile (#166677)"
This reverts commit e8052f2f99de1fb7284e38082ff5714e17cd9562.

Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/malfet due to Broke lint, please rebase, we've moved from mypy to pyrefly ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3488219996))
2025-11-04 22:26:35 +00:00
e020fb3431 [Minor][Inductor] move some combo kernel log from warning to debug (#166993)
Combo kernel warns for long reduction and large pointwise. This becomes too spammy for users such as vLLM.

This PR moves these logs from warn to debug. I validated the spammy log is removed on llama-3.1-8B.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166993
Approved by: https://github.com/zou3519, https://github.com/eellison
2025-11-04 22:09:27 +00:00
e8052f2f99 Add model code stack trace to torch.profile (#166677)
```python
python test/test_fx.py -k profiler
```

Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.

We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.

`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.

One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.

`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.

<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a" />

Example code gen'd.
```
def forward(self, args_list):
    args_iter = iter(args_list)
    arg0_1 = next(args_iter)
    arg1_1 = next(args_iter)
    args_list.clear()
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
    repeated_subgraph0 = self.repeated_subgraph0
    _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
    invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1);  repeated_subgraph0 = arg0_1 = arg1_1 = None
    _rf_invoke_subgraph.__exit__(None, None, None)
    _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
    getitem = invoke_subgraph[0];  invoke_subgraph = None
    _rf_getitem.__exit__(None, None, None)
    return (getitem,)
    _rf.__exit__(None, None, None)

def forward(self, arg0_1, arg1_1):
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
    _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    _rf_mul.__exit__(None, None, None)
    _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
    sin = torch.ops.aten.sin.default(mul);  mul = None
    _rf_sin.__exit__(None, None, None)
    _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
    _rf_add.__exit__(None, None, None)
    return (add,)
    _rf.__exit__(None, None, None)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677
Approved by: https://github.com/ezyang
ghstack dependencies: #166676
2025-11-04 22:05:36 +00:00
a64c7d7404 [DebugMode] output, tensor id annotations for DebugMode (#165076)
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`

Example output for `test_debug_mode_mm`, with both enabled:
```
  torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))  ->  dt$12: f32[8, 32]| S(0)
    aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
      redistribute_input(1, S(0) -> R)
        redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
          _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0)  ->  t$6: f32[8, 32]
          _c10d_functional::wait_tensor(t$7: f32[8, 32])  ->  t$8: f32[8, 32]
      aten::mm(t$9: f32[1, 8], t$10: f32[8, 32])  ->  t$11: f32[1, 32]
  <method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0))  ->  dt$17: f32[]| P
    aten::sum(dt$14: f32[8, 32]| S(0))
      aten::sum(t$15: f32[1, 32])  ->  t$16: f32[]"""
```

Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-04 21:30:46 +00:00
cdca63db8c Fix quoting in pytest_cache.py invocations (#166955)
Especially the job identifier can contain spaces so needs to be quoted

Fixes e.g. https://github.com/pytorch/pytorch/actions/runs/19063797853/job/54449422160#step:15:52

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166955
Approved by: https://github.com/Skylion007
2025-11-04 21:28:19 +00:00
ed45c5f38d Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-04 21:24:57 +00:00
7f0e932136 [dynamo] don't use LocalSource for temp variables created by side_effects (#166917)
Fixes https://github.com/pytorch/pytorch/issues/166900

Implementation notes:
- I tried to disallow guard generation before side effect application in order to futureproof improper guard generation. However, this was not feasible since it is possible to realize lazy VTs while generating side effects (e.g. realizing a constant variable that is used in a deque update).
- `codegen_save_tempvars` now generates `TempLocalSource` for create temporary variables now, so that they won't get confused with `LocalSource` - we should error out when we attempt to create guards for `TempLocalSource`. I considered using `SyntheticLocalSource`, but that has additional `subguards_allowed` behavior that we may not want to have for temp variables.
- We moved the guard installation for constant user-defined pytree objects from `as_python_constant` to `__init__`. Objects created outside the compile-region will be guarded, while objects created inside the compile-region will not be guarded.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166917
Approved by: https://github.com/anijain2305
2025-11-04 21:16:18 +00:00
2673f8b007 Fix torch.linalg.eig inductor stride mismatch (#162484)
Fixes #159445

### Summary

- Fixed a stride layout issue in the `torch.linalg.eig` meta kernel that prevented successful compilation with the inductor backend. The meta kernel was producing incorrect row-major strides.

- LAPACK/BLAS libraries (underlying implementation) expect column-major layout

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162484
Approved by: https://github.com/isuruf
2025-11-04 21:06:58 +00:00
4e1bd16738 inductor: Switch quiesce to use timer based implementation. (#166581)
Major change is to switch to a timer based implementation. Additionally,
we get rid of the context manager for turning of the compile pool. We
still have the warmup calls.

Note that this only modifies the async_compile methods, the fx pool is
left running.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166581
Approved by: https://github.com/masnesral
ghstack dependencies: #166467
2025-11-04 21:01:49 +00:00
871d0cd196 If USE_CUDA=1 is set, do not fallback to no CUDA (#166982)
So many times i build pytorch only to notice chef nuked my nvcc and i wasted 30m building a cpu version, lets hard error fast

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166982
Approved by: https://github.com/malfet
ghstack dependencies: #166976
2025-11-04 20:51:14 +00:00
2bba37309b [inductor] runtime estimations disable use_nccl_estimator by default (#166973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166973
Approved by: https://github.com/eellison, https://github.com/jathu
2025-11-04 20:48:22 +00:00
b4e4ee81d3 Update triton to 3.5.1 release (#166968)
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela, https://github.com/njriasan
2025-11-04 20:34:13 +00:00
3283eaa5ba Upload test stats for trunk/sha tag (#166916)
Noticed that workflow runs for `trunk/{sha}` tags (issued by autorevert) don't populate test_run_s3 Clickhouse table.

This PR is addressing this by changing the gate condition to upload tests stats.

see https://github.com/pytorch/pytorch/actions/runs/19054297956/job/54421254448#step:8:23
as an evidence that HEAD_BRANCH is correctly populated for trunk tags.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166916
Approved by: https://github.com/huydhn, https://github.com/clee2000
2025-11-04 20:33:56 +00:00
397d9fe2ae [inductor] coordesc not tune XBLOCK for mix-order-reduction (#166669)
For mix-order reduction, we current force XBLOCK to be 1 to simplify codegen. Don't tune it in CDT.

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D86224689](https://our.internmc.facebook.com/intern/diff/D86224689)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166669
Approved by: https://github.com/jansel, https://github.com/mlazos, https://github.com/eellison, https://github.com/v0i0
2025-11-04 20:27:07 +00:00
d77c24caac Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#165036)"
This reverts commit 0e1a88904f4a5e30634b196678b56e1d6ec074f5.

Reverted https://github.com/pytorch/pytorch/pull/165036 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19059329909/job/54439919668) [HUD commit link](0e1a88904f) ([comment](https://github.com/pytorch/pytorch/pull/165036#issuecomment-3487846555))
2025-11-04 20:13:33 +00:00
cef98ae5cb [aotd] Compiled saved tensor hooks context (#166887)
Draft to expose compiled saved tensor hook context to selectively apply them.
Exposing node, fw_graph, bw_graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166887
Approved by: https://github.com/bdhirsh
2025-11-04 20:07:00 +00:00
52ea135f77 [BE] Delete Python-3.9 stdlib definitions from torch.package (#166768)
And simplify the entire function to just assert and return

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166768
Approved by: https://github.com/cyyever, https://github.com/atalman
2025-11-04 19:33:14 +00:00
a5f3035aaf More pyrefly local errors (#166976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976
Approved by: https://github.com/maggiemoss, https://github.com/Skylion007
2025-11-04 18:51:35 +00:00
1d3f5e19da [cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)
Fix and regression test for https://github.com/pytorch/pytorch/issues/165801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/Skylion007, https://github.com/drisspg

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
2025-11-04 18:46:43 +00:00
496277a8ff [ROCm][CI] Lower runner check gpu count for distributed jobs (#166961)
This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information:
https://github.com/pytorch/pytorch/issues/166866

It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961
Approved by: https://github.com/jeffdaily
2025-11-04 18:44:21 +00:00
53f75cd5ba Fixed some syntax errors in SECURITY.md file. (#166718)
Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718
Approved by: https://github.com/mikaylagawarecki
2025-11-04 18:18:38 +00:00
527b1109a8 Delete deprecated fp32 precision warnings (#166956)
The deprecation warning led to warning spamming in PyTorch APIs, like
torch.compile. This is not how a deprecation warning should go: if we
add a deprecation warning, we'd better update our built-in APIs to
prevent warning spam.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166956
Approved by: https://github.com/albanD
2025-11-04 17:50:04 +00:00
clr
3144713325 subproc_pool: Add support for enabling quiesce via a timer (#166467)
This adds the capability to subproc pool to enable quiesce via a timer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166467
Approved by: https://github.com/masnesral
2025-11-04 17:37:41 +00:00
eefa16342c [Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)
Prefer unfused addmm when there is at least a single elemwise/reduction consumer..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165
Approved by: https://github.com/eellison
2025-11-04 17:23:04 +00:00
d02f68f484 [BE] Use [[maybe_unused]] (#166865)
Instead of `(void) foo; // Unused parameter` trick, as this is a C++17 standard feature

Will replace further repetitions of the same pattern soon after
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166865
Approved by: https://github.com/mikaylagawarecki, https://github.com/Skylion007, https://github.com/janeyx99
2025-11-04 17:08:28 +00:00
68eb55c4b2 Add model code stack trace to cuda.memory._snapshot (#166676)
We store a mapping between generated fx graph code and original model code stack trace in `fx.traceback._FX_METADATA_REGISTRY`. And we do a post-processing on the memory snapshot to append the original model stack trace information.

To achieve this, the biggest change we had to do in `aot_eager` mode is to give each generated fx graph a unique stack trace, i.e. it cannot just be `<eval_with_key>`. We set co_filename to **pretend** that the code is from `co_filename` file. Now instead of `<eval_with_key>` in stack trace, we get something like `fx_generated_3a4b5c6d7e8f9a0.py`.

`augment_with_fx_traces` arg is added to `torch.cuda.memory._snapshot` and `_dump_snapshot`. When the arg is set to True, a post-processing will run to populate the original model stack trace to the snapshot frames.

The new behavior of GraphModule can be controlled by `TORCH_ENRICH_RPOFILER_STACK_TRACE` or `_dynamo.config.enrich_profiler_metadata=True`.

Alternative:

Instead of setting co_filename, we can also do it like below:
Note that if we do it this way, we will need to dump the file to make the graph module torch-scriptable. TorchScript requires source access in order to carry out compilation, so we need to make sure original .py files are available.
```
        key = filename
        globals_copy = globals.copy()
        globals_copy["__file__"] = key
        globals_copy["__name__"] = key
        linecache.lazycache(key, globals_copy)
        exec(compile(src, key, "exec"), globals)
````

Other changes:

- Update `MemoryViz.js` to display fx node information and original model code if exist

```
python test/test_fx.py -k test_lineno_map
python test/test_fx.py -k test_custom_traceback_raised
python test/test_public_bindings.py
python test/test_cuda.py -k test_fx_memory
python test/test_fx.py -k test_informative_co_filename
python test/test_fx.py -k test_autowrap_functions
python test/dynamo/test_utils.py -k test_inductor_provenance
```

```python
# Profile with memory snapshot
torch.cuda.memory._record_memory_history()

with  torch._dynamo.config.patch("enrich_profiler_stack_trace", True):
    compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
    result = compiled(torch.randn(10, 10, device="cuda:0"))

torch.cuda.memory._dump_snapshot("memory_snapshot.pickle", augment_with_fx_traces=True)
torch.cuda.memory._record_memory_history(enabled=None)
```

<img width="913" height="711" alt="Screenshot 2025-10-30 at 10 40 44 AM" src="https://github.com/user-attachments/assets/8d7a1833-f98d-4756-b666-1d63ab57b27b" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166676
Approved by: https://github.com/albanD, https://github.com/ezyang
2025-11-04 17:01:02 +00:00
8d4b8ab430 [ez] Print some more test timing info in the logs (#166447)
You can just subtract timestamps, but this makes it easier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166447
Approved by: https://github.com/Skylion007
2025-11-04 16:45:22 +00:00
afd50bdd29 [CI] Use smaller amx + avx2 runners for inductor test? (#164989)
Results from CI:
No failures but generally takes longer, maybe ~20% increase in time?
But the smaller runner is ~25% of the cost of the current runner, so in terms of cost this is a decrease

If the 20% is too much, we can try the 4x larger runners, which are about half the cost of the current runner, so it would probably still result in cost savings with hopefully less impact to time

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164989
Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn
2025-11-04 16:43:06 +00:00
56dfd4c74b Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526)
Summary:

* Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA
* Add testing

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526
Approved by: https://github.com/drisspg, https://github.com/ngimel
2025-11-04 15:53:16 +00:00
24db5c4451 [inductor] do not hard fail on FakePG with nccl estimator (#166869)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166869
Approved by: https://github.com/eellison
ghstack dependencies: #166521
2025-11-04 15:22:38 +00:00
cc8bfd1206 Docker release build: Use 13.0.0 nvidia docker (#166904)
Forward fix for failing Docker release builds
Related to: https://github.com/pytorch/pytorch/issues/166897

Nightly Docker build failure https://github.com/pytorch/pytorch/actions/runs/18900508440/job/53946606434
Due to missing base image:
```
ERROR: failed to build: failed to solve: docker.io/nvidia/cuda:13.0.2-devel-ubuntu22.04: not found
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166904
Approved by: https://github.com/tinglvv, https://github.com/malfet
2025-11-04 13:58:10 +00:00
c45b156605 Fix DeepSeek scaling tensor handling (#166752)
Summary:

cuBlasLt enforces size/stride requirements for 1x128 and 128x128 blockwise scaling
kernels, some of which weren't being handled, causing silent incorrect
answers especially for 128x128 scaling cases.

cuBlasLt enforces ([docs](https://docs.nvidia.com/cuda/cublas/#scaling-factors-layouts)) for deepseek-style
scaling, for `A: MxN`, `B: KxN` you have the following:

```Py
L = K // 128
L4 = round_up(L, 4)

1x128 x 128x128:
* A_scale: [M, K // 128], stride: [1, M]
* B_scale: [L4, N // 128], stride: [1, L4]

128x128 x 1x128:
* A_scale: [L4, M // 128], stride: [1, L4]
* B_scale: [N, K // 128], stride: [1, N]

1x128 x 1x128:
* A_scale: [M, K // 128], stride: [1, M]
* B_scale: [N, K // 128], stride: [1, N]
```

Notable here is the `L4` term, which means that we must round up to the nearest multiple of 4 blocks
in the `K` dimension. This wasn't enforced previously, and caused silent wrong answers
where `(K // 128) % 4 != 0`.

Test Plan:

Reviewers:

Subscribers:

@vkuzo

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166752
Approved by: https://github.com/drisspg, https://github.com/vkuzo
2025-11-04 13:32:24 +00:00
8fff7e36b4 [xpu][test] Add UT for expandable segments (#166495)
# Motivation
This PR aims to reuse some UT to validate the expandable segment feature.

# Additional Context
Currently, the failure is related to the internal track `GSD-11403`, we could get the fix when upgrading the driver to `ci-neo-master-034630` or greater
TODO: add test conv and gemm into this test case when upgrading the driver.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166495
Approved by: https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
ghstack dependencies: #166299, #166292, #166424
2025-11-04 08:01:35 +00:00
82fa2aa269 DTensor: Fix trivial as_strided case, add alias support (#166867)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166867
Approved by: https://github.com/albanD
ghstack dependencies: #166868
2025-11-04 07:18:32 +00:00
09e0285608 [xpu][feature][inductor] Enable decompose_mm_pass and UT on Intel GPU (#166613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166613
Approved by: https://github.com/hl475
2025-11-04 06:58:05 +00:00
d980d8dc79 [dynamo] Implement __sym_float__ for SymBool to fix multiplication TypeError (#165264)
Fixes #164684

### Description

Symbolic tracing fails during multiplication between a `SymBool` and a `Tensor`. This scenario is triggered when `.item()` is called on a 0-dim boolean tensor within a `torch.compile` region. In compile mode, this yields a `SymBool`, and the subsequent `SymBool * FakeTensor` operation is unsupported, leading to a `TypeError` or a data-dependent `UserError`.

### Solution

This PR addresses the issue at the type-conversion level, as suggested by reviewers.

The root cause of the TypeError is that torch.sym_float() (which is called by _maybe_convert_to_dtype during type promotion for aten.mul) lacks a conversion path for SymBool and incorrectly falls back to builtins.float(SymBool).

This fix addresses this by implementing the __sym_float__(self) method within the SymBool class (defined in torch/__init__.py).

The torch.sym_float(a) utility function is already designed to check for hasattr(a, "__sym_float__") before falling back to builtins.float(). By adding this method, SymBool instances now correctly advertise their ability to be cast to SymFloat. The new method implementation leverages self.node.sym_float() to correctly convert the symbolic boolean value to its symbolic float representation (0.0 or 1.0), resolving the TypeError at its source.

This approach is more fundamental than modifying a specific operation in builtin.py and ensures SymBool can be correctly promoted to SymFloat in any operation, while still preserving its boolean nature for control flow operations like guard_or_false (which is verified by a new test case).

### Verification

1.  **Bug Reproduced**: The initial `UserError: Could not guard on data-dependent expression` was successfully reproduced with the script from the issue. As shown below
<img width="1369" height="945" alt="Screenshot 2025-10-13 at 10 29 05" src="https://github.com/user-attachments/assets/8daa4555-3347-4af5-906a-02150b8df9d1" />

2.  **Fix Validated**: After applying the code changes, the same script now runs to completion, printing ` eager success` and ` compile success`. As shown below
<img width="1228" height="82" alt="Screenshot 2025-10-13 at 10 29 21" src="https://github.com/user-attachments/assets/94c4f143-b898-4dda-9bff-0ad5450a30fa" />

3. Added a new test class DynamoOpPromotionTests to test/dynamo/test_misc.py with three new test cases:
1. test_symbool_tensor_mul_does_not_fail: Verifies that the original bug report code (with .item() + *) no longer raises an error when compiled.
2. test_symbool_guard_or_false: Verifies that this fix does not cause a regression for guard_or_false(SymBool) (the concern raised by reviewers).
3. test_symbool_tensor_mul: Verifies the behavior of Tensor(bool) * Tensor(float) (without .item()) for completeness.
All new tests were added and pass locally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165264
Approved by: https://github.com/laithsakka, https://github.com/Lucaskabela
2025-11-04 06:33:20 +00:00
c7d00de115 [xpu][fix] Fix XPU oneDNN memory query bug: pointer to array (#166830)
# Motivation

I believe this is a bug - here's why:
In [dnnl_common_types.h](98132c4908/include/oneapi/dnnl/dnnl_common_types.h (L116-L125)) is defined as a pointer to an `int64_t[12]` array;
We can confirm this from the implementation in [memory_desc.cpp](98132c4908/src/common/memory_desc.cpp (L746-L748)) where the member indeed points to an internal array.

# Solution

Therefore, when accessing `md_padded_dims`, we should first dereference the pointer and then use it with an index - directly using it without dereferencing would corrupt memory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166830
Approved by: https://github.com/EikanWang
2025-11-04 06:12:40 +00:00
d3cf90ada5 Revert "[inductor] require shape in TritonCSEVariable (#162275)"
This reverts commit c21868b4359586550b12e1d9102283c792f45dff.

Reverted https://github.com/pytorch/pytorch/pull/162275 on behalf of https://github.com/izaitsevfb due to breaking test_rms_norm_bwd_float32_split_reductions_True_shape2 ([comment](https://github.com/pytorch/pytorch/pull/162275#issuecomment-3484049109))
2025-11-04 06:06:18 +00:00
0e1a88904f [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#165036)
Make sure you're on cutlass 4.2.0+

Test Plan:
Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16  --num-inputs 1 --metrics tflops,accuracy`

Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`

Differential Revision: D82010227

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165036
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg, https://github.com/mlazos
2025-11-04 05:58:58 +00:00
3232caa078 [XPU][Fix] Register convolution_overrideable for flops count (#166839)
Fixes #166838
1. Register `convolution_overrideable` key for flop_counter. CUDA relies on keys with `cudnn_convolution`. For devices like `XPU`, it falls to `convolution_overrideable`. Without the correct registration, the flop_couter will silently return 0 for XPU in line:
e1d011d6eb/torch/_inductor/analysis/profile_analysis.py (L178-L179)

2. Enable the tests when enabling the XPU on `test_analysis.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166839
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/jansel
2025-11-04 05:56:29 +00:00
a6c6acea9d [11/N] Apply ruff UP035 rule (#166225)
This PR continues to apply ruff UP035 rule to inductor code. ruff UP035 rule aims to use Python 3.10 syntax and libraries.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166225
Approved by: https://github.com/aorenste
2025-11-04 04:53:40 +00:00
55be1cc739 [dynamo, 3.14] add explicit SymFloat int conversion (#166902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166902
Approved by: https://github.com/malfet, https://github.com/pianpwk
ghstack dependencies: #166757, #166894, #166895
2025-11-04 04:38:03 +00:00
344cebda52 [dynamo, 3.14] disable cpython dynamo unittests if 3.14 (#166895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166895
Approved by: https://github.com/guilhermeleobas
ghstack dependencies: #166757, #166894
2025-11-04 04:38:03 +00:00
ba72c6b981 [dynamo, 3.14] fix dynamo error message test for 3.14 (#166894)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166894
Approved by: https://github.com/malfet
ghstack dependencies: #166757
2025-11-04 04:38:03 +00:00
888efcc453 [dynamo, 3.14] support tracing type.__dict__[__annotations__].__get__ to trace through typing.get_type_hints (#166757)
This is covered by `test_get_type_hints` in test/dynamo/test_repros.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166757
Approved by: https://github.com/Lucaskabela
2025-11-04 04:38:03 +00:00
24aa9a2ef7 [ROCm][CI] Add distributed testing back to trunk.yml (#166915)
Adding distributed testing back to trunk since we have been observing [reasonable queueing](https://hud.pytorch.org/queue_time_analysis?dateRange=30&startDate=2025-10-05T01%3A44%3A55.924Z&endDate=2025-11-04T01%3A44%3A55.925Z&granularity=week&chartType=bar&repos=pytorch%2Fpytorch&category=machine_type&machineTypes=linux.rocm.gpu.gfx942.1&items=linux.rocm.gpu.gfx942.1) based on current MI3xx capacity.

Partially addresses https://github.com/pytorch/pytorch/issues/166108.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166915
Approved by: https://github.com/jeffdaily
2025-11-04 04:29:29 +00:00
f70faf2b9a [xpu][feature] Introduce PeerToPeerAccess API for XPU (#166424)
# Motivation
This PR introduces support for peer-to-peer (P2P) access between devices, including querying and enabling P2P connections between two devices.
It supports two categories of allocations:
- Regular allocations;
- Expandable segment allocations.

# Additional Context
The follow-up is that we should use this feature to optimize our copy kernel when P2P is supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166424
Approved by: https://github.com/gujinghui, https://github.com/albanD
ghstack dependencies: #166299, #166292
2025-11-04 04:03:28 +00:00
167e64ba1a [xpu][feature] Support expandable segment feature for XPU (#166292)
# Motivation
This PR intends to add expandable segment feature support on XPU. This will help
- Reduce memory fragmentation;
- Gradually map physical pages into virtual address space as needed.

# Additional Context
The traditional caching allocator frequently allocates and frees device memory blocks. However, over time, with varying tensor size, the device address space becomes fragmented. Even when there's enough total free memory, a lack of contiguous space can cause large allocations to fail.
The **expandable segment** feature addresses this by dynamically extending physical memory within a reserved virtual address range, reducing fragmentation and minimizing reallocation overhead.
The potential drawbacks are
- Virtual memory overhead;
- Potential page mapping overhead;
- Increased complexity.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166292
Approved by: https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
ghstack dependencies: #166299
2025-11-04 04:03:28 +00:00
875b18d53c [xpu][feature] Introduce ExpandableSegment for XPU (#166299)
# Motivation
This PR intends to add `ExpandableSegment` struct, which is used to help support the expandable segment feature. I split it to a single PR to facilitate the code review.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166299
Approved by: https://github.com/EikanWang, https://github.com/albanD, https://github.com/gujinghui
2025-11-04 04:03:28 +00:00
eec3749c44 [DebugMode] .fwd_stack_trace for autograd bwd ops (#166842)
In #166440, didn't realize you could turn on anomaly mode while disabling NaN checks for these stacks. Adding them to `debug_mode.operators[*].fwd_stack_trace`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166842
Approved by: https://github.com/yushangdi, https://github.com/mikaylagawarecki
2025-11-04 03:28:43 +00:00
40133fe966 Fix MSCV C++ compilation error of pycore_stackref.h header (#165686)
Wraps the header in a C file and compile it using a C compiler, which should support designated initializers

Fix issue #160647

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165686
Approved by: https://github.com/williamwen42
2025-11-04 02:51:16 +00:00
f288433d3e [dynamo] Raise on as_python_constant error on getattr (#166909)
This ensures that we graph break at the right time, leading to the right
stack trace.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166909
Approved by: https://github.com/tugsbayasgalan
2025-11-04 02:45:59 +00:00
864633fca0 [xpu][test] Enable test_fxir_backend tests for XPU (#166493)
This PR enables `test_fxir_backend.py`'s tests formerly skipped xpu tests. No additional changes needed for the features.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166493
Approved by: https://github.com/angelayi, https://github.com/EikanWang
2025-11-04 02:14:46 +00:00
c21868b435 [inductor] require shape in TritonCSEVariable (#162275)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162275
Approved by: https://github.com/mlazos
ghstack dependencies: #164158
2025-11-04 02:13:41 +00:00
a0a8eca01a Fixes torch.compile(nn.ModuleList()) changes bool() behavior (#159208)
Fixes #159139
## The Cause

The bug occurs because the OptimizedModule wrapper in torch._dynamo.eval_frame doesn't call the len method. This causes Python's bool() check to fall back to the default object truthiness (always True) instead of correctly evaluating containers with len() == 0 as False.
## The Fix

A very easy fix . I just added the len method to OptimizedModule in torch._dynamo.eval_frame class to delegate the call to the original module
```python
def __len__(self):
    """
    Proxy the len() call to the original module to fix truthiness checks.
    """
    return len(self._orig_mod)
```
This successfully fixes the issue . The script now works as expected.
## Reproduction Script
```python
import torch
import torch.nn as nn

# Create an empty nn.ModuleList
original = nn.ModuleList()

# Compile it using torch.compile
compiled = torch.compile(original)

# Compare their boolean evaluations
print(f"bool(original): {bool(original)}")
print(f"bool(compiled): {bool(compiled)}")

# Trigger failure if they differ
assert bool(original) == bool(compiled), "BUG: truthiness behavior mismatch after compilation"
```
## Output

bool(original): False
bool(compiled): False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159208
Approved by: https://github.com/Lucaskabela

Co-authored-by: pushkar-hue <pushkarsharma.rtm@gmail.com>
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
2025-11-04 02:12:10 +00:00
0958f307d9 Add _heapq polyfill (#161093)
----

* Redirect `_heapq.*` functions to the python implementation
* Handle TypeError in PolyfilledFunctionVariable to raise observed exceptions
* Implement `__next__` method in IteratorVariable class

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161093
Approved by: https://github.com/Lucaskabela
2025-11-04 02:11:33 +00:00
7551507c41 [BE][Typing][Dynamo] Type torch/_dynamo/variables/builtin.py (#166745)
Provides type coverage to torch/_dynamo/variables/builtin.py

### Coverage report:
`mypy torch/_dynamo/variables/builtin.py --linecount-report /tmp/coverage_log`
Compare before to after - we go from 2213 lines and 64 funcs covered to 3212 lines and 85 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166745
Approved by: https://github.com/williamwen42
2025-11-04 01:33:10 +00:00
f92834d477 Fix unused assignments (#166791)
This PR cleans up unused assignments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166791
Approved by: https://github.com/xmfan
2025-11-04 01:07:19 +00:00
e1fc01bef8 Enable clang-tidy on some excluded headers (#166835)
This PR enables clang-tidy on some excluded headers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166835
Approved by: https://github.com/Skylion007
2025-11-04 00:37:32 +00:00
22a745737a Remove ifndef C10_MOBILE around aoti_torch_abi_version impl (#166882)
See if after the headeronly migration the mobile build would still fail.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166882
Approved by: https://github.com/mikaylagawarecki
2025-11-04 00:37:22 +00:00
ee708ea96c fix test_type_hints (#163150)
Fixes #163149

### Summary:
Fixes mypy type checking failures in `test_type_hints` by consolidating typing imports and eliminating duplicate/conflicting import patterns that caused mypy to fail resolving type annotations.

### Impact:

- `test_type_hints` works fine now
- module: tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163150
Approved by: https://github.com/Skylion007
2025-11-04 00:29:22 +00:00
64819e3701 [Pytorch] Improve conversion from bf16 on aarch64/NEON (#166880)
Summary:
Conversion from/to bfloat16 was not getting covered by conversion templates, because these used bfloat16_t as data type instead of the custom c10::BFloat16

Conversion by casting from/to bfloat16_t is broken in clang-[17, 20], fixed in clang-21.
Because Pytorch does not currently have CI running binaries compiled using clang-21, we won't implement this approach for now.

We are currently only adding conversion from bfloat16, as it can be implementing by zero-extending into a 4-byte float.

We've observed the following performance improvements, when compiling with clang-19 and targeting armv9a+sve2:

Before:

bfloat16_t->uint8  ===> 423.583us
bfloat16_t->int8  ===> 424.090us
bfloat16_t->int16  ===> 430.817us
bfloat16_t->int64  ===> 571.547us
bfloat16_t->double ===> 459.089us

After:

bfloat16_t->uint8  ===> 123.783us  ----> 342% higher throughput
bfloat16_t->int8  ===> 131.575us  -----> 322% higher throughput
bfloat16_t->int16  ===> 136.794us  ----> 315% higher throughput
bfloat16_t->int64  ===> 177.699us  ----> 322% higher throughput
bfloat16_t->double  ===> 165.556us  ---> 277% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:
buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D86119613

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166880
Approved by: https://github.com/mcfi, https://github.com/aditew01
2025-11-04 00:19:42 +00:00
79ff2c66c8 Revert "Fix unused assignments (#166791)"
This reverts commit 5125872aeb737fa20ea2ec08338e9342cba694e7.

Reverted https://github.com/pytorch/pytorch/pull/166791 on behalf of https://github.com/cyyever due to incomplete PR ([comment](https://github.com/pytorch/pytorch/pull/166791#issuecomment-3483116247))
2025-11-04 00:13:50 +00:00
665a411351 Revert "[CUDA] Skip pynvml test on platforms that don't have complete support (#159689)"
This reverts commit 68e31e2f814f9f6a9fb87381367e6b33e17c1c2b.

Reverted https://github.com/pytorch/pytorch/pull/159689 on behalf of https://github.com/izaitsevfb due to breaking internal tests [D86127316] ([comment](https://github.com/pytorch/pytorch/pull/159689#issuecomment-3483095879))
2025-11-04 00:10:14 +00:00
5c89bdb461 [MPS] Fix smooth_l1_loss backward for fp16 (#166687)
- Enable fp16 implementation for CPU, by using `convert_to_float` primitives instead of `convert_bfloat16_float` and extending bf16 implementation to half
- Simplify OpInfo definitions for the backward

Originally PR used `AT_DISPATCH_ALL_TYPES_AND(kHalf,`, but it cause ICE with gcc-13 when compiled with SVE128:
```
/opt/rh/gcc-toolset-13/root/usr/bin/c++ -DAT_BUILD_ARM_VEC256_WITH_SLEEF -DAT_PER_OPERATOR_HEADERS -DBUILD_ONEDNN_GRAPH -DCAFFE2_BUILD_MAIN_LIB -DCAFFE2_PERF_WITH_SVE=1 -DCPUINFO_SUPPORTED_PLATFORM=1 -DENABLE_IPC_FABRIC -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_POSIX_FALLOCATE=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DKINETO_NAMESPACE=libkineto -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DUSE_C10D_GLOO -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_MIMALLOC -DUSE_RPC -DUSE_TENSORPIPE -DXNN_LOG_LEVEL=0 -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/pytorch/build/aten/src -I/pytorch/aten/src -I/pytorch/build -I/pytorch -I/pytorch/nlohmann -I/pytorch/moodycamel -I/pytorch/third_party/mimalloc/include -I/pytorch/torch/csrc/api -I/pytorch/torch/csrc/api/include -I/pytorch/caffe2/aten/src/TH -I/pytorch/build/caffe2/aten/src/TH -I/pytorch/build/caffe2/aten/src -I/acl -I/acl/include -I/pytorch/build/caffe2/../aten/src -I/pytorch/torch/csrc -I/pytorch/torch/headeronly -I/pytorch/third_party/miniz-3.0.2 -I/pytorch/third_party/kineto/libkineto/include -I/pytorch/third_party/kineto/libkineto/src -I/pytorch/third_party/cpp-httplib -I/pytorch/aten/src/ATen/.. -I/pytorch/third_party/FXdiv/include -I/pytorch/c10/.. -I/pytorch/third_party/pthreadpool/include -I/pytorch/third_party/cpuinfo/include -I/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include -I/pytorch/third_party/NNPACK/include -I/pytorch/third_party/FP16/include -I/pytorch/third_party/tensorpipe -I/pytorch/build/third_party/tensorpipe -I/pytorch/third_party/tensorpipe/third_party/libnop/include -I/pytorch/third_party/kleidiai -I/pytorch/third_party/fmt/include -I/pytorch/build/third_party/ideep/mkl-dnn/include -I/pytorch/third_party/ideep/mkl-dnn/src/../include -I/pytorch/third_party/onnx -I/pytorch/build/third_party/onnx -I/pytorch/third_party/flatbuffers/include -isystem /pytorch/build/third_party/gloo -isystem /pytorch/cmake/../third_party/gloo -isystem /pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /pytorch/third_party/protobuf/src -isystem /opt/OpenBLAS/include -isystem /pytorch/third_party/XNNPACK/include -isystem /pytorch/cmake/../third_party/eigen -isystem /pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /pytorch/third_party/ideep/include -isystem /pytorch/INTERFACE -isystem /pytorch/third_party/nlohmann/include -isystem /pytorch/third_party/concurrentqueue -isystem /pytorch/build/include -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_PYTORCH_QNNPACK -DAT_BUILD_ARM_VEC256_WITH_SLEEF -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION -O3 -DNDEBUG -DNDEBUG -fPIC -fdiagnostics-color=always -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -D__NEON__ -DBLAS_HAS_SBGEMM -Wall -Wextra -Wdeprecated -Wunused -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wredundant-move -Wno-interference-size -Wno-maybe-uninitialized -fvisibility=hidden -pthread -fopenmp -O3  -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256 -DCPU_CAPABILITY=SVE256 -DCPU_CAPABILITY_SVE256 -MD -MT caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp.o -MF caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp.o.d -o caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp.o -c /pytorch/build/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp
during RTL pass: expand
In file included from /pytorch/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp:6,
                 from /pytorch/build/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp:1:
/pytorch/aten/src/ATen/native/cpu/Loops.h: In function ‘void at::native::SVE256::vectorized_loop(char**, int64_t, int64_t, func_t&&, vec_func_t&&) [with func_t = at::native::{anonymous}::smooth_l1_backward_cpu_kernel(at::TensorIterator&, const c10::Scalar&, double)::<lambda()>::<lambda()>::<lambda(scalar_t, scalar_t, scalar_t)>&; vec_func_t = at::native::{anonymous}::smooth_l1_backward_cpu_kernel(at::TensorIterator&, const c10::Scalar&, double)::<lambda()>::<lambda()>::<lambda(at::vec::SVE256::Vectorized<c10::Half>, at::vec::SVE256::Vectorized<c10::Half>, at::vec::SVE256::Vectorized<c10::Half>)>&]’:
/pytorch/aten/src/ATen/native/cpu/Loops.h:200:1: internal compiler error: in expand_insn, at optabs.cc:8185
  200 | vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
      | ^~~~~~~~~~~~~~~
Please submit a full bug report, with preprocessed source.
See <http://bugzilla.redhat.com/bugzilla> for instructions.
Preprocessed source stored into /tmp/ccgYMlTo.out file, please attach this to your bugreport.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166687
Approved by: https://github.com/Skylion007
2025-11-03 23:54:54 +00:00
7b64ad906c [FSDP][Replicate] got rid of reshard_after_forward and updated test cases (#166469)
**Summary:** I have gotten of reshard_after_forward and shard_placement as inputs for replicate as there will be no sharding. I have also updated all the necessary tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166469
Approved by: https://github.com/weifengpy
ghstack dependencies: #166433, #166459
2025-11-03 23:48:18 +00:00
d944279def [FSDP][Replicate] added two replicate overload declarations and changed device_mesh to mesh (#166459)
**Summary:** Just like in fully_shard, I added two overload replicate functions. The `@overload` declarations are necessary because the `@contract` decorator uses `ParamSpec` to capture function parameters, which creates a generic `_ContractFn` protocol signature (`*args: _P.args, **kwargs: _P.kwargs`) that Pyrefly cannot properly type-check when calling the function with explicit keyword arguments. In addition, to make the api cleaner I changed device_mesh input argument to mesh to match fully_shard formatting.

**Test Cases**
1.  pytest test/distributed/_composable/test_replicate_with_fsdp.py
2. pytest test/distributed/_composable/test_replicate_training.py
3. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166459
Approved by: https://github.com/weifengpy
ghstack dependencies: #166433
2025-11-03 23:35:21 +00:00
5048e4701d explicitly remove call_mod_node_to_replace after inlining the submodule in const_fold._inline_module` (#166871)
Summary:
https://github.com/pytorch/pytorch/pull/166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.

This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.

This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.

With pr #166609 step 3 no longer erases the unused module. As an example, exported graph
```
 graph():
    %x : [num_users=2] = placeholder[target=x]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
    %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
    return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
    %x : [num_users=3] = placeholder[target=x]
    %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
    %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
    %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
    return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.

## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.

Test Plan:
Added a test in `test_fx_const_fold.py`.

The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.

With the PR, the module is erased from main graph correctly.

Differential Revision: D86056354

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166871
Approved by: https://github.com/blaine-rister, https://github.com/mlazos
2025-11-03 23:23:10 +00:00
616314cfd5 [FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp (#166433)
**Summary:** I have created a new composable replicate api that's integrated into FSDP's codebase with minimal changes. The key changes I made are when we use DDPMeshInfo, we use Replicate placements, prevent initial sharding of parameters, set worldsize to 1 to skip allgathers and reducescatter.

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py
2. pytest test_pp_composability.py
3. pytest test_replicate_with_fsdp.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166433
Approved by: https://github.com/weifengpy
2025-11-03 23:20:23 +00:00
2b7e4c3ef2 [DCP] Add option to use PrefixStore to create checkpoint background process (#166560)
Summary:
DCP checkpoint background process currently determines the port used for pg via get_free_port().

During checkpoint background process initialization, gloo pg init occasionally times out on the first call but succeeds in a subsequent call.

We hypothesized that the timeouts are related to the port being used, and the solution would be to create the pg with PrefixStore and reuse the master port.

This diff adds the option for checkpoint background process to use PrefixStore with MASTER_ADDR + MASTER_PORT.

The default behavior is unchanged. Enabling the new PrefixStore behavior requires setting "DCP_USE_PREFIX_STORE" env var to "1".

context:
 https://fb.workplace.com/groups/319878845696681/permalink/1516883985996155/

Differential Revision: D84928180

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166560
Approved by: https://github.com/meetv18
2025-11-03 23:08:12 +00:00
6c98657239 Add some Triton related suppressions that don't show on CI (#166868)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166868
Approved by: https://github.com/maggiemoss, https://github.com/zou3519
2025-11-03 22:54:50 +00:00
86b2d82e84 Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)"
This reverts commit 94f2657c4b534136aa8958bc35d44ceac5ccd60c.

Reverted https://github.com/pytorch/pytorch/pull/166165 on behalf of https://github.com/izaitsevfb due to breaks test_LinearAndSoftmax_codegen test ([comment](https://github.com/pytorch/pytorch/pull/166165#issuecomment-3482926991))
2025-11-03 22:52:41 +00:00
d6b27c4cef upgrade to oneDNN v3.10 2025-10-22 14:10:01 +00:00
457 changed files with 12989 additions and 3153 deletions

View File

@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
ARG DEVTOOLSET_VERSION=11
ARG DEVTOOLSET_VERSION=13
RUN yum -y update
RUN yum -y install epel-release
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
RUN yum -y install glibc-langpack-en
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
# Just add everything as a safe.directory for git since these will be used in multiple places with git
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
# Install CUDA
FROM base as cuda
ARG CUDA_VERSION=12.6
ARG DEVTOOLSET_VERSION=13
RUN rm -rf /usr/local/cuda-*
ADD ./common/install_cuda.sh install_cuda.sh
COPY ./common/install_nccl.sh install_nccl.sh
@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
# Preserve CUDA_VERSION for the builds
ENV CUDA_VERSION=${CUDA_VERSION}
# Make things in our path by default
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM cuda as cuda12.6
RUN bash ./install_cuda.sh 12.6
@ -68,8 +70,22 @@ FROM cuda as cuda13.0
RUN bash ./install_cuda.sh 13.0
ENV DESIRED_CUDA=13.0
FROM ${ROCM_IMAGE} as rocm
FROM ${ROCM_IMAGE} as rocm_base
ARG DEVTOOLSET_VERSION=13
ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
# Install devtoolset on ROCm base image
RUN yum -y update && \
yum -y install epel-release && \
yum -y install glibc-langpack-en && \
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM rocm_base as rocm
ARG PYTORCH_ROCM_ARCH
ARG DEVTOOLSET_VERSION=13
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ADD ./common/install_mkl.sh install_mkl.sh
RUN bash ./install_mkl.sh && rm install_mkl.sh
@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
# Final step
FROM ${BASE_TARGET} as final
ARG DEVTOOLSET_VERSION=13
COPY --from=openssl /opt/openssl /opt/openssl
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
COPY --from=conda /opt/conda /opt/conda

View File

@ -63,7 +63,7 @@ docker build \
--target final \
--progress plain \
--build-arg "BASE_TARGET=${BASE_TARGET}" \
--build-arg "DEVTOOLSET_VERSION=11" \
--build-arg "DEVTOOLSET_VERSION=13" \
${EXTRA_BUILD_ARGS} \
-t ${tmp_tag} \
$@ \

View File

@ -261,9 +261,9 @@ case "$tag" in
PYTHON_VERSION=3.10
CUDA_VERSION=12.8.1
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11)
pytorch-linux-jammy-aarch64-py3.10-gcc13)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
GCC_VERSION=13
ACL=yes
VISION=yes
OPENBLAS=yes
@ -271,9 +271,19 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
pytorch-linux-jammy-aarch64-py3.10-clang21)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13
ACL=yes
VISION=yes
OPENBLAS=yes

View File

@ -1 +1 @@
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
fi
fi

View File

@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
# Need the official toolchain repo to get alternate packages
add-apt-repository ppa:ubuntu-toolchain-r/test
apt-get update
apt-get install -y g++-$GCC_VERSION
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
# Cleanup package manager
apt-get autoclean && apt-get clean

View File

@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -1 +1 @@
3.5.0
3.5.1

View File

@ -6,8 +6,8 @@ set -eou pipefail
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# Folders for the build
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
# Fetch magma sources and verify checksum
pushd ${PACKAGE_DIR}
git clone https://github.com/icl-utk-edu/magma
git clone https://github.com/jeffdaily/magma
pushd magma
git checkout ${MAGMA_VERSION}
popd

View File

@ -1,11 +1,11 @@
name: 🚀 Release highlight for proposed Feature
name: 🚀 New Feature for Release
description: Submit a Release highlight for proposed Feature
labels: ["release-feature-request"]
body:
- type: textarea
attributes:
label: Release highlight for proposed Feature
label: New Feature for Release
description: >
Example: “A torch.special module, analogous to SciPy's special module.”
- type: input

View File

@ -38,9 +38,9 @@ runs:
run: |
python3 .github/scripts/pytest_cache.py \
--download \
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier $GITHUB_REF \
--job_identifier $JOB_IDENTIFIER \
--temp_dir $RUNNER_TEMP \
--repo $REPO \
--bucket $BUCKET \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
--pr_identifier "$GITHUB_REF" \
--job_identifier "$JOB_IDENTIFIER" \
--temp_dir "$RUNNER_TEMP" \
--repo "$REPO" \
--bucket "$BUCKET" \

View File

@ -47,11 +47,11 @@ runs:
run: |
python3 .github/scripts/pytest_cache.py \
--upload \
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier $GITHUB_REF \
--job_identifier $JOB_IDENTIFIER \
--sha $SHA \
--test_config $TEST_CONFIG \
--shard $SHARD \
--repo $REPO \
--temp_dir $RUNNER_TEMP \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
--pr_identifier "$GITHUB_REF" \
--job_identifier "$JOB_IDENTIFIER" \
--sha "$SHA" \
--test_config "$TEST_CONFIG" \
--shard "$SHARD" \
--repo "$REPO" \
--temp_dir "$RUNNER_TEMP" \

View File

@ -1 +1 @@
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
ad5816f0eee1c873df1b7d371c69f1f811a89387

125
.github/copilot-instructions.md vendored Normal file
View File

@ -0,0 +1,125 @@
# PyTorch Copilot Instructions
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
## Architecture Overview
### Core Components
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
- **torch/** - Python bindings and public API
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
- `torch/csrc/jit/` - TorchScript JIT compiler
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
- **tools/** - Build scripts, autograd derivatives, code generation
### The Code Generation Workflow
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
1. Declares operator signatures, variants (function/method), and dispatch behavior
2. Gets processed by `torchgen/` to generate C++/Python bindings
3. Produces headers in `build/aten/src/ATen/` during compilation
Example entry structure:
```yaml
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
variants: function, method
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
```
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
## Development Workflows
### Building from Source
**Never run `setup.py` directly** - use pip with editable install:
```bash
python -m pip install --no-build-isolation -v -e .
```
Speed up builds:
- `DEBUG=1` - Debug symbols with `-g -O0`
- `USE_CUDA=0` - Skip CUDA compilation
- `BUILD_TEST=0` - Skip C++ test binaries
- Install `ninja` (`pip install ninja`) for faster builds
- Use `ccache` for incremental compilation caching
Rebuild specific targets: `(cd build && ninja <target>)`
### Testing
**Critical**: DO NOT run entire test suites. Run specific tests only:
```bash
python test/test_torch.py TestTorch.test_specific_case
```
**Test structure**: All tests use `torch.testing._internal.common_utils`:
```python
from torch.testing._internal.common_utils import run_tests, TestCase
class TestFeature(TestCase):
def test_something(self):
# Use self.assertEqual for tensor comparisons
pass
if __name__ == "__main__":
run_tests()
```
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
### Linting
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
## Project-Specific Conventions
### Memory and Storage
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
- CUDA device info lives in storage objects
### Python-C++ Integration (`torch/csrc/`)
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
### Dispatch System
- PyTorch uses operator dispatch to route calls to backend-specific kernels
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
## Git Workflow (AI Agent Specific)
When preparing PRs from this environment:
```bash
git stash -u
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
git stash pop
# Resolve conflicts if necessary
```
## Common Gotchas
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
4. **No internet access** - DO NOT attempt to install dependencies during development
## Key Files Reference
- `AGENTS.md` - Instructions specific to AI coding agents
- `CONTRIBUTING.md` - Comprehensive human contributor guide
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
- `aten/src/ATen/native/README.md` - Operator implementation guide
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
## Performance Debugging
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.

View File

@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
"12.6": "12.6.3",
"12.8": "12.8.1",
"12.9": "12.9.1",
"13.0": "13.0.2",
"13.0": "13.0.0",
}
CUDA_ARCHES_CUDNN_VERSION = {
"12.6": "9",

View File

@ -97,8 +97,8 @@ jobs:
shell: bash
run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 4 ]]; then
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
exit 1
fi

View File

@ -344,5 +344,21 @@ jobs:
if-no-files-found: ignore
path: ./**/core.[1-9]*
- name: Authenticate with AWS
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
# The max duration enforced by the server side
role-duration-seconds: 18000
aws-region: us-east-1
- name: Upload the benchmark results
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
with:
benchmark-results-dir: test/test-reports
dry-run: false
schema-version: v3
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Teardown XPU
uses: ./.github/actions/teardown-xpu

View File

@ -77,9 +77,11 @@ jobs:
pytorch-linux-noble-riscv64-py3.12-gcc14
]
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358

View File

@ -8,6 +8,7 @@ on:
- docker.Makefile
- .github/workflows/docker-release.yml
- .github/scripts/generate_docker_release_matrix.py
- .github/scripts/generate_binary_build_matrix.py
push:
branches:
- nightly

View File

@ -72,7 +72,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.arm64.m7g.4xlarge
build-environment: linux-jammy-aarch64-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },

View File

@ -115,10 +115,10 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
]}
secrets: inherit

View File

@ -84,13 +84,13 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
]}
build-additional-packages: "vision audio torchao"

View File

@ -33,7 +33,7 @@ jobs:
with:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-aarch64-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
runner: linux.arm64.m7g.4xlarge
test-matrix: |
{ include: [

View File

@ -60,7 +60,7 @@ jobs:
with:
build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },

View File

@ -204,6 +204,7 @@ jobs:
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
]}
secrets: inherit
@ -221,7 +222,7 @@ jobs:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
secrets: inherit
inductor-build:

View File

@ -211,7 +211,6 @@ exclude_patterns = [
'**/*pb.h',
'**/*inl.h',
'aten/src/ATen/cpu/FlushDenormal.cpp',
'aten/src/ATen/cpu/Utils.cpp',
'aten/src/ATen/cpu/vml.h',
'aten/src/ATen/CPUFixedAllocator.h',
'aten/src/ATen/Parallel*.h',
@ -230,8 +229,6 @@ exclude_patterns = [
'c10/util/win32-headers.h',
'c10/test/**/*.h',
'third_party/**/*',
'torch/csrc/api/include/torch/nn/modules/common.h',
'torch/csrc/api/include/torch/linalg.h',
'torch/csrc/autograd/generated/**',
'torch/csrc/distributed/**/*.cu',
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
@ -243,7 +240,6 @@ exclude_patterns = [
'torch/csrc/utils/generated_serialization_types.h',
'torch/csrc/utils/pythoncapi_compat.h',
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
'aten/src/ATen/ExpandBase.h',
]
init_command = [
'python3',

View File

@ -234,7 +234,17 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
option(USE_LSAN "Use Leak Sanitizer" OFF)
option(USE_TSAN "Use Thread Sanitizer" OFF)
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
# If USE_CUDA is already defined in cache, it means user explicitly set it
if(DEFINED CACHE{USE_CUDA})
set(_USE_CUDA_EXPLICITLY_SET TRUE)
else()
set(_USE_CUDA_EXPLICITLY_SET FALSE)
endif()
option(USE_CUDA "Use CUDA" ON)
option(USE_XPU "Use XPU" ON)
cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `mypy`](#running-mypy)
- [Running `pyrefly`](#running-pyrefly)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,15 +350,32 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `mypy`
#### Running `pyrefly`
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for more information on how to set up `mypy` and tackle type annotation
tasks.
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
### C++ Unit Testing

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
### Using distributed features

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
#endif
namespace at {
namespace {
/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
@ -41,16 +39,6 @@ namespace {
->rnn
*/
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Float32Backend str2backend(const std::string& name) {
if (name == "generic")
return Float32Backend::GENERIC;
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
} else {
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
}
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
}
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;

View File

@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
auto vals = svreinterpret_u16_bf16(values);
vals = sveor_u16_x(ptrue, vals, mask);
return svreinterpret_bf16_u16(vals);
};
}
Vectorized<BFloat16> round() const;
Vectorized<BFloat16> tan() const;
Vectorized<BFloat16> tanh() const;
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
return convert_float_bfloat16(v1, v2); \
}
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
DEFINE_BF16_FUNC_VIA_FLOAT(log);
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
DEFINE_BF16_FUNC_VIA_FLOAT(round);
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
DEFINE_BF16_FUNC_VIA_FLOAT(log)
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
DEFINE_BF16_FUNC_VIA_FLOAT(round)
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
const Vectorized<BFloat16>& other) const {

View File

@ -191,22 +191,37 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
}
#endif
#ifdef __ARM_FEATURE_BF16
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
CONVERT_TEMPLATE(bfloat16_t, int8_t)
CONVERT_TEMPLATE(bfloat16_t, int16_t)
CONVERT_TEMPLATE(bfloat16_t, int32_t)
CONVERT_TEMPLATE(bfloat16_t, int64_t)
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
CONVERT_TEMPLATE(bfloat16_t, float)
CONVERT_TEMPLATE(bfloat16_t, double)
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
CONVERT_TEMPLATE(int8_t, bfloat16_t)
CONVERT_TEMPLATE(int16_t, bfloat16_t)
CONVERT_TEMPLATE(int32_t, bfloat16_t)
CONVERT_TEMPLATE(int64_t, bfloat16_t)
CONVERT_TEMPLATE(float, bfloat16_t)
CONVERT_TEMPLATE(double, bfloat16_t)
template <typename to_type>
inline void convertFromBf16Impl(
const c10::BFloat16* __restrict src,
to_type* __restrict dst,
int64_t n) {
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
uint64_t len = static_cast<uint64_t>(n);
for (uint64_t i = 0; i < len; i++) {
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
float tmpF;
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
dst[i] = static_cast<to_type>(tmpF);
}
}
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
template <> \
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
return convertFromBf16Impl<to_type>(src, dst, n); \
}
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
CONVERT_FROM_BF16_TEMPLATE(int8_t)
CONVERT_FROM_BF16_TEMPLATE(int16_t)
CONVERT_FROM_BF16_TEMPLATE(int32_t)
CONVERT_FROM_BF16_TEMPLATE(int64_t)
CONVERT_FROM_BF16_TEMPLATE(float)
CONVERT_FROM_BF16_TEMPLATE(double)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
@ -247,8 +262,6 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
#endif
#endif
template <typename src_t>
struct VecConvert<
float,

View File

@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
uint32_t mask = -1;
#endif
void * alpha_ptr = &alpha;
void * beta_ptr = &beta;
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
#ifndef USE_ROCM
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
#else
const bool lie_to_cublaslt = false;
#endif
if (lie_to_cublaslt) {
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
FakeBdesc.descriptor(),
FakeCdesc.descriptor(),
FakeCdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
}
if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
}

View File

@ -55,6 +55,14 @@ struct numeric_limits<int8_t> {
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
};
template <>
struct numeric_limits<uint16_t> {
static inline __host__ __device__ uint16_t lowest() { return 0; }
static inline __host__ __device__ uint16_t max() { return UINT16_MAX; }
static inline __host__ __device__ uint16_t lower_bound() { return 0; }
static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; }
};
template <>
struct numeric_limits<int16_t> {
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
@ -63,6 +71,14 @@ struct numeric_limits<int16_t> {
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
};
template <>
struct numeric_limits<uint32_t> {
static inline __host__ __device__ uint32_t lowest() { return 0; }
static inline __host__ __device__ uint32_t max() { return UINT32_MAX; }
static inline __host__ __device__ uint32_t lower_bound() { return 0; }
static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; }
};
template <>
struct numeric_limits<int32_t> {
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
@ -71,6 +87,21 @@ struct numeric_limits<int32_t> {
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
};
template <>
struct numeric_limits<uint64_t> {
#ifdef _MSC_VER
static inline __host__ __device__ uint64_t lowest() { return 0; }
static inline __host__ __device__ uint64_t max() { return _UI64_MAX; }
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; }
#else
static inline __host__ __device__ uint64_t lowest() { return 0; }
static inline __host__ __device__ uint64_t max() { return UINT64_MAX; }
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; }
#endif
};
template <>
struct numeric_limits<int64_t> {
#ifdef _MSC_VER

View File

@ -24,7 +24,13 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -1009,12 +1009,25 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) {
}
}
static Tensor send_to_meta(const Tensor& self, const Device& device) {
Tensor out_meta;
if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) {
out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device));
out_meta.unsafeGetTensorImpl()->set_wrapped_number(true);
} else {
out_meta = self.to(device);
}
return out_meta;
}
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
auto out_device = correct_out_device(self, other);
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta);
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
}
@ -1023,7 +1036,9 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta);
if (self._is_zerotensor()) {
if (other._is_zerotensor()) {
@ -1052,8 +1067,9 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto meta_out = at::_ops::add_Tensor::redispatch(
meta_dks, self.to(device_), other.to(device_), alpha);
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha);
auto get_out_like = [&] (const Tensor& tensor)
{

View File

@ -1,5 +1,6 @@
#include <ATen/core/ATen_fwd.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
@ -1710,11 +1711,37 @@ Tensor narrow_symint(
"], but got ",
start,
")")
if (start < 0) {
start = start + cur_size;
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
if (cond1 || cond2) {
if (cond1) {
start = start + cur_size;
}
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
"start (",
start,
") + length (",
length,
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
}
// Unbacked start handling!
// Bounds check without converting start:
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
// length <= 0
// - If start >= 0: need start + length <= cur_size
auto end = start + length;
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
(start.sym_lt(0).sym_and((end).sym_le(0)))
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
"start (",
start,
") + length (",
@ -1722,7 +1749,28 @@ Tensor narrow_symint(
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
return at::slice_symint(self, dim, start, end, 1);
} else {
// Cannot statically determine the condition due to unbacked.
// This is an interesting situation; when start is negative and
// start + length == 0, slice and narrow do different things.
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
// pass curr_size instead of 0. Otherwise, they would do the same thing.
// This says at runtime: if start < 0 and end == 0, then pass curr_size
// instead of 0.
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
auto result =
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
// Ensure slice allocated unbacked size is specialized to length.
SymInt new_size = result.sym_size(dim);
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
return result;
}
}
// This overload exists purely for XLA, because they wanted to pass in
@ -1736,8 +1784,8 @@ Tensor narrow_tensor_symint(
start.dim() == 0 &&
isIntegralType(start.scalar_type(), /*includeBool=*/false),
"start must be an 0-dim integral Tensor.");
int64_t st = start.item<int64_t>();
return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
c10::SymInt st = start.item().toSymInt();
return at::narrow_symint(self, dim, std::move(st), std::move(length));
}
std::

View File

@ -293,7 +293,7 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
, empty(size <= 0) {}
inline Vec unnormalize(const Vec &in) const {
return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
return (in + Vec(static_cast<scalar_t>(1))) * Vec(scaling_factor) - Vec(static_cast<scalar_t>(0.5));
}
inline Vec clip_coordinates(const Vec &in) const {
@ -831,7 +831,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
// constant used in cubic convolution
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
const Vec A = Vec(-0.75);
const Vec A = Vec(static_cast<scalar_t>(-0.75));
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
: inp_H(input.size(2))

View File

@ -92,7 +92,8 @@ void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
ScalarType dtype = iter.dtype(0);
if (dtype == kBFloat16) {
if (at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "smooth_l1_backward_cpu_out", [&]() {
auto norm_val = norm.to<float>();
float beta_val(beta);
auto norm_val_vec = Vectorized<float>(norm_val);
@ -101,9 +102,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
const auto zero_vec = Vectorized<float>(0);
const auto pos_1_vec = Vectorized<float>(1);
cpu_kernel_vec(iter,
[=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
const auto x = float(input) - float(target);
if (x <= -beta){
if (x <= -beta) {
return -norm_val * float(grad_output);
}else if (x >= beta){
return norm_val * float(grad_output);
@ -112,14 +113,14 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
}
},
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
// using two blendv calls to simulate the 3 cases
// 1 if x >= beta
// -1 if x <= -beta
// x / beta if |x| < beta
auto [input0, input1] = convert_bfloat16_float(input);
auto [target0, target1] = convert_bfloat16_float(target);
auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
auto [input0, input1] = convert_to_float(input);
auto [target0, target1] = convert_to_float(target);
auto [grad_output0, grad_output1] = convert_to_float(grad_output);
auto x = input0 - target0;
auto pos_or_neg_1_vec = Vectorized<float>::blendv(
neg_1_vec, pos_1_vec, x > zero_vec);
@ -135,9 +136,10 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
output = Vectorized<float>::blendv(
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
input1 = norm_val_vec * output * grad_output1;
return convert_float_bfloat16(input0, input1);
return convert_from_float<scalar_t>(input0, input1);
}
);
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
auto norm_val = norm.to<scalar_t>();

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -5,6 +5,7 @@
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIterator.h>
#include <ATen/OpMathType.h>
@ -78,12 +79,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) {
reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
[=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
[=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
[=](Vec a, Vec b) -> Vec { return minimum(a, b); });
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
}
}
@ -103,12 +104,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) {
reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
[=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
[=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
[=](Vec a, Vec b) -> Vec { return maximum(a, b); });
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
}
}
@ -199,7 +200,7 @@ void aminmax_allreduce_kernel(
}
);
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
reduce_all_impl_vec_two_outputs<scalar_t>(
@ -214,7 +215,7 @@ void aminmax_allreduce_kernel(
[=](Vec a, Vec b) -> Vec { return minimum(a, b); },
[=](Vec a, Vec b) -> Vec { return maximum(a, b); }
);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
}
}

View File

@ -3,6 +3,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/OpMathType.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
@ -338,43 +339,24 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
if (iter.dtype() == kLong) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
using scalar_t = int64_t;
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
return;
}
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
static_cast<double>(upper_bound<scalar_t>()));
});
upper_bound<scalar_t>());
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void max_values_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] {
AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return maximum(a, b); },
lower_bound<scalar_t>());
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void argmax_kernel_impl(TensorIterator &iter) {

View File

@ -11,6 +11,7 @@
#include <vector>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/Parallel.h>
#include <ATen/NumericUtils.h>
#include <ATen/TensorIterator.h>
@ -106,7 +107,7 @@ void min_kernel_impl(
bool keepdim) {
int64_t self_dim_size = ensure_nonempty_size(self, dim);
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] {
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
scalar_t* result_data, int64_t* indice_data,
const scalar_t* self_data, auto self_dim_stride) {
@ -128,7 +129,7 @@ void min_kernel_impl(
*indice_data = index;
}
);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
}
void max_kernel_impl(
@ -139,7 +140,7 @@ void max_kernel_impl(
bool keepdim) {
int64_t self_dim_size = ensure_nonempty_size(self, dim);
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] {
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
scalar_t* result_data, int64_t* indice_data,
const scalar_t* self_data, auto self_dim_stride) {
@ -161,7 +162,7 @@ void max_kernel_impl(
*indice_data = index;
}
);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
}
void aminmax_kernel(
@ -186,7 +187,7 @@ void aminmax_kernel(
return;
}
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
scalar_t* min_result_data, scalar_t* max_result_data,
const scalar_t* self_data, auto self_dim_stride) {
@ -209,7 +210,7 @@ void aminmax_kernel(
*max_result_data = max_number;
}
);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half);
}
void where_kernel_impl(TensorIterator &iter) {

View File

@ -22,6 +22,9 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -5,7 +5,6 @@
#include <array>
#include <type_traits>
#include <ATen/core/TensorBase.h>
#include <ATen/ceil_div.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/cuda/CUDAContext.h>
@ -74,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
if (is_gather_like && num_indices==1) {
const size_t element_size = iter.element_size(0);
constexpr size_t alignment = 16;
@ -84,16 +82,9 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
auto ind_dim_size = index_size[0];
auto inp_stride_bytes = index_stride[0];
auto out_stride_bytes = iter.strides(0)[1];
// avoid grid overflow in the fast kernel
const int64_t vec_chunks = ceil_div(slice_size, alignment);
const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd);
const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
// if it's an eligible grid we use the fast path, otherwise default to slower path
if (blocks_per_slice_upper <= max_grid_y) {
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
return;
}
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
return;
}
}

View File

@ -13,11 +13,12 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
if (allow_neg_indices) {
ind = (ind < 0) ? ind + ind_dim_size : ind;
}
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
if (off >= slice_size) return;
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
// off is guaranteed to be within int32 limits
for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) {
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
}
}
@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int
auto num_threads = at::round_up(
at::ceil_div(slice_size_in_bytes, Alignment),
static_cast<int64_t>(C10_WARP_SIZE));
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
grid_y = std::min(static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y);
dim3 grid = {static_cast<uint32_t>(num_ind), grid_y, 1};
auto block = std::min(max_num_threads, num_threads);
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);

View File

@ -1,5 +1,6 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReduceAllOps.h>
@ -28,22 +29,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) {
}
void aminmax_allreduce_launch_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] {
AT_DISPATCH_V2(
iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] {
_min_max_values_kernel_cuda_impl<scalar_t>(iter);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void aminmax_launch_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() {
AT_DISPATCH_V2(
iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinMaxOps<scalar_t, scalar_t, int32_t>{},
thrust::pair<scalar_t, scalar_t>(
at::numeric_limits<scalar_t>::upper_bound(),
at::numeric_limits<scalar_t>::lower_bound()));
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
} // namespace at::native

View File

@ -1,5 +1,6 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReduceAllOps.h>
@ -33,27 +34,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) {
}
void max_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() {
AT_DISPATCH_V2(
iter.dtype(), "max_values_cuda", AT_WRAP([&]() {
max_values_kernel_cuda_impl<scalar_t>(iter);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void max_launch_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() {
AT_DISPATCH_V2(
iter.input_dtype(), "max_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MaxOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(
at::numeric_limits<scalar_t>::lower_bound(), 0));
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void max_all_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] {
AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] {
max_values_kernel_cuda_impl<scalar_t>(iter);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda)

View File

@ -12,6 +12,7 @@
#include <ATen/NumericUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/NumericLimits.cuh>
@ -33,24 +34,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) {
}
void min_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() {
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void min_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() {
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
void min_all_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] {
AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] {
min_values_kernel_cuda_impl<scalar_t>(iter);
});
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda)

View File

@ -59,6 +59,24 @@
// forward declare
class cublasCommonArgs;
#ifndef _WIN32
namespace fbgemm_gpu {
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
// To update supported ops means a submodule bump, which is.. painful. Instead, we
// can simply forward-declare the methods we want to use.. Works at least as a short-term
// thing, but should still be fixed somewhere/somehow.
at::Tensor f4f4bf16(
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
bool use_mx);
} // namespace fbgemm_gpu
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
@ -767,33 +785,6 @@ _scaled_rowwise_rowwise(
return out;
}
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
// and strides become somewhat meaningless
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
if (scale_type == ScalingType::BlockWise1x128) {
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
} else if (scale_type == ScalingType::BlockWise128x128) {
TORCH_CHECK_VALUE(check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
ceil_div<int64_t>(t.size(1), 128)),
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
TORCH_CHECK(check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
}
}
void
_check_deepseek_support() {
#ifndef USE_ROCM
@ -806,7 +797,7 @@ _check_deepseek_support() {
}
// Only in cublasLt >= 12.9
TORCH_CHECK_NOT_IMPLEMENTED(
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
);
#endif
@ -823,23 +814,61 @@ _scaled_block1x128_block1x128(
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
// As: [M x K // 128], stride: [1, M]
// Bs: [N x K // 128], stride: [1, N]
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
// check types
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
const int64_t M = mat_a.sizes()[0];
const int64_t K = mat_a.sizes()[1];
const int64_t N = mat_b.sizes()[1];
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == M &&
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == M ||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
),
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == N &&
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == N ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
);
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -861,24 +890,65 @@ _scaled_block128x128_block1x128(
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
// A: [M, K], B: [K, N] are FP8, scales are fp32
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
// Bs: [N x K // 128], stride: [1, N]
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
const int64_t M = mat_a.sizes()[0];
const int64_t K = mat_a.sizes()[1];
const int64_t N = mat_b.sizes()[1];
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
(
scale_a.size(1) == 1 &&
scale_a.stride(1) == 1
)
),
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == N &&
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == N ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
);
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -900,24 +970,62 @@ _scaled_block1x128_block128x128(
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
// A: [M, K], B: [K, N] are FP8, scales are fp32
// As: [M x K // 128], stride: [1, M]
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
int64_t M = mat_a.size(0);
int64_t K = mat_a.size(1);
int64_t N = mat_b.size(1);
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == M &&
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == M ||
(
scale_a.size(1) == 1 &&
scale_a.stride(1) == 1
)
),
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
);
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -997,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
auto K_multiplier = 2;
#ifdef USE_ROCM
// AMD
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
#else
// NVIDIA
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
#endif
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifdef USE_ROCM
// AMD
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
#else
// NVIDIA
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
#ifdef USE_ROCM
// AMD
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
@ -1031,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
#else
// NVIDIA
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
// but we have one we need to use. Two clear options are to copy into
// our output (slow), or use a move-assignment-operator (faster).
// However, the compiler can complain about the explicit move preventing
// copy elision because the return from f4f4bf16 is a temporary object.
// So we don't explicitly move, and trust the compiler here...
// In the longer term this should be fixed on the FBGemm side.
out = fbgemm_gpu::f4f4bf16(
mat_a,
mat_b.transpose(-2, -1),
scale_a,
scale_b,
std::nullopt, /* global_scale */
true /* use_mx */
);
return out;
#endif
#endif
}
Tensor&
@ -1160,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
// Handle fp4 packed-K dimension
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
mat_a.sizes()[1] % 16 == 0,
K_multiplier * mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
mat_a.sizes()[1],
K_multiplier * mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -0,0 +1,462 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -157,10 +157,10 @@ bool onednn_strides_check(const Tensor& src) {
return true;
dnnl_dims_t blocks = {0};
int perm[DNNL_MAX_NDIMS] = {0};
std::array<int, DNNL_MAX_NDIMS> perm = {0};
for (int d = 0; d < md_ndims; ++d) {
// no strides check needed for empty tensor
if (md_padded_dims[d] == nullptr)
if ((*md_padded_dims)[d] == 0)
return true;
// no strides verification for runtime dims
@ -178,14 +178,15 @@ bool onednn_strides_check(const Tensor& src) {
// A custom comparator to yield linear order on perm
auto idx_sorter = [&](const int a, const int b) -> bool {
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
if (strides[a] == strides[b] &&
(*md_padded_dims)[a] == (*md_padded_dims)[b])
return a < b;
else if (strides[a] == strides[b])
return md_padded_dims[a] < md_padded_dims[b];
return (*md_padded_dims)[a] < (*md_padded_dims)[b];
else
return strides[a] < strides[b];
};
std::sort(perm, perm + md_ndims, idx_sorter);
std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter);
auto min_stride = block_size;
for (int idx = 0; idx < md_ndims; ++idx) {
@ -199,9 +200,10 @@ bool onednn_strides_check(const Tensor& src) {
return false;
// update min_stride for next iteration
const auto padded_dim = *md_padded_dims[d];
const auto padded_dim = (*md_padded_dims)[d];
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
}
return true;
}

View File

@ -212,17 +212,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({}));
TORCH_CHECK(loss.is_mps());
Tensor loss_squeezed = loss.squeeze();
Tensor input_squeezed = input.squeeze();
Tensor target_squeezed = target.squeeze();
@autoreleasepool {
std::string key =
op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight});
std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed);
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed);
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
MPSGraphTensor* bceLossUnweighted = nil;
// if grad_output is defined, then it's a backward pass
@ -252,12 +247,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
newCachedGraph->gradInputTensor = bceLoss;
}
} else {
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size());
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size());
}
});
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed);
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed);
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed);
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target);
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss);
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
@ -370,7 +365,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
onValue:-1.0f
offValue:0.0f
name:nil];
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
if (isWeightsArrayValid) {
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
secondaryTensor:weightTensor
@ -705,6 +700,7 @@ static void smooth_l1_loss_template(const Tensor& input,
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
TORCH_CHECK(input.is_mps());
TORCH_CHECK(target.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
if ((input.numel() == 0) || (target.numel() == 0)) {
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
return;
@ -771,7 +767,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
// xn - yn
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:targetTensor
@ -797,7 +793,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
name:@"lossTensor"];
MPSGraphTensor* outputTensor = lossTensor;
if (reduction == Reduction::Mean) {
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
dataType:[lossTensor dataType]];
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
}
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor

View File

@ -1,7 +1,7 @@
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test")
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX")
load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX")
# Shared by internal and OSS BUCK
def define_qnnpack(third_party, labels = []):
@ -21,7 +21,7 @@ def define_qnnpack(third_party, labels = []):
("src", "requantization/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O2",
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
@ -82,7 +82,7 @@ def define_qnnpack(third_party, labels = []):
("src", "requantization/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O3",
"-ffast-math",
@ -129,7 +129,7 @@ def define_qnnpack(third_party, labels = []):
("src", "requantization/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O3",
"-ffast-math",
@ -184,7 +184,7 @@ def define_qnnpack(third_party, labels = []):
("src", "requantization/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O3",
"-ffast-math",
@ -236,7 +236,7 @@ def define_qnnpack(third_party, labels = []):
],
),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
],
@ -291,7 +291,7 @@ def define_qnnpack(third_party, labels = []):
("src", "qnnpack/*.h"),
("include", "*.h"),
]),
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O2",
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
@ -398,7 +398,7 @@ def define_qnnpack(third_party, labels = []):
("src", "requantization/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O3",
"-ffast-math",
@ -465,7 +465,7 @@ def define_qnnpack(third_party, labels = []):
("src", "requantization/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
"-Wno-unused-command-line-argument",
@ -525,7 +525,7 @@ def define_qnnpack(third_party, labels = []):
("src", "qnnpack/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = [
"-O3",
"-ffast-math",

View File

@ -0,0 +1,63 @@
#include <ATen/xpu/PeerToPeerAccess.h>
#include <ATen/xpu/XPUContext.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <c10/xpu/XPUCachingAllocator.h>
namespace at::xpu {
// p2pAccessEnabled_ is a flattened 2D matrix of size [num_devices x
// num_devices].
// Each element represents whether device[i] can access device[j]:
// 1 -> access allowed
// 0 -> access not allowed
// -1 -> unknown (not yet queried)
static std::vector<int8_t> p2pAccessEnabled_;
namespace detail {
// Initializes the peer-to-peer (P2P) access capability cache.
void init_p2p_access_cache(c10::DeviceIndex num_devices) {
// By default, each device can always access itself (diagonal entries = 1).
// For simplicity, all entries are initialized to -1 except the diagonal.
static bool once [[maybe_unused]] = [num_devices]() {
p2pAccessEnabled_.clear();
p2pAccessEnabled_.resize(num_devices * num_devices, -1);
for (const auto i : c10::irange(num_devices)) {
p2pAccessEnabled_[i * num_devices + i] = 1;
}
return true;
}();
}
} // namespace detail
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
check_device_index(dev);
check_device_index(dev_to_access);
auto& cache =
p2pAccessEnabled_[dev * c10::xpu::device_count() + dev_to_access];
if (cache != -1) {
return static_cast<bool>(cache);
}
// Query the hardware to determine if P2P access is supported
cache = static_cast<int8_t>(
c10::xpu::get_raw_device(dev).ext_oneapi_can_access_peer(
c10::xpu::get_raw_device(dev_to_access),
sycl::ext::oneapi::peer_access::access_supported));
if (cache) {
XPUCachingAllocator::enablePeerAccess(dev, dev_to_access);
}
return static_cast<bool>(cache);
}
} // namespace at::xpu

View File

@ -0,0 +1,15 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
namespace at::xpu {
namespace detail {
void init_p2p_access_cache(c10::DeviceIndex num_devices);
} // namespace detail
TORCH_XPU_API bool get_p2p_access(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access);
} // namespace at::xpu

View File

@ -1,3 +1,4 @@
#include <ATen/xpu/PeerToPeerAccess.h>
#include <ATen/xpu/PinnedMemoryAllocator.h>
#include <ATen/xpu/XPUContext.h>
#include <ATen/xpu/XPUDevice.h>
@ -12,6 +13,7 @@ void XPUHooks::init() const {
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
const auto device_count = c10::xpu::device_count_ensure_non_zero();
c10::xpu::XPUCachingAllocator::init(device_count);
at::xpu::detail::init_p2p_access_cache(device_count);
}
bool XPUHooks::hasXPU() const {

View File

@ -53,10 +53,8 @@ class AddmmBenchmark(op_bench.TorchBenchmarkBase):
return torch.addmm(input_one, mat1, mat2)
op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(
addmm_long_configs + addmm_long_configs, AddmmBenchmark
)
op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark)
"""Mircobenchmark for addbmm operator."""
@ -107,9 +105,7 @@ addbmm_short_configs = op_bench.cross_product_configs(
)
op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark)
op_bench.generate_pt_gradient_test(
addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark
)
op_bench.generate_pt_gradient_test(addbmm_long_configs, AddbmmBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -8,7 +8,7 @@ load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags")
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX")
load("//tools/build_defs:platform_defs.bzl", "IOS", "MACOSX")
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build")
load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build", is_profile_build_ios = "is_profile_build")
@ -1090,7 +1090,7 @@ def define_buck_targets(
srcs = [
"caffe2/core/common.cc",
],
apple_sdks = (IOS, MACOSX, APPLETVOS),
apple_sdks = (IOS, MACOSX),
compiler_flags = get_pt_compiler_flags(),
labels = labels,
# @lint-ignore BUCKLINT link_whole

View File

@ -929,6 +929,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/utils.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/dynamo/stackref_bridge.c",
"torch/csrc/functorch/init.cpp",
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
@ -1024,6 +1025,7 @@ libtorch_python_core_sources = [
libtorch_python_distributed_core_sources = [
"torch/csrc/distributed/c10d/init.cpp",
"torch/csrc/distributed/c10d/python_comm_hook.cpp",
"torch/csrc/distributed/c10d/python_callback_work.cpp",
]
libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [

View File

@ -59,6 +59,9 @@ constexpr DispatchKeySet nested_dispatch_keyset =
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
constexpr DispatchKeySet functorch_batched_dispatch_keyset =
DispatchKeySet(DispatchKey::FuncTorchBatched);
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
@ -77,6 +80,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
return backend_dispatch_keyset;
case DispatchKey::CompositeExplicitAutogradNonFunctional:
return non_functional_backend_dispatch_keyset;
case DispatchKey::FuncTorchBatchedDecomposition:
return functorch_batched_dispatch_keyset;
default:
return DispatchKeySet(t);
}

View File

@ -1,4 +1,5 @@
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
namespace c10 {
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
return toSymNodeImpl()->has_hint();
}
SymInt SymBool::toSymInt() const {
// If concrete bool, return concrete SymInt
if (auto ma = maybe_as_bool()) {
return SymInt(*ma ? 1 : 0);
}
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
auto node = toSymNodeImpl();
auto one_node = node->wrap_int(1);
auto zero_node = node->wrap_int(0);
return SymInt(node->sym_ite(one_node, zero_node));
}
} // namespace c10

View File

@ -12,6 +12,8 @@
namespace c10 {
class SymInt;
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b) {}
@ -80,6 +82,10 @@ class C10_API SymBool {
return toSymNodeImplUnowned()->constant_bool();
}
// Convert SymBool to SymInt (0 or 1)
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
SymInt toSymInt() const;
bool is_heap_allocated() const {
return ptr_;
}

View File

@ -18,6 +18,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <array>
#include <cstddef>
@ -40,200 +41,99 @@ namespace c10 {
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
///
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
/// the underlying constexpr calls, we rely on apparent-type dispatch for
/// inheritance. This should be fine because their memory format is the same,
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
/// However, you should prefer to use ArrayRef when possible, because its use
/// of TORCH_CHECK will lead to better user-facing error messages.
template <typename T>
class ArrayRef final {
class ArrayRef final : public HeaderOnlyArrayRef<T> {
public:
using iterator = const T*;
using const_iterator = const T*;
using size_type = size_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
void debugCheckNullptrInvariant() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Data != nullptr || Length == 0,
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
}
public:
/// @name Constructors
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
/// SmallVector. As inherited constructors won't work with class template
/// argument deduction (CTAD) until C++23, we add deduction guides after
/// the class definition to enable CTAD.
/// @{
/// Construct an empty ArrayRef.
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a single element.
// TODO Make this explicit
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
constexpr ArrayRef(const T* data, size_t length)
: Data(data), Length(length) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a range.
constexpr ArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {
debugCheckNullptrInvariant();
}
using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
/// Construct an ArrayRef from a SmallVector. This is templated in order to
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
/// copy-construct an ArrayRef.
/// NOTE: this is the only constructor that is not inherited from
/// HeaderOnlyArrayRef.
template <typename U>
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
debugCheckNullptrInvariant();
}
template <
typename Container,
typename U = decltype(std::declval<Container>().data()),
typename = std::enable_if_t<
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
/* implicit */ ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(
!std::is_same_v<T, bool>,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct an ArrayRef from a std::array
template <size_t N>
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
/// @}
/// @name Simple Operations
/// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
/// @{
constexpr iterator begin() const {
return Data;
}
constexpr iterator end() const {
return Data + Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
constexpr reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// Check if all elements in the array satisfy the given expression
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
return std::all_of(cbegin(), cend(), pred);
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return Length == 0;
}
constexpr const T* data() const {
return Data;
}
/// size - Get the array size.
constexpr size_t size() const {
return Length;
}
/// front - Get the first element.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& front() const {
TORCH_CHECK(
!empty(), "ArrayRef: attempted to access front() of empty list");
return Data[0];
!this->empty(), "ArrayRef: attempted to access front() of empty list");
return this->Data[0];
}
/// back - Get the last element.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& back() const {
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
return Data[Length - 1];
}
/// equals - Check for element-wise equality.
constexpr bool equals(ArrayRef RHS) const {
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
TORCH_CHECK(
!this->empty(), "ArrayRef: attempted to access back() of empty list");
return this->Data[this->Length - 1];
}
/// slice(n, m) - Take M elements of the array starting at element N
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr ArrayRef<T> slice(size_t N, size_t M) const {
TORCH_CHECK(
N + M <= size(),
N + M <= this->size(),
"ArrayRef: invalid slice, N = ",
N,
"; M = ",
M,
"; size = ",
size());
return ArrayRef<T>(data() + N, M);
this->size());
return ArrayRef<T>(this->data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr ArrayRef<T> slice(size_t N) const {
TORCH_CHECK(
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
return slice(N, size() - N);
N <= this->size(),
"ArrayRef: invalid slice, N = ",
N,
"; size = ",
this->size());
return slice(N, this->size() - N); // should this slice be this->slice?
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const T& operator[](size_t Index) const {
return Data[Index];
}
/// Vector compatibility
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
/// STD_TORCH_CHECK
constexpr const T& at(size_t Index) const {
TORCH_CHECK(
Index < Length,
Index < this->Length,
"ArrayRef: invalid index Index = ",
Index,
"; Length = ",
Length);
return Data[Index];
this->Length);
return this->Data[Index];
}
/// Disallow accidental assignment from a temporary.
@ -253,16 +153,48 @@ class ArrayRef final {
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
};
/// Deduction guides for ArrayRef to support CTAD with inherited constructors
/// These mirror the constructors inherited from HeaderOnlyArrayRef
/// @{
// Single element constructor
template <typename T>
ArrayRef(const T&) -> ArrayRef<T>;
// Pointer and length constructor
template <typename T>
ArrayRef(const T*, size_t) -> ArrayRef<T>;
// Range constructor (begin, end)
template <typename T>
ArrayRef(const T*, const T*) -> ArrayRef<T>;
// Generic container constructor (anything with .data() and .size())
template <typename Container>
ArrayRef(const Container&) -> ArrayRef<
std::remove_pointer_t<decltype(std::declval<Container>().data())>>;
// std::vector constructor
template <typename T, typename A>
ArrayRef(const std::vector<T, A>&) -> ArrayRef<T>;
// std::array constructor
template <typename T, size_t N>
ArrayRef(const std::array<T, N>&) -> ArrayRef<T>;
// C array constructor
template <typename T, size_t N>
ArrayRef(const T (&)[N]) -> ArrayRef<T>;
// std::initializer_list constructor
template <typename T>
ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
/// @}
template <typename T>
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
int i = 0;

View File

@ -21,13 +21,20 @@ using stream_set = ska::flat_hash_set<xpu::XPUStream>;
struct Block;
typedef bool (*Comparison)(const Block*, const Block*);
bool BlockComparatorSize(const Block* a, const Block* b);
bool BlockComparatorAddress(const Block* a, const Block* b);
struct BlockPool {
BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
BlockPool(bool small)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
const bool is_small;
};
struct ExpandableSegment;
struct Block {
DeviceIndex device;
sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
@ -37,9 +44,11 @@ struct Block {
BlockPool* pool{nullptr}; // owning memory pool
void* ptr{nullptr}; // memory address
bool allocated{false}; // in-use flag
bool mapped{true}; // True if this Block is backed by physical pages
Block* prev{nullptr}; // prev block if split from a larger allocation
Block* next{nullptr}; // next block if split from a larger allocation
int event_count{0}; // number of outstanding XPU events
ExpandableSegment* expandable_segment{nullptr}; // owning expandable segment
Block(
DeviceIndex device,
@ -66,6 +75,20 @@ struct Block {
bool is_split() const {
return (prev != nullptr) || (next != nullptr);
}
// Inserts this block between two existing blocks with [before, this, after].
void splice(Block* before, Block* after) {
if (before) {
TORCH_INTERNAL_ASSERT(before->next == after);
before->next = this;
}
prev = before;
if (after) {
TORCH_INTERNAL_ASSERT(after->prev == before);
after->prev = this;
}
next = after;
}
};
bool BlockComparatorSize(const Block* a, const Block* b) {
@ -80,6 +103,221 @@ bool BlockComparatorSize(const Block* a, const Block* b) {
reinterpret_cast<uintptr_t>(b->ptr);
}
bool BlockComparatorAddress(const Block* a, const Block* b) {
if (a->queue != b->queue) {
return reinterpret_cast<uintptr_t>(a->queue) <
reinterpret_cast<uintptr_t>(b->queue);
}
return reinterpret_cast<uintptr_t>(a->ptr) <
reinterpret_cast<uintptr_t>(b->ptr);
}
// Represents a contiguous virtual memory segment mapped for allocation.
struct SegmentRange {
SegmentRange(void* addr, size_t bytes)
: ptr(static_cast<char*>(addr)), size(bytes) {}
char* ptr; // Starting address of the mapped range.
size_t size; // Size in bytes of the mapped range.
};
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<sycl::queue*> queue,
size_t segment_size,
std::vector<c10::DeviceIndex> peers)
: device_(device),
queue_(queue),
// 2MB for small pool, 20MB for large pool
segment_size_(segment_size),
peers_(std::move(peers)) {
const auto device_total =
c10::xpu::get_raw_device(device)
.get_info<sycl::info::device::global_mem_size>();
// The extra 1/8 allows flexibility for remapping or moving pages within the
// segment when unmapping earlier regions.
constexpr float kVirtualMemOversubscriptFactor = 1.125f; // 1 + 1/8
max_handles_ = numSegments(device_total * kVirtualMemOversubscriptFactor);
ptr_ = sycl::ext::oneapi::experimental::reserve_virtual_mem(
segment_size_ * max_handles_, xpu::get_device_context());
}
C10_DISABLE_COPY_AND_ASSIGN(ExpandableSegment);
ExpandableSegment(ExpandableSegment&&) = delete;
ExpandableSegment& operator=(ExpandableSegment&&) = delete;
// Maps a virtual memory range to physical memory.
SegmentRange map(SegmentRange range) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
if (begin == end) {
return rangeFromHandles(begin, end);
}
// Ensure handles_ vector is large enough to hold all segments.
if (end > handles_.size()) {
handles_.resize(end, std::nullopt);
}
// Allocate and map physical memory for each segment.
for (const auto i : c10::irange(begin, end)) {
TORCH_INTERNAL_ASSERT(!handles_.at(i));
try {
// Allocate physical memory for each segment. Construct the physical_mem
// in-place to avoid copies.
handles_.at(i).emplace(
xpu::get_raw_device(device_),
xpu::get_device_context(),
segment_size_);
// Map the allocated physical memory into the virtual address space.
handles_.at(i).value().map(
ptr_ + i * segment_size_,
segment_size_,
sycl::ext::oneapi::experimental::address_access_mode::read_write);
} catch (const sycl::exception& e) {
// Allocation failure: typically sycl::errc::memory_allocation.
// Mapping failure: typically sycl::errc::runtime (e.g., OOM due to
// over-subscription).
// Note: constructing physical_mem may over-subscribe device memory but
// not immediately trigger OOM. The actual OOM can occur during map().
// Roll back all segments allocated or mapped in this operation.
handles_.at(i) = std::nullopt;
for (const auto j : c10::irange(begin, i)) {
sycl::ext::oneapi::experimental::unmap(
reinterpret_cast<void*>(ptr_ + segment_size_ * j),
segment_size_,
xpu::get_device_context());
handles_.at(j) = std::nullopt;
}
trimHandles();
return rangeFromHandles(begin, begin);
}
}
return rangeFromHandles(begin, end);
}
// Unmap a virtual memory range from physical memory.
SegmentRange unmap(SegmentRange range) {
auto begin = segmentRight(range.ptr);
auto end = segmentLeft(range.ptr + range.size);
if (begin >= end) {
return SegmentRange{range.ptr, 0};
}
unmapHandles(begin, end);
return rangeFromHandles(begin, end);
}
// Returns the base pointer of the virtual memory segment.
char* ptr() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<char*>(ptr_);
}
// Returns the total size of the virtual memory segment.
size_t size() const {
return max_handles_ * segment_size_;
}
~ExpandableSegment() {
forEachAllocatedRange(
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
sycl::ext::oneapi::experimental::free_virtual_mem(
ptr_, segment_size_ * max_handles_, xpu::get_device_context());
}
private:
// Unmaps the physical memory handles in the range [begin, end) from the
// segment.
void unmapHandles(size_t begin, size_t end) {
// Currently, we don't support IPC shared memory with expandable segments.
TORCH_INTERNAL_ASSERT(queue_);
// As explained in Note [Safe to Free Blocks on BlockPool], additional
// synchronization is unnecessary here because the memory is already safe to
// release.
for (const auto i : c10::irange(begin, end)) {
// Note: physical_mem's destructor does NOT automatically unmap any mapped
// ranges. Users must explicitly call unmap on all ranges before
// destroying the physical_mem object.
sycl::ext::oneapi::experimental::unmap(
reinterpret_cast<void*>(ptr_ + segment_size_ * i),
segment_size_,
xpu::get_device_context());
// Here physical_mem object is being destructed.
handles_.at(i) = std::nullopt;
}
trimHandles();
}
// Remove trailing unused handles from the end of handles_.
void trimHandles() {
while (!handles_.empty() && !handles_.back()) {
handles_.pop_back();
}
}
// Iterates over all contiguous ranges of allocated segments in `handles_`,
// and invokes the provided function `fn(start, end)` for each range.
// Each range is defined as a half-open interval [start, end).
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
size_t start = 0;
for (const auto i : c10::irange(handles_.size())) {
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
start = i;
}
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
fn(start, i + 1);
}
}
}
// Returns the number of full segments required to cover `size` bytes.
// Rounds up to ensure partial segments are counted.
size_t numSegments(size_t size) const {
return (size + segment_size_ - 1) / segment_size_;
}
// Returns the index of the segment that contains the pointer `p`,
// relative to the base pointer `ptr_`. This is the *inclusive* lower bound
// of the segment that includes `p`.
size_t segmentLeft(char* p) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
size_t offset = p - ptr();
return offset / segment_size_;
}
// Returns the index of the segment just *past* the one containing pointer
// `p`, relative to the base pointer `ptr_`. This is the *exclusive* upper
// bound, useful for [begin, end) style ranges.
// If `p` lies exactly on a segment boundary, this is equal to segmentLeft(p).
// Otherwise, it rounds up and returns segmentLeft(p) + 1.
size_t segmentRight(char* p) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
size_t offset = p - ptr();
return numSegments(offset);
}
// Constructs a SegmentRange spanning indices [start, end).
SegmentRange rangeFromHandles(size_t begin, size_t end) {
return SegmentRange(
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
c10::DeviceIndex device_{-1};
std::optional<sycl::queue*> queue_;
// Virtual memory address used for reservation.
uintptr_t ptr_{0};
// Size of each segment in bytes.
size_t segment_size_{0};
// Maximum number of segments that can be allocated in this segment.
size_t max_handles_{0};
// Physical memory handles for the segments.
std::vector<std::optional<sycl::ext::oneapi::experimental::physical_mem>>
handles_{};
// Peer devices on which this memory could be accessible, reserved.
std::vector<c10::DeviceIndex> peers_{};
};
struct AllocParams {
AllocParams(
DeviceIndex device,
@ -125,10 +363,12 @@ class DeviceCachingAllocator {
DeviceIndex device_index;
size_t allowed_memory_maximum = 0;
bool set_fraction = false;
std::vector<ExpandableSegment*> expandable_segments;
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
!src->stream_uses.empty()) {
!src->stream_uses.empty() || dst->mapped != src->mapped) {
return 0;
}
@ -147,7 +387,8 @@ class DeviceCachingAllocator {
}
const size_t subsumed_size = src->size;
dst->size += subsumed_size;
auto erased = pool.blocks.erase(src);
auto erased =
src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
delete src;
@ -230,12 +471,175 @@ class DeviceCachingAllocator {
}
}
// Finds the first (lowest-address) block in any segment that has sufficient
// contiguous free virtual address space to satisfy `size`. The available
// space may span multiple adjacent blocks, which can include both free and
// unmapped segments.
Block* find_expandable_block(
c10::DeviceIndex device,
sycl::queue* queue,
BlockPool* pool,
size_t size) {
Block key(device, queue, 0);
auto allocatable = [](Block* b) {
return b && !b->allocated && b->event_count == 0 &&
b->stream_uses.empty();
};
auto has_available_address_space = [&](Block* b) {
size_t bytes = 0;
while (bytes < size && allocatable(b)) {
bytes += b->size;
b = b->next;
}
return bytes >= size;
};
for (auto it = pool->unmapped.lower_bound(&key);
it != pool->unmapped.end() && (*it)->queue == queue;
++it) {
Block* c = *it;
// The unmapped block might have a free mapped block right before it.
// By starting from the previous block, we can use both:
// [Free Mapped Block] + [Unmapped Block] = More contiguous space
if (allocatable(c->prev)) {
c = c->prev;
}
if (has_available_address_space(c)) {
return c;
}
}
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
expandable_segments.emplace_back(new ExpandableSegment(
device, queue, segment_size, devices_with_peer_access));
ExpandableSegment* es = expandable_segments.back();
Block* candidate = new Block(device, queue, es->size(), pool, es->ptr());
candidate->mapped = false;
candidate->expandable_segment = es;
pool->unmapped.insert(candidate);
return candidate;
}
bool map_block(Block* to_map, size_t size) {
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
auto mapped_range =
to_map->expandable_segment->map(SegmentRange{to_map->ptr, size});
// Failed to map the memory
if (mapped_range.size == 0) {
return false;
}
TORCH_INTERNAL_ASSERT(
mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
BlockPool& pool = *to_map->pool;
pool.unmapped.erase(to_map);
to_map->mapped = true;
if (mapped_range.size < to_map->size) {
// to_map -> remaining -> to_map->next(?)
Block* remaining = new Block(
to_map->device,
to_map->queue,
to_map->size - mapped_range.size,
&pool,
static_cast<char*>(to_map->ptr) + mapped_range.size);
remaining->mapped = false;
remaining->expandable_segment = to_map->expandable_segment;
remaining->splice(to_map, to_map->next);
pool.unmapped.insert(remaining);
to_map->size = mapped_range.size;
}
try_merge_blocks(to_map, to_map->prev, pool);
try_merge_blocks(to_map, to_map->next, pool);
pool.blocks.insert(to_map);
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(mapped_range.size);
});
return true;
}
Block* try_allocate_expandable_block(
c10::DeviceIndex device,
sycl::queue* queue,
BlockPool* pool,
size_t size) {
// Candidate points to the start of a chain of contiguous blocks with
// sufficient virtual address space (>= size). The chain may consist of:
// Case 1: [Unmapped Block] -> null
// Case 2: [Unmapped Block] -> [Free Mapped Block]
// Case 3: [Free Mapped Block] -> [Unmapped Block]
Block* candidate = find_expandable_block(device, queue, pool, size);
// Map first block if unmapped (Case 1 & 2), use std::min to avoid
// over-mapping.
if (!candidate->mapped &&
!map_block(candidate, std::min(candidate->size, size))) {
return nullptr;
}
TORCH_INTERNAL_ASSERT(candidate->mapped);
// Map additional blocks until we have enough continuous space (Case 3).
// Each map_block() call merges newly mapped blocks with adjacent free
// blocks
while (candidate->size < size) {
auto remaining = size - candidate->size;
auto new_candidate = candidate->next;
// Map only what we need from the `new_candidate` block.
if (!map_block(new_candidate, std::min(remaining, new_candidate->size))) {
return nullptr;
}
candidate = new_candidate;
}
// Remove from the free pool; block will be marked as `allocated` in
// alloc_found_block()
pool->blocks.erase(candidate);
return candidate;
}
bool get_free_block(AllocParams& p) {
BlockPool& pool = *p.pool;
auto it = pool.blocks.lower_bound(&p.search_key);
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
return false;
}
if ((*it)->expandable_segment) {
if (AcceleratorAllocatorConfig::use_expandable_segments()) {
// When expandable segments are enabled, consider both the current block
// and any immediately adjacent unmapped region as a single expandable
// area. For "best fit" allocation, we use the total expandable size
// instead of just the block's current size, so that blocks which can
// grow into a larger contiguous range are preferred.
auto expandable_size = [](Block* b) {
// b->next may belong to pool.unmapped (reserved but not mapped)
return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
};
auto next = it;
next++;
// Looks for the best fit block with expandable size.
while ((*it)->expandable_segment && next != pool.blocks.end() &&
(*next)->queue == p.queue() &&
expandable_size(*next) < expandable_size(*it)) {
it = next++;
}
} else {
// Expandable segments were previously enabled, but are now disabled
// (e.g. to avoid IPC issues). Skip any expandable blocks and only
// find from regular non-expandable segments.
do {
it++;
} while (it != pool.blocks.end() && (*it)->expandable_segment &&
(*it)->queue == p.queue());
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
return false;
}
}
}
p.block = *it;
pool.blocks.erase(it);
return true;
@ -252,6 +656,10 @@ class DeviceCachingAllocator {
size >
allowed_memory_maximum) {
return false;
} else if (AcceleratorAllocatorConfig::use_expandable_segments()) {
p.block =
try_allocate_expandable_block(device, p.queue(), p.pool, p.size());
return bool(p.block);
}
void* ptr = sycl::aligned_alloc_device(
kDeviceAlignment,
@ -265,6 +673,7 @@ class DeviceCachingAllocator {
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
});
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
return true;
}
@ -283,6 +692,27 @@ class DeviceCachingAllocator {
xpu_events.clear();
}
void release_expandable_segment(Block* block) {
// See Note [Safe to Free Blocks on BlockPool], additional synchronization
// is unnecessary here because this function is only called by
// release_cached_blocks().
TORCH_INTERNAL_ASSERT(
block->size == block->expandable_segment->size(),
"block disagrees with segment");
TORCH_INTERNAL_ASSERT(!block->mapped);
auto it = std::find(
expandable_segments.begin(),
expandable_segments.end(),
block->expandable_segment);
TORCH_INTERNAL_ASSERT(it != expandable_segments.end());
expandable_segments.erase(it);
block->pool->unmapped.erase(block);
delete block->expandable_segment;
delete block;
}
void release_block(Block* block) {
/*
* Note [Safe to Free Blocks on BlockPool]
@ -293,6 +723,7 @@ class DeviceCachingAllocator {
* We have to do a device-level synchronization before free these blocks to
* guarantee that all kernels can access to the blocks have finished.
*/
TORCH_INTERNAL_ASSERT(!block->expandable_segment);
sycl::free(block->ptr, xpu::get_device_context());
auto* pool = block->pool;
pool->blocks.erase(block);
@ -305,15 +736,80 @@ class DeviceCachingAllocator {
delete block;
}
void unmap_block(Block* block) {
auto unmapped =
block->expandable_segment->unmap(SegmentRange{block->ptr, block->size});
if (unmapped.size == 0) {
return;
}
block->pool->blocks.erase(block);
ptrdiff_t before_size = unmapped.ptr - static_cast<char*>(block->ptr);
if (before_size > 0) {
// If the actual unmapped region starts after block->ptr due to alignment,
// the region before unmapped.ptr is still mapped.
// [Prev Block?] -> [Before Block] -> [Unmapped Block]
Block* before_free = new Block(
block->device, block->queue, before_size, block->pool, block->ptr);
before_free->expandable_segment = block->expandable_segment;
before_free->splice(block->prev, block);
block->pool->blocks.insert(before_free);
}
auto after_size = block->size - (before_size + unmapped.size);
if (after_size > 0) {
// If the actual unmapped region ends before block->ptr + block->size,
// the region after (unmapped.ptr + unmapped.size) is still mapped.
// [Unmapped Block] -> [After Block] -> [Next Block?]
Block* after_free = new Block(
block->device,
block->queue,
after_size,
block->pool,
unmapped.ptr + unmapped.size);
after_free->expandable_segment = block->expandable_segment;
after_free->splice(block, block->next);
block->pool->blocks.insert(after_free);
}
// [Before Mapped Block?] -> [Unmapped Block] -> [After Mapped Block?]
block->ptr = unmapped.ptr;
block->size = unmapped.size;
block->mapped = false;
try_merge_blocks(block, block->prev, *block->pool);
try_merge_blocks(block, block->next, *block->pool);
block->pool->unmapped.insert(block);
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
}
void release_blocks(BlockPool& pool) {
std::vector<Block*> to_unmap;
// Frees all non-split blocks in the given pool.
auto it = pool.blocks.begin();
while (it != pool.blocks.end()) {
Block* block = *it;
++it;
if (!block->prev && !block->next) {
if (block->expandable_segment) {
// unmap_block() modifies the free pool, so collect items to free first
// to avoid iterator invalidation.
to_unmap.push_back(block);
} else if (!block->prev && !block->next) {
release_block(block);
}
}
for (Block* block : to_unmap) {
unmap_block(block);
// After unmap_block(), expandable segment blocks with no neighbors are
// also released.
if (!block->prev && !block->next) {
release_expandable_segment(block);
}
}
}
bool release_cached_blocks() {
@ -328,7 +824,8 @@ class DeviceCachingAllocator {
bool should_split(const Block* block, size_t size) {
size_t remaining = block->size - size;
if (block->pool->is_small) {
if (block->pool->is_small ||
AcceleratorAllocatorConfig::use_expandable_segments()) {
return remaining >= kMinBlockSize;
} else {
return remaining > kSmallSize;
@ -361,6 +858,7 @@ class DeviceCachingAllocator {
remaining = block;
block = new Block(device, queue, size, pool, block->ptr);
block->expandable_segment = remaining->expandable_segment;
block->prev = remaining->prev;
if (block->prev) {
block->prev->next = block;
@ -599,6 +1097,15 @@ class XPUAllocator : public DeviceAllocator {
return block;
}
void assertValidDevice(DeviceIndex device) {
const auto device_num = device_allocators.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}
public:
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
@ -711,15 +1218,6 @@ class XPUAllocator : public DeviceAllocator {
xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
}
void assertValidDevice(DeviceIndex device) {
const auto device_num = device_allocators.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}
DeviceStats getDeviceStats(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getStats();
@ -735,6 +1233,13 @@ class XPUAllocator : public DeviceAllocator {
device_allocators[device]->resetAccumulatedStats();
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
assertValidDevice(dev);
assertValidDevice(dev_to_access);
c10::xpu::get_raw_device(dev).ext_oneapi_enable_peer_access(
c10::xpu::get_raw_device(dev_to_access));
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();
@ -793,6 +1298,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
return allocator.recordStream(dataPtr, stream);
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
return allocator.enablePeerAccess(dev, dev_to_access);
}
double getMemoryFraction(DeviceIndex device) {
return allocator.getMemoryFraction(device);
}

View File

@ -25,6 +25,10 @@ C10_XPU_API void raw_delete(void* ptr);
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
C10_XPU_API void enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access);
C10_XPU_API double getMemoryFraction(DeviceIndex device);
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);

View File

@ -1307,7 +1307,7 @@ endif()
if(USE_MKLDNN_ACL)
find_package(ACL REQUIRED)
target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS})
target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS})
endif()
target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE})

View File

@ -26,7 +26,7 @@ find_library(Gloo_CUDA_LIBRARY
# if Gloo + HIP is desired, Gloo_HIP_LIBRARY
# needs to be linked to desired target
find_library(Gloo_HIP_LIBRARY
NAMES gloo_hiop
NAMES gloo_hip
DOC "Gloo's HIP support/code"
)

View File

@ -18,6 +18,7 @@ IF(NOT MKLDNN_FOUND)
SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn")
SET(ONEDNN_AARCH64_TAG "v3.10-rc")
if(USE_XPU) # Build oneDNN GPU library
if(WIN32)
@ -96,6 +97,14 @@ IF(NOT MKLDNN_FOUND)
FIND_PACKAGE(BLAS)
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl)
# Checkout the oneDNN version defined by ONEDNN_AARCH64_TAG for CPU_AARCH64
IF(CPU_AARCH64)
EXECUTE_PROCESS(
COMMAND git${CMAKE_EXECUTABLE_SUFFIX} checkout ${ONEDNN_AARCH64_TAG}
WORKING_DIRECTORY ${MKLDNN_ROOT}
)
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
ENDIF(CPU_AARCH64)
IF(NOT MKLDNN_INCLUDE_DIR)
MESSAGE("MKLDNN_INCLUDE_DIR not found")
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})

View File

@ -28,6 +28,15 @@ endif()
# Find CUDA.
find_package(CUDA)
if(NOT CUDA_FOUND)
# If user explicitly set USE_CUDA=1, error out instead of falling back
if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA)
message(FATAL_ERROR
"PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. "
"Please check your CUDA installation, ensure CUDA toolkit is installed, "
"and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. "
"If you want to build without CUDA, please set USE_CUDA=0.")
endif()
message(WARNING
"PyTorch: CUDA cannot be found. Depending on whether you are building "
"PyTorch or a PyTorch dependent library, the next warning / error will "

View File

@ -45,7 +45,7 @@ supported for complex tensors.
## Transition from the old representation
Users who currently worked around the lack of complex tensors with real tensors of shape {math}`(..., 2)`
can easily to switch using the complex tensors in their code using {func}`torch.view_as_complex`
can easily switch to using the complex tensors in their code using {func}`torch.view_as_complex`
and {func}`torch.view_as_real`. Note that these functions dont perform any copy and return a
view of the input tensor.
@ -140,7 +140,7 @@ through the same optimizer on the {func}`torch.view_as_real` equivalent of the c
`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical
discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers
and capturable vs default optimizers. For more details, see [numbercial accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
and capturable vs default optimizers. For more details, see [numerical accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their
`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the

View File

@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective
.. autofunction:: new_group
```
```{eval-rst}
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
```
```{eval-rst}
.. autofunction:: get_group_rank
```

View File

@ -1720,6 +1720,16 @@ and can be used to share memory across graphs as shown::
g1.replay()
g2.replay()
It's also safe to share a memory pool across separate graphs that do not depend
on each other's outputs, provided they never run concurrently.
Be aware that replaying one graph can clobber another graph's outputs when
they share a pool, unless :meth:`~torch.Tensor.clone` is called on the outputs
beforehand.
This pattern is frequently used in inference servers that accept variable batch
sizes at runtime.
vLLM is a notable example; see `here <https://github.com/vllm-project/vllm/blob/938a81692ea318e59ead4750e7e7425bfd6a4896/vllm/platforms/interface.py#L508-L515>`__
and `here <https://github.com/vllm-project/vllm/blob/938a81692ea318e59ead4750e7e7425bfd6a4896/vllm/compilation/cuda_graph.py#L86-L89>`__.
With :func:`torch.cuda.make_graphed_callables`, if you want to graph several
callables and you know they'll always run in the same order (and never concurrently)
pass them as a tuple in the same order they'll run in the live workload, and

View File

@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
@ -44,6 +45,10 @@ endif()
# Disable unused-variable warnings for variables that are only used to test compilation
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable)
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable)
# Add -Wno-dangling-pointer for GCC 13
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13)
target_compile_options_if_supported(test_aoti_abi_check -Wno-dangling-pointer)
endif()
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})

View File

@ -0,0 +1,52 @@
#include <gtest/gtest.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <vector>
using torch::headeronly::HeaderOnlyArrayRef;
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
HeaderOnlyArrayRef<float> arr;
ASSERT_TRUE(arr.empty());
}
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
float val = 5.0f;
HeaderOnlyArrayRef<float> arr(val);
ASSERT_FALSE(arr.empty());
EXPECT_EQ(arr.size(), 1);
EXPECT_EQ(arr[0], val);
}
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr(vec);
ASSERT_FALSE(arr.empty());
EXPECT_EQ(arr.size(), 7);
for (size_t i = 0; i < arr.size(); i++) {
EXPECT_EQ(arr[i], i + 1);
EXPECT_EQ(arr.at(i), i + 1);
}
EXPECT_EQ(arr.front(), 1);
EXPECT_EQ(arr.back(), 7);
ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
}
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
auto res_vec = arr.vec();
for (size_t i = 0; i < vec.size(); i++) {
EXPECT_EQ(vec[i], res_vec[i]);
}
}
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
auto res_vec = arr.vec();
for (size_t i = 0; i < res_vec.size(); i++) {
EXPECT_EQ(vec[i + 3], res_vec[i]);
}
}

View File

@ -70,6 +70,13 @@ if(NOT MSVC)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12)
target_compile_options_if_supported(test_api "-Wno-error=nonnull")
endif()
# Add -Wno-error=array-bounds for GCC 13+
# See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=113239
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13)
target_compile_options_if_supported(test_api "-Wno-error=array-bounds")
endif()
endif()
if(INSTALL_TEST)

View File

@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
auto out = new_empty(param, param.sizes());
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -311,10 +301,9 @@ void boxed_fill_infinity(
}
Tensor my_pad(Tensor t) {
std::vector<int64_t> padding = {1, 2, 2, 1};
std::string mode = "constant";
double value = 0.0;
return pad(t, padding, mode, value);
return pad(t, {1, 2, 2, 1}, mode, value);
}
void boxed_my_pad(
@ -342,6 +331,11 @@ void boxed_my_narrow(
}
Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);
@ -353,9 +347,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, sizes, dtype);
return new_zeros(t, {2, 5}, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
@ -429,8 +422,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
}
Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
return amax(t, {0,1}, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {

View File

@ -256,23 +256,25 @@ class TestSDPA(NNTestCase):
)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,
k_privateuse1,
v_privateuse1,
attn_mask_privateuse1,
grad_input_mask,
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p=0.0,
is_causal=False,
philox_seed=philox_seed,
philox_offset=philox_offset,
_grad_q, _grad_k, _grad_v, _grad_attn_mask = (
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,
k_privateuse1,
v_privateuse1,
attn_mask_privateuse1,
grad_input_mask,
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p=0.0,
is_causal=False,
philox_seed=philox_seed,
philox_offset=philox_offset,
)
)

View File

@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
mesh_shape=(replicate_size, pp_size),
mesh_dim_names=("replicate", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate", "shard"]
dp_mesh = device_mesh["replicate"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
@ -416,15 +416,13 @@ class ComposabilityTest(MultiProcessTestCase):
param_dtype=MixedPrecisionParam,
reduce_dtype=torch.float32,
)
replicate_config = {"mp_policy": mp_policy}
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
device_mesh=dp_mesh,
**replicate_config,
reshard_after_forward=False,
)
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
dp_model = replicate(partial_model, **replicate_config)
return dp_model
# Apply same precision to reference model (without replicate)
@ -582,11 +580,11 @@ class ComposabilityTest(MultiProcessTestCase):
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
mesh_shape=(replicate_size, pp_size),
mesh_dim_names=("replicate", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate", "shard"]
dp_mesh = device_mesh["replicate"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
dp_group = device_mesh["replicate"].get_group()
@ -648,10 +646,9 @@ class ComposabilityTest(MultiProcessTestCase):
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
device_mesh=dp_mesh,
reshard_after_forward=False,
mesh=dp_mesh,
)
dp_model = replicate(partial_model, device_mesh=dp_mesh)
dp_model = replicate(partial_model, mesh=dp_mesh)
return dp_model
def pipelined_models_parameters(start_layer, model):

View File

@ -3,7 +3,7 @@
import copy
import dataclasses
import functools
from typing import Optional, Union
from typing import Optional
import torch
import torch.distributed as dist
@ -14,7 +14,6 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
_get_gradient_divide_factors,
)
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
SaveForwardInputsModel,
@ -46,35 +45,20 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
def _init_models_and_optims(
self,
reshard_after_forward: Union[bool, int],
param_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
use_shard_placement_fn,
):
torch.manual_seed(42)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = -1
largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
assert largest_dim >= 0, f"{param.shape}"
return Shard(largest_dim)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
replicate_fn = functools.partial(
replicate,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=shard_placement_fn,
)
for mlp in model:
replicate_fn(mlp)
@ -82,27 +66,13 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
return ref_model, ref_optim, model, optim
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
use_shard_placement_fn_vals = [False]
if self.world_size == 2:
# For world size >2, gradient elements get reduced in different
# orders for the baseline vs. dim-1 sharding, leading to numeric
# differences for bf16 reduction, so only test world size 2.
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"param_dtype": [torch.bfloat16, torch.float16],
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_compute_dtype,
)
@ -110,14 +80,10 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
def _test_compute_dtype(
self,
param_dtype: torch.dtype,
reshard_after_forward: Union[bool, int],
use_shard_placement_fn: bool,
):
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=None,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -175,39 +141,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": [False, True],
},
self._test_reduce_dtype_fp32_reduce,
)
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_reduce_dtype_bf16_reduce,
)
self._test_reduce_dtype_fp32_reduce()
self._test_reduce_dtype_bf16_reduce()
def _test_reduce_dtype_fp32_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
if (
self.world_size > 2
and isinstance(reshard_after_forward, int)
and use_shard_placement_fn
):
return
def _test_reduce_dtype_fp32_reduce(self):
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -249,14 +190,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
check_sharded_parity(self, ref_model, model)
def _test_reduce_dtype_bf16_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
self,
):
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
group = dist.distributed_c10d._get_default_group()
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -321,12 +260,8 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
replicate(
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(mlp, mp_policy=mp_policy)
replicate(model, mp_policy=mp_policy)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
orig_reduce_scatter = dist.reduce_scatter_tensor

View File

@ -108,84 +108,70 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
"""Tests the parameter registration after forward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
if reshard_after_forward:
self._assert_dtensor_params(model.parameters())
else:
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
# Multiple Replicate groups
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj)
replicate(model[0].out_proj)
replicate(model)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
if reshard_after_forward is None:
self._assert_dtensor_params(non_root_params)
self._assert_tensor_params(root_params)
elif reshard_after_forward:
self._assert_dtensor_params(non_root_params)
self._assert_dtensor_params(root_params)
else:
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
@skip_if_lt_x_gpu(1)
def test_param_registration_after_backward(self):
"""Tests the parameter registration after backward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
model = MLP(8, device)
replicate(model) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
# Multiple Replicate groups
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
model = MLP(8, device)
replicate(model.in_proj)
replicate(model.out_proj)
replicate(model)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
# need to iterate over the list multiple times
@ -287,14 +273,11 @@ class TestReplicate1DTrainingCore(FSDPTest):
[(7, 15), (15, 3)],
[(16, 17), (17, 8)],
],
"use_shard_placement_fn": [False],
},
self._test_train_parity_single_group,
)
def _test_train_parity_single_group(
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
):
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
torch.manual_seed(42)
model = nn.Sequential(
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
@ -333,7 +316,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"test_device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
@ -354,7 +336,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True], # save CI time
"offload_policy": [
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
@ -371,7 +352,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
def _test_train_parity_multi_group(
self,
reshard_after_forward: Union[bool, int],
offload_policy: OffloadPolicy,
test_device_type: str,
delay_after_forward: bool,
@ -405,13 +385,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
mesh = init_device_mesh(
test_device_type,
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
(self.world_size,),
mesh_dim_names=("replicate",),
)
fully_shard_fn = functools.partial(
replicate,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
mesh=mesh,
offload_policy=offload_policy,
)
for module in model.modules():
@ -527,12 +506,10 @@ class TestReplicate1DTrainingCore(FSDPTest):
Tests parity when running a module that participates multiple
times in forward.
"""
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_multi_forward_module,
)
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
self._test_multi_forward_module()
def _test_multi_forward_module(self):
class MultiForwardModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
@ -687,7 +664,6 @@ class TestReplicateTrainingCompose(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"checkpoint_impl": ["composable", "utils", "wrapper"],
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
"test_device_type": [device_type.type],
@ -697,7 +673,6 @@ class TestReplicateTrainingCompose(FSDPTest):
def _test_train_parity_with_activation_checkpointing(
self,
reshard_after_forward: Union[bool, int],
checkpoint_impl: str,
module_grouping: str,
test_device_type: str,
@ -740,12 +715,11 @@ class TestReplicateTrainingCompose(FSDPTest):
# Apply Replicate
device_mesh = init_device_mesh(
test_device_type,
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
(self.world_size,),
mesh_dim_names=("replicate",),
)
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"device_mesh": device_mesh,
"mesh": device_mesh,
}
if module_grouping == "mem_eff":
assert model_args.n_layers == 3
@ -809,7 +783,6 @@ class TestReplicateSharedParams(FSDPTest):
def test_train_parity_with_shared_params(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
},
self._test_train_shared_params,
@ -817,7 +790,6 @@ class TestReplicateSharedParams(FSDPTest):
def _test_train_shared_params(
self,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
):
torch.manual_seed(42)
@ -830,8 +802,8 @@ class TestReplicateSharedParams(FSDPTest):
if isinstance(module, TransformerBlock):
if use_activation_checkpointing:
checkpoint(module)
replicate(module, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
replicate(module)
replicate(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
@ -868,11 +840,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
with/without resharding after backward.
"""
shard_size, replicate_size = 1, self.world_size
replicate_size = self.world_size
meshes = init_device_mesh(
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("replicate", "shard"),
(replicate_size,),
mesh_dim_names=("replicate",),
)
self.run_subtests(
{
@ -928,8 +900,7 @@ class TestReplicateGradientAccumulation(FSDPTest):
ref_model = copy.deepcopy(model).to(device_type)
replicate_fn = functools.partial(
replicate,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
mesh=mesh,
offload_policy=offload_policy,
)
for mlp in model[1:]:
@ -1040,8 +1011,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
replicate(module, reshard_after_forward=False)
replicate(model, reshard_after_forward=False)
replicate(module)
replicate(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
num_microbatches = 3
@ -1145,8 +1116,8 @@ class TestReplicateTPTraining(FSDPTest):
def init_global_mesh(self) -> DeviceMesh:
return init_device_mesh(
device_type.type,
(2, 1, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
(2, 2),
mesh_dim_names=("dp_replicate", "tp"),
)
@skip_if_lt_x_gpu(8)
@ -1154,7 +1125,6 @@ class TestReplicateTPTraining(FSDPTest):
global_mesh = self.init_global_mesh()
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
@ -1165,12 +1135,11 @@ class TestReplicateTPTraining(FSDPTest):
def _test_replicate_tp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
foreach: bool,
):
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
torch.manual_seed(42)
@ -1197,8 +1166,8 @@ class TestReplicateTPTraining(FSDPTest):
continue
if use_activation_checkpointing:
checkpoint(module)
replicate(module, device_mesh=dp_mesh)
replicate(model, device_mesh=dp_mesh)
replicate(module, mesh=dp_mesh)
replicate(model, mesh=dp_mesh)
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
@ -1229,11 +1198,9 @@ class TestReplicateTPTraining(FSDPTest):
for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 3)
self.assertEqual(len(p.placements), 3)
self.assertEqual(
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
)
self.assertEqual(p.device_mesh.ndim, 2)
self.assertEqual(len(p.placements), 2)
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
if __name__ == "__main__":

View File

@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
if i % 2 == 0:
self.assertTrue("replicate" in _get_registry(layer))
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
self.assertEqual(parameter.placements, (Replicate(),))
elif i % 2 == 1:
self.assertTrue("fully_shard" in _get_registry(layer))
for parameter in layer.parameters():
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
]
global_mesh = self.init_replicate_tp_mesh()
replicate_mesh = global_mesh["replicate", "shard"]
replicate_mesh = global_mesh["replicate"]
for layer in layers:
replicate(layer, device_mesh=replicate_mesh)
replicate(layer, mesh=replicate_mesh)
for parameter in layer.parameters():
self.assertEqual(parameter.device_mesh.shape, (2, 1))
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(2)
def test_train_replicate_fsdp(self):
@ -263,7 +263,6 @@ class ReplicateTest(MultiProcessTestCase):
run_subtests(
self,
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 16, 17],
},
@ -273,7 +272,6 @@ class ReplicateTest(MultiProcessTestCase):
def _test_train_parity_2d_mlp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
):
@ -287,13 +285,12 @@ class ReplicateTest(MultiProcessTestCase):
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_mesh=replicate_shard_mesh)
replicate(ref_model, mesh=replicate_mesh)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
tp_mesh,
replicate_shard_mesh,
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)

View File

@ -1,16 +1,26 @@
# Owner(s): ["oncall: distributed checkpointing"]
import os
import sys
from unittest.mock import patch
import torch
import torch.testing._internal.common_utils as common
from torch import distributed as dist
from torch.distributed.checkpoint._async_process_executor import (
_ProcessBasedAsyncCheckpointExecutor,
_ProcessGroupInitInfo,
)
from torch.distributed.checkpoint.api import CheckpointException
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.common_distributed import skip_if_win32
from torch.testing._internal.common_utils import (
retry_on_connect_failures,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -110,47 +120,184 @@ class TestAsyncProcessExecutor(DTensorTestBase):
"epoch": 5,
}
# 1. Simulate a failure in creating PG in background process.
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=-1,
):
with self.assertRaises(ValueError) as _:
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DCP_USE_PREFIX_STORE", None)
# 1. Simulate a failure in creating PG in background process.
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=-1,
):
with self.assertRaises(ValueError) as _:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
)
fut.result()
# 2. Attempt save with failing storage writer
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=get_free_port(),
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="fail_once"),
)
fut.result()
self.assertIn(
"fail_once policy triggered failure", str(fut.exception())
)
# Verify new process was created for this attempt
if dist.get_rank() == 0:
mock_get_free_port.assert_called_once()
# 2. Attempt save with failing storage writer
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=get_free_port(),
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="fail_once"),
)
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
# Verify new process was created for this attempt
if dist.get_rank() == 0:
mock_get_free_port.assert_called_once()
# 3. Second save attempt with successful storage writer - process should still be alive
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
# Verify process is still alive
mock_get_free_port.assert_not_called()
# Verify successful save
self.assertIsNotNone(result)
# 3. Second save attempt with successful storage writer - process should still be alive
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
# Verify process is still alive
mock_get_free_port.assert_not_called()
# Verify successful save
self.assertIsNotNone(result)
class TestAsyncProcessExecutorPrefixStore(TestCase):
@skip_if_win32()
@retry_on_connect_failures
def test_checkpoint_save_with_prefix_store_enabled(self) -> None:
"""Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled."""
test_state_dict = {
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
}
master_addr = "localhost"
master_port = str(common.find_free_port())
with patch.dict(
os.environ,
{
"DCP_USE_PREFIX_STORE": "1",
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
},
):
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port"
) as mock_get_free_port:
dist.init_process_group(
backend=dist.Backend.GLOO,
rank=0,
world_size=1,
)
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
self.assertIsNotNone(result)
mock_get_free_port.assert_not_called()
class TestProcessGroupInitInfo(DTensorTestBase):
"""Test suite for _ProcessGroupInitInfo."""
@with_comms
def test_process_group_init_info_with_default_pg(self) -> None:
"""Test that ProcessGroupInitInfo correctly initializes."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DCP_USE_PREFIX_STORE", None)
pg_init_info = _ProcessGroupInitInfo()
self.assertEqual(pg_init_info.global_rank, dist.get_rank())
self.assertEqual(pg_init_info.world_size, dist.get_world_size())
self.assertIsNotNone(pg_init_info.tcp_store_master_addr)
self.assertGreater(pg_init_info.tcp_store_master_port, 0)
self.assertEqual(pg_init_info.use_prefix_store, False)
@with_comms
def test_process_group_init_info_with_prefix_store_env_var(self) -> None:
"""Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable."""
# Flag enabled, addr/port correctly defined
with patch.dict(
os.environ,
{
"DCP_USE_PREFIX_STORE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
},
):
pg_init_info = _ProcessGroupInitInfo()
self.assertTrue(pg_init_info.use_prefix_store)
# Missing port
with patch.dict(
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"}
):
with self.assertRaises(CheckpointException):
pg_init_info = _ProcessGroupInitInfo()
# Missing addr
with patch.dict(
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"}
):
with self.assertRaises(CheckpointException):
pg_init_info = _ProcessGroupInitInfo()
# Invalid port
with patch.dict(
os.environ,
{
"DCP_USE_PREFIX_STORE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "a",
},
):
with self.assertRaises(CheckpointException):
pg_init_info = _ProcessGroupInitInfo()
@with_comms
def test_process_group_init_info_without_prefix_store_env_var(self) -> None:
"""Test that ProcessGroupInitInfo defaults to not using prefix store."""
# Env var set to 0
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
# Missing env var
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DCP_USE_PREFIX_STORE", None)
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
# Invalid env var
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
if __name__ == "__main__":

View File

@ -5,8 +5,16 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -42,22 +50,24 @@ class TestDTensorDebugMode(TestCase):
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
with DebugMode(record_torchfunction=True) as debug_mode:
with DebugMode(
record_torchfunction=True, record_ids=True, record_output=True
) as debug_mode:
torch.mm(x_dtensor, y_dtensor).sum()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0))
aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0))
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0)
aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))
redistribute_input(1, S(0) -> R)
redistribute_input(t: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32]| S(0))
aten::sum(dt: f32[8, 32]| S(0))
aten::sum(t: f32[1, 32])""",
redistribute_input(t$2: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
_c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
<method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P
aten::sum(dt$6: f32[8, 32]| S(0))
aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""",
)
self.assertTrue(isinstance(debug_mode.operators[0], _OpCall))
@ -415,6 +425,40 @@ class TestDTensorDebugMode(TestCase):
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
)
with DebugMode(record_stack_trace=True) as debug_mode:
out = mod(inp).sum()
out.backward()
sum_op = [
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -3,7 +3,8 @@
import itertools
import random
import unittest
from typing import Any, Callable, ClassVar, Optional
from collections.abc import Callable
from typing import Any, ClassVar, Optional
import torch
import torch.distributed as dist

View File

@ -1019,6 +1019,28 @@ class DTensorMeshTest(DTensorTestBase):
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
@with_comms
def test_as_strided_identity(self):
# Test calling as_strided with the same size/stride/offset as input tensor
# This should be a no-op but currently fails
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 4, device=self.device_type)
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
# Get the current size, stride, and storage_offset
size = dtensor.size()
stride = dtensor.stride()
storage_offset = dtensor.storage_offset()
# Call as_strided with the exact same parameters
result = dtensor.as_strided(size, stride, storage_offset)
# The result should be identical to the input
self.assertEqual(result.size(), dtensor.size())
self.assertEqual(result.stride(), dtensor.stride())
self.assertEqual(result.to_local(), dtensor.to_local())
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
DTensorMeshTest,

View File

@ -2,6 +2,7 @@
import copy
import json
import logging
import os
import pickle
import random
@ -21,6 +22,7 @@ from unittest import mock, SkipTest
import torch
import torch.distributed as c10d
import torch.distributed._functional_collectives as _functional_collectives
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
if not c10d.is_available() or not c10d.is_nccl_available():
@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
get_required_world_size,
get_timeout,
init_multigpu_helper,
MultiProcessTestCase,
requires_multicast_support,
requires_nccl,
requires_nccl_shrink,
requires_nccl_version,
requires_world_size,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
@ -88,6 +93,53 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
)
_start_time = time.time()
_logger = logging.getLogger(__name__)
def _ts():
return time.time() - _start_time
def configure(level=logging.INFO, force=False):
try:
logging.basicConfig(
level=level,
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
force=force,
)
except TypeError:
logging.basicConfig(
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
)
def log_test_info(rank, message):
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
def log_test_success(rank, message):
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
def log_test_validation(rank, message):
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
def log_test_warning(rank, message):
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
def log_test_error(rank, message):
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
_log_configure = configure
_log_configure(level=logging.INFO, force=True)
class RendezvousEnvTest(TestCase):
@retry_on_connect_failures
@requires_nccl()
@ -317,7 +369,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
@property
def world_size(self):
return 2
return get_required_world_size(self, 2)
@property
def rank_to_GPU(self):
@ -1255,6 +1307,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
pg_2 = c10d.new_group([0, 1])
self.assertEqual(pg_2.group_desc, "undefined")
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_basic(self):
"""Test basic shrink_group functionality."""
self._perform_shrink_test([1], "Basic shrink test")
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_validation(self):
"""Test input validation in shrink_group."""
device, pg = self._setup_shrink_test("validation")
def _test_invalid_input(ranks, description, expected_exception):
"""Helper to test invalid inputs."""
try:
c10d.shrink_group(ranks)
self.fail(f"Expected {expected_exception.__name__} for {description}")
except expected_exception:
log_test_validation(self.rank, f"{description}")
except Exception:
if expected_exception is Exception: # Accept any exception
log_test_validation(self.rank, f"{description}")
else:
raise
# Test cases
_test_invalid_input([], "Empty exclusion list", ValueError)
if self.world_size > 1:
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
log_test_success(self.rank, "All validation tests passed")
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_backend_properties(self):
"""Test that backend properties are preserved after shrinking."""
test_name = "Backend Properties Test"
ranks_to_exclude = [0]
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
device, pg = self._setup_shrink_test("backend_properties")
# Follow _perform_shrink_test pattern from here
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
# Store original backend property values (not references) before shrinking
original_timeout = None
original_high_priority = None
if not is_excluded:
original_backend = pg._get_backend(device)
original_timeout = original_backend.options._timeout
original_high_priority = original_backend.options.is_high_priority_stream
log_test_info(
self.rank,
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
)
if is_excluded:
log_test_info(
self.rank,
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
)
dist.destroy_process_group() # hang without it
return
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
expected_size = self.world_size - len(ranks_to_exclude)
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
# Add custom backend properties validation
new_backend = shrunk_pg._get_backend(device)
log_test_info(self.rank, "Validating backend properties are preserved")
new_timeout = new_backend.options._timeout
new_high_priority = new_backend.options.is_high_priority_stream
log_test_info(
self.rank,
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
)
self.assertEqual(
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
)
log_test_info(
self.rank,
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
)
self.assertEqual(
original_high_priority,
new_high_priority,
f"{test_name}: high_priority_stream not preserved",
)
log_test_validation(
self.rank, f"{test_name}: Backend properties preserved successfully"
)
log_test_success(
self.rank, f"{test_name} successful (shrink + backend validation)"
)
# Cleanup (same as _perform_shrink_test)
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_multiple_comms(self):
"""Test shrink_group with multiple communicators and subgroup invalidation."""
device, pg = self._setup_shrink_test("multiple_comms")
# Create subgroup [0, 1] and test shrinking it
subgroup = c10d.new_group([0, 1])
if self.rank <= 1:
# Shrink subgroup: exclude rank 1
if self.rank == 0: # Only rank 0 remains
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
self.assertEqual(shrunk_subgroup.size(), 1)
# Test communication on shrunk subgroup
tensor = torch.full((1,), self.rank).cuda(device)
c10d.all_reduce(tensor, group=shrunk_subgroup)
self.assertEqual(tensor.item(), 0) # Only rank 0
log_test_success(self.rank, "Subgroup shrinking successful")
dist.barrier() # Sync before default group test
# Shrink default group: exclude last rank
ranks_to_exclude = [self.world_size - 1]
if self.rank not in ranks_to_exclude:
shrunk_default = c10d.shrink_group(ranks_to_exclude)
expected_size = self.world_size - 1
self.assertEqual(shrunk_default.size(), expected_size)
# Test collective on shrunk default group
tensor = torch.full((1,), self.rank).cuda(device)
c10d.all_reduce(tensor, group=shrunk_default)
expected_sum = sum(
range(self.world_size - 1)
) # 0 + 1 + ... + (world_size-2)
self.assertEqual(tensor.item(), expected_sum)
log_test_success(self.rank, "Default group shrinking successful")
# Note: After shrinking default group, the old subgroup is invalid
# due to global rank reassignment
dist.destroy_process_group()
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
"""Helper method to test shrink_group with a specific flag."""
if self.world_size < 2:
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
return
ranks_to_exclude = [rank_to_exclude]
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
if flag_name == "NCCL_SHRINK_ABORT":
log_test_info(
self.rank,
"ABORT flag will terminate ongoing operations before shrinking",
)
self._perform_shrink_test(
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
)
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_flags(self):
"""Test shrink_group with different shrink flags."""
# Test ABORT flags
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_nccl_config(self):
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
device, pg = self._setup_shrink_test("config")
if self.rank == self.world_size - 1:
# excluded rank should not call shrink_group
dist.destroy_process_group()
return
# Prepare pg_options with NCCL config overrides
# Capture parent's current backend options to ensure we can prove override vs inherit
parent_backend = pg._get_backend(torch.device("cuda"))
parent_hp = parent_backend.options.is_high_priority_stream
parent_blocking = parent_backend.options.config.blocking
# Choose overrides that differ from the parent (flip where possible)
override_hp = not parent_hp
if parent_blocking in (0, 1):
override_blocking = 1 - parent_blocking
else:
# If undefined or unexpected, set to 1 which is a concrete value
override_blocking = 1
opts = c10d.ProcessGroupNCCL.Options()
opts.is_high_priority_stream = override_hp
opts.config.blocking = override_blocking
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
# Validate backend options propagated
backend = shrunk_pg._get_backend(torch.device("cuda"))
# is_high_priority_stream should exactly match our override and differ from parent
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
# config is a struct; check representative field and difference from parent when meaningful
self.assertEqual(backend.options.config.blocking, override_blocking)
if parent_blocking in (0, 1):
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_performance(self):
"""Test shrink_group performance and regression detection."""
import time
ranks_to_exclude = self._get_default_ranks_to_exclude()
is_excluded = self.rank in ranks_to_exclude
if not ranks_to_exclude:
log_test_info(self.rank, "Skipping performance test (world_size=1)")
return
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
device, pg = self._setup_shrink_test("performance")
if not is_excluded:
log_test_info(self.rank, "Measuring shrink_group performance")
start_time = time.time()
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
end_time = time.time()
elapsed_time = end_time - start_time
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
# Regression check: should complete within reasonable time
self.assertLess(
elapsed_time,
30.0,
f"shrink_group took {elapsed_time:.3f}s, possible regression",
)
# Test collective performance
expected_size = self.world_size - len(ranks_to_exclude)
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
collective_start = time.time()
_ = self._test_collective_on_shrunk_group(
shrunk_pg, device, ranks_to_exclude, "performance"
)
collective_time = time.time() - collective_start
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
log_test_success(self.rank, "Performance test passed")
else:
log_test_info(self.rank, "Excluded rank - waiting")
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(4)
def test_shrink_group_multiple_exclusions(self):
"""Test shrink_group with multiple ranks excluded at once."""
# Scale exclusions with world size
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
@requires_nccl_shrink()
@requires_world_size(3)
def test_shrink_group_multiple_iterations(self):
"""Test multiple shrink operations in sequence."""
log_test_info(
self.rank,
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
)
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank}")
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
# Track current effective world size throughout shrinking operations
current_world_size = self.world_size
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
# First shrinking: exclude the last rank(s)
first_exclusion = [self.world_size - 1]
if self.world_size >= 6:
first_exclusion.append(
self.world_size - 2
) # Exclude last two ranks for larger sizes
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
if self.rank not in first_exclusion:
# Only non-excluded ranks should call shrink_group
first_pg = c10d.shrink_group(first_exclusion)
self.assertIsNotNone(first_pg)
# IMPORTANT: Update world size after first shrinking
current_world_size = first_pg.size()
expected_first_size = self.world_size - len(first_exclusion)
log_test_info(
self.rank,
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
)
self.assertEqual(first_pg.size(), expected_first_size)
# Second shrinking: exclude another rank from the remaining group
# Choose a rank that's in the middle range
if current_world_size >= 3:
second_exclusion = [
current_world_size - 1
] # Exclude the new "last" rank
log_test_info(
self.rank,
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
)
if self.rank not in second_exclusion:
# Only non-excluded ranks should call shrink_group for second iteration
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
self.assertIsNotNone(second_pg)
# IMPORTANT: Update world size after second shrinking
final_world_size = second_pg.size()
expected_final_size = current_world_size - len(second_exclusion)
log_test_info(
self.rank,
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
)
self.assertEqual(second_pg.size(), expected_final_size)
# Test collective on final group
tensor = torch.full((1,), self.rank).cuda(device)
log_test_info(
self.rank,
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
)
c10d.all_reduce(tensor, group=second_pg)
log_test_info(
self.rank,
f"Final all_reduce completed, result: {tensor.item()}",
)
# Calculate expected sum of remaining ranks
all_excluded = set(first_exclusion + second_exclusion)
remaining_ranks = [
r for r in range(self.world_size) if r not in all_excluded
]
expected_sum = sum(remaining_ranks)
log_test_info(
self.rank,
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
)
self.assertEqual(tensor.item(), expected_sum)
log_test_info(self.rank, "Final verification passed")
else:
log_test_info(
self.rank,
"This rank excluded in second shrinking, not calling shrink_group",
)
else:
log_test_info(
self.rank, "Skipping second shrinking (remaining group too small)"
)
else:
log_test_info(
self.rank,
"This rank excluded in first shrinking, not calling shrink_group",
)
log_test_info(self.rank, "Destroying process group")
dist.destroy_process_group()
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
# Helper methods for optimized shrink group tests
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
"""Common setup for shrink group tests."""
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
world_size = world_size or self.world_size
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
device = torch.device(f"cuda:{self.rank}")
c10d.init_process_group(
"nccl",
world_size=world_size,
rank=self.rank,
store=store,
pg_options=self.opts(),
device_id=device,
)
pg = c10d.distributed_c10d._get_default_group()
if warmup:
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
return device, pg
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
"""Validate properties of a shrunk process group."""
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
actual_size = shrunk_pg.size()
self.assertEqual(
actual_size, expected_size, f"{test_name}: group size mismatch"
)
new_rank = shrunk_pg.rank()
self.assertTrue(
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
)
log_test_info(
self.rank,
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
)
return new_rank
def _test_collective_on_shrunk_group(
self, shrunk_pg, device, ranks_to_exclude, test_name=""
):
"""Test collective communication on shrunk group and verify correctness."""
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
c10d.all_reduce(test_tensor, group=shrunk_pg)
result = test_tensor.item()
expected_sum = sum(
r for r in range(self.world_size) if r not in ranks_to_exclude
)
self.assertEqual(
result, expected_sum, f"{test_name}: collective result mismatch"
)
log_test_info(
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
)
return result
def _perform_shrink_test(
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
):
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
Consistent API: All ranks perform setup to initialize distributed environment.
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
"""
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
# All ranks (including excluded ones) perform setup to initialize distributed environment
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
is_default_group = pg == c10d.distributed_c10d._get_default_group()
if is_excluded:
log_test_info(
self.rank,
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
)
if shrink_flags & NCCL_SHRINK_ABORT:
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
pg._get_backend(torch.device(device)).abort()
log_test_info(
self.rank, f"cleanup resources for excluded rank {self.rank}"
)
dist.destroy_process_group()
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
else:
log_test_info(
self.rank, f"Using regular destroy for excluded rank {self.rank}"
)
dist.destroy_process_group()
return None
# Only non-excluded ranks proceed with shrink
log_test_info(
self.rank,
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
)
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
log_test_info(
self.rank,
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
)
# Non-excluded ranks: validate and test the new group
expected_size = self.world_size - len(ranks_to_exclude)
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
if with_collective:
_ = self._test_collective_on_shrunk_group(
shrunk_pg, device, ranks_to_exclude, test_name
)
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
else:
log_test_success(self.rank, f"{test_name} successful (shrink only)")
dist.destroy_process_group()
return shrunk_pg
def _get_default_ranks_to_exclude(self):
"""Get default ranks to exclude based on world size."""
if self.world_size <= 1:
return []
return [self.world_size - 1] # Exclude last rank by default
@requires_nccl_shrink()
@requires_world_size(3)
def test_shrink_group_vs_abort_reinit_performance(self):
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
device, pg1 = self._setup_shrink_test("_perf_reinit")
torch.cuda.synchronize(device)
# Test 1: Traditional abort + reinit
start_time = time.perf_counter()
dist.destroy_process_group()
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
reinit_time = time.perf_counter() - start_time
# Test collective with original rank values for fair comparison (non-blocking mode)
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
work.wait()
torch.cuda.synchronize(device)
# Verify correctness
expected_sum = sum(r for r in range(self.world_size))
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
dist.destroy_process_group(new_pg)
# Test 2: shrink_group with NCCL_SHRINK_ABORT
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
ranks_to_exclude = [self.world_size - 1]
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
shrink_time = 0
if not is_excluded:
torch.cuda.synchronize(device) # Ensure accurate timing
start_time = time.perf_counter()
shrunk_pg = c10d.shrink_group(
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
)
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
shrink_time = time.perf_counter() - start_time
# Test collective communication on shrunk group (non-blocking mode)
test_tensor = torch.full(
(1,), self.rank, device=device, dtype=torch.float32
)
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
work.wait()
# Verify correctness
expected_sum = sum(
r for r in range(self.world_size) if r not in ranks_to_exclude
)
self.assertEqual(
test_tensor.item(),
expected_sum,
"shrink_test: collective result mismatch",
)
torch.cuda.synchronize(device) # Ensure operations complete
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
dist.destroy_process_group()
else:
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
dist.destroy_process_group()
return
# Performance analysis (only for participating ranks)
if shrink_time > 0 and reinit_time > 0:
speedup = reinit_time / shrink_time
time_saved = reinit_time - shrink_time
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
if speedup > 1.1:
log_test_success(self.rank, "shrink_group significantly faster")
elif speedup > 0.9:
log_test_info(self.rank, "≈ comparable performance")
else:
log_test_warning(self.rank, "abort+reinit faster")
log_test_info(self.rank, "Performance test completed")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_deterministic_mode_no_break(self):
@ -5115,6 +5789,229 @@ class NCCLTraceTest(NCCLTraceTestBase):
else:
self.assertTrue("duration_ms" not in t["entries"][0])
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
"""
Test that when the circular buffer in entries_ is full and we call reset,
then fill the buffer with new entries, dump_entries returns only the new
entries and not the old ones.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely with 10 entries
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify buffer is full with 10 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Now reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add new entries after reset - fill the buffer completely again
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we get exactly 10 new entries, not 20
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Verify all entries have the expected properties (from after reset)
# After reset, record IDs should start from 0 again
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
self.assertIn("record_id", entry)
# Record IDs should be sequential starting from 0 after reset
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
"""
Test that when the circular buffer is full, we reset, and then add fewer
entries than the buffer size, we only get the new entries.
This tests that old entries at the end of the circular buffer are properly
filtered out based on reset_epoch.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add only 3 new entries (much less than buffer size)
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we only get the 3 new entries, not 10
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 3)
# Verify record IDs start from 0 after reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_wraparound(self, timing_enabled):
"""
Test that when we reset in the middle of the circular buffer and then
wrap around, dump_entries correctly returns only entries from the current
epoch in the correct order.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill half the buffer
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset at this point (reset happens at index 5)
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Now add 8 entries, which will wrap around
# (5->9 fills rest of buffer, then 0->2 wraps around)
for _ in range(8):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should get exactly 8 entries, properly ordered
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 8)
# Entries should be in chronological order
# The dump_entries() method returns entries from next_ to end, then 0 to next_
# After filtering old entries, we should have 8 entries in order
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_multiple_resets(self, timing_enabled):
"""
Test multiple consecutive resets to ensure each reset properly increments
the epoch and filters out entries from previous epochs.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# First batch: 2 entries
for _ in range(2):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# First reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Second batch: 3 entries
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Second reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Third batch: 4 entries
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should only see the last 4 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
# Verify record IDs start from 0 after the last reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
def check_if_test_is_skipped(fn):
def wrapper(self, *args, **kwargs):
@ -5446,6 +6343,14 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
if self.rank == 6 or self.rank == 7:
dist.broadcast(tensor2, 6, group=ng2)
self.assertEqual(tensor2, torch.full((1,), 6))
# Test the case when the split changes the pg option of split group
# while the parent pg option is not changed.
new_pg = c10d.new_group([0, 1, 2, 3, 4, 5, 6, 7], device_id=device)
backend_new_pg = new_pg._get_backend(torch.device(device))
self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8)
c10d.split_group(new_pg, [[0, 2, 4, 6], [1, 3, 5, 7]])
self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8)
# a barrier and a cuda sync before destroying all pgs.
dist.barrier(pg)
torch.cuda.synchronize()

View File

@ -7,6 +7,8 @@ import torch
import torch.distributed as dist
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalIntNode,
LocalRunnerMode,
LocalTensor,
LocalTensorMode,
)
@ -17,8 +19,10 @@ from torch.distributed.tensor import (
Partial,
Replicate,
Shard,
zeros,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import reduce_local_int
class LocalTensorTestBase(TestCase):
@ -411,5 +415,78 @@ class TestLocalTensorWorld8(LocalTensorTestBase):
self.assertEqual(full_tensor, local_res)
from torch.distributed._local_tensor._c10d import local_p2p_op, wait_all
class TestLocalRunner(LocalTensorTestBase):
world_size = 6
@staticmethod
def _get_pp_peer(pp_index, mesh, dim, dir):
pp_meshes = mesh._get_all_submeshes(dim)
pp_ret = {}
for pp_mesh in pp_meshes:
global_rank = pp_mesh.mesh[pp_index].item()
global_peer = pp_mesh.mesh[(pp_index + dir) % pp_mesh.size()].item()
pp_ret[global_rank] = global_peer
return torch.SymInt(LocalIntNode(pp_ret))
def _run_dp_pp(
self,
mesh: DeviceMesh,
pp_index: int,
actual: list[torch.Tensor | None],
expected: list[torch.Tensor | None],
) -> None:
ltm = LocalTensorMode(mesh.size())
with ltm:
dp_mesh = mesh["dp"]
pp_mesh = mesh["pp"]
x = torch.rand(2, 4)
xd = distribute_tensor(x, dp_mesh, [Shard(0)])
xd = xd * 2
x = x * 2
yd = zeros(*xd.shape, device_mesh=dp_mesh, placements=[Shard(0)])
if pp_index != pp_mesh.size(0) - 1:
# Send to next pp rank
pp_next_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", +1)
local_p2p_op(pp_next_rank, xd, dist.isend)
expected[pp_index + 1] = ltm.tensor_map(
x,
lambda r, t: t
if reduce_local_int(pp_next_rank, lambda vals: r in vals.values())
else torch.zeros_like(t),
)
if pp_index != 0:
# Receive from prev pp rank
pp_prev_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", -1)
rw = local_p2p_op(pp_prev_rank, yd, dist.irecv)
wait_all(rw)
y = yd.full_tensor()
actual[pp_index] = y
def test_dp_pp(self):
pp_size = 3
mesh = init_device_mesh(
"cpu", (self.world_size // pp_size, pp_size), mesh_dim_names=("dp", "pp")
)
actual: list[torch.Tensor | None] = [None] * pp_size
expected: list[torch.Tensor | None] = [None] * pp_size
with LocalRunnerMode(
self.world_size,
pp_size,
lambda pp_index: self._run_dp_pp(mesh, pp_index, actual, expected),
):
pass
self.assertEqual(actual, expected)
if __name__ == "__main__":
run_tests()

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_heapq.py b/test/dynamo/cpython/3_13/test_heapq.py
index 1aa8e4e2897..94315fa68b4 100644
index 1aa8e4e2897..bc177c2943e 100644
--- a/test/dynamo/cpython/3_13/test_heapq.py
+++ b/test/dynamo/cpython/3_13/test_heapq.py
@@ -1,3 +1,23 @@
@ -35,7 +35,7 @@ index 1aa8e4e2897..94315fa68b4 100644
def test_py_functions(self):
for fname in func_names:
self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
@@ -27,24 +47,7 @@ class TestModules(TestCase):
@@ -27,24 +47,12 @@ class TestModules(TestCase):
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
@ -46,12 +46,15 @@ index 1aa8e4e2897..94315fa68b4 100644
- # However, doctest can't easily find all docstrings in the module (loading
- # it through import_fresh_module seems to confuse it), so we specifically
- # create a finder which returns the doctests from the merge method.
-
+@torch._dynamo.disable
+def randrange(*args):
+ return random.randrange(*args)
- class HeapqMergeDocTestFinder:
- def find(self, *args, **kwargs):
- dtf = doctest.DocTestFinder()
- return dtf.find(py_heapq.merge)
-
- tests.addTests(doctest.DocTestSuite(py_heapq,
- test_finder=HeapqMergeDocTestFinder()))
- return tests
@ -61,7 +64,155 @@ index 1aa8e4e2897..94315fa68b4 100644
def test_push_pop(self):
# 1) Push 256 random numbers and pop them off, verifying all's OK.
@@ -264,12 +267,12 @@ class TestHeap:
@@ -52,7 +60,8 @@ class TestHeap:
data = []
self.check_invariant(heap)
for i in range(256):
- item = random.random()
+ with torch._dynamo.error_on_graph_break(False):
+ item = random.random()
data.append(item)
self.module.heappush(heap, item)
self.check_invariant(heap)
@@ -83,14 +92,16 @@ class TestHeap:
def test_heapify(self):
for size in list(range(30)) + [20000]:
- heap = [random.random() for dummy in range(size)]
+ with torch._dynamo.error_on_graph_break(False):
+ heap = [random.random() for dummy in range(size)]
self.module.heapify(heap)
self.check_invariant(heap)
self.assertRaises(TypeError, self.module.heapify, None)
def test_naive_nbest(self):
- data = [random.randrange(2000) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [randrange(2000) for i in range(1000)]
heap = []
for item in data:
self.module.heappush(heap, item)
@@ -113,7 +124,8 @@ class TestHeap:
# heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops
# (10 log-time steps).
- data = [random.randrange(2000) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@@ -126,7 +138,8 @@ class TestHeap:
self.assertRaises(IndexError, self.module.heapreplace, [], None)
def test_nbest_with_pushpop(self):
- data = [random.randrange(2000) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@@ -163,8 +176,9 @@ class TestHeap:
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
for trial in range(100):
- size = random.randrange(50)
- data = [random.randrange(25) for i in range(size)]
+ with torch._dynamo.error_on_graph_break(False):
+ size = randrange(50)
+ data = [randrange(25) for i in range(size)]
if trial & 1: # Half of the time, use heapify
heap = data[:]
self.module.heapify(heap)
@@ -177,12 +191,13 @@ class TestHeap:
def test_merge(self):
inputs = []
- for i in range(random.randrange(25)):
- row = []
- for j in range(random.randrange(100)):
- tup = random.choice('ABC'), random.randrange(-500, 500)
- row.append(tup)
- inputs.append(row)
+ with torch._dynamo.error_on_graph_break(False):
+ for i in range(randrange(25)):
+ row = []
+ for j in range(randrange(100)):
+ tup = random.choice('ABC'), randrange(-500, 500)
+ row.append(tup)
+ inputs.append(row)
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
for reverse in [False, True]:
@@ -209,12 +224,14 @@ class TestHeap:
list(self.module.merge(iterable(), iterable()))
def test_merge_stability(self):
- class Int(int):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class Int(int):
+ pass
inputs = [[], [], [], []]
for i in range(20000):
- stream = random.randrange(4)
- x = random.randrange(500)
+ with torch._dynamo.error_on_graph_break(False):
+ stream = randrange(4)
+ x = randrange(500)
obj = Int(x)
obj.pair = (x, stream)
inputs[stream].append(obj)
@@ -224,7 +241,8 @@ class TestHeap:
self.assertEqual(result, sorted(result))
def test_nsmallest(self):
- data = [(random.randrange(2000), i) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [(randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nsmallest(n, data)),
@@ -233,7 +251,8 @@ class TestHeap:
sorted(data, key=f)[:n])
def test_nlargest(self):
- data = [(random.randrange(2000), i) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [(randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nlargest(n, data)),
@@ -248,28 +267,29 @@ class TestHeap:
data = [comp(x) for x in data]
self.module.heapify(data)
return [self.module.heappop(data).x for i in range(len(data))]
- class LT:
- def __init__(self, x):
- self.x = x
- def __lt__(self, other):
- return self.x > other.x
- class LE:
- def __init__(self, x):
- self.x = x
- def __le__(self, other):
- return self.x >= other.x
- data = [random.random() for i in range(100)]
+ with torch._dynamo.error_on_graph_break(False):
+ class LT:
+ def __init__(self, x):
+ self.x = x
+ def __lt__(self, other):
+ return self.x > other.x
+ class LE:
+ def __init__(self, x):
+ self.x = x
+ def __le__(self, other):
+ return self.x >= other.x
+ data = [random.random() for i in range(100)]
target = sorted(data, reverse=True)
self.assertEqual(hsort(data, LT), target)
self.assertRaises(TypeError, data, LE)
@ -76,7 +227,7 @@ index 1aa8e4e2897..94315fa68b4 100644
module = c_heapq
@@ -374,7 +377,7 @@ class SideEffectLT:
@@ -374,7 +394,7 @@ class SideEffectLT:
return self.value < other.value
@ -85,7 +236,48 @@ index 1aa8e4e2897..94315fa68b4 100644
def test_non_sequence(self):
for f in (self.module.heapify, self.module.heappop):
@@ -464,13 +467,13 @@ class TestErrorHandling:
@@ -435,10 +455,11 @@ class TestErrorHandling:
def test_comparison_operator_modifiying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
- class EvilClass(int):
- def __lt__(self, o):
- heap.clear()
- return NotImplemented
+ with torch._dynamo.error_on_graph_break(False):
+ class EvilClass(int):
+ def __lt__(self, o):
+ heap.clear()
+ return NotImplemented
heap = []
self.module.heappush(heap, EvilClass(0))
@@ -446,15 +467,16 @@ class TestErrorHandling:
def test_comparison_operator_modifiying_heap_two_heaps(self):
- class h(int):
- def __lt__(self, o):
- list2.clear()
- return NotImplemented
+ with torch._dynamo.error_on_graph_break(False):
+ class h(int):
+ def __lt__(self, o):
+ list2.clear()
+ return NotImplemented
- class g(int):
- def __lt__(self, o):
- list1.clear()
- return NotImplemented
+ class g(int):
+ def __lt__(self, o):
+ list1.clear()
+ return NotImplemented
list1, list2 = [], []
@@ -464,13 +486,13 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))

View File

@ -47,6 +47,11 @@ class TestModules(__TestCase):
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
@torch._dynamo.disable
def randrange(*args):
return random.randrange(*args)
class _TestHeap:
def test_push_pop(self):
@ -55,7 +60,8 @@ class _TestHeap:
data = []
self.check_invariant(heap)
for i in range(256):
item = random.random()
with torch._dynamo.error_on_graph_break(False):
item = random.random()
data.append(item)
self.module.heappush(heap, item)
self.check_invariant(heap)
@ -86,14 +92,16 @@ class _TestHeap:
def test_heapify(self):
for size in list(range(30)) + [20000]:
heap = [random.random() for dummy in range(size)]
with torch._dynamo.error_on_graph_break(False):
heap = [random.random() for dummy in range(size)]
self.module.heapify(heap)
self.check_invariant(heap)
self.assertRaises(TypeError, self.module.heapify, None)
def test_naive_nbest(self):
data = [random.randrange(2000) for i in range(1000)]
with torch._dynamo.error_on_graph_break(False):
data = [randrange(2000) for i in range(1000)]
heap = []
for item in data:
self.module.heappush(heap, item)
@ -116,7 +124,8 @@ class _TestHeap:
# heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops
# (10 log-time steps).
data = [random.randrange(2000) for i in range(1000)]
with torch._dynamo.error_on_graph_break(False):
data = [randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@ -129,7 +138,8 @@ class _TestHeap:
self.assertRaises(IndexError, self.module.heapreplace, [], None)
def test_nbest_with_pushpop(self):
data = [random.randrange(2000) for i in range(1000)]
with torch._dynamo.error_on_graph_break(False):
data = [randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@ -166,8 +176,9 @@ class _TestHeap:
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
for trial in range(100):
size = random.randrange(50)
data = [random.randrange(25) for i in range(size)]
with torch._dynamo.error_on_graph_break(False):
size = randrange(50)
data = [randrange(25) for i in range(size)]
if trial & 1: # Half of the time, use heapify
heap = data[:]
self.module.heapify(heap)
@ -180,12 +191,13 @@ class _TestHeap:
def test_merge(self):
inputs = []
for i in range(random.randrange(25)):
row = []
for j in range(random.randrange(100)):
tup = random.choice('ABC'), random.randrange(-500, 500)
row.append(tup)
inputs.append(row)
with torch._dynamo.error_on_graph_break(False):
for i in range(randrange(25)):
row = []
for j in range(randrange(100)):
tup = random.choice('ABC'), randrange(-500, 500)
row.append(tup)
inputs.append(row)
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
for reverse in [False, True]:
@ -212,12 +224,14 @@ class _TestHeap:
list(self.module.merge(iterable(), iterable()))
def test_merge_stability(self):
class Int(int):
pass
with torch._dynamo.error_on_graph_break(False):
class Int(int):
pass
inputs = [[], [], [], []]
for i in range(20000):
stream = random.randrange(4)
x = random.randrange(500)
with torch._dynamo.error_on_graph_break(False):
stream = randrange(4)
x = randrange(500)
obj = Int(x)
obj.pair = (x, stream)
inputs[stream].append(obj)
@ -227,7 +241,8 @@ class _TestHeap:
self.assertEqual(result, sorted(result))
def test_nsmallest(self):
data = [(random.randrange(2000), i) for i in range(1000)]
with torch._dynamo.error_on_graph_break(False):
data = [(randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nsmallest(n, data)),
@ -236,7 +251,8 @@ class _TestHeap:
sorted(data, key=f)[:n])
def test_nlargest(self):
data = [(random.randrange(2000), i) for i in range(1000)]
with torch._dynamo.error_on_graph_break(False):
data = [(randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nlargest(n, data)),
@ -251,17 +267,18 @@ class _TestHeap:
data = [comp(x) for x in data]
self.module.heapify(data)
return [self.module.heappop(data).x for i in range(len(data))]
class LT:
def __init__(self, x):
self.x = x
def __lt__(self, other):
return self.x > other.x
class LE:
def __init__(self, x):
self.x = x
def __le__(self, other):
return self.x >= other.x
data = [random.random() for i in range(100)]
with torch._dynamo.error_on_graph_break(False):
class LT:
def __init__(self, x):
self.x = x
def __lt__(self, other):
return self.x > other.x
class LE:
def __init__(self, x):
self.x = x
def __le__(self, other):
return self.x >= other.x
data = [random.random() for i in range(100)]
target = sorted(data, reverse=True)
self.assertEqual(hsort(data, LT), target)
self.assertRaises(TypeError, data, LE)
@ -438,10 +455,11 @@ class _TestErrorHandling:
def test_comparison_operator_modifiying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
class EvilClass(int):
def __lt__(self, o):
heap.clear()
return NotImplemented
with torch._dynamo.error_on_graph_break(False):
class EvilClass(int):
def __lt__(self, o):
heap.clear()
return NotImplemented
heap = []
self.module.heappush(heap, EvilClass(0))
@ -449,15 +467,16 @@ class _TestErrorHandling:
def test_comparison_operator_modifiying_heap_two_heaps(self):
class h(int):
def __lt__(self, o):
list2.clear()
return NotImplemented
with torch._dynamo.error_on_graph_break(False):
class h(int):
def __lt__(self, o):
list2.clear()
return NotImplemented
class g(int):
def __lt__(self, o):
list1.clear()
return NotImplemented
class g(int):
def __lt__(self, o):
list1.clear()
return NotImplemented
list1, list2 = [], []

View File

@ -283,7 +283,7 @@ class TestCompilerBisector(TestCase):
)
def test_bisect_pre_grad_graph(self):
def f(x):
for i in range(5):
for _ in range(5):
x = x + 1
return x.relu()

View File

@ -36,6 +36,15 @@ class DummyUserDict(UserDict):
pass
class FakeMapping:
def __init__(self, value: Any) -> None:
self._value = value
self.keys = lambda: ["a", "b", "c"] # not required to be a method
def __getitem__(self, key: str) -> Any:
return self._value
class DictTests(torch._dynamo.test_case.TestCase):
def test_dict_subclass_instantiation(self):
def fn(x):
@ -666,6 +675,18 @@ class DictTests(torch._dynamo.test_case.TestCase):
for k1, m2 in zip(modules, module_dict.children()):
self.assertTrue(modules[k1] is m2)
# FIXME: see comment in torch/_dynamo/polyfills/__init__.py:mutable_mapping_update
@unittest.expectedFailure
def test_dict_construct_from_mapping_like(self):
def fn(x):
fm = FakeMapping(x)
d = dict(fm, x=x)
return d
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_initialization_in_graph(self):
for super_class in (
OrderedDict,
@ -1087,12 +1108,52 @@ class DictTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
@unittest.expectedFailure
def test_newly_constructed_default_dict_no_default_factory(self):
def f1(x):
d = defaultdict()
try:
d[1] += 42
except KeyError:
d[1] = 1
return x + 1, d
x = torch.ones(2)
ref = f1(x)
res = torch.compile(f1, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
def f2(x):
d = defaultdict(None)
try:
d[1] += 42
except KeyError:
d[1] = 1
return x + 1, d
ref = f2(x)
res = torch.compile(f2, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
def f3(x):
d = defaultdict(None, {1: 10})
d[1] += 42
try:
d[2] += 24
except KeyError:
d[2] = 1
return x + 1, d
ref = f3(x)
res = torch.compile(f3, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
def test_newly_constructed_default_dict_with_dict(self):
def f(x):
d = defaultdict(dict, {2: {"a": 1}})
d[0] = {"b": 2}
return x + 1, d
d = dict([("a", 1), ("b", 2)], c=3) # noqa: C406
dd = defaultdict(list, d, d=4, e=5)
dd["x"].append(42)
return x + 1, d, dd
x = torch.ones(2)
ref = f(x)

View File

@ -427,17 +427,29 @@ from user code:
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return re.sub(
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
self.assertExpectedInline(
first_graph_break,
post_munge(first_graph_break),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)
return extract_graph(fn, *args, **kwargs)[0:3]
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Sequence
from typing import Any, Callable, Union
from collections.abc import Callable, Sequence
from typing import Any, Union
import torch
import torch._dynamo

View File

@ -69,6 +69,7 @@ from torch.fx.experimental.symbolic_shapes import (
constrain_unify,
ConstraintViolationError,
expect_true,
guard_or_false,
guard_size_oblivious,
ShapeEnv,
)
@ -100,7 +101,6 @@ from torch.testing._internal.common_utils import (
wrapDeterministicFlagAPITest,
)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.logging_utils import logs_to_string
pytree_modules = {
@ -13194,6 +13194,30 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
self.assertEqual(actual, expected)
@parametrize_pytree_module
def test_pytree_tree_map_dict_order(self, pytree):
def fn(tree):
new_tree = pytree.tree_map(lambda x: x, tree)
return list(new_tree.keys()), list(new_tree.values())
x = torch.randn(3, 2)
fn_opt = torch.compile(fullgraph=True)(fn)
tree1 = {"b": x + 2, "a": x, "c": x - 1}
expected1 = fn(tree1)
actual1 = fn_opt(tree1)
self.assertEqual(actual1, expected1)
tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)])
expected2 = fn(tree2)
actual2 = fn_opt(tree2)
self.assertEqual(actual2, expected2)
tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1})
expected3 = fn(tree3)
actual3 = fn_opt(tree3)
self.assertEqual(actual3, expected3)
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):
@ -13219,6 +13243,27 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
def test_pytree_register_constant_with_side_effect(self):
class Foo:
pass
class Bar:
def __eq__(self, other):
return super().__eq__(other)
def __hash__(self):
return 0
python_pytree.register_constant(Bar)
@torch.compile(backend="eager", fullgraph=True)
def fn(x, obj):
obj.attr = {3: Bar()}
return x + 1
inp = torch.ones(3)
self.assertEqual(fn(inp, Foo()), inp + 1)
class TestTracer(JitTestCase):
def test_jit_save(self):
@ -13636,6 +13681,74 @@ instantiate_device_type_tests(
)
class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_tensor_mul(self):
def symbool_mul_fn(x_bool, sentinel):
result = x_bool * sentinel
return result
x_true = torch.tensor([True], device="cuda")
x_false = torch.tensor([False], device="cuda")
sentinel = torch.tensor(2.0, requires_grad=True, device="cuda")
eager_result_true = symbool_mul_fn(x_true, sentinel)
eager_result_false = symbool_mul_fn(x_false, sentinel)
compiled_fn = torch.compile(symbool_mul_fn, fullgraph=True, dynamic=True)
compiled_result_true = compiled_fn(x_true, sentinel)
compiled_result_false = compiled_fn(x_false, sentinel)
self.assertEqual(eager_result_true, compiled_result_true)
self.assertEqual(eager_result_false, compiled_result_false)
self.assertEqual(compiled_result_true.item(), 2.0)
self.assertEqual(compiled_result_false.item(), 0.0)
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_guard_or_false(self):
def symbool_guard_fn(a_bool_tensor, b):
u0 = a_bool_tensor.item()
# Make sure guard_or_false still handles SymBool produced by .item()
if guard_or_false(u0):
return b * 10
else:
return b * 100
compiled_guard_fn = torch.compile(
symbool_guard_fn, backend="eager", dynamic=True
)
a_true = torch.tensor(True, device="cuda")
a_false = torch.tensor(False, device="cuda")
b = torch.randn(6, device="cuda")
eager_res_true = symbool_guard_fn(a_true, b)
compiled_res_true = compiled_guard_fn(a_true, b)
self.assertEqual(eager_res_true, compiled_res_true)
eager_res_false = symbool_guard_fn(a_false, b)
compiled_res_false = compiled_guard_fn(a_false, b)
self.assertEqual(eager_res_false, compiled_res_false)
self.assertEqual(compiled_res_true, b * 10)
self.assertEqual(compiled_res_false, b * 100)
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_tensor_mul_does_not_fail(self):
def fuzzed_program(arg_0, sentinel):
var_node_2 = arg_0
var_node_1 = torch.squeeze(var_node_2)
var_node_0 = var_node_1.item()
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
sentinel = torch.tensor(1.0, requires_grad=True, device="cuda")
arg_0 = torch.tensor([True], dtype=torch.bool, device="cuda")
args = (arg_0,) + (sentinel,)
try:
compiled_program = torch.compile(
fuzzed_program, fullgraph=True, dynamic=True
)
compiled_program(*args)
except Exception as e:
self.fail(f"torch.compile failed with error: {e}")
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, NamedTuple, Optional
from typing import NamedTuple, Optional, TYPE_CHECKING
import torch
import torch._dynamo
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -48,6 +48,7 @@ from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
expectedFailureDynamic,
rand_strided,
same,
skipIfNotPy312,
@ -1000,6 +1001,18 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.exit_stack.close()
super().tearDown()
def test_compiled_module_truthiness(self):
# Test with empty ModuleList
original_empty = nn.ModuleList()
compiled_empty = torch.compile(original_empty)
self.assertEqual(bool(original_empty), bool(compiled_empty))
self.assertFalse(bool(compiled_empty))
# Test with non-empty ModuleList
original_filled = nn.ModuleList([nn.Linear(10, 5)])
compiled_filled = torch.compile(original_filled)
self.assertEqual(bool(original_filled), bool(compiled_filled))
self.assertTrue(bool(compiled_filled))
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
root = guard_manager_wrapper.root
cloned_root = root.clone_manager(lambda x: True)
@ -7443,6 +7456,93 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
msg,
)
@expectedFailureDynamic
def test_dynamo_default_lru_cache_behavior(self):
@torch.compile(backend="eager")
def fn(x):
return x + 10
torch._dynamo.reset()
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
# Step 1: Compile a static shapes graph
x = torch.randn(10, 10)
fn(x)
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(a), 1)
static_shapes_cache_entry = a[0]
# Step 2: Compile a dynamic shapes graph
y = torch.randn(20, 20)
fn(y)
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(b), 2)
self.assertEqual(b[1], static_shapes_cache_entry)
dynamic_shapes_cache_entry = b[0]
# Step 3: Run with Step 1's inputs
# LRU cache will match against dynamic shape graph first
fn(x)
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(c), 2)
self.assertEqual(c[0], dynamic_shapes_cache_entry)
self.assertEqual(c[1], static_shapes_cache_entry)
@expectedFailureDynamic
def test_dynamo_disable_lru_cache_behavior(self):
@torch.compile(backend="eager")
def fn(x):
return x + 10
def run():
torch._dynamo.reset()
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
# Step 1: Compile a static shapes graph
x = torch.randn(10, 10)
fn(x)
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(a), 1)
static_shapes_cache_entry = a[0]
# Step 2: Compile a dynamic shapes graph
y = torch.randn(20, 20)
fn(y)
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(b), 2)
self.assertEqual(b[0], static_shapes_cache_entry)
dynamic_shapes_cache_entry = b[1]
# Step 3: Run with Step 1's inputs
# LRU cache is disabled, we should still have static entry first
fn(x)
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(c), 2)
self.assertEqual(c[0], static_shapes_cache_entry)
self.assertEqual(c[1], dynamic_shapes_cache_entry)
try:
torch._C._dynamo.eval_frame._set_lru_cache(False)
run()
finally:
torch._C._dynamo.eval_frame._set_lru_cache(True)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

Some files were not shown because too many files have changed in this diff Show More