Compare commits

..

53 Commits

Author SHA1 Message Date
6d8946c013 Fix char printing
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-15 11:19:44 +08:00
4ed26f7382 distributed/debug: add an HTTP server for debugging running jobs (#167395)
This adds a debug HTTP server for debugging stuck or slow jobs. It runs the WorkerServer on every worker and then launches a separate flask process on rank 0 to have users connect to for debugging.

This can easily be improved to trigger profilers as well as visualize the data much better.

Initial handlers:
* pytorch profiler
* FlightRecorder data
* Python stacks

```
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"

from torch.distributed.debug import enable_debug_server

enable_debug_server()
```

Test plan:

```
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py
```

<img width="1499" height="1629" alt="20251107_17h10m47s_grim" src="https://github.com/user-attachments/assets/a8b9a0cb-3bbf-4558-be12-5253e418214e" />
<img width="1192" height="1337" alt="20251107_17h10m39s_grim" src="https://github.com/user-attachments/assets/ac5d7011-4acb-4401-bf2c-f9b22c1466bd" />

<img width="984" height="851" alt="20251107_18h35m38s_grim" src="https://github.com/user-attachments/assets/98b3eb31-ed01-4345-90dd-c79345cf82ce" />
<img width="2880" height="777" alt="20251107_18h35m31s_grim" src="https://github.com/user-attachments/assets/8de84b8b-9d06-4bc8-a1bf-280a2958315b" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167395
Approved by: https://github.com/fduwjj
2025-11-14 23:14:38 +00:00
4c79305b87 [targets2buck] Clean up get_pt_ops_deps (#167690)
Summary: I didn't understand what this macro was doing so I created a bit of a mess, mess be gone!

Test Plan: `buck2 ctargets fbcode//caffe2/... fbsource//xplat/caffe2/...`

Reviewed By: mzlee

Differential Revision: D86460608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167690
Approved by: https://github.com/seemethere
2025-11-14 23:05:24 +00:00
f4b8c4f907 backed size oblivious checks for expand() (#167689)
Summary:
Support semantics when using backed_size_oblivious, similar to https://github.com/pytorch/pytorch/pull/167232

We see errors in a model exported with dynamic shapes, like
```
RuntimeError: non-broadcasting semantics require s67 == 41

While executing %expand : [num_users=1] = call_method[target=expand](args = (%reshape_5, -1, -1, %getitem_9), kwargs = {})
```

Test Plan:
test_dynamic_shapes:
```
test_backed_size_oblivious_expand (test_dynamic_shapes.TestUbackedOps) ... I1112 14:07:54.724596 1386932 Logger.cpp:995] Dropping logs in unit tests.
ok
```

Differential Revision: D86902546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167689
Approved by: https://github.com/laithsakka
2025-11-14 22:31:28 +00:00
d629b7a459 Move CppTypeToScalarType to torch/headeronly (#167610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167610
Approved by: https://github.com/pearu, https://github.com/janeyx99
2025-11-14 22:21:45 +00:00
0922ba5f42 [BE] No need to pass const enum values by reference (#167868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167868
Approved by: https://github.com/slayton58
2025-11-14 21:56:19 +00:00
c87295c044 [precompile] Support captured global tensors. (#167846)
Summary:
In vllm we saw cases where user intialize a tensor in the global scope and reference it in the forward body. This should be supported by pruning the used globals in the scope and serialize them along the artifacts similar to how we handle closure.

Use case example: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3n.py#L65

Test Plan:
test_aot_compile.py

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167846
Approved by: https://github.com/jamesjwu
2025-11-14 21:40:07 +00:00
7aa210d215 Revert "[CodeClean] Remove the Unused MACRO for AOT Inductor Runtime (#165139)"
This reverts commit fcd5f8c352b5b75bd32e57fa044ec5df095032da.

Reverted https://github.com/pytorch/pytorch/pull/165139 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
5a368b8010 Revert "[CodeClean] Replace std::runtime_error with TORCH_CHECK (#165119)"
This reverts commit 398775a43e9808205f75c81d36f5087117d3f3f4.

Reverted https://github.com/pytorch/pytorch/pull/165119 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
602102be50 Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496)"
This reverts commit bc09a84150eaadaadab8a8ecd76cd9afc60d8a19.

Reverted https://github.com/pytorch/pytorch/pull/167496 on behalf of https://github.com/jeanschmidt due to trying to revert 165139, my intention is to land it again, so, will land this once both are reverted ([comment](https://github.com/pytorch/pytorch/pull/167496#issuecomment-3534641209))
2025-11-14 21:33:02 +00:00
200156e385 DTensor: avoid unnecessary DTensorSpec creation in _ToTorchTensor.backward (#167588)
Looks like the check here is cheap and has a potentially large payoff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167588
Approved by: https://github.com/ezyang
2025-11-14 21:08:12 +00:00
a2daf3fc86 [Inductor] Add support bound methods in pattern matcher (#167795)
Fixes: #167776

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167795
Approved by: https://github.com/mlazos
2025-11-14 20:55:51 +00:00
52b45c16de Add reshape, view, flatten to torch/csrc/stable (#167600)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167600
Approved by: https://github.com/janeyx99
ghstack dependencies: #167592
2025-11-14 20:35:53 +00:00
2ef85bed5a Add empty to stable ops (#167592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167592
Approved by: https://github.com/janeyx99
2025-11-14 20:35:53 +00:00
d99c6bcf69 [export] Disable side effects on dynamo_graph_capture_for_export and warn user. (#167763)
Summary:
as title.

Test Plan:
test_dynamo_graph_capture_side_effects

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167763
Approved by: https://github.com/tugsbayasgalan
2025-11-14 20:35:22 +00:00
8378abda84 [torch.export] Fix for flaky test_annotate_on_assert (#167805)
Summary: test_annotate_on_assert become flaky with PR 166341 (Details in https://github.com/pytorch/pytorch/issues/167432). Torchdynamo related metadata can vary depending on the caller. Removing the those metadata before comparison.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_annotate_on_assert'
```
https://www.internalfb.com/intern/testinfra/testrun/7036874728749661

Differential Revision: D87036890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167805
Approved by: https://github.com/yushangdi
2025-11-14 19:56:51 +00:00
5b42a5d9a6 [doc] Add example for torch.is_storage (#161898)
Fixes #161858

### Summary:
Added comprehensive documentation examples for `torch.is_storage()` to help users understand how to check if an object is a PyTorch storage object.

### Impact:

- Enhances API Documentation
- Helps users distinguish between PyTorch storage objects and other types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161898
Approved by: https://github.com/isuruf, https://github.com/malfet
2025-11-14 19:45:54 +00:00
caca3f2eec Revert "Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167722)"
This reverts commit 40e6f090d91026947fbec92a42564ad492f37eae.

Reverted https://github.com/pytorch/pytorch/pull/167722 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167722#issuecomment-3534282212))
2025-11-14 19:38:22 +00:00
9e2bf129e1 [MPS] addmm complex fix (#167826)
Fixes #167727

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167826
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-14 19:29:09 +00:00
c429b1fc5c Ops convolution_backward optional flag bug (#165008)
Fixes #89629

When using torch.ops.aten.convolution_backward, the optional argument bias_sizes was being used in the python function registration without checking whether it was defined.

## For the fix
there are two modes to consider with different results.

First @dynamo.optimize("inductor") is the most demanding.
We cannot be wrong about the size passed into the function. But we should not ignore what the user wants/thinks they are doing. For this case, we want to throw an error when the user is wrong. If the user passes in None, we calculate the expected size directly.

Second @dynamo.optimize("eager") is very lenient.
We really can provide any value we want here. If the user is wrong about bias shape in eager mode, the op will just reshape the bias to the proper size so no error is thrown here.

## For testing
An OpInfo was added for torch.ops.aten.convolution_backward.default.
For the CUDA test_noncontiguous_samples test, a slightly updated error tolerance was necessary for the compounded add multiply (for 2x2 kernel).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165008
Approved by: https://github.com/bdhirsh
2025-11-14 19:24:45 +00:00
1176b2b0b7 [BE]: Update NVTX submodule to 3.3.0 (#167751)
Update NVTX to 3.3.0. Mostly fixes some errors in the bindings, improve C++20 support, and improve C++ bindings to NVTX. Header only library upgrade so should be mostly safe.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167751
Approved by: https://github.com/albanD, https://github.com/eqy
2025-11-14 19:24:37 +00:00
dd37a1a434 Fix NaN gradients in atan2_backward when both inputs are zero (#166787)
Fixes #165427

## Description of Bug 🐛

As reported in #165427, When both the input of  `atan2` function is zero the gradient becomes `NaN`. During the forward pass, `atan2` successfully avoids division-by-zero issue, but during backpropagation gradients become `NaN`.

This is because the backward pass calculates `(self * self + other * other).reciprocal()`, which becomes `inf` at `(0, 0)`. The subsequent multiplication by zero `(0 * inf)` results in `NaN`.

## Changes
- Added an `at::where` condition to handle zero denominators in `atan2_backward`.
- If denom is zero return 0 for the reciprocal; otherwise, use the original value.

## Testing
- Added` test_atan2_zero_gradient` in `test/test_autograd.py` to verify `atan2` returns `0.0` gradients for `(0,0)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166787
Approved by: https://github.com/soulitzer
2025-11-14 19:23:33 +00:00
a74adcf80e [codemod][lowrisk] Remove unused exception parameter from caffe2/caffe2/serialize/inline_container.cc (#167612)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Differential Revision: D85813824

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167612
Approved by: https://github.com/seemethere, https://github.com/malfet
2025-11-14 19:11:01 +00:00
5eac46a011 add assume_32bit_indexing inductor config (#167784)
when we know all tensor and intermediate tensors fit in 32 bit but use unbacked DS
we want a way to assume that we can use 32 bit indexing(we will runtime assert on it).

It is not practical to torch check every possible intermediate tensor size ahead of time.

This is needed to enhance vLLM perf with unbacked,  since in vLLM all tensors and
intermediates assumed to fit in 32 bits.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167784
Approved by: https://github.com/jansel
2025-11-14 19:04:22 +00:00
e0fff31ae3 [dynamo] Make global state guards and torch function stack guards droppable. (#167674)
Summary:
Prior to this PR we will always build global and torch funciton guards in all cases.

In this PR we did 2 changes to dynamo guards:
1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn
2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack.

Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674
Approved by: https://github.com/anijain2305
2025-11-14 18:11:44 +00:00
7ede33b8e3 Tiling bug fix (#167771)
Fix for https://github.com/pytorch/pytorch/issues/166653.

Two fixes:
- We were inducing a split for broadcasted loads. e.g. (x // 16). While a split of 16 here will make the load coalesced in one of the tile vars, since the load is already in cache it's not worth splitting. And it would make the other tile var load from memory that isnt in cache.
- Add a slight term for uncoalesced memory. This prevents doing tiling for loads which are a small % of the overall kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167771
Approved by: https://github.com/v0i0
2025-11-14 17:32:42 +00:00
065176cd97 [export] Add pytree input check for dynamo_graph_capture_for_export (#167731)
Summary:
as title.

Test Plan:
pytest test/export/test_export.py -k test_invalid_pytree_dynamo_graph_capture

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167731
Approved by: https://github.com/tugsbayasgalan
2025-11-14 17:29:55 +00:00
eqy
02ee7dd7d3 [CUDA][Test] Add serialTest() to some largeTensorTest tests (#167471)
Try to prevent two big tests from overlapping in their memory usage

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167471
Approved by: https://github.com/soulitzer
2025-11-14 17:13:14 +00:00
99fdca8f4d [ROCm] Enable StaticCudaLauncher for ROCm (#166492)
This PR enables ROCm/HIP support for PyTorch's StaticCudaLauncher, which provides static compilation and launching of Triton kernels. The implementation has been tested on AMD MI300 and MI200 hardware.

**Changes**

**Python (torch/_inductor/runtime/)**
- static_cuda_launcher.py: Added ROCm detection, .hsaco binary support, and ROCm-specific scratch parameter handling
- triton_heuristics.py: Updated device type checks to support both cuda and hip

**C++ (torch/csrc/)**
- Module.cpp: Enabled StaticCudaLauncher for ROCm builds
- inductor/static_cuda_launcher.cpp: Added HIP API equivalents for all CUDA driver calls
- inductor/static_cuda_launcher.h: Updated header guard

**Tests (test/inductor/)**
- test_static_cuda_launcher.py: Removed @skipIfRocm decorators and updated binary file handling

**Enabled Unit Tests**
All tests in test/inductor/test_static_cuda_launcher.py now pass on ROCm:
1. test_basic
2. test_unsigned_integers
3. test_signed_integers
4. test_basic_1arg
5. test_constexpr
6. test_implied_constant
7. test_kernel_no_args
8. test_high_shared_mem
9. test_too_high_shared_mem
10. test_kernel_empty_tensor
11. test_kernel_many_args
12. test_basic_compile
13. test_incompatible_code
14. test_static_launch_user_defined_triton_kernels
15. test_empty_tensor
16. test_any
17. test_disable_static_cuda_launcher

In addition to this, the following tests from test/inductor/test_codecache.py also pass:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
2. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
3. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True
4. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
5. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
6. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True

The following tests are skipped since triton bundling is necessary for StaticCudaLauncher:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True
2. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166492
Approved by: https://github.com/jeffdaily
2025-11-14 17:11:45 +00:00
9d1a74cb0c Fix mvlgamma_ FPE crash on x86 with integer input (#164230)
Fixes #161871.

Behaviour on arm:

```
PyTorch version: 2.10.0a0+gitdef3b05
Architecture: arm64
Platform: Darwin
Processor: arm

Testing mvlgamma_ with integer tensor on arm64...
 Got expected error: mvlgamma: result type Long can't be cast to the desired output type Float
```

and on x86:

```
PyTorch version: 2.10.0a0+git1310d6a
Architecture: x86_64
Platform: Linux
Processor: x86_64

Testing mvlgamma_ with integer tensor on x86_64...
 Got expected error: mvlgamma: result type Long can't be cast to the desired output type Float
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164230
Approved by: https://github.com/albanD
2025-11-14 17:09:10 +00:00
40e6f090d9 Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167722)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.
This diff adds mutexes to synchronize access to the static maps.

Note: this is a re-land of D86316117 / https://github.com/pytorch/pytorch/pull/167248 (see comments for details)

Test Plan:
Use a GPU OD, run multi-threaded tests (cuda_cublas_handle_pool_test) with TSAN:
```
buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
```
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

TSAN output (before synchronization was added): P2026731804

Differential Revision: D86964261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167722
Approved by: https://github.com/malfet
2025-11-14 16:16:35 +00:00
bfddfde50c Add basic spin config and linting commands (#167226)
This PR adds a basic spin configuration to allow for linting. It is designed as a drop-in replacement for the current Makefile based solution, i.e. it sets up and updates lintrunner based on the hashes of certain configuration files.

Lintrunner is called via Uv's `uvx` command, separating its environment from the general development environment in an effort to reduce instances of competing requirements breaking environments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167226
Approved by: https://github.com/atalman, https://github.com/albanD
2025-11-14 15:35:42 +00:00
b6570615f8 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-14 15:33:11 +00:00
226850cc66 [ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734
Approved by: https://github.com/eqy, https://github.com/cyyever
2025-11-14 08:44:04 +00:00
f8a2ce3b9a Fix inplace ops on Partial DTensors to preserve aliasing semantics (#164729)
Fixes #163374.

Here is the output from reproducible code:

```
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811]
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
  aten::clamp_(dt: f32[][R], None, 2)
    redistribute_input(0, [P] -> [R])
      redistribute_input(t: f32[], [P] -> [R])
        _c10d_functional::all_reduce(t: f32[], sum, 0)
        _c10d_functional::wait_tensor(t: f32[])
    aten::clamp_(t: f32[], None, 2)
    aten::view(t: f32[], [])
(Replicate(),)
tensor(2., device='cuda:0')
```

The behavior is now matching what you were expecting in issue #163374:

Expected behavior (from the issue):
  1. Placement should change from Partial(sum) to Replicate()
  2. Value should be tensor(2.) instead of tensor(144.)

  Actual output from this build:
  1. (Replicate(),) - placement is correct
  2. tensor(2., device='cuda:0') - value is correct

so the inplace operation now properly redistributes the partial DTensor to replicate before performing the clamp snd maintains the correct aliasing semantics. It also produces the expected clamped value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164729
Approved by: https://github.com/ezyang
2025-11-14 07:46:35 +00:00
e2c6834584 Revert "deprecate check_is_size and guard_size_oblivious (#167198)"
This reverts commit 50bf1f0b819f0b1cc9acbb0646ac9555bb9d44b9.

Reverted https://github.com/pytorch/pytorch/pull/167198 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167198#issuecomment-3531149912))
2025-11-14 06:46:15 +00:00
0e7235ed73 [xpu][feature] [1/3] add fp8 scaled_mm implementation for XPU (#165978)
This PR implements `scaled_mm` for XPU. It enables the following data types:
1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2`
2. RowWise Scaling:  `fp8_e4m3` and `fp8_e5m2`

It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts.

This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts.

Secondly, there is a `scaled_mm_v2` API in #164141 . We will align with it once the v1 is cleaned up.

**Co-author:** @yuchengliu1, @carsonwang

## PR stack:

- -> https://github.com/pytorch/pytorch/pull/165978 : implementation of XPU scaled_mm and oneDNN kernel
- https://github.com/pytorch/pytorch/pull/167518 : implementation of XPU scaled_mm_v2
- https://github.com/pytorch/pytorch/pull/166056 : Op registration

## Test Status:

1. Relies on the changes in https://github.com/intel/torch-xpu-ops/pull/1746/, Otherwise the op will fallback to CPU.
2. This PR does not include tests, the tests are enabled in #166056.

## Credit:

This work is based on @yuchengliu1's work at #140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts.

## FP8 Task tracker:
We will track all the scaled_mm related tasks in: https://github.com/pytorch/pytorch/issues/167170

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165978
Approved by: https://github.com/liangan1, https://github.com/EikanWang

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
2025-11-14 06:41:18 +00:00
3522e0ce74 Revert "Fix different seq length (#167481)"
This reverts commit c78e64622e62eb93a03a9c3762df3290d6c65362.

Reverted https://github.com/pytorch/pytorch/pull/167481 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167481#issuecomment-3530992724))
2025-11-14 06:05:45 +00:00
50bf1f0b81 deprecate check_is_size and guard_size_oblivious (#167198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167198
Approved by: https://github.com/bobrenjc93
2025-11-14 05:35:29 +00:00
c78e64622e Fix different seq length (#167481)
Differential Revision: D86685546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167481
Approved by: https://github.com/eellison
2025-11-14 05:31:29 +00:00
5623628894 [SymmMem] op to get remote tensors (#167779)
To support use case in https://github.com/pytorch/helion/pull/1122, i.e.
```
@helion.kernel
def foo(
    x: Tensor,
    group_name: str
):
    x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name)
    for t in x_remotes:
        ...
````

Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function:
```
hdl = rendezvous(tensor)
remote_tensors = tuple(
    hdl.get_remote_tensor(peer, ...) for peer in range(world_size)
)
```
The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail.

This PR wraps the above functionality as an op:
```
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
```
so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167779
Approved by: https://github.com/yf225
2025-11-14 05:01:55 +00:00
2aba180114 Always track _local_scalar_dense output in tensorify_python_scalars. (#166573)
We need to track all symbols, we used to skip
u = item()
and fail with
```
 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166573
Approved by: https://github.com/bobrenjc93
2025-11-14 03:51:43 +00:00
45b2c3d312 [OpenReg][Feat][Docs] Enrich OpenReg device management implementation and add focused documentation (#165897)
## Summary
This PR enriches OpenReg device management codes and adds focused documentation.

## Key Changes
- Introduced device management documentation in `device.md`.
- Updated `OpenRegFunctions.h` and `OpenRegFunctions.cpp` to use `DeviceIndex` and added error handling.
- Implemented `check_device_index` function for validating device indices.
- Enhanced Python bindings in `Module.cpp` for device management.
- Added tests for invalid device index handling in `test_device.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165897
Approved by: https://github.com/fffrog
2025-11-14 03:08:23 +00:00
5b1e112cf9 [Dynamo] Imporve-graph-break-skip-logs (#167067)
Fixes #150477

### Summary:

- Added frame information (function name, file, line number) to all graph break/skip messages
- Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>"

### Impacts:
module: dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167067
Approved by: https://github.com/williamwen42
2025-11-14 03:06:37 +00:00
5e6ac5c6e1 [Pytorch] Improve conversion to bfloat16 on aarch64/NEON (#166958)
Summary:
Autovectorization of casting to bfloat16_t is broken in clang-[17, 20], fixed in clang-21.

We are adding a workaround vectorized code, which improves conversion speed from smaller int data types.

We've observed the following performance improvements, when compiling with clang-19 and targeting armv9a+sve2:

before:

uint8->bfloat16_t  ===> 319.433us
int8->bfloat16_t  ===> 320.216us
int16->bfloat16_t  ===> 326.899us
int32->bfloat16_t  ===> 327.925us

after:

uint8->bfloat16_t  ===> 185.189us  -----> 72% higher throughput
int8->bfloat16_t  ===> 169.790us  -----> 89% higher throughput
int16->bfloat16_t  ===> 180.744us  -----> 81% higher throughput
int32->bfloat16_t  ===> 185.129us  -----> 77% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D86207189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166958
Approved by: https://github.com/mcfi
2025-11-14 02:40:08 +00:00
79317dc7a7 Fix no source name in backward kernel names; Add flex_attention HOP to "original_aten" node meta (#167749)
Fixes #167706

- Add `torch.fx.experimental.proxy_tensor.set_original_aten_op()` around flex_atention HOP dispatch so we have `original_aten` populated for flex_attention
- Update the usages of `original_aten` to also expect HOP in addition to OpOverload

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167749
Approved by: https://github.com/drisspg
2025-11-14 02:24:22 +00:00
96a4c4b3d1 add device generalization support for distributed tests (#165067)
## MOTIVATION
To generalize Distributed test cases for non-CUDA devices

## CHANGES
- Replaced hard coded device/backends with torch.accelerator.current_accelerator() and dist.get_default_backend_for_device
- Use DistributedTestBase instead of MultiProcessTestCase to use common utilities
- Remove instantiate_device_tests and make use of torch.accelerator.current_accelerator for test/distributed/test_c10d_object_collectives.py
- fix deterministic context issue for non-cuda devices in test/distributed/optim/test_zero_redundancy_optimizer.py
- use torch.accelerator.device_count() for multi-gpu check in torch/testing/_internal/distributed/_tensor/common_dtensor.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165067
Approved by: https://github.com/guangyey, https://github.com/albanD
2025-11-14 02:21:11 +00:00
05bcfcc5d1 [Profiler] Add Documentation for FunctionEvent (#167688)
Summary:
Adds documentation for EventList, FunctionEvent and FunctionEventAvg.

Closes https://github.com/pytorch/pytorch/issues/165907

Test Plan: N/A Documentation

Differential Revision: D86913697

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167688
Approved by: https://github.com/sanrise
2025-11-14 02:03:19 +00:00
8cf0bdde45 [xpu][fix] Fix conv1d precision error (#162944)
Currently, conv1d converts the 3D view to 4D before calling onednn::convolution().
However, this function converts the 4D tensor to a channel-last memory format for computation, resulting in incorrect return results (the correct result should be channel-first).
This PR fixes this issue, ensuring that the output return value format is consistent with the expected format.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162944
Approved by: https://github.com/EikanWang
2025-11-14 01:12:21 +00:00
813e5eae9b [fx, 3.14] fix assert detection for 3.14 (#167700)
Failing test was `pytest test/export/test_export.py -k test_python_asserts_with_sym_int`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167700
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #167382, #167383, #167384, #167387, #167396, #167669
2025-11-14 01:00:43 +00:00
2ef236e3e3 [3.14, jit] skip jit tests on 3.14+, add jit deprecation warnings to user-facing API (#167669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167669
Approved by: https://github.com/malfet, https://github.com/atalman
ghstack dependencies: #167382, #167383, #167384, #167387, #167396
2025-11-14 01:00:43 +00:00
532389fe9e [torchelastic] Add flush option to TailLog (#167169)
Differential Revision: D86366889

This PR adds the `flush` option to `TailLog`, and it will automatically flush (by setting `buffering=1`) the files opened by that `TailLog` instance.

This is mainly to resolve the race condition between the default flushing of `TailLog` and where we read the duplicated error files in the termination handler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167169
Approved by: https://github.com/fduwjj
2025-11-14 00:21:26 +00:00
08de54f1ea [3.14] Skip failing spherical_bessel_j0 tests (#167691)
Starting with scipy 1.15, bool inputs error out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167691
Approved by: https://github.com/williamwen42
2025-11-14 00:06:42 +00:00
323 changed files with 4854 additions and 1639 deletions

View File

@ -402,3 +402,6 @@ scikit-build==0.18.1
pyre-extensions==0.0.32
tabulate==0.9.0
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
Flask==3.1.2
#Description: required for torch.distributed.debug

View File

@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -460,13 +129,6 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

330
.spin/cmds.py Normal file
View File

@ -0,0 +1,330 @@
import hashlib
import subprocess
import sys
from pathlib import Path
import click
import spin
def file_digest(file, algorithm: str):
try:
return hashlib.file_digest(file, algorithm)
except AttributeError:
pass # Fallback to manual implementation below
hash = hashlib.new(algorithm)
while chunk := file.read(8192):
hash.update(chunk)
return hash
def _hash_file(file):
with open(file, "rb") as f:
hash = file_digest(f, "sha256")
return hash.hexdigest()
def _hash_files(files):
hashes = {file: _hash_file(file) for file in files}
return hashes
def _read_hashes(hash_file: Path):
if not hash_file.exists():
return {}
with hash_file.open("r") as f:
lines = f.readlines()
hashes = {}
for line in lines:
hash = line[:64]
file = line[66:].strip()
hashes[file] = hash
return hashes
def _updated_hashes(hash_file, files_to_hash):
old_hashes = _read_hashes(hash_file)
new_hashes = _hash_files(files_to_hash)
if new_hashes != old_hashes:
return new_hashes
return None
@click.command()
def regenerate_version():
"""Regenerate version.py."""
cmd = [
sys.executable,
"-m",
"tools.generate_torch_version",
"--is-debug=false",
]
spin.util.run(cmd)
TYPE_STUBS = [
(
"Pytorch type stubs",
Path(".lintbin/.pytorch-type-stubs.sha256"),
[
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"tools/autograd/deprecated.yaml",
],
[
sys.executable,
"-m",
"tools.pyi.gen_pyi",
"--native-functions-path",
"aten/src/ATen/native/native_functions.yaml",
"--tags-path",
"aten/src/ATen/native/tags.yaml",
"--deprecated-functions-path",
"tools/autograd/deprecated.yaml",
],
),
(
"Datapipes type stubs",
None,
[],
[
sys.executable,
"torch/utils/data/datapipes/gen_pyi.py",
],
),
]
@click.command()
def regenerate_type_stubs():
"""Regenerate type stubs."""
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
if hash_file:
if hashes := _updated_hashes(hash_file, files_to_hash):
click.echo(
f"Changes detected in type stub files for {name}. Regenerating..."
)
spin.util.run(cmd)
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Type stubs and hashes updated.")
else:
click.echo(f"No changes detected in type stub files for {name}.")
else:
click.echo(f"No hash file for {name}. Regenerating...")
spin.util.run(cmd)
click.echo("Type stubs regenerated.")
@click.command()
def regenerate_clangtidy_files():
"""Regenerate clang-tidy files."""
cmd = [
sys.executable,
"-m",
"tools.linter.clang_tidy.generate_build_files",
]
spin.util.run(cmd)
#: These linters are expected to need less than 3s cpu time total
VERY_FAST_LINTERS = {
"ATEN_CPU_GPU_AGNOSTIC",
"BAZEL_LINTER",
"C10_NODISCARD",
"C10_UNUSED",
"CALL_ONCE",
"CMAKE_MINIMUM_REQUIRED",
"CONTEXT_DECORATOR",
"COPYRIGHT",
"CUBINCLUDE",
"DEPLOY_DETECTION",
"ERROR_PRONE_ISINSTANCE",
"EXEC",
"HEADER_ONLY_LINTER",
"IMPORT_LINTER",
"INCLUDE",
"LINTRUNNER_VERSION",
"MERGE_CONFLICTLESS_CSV",
"META_NO_CREATE_UNBACKED",
"NEWLINE",
"NOQA",
"NO_WORKFLOWS_ON_FORK",
"ONCE_FLAG",
"PYBIND11_INCLUDE",
"PYBIND11_SPECIALIZATION",
"PYPIDEP",
"PYPROJECT",
"RAWCUDA",
"RAWCUDADEVICE",
"ROOT_LOGGING",
"TABS",
"TESTOWNERS",
"TYPEIGNORE",
"TYPENOSKIP",
"WORKFLOWSYNC",
}
#: These linters are expected to take a few seconds, but less than 10s cpu time total
FAST_LINTERS = {
"CMAKE",
"DOCSTRING_LINTER",
"GHA",
"NATIVEFUNCTIONS",
"RUFF",
"SET_LINTER",
"SHELLCHECK",
"SPACES",
}
#: These linters are expected to take more than 10s cpu time total;
#: some need more than 1 hour.
SLOW_LINTERS = {
"ACTIONLINT",
"CLANGFORMAT",
"CLANGTIDY",
"CODESPELL",
"FLAKE8",
"GB_REGISTRY",
"PYFMT",
"PYREFLY",
"TEST_DEVICE_BIAS",
"TEST_HAS_MAIN",
}
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
LINTRUNNER_CACHE_INFO = (
Path(".lintbin/.lintrunner.sha256"),
[
"requirements.txt",
"pyproject.toml",
".lintrunner.toml",
],
)
LINTRUNNER_BASE_CMD = [
"uvx",
"--python",
"3.10",
"lintrunner@0.12.7",
]
@click.command()
def setup_lint():
"""Set up lintrunner with current CI version."""
cmd = LINTRUNNER_BASE_CMD + ["init"]
subprocess.run(cmd, check=True, capture_output=True, text=True)
def _check_linters():
cmd = LINTRUNNER_BASE_CMD + ["list"]
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
unknown_linters = linters - ALL_LINTERS
missing_linters = ALL_LINTERS - linters
if unknown_linters:
click.secho(
f"Unknown linters found; please add them to the correct category "
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
fg="yellow",
)
if missing_linters:
click.secho(
f"Missing linters found; please update the corresponding category "
f"in .spin/cmds.py: {', '.join(missing_linters)}",
fg="yellow",
)
return unknown_linters, missing_linters
@spin.util.extend_command(
setup_lint,
doc=f"""
If configuration has changed, update lintrunner.
Compares the stored old hashes of configuration files with new ones and
performs setup via setup-lint if the hashes have changed.
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
""",
)
@click.pass_context
def lazy_setup_lint(ctx, parent_callback, **kwargs):
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
click.echo(
"Changes detected in lint configuration files. Setting up linting tools..."
)
parent_callback(**kwargs)
hash_file = LINTRUNNER_CACHE_INFO[0]
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Linting tools set up and hashes updated.")
else:
click.echo("No changes detected in lint configuration files. Skipping setup.")
click.echo("Regenerating version...")
ctx.invoke(regenerate_version)
click.echo("Regenerating type stubs...")
ctx.invoke(regenerate_type_stubs)
click.echo("Done.")
_check_linters()
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def lint(ctx, apply_patches, **kwargs):
"""Lint all files."""
ctx.invoke(lazy_setup_lint)
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
changed_files_linters = SLOW_LINTERS
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
all_files_cmd = cmd + [
"--take",
",".join(all_files_linters),
"--all-files",
]
spin.util.run(all_files_cmd)
changed_files_cmd = cmd + [
"--take",
",".join(changed_files_linters),
]
spin.util.run(changed_files_cmd)
@click.command()
@click.pass_context
def fixlint(ctx, **kwargs):
"""Autofix all files."""
ctx.invoke(lint, apply_patches=True)
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def quicklint(ctx, apply_patches, **kwargs):
"""Lint changed files."""
ctx.invoke(lazy_setup_lint)
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
spin.util.run(cmd)
@click.command()
@click.pass_context
def quickfix(ctx, **kwargs):
"""Autofix changed files."""
ctx.invoke(quicklint, apply_patches=True)

View File

@ -144,7 +144,7 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
}
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
return out;
}

View File

@ -9,7 +9,7 @@ namespace indexing {
const EllipsisIndexType Ellipsis = EllipsisIndexType();
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
stream << slice.start() << ':' << slice.stop() << ':' << slice.step();
return stream;
}
@ -31,12 +31,12 @@ std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index)
}
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
stream << "(";
stream << '(';
for (const auto i : c10::irange(tensor_indices.size())) {
stream << tensor_indices[i];
if (i < tensor_indices.size() - 1) stream << ", ";
}
stream << ")";
stream << ')';
return stream;
}

View File

@ -113,7 +113,7 @@ void TensorNames::checkUnique(const char* op_name) const {
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
out << tensorname.name_ << " (index ";
out << tensorname.origin_idx_ << " of ";
out << tensorname.origin_ << ")";
out << tensorname.origin_ << ')';
return out;
}

View File

@ -13,9 +13,9 @@ std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
if (t.pos == 0) {
// 0 is distinguished; it usually indicates 'self' or the return
// tensor
out << "'" << t.name << "'";
out << '\'' << t.name << '\'';
} else {
out << "argument #" << t.pos << " '" << t.name << "'";
out << "argument #" << t.pos << " '" << t.name << '\'';
}
return out;
}
@ -154,7 +154,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
oss << "Tensor for " << t2 << " is on CPU, ";
}
oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
<< " to be on GPU (while checking arguments for " << c << ")";
<< " to be on GPU (while checking arguments for " << c << ')';
TORCH_CHECK(false, oss.str());
}
TORCH_CHECK(
@ -199,7 +199,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t,
i++;
}
oss << "; but got " << t->toString()
<< " instead (while checking arguments for " << c << ")";
<< " instead (while checking arguments for " << c << ')';
TORCH_CHECK(false, oss.str());
}
}

View File

@ -43,8 +43,8 @@ std::string get_mkldnn_version() {
// https://github.com/intel/ideep/issues/29
{
const dnnl_version_t* ver = dnnl_version();
ss << "Intel(R) MKL-DNN v" << ver->major << "." << ver->minor << "." << ver->patch
<< " (Git Hash " << ver->hash << ")";
ss << "Intel(R) MKL-DNN v" << ver->major << '.' << ver->minor << '.' << ver->patch
<< " (Git Hash " << ver->hash << ')';
}
#else
ss << "MKLDNN not found";
@ -81,7 +81,7 @@ std::string get_openmp_version() {
break;
}
if (ver_str) {
ss << " (a.k.a. OpenMP " << ver_str << ")";
ss << " (a.k.a. OpenMP " << ver_str << ')';
}
}
#else
@ -135,7 +135,7 @@ std::string show_config() {
#if defined(__GNUC__)
{
ss << " - GCC " << __GNUC__ << "." << __GNUC_MINOR__ << "\n";
ss << " - GCC " << __GNUC__ << '.' << __GNUC_MINOR__ << "\n";
}
#endif
@ -147,7 +147,7 @@ std::string show_config() {
#if defined(__clang_major__)
{
ss << " - clang " << __clang_major__ << "." << __clang_minor__ << "." << __clang_patchlevel__ << "\n";
ss << " - clang " << __clang_major__ << '.' << __clang_minor__ << '.' << __clang_patchlevel__ << "\n";
}
#endif
@ -200,7 +200,7 @@ std::string show_config() {
ss << " - Build settings: ";
for (const auto& pair : caffe2::GetBuildOptions()) {
if (!pair.second.empty()) {
ss << pair.first << "=" << pair.second << ", ";
ss << pair.first << '=' << pair.second << ", ";
}
}
ss << "\n";

View File

@ -209,7 +209,7 @@ struct CodeTemplate {
// to indent correctly in the context.
void emitIndent(std::ostream& out, size_t indent) const {
for ([[maybe_unused]] const auto i : c10::irange(indent)) {
out << " ";
out << ' ';
}
}
void emitStringWithIndents(

View File

@ -10,7 +10,7 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
if (dimname.type() == NameType::WILDCARD) {
out << "None";
} else {
out << "'" << dimname.symbol().toUnqualString() << "'";
out << '\'' << dimname.symbol().toUnqualString() << '\'';
}
return out;
}

View File

@ -5,7 +5,7 @@
namespace at {
std::ostream& operator<<(std::ostream& out, const Range& range) {
out << "Range[" << range.begin << ", " << range.end << "]";
out << "Range[" << range.begin << ", " << range.end << ']';
return out;
}

View File

@ -71,7 +71,7 @@ void TensorBase::enforce_invariants() {
void TensorBase::print() const {
if (defined()) {
std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
std::cerr << '[' << toString() << ' ' << sizes() << ']' << '\n';
} else {
std::cerr << "[UndefinedTensor]" << '\n';
}

View File

@ -9,7 +9,7 @@ APIVitals VitalsAPI;
std::ostream& operator<<(std::ostream& os, TorchVital const& tv) {
for (const auto& m : tv.attrs) {
os << "[TORCH_VITAL] " << tv.name << "." << m.first << "\t\t "
os << "[TORCH_VITAL] " << tv.name << '.' << m.first << "\t\t "
<< m.second.value << "\n";
}
return os;

View File

@ -100,18 +100,18 @@ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
// this does match the way things are represented in the schema
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
out << "(";
out << '(';
bool first = true;
for (const auto& set : aliasInfo.beforeSets()) {
if (first) {
first = false;
} else {
out << "|";
out << '|';
}
out << set.toUnqualString();
}
if (aliasInfo.isWrite()) {
out << "!";
out << '!';
}
if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
out << " -> ";
@ -120,12 +120,12 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
if (first) {
first = false;
} else {
out << "|";
out << '|';
}
out << set.toUnqualString();
}
}
out << ")";
out << ')';
return out;
}
} // namespace c10

View File

@ -198,7 +198,7 @@ inline void swap(Blob& lhs, Blob& rhs) noexcept {
}
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
return out << "Blob[" << v.TypeName() << "]";
return out << "Blob[" << v.TypeName() << ']';
}
} // namespace caffe2

View File

@ -100,7 +100,7 @@ struct TORCH_API ClassType : public NamedType {
std::string repr_str() const override {
std::stringstream ss;
ss << str()
<< " (of Python compilation unit at: " << compilation_unit().get() << ")";
<< " (of Python compilation unit at: " << compilation_unit().get() << ')';
return ss.str();
}

View File

@ -58,12 +58,12 @@ std::string DispatchKeyExtractor::dumpState() const {
std::ostringstream oss;
for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) {
if (dispatch_arg_indices_reverse_.get(i)) {
oss << "1";
oss << '1';
} else {
oss << "0";
oss << '0';
}
}
oss << " " << nonFallthroughKeys_ << "\n";
oss << ' ' << nonFallthroughKeys_ << '\n';
return oss.str();
}

View File

@ -69,8 +69,8 @@ private:
void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
auto nesting_value = dispatch_trace_nesting_value();
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << ' ';
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << ']' << std::endl;
}
} // namespace detail

View File

@ -570,7 +570,7 @@ void OperatorEntry::checkInvariants() const {
std::string OperatorEntry::listAllDispatchKeys() const {
std::ostringstream str;
str << "[";
str << '[';
bool has_kernels = false;
for (auto k : allDispatchKeysInFullSet()) {
@ -584,7 +584,7 @@ std::string OperatorEntry::listAllDispatchKeys() const {
str << k;
has_kernels = true;
}
str << "]";
str << ']';
return str.str();
}

View File

@ -210,9 +210,9 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
out << schema.name();
if (!schema.overload_name().empty()) {
out << "." << schema.overload_name();
out << '.' << schema.overload_name();
}
out << "(";
out << '(';
bool seen_kwarg_only = false;
for (const auto i : c10::irange(schema.arguments().size())) {
@ -273,7 +273,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
}
if (need_paren) {
out << "(";
out << '(';
}
for (const auto i : c10::irange(returns.size())) {
if (i > 0) {
@ -288,7 +288,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
out << "...";
}
if (need_paren) {
out << ")";
out << ')';
}
return out;
}
@ -471,7 +471,7 @@ bool FunctionSchema::isForwardCompatibleWith(
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
if (why_not) {
why_not
<< "'" << arguments().at(i).name() << "'"
<< '\'' << arguments().at(i).name() << '\''
<< " is not forward compatible with the older version of the schema";
}
return false;
@ -511,7 +511,7 @@ bool FunctionSchema::isForwardCompatibleWith(
.isForwardCompatibleWith(old.arguments().at(i))) {
if (why_not) {
why_not << "Out argument '"
<< "'" << arguments().at(i).name()
<< '\'' << arguments().at(i).name()
<< " is not FC with the older version of the schema";
}
return false;

View File

@ -571,7 +571,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
if (arg.N()) {
N = std::to_string(*arg.N());
}
out << "[" << N << "]";
out << '[' << N << ']';
} else {
out << unopt_type->str();
}
@ -582,15 +582,15 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
}
if (is_opt) {
out << "?";
out << '?';
}
if (!arg.name().empty()) {
out << " " << arg.name();
out << ' ' << arg.name();
}
if (arg.default_value()) {
out << "=";
out << '=';
if ((type->kind() == c10::TypeKind::StringType ||
unopt_type->kind() == c10::TypeKind::StringType) &&
arg.default_value().value().isString()) {

View File

@ -66,7 +66,7 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
}
std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
out << v.qualifiedClassName() << "." << v.name();
out << v.qualifiedClassName() << '.' << v.name();
return out;
}
@ -526,7 +526,7 @@ std::ostream& printMaybeAnnotatedList(
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
printList(out, the_list.toListRef(), "[", "]", formatter);
out << ")";
out << ')';
return out;
} else {
return printList(out, the_list.toListRef(), "[", "]", formatter);
@ -538,7 +538,7 @@ std::ostream& printDict(
std::ostream& out,
const Dict& v,
const IValueFormatter& formatter) {
out << "{";
out << '{';
bool first = true;
for (const auto& pair : v) {
@ -552,7 +552,7 @@ std::ostream& printDict(
first = false;
}
out << "}";
out << '}';
return out;
}
}
@ -565,8 +565,8 @@ static std::ostream& printMaybeAnnotatedDict(
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
if (the_dict.toGenericDict().empty() ||
!elementTypeCanBeInferredFromMembers(value_type)) {
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
printDict(out, the_dict.toGenericDict(), formatter) << ")";
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ',';
printDict(out, the_dict.toGenericDict(), formatter) << ')';
} else {
return printDict(out, the_dict.toGenericDict(), formatter);
}
@ -577,7 +577,7 @@ static std::ostream& printComplex(std::ostream & out, const IValue & v) {
c10::complex<double> d = v.toComplexDouble();
IValue real(d.real()), imag(std::abs(d.imag()));
auto sign = d.imag() >= 0 ? '+' : '-';
return out << real << sign << imag << "j";
return out << real << sign << imag << 'j';
}
std::ostream& IValue::repr(
@ -605,9 +605,9 @@ std::ostream& IValue::repr(
if (static_cast<double>(i) == d) {
// -0.0 (signed zero) needs to be parsed as -0.
if (i == 0 && std::signbit(d)) {
return out << "-" << i << ".";
return out << '-' << i << '.';
}
return out << i << ".";
return out << i << '.';
}
}
auto orig_prec = out.precision();
@ -643,20 +643,20 @@ std::ostream& IValue::repr(
device_stream << v.toDevice();
out << "torch.device(";
c10::printQuotedString(out, device_stream.str());
return out << ")";
return out << ')';
}
case IValue::Tag::Generator: {
auto generator = v.toGenerator();
out << "torch.Generator(device=";
c10::printQuotedString(out, generator.device().str());
out << ", seed=" << generator.current_seed() << ")";
out << ", seed=" << generator.current_seed() << ')';
return out;
}
case IValue::Tag::GenericDict:
return printMaybeAnnotatedDict(out, v, formatter);
case IValue::Tag::Enum: {
auto enum_holder = v.toEnumHolder();
return out << enum_holder->qualifiedClassName() << "." <<
return out << enum_holder->qualifiedClassName() << '.' <<
enum_holder->name();
}
case IValue::Tag::Object: {
@ -801,7 +801,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
if (c == FP_NORMAL || c == FP_ZERO) {
int64_t i = static_cast<int64_t>(d);
if (static_cast<double>(i) == d) {
return out << i << ".";
return out << i << '.';
}
}
auto orig_prec = out.precision();
@ -852,7 +852,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return printDict(out, v.toGenericDict(), formatter);
case IValue::Tag::PyObject: {
auto py_obj = v.toPyObject();
return out << "<PyObject at" << py_obj << ">";
return out << "<PyObject at" << py_obj << '>';
}
case IValue::Tag::Generator:
return out << "Generator";
@ -862,16 +862,16 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
// TODO we should attempt to call __str__ if the object defines it.
auto obj = v.toObject();
// print this out the way python would do it
return out << "<" << obj->name() << " object at " << obj.get() << ">";
return out << '<' << obj->name() << " object at " << obj.get() << '>';
}
case IValue::Tag::Enum: {
auto enum_holder = v.toEnumHolder();
return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
enum_holder->name() << ">";
return out << "Enum<" << enum_holder->unqualifiedClassName() << '.' <<
enum_holder->name() << '>';
}
}
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << '>';
}
#undef TORCH_FORALL_TAGS
@ -1050,7 +1050,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
std::stringstream err;
err << "Cannot serialize custom bound C++ class";
if (auto qualname = type()->name()) {
err << " " << qualname->qualifiedName();
err << ' ' << qualname->qualifiedName();
}
err << ". Please define serialization methods via def_pickle() for "
"this class.";

View File

@ -211,7 +211,7 @@ struct TORCH_API OptionalType : public UnionType {
std::string str() const override {
std::stringstream ss;
ss << getElementType()->str() << "?";
ss << getElementType()->str() << '?';
return ss.str();
}
@ -240,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
ss << "Optional[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -906,7 +906,7 @@ struct TORCH_API ListType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "List[" << getElementType()->annotation_str(printer) << "]";
ss << "List[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -946,7 +946,7 @@ struct TORCH_API DictType : public SharedType {
std::string str() const override {
std::stringstream ss;
ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
<< ")";
<< ')';
return ss.str();
}
@ -1018,7 +1018,7 @@ struct TORCH_API FutureType
std::string str() const override {
std::stringstream ss;
ss << "Future(" << getElementType()->str() << ")";
ss << "Future(" << getElementType()->str() << ')';
return ss.str();
}
TypePtr createWithContained(
@ -1041,7 +1041,7 @@ struct TORCH_API FutureType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
ss << "Future[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -1060,7 +1060,7 @@ struct TORCH_API AwaitType
std::string str() const override {
std::stringstream ss;
ss << "Await(" << getElementType()->str() << ")";
ss << "Await(" << getElementType()->str() << ')';
return ss.str();
}
TypePtr createWithContained(
@ -1083,7 +1083,7 @@ struct TORCH_API AwaitType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "Await[" << getElementType()->annotation_str(printer) << "]";
ss << "Await[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -1102,7 +1102,7 @@ struct TORCH_API RRefType
std::string str() const override {
std::stringstream ss;
ss << "RRef(" << getElementType()->str() << ")";
ss << "RRef(" << getElementType()->str() << ')';
return ss.str();
}
TypePtr createWithContained(
@ -1115,7 +1115,7 @@ struct TORCH_API RRefType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
ss << "RRef[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};

View File

@ -11,7 +11,7 @@ std::string toString(const OperatorName& opName) {
std::ostream& operator<<(std::ostream& os, const OperatorName& opName) {
os << opName.name;
if (!opName.overload_name.empty()) {
os << "." << opName.overload_name;
os << '.' << opName.overload_name;
}
return os;
}

View File

@ -65,7 +65,7 @@ VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
template <typename T>
std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
out << "(";
out << '(';
if (!vs.size()) {
out << "*)";
return out;
@ -79,10 +79,10 @@ std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
if (v.has_value()) {
out << v.value();
} else {
out << "*";
out << '*';
}
}
out << ")";
out << ')';
return out;
}
@ -105,7 +105,7 @@ std::ostream& operator<<(
}
auto sizes_opt = ss.sizes();
os << "(";
os << '(';
for (size_t i = 0; i < rank_opt.value(); i++) {
if (i > 0) {
os << ", ";
@ -113,10 +113,10 @@ std::ostream& operator<<(
if(sizes_opt.has_value() && sizes_opt.value()[i].is_static()) {
os << sizes_opt.value()[i];
} else {
os << "*";
os << '*';
}
}
os << ")";
os << ')';
return os;
}
@ -131,17 +131,17 @@ std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
}
std::ostream& operator<<(std::ostream& os, const Stride& s) {
os << "{";
os << '{';
if (s.stride_index_.has_value()) {
os << *s.stride_index_;
} else {
os << "*";
os << '*';
}
os << ":";
os << ':';
if (s.stride_.has_value()) {
os << *s.stride_;
} else {
os << "*";
os << '*';
}
os << '}';
return os;

View File

@ -67,7 +67,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
bool has_valid_strides_info = ndim > 0 &&
value->strides().isComplete() && value->strides().size() == ndim;
out << "(";
out << '(';
size_t i = 0;
bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
for (i = 0; i < *ndim; ++i) {
@ -79,7 +79,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
} else if (symbolic) {
out << value->symbolic_sizes().at(i);
} else {
out << "*";
out << '*';
}
}
if (has_valid_strides_info &&
@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
}
out << value->strides()[i].value();
}
out << "]";
out << ']';
}
if (type_verbosity() >= TypeVerbosity::Full) {
if (value->requiresGrad()) {
@ -107,12 +107,12 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << "device=" << *value->device();
}
}
out << ")";
out << ')';
} else {
if (type_verbosity() >= TypeVerbosity::Full) {
size_t i = 0;
if (value->requiresGrad()) {
out << "("
out << '('
<< "requires_grad=" << *value->requiresGrad();
i++;
}
@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
}
if (i > 0) {
out << ")";
out << ')';
}
}
}
@ -133,18 +133,18 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << *prim << "[]";
} else if (t.kind() == TypeKind::OptionalType) {
auto prim = t.castRaw<OptionalType>()->getElementType();
out << *prim << "?";
out << *prim << '?';
} else if(t.kind() == TypeKind::FutureType) {
auto elem = t.castRaw<FutureType>()->getElementType();
out << "Future[" << *elem << "]";
out << "Future[" << *elem << ']';
} else if(t.kind() == TypeKind::RRefType) {
auto elem = t.castRaw<RRefType>()->getElementType();
out << "RRef[" << *elem << "]";
out << "RRef[" << *elem << ']';
} else if(auto tup = t.cast<TupleType>()) {
if (tup->schema()) {
out << "NamedTuple";
}
out << "(";
out << '(';
for(size_t i = 0; i < tup->elements().size(); ++i) {
if(i > 0)
out << ", ";
@ -160,7 +160,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << *(tup->elements()[i]);
}
}
out << ")";
out << ')';
} else if (t.kind() == TypeKind::FunctionType) {
out << "Function";
} else {
@ -475,7 +475,7 @@ std::optional<TypePtr> unifyTypeList(
why_not << "Could not unify type list since element " << i << " of type "
<< elements.at(i)->repr_str()
<< " did not match the types before it ("
<< ret_type->repr_str() << ")";
<< ret_type->repr_str() << ')';
return std::nullopt;
}
ret_type = *maybe_unified;
@ -907,13 +907,13 @@ std::string TupleType::str() const {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
ss << name()->qualifiedName();
} else {
ss << "(";
ss << '(';
for(size_t i = 0; i < elements().size(); ++i) {
if(i > 0)
ss << ", ";
ss << elements()[i]->str();
}
ss << ")";
ss << ')';
}
return ss.str();
}

View File

@ -205,9 +205,9 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType
for (const auto i : c10::irange(reference.size())) {
msg << reference[i]->repr_str();
if (i > 0) {
msg << ",";
msg << ',';
}
msg << " ";
msg << ' ';
}
msg << "} has the single type " << types_[0]->repr_str()
<< ". Use the common supertype instead of creating a Union"

View File

@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
// clang-[17, 20] crashes when autovectorizing static cast to bf16
// Below is a workaround to have some vectorization
// Works decently well for smaller int types
template <typename from_type>
inline void convertToBf16Impl(
const from_type* __restrict src,
c10::BFloat16* __restrict dst,
uint64_t n) {
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
uint64_t loopBound = n - (n % 16);
uint64_t i = 0;
for (; i < loopBound; i += 16) {
float32x4_t a, b, c, d;
a[0] = static_cast<float>(src[i]);
a[1] = static_cast<float>(src[i + 1]);
a[2] = static_cast<float>(src[i + 2]);
a[3] = static_cast<float>(src[i + 3]);
b[0] = static_cast<float>(src[i + 4]);
b[1] = static_cast<float>(src[i + 5]);
b[2] = static_cast<float>(src[i + 6]);
b[3] = static_cast<float>(src[i + 7]);
c[0] = static_cast<float>(src[i + 8]);
c[1] = static_cast<float>(src[i + 9]);
c[2] = static_cast<float>(src[i + 10]);
c[3] = static_cast<float>(src[i + 11]);
d[0] = static_cast<float>(src[i + 12]);
d[1] = static_cast<float>(src[i + 13]);
d[2] = static_cast<float>(src[i + 14]);
d[3] = static_cast<float>(src[i + 15]);
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
}
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
for (; i < n; i++) {
float a = static_cast<float>(src[i]);
dstPtr[i] = vcvth_bf16_f32(a);
}
}
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
template <> \
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
return convertToBf16Impl<from_type>(src, dst, n); \
}
CONVERT_TO_BF16_TEMPLATE(uint8_t)
CONVERT_TO_BF16_TEMPLATE(int8_t)
CONVERT_TO_BF16_TEMPLATE(int16_t)
CONVERT_TO_BF16_TEMPLATE(int32_t)
#endif
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
c10::BFloat16* __restrict dst,

View File

@ -80,7 +80,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}
stream << buf[i];
}
stream << "]";
stream << ']';
return stream;
}

View File

@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}
stream << buf[i];
}
stream << "]";
stream << ']';
return stream;
}

View File

@ -411,16 +411,16 @@ std::string CUDAHooks::showConfig() const {
// HIP_VERSION value format was changed after ROCm v4.2 to include the patch number
if(v < 500) {
// If major=xx, minor=yy then format -> xxyy
oss << (v / 100) << "." << (v % 10);
oss << (v / 100) << '.' << (v % 10);
}
else {
// If major=xx, minor=yy & patch=zzzzz then format -> xxyyzzzzz
oss << (v / 10000000) << "." << (v / 100000 % 100) << "." << (v % 100000);
oss << (v / 10000000) << '.' << (v / 100000 % 100) << '.' << (v % 100000);
}
#else
oss << (v / 1000) << "." << (v / 10 % 100);
oss << (v / 1000) << '.' << (v / 10 % 100);
if (v % 10 != 0) {
oss << "." << (v % 10);
oss << '.' << (v % 10);
}
#endif
};
@ -448,9 +448,9 @@ std::string CUDAHooks::showConfig() const {
auto printCudnnStyleVersion = [&](size_t v) {
oss << (v / 1000) << "." << (v / 100 % 10);
oss << (v / 1000) << '.' << (v / 100 % 10);
if (v % 100 != 0) {
oss << "." << (v % 100);
oss << '.' << (v % 100);
}
};
@ -461,7 +461,7 @@ std::string CUDAHooks::showConfig() const {
if (cudnnCudartVersion != CUDART_VERSION) {
oss << " (built against CUDA ";
printCudaStyleVersion(cudnnCudartVersion);
oss << ")";
oss << ')';
}
oss << "\n";
if (cudnnVersion != CUDNN_VERSION) {
@ -472,11 +472,11 @@ std::string CUDAHooks::showConfig() const {
#endif
#else
// TODO: Check if miopen has the functions above and unify
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << '.' << MIOPEN_VERSION_MINOR << '.' << MIOPEN_VERSION_PATCH << "\n";
#endif
#if AT_MAGMA_ENABLED()
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
oss << " - Magma " << MAGMA_VERSION_MAJOR << '.' << MAGMA_VERSION_MINOR << '.' << MAGMA_VERSION_MICRO << "\n";
#endif
return oss.str();

View File

@ -42,7 +42,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
std::stringstream ss;
ss << nInputs << "_" << nOutputs << f;
ss << nInputs << '_' << nOutputs << f;
ss << f_inputs_type_str << compute_type_str << result_type_str;
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
ss << extra_args_types;
@ -144,7 +144,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
// The cache key includes all the parameters to generate_code + dev_idx
std::stringstream ss;
ss << nInputs << "_" << nOutputs << f;
ss << nInputs << '_' << nOutputs << f;
ss << f_inputs_type_str << compute_type_str << result_type_str;
ss << contiguous << dynamic_casting;
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);

View File

@ -52,10 +52,10 @@ TuningContext* getTuningContext() {
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
if (!blaslog) {
return stream << entry.key_ << "," << entry.time_;
return stream << entry.key_ << ',' << entry.time_;
}
else {
return stream << entry.key_ << "," << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
return stream << entry.key_ << ',' << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
}
}
@ -156,10 +156,10 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
if (isNew) {
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
if (!blaslog) {
untuned_file << op_signature << "," << params_signature << std::endl;
untuned_file << op_signature << ',' << params_signature << std::endl;
}
else {
untuned_file << op_signature << "," << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
untuned_file << op_signature << ',' << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
}
TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature);
}
@ -201,7 +201,7 @@ void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const
if(!file_exists || file_empty) {
for(const auto& [key, val] : validators) {
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
(*realtime_out_) << "Validator," << key << ',' << val << std::endl;
realtime_out_->flush();
}
validators_written_ = true;
@ -219,7 +219,7 @@ void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std
return;
}
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
(*realtime_out_) << op_sig << ',' << param_sig << ',' << result << std::endl;
realtime_out_->flush(); //ensure immediate write to disk
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);

View File

@ -93,7 +93,7 @@ std::string cudnnTypeToString(cudnnDataType_t dtype) {
return "CUDNN_DATA_UINT8x4";
default:
std::ostringstream oss;
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
return oss.str();
}
}
@ -168,7 +168,7 @@ std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
return "CUDNN_TENSOR_NHWC";
default:
std::ostringstream oss;
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ')';
return oss.str();
}
}

View File

@ -346,15 +346,15 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
}
std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
os << layer.layerId() << ":" << layer.key();
os << layer.layerId() << ':' << layer.key();
return os;
}
std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
os << "DynamicLayerStack[ ";
for (const auto& layer : dls) {
os << layer << " ";
os << layer << ' ';
}
os << "]";
os << ']';
return os;
}

View File

@ -22,7 +22,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
if (batched) {
ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", ";
dumpTensor(ss, batched->value());
ss << "]";
ss << ']';
return;
}
ss << "Tensor" << tensor.sizes();
@ -36,7 +36,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
ss << "dead, ";
}
dumpTensor(ss, wrapped->value());
ss << "]";
ss << ']';
}
void TensorWrapper::refreshMetadata() {

View File

@ -73,7 +73,7 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
return "miopenBFloat16";
default:
std::ostringstream oss;
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
return oss.str();
}
}

View File

@ -91,7 +91,7 @@ struct OperationInfo : BaseInfo {
std::stringstream kernelStr;
kernelStr << kernelName;
for (const Tensor& tensor : tensors) {
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
}
return kernelStr.str();
}

View File

@ -39,9 +39,9 @@ std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBuffer
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ':' << buffer.retainCount << ')';
}
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
tensorStr << ':' << tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";

View File

@ -167,7 +167,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
std::stringstream ss;
ss << arg_name << " should be greater than zero but got (";
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
ss << args.back() << ")" << " (while checking arguments for " << c << ')';
TORCH_CHECK(false, ss.str());
}
}

View File

@ -639,7 +639,7 @@ static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params)
<< " deterministic = " << params.deterministic
<< " cudnn_enabled = " << params.cudnn_enabled
<< " allow_tf32 = " << params.allow_tf32
<< "}";
<< '}';
return out;
}

View File

@ -847,7 +847,7 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
<< ", window="; \
if (window.defined()) { \
SS << window.toString() << "{" << window.sizes() << "}"; \
SS << window.toString() << '{' << window.sizes() << '}'; \
} else { \
SS << "None"; \
} \
@ -1046,7 +1046,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional<int64_
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
<< ", window="; \
if (window.defined()) { \
SS << window.toString() << "{" << window.sizes() << "}"; \
SS << window.toString() << '{' << window.sizes() << '}'; \
} else { \
SS << "None"; \
} \

View File

@ -904,19 +904,11 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
}
// since mvlgamma_ has different signature from its
// out and functional variant, we explicitly
// define it (instead of using structured kernel).
Tensor& mvlgamma_(Tensor& self, int64_t p) {
mvlgamma_check(self, p);
Tensor args = native::arange(
-p *HALF + HALF,
HALF,
HALF,
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
args = args.add(self.unsqueeze(-1));
const auto p2_sub_p = static_cast<double>(p * (p - 1));
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
return at::mvlgamma_out(self, self, p);
}
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {

View File

@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const SwizzleType swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const SwizzleType swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;

View File

@ -11,7 +11,7 @@ static inline std::ostream& operator<<(std::ostream& out, dim3 dim) {
if (dim.y == 1 && dim.z == 1) {
out << dim.x;
} else {
out << "[" << dim.x << "," << dim.y << "," << dim.z << "]";
out << '[' << dim.x << ',' << dim.y << ',' << dim.z << ']';
}
return out;
}
@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << "input_mult=[";
for (int i = 0; i < 3; i++) {
if (i != 0) {
out << ",";
out << ',';
}
out << config.input_mult[i];
}
@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << "output_mult=[";
for (int i = 0; i < 2; i++) {
if (i != 0) {
out << ",";
out << ',';
}
out << config.output_mult[i];
}
@ -49,7 +49,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << "block=" << config.block() << ", ";
out << "grid=" << config.grid() << ", ";
out << "global_memory_size=" << config.global_memory_size();
out << ")";
out << ')';
return out;
}

View File

@ -364,9 +364,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
// reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
// stride_output_h + group_count);
// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << "
// std::cout << "PTRS " << mat_a.data_ptr() << ' ' << mat_b.data_ptr() << "
// "
// << out.data_ptr() << " " << scale_a.data_ptr() << " "
// << out.data_ptr() << ' ' << scale_a.data_ptr() << ' '
// << scale_b.data_ptr() << "\n";
// for (int i = 0; i < group_count; i++) {
// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n";

View File

@ -1057,14 +1057,14 @@ std::string generate_code(
// TODO these arrays are potentially of the different types, use function
// traits to determine the types
declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
<< "[" << std::to_string(thread_work_size) << "];\n";
<< '[' << std::to_string(thread_work_size) << "];\n";
}
env.s("declare_load_arrays", declare_load_arrays.str());
std::stringstream declare_store_arrays;
for (int i = 0; i < nOutputs; i++) {
declare_store_arrays << result_type << " out" << std::to_string(i)
<< "[" << std::to_string(thread_work_size) << "];\n";
<< '[' << std::to_string(thread_work_size) << "];\n";
}
env.s("declare_store_arrays", declare_store_arrays.str());
@ -1217,7 +1217,7 @@ std::string generate_code(
for (const auto i : c10::irange(nInputs)){
auto i_string = std::to_string(i);
vector_inputs << "auto * input" << i_string <<
" = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
" = reinterpret_cast<const scalar_t*>(data[" << i_string << '+' << nOutputs << "])" <<
" + block_work_size * idx;\n";
}
env.s("vector_inputs", vector_inputs.str());
@ -1543,17 +1543,17 @@ NvrtcFunction jit_pwise_function(
// Constructs file path by appending constructed cubin name to cache path
std::stringstream ss;
ss << *cache_dir << "/";
ss << *cache_dir << '/';
ss << kernel_name;
#ifdef USE_ROCM
ss << "_arch" << prop->gcnArchName;
#else
ss << "_arch" << cuda_major << "." << cuda_minor;
ss << "_arch" << cuda_major << '.' << cuda_minor;
#endif
ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
ss << "_nvrtc" << nvrtc_major << '.' << nvrtc_minor;
ss << (compile_to_sass ? "_sass" : "_ptx");
ss << "_" << code.length();
ss << "_" << hash_code;
ss << '_' << code.length();
ss << '_' << hash_code;
file_path = ss.str();
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};

View File

@ -115,7 +115,7 @@ std::ostream& operator<<(
std::copy(
strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
oss << sizes.back();
output << oss.str() << "}";
output << oss.str() << '}';
return output;
}

View File

@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
<< " transposed = " << params.transposed
<< " output_padding = " << IntArrayRef{params.output_padding}
<< " groups = " << params.groups << " benchmark = " << params.benchmark
<< " deterministic = " << params.deterministic << "}";
<< " deterministic = " << params.deterministic << '}';
return out;
}
@ -337,10 +337,6 @@ Tensor _convolution_out(
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"convolution only supports 3D, 4D, 5D tensor");
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested =
use_channels_last_for_conv(input_r, weight_r);
Tensor input = input_r, weight = weight_r;
// PyTorch does not support ChannelsLast1D case,
// thus we need the transformation here
@ -348,13 +344,8 @@ Tensor _convolution_out(
input = view4d(input_r);
weight = view4d(weight_r);
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
auto k = weight.ndimension();
if (k == input.ndimension() + 1) {
@ -388,6 +379,14 @@ Tensor _convolution_out(
expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
check_shape_forward(input, weight, bias, params, true);
Tensor output;
@ -514,18 +513,9 @@ Tensor convolution_overrideable(
at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto k = weight_r.ndimension();
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
if (xpu_conv_use_channels_last(input_r, weight_r)) {
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
: at::MemoryFormat::ChannelsLast;
}
Tensor input_c = input_r.contiguous(backend_memory_format);
Tensor weight_c = weight_r.contiguous(backend_memory_format);
return _convolution(
input_c,
weight_c,
input_r,
weight_r,
bias_r,
stride_,
padding_,

View File

@ -0,0 +1,342 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - Else if scale.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
* 1:
* - Returns RowWise.
*
* - Otherwise:
* - Returns Error.
*/
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return at::isFloat8Type(t.scalar_type()) &&
scale.scalar_type() == at::kFloat && scale.numel() == 1;
}
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
scale.is_contiguous());
}
bool is_desired_scaling(
const at::Tensor& t,
const at::Tensor& scale,
ScalingType desired_scaling) {
auto result = desired_scaling == ScalingType::TensorWise
? is_tensorwise_scaling(t, scale)
: is_rowwise_scaling(t, scale);
return result;
}
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& scale_a,
const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) &&
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
}
TORCH_CHECK(
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
a.size(0),
", 1) and scale_b should be (1, ",
b.size(1),
"), and both should be contiguous.\n"
"Got a.dtype()=",
a.scalar_type(),
", scale_a.dtype()=",
scale_a.scalar_type(),
", scale_a.size()=",
scale_a.sizes(),
", scale_a.stride()=",
scale_a.strides(),
", ",
"b.dtype()=",
b.scalar_type(),
", scale_b.dtype()=",
scale_b.scalar_type(),
", scale_b.size()=",
scale_b.sizes(),
" and scale_b.stride()=",
scale_b.strides());
}
Tensor& _scaled_gemm(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const ScalingType scaling_choice_a,
const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out,
const std::optional<Tensor>& alpha = std::nullopt) {
// TODO: scale_result and alpha is not defined or used!
std::optional<Tensor> scaled_result = std::nullopt;
at::native::onednn::scaled_matmul(
mat1,
mat2,
out,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
scaled_result,
use_fast_accum);
return out;
}
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output
// matrices Scales are only applicable when matrices are of Float8 type and
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
// type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only
// utilized if the output is a float8 type
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
// false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_out_xpu(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
// Note: fast_accum is not supported in XPU for now.
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
// Check what type of scaling we are doing based on inputs. This list is
// sorted by decreasing priority.
// List of supported datatypes for XPU with oneDNN:
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
},
mat1,
mat2,
scale_a,
scale_b);
TORCH_CHECK(
!scale_result ||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(
!bias || bias->numel() == mat2.sizes()[1],
"Bias must be size ",
mat2.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK(
mat1.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
").");
TORCH_CHECK(
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
"mat2 shape (",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
") must be divisible by 16");
// Check types
TORCH_CHECK(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
TORCH_CHECK(
at::isFloat8Type(mat1.scalar_type()),
"Expected mat1 to be Float8 matrix got ",
mat1.scalar_type());
TORCH_CHECK(
at::isFloat8Type(mat2.scalar_type()),
"Expected mat2 to be Float8 matrix got ",
mat2.scalar_type());
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
// support 2D scales, only 1D. Needs to add more checks there.
if (bias) {
TORCH_CHECK(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
auto scale_result_ = scale_result.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat1, "mat1", 1},
{mat2, "mat2", 2},
{bias_, "bias", 3},
{scale_a, "scale_a", 4},
{scale_b, "scale_b", 5},
{scale_result_, "scale_result", 6}};
checkAllSameGPU(__func__, targs);
}
// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}
return out;
}
// TODO: Scale_result is not supported by now!!
return _scaled_gemm(
mat1,
mat2,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
}
Tensor _scaled_mm_xpu(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_xpu(
mat_a,
mat_b,
scale_a,
scale_b,
bias,
scale_result,
out_dtype,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -1,3 +1,4 @@
#include <ATen/BlasBackend.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <c10/core/ScalarType.h>
@ -8,7 +9,6 @@
#include <oneapi/dnnl/dnnl.hpp>
namespace at::native::onednn {
at::Tensor broadcast_bias2D(
at::Tensor& dst,
at::Tensor& bias,
@ -328,4 +328,236 @@ void quantized_matmul(
result.copy_(dst);
}
// Describes how to configure oneDNN scales for a given role/ScalingType
struct ScaleSpec {
// specifies the way scale values will be applied to an ARG tensor.
int mask;
// specifies how scales are grouped along dimensions where
// multiple scale factors are used.
dnnl::memory::dims groups;
// specifies data type for scale factors.
dnnl::memory::data_type dtype;
// Helper to compute expected number of elements for scale tensors
// arg_type: "src" for SRC (groups pattern {1, X}),
// "wei" for WEIGHTS (groups pattern {X, 1})
int64_t expected_numel(
int64_t outer_dim,
int64_t inner_dim,
const std::string& arg_type) const {
if (groups == dnnl::memory::dims{1, 1})
return 1; // tensorwise scaling
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
TORCH_INTERNAL_ASSERT(
(groups == dnnl::memory::dims{1, inner_dim} ||
groups == dnnl::memory::dims{inner_dim, 1}),
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
groups,
".");
return outer_dim;
}
// Normalize an incoming scale tensor to contiguous storage and appropriate
// dtype/view
at::Tensor normalize(const at::Tensor& scale) const {
TORCH_INTERNAL_ASSERT(
dtype == dnnl::memory::data_type::f32,
"tensor scale currently must be f32, but got scale dtype: ",
scale.scalar_type());
return scale.to(at::kFloat).contiguous();
}
};
// This function defines how to set scales mask and groups according to:
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
// The returned value will be used in
// `set_scales(arg, mask, groups, data_type)`.
inline ScaleSpec make_scale_spec(
at::blas::ScalingType scaling_type,
int64_t M,
int64_t K,
int64_t N,
const std::string& arg_type) {
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
TORCH_INTERNAL_ASSERT(
(scaling_type == at::blas::ScalingType::TensorWise ||
scaling_type == at::blas::ScalingType::RowWise),
"Currently only support scaling_type for TensorWise or RowWise");
int64_t dim = K; // Currently only K is used for grouping
bool is_src = (arg_type == "src");
if (scaling_type == at::blas::ScalingType::TensorWise) {
// Scale tensorwise. The same as `--attr-scales=common`.
// mask=0 : scale whole tensor
// groups={1, 1}: indicates that there is only one group for scaling
return {0, {1, 1}, dnnl::memory::data_type::f32};
} else {
// (scaling_type == at::blas::ScalingType::RowWise)
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
return {
(1 << 0) | (1 << 1),
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
dnnl::memory::data_type::f32};
}
}
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum) {
auto& engine = GpuEngineManager::Instance().get_engine();
auto& stream = GpuStreamManager::Instance().get_stream();
// This function will do steps with following steps
// 1. create memory descriptor
// 2. call write_to_dnnl_memory() to actually write memory
// 3. execute
const int64_t M = mat1.size(0);
const int64_t K = mat1.size(1);
const int64_t N = mat2.size(1);
// 1.1 Create memory descriptor
dnnl::memory::desc src_md = get_onednn_md(mat1);
dnnl::memory::desc weights_md = get_onednn_md(mat2);
dnnl::memory::desc dst_md = get_onednn_md(result);
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
// So we could directly get their memory desc and set later.
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
dnnl::memory::desc bias_md;
bool with_bias = bias.has_value();
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
if (with_bias) {
if (possible_reshaped_bias.dim() == 1) {
possible_reshaped_bias =
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
bias_md = get_onednn_md(possible_reshaped_bias);
} else {
bias_md = get_onednn_md(possible_reshaped_bias);
}
}
// 1.2 Create primitive descriptor and set scales mask
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
#if ONEDNN_SUPPORT_DETERMINISTIC
if (at::globalContext().deterministicAlgorithms() ||
at::globalContext().deterministicMkldnn())
op_attr.set_deterministic(true);
#endif
std::vector<int64_t> default_groups;
op_attr.set_scales(
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
op_attr.set_scales(
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
// scale_result tensor currently only supports scalar(TensorWise Scaling).
bool with_dst_scale = scale_result && scale_result->defined();
if (with_dst_scale) {
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// 1.3 Create the matmul primitive descriptor
dnnl::matmul::primitive_desc matmul_pd = with_bias
? dnnl::matmul::primitive_desc(
engine, src_md, weights_md, bias_md, dst_md, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_md, weights_md, dst_md, op_attr);
// 1.4 (Possible) Additional Checks
// TODO: In case there are memory desc does not align with the actual tensor,
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
// call. For example, weights not the same as matmul_pd.weights_desc(),
// 2. Prepare memory
// Create memory
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
dnnl::memory b_usr_m;
if (with_bias) {
b_usr_m =
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
}
// Prepare runtime scale memories (flat 1-D views) using the specs
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
int64_t expected_numel,
const at::Tensor& scale_tensor) {
at::Tensor prepared = spec.normalize(scale_tensor);
TORCH_CHECK(
prepared.numel() == expected_numel,
"Scale buffer length mismatch. Expected ",
expected_numel,
", got ",
prepared.numel());
dnnl::memory::desc scale_md(
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
};
auto scratchpad =
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
// 3. Setup Args for exec
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, src_usr_m});
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
args.insert({DNNL_ARG_DST, dst_usr_m});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, b_usr_m});
}
// Attach runtime scales using specs
auto src_sc_mem = make_scale_mem_from_spec(
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
auto wei_sc_mem = make_scale_mem_from_spec(
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
if (with_dst_scale) {
// Bind single f32 scalar as DST scale
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
dnnl::memory::desc dst_sc_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto dst_sc_mem =
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
}
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
sycl::event matmul_fwd_event =
dnnl::sycl_interop::execute(matmul_p, stream, args);
return matmul_fwd_event;
}
} // namespace at::native::onednn

View File

@ -78,6 +78,10 @@ dnnl::memory::data_type get_onednn_dtype(
return dnnl::memory::data_type::f32;
case at::ScalarType::BFloat16:
return dnnl::memory::data_type::bf16;
case at::ScalarType::Float8_e4m3fn:
return dnnl::memory::data_type::f8_e4m3;
case at::ScalarType::Float8_e5m2:
return dnnl::memory::data_type::f8_e5m2;
default:
if (!allow_undef) {
TORCH_CHECK(

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/BlasBackend.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
@ -202,4 +203,16 @@ void sdpa_backward(
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value);
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum);
} // namespace at::native::onednn

View File

@ -82,7 +82,6 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
std::string to_hex_key(float);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);

View File

@ -301,10 +301,6 @@ std::string getArrayRefString(const IntArrayRef s) {
return fmt::to_string(fmt::join(s, ","));
}
std::string to_hex_key(float f) {
return fmt::format("{:a}", f);
}
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
fmt::basic_memory_buffer<char, 100> buffer;
auto buf_iterator = std::back_inserter(buffer);

View File

@ -96,7 +96,9 @@ kernel void addmm(
auto bias =
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
static_cast<T>(
c10::metal::mul(alpha_beta[0], sum) +
c10::metal::mul(alpha_beta[1], bias));
}
}

View File

@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
const Scalar& alpha,
const Scalar& beta,
const Tensor& bias) {
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
return do_metal_mm(self, other, output);
}
auto stream = getCurrentMPSStream();
@ -147,13 +147,15 @@ Tensor& do_metal_addmm(const Tensor& self,
std::array<int64_t, 2> i64;
std::array<int32_t, 2> i32;
std::array<float, 2> f32;
} alpha_beta;
std::array<c10::complex<float>, 2> c64;
} alpha_beta{};
if (output.scalar_type() == kLong) {
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
} else if (c10::isIntegralType(output.scalar_type(), true)) {
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
} else if (c10::isComplexType(output.scalar_type())) {
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
} else {
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
}
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs

View File

@ -244,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
if (has_min)
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar

View File

@ -301,12 +301,12 @@ class AvgPoolMicrokernelTester {
ASSERT_NEAR(
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
ASSERT_EQ(
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
}
}
@ -396,12 +396,12 @@ class AvgPoolMicrokernelTester {
ASSERT_NEAR(
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
ASSERT_EQ(
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
}
}

View File

@ -232,7 +232,7 @@ class MaxPoolMicrokernelTester {
ASSERT_EQ(
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc();
}
}

View File

@ -17,7 +17,7 @@ inline std::vector<T> _expand_param_if_needed(
std::ostringstream ss;
ss << "expected " << param_name << " to be a single integer value or a "
<< "list of " << expected_dim << " values to match the convolution "
<< "dimensions, but got " << param_name << "=" << list_param;
<< "dimensions, but got " << param_name << '=' << list_param;
TORCH_CHECK(false, ss.str());
} else {
return list_param.vec();

View File

@ -358,9 +358,9 @@ std::string Adapter::stringize() const {
std::string device_type = get_device_type_str(properties.deviceType);
VkPhysicalDeviceLimits limits = properties.limits;
ss << "{" << std::endl;
ss << '{' << std::endl;
ss << " Physical Device Info {" << std::endl;
ss << " apiVersion: " << v_major << "." << v_minor << std::endl;
ss << " apiVersion: " << v_major << '.' << v_minor << std::endl;
ss << " driverversion: " << properties.driverVersion << std::endl;
ss << " deviceType: " << device_type << std::endl;
ss << " deviceName: " << properties.deviceName << std::endl;
@ -371,7 +371,7 @@ std::string Adapter::stringize() const {
#define PRINT_LIMIT_PROP_VEC3(name) \
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
<< ',' << limits.name[1] << ',' << limits.name[2] << std::endl;
ss << " Physical Device Limits {" << std::endl;
PRINT_LIMIT_PROP(maxImageDimension1D);
@ -425,7 +425,7 @@ std::string Adapter::stringize() const {
;
}
ss << " ]" << std::endl;
ss << "}";
ss << '}';
return ss.str();
}

View File

@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
VK_RESULT_CASE(VK_ERROR_FORMAT_NOT_SUPPORTED)
VK_RESULT_CASE(VK_ERROR_FRAGMENTED_POOL)
default:
out << "VK_ERROR_UNKNOWN (VkResult " << result << ")";
out << "VK_ERROR_UNKNOWN (VkResult " << result << ')';
break;
}
return out;
@ -46,7 +46,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
//
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
out << loc.function << " at " << loc.file << ":" << loc.line;
out << loc.function << " at " << loc.file << ':' << loc.line;
return out;
}
@ -66,7 +66,7 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
: msg_(std::move(msg)), source_location_{source_location} {
std::ostringstream oss;
oss << "Exception raised from " << source_location_ << ": ";
oss << "(" << cond << ") is false! ";
oss << '(' << cond << ") is false! ";
oss << msg_;
what_ = oss.str();
}

View File

@ -173,8 +173,8 @@ void QueryPool::extract_results() {
static std::string stringize(const VkExtent3D& extents) {
std::stringstream ss;
ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth
<< "}";
ss << '{' << extents.width << ", " << extents.height << ", " << extents.depth
<< '}';
return ss.str();
}

View File

@ -149,7 +149,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
(void)flags;
std::stringstream stream;
stream << layer_prefix << " " << message_code << " " << message << std::endl;
stream << layer_prefix << ' ' << message_code << ' ' << message << std::endl;
const std::string log = stream.str();
std::cout << log;

View File

@ -253,7 +253,7 @@ using vec4 = vec<4u>;
// uvec3 is the type representing tensor extents. Useful for debugging.
inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
return os;
}

View File

@ -73,8 +73,8 @@ TEST(TestScalar, TestScalar) {
Scalar bar = 3.0;
Half h = bar.toHalf();
Scalar h2 = h;
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " "
<< bar.toDouble() << " " << what.isIntegral(false) << "\n";
cout << "H2: " << h2.toDouble() << ' ' << what.toFloat() << ' '
<< bar.toDouble() << ' ' << what.isIntegral(false) << "\n";
auto gen = at::detail::getDefaultCPUGenerator();
{
// See Note [Acquire lock when using random generators]

View File

@ -19,7 +19,7 @@ TEST(Vitals, Basic) {
c10::utils::set_env("TORCH_VITAL", "1");
TORCH_VITAL_DEFINE(Testing);
TORCH_VITAL(Testing, Attribute0) << 1;
TORCH_VITAL(Testing, Attribute1) << "1";
TORCH_VITAL(Testing, Attribute1) << '1';
TORCH_VITAL(Testing, Attribute2) << 1.0f;
TORCH_VITAL(Testing, Attribute3) << 1.0;
auto t = at::ones({1, 1});

View File

@ -129,14 +129,14 @@ void showRtol(const at::Tensor& a, const at::Tensor& b) {
std::cout << "Max Diff allowed: " << maxDiff << std::endl;
if (diff.sizes().size() == 2) {
for (const auto y : c10::irange(diff.sizes()[0])) {
std::cout << y << ":";
std::cout << y << ':';
for (const auto x : c10::irange(diff.sizes()[1])) {
float diff_xy = diff[y][x].item<float>();
if (diff_xy > maxDiff) {
std::cout << std::setw(5) << x;
}
else {
std::cout << std::setw(5) << " ";
std::cout << std::setw(5) << ' ';
}
}
std::cout << std::endl;
@ -3276,7 +3276,7 @@ TEST_F(VulkanAPITest, masked_fill_invalidinputs_exceptions) {
void print_shape(const std::vector<int64_t>& shape) {
for (const auto& num : shape) {
std::cout << num << " ";
std::cout << num << ' ';
}
}
@ -3367,7 +3367,7 @@ void test_masked_fill_scalar(
print_shape(tmp_curr_input_shape);
std::cout << "], and mask of shape [";
print_shape(tmp_curr_mask_shape);
std::cout << "]" << std::endl;
std::cout << ']' << std::endl;
}
ASSERT_TRUE(check);
@ -4542,9 +4542,9 @@ void test_softmax(const at::IntArrayRef shape, bool log_softmax = false) {
if (!check) {
std::cout << "Softmax test failed on axis " << dim << "for tensor dims {";
for (uint32_t place = 0; place < shape.size() - 1; place++) {
std::cout << shape[place] << " ";
std::cout << shape[place] << ' ';
}
std::cout << shape.back() << "}" << std::endl;
std::cout << shape.back() << '}' << std::endl;
showRtol(out_cpu, out_vulkan.cpu());
}
ASSERT_TRUE(check);

View File

@ -95,7 +95,7 @@ void showRtol(
std::cout << "Max Diff found is: " << diff.max().item<double>() << std::endl;
if (diff.sizes().size() == 2) {
for (const auto y : c10::irange(diff.sizes()[0])) {
std::cout << y << ":";
std::cout << y << ':';
for (const auto x : c10::irange(diff.sizes()[1])) {
double diff_xy = diff[y][x].item<double>();
if (diff_xy > maxDiff) {
@ -109,7 +109,7 @@ void showRtol(
}
}
} else {
std::cout << std::setw(5) << " ";
std::cout << std::setw(5) << ' ';
}
}
std::cout << std::endl;
@ -148,19 +148,19 @@ using at::native::vulkan::api::utils::ivec4;
using at::native::vulkan::api::utils::vec4;
std::ostream& operator<<(std::ostream& os, const vec4& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ')';
return os;
}
std::ostream& operator<<(std::ostream& os, const ivec3& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
return os;
}
std::ostream& operator<<(std::ostream& os, const ivec4& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ')';
return os;
}
@ -3379,7 +3379,7 @@ bool _test_quantized_linear(
showRtol(out_cpu_dequant, out_vk_to_cpu_dequant);
}
if (xpos != -1 && ypos != -1) {
std::cout << "\nFailure caused on row/col: " << ypos << "/" << xpos
std::cout << "\nFailure caused on row/col: " << ypos << '/' << xpos
<< "\n";
std::cout << "Input tensor scale: " << scale << " zerop: " << zero_point
<< "\n";

View File

@ -2,6 +2,7 @@
# These load paths point to different files in internal and OSS environment
load("@bazel_skylib//lib:paths.bzl", "paths")
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
pt_allow_forced_schema_registration = True,
compatible_with = [],
apple_sdks = None):
if get_fbsource_cell() == "fbcode":
return
oplist_dir_name = name + "_pt_oplist"
# @lint-ignore BUCKLINT
@ -865,6 +869,9 @@ def define_buck_targets(
pt_xplat_cxx_library = fb_xplat_cxx_library,
c2_fbandroid_xplat_compiler_flags = [],
labels = []):
if get_fbsource_cell() == "fbcode":
return
# @lint-ignore BUCKLINT
fb_native.filegroup(
name = "metal_build_srcs",

View File

@ -176,7 +176,7 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
os << k;
first = false;
}
os << ")";
os << ')';
return os;
}

View File

@ -34,20 +34,6 @@ namespace c10 {
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;

View File

@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
} else {
stream << "(nullopt)";
}
stream << ")";
stream << ')';
return stream;
}

View File

@ -136,7 +136,7 @@ std::string c10_retrieve_device_side_assertion_info() {
// Something failed, let's talk about that
oss << failures_found
<< " CUDA device-side assertion failures were found on GPU #"
<< device_num << "!" << std::endl;
<< device_num << '!' << std::endl;
if (assertion_data_for_device.assertion_count >
C10_CUDA_DSA_ASSERTION_COUNT) {
oss << "But at least " << assertion_data_for_device.assertion_count
@ -151,17 +151,17 @@ std::string c10_retrieve_device_side_assertion_info() {
oss << "Assertion failure " << i << std::endl;
oss << " GPU assertion failure message = " << self.assertion_msg
<< std::endl;
oss << " File containing assertion = " << self.filename << ":"
oss << " File containing assertion = " << self.filename << ':'
<< self.line_number << std::endl;
oss << " Device function containing assertion = " << self.function_name
<< std::endl;
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ","
<< self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl;
oss << " Block ID that failed assertion = [" << self.block_id[0] << ","
<< self.block_id[1] << "," << self.block_id[2] << "]" << std::endl;
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ','
<< self.thread_id[1] << ',' << self.thread_id[2] << ']' << std::endl;
oss << " Block ID that failed assertion = [" << self.block_id[0] << ','
<< self.block_id[1] << ',' << self.block_id[2] << ']' << std::endl;
if (launch_info.generation_number == self.caller) {
oss << " File containing kernel launch = "
<< launch_info.launch_filename << ":" << launch_info.launch_linenum
<< launch_info.launch_filename << ':' << launch_info.launch_linenum
<< std::endl;
oss << " Function containing kernel launch = "
<< launch_info.launch_function << std::endl;

View File

@ -435,7 +435,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
if (i > 0) {
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
<< ")";
<< ')';
} else {
ASSERT_TRUE(res.find("Unknown") == std::string::npos) << i;
}

View File

@ -98,7 +98,7 @@ struct Noncopyable {
};
std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) {
out << "Noncopyable(" << nc.x << ")";
out << "Noncopyable(" << nc.x << ')';
return out;
}
} // namespace

View File

@ -198,13 +198,13 @@ ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
template <typename T>
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
int i = 0;
out << "[";
out << '[';
for (const auto& e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
out << ']';
return out;
}

View File

@ -107,7 +107,7 @@ class GetBacktraceImpl {
/*status*/ &status);
os << " frame #" << idx++ << "\t"
<< ((demangled != NULL && status == 0) ? demangled : symbol) << "["
<< ((demangled != NULL && status == 0) ? demangled : symbol) << '['
<< addr << "]\t" << std::endl;
}
free(demangled);
@ -413,8 +413,8 @@ class GetBacktraceImpl {
<< back_trace_[i_frame] << std::dec;
if (with_symbol) {
stream << std::setfill('0') << std::setw(16) << std::uppercase
<< std::hex << p_symbol->Address << std::dec << " " << module
<< "!" << p_symbol->Name;
<< std::hex << p_symbol->Address << std::dec << ' ' << module
<< '!' << p_symbol->Name;
} else {
stream << " <unknown symbol address> " << module << "!<unknown symbol>";
}
@ -424,7 +424,7 @@ class GetBacktraceImpl {
} else {
stream << "<unknown file> @ <unknown line number>";
}
stream << "]" << std::endl;
stream << ']' << std::endl;
}
return stream.str();

View File

@ -44,7 +44,7 @@ std::string Error::compute_what(bool include_backtrace) const {
if (context_.size() == 1) {
// Fold error and context in one line
oss << " (" << context_[0] << ")";
oss << " (" << context_[0] << ')';
} else {
for (const auto& c : context_) {
oss << "\n " << c;
@ -247,7 +247,7 @@ void WarningHandler::process(const Warning& warning) {
LOG_AT_FILE_LINE(
WARNING, warning.source_location().file, warning.source_location().line)
<< "Warning: " << warning.msg() << " (function "
<< warning.source_location().function << ")";
<< warning.source_location().function << ')';
}
std::string GetExceptionString(const std::exception& e) {

View File

@ -473,12 +473,12 @@ MessageLogger::MessageLogger(
if (GLOBAL_RANK != -1) {
stream_ << "[rank" << GLOBAL_RANK << "]:";
}
stream_ << "[" << CAFFE2_SEVERITY_PREFIX[std::min(4, GLOG_FATAL - severity_)]
stream_ << '[' << CAFFE2_SEVERITY_PREFIX[std::min(4, GLOG_FATAL - severity_)]
<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday
<< std::setfill('0') << " " << std::setw(2) << timeinfo->tm_hour
<< ":" << std::setw(2) << timeinfo->tm_min << ":" << std::setw(2)
<< timeinfo->tm_sec << "." << std::setw(9) << ns << " "
<< c10::detail::StripBasename(std::string(file)) << ":" << line
<< std::setfill('0') << ' ' << std::setw(2) << timeinfo->tm_hour
<< ':' << std::setw(2) << timeinfo->tm_min << ':' << std::setw(2)
<< timeinfo->tm_sec << '.' << std::setw(9) << ns << ' '
<< c10::detail::StripBasename(std::string(file)) << ':' << line
<< "] ";
}

View File

@ -1412,13 +1412,13 @@ inline size_t capacity_in_bytes(const SmallVector<T, N>& X) {
template <typename T, unsigned N>
std::ostream& operator<<(std::ostream& out, const SmallVector<T, N>& list) {
int i = 0;
out << "[";
out << '[';
for (auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
out << ']';
return out;
}

View File

@ -79,7 +79,7 @@ std::ostream& _str(std::ostream& ss, const std::wstring& wString) {
} // namespace detail
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
out << loc.function << " at " << loc.file << ":" << loc.line;
out << loc.function << " at " << loc.file << ':' << loc.line;
return out;
}

View File

@ -170,7 +170,7 @@ inline bool isPrint(char s) {
}
inline void printQuotedString(std::ostream& stmt, const std::string_view str) {
stmt << "\"";
stmt << '"';
for (auto s : str) {
switch (s) {
case '\\':
@ -224,7 +224,7 @@ inline void printQuotedString(std::ostream& stmt, const std::string_view str) {
break;
}
}
stmt << "\"";
stmt << '"';
}
template <typename T>

View File

@ -223,7 +223,7 @@ void FatalSignalHandler::fatalSignalHandler(int signum) {
// a single thread that wouldn't receive the SIGUSR2
if (std::cv_status::timeout == writingCond.wait_for(ul, 2s)) {
if (!signalReceived) {
std::cerr << "signal lost waiting for stacktrace " << pid << ":"
std::cerr << "signal lost waiting for stacktrace " << pid << ':'
<< tid << '\n';
break;
}

View File

@ -877,7 +877,7 @@ std::ostream& operator<<(
std::ostream& stream,
const SparseBitVector<ElementSize>& vec) {
bool first = true;
stream << "{";
stream << '{';
for (auto el : vec) {
if (first) {
first = false;
@ -886,7 +886,7 @@ std::ostream& operator<<(
}
stream << el;
}
stream << "}";
stream << '}';
return stream;
}

View File

@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
);
} catch (const std::ios_base::failure& e) {
} catch (const std::ios_base::failure&) {
#ifdef _WIN32
// Windows have verbose error code, we prefer to use it than std errno.
uint32_t error_code = GetLastError();

View File

@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
if("${_arch}" STREQUAL "121a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a")
"89;90a;100a;103a;120a;121a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")

View File

@ -0,0 +1,113 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators

View File

@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe
skip_all_guards_unsafe
nested_compile_region
```

View File

@ -376,3 +376,19 @@ keep-runtime-typing = true
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"
[tool.spin]
package = 'torch'
[tool.spin.commands]
"Build" = [
".spin/cmds.py:lint",
".spin/cmds.py:fixlint",
".spin/cmds.py:quicklint",
".spin/cmds.py:quickfix",
]
"Regenerate" = [
".spin/cmds.py:regenerate_version",
".spin/cmds.py:regenerate_type_stubs",
".spin/cmds.py:regenerate_clangtidy_files",
]

View File

@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
networkx>=2.5.1
optree>=0.13.0
psutil
spin
sympy>=1.13.3
typing-extensions>=4.13.2
wheel

View File

@ -1358,45 +1358,6 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []

View File

@ -13,6 +13,17 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
#undef DEFINE_CHECK
}
TEST(TestScalarType, CppTypeToScalarType) {
using torch::headeronly::CppTypeToScalarType;
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
{ \
EXPECT_EQ( \

View File

@ -634,3 +634,38 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_parallel_for", &boxed_test_parallel_for);
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
}
Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
return flatten(t, start_dim, end_dim);
}
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
m.def("my_view(Tensor t, int[] size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
m.impl("my_flatten", TORCH_BOX(&my_flatten));
m.impl("my_reshape", TORCH_BOX(&my_reshape));
m.impl("my_view", TORCH_BOX(&my_view));
}

View File

@ -487,3 +487,58 @@ def test_get_num_threads() -> int:
Returns: int - the number of threads for the parallel backend
"""
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
"""
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
Args:
size: list[int] - size of the tensor to create
dtype: ScalarType or None - data type of the tensor
device: Device or None - device on which to create the tensor
pin_memory: bool or None - whether to use pinned memory
Returns: Tensor - an uninitialized tensor with the specified properties
"""
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the input tensor from start_dim to end_dim into a single dimension.
Args:
t: Tensor - tensor to flatten
start_dim: int - first dimension to flatten (default: 0)
end_dim: int - last dimension to flatten (default: -1)
Returns: Tensor - flattened tensor
"""
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
def my_reshape(t, shape) -> Tensor:
"""
Returns a tensor with the same data but different shape.
Args:
t: Tensor - tensor to reshape
shape: list[int] - new shape for the tensor
Returns: Tensor - reshaped tensor
"""
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
def my_view(t, size) -> Tensor:
"""
Returns a new tensor with the same data as the input tensor but of a different shape.
Args:
t: Tensor - tensor to view
size: list[int] - new size for the tensor
Returns: Tensor - tensor with new view
"""
return torch.ops.libtorch_agnostic.my_view.default(t, size)

View File

@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
"cxx": ["-fdiagnostics-color=always"],
}
extension = CppExtension

View File

@ -525,6 +525,97 @@ if not IS_WINDOWS:
expected_num_threads = torch.get_num_threads()
self.assertEqual(num_threads, expected_num_threads)
def test_my_empty(self, device):
import libtorch_agnostic
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
size = [2, 3]
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
expected = torch.empty(size)
self.assertEqual(result, expected, exact_device=True)
result_float = libtorch_agnostic.ops.my_empty(
size, torch.float32, None, None
)
expected_float = torch.empty(size, dtype=torch.float32)
self.assertEqual(result_float, expected_float, exact_device=True)
result_with_device = libtorch_agnostic.ops.my_empty(
size, torch.float64, device, None
)
expected_with_device = torch.empty(
size, dtype=torch.float64, device=device
)
self.assertEqual(
result_with_device, expected_with_device, exact_device=True
)
if device == "cuda":
result_pinned = libtorch_agnostic.ops.my_empty(
size, torch.float32, "cpu", True
)
expected_pinned = torch.empty(
size, dtype=torch.float32, device="cpu", pin_memory=True
)
self.assertEqual(result_pinned, expected_pinned)
self.assertTrue(result_pinned.is_pinned())
finally:
torch.use_deterministic_algorithms(deterministic)
def test_my_flatten(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_flatten(t)
expected = torch.flatten(t)
self.assertEqual(result, expected)
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
expected_start = torch.flatten(t, 1)
self.assertEqual(result_start, expected_start)
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
expected_range = torch.flatten(t, 2, -1)
self.assertEqual(result_range, expected_range)
def test_my_reshape(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
expected = torch.reshape(t, [6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
expected_infer = torch.reshape(t, [-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
expected_flat = torch.reshape(t, [-1])
self.assertEqual(result_flat, expected_flat)
def test_my_view(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_view(t, [6, 4])
expected = t.view([6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
expected_infer = t.view([-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
expected_flat = t.view([-1])
self.assertEqual(result_flat, expected_flat)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -4,17 +4,12 @@
#include <c10/util/Exception.h>
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

Some files were not shown because too many files have changed in this diff Show More