Compare commits

..

14 Commits

Author SHA1 Message Date
fbb7943140 batch invariance testing 2025-11-18 13:55:54 -08:00
ebb2001a48 [codemod][lowrisk] Remove unused exception parameter from caffe2/torch/csrc/Exceptions.h (#168056)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Reviewed By: dtolnay

Differential Revision: D87273132

Pull Request resolved: https://github.com/pytorch/pytorch/pull/168056
Approved by: https://github.com/malfet, https://github.com/Skylion007
2025-11-18 20:21:48 +00:00
ae85307512 huber_loss numerical issue (#166952)
For GPU: Previously reported that only a single sample could be tested with huber_loss functional. Current snapshot of the code does not appear to suffer from numerical issues as reported before.

For CPU: While testing GPU, it was discovered that with Half appears to be numerically unstable. This commit resolves issue with CPU by upcasting Half to float for the computation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166952
Approved by: https://github.com/benjaminglass1, https://github.com/isuruf
2025-11-18 20:06:29 +00:00
7921c0eb0e [ROCm][CI] Limit caching to ROCm jammy docker images (#168088)
Since the currently intended workflow on the new MI3xx CI capacity is [trunk-rocm-mi300.yml](d91269e8ce/.github/workflows/trunk-rocm-mi300.yml (L54)), which only needs the jammy images, limiting those to optimize docker caching times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/168088
Approved by: https://github.com/jeffdaily
2025-11-18 20:04:20 +00:00
dda2cb3769 Handled erased hiding nodes from dtype bucketing (#167863)
The bucketing dtype fusing was causing nodes which had dependencies to be erased. Transfer those deps over to the new nodes, and also add an assertion that none of our deps are erased to catch this type of error in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167863
Approved by: https://github.com/fmassa
ghstack dependencies: #167852, #167853
2025-11-18 19:50:08 +00:00
4c5042b368 Fix all gather bucketing fusion in of dtype casts (#167853)
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it.

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167853
Approved by: https://github.com/ruisizhang123
ghstack dependencies: #167852
2025-11-18 19:50:08 +00:00
e3c5b78999 small changes (#167852)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167852
Approved by: https://github.com/fmassa
2025-11-18 19:50:08 +00:00
14f370f551 [xpu][test] port some distributed tensor test files for Intel GPU (#161703)
it's another pr to port distributed tensor test for Intel GPU, while the other pr is https://github.com/pytorch/pytorch/pull/161604
We could enable Intel GPU with following methods and try the best to keep the original code styles:

Use torch.accelerator for general gpu
Skip the case if running on xpu which has known issues

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161703
Approved by: https://github.com/guangyey, https://github.com/d4l3k, https://github.com/albanD
2025-11-18 19:49:44 +00:00
aa22d41f9b [refcycle-logger] Output tensor size in the refcycle visualization (#167079)
Summary:
As title.

Knowing the size of the leaked tensor is useful, it allows us to focus on the largest leaks.

Differential Revision: D86218574

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167079
Approved by: https://github.com/kausv
2025-11-18 19:48:15 +00:00
d1f6dd6105 distributed/debug: add an HTTP server for debugging running jobs (#167395)
This adds a debug HTTP server for debugging stuck or slow jobs. It runs the WorkerServer on every worker and then launches a separate flask process on rank 0 to have users connect to for debugging.

This can easily be improved to trigger profilers as well as visualize the data much better.

Initial handlers:
* pytorch profiler
* FlightRecorder data
* Python stacks

```
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"

from torch.distributed.debug import enable_debug_server

enable_debug_server()
```

Test plan:

```
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py
```

<img width="2000" height="1045" alt="20251117_16h58m18s_grim" src="https://github.com/user-attachments/assets/82305b75-227c-4412-a481-00b622db6bd1" />
<img width="2172" height="1624" alt="20251117_16h58m11s_grim" src="https://github.com/user-attachments/assets/def9841c-c7e6-483a-81c3-cf0c56f6bad8" />
<img width="1985" height="1635" alt="20251117_16h58m03s_grim" src="https://github.com/user-attachments/assets/04fcf148-df58-41b4-8754-8706ee0d1de6" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167395
Approved by: https://github.com/fduwjj, https://github.com/malfet, https://github.com/atalman
2025-11-18 19:00:24 +00:00
5333e51195 [CUDA][Thor] Enable CUTLASS matmuls on Thor (#164836)
This PR enables special matmuls on Thor devices. This includes row-wise scaled matmul on `fp8` and group gemm on `bfloat16`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164836
Approved by: https://github.com/ngimel
2025-11-18 18:45:47 +00:00
0e13964b74 [CI] Disable ET tests (again) (#168090)
Repeatition of https://github.com/pytorch/pytorch/pull/155708
Has been broken for a while, and ET pin in Pytorch are so old that `torch==2.10.0.dev20250915` could no longer be found in nightly indices
Pull Request resolved: https://github.com/pytorch/pytorch/pull/168090
Approved by: https://github.com/atalman, https://github.com/yangw-dev
2025-11-18 18:08:12 +00:00
20cae808f7 ComplexTensor subclass (#167621)
This PR introduces a `Tensor` subclass which represents a complex tensor in terms of two real ones. Ops are decomposed as individual ops  on the real and imaginary parts.

It is compatible with `torch.compile`, so long as the real ops used are also compatible. Autograd "works", but is WIP due to different edge-case behaviour.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167621
Approved by: https://github.com/ezyang
2025-11-18 17:57:33 +00:00
57927a620d [Profiler] Deprecate export_memory_timeline method (#168036)
Summary: The export_memory_timeline method in torch.profiler is being deprecated in favor of the newer memory snapshot API (torch.cuda.memory._record_memory_history and torch.cuda.memory._export_memory_snapshot). This change adds the deprecated decorator from typing_extensions and updates the docstring to guide users to the recommended alternative. The decorator will emit a FutureWarning at runtime, and the docstring now includes a .. deprecated:: directive for documentation visibility.

Test Plan: Manual verification that the decorator is properly applied and the deprecation message is informative.

Differential Revision: D87272399

Pull Request resolved: https://github.com/pytorch/pytorch/pull/168036
Approved by: https://github.com/valentinandrei
2025-11-18 17:56:50 +00:00
53 changed files with 3100 additions and 181 deletions

View File

@ -402,3 +402,6 @@ scikit-build==0.18.1
pyre-extensions==0.0.32
tabulate==0.9.0
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
Jinja2==3.1.6
#Description: required for torch.distributed.debug

View File

@ -1768,7 +1768,7 @@ test_operator_microbenchmark() {
cd "${TEST_DIR}"/benchmarks/operator_benchmark
for OP_BENCHMARK_TESTS in optimizer; do
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile

View File

@ -75,7 +75,8 @@ jobs:
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter,
pytorch-linux-jammy-py3-clang12-executorch,
# TODO: Re-enable me when docker pin update happens
# pytorch-linux-jammy-py3-clang12-executorch,
pytorch-linux-jammy-py3.12-triton-cpu,
pytorch-linux-noble-riscv64-py3.12-gcc14
]

View File

@ -50,9 +50,10 @@ jobs:
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}"
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:

View File

@ -283,6 +283,7 @@ jobs:
name: linux-jammy-py3-clang12-executorch
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
if: false # Has been broken for a while
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3-clang12-executorch

View File

@ -813,8 +813,43 @@ void smooth_l1_kernel(TensorIteratorBase& iter, double beta) {
}
void huber_kernel(TensorIterator& iter, double delta) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() {
// Special-case kHalf: compute in float for numerical stability
if (iter.dtype() == kHalf) {
const float delta_val(static_cast<float>(delta));
const Vectorized<float> delta_vec(static_cast<float>(delta));
const Vectorized<float> point_five_vec(static_cast<float>(0.5));
cpu_kernel_vec(
iter,
// scalar lambda: convert half -> float, compute in float, cast back to half
[&delta_val] (at::Half a, at::Half b) -> at::Half {
float af = static_cast<float>(a);
float bf = static_cast<float>(b);
float z = std::abs(af - bf);
float out = z < delta_val
? 0.5f * z * z
: delta_val * (z - 0.5f * delta_val);
return static_cast<at::Half>(out);
},
[&delta_vec, &point_five_vec] (Vectorized<Half> a, Vectorized<Half> b) {
auto [a0, a1] = convert_half_float(a);
auto [b0, b1] = convert_half_float(b);
auto z = (a0 - b0).abs();
a0 = Vectorized<float>::blendv(
point_five_vec * z * z,
delta_vec * (z - point_five_vec * delta_vec),
z >= delta_vec);
z = (a1 - b1).abs();
a1 = Vectorized<float>::blendv(
point_five_vec * z * z,
delta_vec * (z - point_five_vec * delta_vec),
z >= delta_vec);
return convert_float_half(a0, a1);
}
);
return;
}
else {
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "huber_cpu", [&]() {
using Vec = Vectorized<scalar_t>;
const scalar_t delta_val(delta);
const Vec delta_val_vec(delta_val);
@ -835,6 +870,7 @@ void huber_kernel(TensorIterator& iter, double delta) {
z >= delta_val_vec);
});
});
}
}
void sigmoid_backward_kernel(TensorIteratorBase& iter) {

View File

@ -346,8 +346,9 @@ void dispatch_bf16_grouped_kernel_on_tile_size(
bool small = (M <= 128 || N <= 128);
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm10x = properties != nullptr && properties->major == 10;
const bool sm11x = properties != nullptr && properties->major == 11;
if (sm10x) {
if (sm10x || sm11x) {
if (small){
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm100,

View File

@ -958,8 +958,9 @@ void dispatch_fp8_rowwise_kernel_on_sm(
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
const bool sm9x = properties != nullptr && properties->major == 9;
const bool sm10x = properties != nullptr && properties->major == 10;
const bool sm11x = properties != nullptr && properties->major == 11;
const bool sm12x = properties != nullptr && properties->major == 12;
if (!(sm89 || sm9x || sm10x || sm12x)) {
if (!(sm89 || sm9x || sm10x || sm11x || sm12x)) {
TORCH_CHECK(
false, "Rowwise scaling is not currently supported on your device");
}
@ -968,7 +969,7 @@ void dispatch_fp8_rowwise_kernel_on_sm(
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
/*ArchTag=*/cutlass::arch::Sm90,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (sm10x) {
} else if (sm10x || sm11x) {
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
/*ArchTag=*/cutlass::arch::Sm100,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);

View File

@ -1,64 +0,0 @@
import operator_benchmark as op_bench
import torch
import torch.optim as optim
"""Microbenchmarks for optimizer operators."""
optimizer_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["adamw", optim.AdamW],
["adam", optim.Adam],
["sgd", optim.SGD],
["rmsprop", optim.RMSprop],
["adagrad", optim.Adagrad],
],
)
optimizer_configs_long = op_bench.cross_product_configs(
num_params=[1, 10, 100],
param_size=[100000, 1000000, 10000000],
device=["cuda"],
tags=["long"],
)
class OptimizerBenchmark(op_bench.TorchBenchmarkBase):
def init(self, op_func, device, shape=None, num_params=None, param_size=None):
if shape is not None:
num_params = num_params if num_params is not None else 1
self.params = [
torch.randn(shape, device=device, requires_grad=True)
for _ in range(num_params)
]
for param in self.params:
param.grad = torch.randn(shape, device=device)
else:
self.params = [
torch.randn(param_size, device=device, requires_grad=True)
for _ in range(num_params)
]
for param in self.params:
param.grad = torch.randn_like(param)
kwargs = {"momentum": 0.9} if op_func == optim.SGD else {}
self.optimizer = op_func(self.params, lr=0.001, **kwargs)
self.inputs = {"dummy": self.params[0]} # Added to run memory benchmarking
def forward(self, dummy):
self.optimizer.step()
for param in self.params:
param.grad = torch.randn_like(param)
return self.params[0]
op_bench.generate_pt_tests_from_op_list(
optimizer_list, optimizer_configs_long, OptimizerBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -113,6 +113,12 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a")
endif()
endif()
# We will need to gate against CUDA version, because sm_110a is available on CUDA 13.0+
if("${_arch}" STREQUAL "110a" AND CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
if(_existing_arch_flags MATCHES ".*compute_110.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_110a,code=sm_110a")
endif()
endif()
if("${_arch}" STREQUAL "120a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
@ -132,13 +138,13 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a;121a")
"89;90a;100a;103a;110a;120a;121a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
"90a;100a;103a")
"90a;100a;103a;110a")
endif()

View File

@ -987,6 +987,24 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T
collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the
{func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs.
### torch.distributed.debug HTTP Server
The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can
be started by calling {func}`torch.distributed.debug.start_debug_server`. This
allows users to collect data across all workers at runtime.
```{eval-rst}
.. automodule:: torch.distributed.debug
:members:
:undoc-members:
:show-inheritance:
:special-members: __init__
:member-order: bysource
```
## Logging
In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log

View File

@ -0,0 +1,238 @@
# Owner(s): ["module: complex"]
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
# Support both when imported from elsewhere or directly as a file
try:
from .utils import (
COMPLEX_DTYPES,
Descriptor,
force_test_op_db,
get_overload_packet_from_name,
implemented_op_db,
TestCase,
Variant,
)
except ImportError:
from utils import (
COMPLEX_DTYPES,
Descriptor,
force_test_op_db,
get_overload_packet_from_name,
implemented_op_db,
TestCase,
Variant,
)
from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
OpDTypes,
ops,
)
from torch.testing._internal.common_utils import (
run_tests,
TestGradients,
unMarkDynamoStrictTest,
)
if TYPE_CHECKING:
from torch.testing._internal.opinfo.core import OpInfo
aten = torch.ops.aten
SKIPS = {
Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output",
Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output",
Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.reciprocal, variant=Variant.GradCheck
): "Numerical inconsistency",
Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.true_divide, variant=Variant.GradCheck
): "Numerical inconsistency",
Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.any, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.all, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.allclose, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.conj_physical, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten._conj_physical, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.cumprod, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.index_add, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.diagonal_scatter, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.flip, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.masked_fill, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.masked_scatter, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.rsub, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.ne, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.squeeze, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.index_select, variant=Variant.Distributed
): "Sharding propagation failed",
Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support",
}
EXTRA_KWARGS = {
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 2e-5,
"atol": 5e-5,
},
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 1e-4,
"atol": 1e-5,
},
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 2e-2,
"atol": 2e-6,
},
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-5,
"atol": 5e-5,
},
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 1e-4,
"atol": 1e-5,
},
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-2,
"atol": 2e-6,
},
Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-6,
"atol": 1e-2,
},
}
class TestComplexTensor(TestCase):
_default_dtype_check_enabled = True
@ops(
implemented_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=list(COMPLEX_DTYPES),
)
def test_consistency(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Op)
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_maybe_error(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Op)
@unMarkDynamoStrictTest
class TestComplexBwdGradients(TestGradients):
_default_dtype_check_enabled = True
@ops(
implemented_op_db,
dtypes=OpDTypes.supported_backward,
allowed_dtypes=[torch.complex128],
)
def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None:
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device_type=torch.device(device).type,
dtype=dtype,
variant=Variant.GradCheck,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
self.skipTest(reason)
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!")
with ComplexTensorMode():
op.gradcheck_fast_mode = False
self._grad_test_helper(device, dtype, op, op.get_op())
instantiate_device_type_tests(TestComplexTensor, globals())
instantiate_device_type_tests(TestComplexBwdGradients, globals())
if dist.is_available():
from torch.testing._internal.common_distributed import MultiProcessTestCase
@unMarkDynamoStrictTest
class TestComplexDistributed(TestCase, MultiProcessTestCase):
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_distributed(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Distributed)
instantiate_device_type_tests(TestComplexDistributed, globals())
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,214 @@
from __future__ import annotations
from dataclasses import dataclass, field, fields
from enum import auto, Enum
from typing import Any, TYPE_CHECKING
import torch
import torch.distributed as dist
from torch._subclasses.complex_tensor._ops.common import (
_as_complex_tensor,
_as_interleaved,
_get_op_name,
COMPLEX_OPS_TABLE,
COMPLEX_TO_REAL,
FORCE_TEST_LIST,
OpOverloadPacket,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torch.utils._pytree import tree_flatten
if TYPE_CHECKING:
from collections.abc import Callable
from torch.distributed.tensor import DTensor
from torch.testing._internal.opinfo.core import OpInfo
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
class Variant(Enum):
Op = auto()
GradCheck = auto()
Distributed = auto()
def _as_local(arg: DTensor | Any) -> torch.Tensor | Any:
if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)):
return arg
return arg.full_tensor()
def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
if not isinstance(arg, torch.Tensor):
return arg
return dist.tensor.DTensor.from_local(_as_complex_tensor(arg))
TRANSFORM_FUNCS = {
Variant.Op: _as_complex_tensor,
Variant.Distributed: _as_complex_dtensor,
}
@dataclass(frozen=True, kw_only=True)
class Descriptor:
op: OpOverloadPacket
variant: Variant | None
device_type: str | None = field(default=None)
dtype: torch.dtype | None = field(default=None)
def matches(self, other: Descriptor) -> bool:
fields1 = fields(self)
fields2 = fields(other)
if fields1 != fields2:
return False
for f in fields1:
f1 = getattr(self, f.name)
f2 = getattr(other, f.name)
if f1 is not None and f2 is not None and f1 != f2:
return False
return True
class TestCase(PytorchTestCase):
def assertSameResult(
self,
expected: Callable[[], Any],
actual: Callable[[], Any],
*args,
**kwargs,
) -> None:
try:
result_e = expected()
exception_e = None
except Exception as e: # noqa: BLE001
result_e = None
exception_e = e
try:
result_a = actual()
exception_a = None
except Exception as e: # noqa: BLE001
result_a = None
exception_a = e
if (exception_e is None) != (exception_a is None):
if exception_a is not None and exception_e is None:
raise exception_a
self.assertIs(
type(exception_e),
type(exception_a),
f"\n{exception_e=}\n{exception_a=}",
)
if exception_e is None:
flattened_e, spec_e = tree_flatten(result_e)
flattened_a, spec_a = tree_flatten(result_a)
self.assertEqual(
spec_e,
spec_a,
"Both functions must return a result with the same tree structure.",
)
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
value_e = _as_interleaved(_as_local(value_e))
value_a = _as_interleaved(_as_local(value_a))
self.assertEqual(value_e, value_a, *args, **kwargs)
def check_consistency(
self, device: str, dtype, op: OpInfo, variant: Variant
) -> None:
try:
from .test_complex_tensor import EXTRA_KWARGS, SKIPS
except ImportError:
from test_complex_tensor import EXTRA_KWARGS, SKIPS
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device_type=torch.device(device).type,
dtype=dtype,
variant=variant,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
self.skipTest(reason)
kwargs = {}
for extra_info, extra_kw in EXTRA_KWARGS.items():
if extra_info.matches(test_info):
kwargs = extra_kw
break
sample_inputs = op.sample_inputs(device, dtype)
transform_fn = TRANSFORM_FUNCS[variant]
for sample_input in sample_inputs:
def expected(sample_input=sample_input):
return op(sample_input.input, *sample_input.args, **sample_input.kwargs)
subclass_sample = sample_input.transform(transform_fn)
def actual(subclass_sample=subclass_sample):
return op(
subclass_sample.input,
*subclass_sample.args,
**subclass_sample.kwargs,
)
self.assertSameResult(expected, actual, **kwargs)
aten = torch.ops.aten
complex_op_db = tuple(
filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db)
)
def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
for domain_name in torch.ops:
op_namespace = getattr(torch.ops, domain_name)
op: OpOverloadPacket | None = getattr(op_namespace, name, None)
if op is not None:
return op
raise RuntimeError(f"No op with {name=} found.")
force_test_names = set(map(_get_op_name, FORCE_TEST_LIST))
implemented_op_names = (
set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names
)
implemented_op_db = tuple(
filter(lambda op: op.name in implemented_op_names, complex_op_db)
)
force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db))
tested_op_names = {op.name for op in implemented_op_db} | {
op.name for op in force_test_op_db
}
non_tested_ops = {
op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names
}
# TODO (hameerabbasi): There are a number of ops that don't have any associated
# OpInfos. We still need to write tests for those ops.
if len(non_tested_ops) != 0:
import textwrap
import warnings
list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops]))
warnings.warn(
"Not all implemented ops are tested. List of ops missing tests:"
f"\n{textwrap.indent(list_missing_ops, ' ')}",
UserWarning,
stacklevel=2,
)

View File

@ -6,7 +6,7 @@ import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
@ -14,6 +14,9 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
c10d_functional = torch.ops.c10d_functional
c10d_ops = torch.ops.c10d
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class TestCommMode(TestCase):
@ -28,7 +31,7 @@ class TestCommMode(TestCase):
dist.init_process_group(
backend="fake", rank=1, world_size=self.world_size, store=store
)
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
self.device_type = device_type
self.world_pg = dist.distributed_c10d._get_default_group()
def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
@ -111,12 +114,12 @@ class TestCommMode(TestCase):
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_comm_mode_with_c10d(self):
if not torch.cuda.is_available():
if not torch.accelerator.is_available():
return
inp = torch.rand(2, 8, 16).cuda()
inp = torch.rand(2, 8, 16).to(device_type)
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
comm_mode = CommDebugMode()

View File

@ -658,11 +658,11 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms
def test_dtensor_device_mesh_device_conversion(self):
# construct a cuda device mesh
# construct a gpu device mesh
mesh = self.build_device_mesh()
# construct from a cpu local tensor with cuda device mesh
# should automatically convert the dist tensor to cuda
# construct from a cpu local tensor with gpu device mesh
# should automatically convert the dist tensor to gpu
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
@ -711,7 +711,7 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms
def test_dtensor_2d_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
# construct a cuda device mesh
# construct a gpu device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works
@ -733,7 +733,7 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms
def test_device_mesh_nd(self):
# construct a cuda device mesh
# construct a gpu device mesh
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works
@ -1064,8 +1064,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
# Keep everything deterministic.
torch.manual_seed(0)
tensor = torch.rand(size)
if self.device_type == "cuda":
return tensor.cuda()
if self.device_type != "cpu":
return tensor.to(self.device_type)
else:
return tensor

View File

@ -39,6 +39,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
)
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_device_type import skipXPUIf
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import (
@ -47,8 +48,6 @@ from torch.testing._internal.common_utils import (
run_tests,
skipIfHpu,
skipIfTorchDynamo,
TEST_CUDA,
TEST_HPU,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -95,6 +94,10 @@ aot_eager_graph = aot_autograd(
partition_fn=min_cut_rematerialization_partition,
)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
"""
@ -141,7 +144,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
@property
def device_type(self) -> str:
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu"
return device_type
@property
def world_size(self) -> int:
@ -160,9 +163,9 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
res = fn(x)
res.to_local().sum().backward()
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available")
def test_dtensor_basic_export(self):
mesh = DeviceMesh("cuda", torch.arange(self.world_size))
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
param = torch.randn(4, 4)
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
@ -188,10 +191,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
)
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
f"""\
def forward(self, b_buffer, x):
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None
view_as = torch.ops.aten.view_as.default(to, to); to = None
dtensor___init__0 = self.dtensor___init__0
dtensor_const_func_spec0 = self.dtensor_const_func_spec0
@ -206,10 +209,10 @@ def forward(self, b_buffer, x):
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
self.assertExpectedInline(
str(ep.run_decompositions({}).graph_module.code).strip(),
"""\
f"""\
def forward(self, b_parametrizations_buffer_original0, x):
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
@ -377,6 +380,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
self.assertEqual(res, ref)
@skipIfHpu
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981")
def test_dtensor_dynamic_loss_parallel_log_softmax(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -815,13 +819,13 @@ def forward(self, b_parametrizations_buffer_original0, x):
out = layer_norm.permute(0, 2, 1)
return out
x = torch.randn(4, 2, 4, requires_grad=True, device="cuda")
x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type)
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
y = torch.randn(4, requires_grad=True, device="cuda")
y = torch.randn(4, requires_grad=True, device=self.device_type)
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
z = torch.randn(4, requires_grad=True, device="cuda")
z = torch.randn(4, requires_grad=True, device=self.device_type)
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
@ -919,7 +923,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
# pass in tensor as inputs/outputs, create DTensor and run redistribute
# (allgather collective) inside the fn
def fn(x_dt):
if x_dt.device_mesh.device_type == "cuda":
if x_dt.device_mesh.device_type == f"{self.device_type}":
return x_dt + 1
else:
return x_dt + 2
@ -1051,7 +1055,7 @@ def forward(self, primals_1):
model = FakeTransformer().to(self.device_type)
tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",))
# apply sequence parallel
parallel_plan = {

View File

@ -27,8 +27,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_CUDA,
TEST_HPU,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -541,7 +539,7 @@ class RedistributeTest(DTensorTestBase):
local_out_dt = out_dt.to_local()
local_expected_dt = expected_dt.to_local()
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
if TEST_HPU or TEST_CUDA:
if torch.accelerator.is_available():
self.assertEqual(
comm_mode.get_comm_counts()[
torch.ops._dtensor.shard_dim_alltoall

View File

@ -296,8 +296,8 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(dist_tensor.dtype, torch.float32)
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_stack(self):
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2)

View File

@ -30,7 +30,7 @@ from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0):
# for tests, assume a matmul can hide a single collective
if "c10" in str(fx_node.target):
return 1.0
@ -1112,7 +1112,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
def estimate_with_half_compute(fx_node, override_size=None):
return estimate_aten_runtime(fx_node, compute_multiplier=0.5)
return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5)
def func(a, b, *, ranks):
# Two all_gathers that will be hidden by multiple compute operations
@ -1162,6 +1162,56 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_bucketing_with_convert_dtype(self):
"""Test that all_gathers with dtype conversion get bucketed and produce correct results."""
def func(a, b, c, d, *, ranks):
# Convert inputs to float16 before all_gather
a_fp16 = a.to(torch.float16)
b_fp16 = b.to(torch.float16)
# Two all_gathers with converted dtypes
ag1 = _functional_collectives.all_gather_tensor(a_fp16, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b_fp16, 0, ranks)
# same dtype
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d, 0, ranks)
return ag1, ag2, ag3, ag4
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(4, 4, dtype=torch.float32, device=device_type)
b = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 2
c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3
d = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify convert_element_type ops are removed (dtype conversion handled by _pre_bucket_all_gather)
FileCheck().check_not("torch.ops.prims.convert_element_type").run(
aten_graph_str
)
# Verify correctness - this tests that dtype conversion is handled correctly
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
def get_toy_model(device_type: str):
"""

View File

@ -0,0 +1,56 @@
# Owner(s): ["oncall: distributed"]
import os
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import torch
import torch.distributed as dist
from torch.distributed.debug import start_debug_server, stop_debug_server
from torch.testing._internal.common_utils import run_tests, TestCase
session = requests.Session()
retry_strategy = Retry(total=5, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
class TestDebug(TestCase):
def test_basics(self) -> None:
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(store.port)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
port = 25999
def fetch(path: str) -> str:
resp = session.get(f"http://localhost:{port}{path}")
resp.raise_for_status()
return resp.text
start_debug_server(port=port)
self.assertIn("torch profiler", fetch("/"))
self.assertIn("View 0", fetch("/profile?duration=0.01"))
self.assertIn("test_basics", fetch("/stacks"))
self.assertIn("pg_status", fetch("/fr_trace"))
if torch.cuda.is_available():
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
# test errors
resp = session.get(f"http://localhost:{port}/blah")
self.assertEqual(resp.status_code, 404)
self.assertIn("Handler not found: /blah", resp.text)
stop_debug_server()
if __name__ == "__main__":
run_tests()

View File

@ -667,6 +667,94 @@ class TestOverlapPreservingBucketing(InductorTestCase):
str(traced.graph)
)
def test_can_bucket_with_convert_dtype_as_hiding_nodes(self):
"""
Test that all_gathers can bucket when convert_element_type ops ARE the hiding nodes.
Graph structure:
ag1_start -> convert1 (hides ag1) -> ag1_wait -> ag2_start -> convert2 (hides ag2) -> ag2_wait
The convert_element_type ops ARE hiding nodes - no matmuls.
This tests that dependencies are transferred correctly when convert nodes are erased.
"""
def func(a, b, c):
group_name = "0"
group_size = 1
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
b = torch.ops.prims.convert_element_type.default(b, torch.float16)
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
ag3 = torch.ops._c10d_functional.all_gather_into_tensor(
c, group_size, group_name
)
mm = ag1_out @ ag1_out
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
ag3_out = torch.ops._c10d_functional.wait_tensor(ag3)
return ag1_out, ag2_out, ag3_out, mm
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device, dtype=torch.float32)
b = torch.ones(4, 4, device=self.device, dtype=torch.float32)
c = torch.ones(4, 4, device=self.device, dtype=torch.float32)
traced = make_fx(func)(a, b, c)
# Find nodes
ag1, ag2, ag3 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
convert1 = traced.graph.find_nodes(
op="call_function",
target=torch.ops.prims.convert_element_type.default,
)[0]
mm = traced.graph.find_nodes(
op="call_function",
target=torch.ops.aten.mm.default,
)[0]
hiding_annotations = {
ag1: convert1,
ag2: mm,
ag3: mm,
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
graph_str = str(traced.graph)
f = FileCheck()
f.check_count("%all_gather_into_tensor", 1, exactly=True)
f.check("pre_bucket_all_gather").check("wait_tensor").check(
"%all_gather_into_tensor_out"
).run(graph_str)
if __name__ == "__main__":
run_tests()

View File

@ -828,9 +828,6 @@ inductor_one_sample["cuda"] = {
"nn.functional.fractional_max_pool3d": {f16, f32, f64},
"nn.functional.group_norm": {f16},
"nn.functional.hinge_embedding_loss": {f16},
# Enabling all tests for this test fails randomly
# See https://github.com/pytorch/pytorch/issues/129238
"nn.functional.huber_loss": {f16},
"nn.functional.interpolate.bicubic": {f16},
"nn.functional.interpolate.bilinear": {f16},
"nn.functional.interpolate.trilinear": {f16},
@ -948,9 +945,6 @@ inductor_one_sample["xpu"] = {
"nn.functional.fractional_max_pool3d": {f16, f32, f64},
"nn.functional.group_norm": {f16},
"nn.functional.hinge_embedding_loss": {f16},
# Enabling all tests for this test fails randomly
# See https://github.com/pytorch/pytorch/issues/129238
"nn.functional.huber_loss": {f16},
"nn.functional.interpolate.bicubic": {f16},
"nn.functional.interpolate.bilinear": {f16},
"nn.functional.interpolate.trilinear": {f16},

View File

@ -110,6 +110,16 @@ class AttentionBlock(nn.Module):
return self.out_proj(attn_out)
def pack_sequences(seqs, device):
x_packed = torch.cat(seqs, dim=0)
seq_lens = torch.tensor([len(s) for s in seqs], device=device)
cu_seq = torch.zeros(len(seqs) + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lens.cumsum(0)
max_len = seq_lens.max().item()
return x_packed, cu_seq, max_len
def create_variable_length_batch(
shape: VarlenShape, device: torch.device, dtype: torch.dtype
):
@ -119,16 +129,15 @@ def create_variable_length_batch(
seq_lengths.append(min(length, shape.max_seq_len))
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
sequences = [
torch.randn(
seq_len, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
for seq_len in seq_lengths
]
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
max_len = seq_lengths.max().item()
x_packed, cu_seq, max_len = pack_sequences(sequences, device)
x_padded = torch.zeros(
shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype
)
@ -146,7 +155,6 @@ def create_variable_length_batch(
"x_packed": x_packed,
"x_padded": x_padded,
"max_len": max_len,
"total_tokens": total_tokens,
}
@ -428,6 +436,143 @@ class TestVarlenAttention(NNTestCase):
start_idx = end_idx
@skipIfRocm(msg="ROCM does not support variable length attention")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
@parametrize("is_causal", [False, True])
def test_batch_invariance(self, device, dtype, is_causal):
torch.manual_seed(42)
batch_size = 4
max_seq_len = 512
embed_dim = 1024
num_heads = 16
head_dim = embed_dim // num_heads
seq_lengths = []
for _ in range(batch_size):
length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
seq_lengths.append(min(length, max_seq_len))
sequences_q = [
torch.testing.make_tensor(
(seq_len, num_heads, head_dim),
device=device,
dtype=dtype,
requires_grad=True,
)
for seq_len in seq_lengths
]
sequences_k = [
torch.testing.make_tensor(
(seq_len, num_heads, head_dim),
device=device,
dtype=dtype,
requires_grad=True,
)
for seq_len in seq_lengths
]
sequences_v = [
torch.testing.make_tensor(
(seq_len, num_heads, head_dim),
device=device,
dtype=dtype,
requires_grad=True,
)
for seq_len in seq_lengths
]
q_packed_orig = torch.cat(sequences_q, dim=0)
k_packed_orig = torch.cat(sequences_k, dim=0)
v_packed_orig = torch.cat(sequences_v, dim=0)
seq_lens = torch.tensor(seq_lengths, device=device)
cu_seq_orig = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seq_orig[1:] = seq_lens.cumsum(0)
max_len_orig = seq_lens.max().item()
original_output = varlen_attn(
q_packed_orig,
k_packed_orig,
v_packed_orig,
cu_seq_orig,
cu_seq_orig,
max_len_orig,
max_len_orig,
is_causal,
)
perm = torch.randperm(batch_size)
permuted_sequences_q = [sequences_q[perm[i]] for i in range(batch_size)]
permuted_sequences_k = [sequences_k[perm[i]] for i in range(batch_size)]
permuted_sequences_v = [sequences_v[perm[i]] for i in range(batch_size)]
q_packed_perm = torch.cat(permuted_sequences_q, dim=0)
k_packed_perm = torch.cat(permuted_sequences_k, dim=0)
v_packed_perm = torch.cat(permuted_sequences_v, dim=0)
permuted_seq_lens = torch.tensor(
[seq_lengths[perm[i]] for i in range(batch_size)], device=device
)
cu_seq_perm = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
cu_seq_perm[1:] = permuted_seq_lens.cumsum(0)
max_len_perm = permuted_seq_lens.max().item()
permuted_output = varlen_attn(
q_packed_perm,
k_packed_perm,
v_packed_perm,
cu_seq_perm,
cu_seq_perm,
max_len_perm,
max_len_perm,
is_causal,
)
for i in range(batch_size):
orig_idx = perm[i].item()
orig_start = cu_seq_orig[orig_idx].item()
orig_end = cu_seq_orig[orig_idx + 1].item()
orig_seq_output = original_output[orig_start:orig_end]
perm_start = cu_seq_perm[i].item()
perm_end = cu_seq_perm[i + 1].item()
perm_seq_output = permuted_output[perm_start:perm_end]
self.assertEqual(orig_seq_output, perm_seq_output)
original_grad_out = torch.ones_like(original_output)
permuted_grad_out = torch.ones_like(permuted_output)
original_grad = torch.autograd.grad(
outputs=original_output,
inputs=q_packed_orig,
grad_outputs=original_grad_out,
)[0]
permuted_grad = torch.autograd.grad(
outputs=permuted_output,
inputs=q_packed_perm,
grad_outputs=permuted_grad_out,
)[0]
for i in range(batch_size):
orig_idx = perm[i].item()
orig_start = cu_seq_orig[orig_idx].item()
orig_end = cu_seq_orig[orig_idx + 1].item()
orig_seq_grad = original_grad[orig_start:orig_end]
perm_start = cu_seq_perm[i].item()
perm_end = cu_seq_perm[i + 1].item()
perm_seq_grad = permuted_grad[perm_start:perm_end]
self.assertEqual(orig_seq_grad, perm_seq_grad)
device_types = ("cuda",)

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
@ -206,6 +208,7 @@ class Store:
desired_value: str,
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def multi_get(self, keys: list[str]) -> list[bytes]: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
@ -872,3 +875,15 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...
class _Request:
def body(self) -> bytes: ...
def get_param(self, str) -> str: ...
class _Response:
def set_content(self, content: str | bytes, content_type: str) -> None: ...
def set_status(self, status: int) -> None: ...
def _register_handler(
name: str, handler: Callable[[_Request, _Response], None]
) -> None: ...

View File

@ -60,6 +60,7 @@ class _ExperimentalConfig:
verbose: bool = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
profile_all_threads: bool = ...,
) -> None: ...
class ProfilerConfig:

View File

@ -341,12 +341,58 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
sz_bytes = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
numel = get_size_numel(t.size())
sz_bytes += numel * get_dtype_size(t.dtype)
return sz_bytes
"""Estimate the size of a collective operation in bytes, including inputs and outputs."""
input_bytes = None
args, kwargs = fx_node.args, fx_node.kwargs
kwargs = dict(kwargs)
# dont double count pre-allocated buffer passed in
kwargs.pop("out", None)
def tensor_bytes(t) -> int:
return get_size_numel(t.size()) * get_dtype_size(t.dtype)
def add_inp_bytes(inp: torch.fx.Node):
t = inp.meta.get("val", None)
if t is None:
return
nonlocal input_bytes
if input_bytes is None:
input_bytes = 0
input_bytes += tensor_bytes(t)
pytree.tree_map_only(
torch.fx.Node,
add_inp_bytes,
(args, kwargs),
)
output_tensor = fx_node.meta.get("val", None)
if input_bytes is None or output_tensor is None:
return 0
output_bytes = (
get_size_numel(output_tensor.size()) * output_tensor.element_size()
) # pyre-ignore
return input_bytes + output_bytes
def estimate_fx_collective_memory_footprint(fx_node: torch.fx.Node) -> int:
"""Estimate the memory footprint of a collective operation in bytes.
This returns the total bytes that need to be live concurrently in memory.
For all_reduce, we divide by 2 since it can be done in-place.
"""
from torch._inductor.fx_passes.bucketing import (
is_all_reduce_tensor as is_all_reduce,
)
size = estimate_fx_collective_size(fx_node)
return size if not is_all_reduce(fx_node) else size // 2
def estimate_nccl_collective_runtime_from_fx_node(

View File

@ -489,15 +489,34 @@ def all_reduce_merge_fn_to_trace(
return new_outs
# List of all torch dtypes for serialization through custom ops
# TODO: custom ops support list[dtype] input
_ALL_DTYPES = tuple(
[
getattr(torch, attr)
for attr in dir(torch)
if isinstance(getattr(torch, attr), torch.dtype)
]
)
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
def _pre_bucket_all_gather(
ag_ins: list[torch.Tensor],
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtype_ints: list[
int
], # dtype enum values, that inputs are converted to before all_gather
rank: int,
) -> torch.Tensor:
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
# Convert int indices back to torch.dtype
out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
ins_split_sizes_bytes = [
ag_in.numel() * out_dtype.itemsize
for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
@ -507,8 +526,14 @@ def _pre_bucket_all_gather(
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
# View each destination slice as its output dtype, then copy
# The copy operation handles dtype conversion from input dtype to output dtype
foreach_copy_dsts_typed = [
dst.view(out_dtype)
for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True)
]
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened)
return new_ag_out
@ -517,9 +542,14 @@ def _pre_bucket_all_gather_fake(
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtype_ints: list[int],
rank: int,
) -> torch.Tensor:
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
ins_split_sizes_bytes = [
ag_in.numel() * out_dtype.itemsize
for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
@ -541,12 +571,9 @@ def all_gather_merge_fn_to_trace_custom_ops(
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
) -> list[torch.Tensor]:
ag_ins = [
torch._prims.convert_element_type(_ag_in, out_dtype)
if _ag_in.dtype != out_dtype
else _ag_in
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
]
# Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion
# by viewing destination slices as output dtypes and letting copy do the conversion
ag_ins = _ag_ins
ins_sizes = [ag_in.shape for ag_in in ag_ins]
ins_split_sizes_bytes = [
ag_in.numel() * out_dtype.itemsize
@ -557,8 +584,13 @@ def all_gather_merge_fn_to_trace_custom_ops(
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
# Convert out_dtypes to indices for custom_op
# TODO: custom ops support list[dtype] input
out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes]
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
ag_ins, group_size, group_name, dtype, rank
ag_ins, group_size, group_name, dtype, out_dtype_ints, rank
)
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
wait_tensor = torch.ops.c10d_functional.wait_tensor(
@ -721,6 +753,20 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
return replacements, new_nodes
def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool:
node_in = n.args[0]
return (
is_all_gather_into_tensor(n)
and isinstance(node_in, torch.fx.Node)
and node_in.op == "call_function"
and (
node_in.target is torch.ops.prims.convert_element_type.default
or node_in.target is torch.ops.aten._to_copy.default
)
and len(node_in.users) == 1
)
def process_collective_bucket(
g: torch.fx.Graph,
bucket_nodes: list[torch.fx.Node],
@ -755,13 +801,7 @@ def process_collective_bucket(
# Handle convert_element_type operations (for all_gather)
node_in = n.args[0]
if (
is_all_gather_into_tensor(n)
and isinstance(node_in, torch.fx.Node) # Add type check
and node_in.op == "call_function"
and node_in.target is torch.ops.prims.convert_element_type.default
and len(node_in.users) == 1
):
if has_mergeable_all_gather_convert_dtype(n):
ag_node_to_pre_nodes[n].append(node_in)
node_in = node_in.args[0]

View File

@ -3,12 +3,14 @@ from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Literal, Optional
import torch
import torch.fx as fx
from torch._dynamo.utils import counters
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
from torch._inductor.fx_passes.bucketing import (
bucket_key,
BucketMode,
has_mergeable_all_gather_convert_dtype,
is_all_gather_into_tensor as is_all_gather,
is_reduce_scatter_tensor as is_reduce_scatter,
is_wait_tensor,
@ -207,6 +209,7 @@ class OverlapPreservingBucketer:
prev_event = event
position += 1
return head
def _populate_node_to_event(self, pg: str) -> None:
@ -231,7 +234,6 @@ class OverlapPreservingBucketer:
self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
def bucket_collectives(self) -> None:
"""Main entry point for bucketing collectives."""
# Group collectives by PG first
pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for start in self.collective_info:
@ -281,6 +283,15 @@ class OverlapPreservingBucketer:
# Apply topological sort with all dependencies
from torch._dynamo.graph_deduplication import _stable_topological_sort
for n, deps in additional_deps.items():
torch._check(
not n._erased, lambda: f"Erased node deps not transferred: {n}"
)
for d in deps:
torch._check(
not d._erased, lambda: f"Erased node deps not transferred: {d}"
)
_stable_topological_sort(self.graph, additional_deps)
# After topological sort, preserve dependencies using effect tokens
@ -762,6 +773,11 @@ class OverlapPreservingBucketer:
old_starts = list(bucket)
old_waits = [self.collective_info[n].wait_node for n in bucket]
fused_convert_dtypes = []
for n in old_starts:
if has_mergeable_all_gather_convert_dtype(n):
fused_convert_dtypes.append(n.args[0])
# Find where to place the bucketed operations
next_node = bucket[0]
while next_node in bucket:
@ -809,6 +825,22 @@ class OverlapPreservingBucketer:
for old_wait in old_waits:
erased_to_new[old_wait] = new_wait
# Handle convert_element_type nodes that were fused and erased
# The bucketed operation may have a _pre_bucket op that handles dtype conversion
if fused_convert_dtypes:
# all gather bucketing may fuse in dtype conversion into the bucketing
# if so, we need to transfer hiding deps from the old dtype conversion
# to the new bucketing node
new_convert_dtypes_node = new_start.kwargs["out"]
assert isinstance(new_convert_dtypes_node, fx.Node)
assert (
new_convert_dtypes_node.target
== torch.ops.bucketing._pre_bucket_all_gather.default
)
for n in fused_convert_dtypes:
erased_to_new[n] = new_convert_dtypes_node
# Transfer all dependencies from old nodes to new nodes
self.aug_graph.transfer_erased_node_deps(erased_to_new)

View File

@ -11,7 +11,7 @@ from typing import Any, Literal
import torch
import torch.fx as fx
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.comm_analysis import estimate_fx_collective_size
from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor
from torch._inductor.fx_passes.memory_estimator import (
_is_releasable,
@ -45,21 +45,26 @@ def get_group_name(n: fx.Node) -> str:
def get_custom_estimation(
n: fx.Node,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
| None = None,
override_size: int | None = None,
) -> float | None:
if custom_runtime_estimation is None:
return None
return custom_runtime_estimation(n)
return custom_runtime_estimation(n, override_size)
def estimate_collective_time(
n: fx.Node,
override_size: int | None = None,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
| None = None,
) -> float:
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None:
if (
est := get_custom_estimation(n, custom_runtime_estimation, override_size)
) is not None:
return est
# Use analytical model (benchmarking is handled separately in alignment)
@ -99,7 +104,8 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
def benchmark_node_with_cache_key(
n: fx.Node,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
| None = None,
) -> tuple[float, str | None]:
"""Benchmark a compute node and return (runtime, cache_key)."""
assert is_compute_node(n)
@ -142,7 +148,9 @@ def benchmark_node_with_cache_key(
if unbacked_tensor:
return 0, key
if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None:
if (
est := get_custom_estimation(n, custom_runtime_estimation, None)
) is not None:
set_cached_node_time(key, est)
return est, key
@ -154,7 +162,8 @@ def benchmark_node_with_cache_key(
def benchmark_node(
n: fx.Node,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
| None = None,
) -> float:
return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0]
@ -236,7 +245,7 @@ class OverlapScheduler:
insert_overlap_deps: bool,
compute_overlap_multipler: float,
max_coll_distance: int,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None,
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None,
collective_estimator: Literal["analytical", "benchmark"],
):
self.gm = gm
@ -318,7 +327,7 @@ class OverlapScheduler:
info = CollectiveInfo(
start_node=start,
wait_node=node,
size_bytes=estimate_fx_collective_size(start),
size_bytes=estimate_fx_collective_memory_footprint(start),
estimated_time_ms=coll_time_ms,
exposed_time_ms=coll_time_ms, # Initially fully exposed
)
@ -431,7 +440,10 @@ class OverlapScheduler:
# Benchmark CUDA events (non-deterministic, needs alignment)
# Skip collectives with custom estimation
for n in collective_nodes:
if get_custom_estimation(n, self.custom_runtime_estimation) is not None:
if (
get_custom_estimation(n, self.custom_runtime_estimation, None)
is not None
):
continue
# Benchmark actual size
@ -1000,7 +1012,8 @@ def schedule_overlap_bucketing(
insert_overlap_deps: bool = False,
compute_overlap_multipler: float = 1.0,
max_coll_distance: int = 1000,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
| None = None,
collective_estimator: Literal["analytical", "benchmark"] = "analytical",
) -> torch.fx.GraphModule:
"""Schedule nodes to maximize compute-collective overlap.

View File

@ -0,0 +1,9 @@
from ._core import ComplexTensor
from ._ops import ComplexTensorMode, is_complex_tensor
__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"]
ComplexTensor.__module__ = __name__
ComplexTensorMode.__module__ = __name__
is_complex_tensor.__module__ = __name__

View File

@ -0,0 +1,151 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from typing_extensions import Self
import torch
from torch import Tensor
from torch.autograd import Function
if TYPE_CHECKING:
from torch._ops import OpOverload
from torch._prims_common import DeviceLikeType
from torch.autograd.function import FunctionCtx
class ComplexTensor(Tensor):
"""A class that decomposes all ops on complex Tensors into their real and imaginary parts."""
_re: Tensor
_im: Tensor
def __new__(cls, real: Tensor, imag: Tensor) -> Self:
"""Initialize a ComplexTensor from its real and imaginary parts."""
from ._ops.common import REAL_TO_COMPLEX
shape = real.shape
device = real.device
# TODO (hameerabbasi): `torch.compile` sometimes fails here without making these
# contiguous. Why?
real = real.contiguous()
imag = imag.contiguous()
# TODO (hameerabbasi):
# What should we do with dtype?
# We could convert to the complex type (float32 -> complex64), but we
# can't use that model for say `bfloat16` which does not have a
# corresponding complex dtype.
# If we want to support this complex rep using any float type (see
# https://github.com/pytorch/pytorch/issues/95100)
# We either need to:
# 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere
# else.
# 2) We use the real float dtype here, and it is up to the user to know
# that dtype=float<size> here really means complex<2xSize> with dtype
# matching that of re/im parts alone
# I'm going with 1 for now, so that I can make gradcheck and some complex
# ops work properly, but might want to discuss this in the RFP.
dtype = REAL_TO_COMPLEX.get(real.dtype)
if dtype is None:
raise TypeError(
"Unsupported dtype for constituent tensors. Supported dtypes are: "
f"{set(REAL_TO_COMPLEX.keys())!r}."
)
storage_offset = real.storage_offset()
strides = real.stride()
layout = real.layout
pin_memory = real.is_pinned()
assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}"
assert device == imag.device, (
f"Expected imag device {device}, got {imag.device}"
)
assert real.dtype == imag.dtype, (
f"Expected imag dtype {real.dtype}, got {imag.dtype}"
)
assert pin_memory == imag.is_pinned(), (
f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}"
)
res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
shape,
device=device,
dtype=dtype,
storage_offset=storage_offset,
strides=strides,
pin_memory=pin_memory,
layout=layout,
requires_grad=False,
)
res._re = real.clone().detach()
res._im = imag.clone().detach()
return res
@property
def re(self) -> Tensor:
return self._re
@property
def im(self) -> Tensor:
return self._im
@classmethod
def __torch_dispatch__(
cls,
func: OpOverload,
types: tuple[type, ...],
args: tuple = (),
kwargs: dict | None = None,
):
from ._ops.common import lookup_complex
kwargs = {} if kwargs is None else kwargs
impl = lookup_complex(func, *args, **kwargs)
if impl is None:
return NotImplemented
return impl(*args, **kwargs)
@staticmethod
def from_interleaved(t: Tensor) -> ComplexTensor:
t_real = torch.real(t)
t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real)
return Complex.apply(t_real, t_imag)
def as_interleaved(self) -> Tensor:
return torch.complex(self.real, self.imag)
@staticmethod
def __tensor_unflatten__(
inner_tensors: dict[str, Tensor],
meta: Any,
outer_size: tuple[int, ...],
outer_stride: tuple[int, ...],
) -> ComplexTensor:
assert meta is None
re, im = inner_tensors["re"], inner_tensors["im"]
return ComplexTensor(re, im)
def __tensor_flatten__(self) -> tuple[list[str], Any]:
return ["re", "im"], None
def __repr__(self, *, tensor_contents=None) -> str:
return f"ComplexTensor(real={self.re!r}, imag={self.im!r})"
def is_pinned(self, device: DeviceLikeType | None = None) -> bool:
return self.re.is_pinned(device)
class Complex(Function):
@staticmethod
def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override]
return ComplexTensor(real, imag)
@staticmethod
def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override]
return grad_output.real, grad_output.imag

View File

@ -0,0 +1,5 @@
from . import aten, prims
from .common import ComplexTensorMode, is_complex_tensor
__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"]

View File

@ -0,0 +1,921 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from .._core import ComplexTensor
from .common import (
_get_func_name,
COMPLEX_TO_REAL,
complex_to_real_dtype,
is_complex,
OpType,
promote_tensors,
register_binary_nonlinear,
register_complex,
register_error,
register_force_test,
register_simple,
split_complex_arg,
split_complex_tensor,
)
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Any
aten = torch.ops.aten
def register_binary_linear(op: OpType):
def impl_with_alpha(
lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs
) -> ComplexTensor:
return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs)
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
alpha = kwargs.pop("alpha", None)
if alpha is not None:
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs)
a_r, a_i = split_complex_arg(lhs)
b_r, b_i = split_complex_arg(rhs)
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
u = op(a_r, b_r, *args, **kwargs)
v = op(a_i, b_i, *args, **kwargs)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
return register_complex(op, impl)
@register_complex(aten.real)
def real_impl(self: ComplexTensor) -> torch.Tensor:
re, _ = split_complex_tensor(self)
return re
@register_complex(aten.imag)
def imag_impl(self: ComplexTensor) -> torch.Tensor:
_, im = split_complex_tensor(self)
return im
@register_complex(aten.is_pinned)
def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool:
return self.is_pinned(device)
SIMPLE_OPS_LIST = [
aten.slice,
aten.flatten,
aten.view,
aten.diagonal,
aten.expand,
aten.unsqueeze,
aten.unsqueeze_,
aten.mean,
aten.sum,
aten.clone,
aten.neg,
aten.flip,
aten.permute,
aten.repeat,
aten.index_select,
aten.split,
aten.split_with_sizes,
aten.cumsum,
aten.detach,
aten.select,
aten.squeeze,
aten.zero_,
aten.transpose,
aten.t,
aten.gather,
]
for simple_op in SIMPLE_OPS_LIST:
globals()[_get_func_name(simple_op)] = register_simple(simple_op)
# TODO (hameerabbasi): Not being tested
SIMPLE_FORCE_TESTED_OPS = [
aten.copy,
aten.col2im,
aten.alias,
aten.lift_fresh,
aten._unsafe_view,
aten.index,
aten._neg_view,
aten.avg_pool2d,
aten.avg_pool3d,
aten.avg_pool2d_backward,
aten.avg_pool3d_backward,
aten.masked_scatter_backward,
aten.select_backward,
aten.slice_backward,
aten.embedding,
]
for simple_op in SIMPLE_FORCE_TESTED_OPS:
globals()[_get_func_name(simple_op)] = register_force_test(
simple_op, register_simple(simple_op)
)
del simple_op
# some binary ops which we can stamp out
mul_impl = register_binary_nonlinear(aten.mul)
mul__impl = register_binary_nonlinear(aten.mul_)
mm_impl = register_binary_nonlinear(aten.mm)
dot_impl = register_binary_nonlinear(aten.dot)
bmm_impl = register_binary_nonlinear(aten.bmm)
# TODO (hameerabbasi): Not being tested
convolution_impl = register_force_test(
aten.convolution, register_binary_nonlinear(aten.convolution)
)
slice_scatter_impl = register_force_test(
aten.slice_scatter, register_binary_linear(aten.slice_scatter)
)
select_scatter_impl = register_force_test(
aten.select_scatter, register_binary_linear(aten.select_scatter)
)
add_impl = register_binary_linear(aten.add)
add__impl = register_binary_linear(aten.add_)
sub_impl = register_binary_linear(aten.sub)
sub__impl = register_binary_linear(aten.sub_)
diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter)
fill__impl = register_binary_linear(aten.fill_)
@register_complex(aten.rsub)
def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor:
if alpha is None:
return torch.sub(rhs, lhs) # type: ignore[bad-return]
return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return]
@register_complex(aten.div)
@register_complex(aten.true_divide)
def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None):
if rounding_mode is not None:
raise NotImplementedError(
"`rounding_mode` other than `None` not implemented for`ComplexTensor`."
)
a_r, a_i = split_complex_tensor(lhs)
if not is_complex(rhs):
return ComplexTensor(a_r / rhs, a_i / rhs)
b_r, b_i = split_complex_arg(rhs)
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
num_r = a_r * b_r + a_i * b_i
num_i = a_i * b_r - a_r * b_i
den = b_r * b_r + b_i * b_i
return ComplexTensor(
(num_r / den).to(out_dt),
(num_i / den).to(out_dt),
)
@register_complex(aten.reciprocal)
def reciprocal_impl(self: ComplexTensor):
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
den = self_r * self_r + self_i * self_i
return ComplexTensor(
aten.div(self_r, den).to(out_dt),
aten.div(-self_i, den).to(out_dt),
)
# reductions
@register_complex(aten.prod)
def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
dtype = kwargs.pop("dtype", out_dt)
kwargs["dtype"] = complex_to_real_dtype(self.dtype)
prod_r = torch.prod(torch.abs(self), *args, **kwargs)
sum_phi = torch.sum(torch.angle(self), *args, **kwargs)
u = prod_r * torch.cos(sum_phi)
v = prod_r * torch.sin(sum_phi)
return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return]
@register_complex(aten.pow)
def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor:
out_dt, (self, exponent) = promote_tensors(self, exponent)
return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.cumprod)
def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
dtype = kwargs.pop("dtype", self.dtype)
kwargs["dtype"] = complex_to_real_dtype(dtype)
prod_r = torch.cumprod(torch.abs(self), *args, **kwargs)
sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs)
u = prod_r * torch.cos(sum_phi)
v = prod_r * torch.sin(sum_phi)
return ComplexTensor(u, v)
# unary funcs,
# most of these are simple or require some kind of identity
@register_complex(aten.abs)
def abs_impl(self: ComplexTensor) -> torch.Tensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
result = torch.hypot(x, y)
return result.to(out_dt)
@register_complex(aten.angle)
def angle_impl(self: ComplexTensor) -> torch.Tensor:
x, y = split_complex_tensor(self)
return torch.atan2(y, x)
@register_complex(aten.acos)
def acos_impl(self: ComplexTensor) -> ComplexTensor:
_, y = split_complex_tensor(self)
acosh_z = torch.acosh(self)
assert isinstance(acosh_z, ComplexTensor)
acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z)
sign_im = 2 * torch.signbit(y) - 1
return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re))
@register_complex(aten.asin)
def asin_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
asinh_iz = torch.asinh(ComplexTensor(-y, x))
assert isinstance(asinh_iz, ComplexTensor)
asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz)
return ComplexTensor(asinh_iz_im, -asinh_iz_re)
@register_complex(aten.atan)
def atan_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
tanh_iz = torch.atanh(ComplexTensor(-y, x))
assert isinstance(tanh_iz, ComplexTensor)
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
@register_complex(aten.asinh)
def asinh_impl(self: ComplexTensor) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.acosh)
def acosh_impl(self: ComplexTensor) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.atanh)
def atanh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
ret = 0.5 * (
torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y))
)
assert isinstance(ret, ComplexTensor)
ret_re, ret_im = split_complex_tensor(ret)
return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt))
@register_complex(aten.cos)
def cos_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return]
@register_complex(aten.cosh)
def cosh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
u = torch.cosh(x) * torch.cos(y)
v = torch.sinh(x) * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.sin)
def sin_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
sinh_iz = torch.sinh(ComplexTensor(-y, x))
assert isinstance(sinh_iz, ComplexTensor)
sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz)
return ComplexTensor(sinh_iz_im, -sinh_iz_re)
@register_complex(aten.sinh)
def sinh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
u = torch.sinh(x) * torch.cos(y)
v = torch.cosh(x) * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.tan)
def tan_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
tanh_iz = torch.tanh(ComplexTensor(-y, x))
assert isinstance(tanh_iz, ComplexTensor)
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
@register_complex(aten.tanh)
def tanh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
_2x = 2 * x
_2y = 2 * y
_d = torch.cosh(_2x) + torch.cos(_2y)
_2xsh = torch.sinh(_2x)
out_re = _2xsh / _d
out_im = torch.sin(_2y) / _d
return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt))
@register_complex(aten.exp)
def exp_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
ex = torch.exp(x)
u = ex * torch.cos(y)
v = ex * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.expm1)
def expm1_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
# TODO (hameerabbasi): The two lines below may have numerical issues
ex = torch.exp(x)
u = ex * torch.cos(y) - 1
v = ex * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.log)
def log_impl(self: ComplexTensor) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
re = torch.log(torch.abs(self))
im = torch.angle(self)
return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.log1p)
def log1p_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
# TODO (hameerabbasi): The line below may have numerical issues
return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return]
@register_complex(aten.any)
def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
x, y = split_complex_tensor(self)
return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs)
@register_complex(aten.all)
def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
x, y = split_complex_tensor(self)
return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs)
@register_complex(aten.eq)
def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
a_r, a_i = split_complex_arg(self)
b_r, b_i = split_complex_arg(rhs)
return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs)
@register_complex(aten.ne)
def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
a_r, a_i = split_complex_tensor(self)
b_r, b_i = split_complex_arg(rhs)
return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs)
@register_complex(aten.isnan)
def isnan_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.isnan(re) | torch.isnan(im)
@register_complex(aten.isinf)
def isinf_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.isinf(re) | torch.isinf(im)
@register_complex(aten.isfinite)
def isfinite_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.isfinite(re) & torch.isfinite(im)
@register_complex(aten.isclose)
def isclose_impl(
self: ComplexTensor,
rhs: ComplexTensor,
rtol=1e-5,
atol=1e-8,
equal_nan: bool = False,
) -> torch.Tensor:
abs_diff = torch.abs(self - rhs)
abs_other = torch.abs(rhs)
basic_condition = abs_diff <= (rtol * abs_other + atol)
# This is the nontrivial part
if equal_nan:
a_r, a_i = split_complex_tensor(self)
b_r, b_i = split_complex_arg(rhs)
a_r_nan = torch.isnan(a_r)
b_r_nan = torch.isnan(b_r)
a_i_nan = torch.isnan(a_i)
b_i_nan = torch.isnan(b_i)
a_nan = a_r_nan | a_i_nan
# This logical expression makes sure that the isnan of both the real and imaginary parts
# matches (so 1 + nan*i doesn't equal nan + 1*i)
equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan
return basic_condition | equal_nan_condition
return basic_condition
ERROR_OPS_LIST = [
aten.lt,
aten.le,
aten.gt,
aten.ge,
aten.amin,
aten.amax,
aten.clamp,
aten.ceil,
aten.floor,
aten.minimum,
aten.maximum,
aten.trunc,
aten.sign,
aten.argmax,
aten.argmin,
aten.sort,
aten.topk,
aten.round,
aten.fmod,
]
ERROR_TYPES = {
aten.minimum: RuntimeError,
aten.maximum: RuntimeError,
aten.argmax: RuntimeError,
aten.argmin: RuntimeError,
aten.sort: RuntimeError,
aten.topk: RuntimeError,
}
for err_op in ERROR_OPS_LIST:
globals()[_get_func_name(err_op)] = register_error(
err_op, ERROR_TYPES.get(err_op, NotImplementedError)
)
del err_op
@register_complex(aten.masked_scatter)
def masked_scatter_impl(
self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor
) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
source_r, source_i = split_complex_arg(source)
ret_r = torch.masked_scatter(self_r, mask, source_r)
ret_i = torch.masked_scatter(self_i, mask, source_i)
return ComplexTensor(ret_r, ret_i)
@register_complex(aten.where)
def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor:
x_r, x_i = split_complex_arg(x)
y_r, y_i = split_complex_arg(y)
ret_r = torch.where(mask, x_r, y_r)
ret_i = torch.where(mask, x_i, y_i)
return ComplexTensor(ret_r, ret_i)
@register_complex(aten.full_like)
def full_like_impl(
input: ComplexTensor,
fill_value: complex,
*args,
dtype: torch.dtype | None = None,
**kwargs,
) -> torch.Tensor | ComplexTensor:
# Note: Cannot be merged with the cases below due to the `fill_value` argument
input_r, input_i = split_complex_tensor(input)
if dtype is not None and dtype not in COMPLEX_TO_REAL:
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs)
if dtype is not None:
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
fv_r, fv_i = split_complex_arg(fill_value)
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
return ComplexTensor(ret_r, ret_i)
def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]:
def impl(
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
) -> torch.Tensor | ComplexTensor:
self_re, self_im = split_complex_tensor(self)
if dtype is not None and dtype not in COMPLEX_TO_REAL:
return op(self_re, *args, dtype=dtype, **kwargs)
if dtype is not None:
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
ret_re = op(self_re, *args, **kwargs)
ret_im = op(self_im, *args, **kwargs)
return ComplexTensor(ret_re, ret_im)
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
LIKE_OPS_LIST = [
aten.empty_like,
aten.zeros_like,
aten.randn_like,
aten.new_zeros,
]
for like_op in LIKE_OPS_LIST:
globals()[_get_func_name(like_op)] = register_like(like_op)
del like_op
@register_complex(aten.cat)
def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor:
tensors_r = []
tensors_i = []
for t in tensors:
t_r, t_i = split_complex_arg(t)
tensors_r.append(t_r)
tensors_i.append(t_i)
ret_r = torch.cat(tensors_r, dim=dim)
ret_i = torch.cat(tensors_i, dim=dim)
return ComplexTensor(ret_r, ret_i)
@register_complex(aten.sgn)
def sgn_impl(self: ComplexTensor) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
abs_self = torch.abs(ComplexTensor(self_r, self_i))
mask = (self_r != 0) | (self_i != 0)
masked_sgn = ComplexTensor(
(self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt)
)
return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return]
@register_complex(aten.sqrt)
def sqrt_impl(self: ComplexTensor) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
self = ComplexTensor(self_r, self_i)
self_abs_sqrt = torch.sqrt(torch.abs(self))
self_half_angle = 0.5 * torch.angle(self)
ret_r = self_abs_sqrt * torch.cos(self_half_angle)
ret_i = self_abs_sqrt * torch.sin(self_half_angle)
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
@register_complex(aten.rsqrt)
def rsqrt_impl(self: ComplexTensor) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
self = ComplexTensor(self_r, self_i)
self_abs_rsqrt = torch.rsqrt(torch.abs(self))
self_neg_half_angle = -0.5 * torch.angle(self)
ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle)
ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle)
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
@register_complex(aten.addmm)
def addmm_impl(
input: ComplexTensor,
mat1: ComplexTensor,
mat2: ComplexTensor,
out_dtype: torch.dtype | None = None,
beta: complex = 1,
alpha: complex = 1,
) -> ComplexTensor:
ret = beta * input + alpha * torch.mm(mat1, mat2)
assert isinstance(ret, ComplexTensor)
ret_r, ret_i = split_complex_tensor(ret)
if out_dtype is not None:
out_dtype = COMPLEX_TO_REAL[out_dtype]
ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype)
return ComplexTensor(ret_r, ret_i)
def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return (re != 0) | (im != 0)
def register_nonzero_impl(op: OpType):
def nonzero_impl(
self: ComplexTensor, other: ComplexTensor, *args, **kwargs
) -> torch.Tensor:
return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs)
func_name = _get_func_name(op)
nonzero_impl.__name__ = func_name
nonzero_impl.__qualname__ = func_name
return register_complex(op, nonzero_impl)
logical_and_impl = register_nonzero_impl(aten.logical_and)
logical_or_impl = register_nonzero_impl(aten.logical_or)
logical_xor_impl = register_nonzero_impl(aten.logical_xor)
@register_complex(aten.logical_not)
def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
return torch.logical_not(elemwise_nonzero(self), *args, **kwargs)
@register_complex(aten.view_as_real)
def view_as_real_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.stack([re, im], dim=-1)
@register_complex(aten.linalg_vector_norm)
def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs)
@register_force_test(aten.copy_)
def copy__impl(self: ComplexTensor, src, *args, **kwargs):
self_re, self_im = split_complex_tensor(self)
src_re, src_im = split_complex_arg(src)
ret_re = self_re.copy_(src_re, *args, **kwargs)
ret_im = self_im.copy_(src_im, *args, **kwargs)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten._local_scalar_dense)
def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex:
x, y = split_complex_tensor(self)
u = aten._local_scalar_dense(x, *args, **kwargs)
v = aten._local_scalar_dense(y, *args, **kwargs)
return complex(u, v)
@register_complex(aten.allclose)
def allclose_impl(
input: torch.Tensor,
other: torch.Tensor,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
return torch.all(
torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
).item() # type: ignore[bad-return]
@register_complex(aten.stack)
def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
re_im_tuples = [split_complex_arg(self_i) for self_i in self]
u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs)
v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs)
return ComplexTensor(u, v)
# TODO (hameerabbasi): Not being tested
@register_complex(aten._conj_physical)
@register_complex(aten.conj_physical)
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
re, im = split_complex_tensor(self)
return ComplexTensor(re, -im)
# TODO (hameerabbasi): Not being tested
@register_complex(aten._conj)
def _conj_impl(self: ComplexTensor) -> ComplexTensor:
re, im = split_complex_tensor(self)
return ComplexTensor(re, torch._neg_view(im))
@register_complex(aten.index_add)
def index_add_impl(
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
) -> ComplexTensor:
alpha = kwargs.pop("alpha", None)
if alpha is not None:
source = source * alpha
self_re, self_im = split_complex_arg(self)
source_re, source_im = split_complex_arg(source)
ret_re = self_re.index_add(dim, index, source_re)
ret_im = self_im.index_add(dim, index, source_im)
return ComplexTensor(ret_re, ret_im)
# TODO (hameerabbasi): Not being tested
@register_complex(aten.index_add_)
def index_add__impl(
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
) -> ComplexTensor:
alpha = kwargs.pop("alpha", None)
if alpha is not None:
source = source * alpha
self_re, self_im = split_complex_arg(self)
source_re, source_im = split_complex_arg(source)
ret_re = self_re.index_add_(dim, index, source_re)
ret_im = self_im.index_add_(dim, index, source_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.masked_fill)
def masked_fill_impl(
self: ComplexTensor, mask: torch.Tensor, value: complex
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
value_re, value_im = split_complex_arg(value)
ret_re = self_re.masked_fill(mask, value_re)
ret_im = self_im.masked_fill(mask, value_im)
return ComplexTensor(ret_re, ret_im)
# TODO (hameerabbasi): Not being tested
@register_complex(aten.masked_fill_)
def masked_fill__impl(
self: ComplexTensor, mask: torch.Tensor, value: complex
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
value_re, value_im = split_complex_arg(value)
ret_re = self_re.masked_fill_(mask, value_re)
ret_im = self_im.masked_fill_(mask, value_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.constant_pad_nd)
def constant_pad_nd_impl(
self: ComplexTensor, pad, value: complex | None = None
) -> ComplexTensor:
self_re, self_im = split_complex_tensor(self)
if value is None:
ret_re = aten.constant_pad_nd(self_re, pad)
ret_im = aten.constant_pad_nd(self_im, pad)
else:
value_re, value_im = split_complex_arg(value)
ret_re = aten.constant_pad_nd(self_re, pad, value_re)
ret_im = aten.constant_pad_nd(self_im, pad, value_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.var)
def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
self_re, self_im = split_complex_tensor(self)
return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs)
@register_complex(aten.scatter_add)
def scatter_add_impl(
self: ComplexTensor, dim, index, src: ComplexTensor
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
src_re, src_im = split_complex_arg(src)
ret_re = torch.scatter_add(self_re, dim, index, src_re)
ret_im = torch.scatter_add(self_im, dim, index, src_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.scatter_add_)
def scatter_add__impl(
self: ComplexTensor, dim, index, src: ComplexTensor
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
src_re, src_im = split_complex_arg(src)
out_re = self_re.scatter_add_(dim, index, src_re)
out_im = self_im.scatter_add_(dim, index, src_im)
return ComplexTensor(out_re, out_im)
@register_complex(aten.index_put_)
def index_put__impl(
self: ComplexTensor,
indices: tuple[torch.Tensor, ...],
values: ComplexTensor,
accumulate: bool = False,
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
values_re, values_im = split_complex_arg(values)
out_re = self_re.index_put_(indices, values_re, accumulate=accumulate)
out_im = self_im.index_put_(indices, values_im, accumulate=accumulate)
return ComplexTensor(out_re, out_im)
@register_complex(aten.tanh_backward)
def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor):
return out_grad * (1.0 - y * y).conj_physical()
@register_complex(aten.diagonal_backward)
def diagonal_backward(
grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
):
grad_input = grad_output.new_zeros(input_sizes)
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any:
if not isinstance(dt, torch.dtype):
return dt
return COMPLEX_TO_REAL[dt]
def register_to_impl(op: OpType):
"""Register an op similar to `aten.to`, but may have different signatures."""
def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor:
x, y = split_complex_tensor(self)
try:
args = tuple(_dt_to_real(a) for a in args)
kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()}
except KeyError:
return op(x, *args, **kwargs)
return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs))
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
to_impl = register_to_impl(aten.to)
_to_copy_impl = register_to_impl(aten._to_copy)

View File

@ -0,0 +1,317 @@
from collections.abc import Callable
from typing import Any, overload, TypeAlias
from typing_extensions import TypeIs
import torch
from torch import Tensor
from torch._decomp import get_decompositions
from torch._ops import OpOverload, OpOverloadPacket
from torch._refs import is_complex as _is_complex
from torch.types import Number
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from .._core import ComplexTensor
OpType: TypeAlias = OpOverloadPacket | OpOverload
TableType: TypeAlias = dict[OpType, Callable]
# Mapping from ops to implementations
COMPLEX_OPS_TABLE: TableType = {}
COMPLEX_TO_REAL = {
torch.complex128: torch.float64,
torch.complex64: torch.float32,
torch.complex32: torch.float16,
}
REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()}
# Used to promote dtypes in `promote_real_cpu_tensors`
PROMOTE_TYPES = {
torch.float16: torch.float32,
torch.bfloat16: torch.float32,
torch.complex32: torch.complex64,
}
def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]:
r"""Returns True if the input is a ComplexTensor, else False
Args:
a: any input
Examples:
>>> # xdoctest: +SKIP
>>> from torch.complex import ComplexTensor
>>> data = torch.zeros((3, 2), dtype=torch.complex64)
>>> ct = ComplexTensor.from_interleaved(data)
>>> is_complex_tensor(ct)
True
"""
return isinstance(obj, ComplexTensor)
@overload
def promote_tensors(
*tensors: ComplexTensor,
) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ...
@overload
def promote_tensors(
*tensors: Tensor,
) -> tuple[torch.dtype, tuple[Tensor, ...]]: ...
def promote_tensors(
*tensors: Tensor | ComplexTensor,
) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]:
"""
Promotes all tensors to a common dtype.
Additionally promotes CPU tensors to at least `float32`.
"""
tensor = next(t for t in tensors if isinstance(t, Tensor))
out_dt = tensor.dtype
for t in tensors:
if isinstance(t, Tensor):
out_dt = torch.promote_types(out_dt, t.dtype)
prom_dt = PROMOTE_TYPES.get(out_dt, out_dt)
return out_dt, tuple(
t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt)
for t in tensors
)
def register_complex(
op: OpType,
func_impl: Callable | None = None,
):
"""Decorator to register an implementation for some ops in some dispatch tables"""
def inner(func):
if COMPLEX_OPS_TABLE.get(op, func) is not func:
raise RuntimeError(f"Attempted to register multiple functions for {op}")
COMPLEX_OPS_TABLE[op] = func
return func
if func_impl is None:
return inner
return inner(func_impl)
FORCE_TEST_LIST: list[OpType] = []
def register_force_test(op: OpType, *args, **kwargs):
"""Will attempt to test these ops even if they err on "normal" inputs"""
FORCE_TEST_LIST.append(op)
return register_complex(op, *args, **kwargs)
DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload]
def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None:
"""
Lookup an impl from the table.
Try the particular overload first, then the overload packet.
If nothing is found, try the decompositions with both.
"""
return COMPLEX_OPS_TABLE.get(
func,
COMPLEX_OPS_TABLE.get(
func.overloadpacket,
DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)),
),
)
def is_complex(x: Any, /) -> bool:
"""Utility to detect if a given object is (known) to be complex."""
return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex)
@overload
def split_complex_arg(
arg: Tensor | ComplexTensor,
) -> tuple[Tensor, Tensor]: ...
@overload
def split_complex_arg(
arg: complex | Number,
) -> tuple[Number, Number]: ...
def split_complex_arg(
arg: Tensor | ComplexTensor | complex | Number,
) -> tuple[Tensor, Tensor] | tuple[Number, Number]:
"""
Split a complex argument into a real/imaginary component.
If real, use zero for the imaginary part.
"""
if isinstance(arg, ComplexTensor):
return split_complex_tensor(arg)
if isinstance(arg, Tensor):
if is_complex(arg):
return arg.real, arg.imag
return arg, torch.zeros_like(arg)
# TODO (hameerabbasi): Should there be a `torch.SymComplex`?
if isinstance(arg, complex):
return arg.real, arg.imag
if isinstance(arg, float | torch.SymFloat):
return arg, 0.0
if isinstance(arg, int | torch.SymInt):
return arg, 0
if isinstance(arg, bool | torch.SymBool):
return arg, False
raise TypeError(f"Expected tensor or number got, {type(arg)}")
def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]:
"""Split a ComplexTensor into its real and imaginary parts."""
return complex_tensor.re, complex_tensor.im
def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype:
"""Convert a complex dtype to the dtype of its real part. Return other dtypes as-is."""
return COMPLEX_TO_REAL.get(dtype, dtype)
def _get_op_name(op: OpType) -> str:
"""Get the op name from the op."""
if isinstance(op, OpOverload):
op = op.overloadpacket
return str(op).split(".", 1)[1]
def _get_func_name(op: OpType) -> str:
"""Get the name of the implementation function from the op."""
return f"{_get_op_name(op)}_impl"
def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError):
msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`."
def ordered_impl(*args, **kwargs):
raise exc_type(msg)
func_name = _get_func_name(op)
ordered_impl.__name__ = func_name
ordered_impl.__qualname__ = func_name
return register_force_test(op, ordered_impl)
def register_binary_nonlinear(op: OpType) -> Callable:
"""Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ..."""
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
a_r, a_i = split_complex_arg(lhs)
b_r, b_i = split_complex_arg(rhs)
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs)
return ComplexTensor(real.to(out_dt), imag.to(out_dt))
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
def register_simple(op: OpType):
"""Register an op which can be applied independently to the real and complex parts to get the result."""
def impl(
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
) -> ComplexTensor:
x, y = split_complex_tensor(self)
if dtype is not None and dtype not in COMPLEX_TO_REAL:
raise RuntimeError(
"Non-complex `dtype` specified, please write custom impl."
)
if dtype in COMPLEX_TO_REAL:
assert dtype is not None
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
u = op(x, *args, **kwargs)
v = op(y, *args, **kwargs)
u_flat, u_spec = tree_flatten(u)
v_flat, v_spec = tree_flatten(v)
assert u_spec == v_spec
out_flat = [
ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)
]
return tree_unflatten(out_flat, u_spec)
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any:
"""Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is."""
if (
not isinstance(arg, ComplexTensor)
and isinstance(arg, Tensor)
and arg.dtype in COMPLEX_TO_REAL
):
return ComplexTensor.from_interleaved(arg)
return arg
def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any:
"""Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is."""
if isinstance(arg, ComplexTensor):
return arg.as_interleaved()
return arg
class ComplexTensorMode(TorchDispatchMode):
_compile: bool
""" A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """
def __init__(self, _dispatch_key=None, *, _compile: bool = False):
"""Initialize a ComplexTensorMode.
Args:
_dispatch_key: passed on to TorchDispatchMode
_compile: Compile the op before the computation
"""
super().__init__(_dispatch_key)
self._compile = _compile
def __torch_dispatch__(
self,
func: OpOverload,
types: tuple[type],
args: tuple = (),
kwargs: dict[str, Any] | None = None,
):
if kwargs is None:
kwargs = {}
# TODO (hameerabbasi): Test perf with `_compile` set to `True`
if self._compile:
func = torch.compile(func) # type: ignore[bad-assignment]
args = tree_map(_as_complex_tensor, args)
kwargs = tree_map(_as_complex_tensor, kwargs)
return tree_map(_as_interleaved, func(*args, **kwargs))

View File

@ -0,0 +1,34 @@
import torch
from .._core import ComplexTensor
from .common import (
complex_to_real_dtype,
register_complex,
register_force_test,
split_complex_tensor,
)
prims = torch.ops.prims
aten = torch.ops.aten
# TODO (hameerabbasi): Not being tested
@register_force_test(prims.convert_element_type)
def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor:
dtype = complex_to_real_dtype(dtype)
u, v = split_complex_tensor(x)
u_out = prims.convert_element_type(u, dtype)
v_out = prims.convert_element_type(v, dtype)
return ComplexTensor(u_out, v_out)
@register_complex(prims.conj_physical)
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
return aten._conj_physical(self)
@register_complex(prims.conj)
def conj_impl(self: ComplexTensor) -> ComplexTensor:
return aten._conj(self)

View File

@ -138,7 +138,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) {
throw; \
} \
} \
catch (const std::exception& e) { \
catch (const std::exception&) { \
torch::translate_exception_to_python(std::current_exception()); \
return retval; \
}

View File

@ -81,7 +81,7 @@ c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED};
try {
backendType = deviceTypeToBackendType_.at(deviceType);
} catch (const std::out_of_range& e) {
} catch (const std::out_of_range&) {
TORCH_CHECK(
false, "No backend type associated with device type ", deviceType);
}

View File

@ -246,7 +246,7 @@ class UvTcpServer : public UvTcpSocket {
uv_err_name(uv_res),
uv_strerror(uv_res)));
res->cacheSocketPort();
} catch (std::exception& ex) {
} catch (std::exception&) {
res->close();
throw;
}
@ -322,7 +322,7 @@ class UvTcpServer : public UvTcpSocket {
uv_err_name(uv_res),
uv_strerror(uv_res)));
res->cacheSocketPort();
} catch (std::exception& ex) {
} catch (std::exception&) {
res->close();
throw;
}

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -18,6 +18,14 @@ class TORCH_API Request {
virtual const std::string& body() const = 0;
virtual const std::multimap<std::string, std::string>& params() const = 0;
std::string getParam(const std::string& key) const {
auto it = params().find(key);
if (it != params().end()) {
return it->second;
}
return "";
}
};
// Response represents a response to the handler. This conceptually maps to an

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -4209,7 +4210,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4225,6 +4228,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, const py::function& handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4242,12 +4264,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
// Default constructor.
.def(py::init<>())
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
.def("get_param", &::c10d::control_plane::Request::getParam);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(

View File

@ -353,7 +353,7 @@ static PyObject* NodeBase__update_args_kwargs(
Py_CLEAR(node->_kwargs);
node->_kwargs = map_aggregate(args[1], visit_fn);
Py_RETURN_NONE;
} catch (const PythonError& e) {
} catch (const PythonError&) {
return nullptr;
}
}
@ -397,7 +397,7 @@ static PyObject* NodeBase__replace_input_with(
PyObject* update_args[2] = {new_args.get(), new_kwargs.get()};
return NodeBase__update_args_kwargs(self, update_args, 2);
} catch (const PythonError& e) {
} catch (const PythonError&) {
return nullptr;
}
}
@ -802,7 +802,7 @@ static PyObject* py_map_aggregate(
// args[0]: aggregate, args[1]: callable fn
return map_aggregate(
args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); });
} catch (const PythonError& e) {
} catch (const PythonError&) {
return nullptr; // error should already be set
}
}
@ -824,7 +824,7 @@ static PyObject* py_map_arg(
}
return Py_NewRef(a);
});
} catch (const PythonError& e) {
} catch (const PythonError&) {
return nullptr; // error should already be set
}
}

View File

@ -117,7 +117,7 @@ struct type_caster<torch::jit::IValue> {
try {
value = torch::jit::toTypeInferredIValue(src);
return true;
} catch (std::exception& e) {
} catch (std::exception&) {
return false;
}
}
@ -142,7 +142,7 @@ struct type_caster<torch::jit::Symbol> {
std::string src_str;
try {
src_str = py::cast<std::string>(src);
} catch (std::exception& e) {
} catch (std::exception&) {
return false;
}
value = torch::jit::Symbol::fromQualString(src_str);

View File

@ -285,7 +285,7 @@ struct FromImpl<torch::headeronly::HeaderOnlyArrayRef<T>> {
torch_list_push_back(new_list_handle, from(elem)));
}
return from(new_list_handle);
} catch (const std::runtime_error& e) {
} catch (const std::runtime_error&) {
if (new_list_handle != nullptr) {
// clean up memory if an error was thrown
TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle));
@ -553,7 +553,7 @@ struct ToImpl<std::vector<T>> {
}
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
return result;
} catch (const std::runtime_error& e) {
} catch (const std::runtime_error&) {
// clean up memory if an exception is thrown, and rethrow
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
throw;

View File

@ -0,0 +1,82 @@
import logging
import multiprocessing
import socket
# import for registration side effect
import torch.distributed.debug._handlers # noqa: F401
from torch._C._distributed_c10d import _WorkerServer
from torch.distributed.debug._store import get_rank, tcpstore_client
__all__ = [
"start_debug_server",
"stop_debug_server",
]
logger: logging.Logger = logging.getLogger(__name__)
_WORKER_SERVER: _WorkerServer | None = None
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
"""
Start the debug server stack on all workers. The frontend debug server is
only started on rank0 while the per rank worker servers are started on all
ranks.
This server provides an HTTP frontend that allows for debugging slow and
deadlocked distributed jobs across all ranks simultaneously. This collects
data such as stack traces, FlightRecorder events, and performance profiles.
WARNING: This is intended to only be used in trusted network environments.
The debug server is not designed to be secure and should not be exposed to
the public internet. See SECURITY.md for more details.
WARNING: This is an experimental feature and may change at any time.
Args:
port (int): The port to start the frontend debug server on.
worker_port (int): The port to start the worker server on. Defaults to 0, which
will cause the worker server to bind to an ephemeral port.
"""
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _WORKER_SERVER is None, "debug server already started"
assert _DEBUG_SERVER_PROC is None, "debug server already started"
logger.info("Starting debug server on port %d", port)
store = tcpstore_client()
_WORKER_SERVER = _WorkerServer("::", worker_port)
RANK = get_rank()
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
from torch.distributed.debug._frontend import main
if RANK == 0:
_DEBUG_SERVER_PROC = multiprocessing.Process(
target=main, args=(port,), daemon=True
)
_DEBUG_SERVER_PROC.start()
def stop_debug_server() -> None:
"""
Shutdown the debug server and stop the frontend debug server process.
"""
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _DEBUG_SERVER_PROC is not None
assert _WORKER_SERVER is not None
logger.info("Stopping debug server")
_DEBUG_SERVER_PROC.terminate()
_WORKER_SERVER.shutdown()
_DEBUG_SERVER_PROC.join()
_WORKER_SERVER = None
_DEBUG_SERVER_PROC = None

View File

@ -0,0 +1,353 @@
import json
import logging
import socket
import threading
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import parse_qs, urlparse
import requests
from jinja2 import DictLoader, Environment
from torch.distributed.debug._store import get_world_size, tcpstore_client
logger: logging.Logger = logging.getLogger(__name__)
def fetch_all(
endpoint: str, args: str = ""
) -> tuple[list[str], Iterator[requests.Response]]:
store = tcpstore_client()
keys = [f"rank{r}" for r in range(get_world_size())]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
templates = {
"base.html": """
<!doctype html>
<head>
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
<style>
body {
margin: 0;
font-family:
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
font-size: 1rem;
font-weight: 400;
line-height: 1.5;
color: #212529;
text-align: left;
background-color: #fff;
}
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
margin-bottom: .5rem;
font-weight: 500;
line-height: 1.2;
}
nav {
background-color: rgba(0, 0, 0, 0.17);
padding: 10px;
display: flex;
align-items: center;
padding: 16px;
justify-content: flex-start;
}
nav h1 {
display: inline-block;
margin: 0;
}
nav a {
margin: 0 8px;
}
section {
max-width: 1280px;
padding: 16px;
margin: 0 auto;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
</head>
<nav>
<h1>Torch Distributed Debug Server</h1>
<a href="/">Home</a> <!--@lint-ignore-->
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
<a href="/fr_trace">FlightRecorder</a> <!--@lint-ignore-->
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
</nav>
<section class="content">
{% block header %}{% endblock %}
{% block content %}{% endblock %}
</section>
""",
"index.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}Index{% endblock %}</h1>
{% endblock %}
{% block content %}
Hi
{% endblock %}
""",
"raw_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{title}}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ resp.text }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"json_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{ title }}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ format_json(resp.text) }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"profile.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}torch.profiler{% endblock %}</h1>
{% endblock %}
{% block content %}
<form action="/profile" method="get">
<label for="duration">Duration (seconds):</label>
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
<input type="submit" value="Submit">
</form>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<script>
function run{{ i }}() {
var data = {{ resp.text | safe }};
openPerfetto(data);
}
</script>
<button onclick="run{{ i }}()">View {{ i }}</button>
{% endif %}
{% endfor %}
{% endblock %}
""",
}
class _IPv6HTTPServer(ThreadingHTTPServer):
address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore
request_queue_size: int = 1024
class HTTPRequestHandler(BaseHTTPRequestHandler):
frontend: "FrontendServer"
def do_GET(self):
self.frontend._handle_request(self)
def get_path(self) -> str:
return urlparse(self.path).path
def get_query(self) -> dict[str, list[str]]:
return parse_qs(urlparse(self.path).query)
def get_query_arg(
self, name: str, default: object = None, type: type = str
) -> object:
query = self.get_query()
if name not in query:
return default
return type(query[name][0])
class FrontendServer:
def __init__(self, port: int):
# Setup templates
loader = DictLoader(templates)
self._jinja_env = Environment(loader=loader, enable_async=True)
self._jinja_env.globals.update(
zip=zip,
format_json=format_json,
enumerate=enumerate,
)
# Create routes
self._routes = {
"/": self._handle_index,
"/stacks": self._handle_stacks,
"/fr_trace": self._handle_fr_trace,
"/fr_trace_nccl": self._handle_fr_trace_nccl,
"/profile": self._handle_profiler,
}
# Create HTTP server
RequestHandlerClass = type(
"HTTPRequestHandler",
(HTTPRequestHandler,),
{"frontend": self},
)
server_address = ("", port)
self._server = _IPv6HTTPServer(server_address, RequestHandlerClass)
self._thread = threading.Thread(
target=self._serve,
args=(),
daemon=True,
)
self._thread.start()
def _serve(self) -> None:
try:
self._server.serve_forever()
except Exception:
logger.exception("got exception in checkpoint server")
def join(self) -> None:
self._thread.join()
def _handle_request(self, req: HTTPRequestHandler) -> None:
path = req.get_path()
if path not in self._routes:
req.send_error(404, f"Handler not found: {path}")
return
handler = self._routes[path]
try:
resp = handler(req)
except Exception as e:
logger.exception(
"Exception in checkpoint server when handling %s",
path,
)
req.send_error(500, str(e))
return
req.send_response(200)
req.send_header("Content-type", "text/html")
req.end_headers()
req.wfile.write(resp)
def _render_template(self, template: str, **kwargs: object) -> bytes:
return self._jinja_env.get_template(template).render(**kwargs).encode()
def _handle_index(self, req: HTTPRequestHandler) -> bytes:
return self._render_template("index.html")
def _handle_stacks(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_traceback")
return self._render_template(
"raw_resp.html", title="Stacks", addrs=addrs, resps=resps
)
def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("fr_trace_json")
return self._render_template(
"json_resp.html",
title="FlightRecorder",
addrs=addrs,
resps=resps,
)
def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return self._render_template(
"json_resp.html",
title="FlightRecorder NCCL",
addrs=addrs,
resps=resps,
)
def _handle_profiler(self, req: HTTPRequestHandler) -> bytes:
duration = req.get_query_arg("duration", default=1.0, type=float)
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
return self._render_template("profile.html", addrs=addrs, resps=resps)
def main(port: int) -> None:
server = FrontendServer(port=port)
logger.info("Frontend server started on port %d", server._server.server_port)
server.join()

View File

@ -0,0 +1,22 @@
import tempfile
import time
from torch._C._distributed_c10d import _register_handler, _Request, _Response
from torch.profiler import _ExperimentalConfig, profile
def _torch_profile(req: _Request, resp: _Response) -> None:
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
duration = float(req.get_param("duration"))
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(duration)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)

View File

@ -0,0 +1,24 @@
import os
import torch.distributed as dist
def get_rank() -> int:
return int(os.environ["RANK"])
def get_world_size() -> int:
return int(os.environ["WORLD_SIZE"])
def tcpstore_client() -> dist.Store:
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store

View File

@ -9,7 +9,7 @@ from collections.abc import Callable, Iterable
from enum import Enum
from functools import partial
from typing import Any, Optional
from typing_extensions import Self
from typing_extensions import deprecated, Self
from warnings import warn
import torch
@ -408,6 +408,11 @@ class _KinetoProfile:
)
return MemoryProfile(self.profiler.kineto_results)
@deprecated(
"`export_memory_timeline` is deprecated and will be removed in a future version. "
"Please use `torch.cuda.memory._record_memory_history` and `torch.cuda.memory._export_memory_snapshot` instead.",
category=FutureWarning,
)
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
"""Export memory event information from the profiler collected
tree for a given device, and export a timeline plot. There are 3
@ -429,6 +434,11 @@ class _KinetoProfile:
``torch.profiler._memory_profiler.Category``.
Output: Memory timeline written as gzipped JSON, JSON, or HTML.
.. deprecated::
``export_memory_timeline`` is deprecated and will be removed in a future version.
Please use ``torch.cuda.memory._record_memory_history`` and
``torch.cuda.memory._export_memory_snapshot`` instead.
"""
# Default to device 0, if unset. Fallback on cpu.
if device is None:

View File

@ -386,7 +386,7 @@ class DTensorTestBase(MultiProcessTestCase):
@property
def backend(self) -> str:
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
backend = dist.get_default_backend_for_device(self.device_type)
return backend
def init_manual_seed_for_rank(self) -> None:

View File

@ -249,6 +249,8 @@ def object_annotation(obj):
if len(filename) > FRAME_FILENAME_LIMIT:
filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
return f"frame\n{filename}:{obj.f_lineno}"
elif is_cuda_tensor(obj):
return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})"
else:
return f"object\n{type(obj).__module__}.{type(obj).__name__}"