Compare commits

...

24 Commits

Author SHA1 Message Date
653c0ecf35 WIP add pointwise strategy 2025-11-17 21:18:29 -08:00
057434a442 claude generate tests 2025-11-17 21:03:07 -08:00
9fd0af1c3b Notes on tensor_ops 2025-11-17 21:03:07 -08:00
53305e5379 Support mm via single-dim strategy 2025-11-17 21:03:07 -08:00
ea5f2aceda document things 2025-11-17 21:01:28 -08:00
83557a528f [DTensor] add register_single_dim_strategy
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977ea86d849b84d453408109cc4f602019c4d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167677
2025-11-14 15:27:09 -08:00
54d05a0874 [DTensor] Fix mypy on register_op_strategy
ghstack-source-id: 59ef401df5d190a4c7611a327779ba00ba2a8d7c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167673
2025-11-14 15:27:09 -08:00
bfddfde50c Add basic spin config and linting commands (#167226)
This PR adds a basic spin configuration to allow for linting. It is designed as a drop-in replacement for the current Makefile based solution, i.e. it sets up and updates lintrunner based on the hashes of certain configuration files.

Lintrunner is called via Uv's `uvx` command, separating its environment from the general development environment in an effort to reduce instances of competing requirements breaking environments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167226
Approved by: https://github.com/atalman, https://github.com/albanD
2025-11-14 15:35:42 +00:00
b6570615f8 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-14 15:33:11 +00:00
226850cc66 [ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734
Approved by: https://github.com/eqy, https://github.com/cyyever
2025-11-14 08:44:04 +00:00
f8a2ce3b9a Fix inplace ops on Partial DTensors to preserve aliasing semantics (#164729)
Fixes #163374.

Here is the output from reproducible code:

```
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811]
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
  aten::clamp_(dt: f32[][R], None, 2)
    redistribute_input(0, [P] -> [R])
      redistribute_input(t: f32[], [P] -> [R])
        _c10d_functional::all_reduce(t: f32[], sum, 0)
        _c10d_functional::wait_tensor(t: f32[])
    aten::clamp_(t: f32[], None, 2)
    aten::view(t: f32[], [])
(Replicate(),)
tensor(2., device='cuda:0')
```

The behavior is now matching what you were expecting in issue #163374:

Expected behavior (from the issue):
  1. Placement should change from Partial(sum) to Replicate()
  2. Value should be tensor(2.) instead of tensor(144.)

  Actual output from this build:
  1. (Replicate(),) - placement is correct
  2. tensor(2., device='cuda:0') - value is correct

so the inplace operation now properly redistributes the partial DTensor to replicate before performing the clamp snd maintains the correct aliasing semantics. It also produces the expected clamped value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164729
Approved by: https://github.com/ezyang
2025-11-14 07:46:35 +00:00
e2c6834584 Revert "deprecate check_is_size and guard_size_oblivious (#167198)"
This reverts commit 50bf1f0b819f0b1cc9acbb0646ac9555bb9d44b9.

Reverted https://github.com/pytorch/pytorch/pull/167198 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167198#issuecomment-3531149912))
2025-11-14 06:46:15 +00:00
0e7235ed73 [xpu][feature] [1/3] add fp8 scaled_mm implementation for XPU (#165978)
This PR implements `scaled_mm` for XPU. It enables the following data types:
1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2`
2. RowWise Scaling:  `fp8_e4m3` and `fp8_e5m2`

It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts.

This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts.

Secondly, there is a `scaled_mm_v2` API in #164141 . We will align with it once the v1 is cleaned up.

**Co-author:** @yuchengliu1, @carsonwang

## PR stack:

- -> https://github.com/pytorch/pytorch/pull/165978 : implementation of XPU scaled_mm and oneDNN kernel
- https://github.com/pytorch/pytorch/pull/167518 : implementation of XPU scaled_mm_v2
- https://github.com/pytorch/pytorch/pull/166056 : Op registration

## Test Status:

1. Relies on the changes in https://github.com/intel/torch-xpu-ops/pull/1746/, Otherwise the op will fallback to CPU.
2. This PR does not include tests, the tests are enabled in #166056.

## Credit:

This work is based on @yuchengliu1's work at #140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts.

## FP8 Task tracker:
We will track all the scaled_mm related tasks in: https://github.com/pytorch/pytorch/issues/167170

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165978
Approved by: https://github.com/liangan1, https://github.com/EikanWang

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
2025-11-14 06:41:18 +00:00
3522e0ce74 Revert "Fix different seq length (#167481)"
This reverts commit c78e64622e62eb93a03a9c3762df3290d6c65362.

Reverted https://github.com/pytorch/pytorch/pull/167481 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167481#issuecomment-3530992724))
2025-11-14 06:05:45 +00:00
50bf1f0b81 deprecate check_is_size and guard_size_oblivious (#167198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167198
Approved by: https://github.com/bobrenjc93
2025-11-14 05:35:29 +00:00
c78e64622e Fix different seq length (#167481)
Differential Revision: D86685546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167481
Approved by: https://github.com/eellison
2025-11-14 05:31:29 +00:00
5623628894 [SymmMem] op to get remote tensors (#167779)
To support use case in https://github.com/pytorch/helion/pull/1122, i.e.
```
@helion.kernel
def foo(
    x: Tensor,
    group_name: str
):
    x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name)
    for t in x_remotes:
        ...
````

Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function:
```
hdl = rendezvous(tensor)
remote_tensors = tuple(
    hdl.get_remote_tensor(peer, ...) for peer in range(world_size)
)
```
The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail.

This PR wraps the above functionality as an op:
```
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
```
so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167779
Approved by: https://github.com/yf225
2025-11-14 05:01:55 +00:00
2aba180114 Always track _local_scalar_dense output in tensorify_python_scalars. (#166573)
We need to track all symbols, we used to skip
u = item()
and fail with
```
 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166573
Approved by: https://github.com/bobrenjc93
2025-11-14 03:51:43 +00:00
45b2c3d312 [OpenReg][Feat][Docs] Enrich OpenReg device management implementation and add focused documentation (#165897)
## Summary
This PR enriches OpenReg device management codes and adds focused documentation.

## Key Changes
- Introduced device management documentation in `device.md`.
- Updated `OpenRegFunctions.h` and `OpenRegFunctions.cpp` to use `DeviceIndex` and added error handling.
- Implemented `check_device_index` function for validating device indices.
- Enhanced Python bindings in `Module.cpp` for device management.
- Added tests for invalid device index handling in `test_device.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165897
Approved by: https://github.com/fffrog
2025-11-14 03:08:23 +00:00
5b1e112cf9 [Dynamo] Imporve-graph-break-skip-logs (#167067)
Fixes #150477

### Summary:

- Added frame information (function name, file, line number) to all graph break/skip messages
- Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>"

### Impacts:
module: dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167067
Approved by: https://github.com/williamwen42
2025-11-14 03:06:37 +00:00
5e6ac5c6e1 [Pytorch] Improve conversion to bfloat16 on aarch64/NEON (#166958)
Summary:
Autovectorization of casting to bfloat16_t is broken in clang-[17, 20], fixed in clang-21.

We are adding a workaround vectorized code, which improves conversion speed from smaller int data types.

We've observed the following performance improvements, when compiling with clang-19 and targeting armv9a+sve2:

before:

uint8->bfloat16_t  ===> 319.433us
int8->bfloat16_t  ===> 320.216us
int16->bfloat16_t  ===> 326.899us
int32->bfloat16_t  ===> 327.925us

after:

uint8->bfloat16_t  ===> 185.189us  -----> 72% higher throughput
int8->bfloat16_t  ===> 169.790us  -----> 89% higher throughput
int16->bfloat16_t  ===> 180.744us  -----> 81% higher throughput
int32->bfloat16_t  ===> 185.129us  -----> 77% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D86207189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166958
Approved by: https://github.com/mcfi
2025-11-14 02:40:08 +00:00
79317dc7a7 Fix no source name in backward kernel names; Add flex_attention HOP to "original_aten" node meta (#167749)
Fixes #167706

- Add `torch.fx.experimental.proxy_tensor.set_original_aten_op()` around flex_atention HOP dispatch so we have `original_aten` populated for flex_attention
- Update the usages of `original_aten` to also expect HOP in addition to OpOverload

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167749
Approved by: https://github.com/drisspg
2025-11-14 02:24:22 +00:00
96a4c4b3d1 add device generalization support for distributed tests (#165067)
## MOTIVATION
To generalize Distributed test cases for non-CUDA devices

## CHANGES
- Replaced hard coded device/backends with torch.accelerator.current_accelerator() and dist.get_default_backend_for_device
- Use DistributedTestBase instead of MultiProcessTestCase to use common utilities
- Remove instantiate_device_tests and make use of torch.accelerator.current_accelerator for test/distributed/test_c10d_object_collectives.py
- fix deterministic context issue for non-cuda devices in test/distributed/optim/test_zero_redundancy_optimizer.py
- use torch.accelerator.device_count() for multi-gpu check in torch/testing/_internal/distributed/_tensor/common_dtensor.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165067
Approved by: https://github.com/guangyey, https://github.com/albanD
2025-11-14 02:21:11 +00:00
05bcfcc5d1 [Profiler] Add Documentation for FunctionEvent (#167688)
Summary:
Adds documentation for EventList, FunctionEvent and FunctionEventAvg.

Closes https://github.com/pytorch/pytorch/issues/165907

Test Plan: N/A Documentation

Differential Revision: D86913697

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167688
Approved by: https://github.com/sanrise
2025-11-14 02:03:19 +00:00
60 changed files with 2667 additions and 284 deletions

330
.spin/cmds.py Normal file
View File

@ -0,0 +1,330 @@
import hashlib
import subprocess
import sys
from pathlib import Path
import click
import spin
def file_digest(file, algorithm: str):
try:
return hashlib.file_digest(file, algorithm)
except AttributeError:
pass # Fallback to manual implementation below
hash = hashlib.new(algorithm)
while chunk := file.read(8192):
hash.update(chunk)
return hash
def _hash_file(file):
with open(file, "rb") as f:
hash = file_digest(f, "sha256")
return hash.hexdigest()
def _hash_files(files):
hashes = {file: _hash_file(file) for file in files}
return hashes
def _read_hashes(hash_file: Path):
if not hash_file.exists():
return {}
with hash_file.open("r") as f:
lines = f.readlines()
hashes = {}
for line in lines:
hash = line[:64]
file = line[66:].strip()
hashes[file] = hash
return hashes
def _updated_hashes(hash_file, files_to_hash):
old_hashes = _read_hashes(hash_file)
new_hashes = _hash_files(files_to_hash)
if new_hashes != old_hashes:
return new_hashes
return None
@click.command()
def regenerate_version():
"""Regenerate version.py."""
cmd = [
sys.executable,
"-m",
"tools.generate_torch_version",
"--is-debug=false",
]
spin.util.run(cmd)
TYPE_STUBS = [
(
"Pytorch type stubs",
Path(".lintbin/.pytorch-type-stubs.sha256"),
[
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"tools/autograd/deprecated.yaml",
],
[
sys.executable,
"-m",
"tools.pyi.gen_pyi",
"--native-functions-path",
"aten/src/ATen/native/native_functions.yaml",
"--tags-path",
"aten/src/ATen/native/tags.yaml",
"--deprecated-functions-path",
"tools/autograd/deprecated.yaml",
],
),
(
"Datapipes type stubs",
None,
[],
[
sys.executable,
"torch/utils/data/datapipes/gen_pyi.py",
],
),
]
@click.command()
def regenerate_type_stubs():
"""Regenerate type stubs."""
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
if hash_file:
if hashes := _updated_hashes(hash_file, files_to_hash):
click.echo(
f"Changes detected in type stub files for {name}. Regenerating..."
)
spin.util.run(cmd)
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Type stubs and hashes updated.")
else:
click.echo(f"No changes detected in type stub files for {name}.")
else:
click.echo(f"No hash file for {name}. Regenerating...")
spin.util.run(cmd)
click.echo("Type stubs regenerated.")
@click.command()
def regenerate_clangtidy_files():
"""Regenerate clang-tidy files."""
cmd = [
sys.executable,
"-m",
"tools.linter.clang_tidy.generate_build_files",
]
spin.util.run(cmd)
#: These linters are expected to need less than 3s cpu time total
VERY_FAST_LINTERS = {
"ATEN_CPU_GPU_AGNOSTIC",
"BAZEL_LINTER",
"C10_NODISCARD",
"C10_UNUSED",
"CALL_ONCE",
"CMAKE_MINIMUM_REQUIRED",
"CONTEXT_DECORATOR",
"COPYRIGHT",
"CUBINCLUDE",
"DEPLOY_DETECTION",
"ERROR_PRONE_ISINSTANCE",
"EXEC",
"HEADER_ONLY_LINTER",
"IMPORT_LINTER",
"INCLUDE",
"LINTRUNNER_VERSION",
"MERGE_CONFLICTLESS_CSV",
"META_NO_CREATE_UNBACKED",
"NEWLINE",
"NOQA",
"NO_WORKFLOWS_ON_FORK",
"ONCE_FLAG",
"PYBIND11_INCLUDE",
"PYBIND11_SPECIALIZATION",
"PYPIDEP",
"PYPROJECT",
"RAWCUDA",
"RAWCUDADEVICE",
"ROOT_LOGGING",
"TABS",
"TESTOWNERS",
"TYPEIGNORE",
"TYPENOSKIP",
"WORKFLOWSYNC",
}
#: These linters are expected to take a few seconds, but less than 10s cpu time total
FAST_LINTERS = {
"CMAKE",
"DOCSTRING_LINTER",
"GHA",
"NATIVEFUNCTIONS",
"RUFF",
"SET_LINTER",
"SHELLCHECK",
"SPACES",
}
#: These linters are expected to take more than 10s cpu time total;
#: some need more than 1 hour.
SLOW_LINTERS = {
"ACTIONLINT",
"CLANGFORMAT",
"CLANGTIDY",
"CODESPELL",
"FLAKE8",
"GB_REGISTRY",
"PYFMT",
"PYREFLY",
"TEST_DEVICE_BIAS",
"TEST_HAS_MAIN",
}
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
LINTRUNNER_CACHE_INFO = (
Path(".lintbin/.lintrunner.sha256"),
[
"requirements.txt",
"pyproject.toml",
".lintrunner.toml",
],
)
LINTRUNNER_BASE_CMD = [
"uvx",
"--python",
"3.10",
"lintrunner@0.12.7",
]
@click.command()
def setup_lint():
"""Set up lintrunner with current CI version."""
cmd = LINTRUNNER_BASE_CMD + ["init"]
subprocess.run(cmd, check=True, capture_output=True, text=True)
def _check_linters():
cmd = LINTRUNNER_BASE_CMD + ["list"]
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
unknown_linters = linters - ALL_LINTERS
missing_linters = ALL_LINTERS - linters
if unknown_linters:
click.secho(
f"Unknown linters found; please add them to the correct category "
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
fg="yellow",
)
if missing_linters:
click.secho(
f"Missing linters found; please update the corresponding category "
f"in .spin/cmds.py: {', '.join(missing_linters)}",
fg="yellow",
)
return unknown_linters, missing_linters
@spin.util.extend_command(
setup_lint,
doc=f"""
If configuration has changed, update lintrunner.
Compares the stored old hashes of configuration files with new ones and
performs setup via setup-lint if the hashes have changed.
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
""",
)
@click.pass_context
def lazy_setup_lint(ctx, parent_callback, **kwargs):
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
click.echo(
"Changes detected in lint configuration files. Setting up linting tools..."
)
parent_callback(**kwargs)
hash_file = LINTRUNNER_CACHE_INFO[0]
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Linting tools set up and hashes updated.")
else:
click.echo("No changes detected in lint configuration files. Skipping setup.")
click.echo("Regenerating version...")
ctx.invoke(regenerate_version)
click.echo("Regenerating type stubs...")
ctx.invoke(regenerate_type_stubs)
click.echo("Done.")
_check_linters()
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def lint(ctx, apply_patches, **kwargs):
"""Lint all files."""
ctx.invoke(lazy_setup_lint)
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
changed_files_linters = SLOW_LINTERS
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
all_files_cmd = cmd + [
"--take",
",".join(all_files_linters),
"--all-files",
]
spin.util.run(all_files_cmd)
changed_files_cmd = cmd + [
"--take",
",".join(changed_files_linters),
]
spin.util.run(changed_files_cmd)
@click.command()
@click.pass_context
def fixlint(ctx, **kwargs):
"""Autofix all files."""
ctx.invoke(lint, apply_patches=True)
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def quicklint(ctx, apply_patches, **kwargs):
"""Lint changed files."""
ctx.invoke(lazy_setup_lint)
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
spin.util.run(cmd)
@click.command()
@click.pass_context
def quickfix(ctx, **kwargs):
"""Autofix changed files."""
ctx.invoke(quicklint, apply_patches=True)

View File

@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
// clang-[17, 20] crashes when autovectorizing static cast to bf16
// Below is a workaround to have some vectorization
// Works decently well for smaller int types
template <typename from_type>
inline void convertToBf16Impl(
const from_type* __restrict src,
c10::BFloat16* __restrict dst,
uint64_t n) {
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
uint64_t loopBound = n - (n % 16);
uint64_t i = 0;
for (; i < loopBound; i += 16) {
float32x4_t a, b, c, d;
a[0] = static_cast<float>(src[i]);
a[1] = static_cast<float>(src[i + 1]);
a[2] = static_cast<float>(src[i + 2]);
a[3] = static_cast<float>(src[i + 3]);
b[0] = static_cast<float>(src[i + 4]);
b[1] = static_cast<float>(src[i + 5]);
b[2] = static_cast<float>(src[i + 6]);
b[3] = static_cast<float>(src[i + 7]);
c[0] = static_cast<float>(src[i + 8]);
c[1] = static_cast<float>(src[i + 9]);
c[2] = static_cast<float>(src[i + 10]);
c[3] = static_cast<float>(src[i + 11]);
d[0] = static_cast<float>(src[i + 12]);
d[1] = static_cast<float>(src[i + 13]);
d[2] = static_cast<float>(src[i + 14]);
d[3] = static_cast<float>(src[i + 15]);
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
}
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
for (; i < n; i++) {
float a = static_cast<float>(src[i]);
dstPtr[i] = vcvth_bf16_f32(a);
}
}
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
template <> \
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
return convertToBf16Impl<from_type>(src, dst, n); \
}
CONVERT_TO_BF16_TEMPLATE(uint8_t)
CONVERT_TO_BF16_TEMPLATE(int8_t)
CONVERT_TO_BF16_TEMPLATE(int16_t)
CONVERT_TO_BF16_TEMPLATE(int32_t)
#endif
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
c10::BFloat16* __restrict dst,

View File

@ -0,0 +1,342 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - Else if scale.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
* 1:
* - Returns RowWise.
*
* - Otherwise:
* - Returns Error.
*/
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return at::isFloat8Type(t.scalar_type()) &&
scale.scalar_type() == at::kFloat && scale.numel() == 1;
}
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
scale.is_contiguous());
}
bool is_desired_scaling(
const at::Tensor& t,
const at::Tensor& scale,
ScalingType desired_scaling) {
auto result = desired_scaling == ScalingType::TensorWise
? is_tensorwise_scaling(t, scale)
: is_rowwise_scaling(t, scale);
return result;
}
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& scale_a,
const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) &&
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
}
TORCH_CHECK(
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
a.size(0),
", 1) and scale_b should be (1, ",
b.size(1),
"), and both should be contiguous.\n"
"Got a.dtype()=",
a.scalar_type(),
", scale_a.dtype()=",
scale_a.scalar_type(),
", scale_a.size()=",
scale_a.sizes(),
", scale_a.stride()=",
scale_a.strides(),
", ",
"b.dtype()=",
b.scalar_type(),
", scale_b.dtype()=",
scale_b.scalar_type(),
", scale_b.size()=",
scale_b.sizes(),
" and scale_b.stride()=",
scale_b.strides());
}
Tensor& _scaled_gemm(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const ScalingType scaling_choice_a,
const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out,
const std::optional<Tensor>& alpha = std::nullopt) {
// TODO: scale_result and alpha is not defined or used!
std::optional<Tensor> scaled_result = std::nullopt;
at::native::onednn::scaled_matmul(
mat1,
mat2,
out,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
scaled_result,
use_fast_accum);
return out;
}
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output
// matrices Scales are only applicable when matrices are of Float8 type and
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
// type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only
// utilized if the output is a float8 type
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
// false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_out_xpu(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
// Note: fast_accum is not supported in XPU for now.
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
// Check what type of scaling we are doing based on inputs. This list is
// sorted by decreasing priority.
// List of supported datatypes for XPU with oneDNN:
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
},
mat1,
mat2,
scale_a,
scale_b);
TORCH_CHECK(
!scale_result ||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(
!bias || bias->numel() == mat2.sizes()[1],
"Bias must be size ",
mat2.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK(
mat1.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
").");
TORCH_CHECK(
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
"mat2 shape (",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
") must be divisible by 16");
// Check types
TORCH_CHECK(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
TORCH_CHECK(
at::isFloat8Type(mat1.scalar_type()),
"Expected mat1 to be Float8 matrix got ",
mat1.scalar_type());
TORCH_CHECK(
at::isFloat8Type(mat2.scalar_type()),
"Expected mat2 to be Float8 matrix got ",
mat2.scalar_type());
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
// support 2D scales, only 1D. Needs to add more checks there.
if (bias) {
TORCH_CHECK(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
auto scale_result_ = scale_result.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat1, "mat1", 1},
{mat2, "mat2", 2},
{bias_, "bias", 3},
{scale_a, "scale_a", 4},
{scale_b, "scale_b", 5},
{scale_result_, "scale_result", 6}};
checkAllSameGPU(__func__, targs);
}
// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}
return out;
}
// TODO: Scale_result is not supported by now!!
return _scaled_gemm(
mat1,
mat2,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
}
Tensor _scaled_mm_xpu(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_xpu(
mat_a,
mat_b,
scale_a,
scale_b,
bias,
scale_result,
out_dtype,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -1,3 +1,4 @@
#include <ATen/BlasBackend.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <c10/core/ScalarType.h>
@ -8,7 +9,6 @@
#include <oneapi/dnnl/dnnl.hpp>
namespace at::native::onednn {
at::Tensor broadcast_bias2D(
at::Tensor& dst,
at::Tensor& bias,
@ -328,4 +328,236 @@ void quantized_matmul(
result.copy_(dst);
}
// Describes how to configure oneDNN scales for a given role/ScalingType
struct ScaleSpec {
// specifies the way scale values will be applied to an ARG tensor.
int mask;
// specifies how scales are grouped along dimensions where
// multiple scale factors are used.
dnnl::memory::dims groups;
// specifies data type for scale factors.
dnnl::memory::data_type dtype;
// Helper to compute expected number of elements for scale tensors
// arg_type: "src" for SRC (groups pattern {1, X}),
// "wei" for WEIGHTS (groups pattern {X, 1})
int64_t expected_numel(
int64_t outer_dim,
int64_t inner_dim,
const std::string& arg_type) const {
if (groups == dnnl::memory::dims{1, 1})
return 1; // tensorwise scaling
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
TORCH_INTERNAL_ASSERT(
(groups == dnnl::memory::dims{1, inner_dim} ||
groups == dnnl::memory::dims{inner_dim, 1}),
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
groups,
".");
return outer_dim;
}
// Normalize an incoming scale tensor to contiguous storage and appropriate
// dtype/view
at::Tensor normalize(const at::Tensor& scale) const {
TORCH_INTERNAL_ASSERT(
dtype == dnnl::memory::data_type::f32,
"tensor scale currently must be f32, but got scale dtype: ",
scale.scalar_type());
return scale.to(at::kFloat).contiguous();
}
};
// This function defines how to set scales mask and groups according to:
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
// The returned value will be used in
// `set_scales(arg, mask, groups, data_type)`.
inline ScaleSpec make_scale_spec(
at::blas::ScalingType scaling_type,
int64_t M,
int64_t K,
int64_t N,
const std::string& arg_type) {
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
TORCH_INTERNAL_ASSERT(
(scaling_type == at::blas::ScalingType::TensorWise ||
scaling_type == at::blas::ScalingType::RowWise),
"Currently only support scaling_type for TensorWise or RowWise");
int64_t dim = K; // Currently only K is used for grouping
bool is_src = (arg_type == "src");
if (scaling_type == at::blas::ScalingType::TensorWise) {
// Scale tensorwise. The same as `--attr-scales=common`.
// mask=0 : scale whole tensor
// groups={1, 1}: indicates that there is only one group for scaling
return {0, {1, 1}, dnnl::memory::data_type::f32};
} else {
// (scaling_type == at::blas::ScalingType::RowWise)
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
return {
(1 << 0) | (1 << 1),
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
dnnl::memory::data_type::f32};
}
}
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum) {
auto& engine = GpuEngineManager::Instance().get_engine();
auto& stream = GpuStreamManager::Instance().get_stream();
// This function will do steps with following steps
// 1. create memory descriptor
// 2. call write_to_dnnl_memory() to actually write memory
// 3. execute
const int64_t M = mat1.size(0);
const int64_t K = mat1.size(1);
const int64_t N = mat2.size(1);
// 1.1 Create memory descriptor
dnnl::memory::desc src_md = get_onednn_md(mat1);
dnnl::memory::desc weights_md = get_onednn_md(mat2);
dnnl::memory::desc dst_md = get_onednn_md(result);
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
// So we could directly get their memory desc and set later.
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
dnnl::memory::desc bias_md;
bool with_bias = bias.has_value();
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
if (with_bias) {
if (possible_reshaped_bias.dim() == 1) {
possible_reshaped_bias =
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
bias_md = get_onednn_md(possible_reshaped_bias);
} else {
bias_md = get_onednn_md(possible_reshaped_bias);
}
}
// 1.2 Create primitive descriptor and set scales mask
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
#if ONEDNN_SUPPORT_DETERMINISTIC
if (at::globalContext().deterministicAlgorithms() ||
at::globalContext().deterministicMkldnn())
op_attr.set_deterministic(true);
#endif
std::vector<int64_t> default_groups;
op_attr.set_scales(
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
op_attr.set_scales(
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
// scale_result tensor currently only supports scalar(TensorWise Scaling).
bool with_dst_scale = scale_result && scale_result->defined();
if (with_dst_scale) {
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// 1.3 Create the matmul primitive descriptor
dnnl::matmul::primitive_desc matmul_pd = with_bias
? dnnl::matmul::primitive_desc(
engine, src_md, weights_md, bias_md, dst_md, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_md, weights_md, dst_md, op_attr);
// 1.4 (Possible) Additional Checks
// TODO: In case there are memory desc does not align with the actual tensor,
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
// call. For example, weights not the same as matmul_pd.weights_desc(),
// 2. Prepare memory
// Create memory
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
dnnl::memory b_usr_m;
if (with_bias) {
b_usr_m =
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
}
// Prepare runtime scale memories (flat 1-D views) using the specs
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
int64_t expected_numel,
const at::Tensor& scale_tensor) {
at::Tensor prepared = spec.normalize(scale_tensor);
TORCH_CHECK(
prepared.numel() == expected_numel,
"Scale buffer length mismatch. Expected ",
expected_numel,
", got ",
prepared.numel());
dnnl::memory::desc scale_md(
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
};
auto scratchpad =
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
// 3. Setup Args for exec
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, src_usr_m});
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
args.insert({DNNL_ARG_DST, dst_usr_m});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, b_usr_m});
}
// Attach runtime scales using specs
auto src_sc_mem = make_scale_mem_from_spec(
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
auto wei_sc_mem = make_scale_mem_from_spec(
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
if (with_dst_scale) {
// Bind single f32 scalar as DST scale
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
dnnl::memory::desc dst_sc_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto dst_sc_mem =
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
}
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
sycl::event matmul_fwd_event =
dnnl::sycl_interop::execute(matmul_p, stream, args);
return matmul_fwd_event;
}
} // namespace at::native::onednn

View File

@ -78,6 +78,10 @@ dnnl::memory::data_type get_onednn_dtype(
return dnnl::memory::data_type::f32;
case at::ScalarType::BFloat16:
return dnnl::memory::data_type::bf16;
case at::ScalarType::Float8_e4m3fn:
return dnnl::memory::data_type::f8_e4m3;
case at::ScalarType::Float8_e5m2:
return dnnl::memory::data_type::f8_e5m2;
default:
if (!allow_undef) {
TORCH_CHECK(

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/BlasBackend.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
@ -202,4 +203,16 @@ void sdpa_backward(
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value);
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum);
} // namespace at::native::onednn

View File

@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
if("${_arch}" STREQUAL "121a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a")
"89;90a;100a;103a;120a;121a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")

View File

@ -0,0 +1,113 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators

View File

@ -376,3 +376,19 @@ keep-runtime-typing = true
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"
[tool.spin]
package = 'torch'
[tool.spin.commands]
"Build" = [
".spin/cmds.py:lint",
".spin/cmds.py:fixlint",
".spin/cmds.py:quicklint",
".spin/cmds.py:quickfix",
]
"Regenerate" = [
".spin/cmds.py:regenerate_version",
".spin/cmds.py:regenerate_type_stubs",
".spin/cmds.py:regenerate_clangtidy_files",
]

View File

@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
networkx>=2.5.1
optree>=0.13.0
psutil
spin
sympy>=1.13.3
typing-extensions>=4.13.2
wheel

View File

@ -4,17 +4,12 @@
#include <c10/util/Exception.h>
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

View File

@ -1,3 +1,4 @@
#include <c10/util/Exception.h>
#include <include/openreg.h>
#include "OpenRegException.h"
@ -9,21 +10,22 @@ orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count);
}
orError_t GetDevice(c10::DeviceIndex* device) {
orError_t GetDevice(DeviceIndex* device) {
int tmp_device = -1;
auto err = orGetDevice(&tmp_device);
*device = static_cast<c10::DeviceIndex>(tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
return err;
}
orError_t SetDevice(c10::DeviceIndex device) {
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
orError_t SetDevice(DeviceIndex device) {
int cur_device = -1;
orGetDevice(&cur_device);
OPENREG_CHECK(orGetDevice(&cur_device));
if (device == cur_device) {
return orSuccess;
}
return orSetDevice(device);
}
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
int device_count_impl() {
int count = 0;
@ -31,34 +33,37 @@ int device_count_impl() {
return count;
}
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
OPENREG_EXPORT DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
TORCH_CHECK(
result <= std::numeric_limits<c10::DeviceIndex>::max(),
result <= std::numeric_limits<DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;
} catch (const c10::Error& ex) {
} catch (const Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg());
return 0;
}
}();
return static_cast<c10::DeviceIndex>(count);
return static_cast<DeviceIndex>(count);
}
OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
OPENREG_EXPORT DeviceIndex current_device() {
DeviceIndex cur_device = -1;
OPENREG_CHECK(GetDevice(&cur_device));
return cur_device;
}
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device);
// LITERALINCLUDE START: OPENREG set_device FUNCTION
OPENREG_EXPORT void set_device(DeviceIndex device) {
check_device_index(device);
OPENREG_CHECK(SetDevice(device));
}
// LITERALINCLUDE END: OPENREG set_device FUNCTION
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
@ -71,4 +76,8 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
return current_device;
}
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
check_device_index(to_device);
return ExchangeDevice(to_device);
}
} // namespace c10::openreg

View File

@ -9,10 +9,20 @@
namespace c10::openreg {
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
OPENREG_EXPORT DeviceIndex device_count() noexcept;
OPENREG_EXPORT DeviceIndex current_device();
OPENREG_EXPORT void set_device(DeviceIndex device);
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
static inline void check_device_index(int64_t device) {
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::openreg::device_count()),
"), but got ",
static_cast<int>(device),
".");
}
} // namespace c10::openreg

View File

@ -2,6 +2,8 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
} // namespace c10::openreg

View File

@ -11,6 +11,7 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
@ -58,6 +59,7 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
set_device(d.index());
}
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
/**
* Set the current device to c10::Device, without checking for errors

View File

@ -27,6 +27,10 @@ class TestDevice(TestCase):
self.assertEqual(torch.accelerator.current_device_index(), 1)
self.assertEqual(torch.accelerator.current_device_index(), device)
def test_invalid_device_index(self):
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.set_device_index(2)
if __name__ == "__main__":
run_tests()

View File

@ -34,18 +34,21 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
}
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
PyObject* _setDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
auto device = THPUtils_unpackLong(arg);
auto device = THPUtils_unpackDeviceIndex(arg);
torch::utils::device_lazy_init(at::kPrivateUse1);
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
c10::openreg::set_device(device);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");

View File

@ -41,8 +41,13 @@ def current_device():
return torch_openreg._C._get_device()
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
def set_device(device) -> None:
return torch_openreg._C._set_device(device)
if device >= 0:
torch_openreg._C._set_device(device)
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
def init():

View File

@ -65,6 +65,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
curr_backend = dist.get_default_backend_for_device(device_type)
class SimpleModel(nn.Module):
@ -422,10 +423,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
return f"cpu:gloo,{device_type}:{curr_backend}"
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_fully_shard_tp_2d_set_full_state_dict(self):
dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh(
@ -514,8 +515,8 @@ class Test2dFSDP1ParallelIntegration(DTensorTestBase):
).to_local()
self.assertEqual(param_m2, param_m1)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_ddp_integration_functionality(self) -> None:
model, twod_model, dp_pg = self.init_model(self.device_type)
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
@ -566,8 +567,8 @@ class TestNew2dParallelTraining(DTensorTestBase):
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_fsdp_state_enable_extension(self):
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
@ -642,18 +643,18 @@ class TestNew2dParallelTraining(DTensorTestBase):
# Ensure all params are still the same after optimizer update.
self._compare_params(model, model_2d)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_e2e_training_default(self):
self._test_2d_e2e_training()
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_e2e_training_use_orig_params(self):
self._test_2d_e2e_training(use_orig_params=True)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_e2e_training_not_use_orig_params(self):
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
# self._test_2d_e2e_training(recompute_activation=True)
@ -666,10 +667,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
return f"cpu:gloo,{device_type}:{curr_backend}"
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_fsdp_2d_extension(self):
"""
Test whether _fsdp_extension from FSDPstate has been set correctly.
@ -700,8 +701,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
@parametrize("is_even_sharded_model", [True, False])
def test_2d_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -756,8 +757,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
@parametrize("is_even_sharded_model", [True, False])
def test_2d_load_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -811,8 +812,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
self.assertEqual(v1.device_mesh, v2.device_mesh)
self.assertEqual(v1.placements, v2.placements)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
@parametrize("is_even_sharded_model", [True, False])
def test_2d_optim_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -899,9 +900,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
else:
self.assertEqual(new_state, state)
@skip_if_lt_x_gpu(4)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_fsdp1_tp_2d_set_full_state_dict(self):
"""
This is a workaround for loading full state dict into a FSDP1+TP 2D model.

View File

@ -29,8 +29,8 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
at_least_x_gpu,
MultiProcessTestCase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
@ -40,7 +40,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_XPU,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@ -107,11 +106,9 @@ class ComposabilityTest(MultiProcessTestCase):
def device(self):
return self.rank
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
def test_pp_and_dcp(self):
"""
Test that pipeline parallelism and distributed checkpointing can be used together and
@ -201,11 +198,9 @@ class ComposabilityTest(MultiProcessTestCase):
_dcp_test(self)
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@parametrize(
"ScheduleClass",
[
@ -355,11 +350,9 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@parametrize(
"ScheduleClass",
[
@ -550,11 +543,9 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@parametrize(
"ScheduleClass",
[

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch
@ -18,8 +17,8 @@ from torch.distributed.algorithms.ddp_comm_hooks import (
)
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
DistributedTestBase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
@ -30,9 +29,12 @@ if TEST_WITH_DEV_DBG_ASAN:
sys.exit(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def gpus_for_rank(world_size):
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
visible_devices = list(range(torch.accelerator.device_count()))
gpus_per_process = torch.accelerator.device_count() // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
@ -60,27 +62,7 @@ class TestDdpCommHook(nn.Module):
return self.t0(x ** (1 + rank))
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
try:
os.remove(self.file_name)
except OSError:
pass
def _get_process_group_nccl(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
return dist.distributed_c10d._get_default_group()
class DistributedDataParallelCommHookTest(DistributedTestBase):
@property
def world_size(self):
return 2
@ -119,14 +101,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
param = next(model.parameters())
return param.grad
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_allreduce_hook(self):
"""
This unit test verifies the ``allreduce`` hook registered case gives same result
with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -135,14 +117,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_fp16compress_hook(self):
"""
This unit test verifies the ``fp16 compress`` hook registered case
gives close result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -151,14 +133,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
"""
This unit test verifies the ``quantize per tensor`` hook registered case
gives close result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -167,14 +149,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_quantize_per_channel_hook(self):
"""
This unit test verifies the ``quantize per channel`` hook registered case
gives close result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -185,14 +167,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_noop_hook(self):
"""
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
gives same result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -204,10 +186,10 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_is_last_hook(self):
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
def hook(flags, bucket):
flags.append(bucket.is_last())

View File

@ -32,7 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
class TestStateDictUtils(DTensorTestBase):
@property
def world_size(self):
return min(4, torch.cuda.device_count())
return min(4, torch.accelerator.device_count())
@with_comms
@skip_if_lt_x_gpu(2)
@ -49,7 +49,7 @@ class TestStateDictUtils(DTensorTestBase):
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type)
@with_comms
@skip_if_lt_x_gpu(4)
@ -69,14 +69,16 @@ class TestStateDictUtils(DTensorTestBase):
)
if dist.get_rank() in (0, 2):
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
self.assertNotEqual(
gathered_state_dict["dtensor"].device.type, self.device_type
)
else:
self.assertEqual(gathered_state_dict, {})
@with_comms
@skip_if_lt_x_gpu(4)
def test_cpu_and_ranks_only(self):
device = torch.device("cuda")
device = torch.device(self.device_type)
state_dict = {
"tensor1": torch.arange(10, device=device),
"tensor2": torch.ones(10, device=device),
@ -85,7 +87,7 @@ class TestStateDictUtils(DTensorTestBase):
cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
if dist.get_rank() in (0, 2):
for v in cpu_state_dict.values():
self.assertFalse(v.is_cuda)
self.assertNotEqual(v.device.type, self.device_type)
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
else:
@ -109,27 +111,27 @@ class TestStateDictUtils(DTensorTestBase):
for _ in range(10):
tensor, dtensor = create_dtensor()
ltensor.append(tensor)
ltensor.append(torch.ones(10, device=torch.device("cuda")))
ltensor.append(torch.ones(10, device=torch.device(self.device_type)))
ldtensor.append(dtensor)
ldtensor.append(torch.ones(10, device=torch.device("cuda")))
ldtensor.append(torch.ones(10, device=torch.device(self.device_type)))
tensor, dtensor = create_dtensor()
dist_state_dict = {
"local": dtensor,
"list": ldtensor,
"arange": torch.arange(10, device=torch.device("cuda")),
"arange": torch.arange(10, device=torch.device(self.device_type)),
}
state_dict = {
"local": tensor,
"list": ltensor,
"arange": torch.arange(10, device=torch.device("cuda")),
"arange": torch.arange(10, device=torch.device(self.device_type)),
}
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
@with_comms
@skip_if_lt_x_gpu(2)
def test_create_cpu_state_dict(self):
device = torch.device("cuda")
device = torch.device(self.device_type)
rank = dist.get_rank()
# Scale tensors based on world size
# to fit in the tensor shards accurately.
@ -149,7 +151,7 @@ class TestStateDictUtils(DTensorTestBase):
metadata=ShardMetadata(
shard_offsets=[5 * rank, 0],
shard_sizes=[5, 10],
placement=f"rank:{rank}/cuda:{rank}",
placement=f"rank:{rank}/{self.device_type}:{rank}",
),
)
],
@ -159,7 +161,7 @@ class TestStateDictUtils(DTensorTestBase):
torch.arange(50 * scale_factor, device=device).reshape(
5 * scale_factor, 10
),
init_device_mesh("cuda", mesh_shape=(self.world_size,)),
init_device_mesh(self.device_type, mesh_shape=(self.world_size,)),
[Shard(0)],
),
"non_tensor_bytes_io": copy.deepcopy(buffer),
@ -245,7 +247,7 @@ class TestStateDictUtils(DTensorTestBase):
even_tensor = torch.randn(self.world_size, 2)
uneven_tensor = torch.randn(1, 2)
mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
even_dtensor = distribute_tensor(
torch.randn(self.world_size, 2), mesh, [Shard(0)]
)
@ -273,10 +275,10 @@ class TestStateDictUtils(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
def test_cpu_offload_for_dtensor(self):
device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
sd = {
"k": DTensor.from_local(
torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)]
torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)]
)
}
cpu_sd = _create_cpu_state_dict(sd)
@ -290,12 +292,12 @@ class TestStateDictUtils(DTensorTestBase):
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
_copy_state_dict(sd, cpu_sd, non_blocking=True)
torch.cuda.synchronize()
torch.accelerator.synchronize()
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
sd["k"] += 1
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
_copy_state_dict(sd, cpu_sd, non_blocking=True)
torch.cuda.synchronize()
torch.accelerator.synchronize()
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))

View File

@ -7,7 +7,7 @@
import copy
import sys
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import Any, cast
import numpy as np
@ -40,7 +40,6 @@ from torch.testing._internal.common_distributed import (
skip_if_rocm_multiprocess,
skip_if_win32,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -57,7 +56,17 @@ except ImportError:
HAS_TORCHVISION = False
device_type = str(get_devtype())
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
@contextmanager
def deterministic_algorithms(enabled=True):
prev_state = torch.are_deterministic_algorithms_enabled()
torch.use_deterministic_algorithms(enabled)
try:
yield
finally:
torch.use_deterministic_algorithms(prev_state)
class TestZeroRedundancyOptimizer(DistributedTestBase):
@ -1241,7 +1250,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
enabled=True, deterministic=True, benchmark=False
)
if "cuda" in device
else torch.use_deterministic_algorithms(True)
else deterministic_algorithms(True)
)
with det_ctx:
device_ids = [rank] if requires_ddp_rank(device) else None

View File

@ -32,6 +32,7 @@ from torch.distributed.tensor._ops._einsum_strategy import (
)
from torch.distributed.tensor._ops.utils import (
register_op_strategy,
register_single_dim_strategy,
replicate_op_strategy,
)
from torch.distributed.tensor.debug import CommDebugMode
@ -655,5 +656,202 @@ TestStrategyHashingWithLocalTensor = create_local_tensor_test_class(
TestStrategyHashing,
)
class TestSingleDimStrategy(DTensorTestBase):
@with_comms
def test_register_single_dim_strategy_replaces_existing_rule(self):
"""
Test that calling register_single_dim_strategy works and replaces an existing registered rule.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_mm_like_strategy,
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test a specific input sharding combination
lhs_placement = (Shard(1),)
rhs_placement = (Shard(0),)
lhs_spec = DTensorSpec(mesh, lhs_placement, lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, rhs_placement, rhs_tensor_meta)
# Create the OpSchema for mm operation
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get the strategies from the old mm_like_strategy (what was used before)
old_style_strategy = _mm_like_strategy("mk,kn->mn", mesh, op_schema)
# Get the strategies from the new register_single_dim_strategy approach
# First, we need to get the single dim strategy function
def mm_single_dim_strategy_func(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Now expand it to full strategy using the same logic as register_single_dim_strategy
single_dim_strategies = mm_single_dim_strategy_func(op_schema)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
new_style_strategy = OpStrategy(all_strategies)
# Verify that both strategies produce the same set of shardings
old_strategy_set = {str(strategy) for strategy in old_style_strategy.strategies}
new_strategy_set = {str(strategy) for strategy in new_style_strategy.strategies}
self.assertEqual(
old_strategy_set,
new_strategy_set,
"Old and new strategies should produce the same shardings",
)
# Verify that the registration actually works by checking the propagator
propagator = DTensor._op_dispatcher.sharding_propagator
# Save the original strategy if it exists
original_strategy = None
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
original_strategy = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
try:
# Register a custom single-dim strategy
@register_single_dim_strategy(torch.ops.aten.mm.default)
def custom_mm_single_dim_strategy(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Verify the strategy was registered
self.assertIn(
torch.ops.aten.mm.default,
propagator.op_strategy_funcs,
"Strategy should be registered after calling register_single_dim_strategy",
)
# Verify it replaced any existing rule
registered_func = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
self.assertIsNotNone(
registered_func, "Registered strategy function should not be None"
)
# Test that the registered strategy produces valid output
result_strategy = registered_func(op_schema)
self.assertIsInstance(
result_strategy, OpStrategy, "Result should be an OpStrategy"
)
self.assertGreater(
len(result_strategy.strategies),
0,
"Strategy should contain at least one OpSpec",
)
finally:
# Restore original strategy if it existed
if original_strategy is not None:
propagator.op_strategy_funcs[torch.ops.aten.mm.default] = (
original_strategy
)
else:
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[torch.ops.aten.mm.default]
# Clear the cache
propagator.propagate_op_sharding.cache.cache_clear()
@with_comms
def test_single_dim_strategy_shardings_match_full_strategy(self):
"""
Verify that the shardings produced by a single-dim strategy match those produced
by the full strategy implementation.
"""
from torch.distributed.tensor._ops._matrix_ops import (
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test multiple input sharding combinations
mm_combs = (
(Shard(0), Replicate()),
(Replicate(), Shard(1)),
(Shard(1), Shard(0)),
(Replicate(), Replicate()),
)
for lhs_placement, rhs_placement in mm_combs:
lhs_spec = DTensorSpec(mesh, (lhs_placement,), lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, (rhs_placement,), rhs_tensor_meta)
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get single-dim strategies
single_dim_strategies = gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Expand to full strategy (mimicking what register_single_dim_strategy does)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
expanded_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
expanded_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
# Verify that for the given input shardings, we can find a matching strategy
# with zero redistribute cost
found_zero_cost_strategy = False
for strategy in expanded_strategies:
if strategy.input_specs == (lhs_spec, rhs_spec):
# This strategy should have zero redistribute cost since inputs match
found_zero_cost_strategy = True
# In a real strategy, redistribute costs would be computed
# Here we just verify the structure is correct
self.assertEqual(
len(strategy.input_specs),
2,
"MM should have exactly 2 input specs",
)
self.assertIsNotNone(
strategy.output_specs, "Output spec should not be None"
)
break
self.assertTrue(
found_zero_cost_strategy,
f"Should find a strategy matching input shardings {lhs_placement}, {rhs_placement}",
)
if __name__ == "__main__":
run_tests()

View File

@ -331,6 +331,25 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
self.assertEqual(z.placements, (Replicate(),))
self.assertEqual(z.to_local(), input)
def test_inplace_op_partial_to_replicate(self):
# test that in-place operations that require redistribution raise an error
# to preserve aliasing semantics (issue #163374)
device_mesh = self.build_device_mesh()
input_tensor = torch.tensor(64.0, device=self.device_type)
partial_dt = DTensor.from_local(
input_tensor, device_mesh, placements=(Partial(),)
)
self.assertTrue(partial_dt.placements[0].is_partial())
# Inplace ops that require placement changes (Partial -> Replicate) should error
with self.assertRaisesRegex(
RuntimeError,
"in-place operations that require placement changes are not supported",
):
partial_dt.clamp_(max=10)
if __name__ == "__main__":
run_tests()

View File

@ -24,7 +24,7 @@ from torch.distributed._functional_collectives import (
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
DistributedTestBase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
@ -59,12 +59,8 @@ if not dist.is_available():
sys.exit(0)
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestWithNCCL(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@requires_accelerator_dist_backend()
class TestWithNCCL(DistributedTestBase):
@property
def world_size(self) -> int:
return 2
@ -78,16 +74,7 @@ class TestWithNCCL(MultiProcessTestCase):
return torch.device(self.rank)
def _init_process_group(self) -> None:
torch.accelerator.set_device_index(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
backend = dist.get_default_backend_for_device(self.device.type)
dist.init_process_group(
backend=backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
self.create_pg(self.device.type)
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
@skip_if_lt_x_gpu(2)

View File

@ -11,13 +11,10 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
from torch.testing._internal.common_utils import (
run_tests,
skipIfHpu,
TEST_CUDA,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
@ -29,16 +26,8 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
if TEST_HPU:
DEVICE = "hpu"
elif TEST_CUDA:
DEVICE = "cuda"
else:
DEVICE = "cpu"
device_module = torch.get_device_module(DEVICE)
device_count = device_module.device_count()
BACKEND = dist.get_default_backend_for_device(DEVICE)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()
def with_comms(func=None):
@ -49,11 +38,10 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if DEVICE != "cpu" and device_count < self.world_size:
if device_type != "cpu" and device_count < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
kwargs["device"] = DEVICE
self.pg = self.create_pg(device=DEVICE)
self.pg = self.create_pg(device=device_type)
try:
return func(self, *args, **kwargs)
finally:
@ -64,7 +52,7 @@ def with_comms(func=None):
class TestObjectCollectives(DistributedTestBase):
@with_comms()
def test_all_gather_object(self, device):
def test_all_gather_object(self):
output = [None] * dist.get_world_size()
dist.all_gather_object(object_list=output, obj=self.rank)
@ -72,7 +60,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(i, v, f"rank: {self.rank}")
@with_comms()
def test_gather_object(self, device):
def test_gather_object(self):
output = [None] * dist.get_world_size() if self.rank == 0 else None
dist.gather_object(obj=self.rank, object_gather_list=output)
@ -82,7 +70,7 @@ class TestObjectCollectives(DistributedTestBase):
@skipIfHpu
@with_comms()
def test_send_recv_object_list(self, device):
def test_send_recv_object_list(self):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
if self.rank == 0:
@ -96,7 +84,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(None, object_list[0])
@with_comms()
def test_broadcast_object_list(self, device):
def test_broadcast_object_list(self):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
# TODO test with broadcast_object_list's device argument
@ -105,7 +93,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(99, object_list[0])
@with_comms()
def test_scatter_object_list(self, device):
def test_scatter_object_list(self):
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
output_list = [None]
dist.scatter_object_list(
@ -123,34 +111,30 @@ class TestObjectCollectives(DistributedTestBase):
my_pg = dist.new_group(ranks, use_local_synchronization=True)
return rank, ranks, my_pg
@skipIfHpu
@with_comms()
def test_subpg_scatter_object(self, device):
def test_subpg_scatter_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
self.assertEqual(rank, out_list[0])
@skipIfHpu
@with_comms()
def test_subpg_all_gather_object(self, device):
def test_subpg_all_gather_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks)
dist.all_gather_object(out_list, rank, group=my_pg)
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_gather_object(self, device):
def test_subpg_gather_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks) if rank == ranks[0] else None
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
if rank == ranks[0]:
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_broadcast_object(self, device):
def test_subpg_broadcast_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
if rank == ranks[0]:
@ -159,7 +143,5 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(ranks[0], out_list[0])
devices = ("cpu", "cuda", "hpu")
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -29,7 +29,7 @@ from torch.distributed.tensor._collective_utils import (
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
os.environ["LOCAL_RANK"] = f"{local_rank}"
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.")
class DeviceMeshTestGlooBackend(DTensorTestBase):
@property
def backend(self):

View File

@ -208,6 +208,21 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
)
self.assertEqual(y, expected)
def test_get_remote_tensors(self) -> None:
"""
Get all remote tensors
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
my_tensor = symm_mem.empty(1, device=self.device).fill_(self.rank)
remote_tensors = torch.ops.symm_mem.get_remote_tensors(my_tensor, group_name)
dist.barrier()
for peer, tensor in enumerate(remote_tensors):
self.assertEqual(tensor, peer)
@skipIfRocm
def test_nvshmem_put(self) -> None:
self._init_device()

View File

@ -1,9 +1,11 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -13,13 +15,16 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -599,6 +604,92 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -952,7 +952,9 @@ User code traceback:
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: skip: from user code at:
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
assert x is None
""",
@ -1078,6 +1080,88 @@ from user code:
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback(self, records):
def fn(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
self.assertExpectedInline(
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
"""\
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message(self, records):
def fn(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
if x.sum() > 0:
""",
)
@make_logging_test(dynamo=logging.DEBUG)
def test_skip_frame_empty_function_message(self, records):
def empty_fn(x):
pass
torch.compile(empty_fn, backend="eager")(torch.randn(3))
skip_messages = [
r
for r in records
if "intentionally decided to skip the frame" in r.getMessage()
]
self.assertEqual(len(skip_messages), 1)
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
self.assertExpectedInline(
msg,
"""\
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
Reason: no content in function call empty_fn test_error_messages.py N""",
)
@make_logging_test(graph_breaks=True)
def test_nested_compile_user_frames(self, records):
def fn(x):
@ -1624,6 +1708,110 @@ from user code:
)
class NestedGraphBreakLoggingTests(
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
):
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
def f2(x):
return f1(x + 2)
def f3(x):
return f2(x + 3)
torch.compile(f3, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
User code traceback:
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
torch.compile(f3, backend="eager")(torch.randn(3))
File "test_error_messages.py", line N, in f3
return f2(x + 3)
File "test_error_messages.py", line N, in f2
return f1(x + 2)
File "test_error_messages.py", line N, in f1
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
def f2(x):
return f1(x + 4)
def f3(x):
return f2(x + 5)
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
User code traceback:
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
File "test_error_messages.py", line N, in f3
return f2(x + 5)
File "test_error_messages.py", line N, in f2
return f1(x + 4)
File "test_error_messages.py", line N, in f1
if x.sum() > 0:
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -14036,6 +14036,44 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
except Exception as e:
self.fail(f"torch.compile failed with error: {e}")
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorify_track_item_symint(self):
def _random_resize(image: torch.Tensor):
image_metanet = image
default_patch_size = 14
rand_cnn_resolution = (224, 256)
min_nump = rand_cnn_resolution[0] // default_patch_size
max_nump = rand_cnn_resolution[1] // default_patch_size
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
torch._check(new_nump > 0)
torch._check(new_nump * default_patch_size > 1)
image_metanet = F.interpolate(
image_metanet,
size=(new_nump * default_patch_size, new_nump * default_patch_size),
mode="bilinear",
align_corners=True,
)
img_h_new, img_w_new = image_metanet.shape[2:]
return (img_h_new, img_w_new), image_metanet
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
# Test the function
input_tensor = torch.rand(1, 3, 224, 224)
(h, w), output = _random_resize_compiled(input_tensor)
# Verify output properties
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 3)
self.assertEqual(output.shape[2], h)
self.assertEqual(output.shape[3], w)
self.assertTrue(h % 14 == 0)
self.assertTrue(w % 14 == 0)
self.assertTrue(224 <= h <= 256)
self.assertTrue(224 <= w <= 256)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -3249,7 +3249,14 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
V_sliced = V[:, :, :-128]
out_eager = flex_attention(Q, K_sliced, V_sliced)
out_compiled = func(Q, K_sliced, V_sliced)
out_compiled, code = run_and_get_code(func, Q, K_sliced, V_sliced)
# Make sure flex attention kernels have flex_attention in name
FileCheck().check_regex("triton_tem_fused_flex_attention.*").run(code[0])
FileCheck().check_regex("triton_tem_fused_flex_attention_backward.*").run(
code[1]
)
grad = torch.rand_like(out_eager)

View File

@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2672,8 +2701,10 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2700,7 +2731,10 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -53,6 +53,7 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -166,7 +167,8 @@ class AOTCompiledFunction:
state = pickle.loads(data)
state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
deserializer, compiled_fn_state = state["compiled_fn"]
state["compiled_fn"] = deserializer(compiled_fn_state)
with torch._inductor.config.patch(enable_autograd_for_aot=True):
state["compiled_fn"] = deserializer(compiled_fn_state)
state["original_code"] = SerializedCode.to_code_object(state["original_code"])
artifacts = CompileArtifacts(**state)
@ -273,6 +275,7 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -1870,7 +1870,7 @@ class ConvertFrame:
raise
soft_fail = isinstance(e, Unsupported)
code = frame.f_code
# This is a soft failure. In the sense, the code path reaches here
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
# BUILD_SET etc. In such case, we can fallback to eager without
@ -1885,7 +1885,13 @@ class ConvertFrame:
user_stack_formatted = "".join(
traceback.format_list(user_stack)
)
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
frame_info = exc.format_frame_info(code)
user_stack_trace = (
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
"The graph break occurred in the following user code:\n"
f"{user_stack_formatted}"
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
@ -1897,6 +1903,7 @@ class ConvertFrame:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=config.verbose,
)
if not config.suppress_errors and not soft_fail:

View File

@ -794,6 +794,38 @@ def format_error_msg_verbose(
return msg
def format_frame_info(code: types.CodeType) -> str:
return (
f"{getattr(code, 'co_name', '<unknown>')} "
f"({getattr(code, 'co_filename', '<unknown>')} "
f"line {getattr(code, 'co_firstlineno', 0)})"
)
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
if code is not None:
frame_info = format_frame_info(code)
return (
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: {reason}"
)
else:
return (
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
f"Reason: {reason}"
)
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
frame_info = format_frame_info(code)
return (
"Skipping frame because there is a graph break in a for/while loop\n"
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
f"{frame_summary}"
)
def format_error_msg(
exc: Exception,
code: types.CodeType,

View File

@ -94,6 +94,8 @@ from .exc import (
BackendCompilerFailed,
collapse_resume_frames,
format_graph_break_message,
format_loop_skip_frame_message,
format_skip_frame_message,
get_stack_above_dynamo,
ResumePrologueTracingError,
StepUnsupported,
@ -605,9 +607,9 @@ def generic_jump(
)
# compile a partial subgraph prefix then jump into user code
if self.maybe_has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
msg = format_loop_skip_frame_message(
self.f_code,
"".join(traceback.format_list([self.frame_summary()])),
)
log.info(msg)
raise exc.SkipFrame(msg)
@ -883,9 +885,9 @@ def break_graph_if_unsupported(
)
if self.maybe_has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
msg = format_loop_skip_frame_message(
self.f_code,
"".join(traceback.format_list([self.frame_summary()])),
)
log.info(msg)
raise exc.SkipFrame(msg) from excp
@ -4626,8 +4628,9 @@ class InstructionTranslator(InstructionTranslatorBase):
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
):
raise exc.SkipFrame("because no content in function call")
raise exc.SkipFrame(
format_skip_frame_message(self.f_code, "no content in function call")
)
self.instruction_pointer = None
_step_logger()(
logging.INFO,

View File

@ -2248,12 +2248,15 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
try:
val.data_ptr() # will throw for functorch tensors
except RuntimeError as e:
from .exc import SkipFrame
from .exc import format_skip_frame_message, SkipFrame
# This will be GradTrackingTensor/BatchedTensor/etc
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame(
f"torch.compile cannot be run in context: {functorch_subclass_name}"
format_skip_frame_message(
None,
f"torch.compile cannot be run in context: {functorch_subclass_name}",
)
) from e

View File

@ -42,6 +42,7 @@ from torch._guards import Source
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
format_skip_frame_message,
get_dynamo_observed_exception,
handle_observed_exception,
InfiniteGeneratorError,
@ -1652,8 +1653,13 @@ class SkipFunctionVariable(VariableTracker):
skip_frame_msg = kwargs.get("msg")
if skip_frame_msg:
skip_frame_msg = skip_frame_msg.as_python_constant()
else:
skip_frame_msg = ""
raise SkipFrame(
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
format_skip_frame_message(
tx.f_code,
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
)
)
elif self.value is torch._dynamo.step_unsupported:
raise StepUnsupported

View File

@ -3652,24 +3652,26 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
# - lifted args from tracing subgraph: [score_mod_other_buffers, mask_fn_other_buffers]
_, _, _, inp_arg_block_mask, inp_arg_scale, inp_arg_kernel_options = inp_args
block_mask = tuple(inp_arg_block_mask + (mask_fn_node,))
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=inp_args[:3]
+ (
score_mod_node,
block_mask,
inp_arg_scale,
inp_arg_kernel_options,
score_mod_lifted_args,
mask_fn_lifted_args,
with torch.fx.experimental.proxy_tensor.set_original_aten_op(self.value):
proxy = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=inp_args[:3]
+ (
score_mod_node,
block_mask,
inp_arg_scale,
inp_arg_kernel_options,
score_mod_lifted_args,
mask_fn_lifted_args,
),
kwargs={},
),
kwargs={},
),
example_value=None,
)
example_value=None,
)
return proxy
class AutogradFunctionApplyVariable(VariableTracker):

View File

@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -356,9 +356,10 @@ def trace_flex_attention(
)
# pyrefly: ignore [missing-attribute]
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", flex_attention, proxy_args, {}
)
with torch.fx.experimental.proxy_tensor.set_original_aten_op(flex_attention):
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", flex_attention, proxy_args, {}
)
return track_tensor_tree(
example_out,
out_proxy,
@ -1114,23 +1115,26 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
assert mode is not None, "Mode should always be enabled for python fallback key"
return trace_flex_attention_backward(
mode,
query,
key,
value,
out,
logsumexp,
grad_out,
grad_logsumexp,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
with torch.fx.experimental.proxy_tensor.set_original_aten_op(
flex_attention_backward
):
return trace_flex_attention_backward(
mode,
query,
key,
value,
out,
logsumexp,
grad_out,
grad_logsumexp,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
@flex_attention_backward.py_functionalize_impl

View File

@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2713,7 +2715,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation:
if V.aot_compilation and not config.enable_autograd_for_aot:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()

View File

@ -1190,6 +1190,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -773,9 +773,86 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if not config.enable_autograd_for_aot:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
def post_compile(
self,
@ -783,10 +860,8 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
pass
def prepare_for_serialization(self) -> None:
pass
if self.current_callable is None:
self.__post_init__()
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -781,9 +781,19 @@ def get_fused_kernel_name(
) -> str:
all_origins = aggregate_origins(node_schedule)
if descriptive_names == "original_aten":
def get_origin_meta_str(origin):
original_aten = origin.meta["original_aten"]
key = ""
if isinstance(original_aten, torch._ops.OpOverload):
key = original_aten._overloadpacket.__name__
elif isinstance(original_aten, torch._ops.HigherOrderOperator):
key = str(original_aten.name())
return key
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
sources = [
origin.meta["original_aten"]._overloadpacket.__name__
get_origin_meta_str(origin)
for origin in all_origins
if origin.op == "call_function"
and "original_aten" in origin.meta
@ -794,12 +804,22 @@ def get_fused_kernel_name(
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
sources = []
for origin in all_origins:
if origin.op == "call_function" and "source_fn_stack" in origin.meta:
source_fn = origin.meta["source_fn_stack"][-1]
if origin.op == "call_function":
source_fn = None
suffix = ""
if "source_fn_stack" in origin.meta:
source_fn = origin.meta["source_fn_stack"][-1]
elif "fwd_source_fn_stack" in origin.meta:
# backward nodes have "fwd_source_fn_stack" instead
source_fn = origin.meta["fwd_source_fn_stack"][-1]
suffix = "backward"
if not source_fn:
continue
if isinstance(source_fn[1], str):
sources.append(source_fn[1])
sources.append(source_fn[1] + suffix)
else:
sources.append(source_fn[1].__name__)
sources.append(source_fn[1].__name__ + suffix)
sources = sorted(OrderedSet(sources))
elif descriptive_names == "inductor_node":
sources = [
@ -852,11 +872,20 @@ def get_kernel_metadata(
for node in inductor_nodes:
if "original_aten" in node.meta and node.meta["original_aten"] is not None:
key = str(node.meta["original_aten"]._overloadpacket)
original_aten_dict[key].append(node.name)
original_aten = node.meta["original_aten"]
key = None
if isinstance(original_aten, torch._ops.OpOverload):
key = str(original_aten._overloadpacket)
elif isinstance(original_aten, torch._ops.HigherOrderOperator):
key = str(original_aten.name())
if key:
original_aten_dict[key].append(node.name)
if "from_node" in node.meta:
key = node.meta["from_node"][0].name
from_node_dict[key].append(node.name)
elif node.meta.get("partitioner_tag") == "is_backward":
# backward nodes currently don't have a "from node"
from_node_dict[node.name].append(node.name)
sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
metadata = (
f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "

View File

@ -891,10 +891,14 @@ class TorchLogsFormatter(logging.Formatter):
# exception handling - copied from logging.Formatter.format
s = record.message
if record.exc_info:
from torch._dynamo import config
should_format_exc = config.verbose or artifact_name != "graph_breaks"
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if should_format_exc:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if s[-1:] != "\n":
s = s + "\n"

View File

@ -24,7 +24,73 @@ __all__ = [
class EventList(list):
"""A list of Events (for pretty printing)."""
"""A list of profiling events with helper methods for analysis and visualization.
EventList extends the standard Python list to provide specialized methods for
working with profiling events (FunctionEvent or FunctionEventAvg objects).
It includes utilities for aggregating statistics, formatting output tables,
and exporting profiling data.
This class is typically returned by profiler methods and should not be
instantiated directly by users.
Args:
*args: Standard list arguments.
use_device (str, optional): Device type for profiling ("cuda", "xpu", etc.).
profile_memory (bool, optional): Whether memory profiling was enabled. Default: False.
with_flops (bool, optional): Whether to include FLOP counts. Default: False.
Attributes:
_use_device (str): Device type being profiled.
_profile_memory (bool): Whether memory profiling is enabled.
_with_flops (bool): Whether FLOP counting is enabled.
_tree_built (bool): Whether the event tree structure has been built.
Key Methods:
table(...): Format events as a table string for display.
export_chrome_trace(path): Export to Chrome tracing format.
export_stacks(path, metric): Export stack traces with metrics.
key_averages(...): Compute averaged statistics grouped by operation name.
total_average(): Compute aggregate totals across all events (sums, not averages).
Properties:
self_cpu_time_total: Sum of self CPU time across all events.
Example::
import torch
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU]) as prof:
x = torch.randn(100, 100)
y = torch.matmul(x, x)
# EventList is returned by prof.events()
events = prof.events()
# Display as formatted table
print(
events.table(
sort_by="cpu_time_total", row_limit=20, top_level_events_only=False
)
)
# Export to Chrome tracing format
events.export_chrome_trace("trace.json")
# Get averaged statistics
avg_events = events.key_averages()
print(avg_events.table())
# Export stack traces
events.export_stacks("stacks.txt", "self_cpu_time_total")
See Also:
- :class:`FunctionEvent`: Individual profiling event
- :class:`FunctionEventAvg`: Averaged profiling statistics
- :meth:`table`: Format events as a readable table
- :meth:`key_averages`: Aggregate events by operation name
"""
def __init__(self, *args, **kwargs):
use_device = kwargs.pop("use_device", None)
@ -373,10 +439,23 @@ class EventList(list):
return avg_list
def total_average(self):
"""Averages all events.
"""Compute aggregate statistics across all events.
Accumulates statistics from all events into a single FunctionEventAvg object.
This is primarily useful for computing total metrics (total CPU time, total
memory usage, etc.) across the entire profiling session, regardless of
operation type.
Note:
This sums up times and counts across ALL different operations, so the
"average" metrics (like cpu_time) represent the average time per operation
call across the entire session, mixing all operation types together.
For per-operation averages, use :meth:`key_averages` instead.
Returns:
A FunctionEventAvg object.
FunctionEventAvg: A single aggregate object with key="Total" containing
accumulated statistics.
"""
total_stat = FunctionEventAvg()
for evt in self:
@ -471,7 +550,64 @@ Kernel = namedtuple("Kernel", ["name", "device", "duration"])
class FunctionEvent(FormattedTimesMixin):
"""Profiling information about a single function."""
"""Profiling information about a single function.
FunctionEvent records the execution of a single operation during profiling.
These events are obtained from the profiler/kineto and contain detailed
timing and memory usage information.
.. note::
FunctionEvent objects are typically created by the profiler/kineto and should not
be instantiated directly by users. Access them through the profiler's output.
Attributes:
id (int): Unique identifier for this event.
node_id (int): Node identifier for distributed profiling (-1 if not applicable).
name (str): Name of the profiled function/operator.
overload_name (str): Overload name for the operator (requires _ExperimentalConfig(capture_overload_names=True) set).
trace_name (str): Same as name, just changes ProfilerStep* to ProfilerStep#
time_range (Interval): Time interval containing start and end timestamps in microseconds.
thread (int): Thread ID where the operation started.
fwd_thread (int): Thread ID of the corresponding forward operation.
kernels (List[Kernel]): List of device kernels launched by this operation.
count (int): Number of times this event was called (usually 1).
cpu_children (List[FunctionEvent]): Direct CPU child operations.
cpu_parent (FunctionEvent): Direct CPU parent operation.
input_shapes (Tuple[int, ...]): Shapes of input tensors (requires record_shapes=true).
concrete_inputs (List[Any]): Concrete input values (requires record_shapes=true).
kwinputs (Dict[str, Any]): Keyword arguments (requires record_shapes=true).
stack (List[str]): Python stack trace where the operation was called (requires with_stack=true).
scope (int): at::RecordScope identifier (0=forward, 1=backward, etc.).
use_device (str): Device type being profiled ("cuda", "xpu", etc.).
cpu_memory_usage (int): CPU memory allocated in bytes.
device_memory_usage (int): Device memory allocated in bytes.
is_async (bool): Whether this is an asynchronous operation.
is_remote (bool): Whether this operation occurred on a remote node.
sequence_nr (int): Sequence number for autograd operations.
device_type (DeviceType): Type of device (CPU, CUDA, XPU, PrivateUse1, etc.).
device_index (int): Index of the device (e.g., GPU 0, 1, 2).
device_resource_id (int): Resource ID on the device (ie. stream ID).
is_legacy (bool): Whether this is from the legacy profiler.
flops (int): Estimated floating point operations.
is_user_annotation (bool): Whether this is a user-annotated region.
metadata_json (str): Additional metadata in JSON format.
Properties:
cpu_time_total (float): Total CPU time in microseconds.
device_time_total (float): Total device (CUDA/XPU/etc) time in microseconds.
self_cpu_time_total (float): CPU time excluding child operations.
self_device_time_total (float): Device time excluding child operations.
self_cpu_memory_usage (int): CPU memory usage excluding child operations.
self_device_memory_usage (int): Device memory usage excluding child operations.
cpu_time (float): Average CPU time per call.
device_time (float): Average device time per call.
key (str): Key used for grouping events (usually same as name).
See Also:
- :class:`torch.profiler.profile`: Context manager for profiling
- :class:`EventList`: List container for FunctionEvent objects with helper methods
- :class:`FunctionEventAvg`: Averaged statistics over multiple FunctionEvent objects
"""
def __init__(
self,
@ -701,7 +837,50 @@ class FunctionEvent(FormattedTimesMixin):
class FunctionEventAvg(FormattedTimesMixin):
"""Used to average stats over multiple FunctionEvent objects."""
"""Averaged profiling statistics over multiple FunctionEvent objects.
FunctionEventAvg aggregates statistics from multiple FunctionEvent objects
with the same key (typically same operation name). This is useful for getting
average performance metrics across multiple invocations of the same operation.
This class is typically created by calling :meth:`EventList.key_averages()` on
a profiler's event list.
Attributes:
key (str): Grouping key for the events (typically operation name).
count (int): Total number of events aggregated.
node_id (int): Node identifier for distributed profiling (-1 if not applicable).
is_async (bool): Whether the operations are asynchronous.
is_remote (bool): Whether the operations occurred on a remote node.
use_device (str): Device type being profiled ("cuda", "xpu", etc.).
cpu_time_total (int): Accumulated total CPU time in microseconds.
device_time_total (int): Accumulated total device time in microseconds.
self_cpu_time_total (int): Accumulated self CPU time (excluding children) in microseconds.
self_device_time_total (int): Accumulated self device time (excluding children) in microseconds.
input_shapes (List[List[int]]): Input tensor shapes (requires record_shapes=true).
overload_name (str): Operator overload name (requires _ExperimentalConfig(capture_overload_names=True) set).
stack (List[str]): Python stack trace where the operation was called (requires with_stack=true).
scope (int): at::RecordScope identifier (0=forward, 1=backward, etc.).
cpu_memory_usage (int): Accumulated CPU memory usage in bytes.
device_memory_usage (int): Accumulated device memory usage in bytes.
self_cpu_memory_usage (int): Accumulated self CPU memory usage in bytes.
self_device_memory_usage (int): Accumulated self device memory usage in bytes.
cpu_children (List[FunctionEvent]): CPU child events.
cpu_parent (FunctionEvent): CPU parent event.
device_type (DeviceType): Type of device (CPU, CUDA, XPU, PrivateUse1, etc.).
is_legacy (bool): Whether from legacy profiler.
flops (int): Total floating point operations.
is_user_annotation (bool): Whether this is a user-annotated region.
Properties:
cpu_time (float): Average CPU time per invocation.
device_time (float): Average device time per invocation.
See Also:
- :class:`EventList.key_averages`: Method that creates FunctionEventAvg objects
- :class:`FunctionEvent`: Individual profiling event
- :class:`EventList`: Container for profiling events
"""
def __init__(self) -> None:
self.key: Optional[str] = None

View File

@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -465,6 +465,39 @@ lib.define(
"_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
)
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
"""
Given a local tensor and a group name, return a tuple of tensors that are
symmetric on other devices. The returned tensors are ordered by rank IDs. The
length of the tuple equals to the size of the group.
Note: this API works only when `world_within_direct_access()` returns True, i.e.
only when the group is within NVLink domain or similar. It does not work across
network interfaces.
"""
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
def _get_remote_tensors_default(
local: torch.Tensor, group_name: str
) -> tuple[torch.Tensor, ...]:
hdl = rendezvous(local, group_name)
if hdl is None:
raise ValueError("Tensor is not allocated from Symmetric Memory")
return tuple(
hdl.get_remote_tensor(peer, local.size(), local.dtype)
for peer in range(hdl.world_size)
)
@torch.library.impl(lib, "get_remote_tensors", "Meta")
def _get_remote_tensors_meta(
local: torch.Tensor, group_name: str
) -> tuple[torch.Tensor, ...]:
group = c10d._resolve_process_group(group_name)
return tuple(torch.empty_like(local) for _ in range(group.size()))
class _ScaleMode(Enum):
UNSCALED = "unscaled"

View File

@ -337,19 +337,34 @@ class OpDispatcher:
if is_inplace_op:
# inplace op should return self instead of re-wrapping
if output_sharding.output_spec is not None:
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], dtensor.DTensor)
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
# the inplace argument's tensor meta. Here we choose to special case
# this op because as far as I know this is the only inplace op that
# has such as behavior. We can extend this special case if necessary.
if op_call == aten.squeeze_.dim:
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], dtensor.DTensor)
# update the spec to handle tensor meta changes
args[0]._spec = output_spec
# use return_and_correct_aliasing to match the outer and the inner
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
else:
# For all other inplace ops, check if placement changes are required
# Inplace operations that change placement are not supported because
# they would require redistribution, which breaks aliasing semantics.
# If there are views into the tensor, the views would not be updated.
if args[0]._spec.placements != output_spec.placements:
raise RuntimeError(
f"{op_call}: in-place operations that require placement changes "
f"are not supported. The operation would change placement from "
f"{args[0]._spec.placements} to {output_spec.placements}, "
f"which requires redistribution and breaks aliasing semantics. "
f"Please use the out-of-place version of this operation instead."
)
# Most inplace ops don't change tensor meta, so no spec update needed
return args[0]
else:
return None

View File

@ -23,6 +23,7 @@ from torch.distributed.tensor._ops.utils import (
map_placements_after_broadcast,
prod,
register_op_strategy,
register_single_dim_strategy,
)
from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
@ -237,10 +238,119 @@ def dot_strategy(op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("i,i->", mesh, op_schema)
@register_op_strategy(aten.mm.default)
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# @register_op_strategy(aten.mm.default)
# def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# mesh = op_schema.get_mesh_from_args()
# return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
from ._einsum_strategy import EinsumDims
def gen_single_dim_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> list[list[Placement]]:
"""
Generate a strategy list for the ops that follow einsum style notation.
In principle, each mesh dim is independent of other device mesh dim when we
generate strategies. So we generate strategy over each device mesh dim and
do product combination on all mesh dims. We basically follow the below rule
for each device mesh dim:
1. Shard on contracting dim: When both inputs shard on contracting dim over
the same device dim. The result will be Partial over that device dim.
2. Shard on noncontracting dim:
2.1: Shard on batch dim: output, both inputs all should shard on batch
dim.
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
input should shard on this free dim.
3. Linearity (Partial): If enabled, set Partial on output and inputs over
the same device mesh dim.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
strategies_over_one_mesh_dim = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
strategies_over_one_mesh_dim.append(placement_list)
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
# Contracting dim can shard on same device axis for both inputs. This
# results in the output being Partial on that device axis. For example:
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim_output = output_dim.index(lhs_dim)
lhs_free_dim_input = input_dims[0].index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim_output),
Shard(lhs_free_dim_input),
Replicate(),
]
strategies_over_one_mesh_dim.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim_output = output_dim.index(rhs_dim)
rhs_free_dim_input = input_dims[1].index(rhs_dim)
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim_output),
Replicate(),
Shard(rhs_free_dim_input),
]
strategies_over_one_mesh_dim.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()]
for _ in input_dims:
linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list)
return strategies_over_one_mesh_dim
@register_single_dim_strategy(aten.mm.default)
def mm_single_dim_strategy(op_schema: OpSchema) -> list[Placement]:
self_strategy, mat2_strategy = op_schema.args_schema
if not isinstance(self_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
if not isinstance(mat2_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
# generate all possible strategies for mm
mesh = op_schema.get_mesh_from_args()
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
@register_op_strategy(aten.addmm.default)

View File

@ -18,6 +18,7 @@ from torch.distributed.tensor._ops.utils import (
map_placements_after_broadcast,
normalize_dim,
register_op_strategy,
register_single_dim_strategy,
)
from torch.distributed.tensor.placement_types import (
Partial,
@ -488,6 +489,58 @@ def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
return pointwise_strategy(op_schema, linearity=linearity_type)
def single_mesh_dim_pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> list[list[Placement]]:
return single_mesh_dim_common_pointwise_strategy(op_schema.args_schema, linearity)
def single_mesh_dim_common_pointwise_strategy(
args_schema: Sequence[object],
linearity: int = -1,
scalar_tensor_idx: Optional[int] = None
) -> list[list[Placement]]:
"""
Common strategy for pointwise operations.
Args:
args_schema: Input arguments schema
linearity: depending on the operator, we support different types of linearity
-1: the operation does not support linearity
0: the unary operation that supports linearity, output propagates partial.
1: the binary operation supports add linearity, where it requires every operand
to be partial, output propagates partial.
2: the binary operation supports multiplicative linearity, where it requires
the primary operand to be partial, and the other operands to be replicate,
output propagates partial.
scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh
to be different from the mesh of followed_strategy
"""
# handle broadcasting
common_shape = torch.broadcast_shapes(
*[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)]
)
placements_list = []
for i in range(len(common_shape)):
# Shard output dim i, and then shard the corresponding arguments if they have a corresponding (non broadcast) dim
shard_placements = [Shard(i)]
for arg in args_schema:
if isinstance(arg, OpStrategy):
common_dim_to_arg_dim = infer_broadcast_dims_map(common_shape, arg.shape)
if common_dim_to_arg_dim[i] >= 0:
shard_placements.append(Shard(common_dim_to_arg_dim[i]))
else:
shard_placements.append(Replicate())
placements_list.append(shard_placements)
if linearity > 0:
# TODO implement partial
# TODO: can the same op support both add and multiplicative linearity?
pass
# TODO: handle scalar_tensor_idx
return placements_list
def common_pointwise_strategy(
args_schema: Sequence[object],
followed_strategy: OpStrategy,
@ -623,11 +676,15 @@ for op in linear_pointwise_ops:
linear_pointwise_strategy
)
for op in pointwise_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
pointwise_strategy
)
# for op in pointwise_ops:
# register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
# pointwise_strategy
# )
for op in pointwise_ops:
register_single_dim_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
single_mesh_dim_pointwise_strategy
)
# TODO: add all for_each ops
for_each_ops = [

View File

@ -42,6 +42,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
aten = torch.ops.aten
# WHC- i think anywhere this is used, we can replace it with a corresponding single-dim passthrough strategy
# (anyshard, replicate, partial can all pass through- and then expand that to the mesh dims later)
def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
# For ops with a single tensor input, we perform a 1:1 mapping such that
# for each strategy that the input supports, we create a corresponding strategy.
@ -98,6 +100,28 @@ register_op_strategy(
)(propagate_single_input_strategy)
"""
WHC- equal_strategy is an example baking an optimization into the sharding rule.
The unoptimized equal strategy (for one mesh dim) should look like this
S, S -> S
R, R -> R
P, P -> P * - this could work, i think, if we supported a Partial of boolean and reduction?
And this should be expanded to the full mesh.
But what this rule actually does is
- compare the two tensor args to equal- look at the strategies for each, which represent the I-O sharding relationship for the
op that produced those tensor args. Pick the one that has the strategy (OpSpec) with the most Shard() placements in it.
Why? becuase converting the other arg from R->S is cheaper than converting S->R
- start with the assumption that the 'equal' op has the same strategy as the op that produced its max-shard input
- then adjust the placements from partial to replicate since we don't support partial in equal
- finally, produce an OpSpec that only populates the 'output_specs' of OpSpec
TODO: why is it ok to populate only the output_specs of an OpSpec? Is it defined to mean that all input specs are the same as the output spec?
"""
@register_op_strategy(
[
aten.equal.default,
@ -141,6 +165,19 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType:
return equal_strategy
"""
WHC
seems like we could replace this with single-mesh strategy
S->S
R->R
P->R
The P->R thing is odd, but makes sense:
* can't support P->P since it would be incorrect to create a new 'partial' tensor from ones, which would no longer be ones if we replicated them
* don't want to omit the support for input Partial becuase we'd force a replication on the input which would be wasteful
"""
@register_op_strategy(
[
aten.empty_like.default,
@ -489,6 +526,19 @@ def replicate_tensor_dim(
)
"""
WHC- example of a complicated 'follow your inputs' strategy that would be useful to try out as a simple rule
seems very simple to write this way
assert input, src same ndim
for i in range(input.ndim):
if i != slice_dim:
Shard(i), Shard(i) -> Shard(i)
"""
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
# 1. number of dimensions in input and src need to match.

View File

@ -4,8 +4,7 @@ import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
from typing import cast, Optional, Union
import torch
from torch._prims_common import DimsSequenceType, DimsType
@ -28,10 +27,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
# from torch.testing._internal.distributed._tensor.common_dtensor import redistribute
# convenient wrapper to register sharding propagation rules
@ -54,11 +50,69 @@ def register_prop_rule(
return wrapper
def register_op_strategy(
op, schema_info=None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# pyre-fixme[2]: Parameter must be annotated.
def _expand_single_dim_strategy_to_mesh(single_dim_strategy: Callable[[OpSchema], list[list[Placement]]]) -> Callable[[OpSchema], StrategyType]:
"""
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
sharding types used by inputs.
"""
def expanded_strategy(op_schema: OpSchema) -> StrategyType:
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
inputs_strategy = op_schema.args_strategy
mesh = inputs_strategy[0].mesh
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
# TODO: add Replicate since its implicit in single_dim strategies
# TODO: filter out 'invalid' placements
# - ShardVar needs to say whether 'even sharding' is required or not
# copied from einsum strategy..
# TODO: identify differences between this and 'expand_' util
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
arg_specs = spec_list[1:]
src_strategies = [s for s in op_schema.args_schema if isinstance(s, OpStrategy)]
assert len(arg_specs) == len(src_strategies), "expected one src strategy per arg spec"
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:], redistribute_cost=[
generate_redistribute_costs(src_strategy, arg_spec) for (src_strategy, arg_spec) in zip(src_strategies, arg_specs)
])
)
return OpStrategy(all_strategies)
return expanded_strategy
def register_single_dim_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[
[Callable[[OpSchema], list[list[Placement]]]], Callable[[OpSchema], list[list[Placement]]]
]:
"""
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
to cover all the mesh dims present in the runtime inputs.
"""
def expanded_registration_wrapper(
single_dim_strategy: Callable[[OpSchema], list[list[Placement]]],
) -> Callable[[OpSchema], list[list[Placement]]]:
_expanded_strategy = _expand_single_dim_strategy_to_mesh(single_dim_strategy)
register_op_strategy(op, schema_info)(_expanded_strategy)
return single_dim_strategy
return expanded_registration_wrapper
def register_op_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
@ -68,7 +122,9 @@ def register_op_strategy(
"memory_format",
]
def wrapper(impl):
def wrapper(
impl: Callable[[OpSchema], StrategyType],
) -> Callable[[OpSchema], StrategyType]:
if isinstance(op, list):
overloads = op
else:
@ -159,7 +215,10 @@ def prod(xs: Iterable[int]) -> int:
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is shardable according to the spec."""
"""Check if the spec matches these criteria:
* any Shard placements in spec refer to valid tensor dims
* no empty local tensors (uneven sharding OK, as long as last rank has >0 size)
"""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
@ -225,6 +284,9 @@ def infer_broadcast_dims_map(
) -> list[int]:
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
# this is aligned with the broadcast semantics
# e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4],
# broadcast_dims_map will be [-1, 0, 1, 2]
# meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input
common_ndim = len(common_shape)
input_ndim = len(input_shape)
broadcast_dims_map = [-1] * common_ndim

View File

@ -1543,7 +1543,9 @@ ORIGINAL_ATEN: Optional[object] = None
@contextmanager
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
def set_original_aten_op(
func: OpOverload | torch._ops.HigherOrderOperator,
) -> Generator[None, None, None]:
global ORIGINAL_ATEN
if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
ORIGINAL_ATEN = func

View File

@ -207,12 +207,19 @@ def tensorify_python_scalars(
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if not dtype.is_floating_point:
continue
assert isinstance(node.args[0], fx.Node), node.args[0]
s = node.meta["val"].node.expr
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# only tensorify if the dtype is floating point
if not dtype.is_floating_point:
continue
expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode
)
@ -220,9 +227,7 @@ def tensorify_python_scalars(
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
expr_to_tensor_proxy[s], torch.float64
)
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# pyrefly: ignore [bad-argument-type]
elif (sym_expr := _get_sym_val(node)) is not None:
if sym_expr not in expr_to_sym_proxy and not isinstance(

View File

@ -43,6 +43,7 @@ from torch.distributed.tensor.parallel import (
SequenceParallel,
)
from torch.testing._internal.common_distributed import (
ACCELERATOR_DIST_BACKENDS,
MultiProcContinuousTest,
MultiProcessTestCase,
MultiThreadedTestCase,
@ -396,14 +397,17 @@ class DTensorTestBase(MultiProcessTestCase):
return init_device_mesh(self.device_type, (self.world_size,))
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
if backend is None:
backend = self.backend
requires_gpu = any(
gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS
)
if requires_gpu and torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
curr_backend = dist.get_default_backend_for_device(self.device_type)
if backend is None:
backend = self.backend
if backend not in [
"nccl",
"gloo",