mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 10:04:58 +08:00
Compare commits
14 Commits
optimizer_
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| fbb7943140 | |||
| ebb2001a48 | |||
| ae85307512 | |||
| 7921c0eb0e | |||
| dda2cb3769 | |||
| 4c5042b368 | |||
| e3c5b78999 | |||
| 14f370f551 | |||
| aa22d41f9b | |||
| d1f6dd6105 | |||
| 5333e51195 | |||
| 0e13964b74 | |||
| 20cae808f7 | |||
| 57927a620d |
@ -402,3 +402,6 @@ scikit-build==0.18.1
|
||||
pyre-extensions==0.0.32
|
||||
tabulate==0.9.0
|
||||
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
|
||||
|
||||
Jinja2==3.1.6
|
||||
#Description: required for torch.distributed.debug
|
||||
|
||||
@ -1768,7 +1768,7 @@ test_operator_microbenchmark() {
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||
|
||||
for OP_BENCHMARK_TESTS in optimizer; do
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
|
||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
||||
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
||||
|
||||
3
.github/workflows/docker-builds.yml
vendored
3
.github/workflows/docker-builds.yml
vendored
@ -75,7 +75,8 @@ jobs:
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter,
|
||||
pytorch-linux-jammy-py3-clang12-executorch,
|
||||
# TODO: Re-enable me when docker pin update happens
|
||||
# pytorch-linux-jammy-py3-clang12-executorch,
|
||||
pytorch-linux-jammy-py3.12-triton-cpu,
|
||||
pytorch-linux-noble-riscv64-py3.12-gcc14
|
||||
]
|
||||
|
||||
7
.github/workflows/docker-cache-rocm.yml
vendored
7
.github/workflows/docker-cache-rocm.yml
vendored
@ -50,9 +50,10 @@ jobs:
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}"
|
||||
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
|
||||
1
.github/workflows/trunk.yml
vendored
1
.github/workflows/trunk.yml
vendored
@ -283,6 +283,7 @@ jobs:
|
||||
name: linux-jammy-py3-clang12-executorch
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
if: false # Has been broken for a while
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3-clang12-executorch
|
||||
|
||||
@ -813,8 +813,43 @@ void smooth_l1_kernel(TensorIteratorBase& iter, double beta) {
|
||||
}
|
||||
|
||||
void huber_kernel(TensorIterator& iter, double delta) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() {
|
||||
// Special-case kHalf: compute in float for numerical stability
|
||||
if (iter.dtype() == kHalf) {
|
||||
const float delta_val(static_cast<float>(delta));
|
||||
const Vectorized<float> delta_vec(static_cast<float>(delta));
|
||||
const Vectorized<float> point_five_vec(static_cast<float>(0.5));
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
// scalar lambda: convert half -> float, compute in float, cast back to half
|
||||
[&delta_val] (at::Half a, at::Half b) -> at::Half {
|
||||
float af = static_cast<float>(a);
|
||||
float bf = static_cast<float>(b);
|
||||
float z = std::abs(af - bf);
|
||||
float out = z < delta_val
|
||||
? 0.5f * z * z
|
||||
: delta_val * (z - 0.5f * delta_val);
|
||||
return static_cast<at::Half>(out);
|
||||
},
|
||||
[&delta_vec, &point_five_vec] (Vectorized<Half> a, Vectorized<Half> b) {
|
||||
auto [a0, a1] = convert_half_float(a);
|
||||
auto [b0, b1] = convert_half_float(b);
|
||||
auto z = (a0 - b0).abs();
|
||||
a0 = Vectorized<float>::blendv(
|
||||
point_five_vec * z * z,
|
||||
delta_vec * (z - point_five_vec * delta_vec),
|
||||
z >= delta_vec);
|
||||
z = (a1 - b1).abs();
|
||||
a1 = Vectorized<float>::blendv(
|
||||
point_five_vec * z * z,
|
||||
delta_vec * (z - point_five_vec * delta_vec),
|
||||
z >= delta_vec);
|
||||
return convert_float_half(a0, a1);
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "huber_cpu", [&]() {
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
const scalar_t delta_val(delta);
|
||||
const Vec delta_val_vec(delta_val);
|
||||
@ -835,6 +870,7 @@ void huber_kernel(TensorIterator& iter, double delta) {
|
||||
z >= delta_val_vec);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void sigmoid_backward_kernel(TensorIteratorBase& iter) {
|
||||
|
||||
@ -346,8 +346,9 @@ void dispatch_bf16_grouped_kernel_on_tile_size(
|
||||
bool small = (M <= 128 || N <= 128);
|
||||
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
|
||||
const bool sm10x = properties != nullptr && properties->major == 10;
|
||||
const bool sm11x = properties != nullptr && properties->major == 11;
|
||||
|
||||
if (sm10x) {
|
||||
if (sm10x || sm11x) {
|
||||
if (small){
|
||||
bf16bf16_grouped_gemm_impl_sm90_sm100<
|
||||
cutlass::arch::Sm100,
|
||||
|
||||
@ -958,8 +958,9 @@ void dispatch_fp8_rowwise_kernel_on_sm(
|
||||
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
|
||||
const bool sm9x = properties != nullptr && properties->major == 9;
|
||||
const bool sm10x = properties != nullptr && properties->major == 10;
|
||||
const bool sm11x = properties != nullptr && properties->major == 11;
|
||||
const bool sm12x = properties != nullptr && properties->major == 12;
|
||||
if (!(sm89 || sm9x || sm10x || sm12x)) {
|
||||
if (!(sm89 || sm9x || sm10x || sm11x || sm12x)) {
|
||||
TORCH_CHECK(
|
||||
false, "Rowwise scaling is not currently supported on your device");
|
||||
}
|
||||
@ -968,7 +969,7 @@ void dispatch_fp8_rowwise_kernel_on_sm(
|
||||
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
|
||||
/*ArchTag=*/cutlass::arch::Sm90,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else if (sm10x) {
|
||||
} else if (sm10x || sm11x) {
|
||||
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
|
||||
/*ArchTag=*/cutlass::arch::Sm100,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
"""Microbenchmarks for optimizer operators."""
|
||||
|
||||
|
||||
optimizer_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["adamw", optim.AdamW],
|
||||
["adam", optim.Adam],
|
||||
["sgd", optim.SGD],
|
||||
["rmsprop", optim.RMSprop],
|
||||
["adagrad", optim.Adagrad],
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_configs_long = op_bench.cross_product_configs(
|
||||
num_params=[1, 10, 100],
|
||||
param_size=[100000, 1000000, 10000000],
|
||||
device=["cuda"],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class OptimizerBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, op_func, device, shape=None, num_params=None, param_size=None):
|
||||
if shape is not None:
|
||||
num_params = num_params if num_params is not None else 1
|
||||
self.params = [
|
||||
torch.randn(shape, device=device, requires_grad=True)
|
||||
for _ in range(num_params)
|
||||
]
|
||||
for param in self.params:
|
||||
param.grad = torch.randn(shape, device=device)
|
||||
else:
|
||||
self.params = [
|
||||
torch.randn(param_size, device=device, requires_grad=True)
|
||||
for _ in range(num_params)
|
||||
]
|
||||
for param in self.params:
|
||||
param.grad = torch.randn_like(param)
|
||||
|
||||
kwargs = {"momentum": 0.9} if op_func == optim.SGD else {}
|
||||
self.optimizer = op_func(self.params, lr=0.001, **kwargs)
|
||||
self.inputs = {"dummy": self.params[0]} # Added to run memory benchmarking
|
||||
|
||||
def forward(self, dummy):
|
||||
self.optimizer.step()
|
||||
for param in self.params:
|
||||
param.grad = torch.randn_like(param)
|
||||
return self.params[0]
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
optimizer_list, optimizer_configs_long, OptimizerBenchmark
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
@ -113,6 +113,12 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a")
|
||||
endif()
|
||||
endif()
|
||||
# We will need to gate against CUDA version, because sm_110a is available on CUDA 13.0+
|
||||
if("${_arch}" STREQUAL "110a" AND CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
if(_existing_arch_flags MATCHES ".*compute_110.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_110a,code=sm_110a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "120a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
@ -132,13 +138,13 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
"89;90a;100a;103a;110a;120a;121a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
|
||||
"90a;100a;103a")
|
||||
"90a;100a;103a;110a")
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
@ -987,6 +987,24 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T
|
||||
collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the
|
||||
{func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs.
|
||||
|
||||
|
||||
### torch.distributed.debug HTTP Server
|
||||
|
||||
The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can
|
||||
be started by calling {func}`torch.distributed.debug.start_debug_server`. This
|
||||
allows users to collect data across all workers at runtime.
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.distributed.debug
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:special-members: __init__
|
||||
:member-order: bysource
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Logging
|
||||
|
||||
In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log
|
||||
|
||||
238
test/complex_tensor/test_complex_tensor.py
Normal file
238
test/complex_tensor/test_complex_tensor.py
Normal file
@ -0,0 +1,238 @@
|
||||
# Owner(s): ["module: complex"]
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Support both when imported from elsewhere or directly as a file
|
||||
try:
|
||||
from .utils import (
|
||||
COMPLEX_DTYPES,
|
||||
Descriptor,
|
||||
force_test_op_db,
|
||||
get_overload_packet_from_name,
|
||||
implemented_op_db,
|
||||
TestCase,
|
||||
Variant,
|
||||
)
|
||||
except ImportError:
|
||||
from utils import (
|
||||
COMPLEX_DTYPES,
|
||||
Descriptor,
|
||||
force_test_op_db,
|
||||
get_overload_packet_from_name,
|
||||
implemented_op_db,
|
||||
TestCase,
|
||||
Variant,
|
||||
)
|
||||
|
||||
from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
OpDTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestGradients,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.testing._internal.opinfo.core import OpInfo
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
SKIPS = {
|
||||
Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output",
|
||||
Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output",
|
||||
Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(
|
||||
op=aten.reciprocal, variant=Variant.GradCheck
|
||||
): "Numerical inconsistency",
|
||||
Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(
|
||||
op=aten.true_divide, variant=Variant.GradCheck
|
||||
): "Numerical inconsistency",
|
||||
Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(
|
||||
op=aten.any, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.all, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.allclose, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.conj_physical, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten._conj_physical, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.cumprod, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.index_add, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.diagonal_scatter, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.flip, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.masked_fill, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.masked_scatter, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.rsub, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.ne, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.squeeze, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.index_select, variant=Variant.Distributed
|
||||
): "Sharding propagation failed",
|
||||
Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support",
|
||||
}
|
||||
|
||||
EXTRA_KWARGS = {
|
||||
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): {
|
||||
"rtol": 2e-5,
|
||||
"atol": 5e-5,
|
||||
},
|
||||
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): {
|
||||
"rtol": 1e-4,
|
||||
"atol": 1e-5,
|
||||
},
|
||||
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): {
|
||||
"rtol": 2e-2,
|
||||
"atol": 2e-6,
|
||||
},
|
||||
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 2e-5,
|
||||
"atol": 5e-5,
|
||||
},
|
||||
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 1e-4,
|
||||
"atol": 1e-5,
|
||||
},
|
||||
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 2e-2,
|
||||
"atol": 2e-6,
|
||||
},
|
||||
Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 2e-6,
|
||||
"atol": 1e-2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestComplexTensor(TestCase):
|
||||
_default_dtype_check_enabled = True
|
||||
|
||||
@ops(
|
||||
implemented_op_db,
|
||||
dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=list(COMPLEX_DTYPES),
|
||||
)
|
||||
def test_consistency(self, device, dtype, op: OpInfo):
|
||||
self.check_consistency(device, dtype, op, Variant.Op)
|
||||
|
||||
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
|
||||
def test_maybe_error(self, device, dtype, op: OpInfo):
|
||||
self.check_consistency(device, dtype, op, Variant.Op)
|
||||
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestComplexBwdGradients(TestGradients):
|
||||
_default_dtype_check_enabled = True
|
||||
|
||||
@ops(
|
||||
implemented_op_db,
|
||||
dtypes=OpDTypes.supported_backward,
|
||||
allowed_dtypes=[torch.complex128],
|
||||
)
|
||||
def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None:
|
||||
test_info = Descriptor(
|
||||
op=get_overload_packet_from_name(op.name),
|
||||
device_type=torch.device(device).type,
|
||||
dtype=dtype,
|
||||
variant=Variant.GradCheck,
|
||||
)
|
||||
for xfail_info, reason in SKIPS.items():
|
||||
if xfail_info.matches(test_info):
|
||||
self.skipTest(reason)
|
||||
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!")
|
||||
|
||||
with ComplexTensorMode():
|
||||
op.gradcheck_fast_mode = False
|
||||
self._grad_test_helper(device, dtype, op, op.get_op())
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestComplexTensor, globals())
|
||||
instantiate_device_type_tests(TestComplexBwdGradients, globals())
|
||||
|
||||
|
||||
if dist.is_available():
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestComplexDistributed(TestCase, MultiProcessTestCase):
|
||||
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
|
||||
def test_distributed(self, device, dtype, op: OpInfo):
|
||||
self.check_consistency(device, dtype, op, Variant.Distributed)
|
||||
|
||||
instantiate_device_type_tests(TestComplexDistributed, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
214
test/complex_tensor/utils.py
Normal file
214
test/complex_tensor/utils.py
Normal file
@ -0,0 +1,214 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import auto, Enum
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.complex_tensor._ops.common import (
|
||||
_as_complex_tensor,
|
||||
_as_interleaved,
|
||||
_get_op_name,
|
||||
COMPLEX_OPS_TABLE,
|
||||
COMPLEX_TO_REAL,
|
||||
FORCE_TEST_LIST,
|
||||
OpOverloadPacket,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.opinfo.core import OpInfo
|
||||
|
||||
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
|
||||
|
||||
|
||||
class Variant(Enum):
|
||||
Op = auto()
|
||||
GradCheck = auto()
|
||||
Distributed = auto()
|
||||
|
||||
|
||||
def _as_local(arg: DTensor | Any) -> torch.Tensor | Any:
|
||||
if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)):
|
||||
return arg
|
||||
|
||||
return arg.full_tensor()
|
||||
|
||||
|
||||
def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
|
||||
if not isinstance(arg, torch.Tensor):
|
||||
return arg
|
||||
|
||||
return dist.tensor.DTensor.from_local(_as_complex_tensor(arg))
|
||||
|
||||
|
||||
TRANSFORM_FUNCS = {
|
||||
Variant.Op: _as_complex_tensor,
|
||||
Variant.Distributed: _as_complex_dtensor,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Descriptor:
|
||||
op: OpOverloadPacket
|
||||
variant: Variant | None
|
||||
device_type: str | None = field(default=None)
|
||||
dtype: torch.dtype | None = field(default=None)
|
||||
|
||||
def matches(self, other: Descriptor) -> bool:
|
||||
fields1 = fields(self)
|
||||
fields2 = fields(other)
|
||||
if fields1 != fields2:
|
||||
return False
|
||||
|
||||
for f in fields1:
|
||||
f1 = getattr(self, f.name)
|
||||
f2 = getattr(other, f.name)
|
||||
if f1 is not None and f2 is not None and f1 != f2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class TestCase(PytorchTestCase):
|
||||
def assertSameResult(
|
||||
self,
|
||||
expected: Callable[[], Any],
|
||||
actual: Callable[[], Any],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
try:
|
||||
result_e = expected()
|
||||
exception_e = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
result_e = None
|
||||
exception_e = e
|
||||
|
||||
try:
|
||||
result_a = actual()
|
||||
exception_a = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
result_a = None
|
||||
exception_a = e
|
||||
|
||||
if (exception_e is None) != (exception_a is None):
|
||||
if exception_a is not None and exception_e is None:
|
||||
raise exception_a
|
||||
self.assertIs(
|
||||
type(exception_e),
|
||||
type(exception_a),
|
||||
f"\n{exception_e=}\n{exception_a=}",
|
||||
)
|
||||
|
||||
if exception_e is None:
|
||||
flattened_e, spec_e = tree_flatten(result_e)
|
||||
flattened_a, spec_a = tree_flatten(result_a)
|
||||
|
||||
self.assertEqual(
|
||||
spec_e,
|
||||
spec_a,
|
||||
"Both functions must return a result with the same tree structure.",
|
||||
)
|
||||
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
|
||||
value_e = _as_interleaved(_as_local(value_e))
|
||||
value_a = _as_interleaved(_as_local(value_a))
|
||||
|
||||
self.assertEqual(value_e, value_a, *args, **kwargs)
|
||||
|
||||
def check_consistency(
|
||||
self, device: str, dtype, op: OpInfo, variant: Variant
|
||||
) -> None:
|
||||
try:
|
||||
from .test_complex_tensor import EXTRA_KWARGS, SKIPS
|
||||
except ImportError:
|
||||
from test_complex_tensor import EXTRA_KWARGS, SKIPS
|
||||
test_info = Descriptor(
|
||||
op=get_overload_packet_from_name(op.name),
|
||||
device_type=torch.device(device).type,
|
||||
dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
for xfail_info, reason in SKIPS.items():
|
||||
if xfail_info.matches(test_info):
|
||||
self.skipTest(reason)
|
||||
|
||||
kwargs = {}
|
||||
for extra_info, extra_kw in EXTRA_KWARGS.items():
|
||||
if extra_info.matches(test_info):
|
||||
kwargs = extra_kw
|
||||
break
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
transform_fn = TRANSFORM_FUNCS[variant]
|
||||
|
||||
for sample_input in sample_inputs:
|
||||
|
||||
def expected(sample_input=sample_input):
|
||||
return op(sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
|
||||
subclass_sample = sample_input.transform(transform_fn)
|
||||
|
||||
def actual(subclass_sample=subclass_sample):
|
||||
return op(
|
||||
subclass_sample.input,
|
||||
*subclass_sample.args,
|
||||
**subclass_sample.kwargs,
|
||||
)
|
||||
|
||||
self.assertSameResult(expected, actual, **kwargs)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
complex_op_db = tuple(
|
||||
filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db)
|
||||
)
|
||||
|
||||
|
||||
def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
|
||||
for domain_name in torch.ops:
|
||||
op_namespace = getattr(torch.ops, domain_name)
|
||||
op: OpOverloadPacket | None = getattr(op_namespace, name, None)
|
||||
if op is not None:
|
||||
return op
|
||||
|
||||
raise RuntimeError(f"No op with {name=} found.")
|
||||
|
||||
|
||||
force_test_names = set(map(_get_op_name, FORCE_TEST_LIST))
|
||||
implemented_op_names = (
|
||||
set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names
|
||||
)
|
||||
implemented_op_db = tuple(
|
||||
filter(lambda op: op.name in implemented_op_names, complex_op_db)
|
||||
)
|
||||
force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db))
|
||||
|
||||
tested_op_names = {op.name for op in implemented_op_db} | {
|
||||
op.name for op in force_test_op_db
|
||||
}
|
||||
non_tested_ops = {
|
||||
op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names
|
||||
}
|
||||
|
||||
|
||||
# TODO (hameerabbasi): There are a number of ops that don't have any associated
|
||||
# OpInfos. We still need to write tests for those ops.
|
||||
if len(non_tested_ops) != 0:
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops]))
|
||||
warnings.warn(
|
||||
"Not all implemented ops are tested. List of ops missing tests:"
|
||||
f"\n{textwrap.indent(list_missing_ops, ' ')}",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@ -6,7 +6,7 @@ import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import requires_nccl
|
||||
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
@ -14,6 +14,9 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
c10d_functional = torch.ops.c10d_functional
|
||||
c10d_ops = torch.ops.c10d
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
class TestCommMode(TestCase):
|
||||
@ -28,7 +31,7 @@ class TestCommMode(TestCase):
|
||||
dist.init_process_group(
|
||||
backend="fake", rank=1, world_size=self.world_size, store=store
|
||||
)
|
||||
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device_type = device_type
|
||||
self.world_pg = dist.distributed_c10d._get_default_group()
|
||||
|
||||
def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
|
||||
@ -111,12 +114,12 @@ class TestCommMode(TestCase):
|
||||
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
|
||||
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
def test_comm_mode_with_c10d(self):
|
||||
if not torch.cuda.is_available():
|
||||
if not torch.accelerator.is_available():
|
||||
return
|
||||
|
||||
inp = torch.rand(2, 8, 16).cuda()
|
||||
inp = torch.rand(2, 8, 16).to(device_type)
|
||||
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
|
||||
|
||||
comm_mode = CommDebugMode()
|
||||
|
||||
@ -658,11 +658,11 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_device_mesh_device_conversion(self):
|
||||
# construct a cuda device mesh
|
||||
# construct a gpu device mesh
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
# construct from a cpu local tensor with cuda device mesh
|
||||
# should automatically convert the dist tensor to cuda
|
||||
# construct from a cpu local tensor with gpu device mesh
|
||||
# should automatically convert the dist tensor to gpu
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
@ -711,7 +711,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_2d_mesh(self):
|
||||
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
|
||||
# construct a cuda device mesh
|
||||
# construct a gpu device mesh
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
|
||||
# construct a dist tensor on 2d device mesh and test if works
|
||||
@ -733,7 +733,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_device_mesh_nd(self):
|
||||
# construct a cuda device mesh
|
||||
# construct a gpu device mesh
|
||||
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
# construct a dist tensor on 3d device mesh and test if works
|
||||
@ -1064,8 +1064,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
# Keep everything deterministic.
|
||||
torch.manual_seed(0)
|
||||
tensor = torch.rand(size)
|
||||
if self.device_type == "cuda":
|
||||
return tensor.cuda()
|
||||
if self.device_type != "cpu":
|
||||
return tensor.to(self.device_type)
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_device_type import skipXPUIf
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -47,8 +48,6 @@ from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
@ -95,6 +94,10 @@ aot_eager_graph = aot_autograd(
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
|
||||
"""
|
||||
@ -141,7 +144,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu"
|
||||
return device_type
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@ -160,9 +163,9 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||
res = fn(x)
|
||||
res.to_local().sum().backward()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available")
|
||||
def test_dtensor_basic_export(self):
|
||||
mesh = DeviceMesh("cuda", torch.arange(self.world_size))
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
param = torch.randn(4, 4)
|
||||
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
|
||||
@ -188,10 +191,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
f"""\
|
||||
def forward(self, b_buffer, x):
|
||||
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
||||
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None
|
||||
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None
|
||||
view_as = torch.ops.aten.view_as.default(to, to); to = None
|
||||
dtensor___init__0 = self.dtensor___init__0
|
||||
dtensor_const_func_spec0 = self.dtensor_const_func_spec0
|
||||
@ -206,10 +209,10 @@ def forward(self, b_buffer, x):
|
||||
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
|
||||
self.assertExpectedInline(
|
||||
str(ep.run_decompositions({}).graph_module.code).strip(),
|
||||
"""\
|
||||
f"""\
|
||||
def forward(self, b_parametrizations_buffer_original0, x):
|
||||
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None
|
||||
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
|
||||
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
||||
@ -377,6 +380,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
@skipIfHpu
|
||||
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981")
|
||||
def test_dtensor_dynamic_loss_parallel_log_softmax(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
@ -815,13 +819,13 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
out = layer_norm.permute(0, 2, 1)
|
||||
return out
|
||||
|
||||
x = torch.randn(4, 2, 4, requires_grad=True, device="cuda")
|
||||
x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type)
|
||||
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
|
||||
|
||||
y = torch.randn(4, requires_grad=True, device="cuda")
|
||||
y = torch.randn(4, requires_grad=True, device=self.device_type)
|
||||
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
|
||||
|
||||
z = torch.randn(4, requires_grad=True, device="cuda")
|
||||
z = torch.randn(4, requires_grad=True, device=self.device_type)
|
||||
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
@ -919,7 +923,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
||||
# (allgather collective) inside the fn
|
||||
def fn(x_dt):
|
||||
if x_dt.device_mesh.device_type == "cuda":
|
||||
if x_dt.device_mesh.device_type == f"{self.device_type}":
|
||||
return x_dt + 1
|
||||
else:
|
||||
return x_dt + 2
|
||||
@ -1051,7 +1055,7 @@ def forward(self, primals_1):
|
||||
|
||||
model = FakeTransformer().to(self.device_type)
|
||||
|
||||
tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
|
||||
tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",))
|
||||
|
||||
# apply sequence parallel
|
||||
parallel_plan = {
|
||||
|
||||
@ -27,8 +27,6 @@ from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
@ -541,7 +539,7 @@ class RedistributeTest(DTensorTestBase):
|
||||
local_out_dt = out_dt.to_local()
|
||||
local_expected_dt = expected_dt.to_local()
|
||||
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
|
||||
if TEST_HPU or TEST_CUDA:
|
||||
if torch.accelerator.is_available():
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts()[
|
||||
torch.ops._dtensor.shard_dim_alltoall
|
||||
|
||||
@ -296,8 +296,8 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(dist_tensor.dtype, torch.float32)
|
||||
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_stack(self):
|
||||
mesh_2d = DeviceMesh(
|
||||
self.device_type, torch.arange(self.world_size).reshape(2, 2)
|
||||
|
||||
@ -30,7 +30,7 @@ from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
|
||||
def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0):
|
||||
# for tests, assume a matmul can hide a single collective
|
||||
if "c10" in str(fx_node.target):
|
||||
return 1.0
|
||||
@ -1112,7 +1112,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
|
||||
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
|
||||
def estimate_with_half_compute(fx_node, override_size=None):
|
||||
return estimate_aten_runtime(fx_node, compute_multiplier=0.5)
|
||||
return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5)
|
||||
|
||||
def func(a, b, *, ranks):
|
||||
# Two all_gathers that will be hidden by multiple compute operations
|
||||
@ -1162,6 +1162,56 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
correct = func(a, b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_bucketing_with_convert_dtype(self):
|
||||
"""Test that all_gathers with dtype conversion get bucketed and produce correct results."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# Convert inputs to float16 before all_gather
|
||||
a_fp16 = a.to(torch.float16)
|
||||
b_fp16 = b.to(torch.float16)
|
||||
|
||||
# Two all_gathers with converted dtypes
|
||||
ag1 = _functional_collectives.all_gather_tensor(a_fp16, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b_fp16, 0, ranks)
|
||||
|
||||
# same dtype
|
||||
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d, 0, ranks)
|
||||
|
||||
return ag1, ag2, ag3, ag4
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(4, 4, dtype=torch.float32, device=device_type)
|
||||
b = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 2
|
||||
c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3
|
||||
d = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify convert_element_type ops are removed (dtype conversion handled by _pre_bucket_all_gather)
|
||||
FileCheck().check_not("torch.ops.prims.convert_element_type").run(
|
||||
aten_graph_str
|
||||
)
|
||||
|
||||
# Verify correctness - this tests that dtype conversion is handled correctly
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
def get_toy_model(device_type: str):
|
||||
"""
|
||||
|
||||
56
test/distributed/test_debug.py
Normal file
56
test/distributed/test_debug.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.debug import start_debug_server, stop_debug_server
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
session = requests.Session()
|
||||
retry_strategy = Retry(total=5, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
|
||||
class TestDebug(TestCase):
|
||||
def test_basics(self) -> None:
|
||||
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(store.port)
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
|
||||
port = 25999
|
||||
|
||||
def fetch(path: str) -> str:
|
||||
resp = session.get(f"http://localhost:{port}{path}")
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
|
||||
start_debug_server(port=port)
|
||||
|
||||
self.assertIn("torch profiler", fetch("/"))
|
||||
self.assertIn("View 0", fetch("/profile?duration=0.01"))
|
||||
self.assertIn("test_basics", fetch("/stacks"))
|
||||
self.assertIn("pg_status", fetch("/fr_trace"))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
|
||||
|
||||
# test errors
|
||||
resp = session.get(f"http://localhost:{port}/blah")
|
||||
self.assertEqual(resp.status_code, 404)
|
||||
self.assertIn("Handler not found: /blah", resp.text)
|
||||
|
||||
stop_debug_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -667,6 +667,94 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
||||
str(traced.graph)
|
||||
)
|
||||
|
||||
def test_can_bucket_with_convert_dtype_as_hiding_nodes(self):
|
||||
"""
|
||||
Test that all_gathers can bucket when convert_element_type ops ARE the hiding nodes.
|
||||
|
||||
Graph structure:
|
||||
ag1_start -> convert1 (hides ag1) -> ag1_wait -> ag2_start -> convert2 (hides ag2) -> ag2_wait
|
||||
|
||||
The convert_element_type ops ARE hiding nodes - no matmuls.
|
||||
This tests that dependencies are transferred correctly when convert nodes are erased.
|
||||
"""
|
||||
|
||||
def func(a, b, c):
|
||||
group_name = "0"
|
||||
group_size = 1
|
||||
|
||||
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
a, group_size, group_name
|
||||
)
|
||||
b = torch.ops.prims.convert_element_type.default(b, torch.float16)
|
||||
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
|
||||
|
||||
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
b, group_size, group_name
|
||||
)
|
||||
ag3 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
c, group_size, group_name
|
||||
)
|
||||
|
||||
mm = ag1_out @ ag1_out
|
||||
|
||||
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
|
||||
ag3_out = torch.ops._c10d_functional.wait_tensor(ag3)
|
||||
|
||||
return ag1_out, ag2_out, ag3_out, mm
|
||||
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device, dtype=torch.float32)
|
||||
b = torch.ones(4, 4, device=self.device, dtype=torch.float32)
|
||||
c = torch.ones(4, 4, device=self.device, dtype=torch.float32)
|
||||
|
||||
traced = make_fx(func)(a, b, c)
|
||||
|
||||
# Find nodes
|
||||
ag1, ag2, ag3 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
convert1 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops.prims.convert_element_type.default,
|
||||
)[0]
|
||||
mm = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops.aten.mm.default,
|
||||
)[0]
|
||||
|
||||
hiding_annotations = {
|
||||
ag1: convert1,
|
||||
ag2: mm,
|
||||
ag3: mm,
|
||||
}
|
||||
|
||||
# Build collective info and ancestors
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
graph_str = str(traced.graph)
|
||||
|
||||
f = FileCheck()
|
||||
f.check_count("%all_gather_into_tensor", 1, exactly=True)
|
||||
f.check("pre_bucket_all_gather").check("wait_tensor").check(
|
||||
"%all_gather_into_tensor_out"
|
||||
).run(graph_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -828,9 +828,6 @@ inductor_one_sample["cuda"] = {
|
||||
"nn.functional.fractional_max_pool3d": {f16, f32, f64},
|
||||
"nn.functional.group_norm": {f16},
|
||||
"nn.functional.hinge_embedding_loss": {f16},
|
||||
# Enabling all tests for this test fails randomly
|
||||
# See https://github.com/pytorch/pytorch/issues/129238
|
||||
"nn.functional.huber_loss": {f16},
|
||||
"nn.functional.interpolate.bicubic": {f16},
|
||||
"nn.functional.interpolate.bilinear": {f16},
|
||||
"nn.functional.interpolate.trilinear": {f16},
|
||||
@ -948,9 +945,6 @@ inductor_one_sample["xpu"] = {
|
||||
"nn.functional.fractional_max_pool3d": {f16, f32, f64},
|
||||
"nn.functional.group_norm": {f16},
|
||||
"nn.functional.hinge_embedding_loss": {f16},
|
||||
# Enabling all tests for this test fails randomly
|
||||
# See https://github.com/pytorch/pytorch/issues/129238
|
||||
"nn.functional.huber_loss": {f16},
|
||||
"nn.functional.interpolate.bicubic": {f16},
|
||||
"nn.functional.interpolate.bilinear": {f16},
|
||||
"nn.functional.interpolate.trilinear": {f16},
|
||||
|
||||
@ -110,6 +110,16 @@ class AttentionBlock(nn.Module):
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
|
||||
def pack_sequences(seqs, device):
|
||||
x_packed = torch.cat(seqs, dim=0)
|
||||
seq_lens = torch.tensor([len(s) for s in seqs], device=device)
|
||||
cu_seq = torch.zeros(len(seqs) + 1, device=device, dtype=torch.int32)
|
||||
cu_seq[1:] = seq_lens.cumsum(0)
|
||||
max_len = seq_lens.max().item()
|
||||
|
||||
return x_packed, cu_seq, max_len
|
||||
|
||||
|
||||
def create_variable_length_batch(
|
||||
shape: VarlenShape, device: torch.device, dtype: torch.dtype
|
||||
):
|
||||
@ -119,16 +129,15 @@ def create_variable_length_batch(
|
||||
seq_lengths.append(min(length, shape.max_seq_len))
|
||||
|
||||
seq_lengths = torch.tensor(seq_lengths, device=device)
|
||||
total_tokens = seq_lengths.sum().item()
|
||||
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
sequences = [
|
||||
torch.randn(
|
||||
seq_len, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
for seq_len in seq_lengths
|
||||
]
|
||||
|
||||
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq[1:] = seq_lengths.cumsum(0)
|
||||
|
||||
max_len = seq_lengths.max().item()
|
||||
x_packed, cu_seq, max_len = pack_sequences(sequences, device)
|
||||
x_padded = torch.zeros(
|
||||
shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype
|
||||
)
|
||||
@ -146,7 +155,6 @@ def create_variable_length_batch(
|
||||
"x_packed": x_packed,
|
||||
"x_padded": x_padded,
|
||||
"max_len": max_len,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
|
||||
@ -428,6 +436,143 @@ class TestVarlenAttention(NNTestCase):
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
@skipIfRocm(msg="ROCM does not support variable length attention")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parametrize("is_causal", [False, True])
|
||||
def test_batch_invariance(self, device, dtype, is_causal):
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch_size = 4
|
||||
max_seq_len = 512
|
||||
embed_dim = 1024
|
||||
num_heads = 16
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
|
||||
seq_lengths = []
|
||||
for _ in range(batch_size):
|
||||
length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
|
||||
seq_lengths.append(min(length, max_seq_len))
|
||||
|
||||
sequences_q = [
|
||||
torch.testing.make_tensor(
|
||||
(seq_len, num_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
for seq_len in seq_lengths
|
||||
]
|
||||
sequences_k = [
|
||||
torch.testing.make_tensor(
|
||||
(seq_len, num_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
for seq_len in seq_lengths
|
||||
]
|
||||
sequences_v = [
|
||||
torch.testing.make_tensor(
|
||||
(seq_len, num_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
for seq_len in seq_lengths
|
||||
]
|
||||
|
||||
q_packed_orig = torch.cat(sequences_q, dim=0)
|
||||
k_packed_orig = torch.cat(sequences_k, dim=0)
|
||||
v_packed_orig = torch.cat(sequences_v, dim=0)
|
||||
|
||||
seq_lens = torch.tensor(seq_lengths, device=device)
|
||||
cu_seq_orig = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq_orig[1:] = seq_lens.cumsum(0)
|
||||
max_len_orig = seq_lens.max().item()
|
||||
|
||||
original_output = varlen_attn(
|
||||
q_packed_orig,
|
||||
k_packed_orig,
|
||||
v_packed_orig,
|
||||
cu_seq_orig,
|
||||
cu_seq_orig,
|
||||
max_len_orig,
|
||||
max_len_orig,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
perm = torch.randperm(batch_size)
|
||||
permuted_sequences_q = [sequences_q[perm[i]] for i in range(batch_size)]
|
||||
permuted_sequences_k = [sequences_k[perm[i]] for i in range(batch_size)]
|
||||
permuted_sequences_v = [sequences_v[perm[i]] for i in range(batch_size)]
|
||||
|
||||
q_packed_perm = torch.cat(permuted_sequences_q, dim=0)
|
||||
k_packed_perm = torch.cat(permuted_sequences_k, dim=0)
|
||||
v_packed_perm = torch.cat(permuted_sequences_v, dim=0)
|
||||
|
||||
permuted_seq_lens = torch.tensor(
|
||||
[seq_lengths[perm[i]] for i in range(batch_size)], device=device
|
||||
)
|
||||
cu_seq_perm = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq_perm[1:] = permuted_seq_lens.cumsum(0)
|
||||
max_len_perm = permuted_seq_lens.max().item()
|
||||
|
||||
permuted_output = varlen_attn(
|
||||
q_packed_perm,
|
||||
k_packed_perm,
|
||||
v_packed_perm,
|
||||
cu_seq_perm,
|
||||
cu_seq_perm,
|
||||
max_len_perm,
|
||||
max_len_perm,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
orig_idx = perm[i].item()
|
||||
|
||||
orig_start = cu_seq_orig[orig_idx].item()
|
||||
orig_end = cu_seq_orig[orig_idx + 1].item()
|
||||
orig_seq_output = original_output[orig_start:orig_end]
|
||||
|
||||
perm_start = cu_seq_perm[i].item()
|
||||
perm_end = cu_seq_perm[i + 1].item()
|
||||
perm_seq_output = permuted_output[perm_start:perm_end]
|
||||
|
||||
self.assertEqual(orig_seq_output, perm_seq_output)
|
||||
|
||||
original_grad_out = torch.ones_like(original_output)
|
||||
permuted_grad_out = torch.ones_like(permuted_output)
|
||||
|
||||
original_grad = torch.autograd.grad(
|
||||
outputs=original_output,
|
||||
inputs=q_packed_orig,
|
||||
grad_outputs=original_grad_out,
|
||||
)[0]
|
||||
|
||||
permuted_grad = torch.autograd.grad(
|
||||
outputs=permuted_output,
|
||||
inputs=q_packed_perm,
|
||||
grad_outputs=permuted_grad_out,
|
||||
)[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
orig_idx = perm[i].item()
|
||||
|
||||
orig_start = cu_seq_orig[orig_idx].item()
|
||||
orig_end = cu_seq_orig[orig_idx + 1].item()
|
||||
orig_seq_grad = original_grad[orig_start:orig_end]
|
||||
|
||||
perm_start = cu_seq_perm[i].item()
|
||||
perm_end = cu_seq_perm[i + 1].item()
|
||||
perm_seq_grad = permuted_grad[perm_start:perm_end]
|
||||
|
||||
self.assertEqual(orig_seq_grad, perm_seq_grad)
|
||||
|
||||
|
||||
device_types = ("cuda",)
|
||||
|
||||
|
||||
@ -100,7 +100,9 @@ class Logger:
|
||||
def _set_static_graph(self) -> None: ...
|
||||
|
||||
class _WorkerServer:
|
||||
def __init__(self, socket_path: str) -> None: ...
|
||||
port: int
|
||||
|
||||
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
||||
def get_debug_level(): ...
|
||||
@ -206,6 +208,7 @@ class Store:
|
||||
desired_value: str,
|
||||
) -> bytes: ...
|
||||
def delete_key(self, key: str) -> bool: ...
|
||||
def multi_get(self, keys: list[str]) -> list[bytes]: ...
|
||||
def num_keys(self) -> int: ...
|
||||
def set_timeout(self, timeout: timedelta): ...
|
||||
@overload
|
||||
@ -872,3 +875,15 @@ class ProcessGroupXCCL(Backend):
|
||||
|
||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||
def _current_process_group() -> ProcessGroup: ...
|
||||
|
||||
class _Request:
|
||||
def body(self) -> bytes: ...
|
||||
def get_param(self, str) -> str: ...
|
||||
|
||||
class _Response:
|
||||
def set_content(self, content: str | bytes, content_type: str) -> None: ...
|
||||
def set_status(self, status: int) -> None: ...
|
||||
|
||||
def _register_handler(
|
||||
name: str, handler: Callable[[_Request, _Response], None]
|
||||
) -> None: ...
|
||||
|
||||
@ -60,6 +60,7 @@ class _ExperimentalConfig:
|
||||
verbose: bool = ...,
|
||||
performance_events: list[str] = ...,
|
||||
enable_cuda_sync_events: bool = ...,
|
||||
profile_all_threads: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
class ProfilerConfig:
|
||||
|
||||
@ -341,12 +341,58 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
|
||||
|
||||
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||
sz_bytes = 0
|
||||
for node in fx_node.all_input_nodes:
|
||||
if (t := node.meta.get("val")) is not None:
|
||||
numel = get_size_numel(t.size())
|
||||
sz_bytes += numel * get_dtype_size(t.dtype)
|
||||
return sz_bytes
|
||||
"""Estimate the size of a collective operation in bytes, including inputs and outputs."""
|
||||
input_bytes = None
|
||||
|
||||
args, kwargs = fx_node.args, fx_node.kwargs
|
||||
kwargs = dict(kwargs)
|
||||
|
||||
# dont double count pre-allocated buffer passed in
|
||||
kwargs.pop("out", None)
|
||||
|
||||
def tensor_bytes(t) -> int:
|
||||
return get_size_numel(t.size()) * get_dtype_size(t.dtype)
|
||||
|
||||
def add_inp_bytes(inp: torch.fx.Node):
|
||||
t = inp.meta.get("val", None)
|
||||
if t is None:
|
||||
return
|
||||
|
||||
nonlocal input_bytes
|
||||
if input_bytes is None:
|
||||
input_bytes = 0
|
||||
input_bytes += tensor_bytes(t)
|
||||
|
||||
pytree.tree_map_only(
|
||||
torch.fx.Node,
|
||||
add_inp_bytes,
|
||||
(args, kwargs),
|
||||
)
|
||||
|
||||
output_tensor = fx_node.meta.get("val", None)
|
||||
|
||||
if input_bytes is None or output_tensor is None:
|
||||
return 0
|
||||
|
||||
output_bytes = (
|
||||
get_size_numel(output_tensor.size()) * output_tensor.element_size()
|
||||
) # pyre-ignore
|
||||
|
||||
return input_bytes + output_bytes
|
||||
|
||||
|
||||
def estimate_fx_collective_memory_footprint(fx_node: torch.fx.Node) -> int:
|
||||
"""Estimate the memory footprint of a collective operation in bytes.
|
||||
|
||||
This returns the total bytes that need to be live concurrently in memory.
|
||||
For all_reduce, we divide by 2 since it can be done in-place.
|
||||
"""
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
is_all_reduce_tensor as is_all_reduce,
|
||||
)
|
||||
|
||||
size = estimate_fx_collective_size(fx_node)
|
||||
return size if not is_all_reduce(fx_node) else size // 2
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime_from_fx_node(
|
||||
|
||||
@ -489,15 +489,34 @@ def all_reduce_merge_fn_to_trace(
|
||||
return new_outs
|
||||
|
||||
|
||||
# List of all torch dtypes for serialization through custom ops
|
||||
# TODO: custom ops support list[dtype] input
|
||||
_ALL_DTYPES = tuple(
|
||||
[
|
||||
getattr(torch, attr)
|
||||
for attr in dir(torch)
|
||||
if isinstance(getattr(torch, attr), torch.dtype)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
|
||||
def _pre_bucket_all_gather(
|
||||
ag_ins: list[torch.Tensor],
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtype_ints: list[
|
||||
int
|
||||
], # dtype enum values, that inputs are converted to before all_gather
|
||||
rank: int,
|
||||
) -> torch.Tensor:
|
||||
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
||||
# Convert int indices back to torch.dtype
|
||||
out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
|
||||
ins_split_sizes_bytes = [
|
||||
ag_in.numel() * out_dtype.itemsize
|
||||
for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
|
||||
]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
@ -507,8 +526,14 @@ def _pre_bucket_all_gather(
|
||||
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
||||
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
||||
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
|
||||
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
||||
# View each destination slice as its output dtype, then copy
|
||||
# The copy operation handles dtype conversion from input dtype to output dtype
|
||||
foreach_copy_dsts_typed = [
|
||||
dst.view(out_dtype)
|
||||
for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True)
|
||||
]
|
||||
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
||||
torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened)
|
||||
return new_ag_out
|
||||
|
||||
|
||||
@ -517,9 +542,14 @@ def _pre_bucket_all_gather_fake(
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtype_ints: list[int],
|
||||
rank: int,
|
||||
) -> torch.Tensor:
|
||||
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
||||
out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
|
||||
ins_split_sizes_bytes = [
|
||||
ag_in.numel() * out_dtype.itemsize
|
||||
for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
|
||||
]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
@ -541,12 +571,9 @@ def all_gather_merge_fn_to_trace_custom_ops(
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> list[torch.Tensor]:
|
||||
ag_ins = [
|
||||
torch._prims.convert_element_type(_ag_in, out_dtype)
|
||||
if _ag_in.dtype != out_dtype
|
||||
else _ag_in
|
||||
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
|
||||
]
|
||||
# Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion
|
||||
# by viewing destination slices as output dtypes and letting copy do the conversion
|
||||
ag_ins = _ag_ins
|
||||
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [
|
||||
ag_in.numel() * out_dtype.itemsize
|
||||
@ -557,8 +584,13 @@ def all_gather_merge_fn_to_trace_custom_ops(
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
|
||||
# Convert out_dtypes to indices for custom_op
|
||||
# TODO: custom ops support list[dtype] input
|
||||
out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes]
|
||||
|
||||
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
|
||||
ag_ins, group_size, group_name, dtype, rank
|
||||
ag_ins, group_size, group_name, dtype, out_dtype_ints, rank
|
||||
)
|
||||
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
||||
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
||||
@ -721,6 +753,20 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
||||
return replacements, new_nodes
|
||||
|
||||
|
||||
def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool:
|
||||
node_in = n.args[0]
|
||||
return (
|
||||
is_all_gather_into_tensor(n)
|
||||
and isinstance(node_in, torch.fx.Node)
|
||||
and node_in.op == "call_function"
|
||||
and (
|
||||
node_in.target is torch.ops.prims.convert_element_type.default
|
||||
or node_in.target is torch.ops.aten._to_copy.default
|
||||
)
|
||||
and len(node_in.users) == 1
|
||||
)
|
||||
|
||||
|
||||
def process_collective_bucket(
|
||||
g: torch.fx.Graph,
|
||||
bucket_nodes: list[torch.fx.Node],
|
||||
@ -755,13 +801,7 @@ def process_collective_bucket(
|
||||
|
||||
# Handle convert_element_type operations (for all_gather)
|
||||
node_in = n.args[0]
|
||||
if (
|
||||
is_all_gather_into_tensor(n)
|
||||
and isinstance(node_in, torch.fx.Node) # Add type check
|
||||
and node_in.op == "call_function"
|
||||
and node_in.target is torch.ops.prims.convert_element_type.default
|
||||
and len(node_in.users) == 1
|
||||
):
|
||||
if has_mergeable_all_gather_convert_dtype(n):
|
||||
ag_node_to_pre_nodes[n].append(node_in)
|
||||
node_in = node_in.args[0]
|
||||
|
||||
|
||||
@ -3,12 +3,14 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_key,
|
||||
BucketMode,
|
||||
has_mergeable_all_gather_convert_dtype,
|
||||
is_all_gather_into_tensor as is_all_gather,
|
||||
is_reduce_scatter_tensor as is_reduce_scatter,
|
||||
is_wait_tensor,
|
||||
@ -207,6 +209,7 @@ class OverlapPreservingBucketer:
|
||||
|
||||
prev_event = event
|
||||
position += 1
|
||||
|
||||
return head
|
||||
|
||||
def _populate_node_to_event(self, pg: str) -> None:
|
||||
@ -231,7 +234,6 @@ class OverlapPreservingBucketer:
|
||||
self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
|
||||
|
||||
def bucket_collectives(self) -> None:
|
||||
"""Main entry point for bucketing collectives."""
|
||||
# Group collectives by PG first
|
||||
pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
for start in self.collective_info:
|
||||
@ -281,6 +283,15 @@ class OverlapPreservingBucketer:
|
||||
# Apply topological sort with all dependencies
|
||||
from torch._dynamo.graph_deduplication import _stable_topological_sort
|
||||
|
||||
for n, deps in additional_deps.items():
|
||||
torch._check(
|
||||
not n._erased, lambda: f"Erased node deps not transferred: {n}"
|
||||
)
|
||||
for d in deps:
|
||||
torch._check(
|
||||
not d._erased, lambda: f"Erased node deps not transferred: {d}"
|
||||
)
|
||||
|
||||
_stable_topological_sort(self.graph, additional_deps)
|
||||
|
||||
# After topological sort, preserve dependencies using effect tokens
|
||||
@ -762,6 +773,11 @@ class OverlapPreservingBucketer:
|
||||
old_starts = list(bucket)
|
||||
old_waits = [self.collective_info[n].wait_node for n in bucket]
|
||||
|
||||
fused_convert_dtypes = []
|
||||
for n in old_starts:
|
||||
if has_mergeable_all_gather_convert_dtype(n):
|
||||
fused_convert_dtypes.append(n.args[0])
|
||||
|
||||
# Find where to place the bucketed operations
|
||||
next_node = bucket[0]
|
||||
while next_node in bucket:
|
||||
@ -809,6 +825,22 @@ class OverlapPreservingBucketer:
|
||||
for old_wait in old_waits:
|
||||
erased_to_new[old_wait] = new_wait
|
||||
|
||||
# Handle convert_element_type nodes that were fused and erased
|
||||
# The bucketed operation may have a _pre_bucket op that handles dtype conversion
|
||||
if fused_convert_dtypes:
|
||||
# all gather bucketing may fuse in dtype conversion into the bucketing
|
||||
# if so, we need to transfer hiding deps from the old dtype conversion
|
||||
# to the new bucketing node
|
||||
new_convert_dtypes_node = new_start.kwargs["out"]
|
||||
assert isinstance(new_convert_dtypes_node, fx.Node)
|
||||
assert (
|
||||
new_convert_dtypes_node.target
|
||||
== torch.ops.bucketing._pre_bucket_all_gather.default
|
||||
)
|
||||
|
||||
for n in fused_convert_dtypes:
|
||||
erased_to_new[n] = new_convert_dtypes_node
|
||||
|
||||
# Transfer all dependencies from old nodes to new nodes
|
||||
self.aug_graph.transfer_erased_node_deps(erased_to_new)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any, Literal
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.comm_analysis import estimate_fx_collective_size
|
||||
from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint
|
||||
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor
|
||||
from torch._inductor.fx_passes.memory_estimator import (
|
||||
_is_releasable,
|
||||
@ -45,21 +45,26 @@ def get_group_name(n: fx.Node) -> str:
|
||||
|
||||
def get_custom_estimation(
|
||||
n: fx.Node,
|
||||
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
|
||||
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
|
||||
| None = None,
|
||||
override_size: int | None = None,
|
||||
) -> float | None:
|
||||
if custom_runtime_estimation is None:
|
||||
return None
|
||||
|
||||
return custom_runtime_estimation(n)
|
||||
return custom_runtime_estimation(n, override_size)
|
||||
|
||||
|
||||
def estimate_collective_time(
|
||||
n: fx.Node,
|
||||
override_size: int | None = None,
|
||||
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
|
||||
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
|
||||
| None = None,
|
||||
) -> float:
|
||||
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
|
||||
if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None:
|
||||
if (
|
||||
est := get_custom_estimation(n, custom_runtime_estimation, override_size)
|
||||
) is not None:
|
||||
return est
|
||||
|
||||
# Use analytical model (benchmarking is handled separately in alignment)
|
||||
@ -99,7 +104,8 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
|
||||
|
||||
def benchmark_node_with_cache_key(
|
||||
n: fx.Node,
|
||||
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
|
||||
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
|
||||
| None = None,
|
||||
) -> tuple[float, str | None]:
|
||||
"""Benchmark a compute node and return (runtime, cache_key)."""
|
||||
assert is_compute_node(n)
|
||||
@ -142,7 +148,9 @@ def benchmark_node_with_cache_key(
|
||||
if unbacked_tensor:
|
||||
return 0, key
|
||||
|
||||
if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None:
|
||||
if (
|
||||
est := get_custom_estimation(n, custom_runtime_estimation, None)
|
||||
) is not None:
|
||||
set_cached_node_time(key, est)
|
||||
return est, key
|
||||
|
||||
@ -154,7 +162,8 @@ def benchmark_node_with_cache_key(
|
||||
|
||||
def benchmark_node(
|
||||
n: fx.Node,
|
||||
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
|
||||
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
|
||||
| None = None,
|
||||
) -> float:
|
||||
return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0]
|
||||
|
||||
@ -236,7 +245,7 @@ class OverlapScheduler:
|
||||
insert_overlap_deps: bool,
|
||||
compute_overlap_multipler: float,
|
||||
max_coll_distance: int,
|
||||
custom_runtime_estimation: Callable[[fx.Node], float | None] | None,
|
||||
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None,
|
||||
collective_estimator: Literal["analytical", "benchmark"],
|
||||
):
|
||||
self.gm = gm
|
||||
@ -318,7 +327,7 @@ class OverlapScheduler:
|
||||
info = CollectiveInfo(
|
||||
start_node=start,
|
||||
wait_node=node,
|
||||
size_bytes=estimate_fx_collective_size(start),
|
||||
size_bytes=estimate_fx_collective_memory_footprint(start),
|
||||
estimated_time_ms=coll_time_ms,
|
||||
exposed_time_ms=coll_time_ms, # Initially fully exposed
|
||||
)
|
||||
@ -431,7 +440,10 @@ class OverlapScheduler:
|
||||
# Benchmark CUDA events (non-deterministic, needs alignment)
|
||||
# Skip collectives with custom estimation
|
||||
for n in collective_nodes:
|
||||
if get_custom_estimation(n, self.custom_runtime_estimation) is not None:
|
||||
if (
|
||||
get_custom_estimation(n, self.custom_runtime_estimation, None)
|
||||
is not None
|
||||
):
|
||||
continue
|
||||
|
||||
# Benchmark actual size
|
||||
@ -1000,7 +1012,8 @@ def schedule_overlap_bucketing(
|
||||
insert_overlap_deps: bool = False,
|
||||
compute_overlap_multipler: float = 1.0,
|
||||
max_coll_distance: int = 1000,
|
||||
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
|
||||
custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
|
||||
| None = None,
|
||||
collective_estimator: Literal["analytical", "benchmark"] = "analytical",
|
||||
) -> torch.fx.GraphModule:
|
||||
"""Schedule nodes to maximize compute-collective overlap.
|
||||
|
||||
9
torch/_subclasses/complex_tensor/__init__.py
Normal file
9
torch/_subclasses/complex_tensor/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ._core import ComplexTensor
|
||||
from ._ops import ComplexTensorMode, is_complex_tensor
|
||||
|
||||
|
||||
__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"]
|
||||
|
||||
ComplexTensor.__module__ = __name__
|
||||
ComplexTensorMode.__module__ = __name__
|
||||
is_complex_tensor.__module__ = __name__
|
||||
151
torch/_subclasses/complex_tensor/_core.py
Normal file
151
torch/_subclasses/complex_tensor/_core.py
Normal file
@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.autograd.function import FunctionCtx
|
||||
|
||||
|
||||
class ComplexTensor(Tensor):
|
||||
"""A class that decomposes all ops on complex Tensors into their real and imaginary parts."""
|
||||
|
||||
_re: Tensor
|
||||
_im: Tensor
|
||||
|
||||
def __new__(cls, real: Tensor, imag: Tensor) -> Self:
|
||||
"""Initialize a ComplexTensor from its real and imaginary parts."""
|
||||
from ._ops.common import REAL_TO_COMPLEX
|
||||
|
||||
shape = real.shape
|
||||
device = real.device
|
||||
|
||||
# TODO (hameerabbasi): `torch.compile` sometimes fails here without making these
|
||||
# contiguous. Why?
|
||||
real = real.contiguous()
|
||||
imag = imag.contiguous()
|
||||
|
||||
# TODO (hameerabbasi):
|
||||
# What should we do with dtype?
|
||||
# We could convert to the complex type (float32 -> complex64), but we
|
||||
# can't use that model for say `bfloat16` which does not have a
|
||||
# corresponding complex dtype.
|
||||
# If we want to support this complex rep using any float type (see
|
||||
# https://github.com/pytorch/pytorch/issues/95100)
|
||||
# We either need to:
|
||||
# 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere
|
||||
# else.
|
||||
# 2) We use the real float dtype here, and it is up to the user to know
|
||||
# that dtype=float<size> here really means complex<2xSize> with dtype
|
||||
# matching that of re/im parts alone
|
||||
# I'm going with 1 for now, so that I can make gradcheck and some complex
|
||||
# ops work properly, but might want to discuss this in the RFP.
|
||||
dtype = REAL_TO_COMPLEX.get(real.dtype)
|
||||
if dtype is None:
|
||||
raise TypeError(
|
||||
"Unsupported dtype for constituent tensors. Supported dtypes are: "
|
||||
f"{set(REAL_TO_COMPLEX.keys())!r}."
|
||||
)
|
||||
storage_offset = real.storage_offset()
|
||||
strides = real.stride()
|
||||
layout = real.layout
|
||||
pin_memory = real.is_pinned()
|
||||
|
||||
assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}"
|
||||
assert device == imag.device, (
|
||||
f"Expected imag device {device}, got {imag.device}"
|
||||
)
|
||||
assert real.dtype == imag.dtype, (
|
||||
f"Expected imag dtype {real.dtype}, got {imag.dtype}"
|
||||
)
|
||||
assert pin_memory == imag.is_pinned(), (
|
||||
f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}"
|
||||
)
|
||||
|
||||
res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
shape,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
storage_offset=storage_offset,
|
||||
strides=strides,
|
||||
pin_memory=pin_memory,
|
||||
layout=layout,
|
||||
requires_grad=False,
|
||||
)
|
||||
res._re = real.clone().detach()
|
||||
res._im = imag.clone().detach()
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def re(self) -> Tensor:
|
||||
return self._re
|
||||
|
||||
@property
|
||||
def im(self) -> Tensor:
|
||||
return self._im
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(
|
||||
cls,
|
||||
func: OpOverload,
|
||||
types: tuple[type, ...],
|
||||
args: tuple = (),
|
||||
kwargs: dict | None = None,
|
||||
):
|
||||
from ._ops.common import lookup_complex
|
||||
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
impl = lookup_complex(func, *args, **kwargs)
|
||||
if impl is None:
|
||||
return NotImplemented
|
||||
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_interleaved(t: Tensor) -> ComplexTensor:
|
||||
t_real = torch.real(t)
|
||||
t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real)
|
||||
return Complex.apply(t_real, t_imag)
|
||||
|
||||
def as_interleaved(self) -> Tensor:
|
||||
return torch.complex(self.real, self.imag)
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: dict[str, Tensor],
|
||||
meta: Any,
|
||||
outer_size: tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
) -> ComplexTensor:
|
||||
assert meta is None
|
||||
re, im = inner_tensors["re"], inner_tensors["im"]
|
||||
return ComplexTensor(re, im)
|
||||
|
||||
def __tensor_flatten__(self) -> tuple[list[str], Any]:
|
||||
return ["re", "im"], None
|
||||
|
||||
def __repr__(self, *, tensor_contents=None) -> str:
|
||||
return f"ComplexTensor(real={self.re!r}, imag={self.im!r})"
|
||||
|
||||
def is_pinned(self, device: DeviceLikeType | None = None) -> bool:
|
||||
return self.re.is_pinned(device)
|
||||
|
||||
|
||||
class Complex(Function):
|
||||
@staticmethod
|
||||
def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override]
|
||||
return ComplexTensor(real, imag)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override]
|
||||
return grad_output.real, grad_output.imag
|
||||
5
torch/_subclasses/complex_tensor/_ops/__init__.py
Normal file
5
torch/_subclasses/complex_tensor/_ops/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from . import aten, prims
|
||||
from .common import ComplexTensorMode, is_complex_tensor
|
||||
|
||||
|
||||
__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"]
|
||||
921
torch/_subclasses/complex_tensor/_ops/aten.py
Normal file
921
torch/_subclasses/complex_tensor/_ops/aten.py
Normal file
@ -0,0 +1,921 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .._core import ComplexTensor
|
||||
from .common import (
|
||||
_get_func_name,
|
||||
COMPLEX_TO_REAL,
|
||||
complex_to_real_dtype,
|
||||
is_complex,
|
||||
OpType,
|
||||
promote_tensors,
|
||||
register_binary_nonlinear,
|
||||
register_complex,
|
||||
register_error,
|
||||
register_force_test,
|
||||
register_simple,
|
||||
split_complex_arg,
|
||||
split_complex_tensor,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def register_binary_linear(op: OpType):
|
||||
def impl_with_alpha(
|
||||
lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs
|
||||
) -> ComplexTensor:
|
||||
return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs)
|
||||
|
||||
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
if alpha is not None:
|
||||
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs)
|
||||
a_r, a_i = split_complex_arg(lhs)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
|
||||
u = op(a_r, b_r, *args, **kwargs)
|
||||
v = op(a_i, b_i, *args, **kwargs)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
@register_complex(aten.real)
|
||||
def real_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, _ = split_complex_tensor(self)
|
||||
return re
|
||||
|
||||
|
||||
@register_complex(aten.imag)
|
||||
def imag_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
_, im = split_complex_tensor(self)
|
||||
return im
|
||||
|
||||
|
||||
@register_complex(aten.is_pinned)
|
||||
def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool:
|
||||
return self.is_pinned(device)
|
||||
|
||||
|
||||
SIMPLE_OPS_LIST = [
|
||||
aten.slice,
|
||||
aten.flatten,
|
||||
aten.view,
|
||||
aten.diagonal,
|
||||
aten.expand,
|
||||
aten.unsqueeze,
|
||||
aten.unsqueeze_,
|
||||
aten.mean,
|
||||
aten.sum,
|
||||
aten.clone,
|
||||
aten.neg,
|
||||
aten.flip,
|
||||
aten.permute,
|
||||
aten.repeat,
|
||||
aten.index_select,
|
||||
aten.split,
|
||||
aten.split_with_sizes,
|
||||
aten.cumsum,
|
||||
aten.detach,
|
||||
aten.select,
|
||||
aten.squeeze,
|
||||
aten.zero_,
|
||||
aten.transpose,
|
||||
aten.t,
|
||||
aten.gather,
|
||||
]
|
||||
|
||||
for simple_op in SIMPLE_OPS_LIST:
|
||||
globals()[_get_func_name(simple_op)] = register_simple(simple_op)
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
SIMPLE_FORCE_TESTED_OPS = [
|
||||
aten.copy,
|
||||
aten.col2im,
|
||||
aten.alias,
|
||||
aten.lift_fresh,
|
||||
aten._unsafe_view,
|
||||
aten.index,
|
||||
aten._neg_view,
|
||||
aten.avg_pool2d,
|
||||
aten.avg_pool3d,
|
||||
aten.avg_pool2d_backward,
|
||||
aten.avg_pool3d_backward,
|
||||
aten.masked_scatter_backward,
|
||||
aten.select_backward,
|
||||
aten.slice_backward,
|
||||
aten.embedding,
|
||||
]
|
||||
|
||||
for simple_op in SIMPLE_FORCE_TESTED_OPS:
|
||||
globals()[_get_func_name(simple_op)] = register_force_test(
|
||||
simple_op, register_simple(simple_op)
|
||||
)
|
||||
|
||||
del simple_op
|
||||
|
||||
# some binary ops which we can stamp out
|
||||
mul_impl = register_binary_nonlinear(aten.mul)
|
||||
mul__impl = register_binary_nonlinear(aten.mul_)
|
||||
mm_impl = register_binary_nonlinear(aten.mm)
|
||||
dot_impl = register_binary_nonlinear(aten.dot)
|
||||
bmm_impl = register_binary_nonlinear(aten.bmm)
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
convolution_impl = register_force_test(
|
||||
aten.convolution, register_binary_nonlinear(aten.convolution)
|
||||
)
|
||||
|
||||
slice_scatter_impl = register_force_test(
|
||||
aten.slice_scatter, register_binary_linear(aten.slice_scatter)
|
||||
)
|
||||
select_scatter_impl = register_force_test(
|
||||
aten.select_scatter, register_binary_linear(aten.select_scatter)
|
||||
)
|
||||
|
||||
add_impl = register_binary_linear(aten.add)
|
||||
add__impl = register_binary_linear(aten.add_)
|
||||
sub_impl = register_binary_linear(aten.sub)
|
||||
sub__impl = register_binary_linear(aten.sub_)
|
||||
diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter)
|
||||
fill__impl = register_binary_linear(aten.fill_)
|
||||
|
||||
|
||||
@register_complex(aten.rsub)
|
||||
def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor:
|
||||
if alpha is None:
|
||||
return torch.sub(rhs, lhs) # type: ignore[bad-return]
|
||||
return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.div)
|
||||
@register_complex(aten.true_divide)
|
||||
def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None):
|
||||
if rounding_mode is not None:
|
||||
raise NotImplementedError(
|
||||
"`rounding_mode` other than `None` not implemented for`ComplexTensor`."
|
||||
)
|
||||
a_r, a_i = split_complex_tensor(lhs)
|
||||
if not is_complex(rhs):
|
||||
return ComplexTensor(a_r / rhs, a_i / rhs)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
|
||||
num_r = a_r * b_r + a_i * b_i
|
||||
num_i = a_i * b_r - a_r * b_i
|
||||
den = b_r * b_r + b_i * b_i
|
||||
return ComplexTensor(
|
||||
(num_r / den).to(out_dt),
|
||||
(num_i / den).to(out_dt),
|
||||
)
|
||||
|
||||
|
||||
@register_complex(aten.reciprocal)
|
||||
def reciprocal_impl(self: ComplexTensor):
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
den = self_r * self_r + self_i * self_i
|
||||
return ComplexTensor(
|
||||
aten.div(self_r, den).to(out_dt),
|
||||
aten.div(-self_i, den).to(out_dt),
|
||||
)
|
||||
|
||||
|
||||
# reductions
|
||||
@register_complex(aten.prod)
|
||||
def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
dtype = kwargs.pop("dtype", out_dt)
|
||||
kwargs["dtype"] = complex_to_real_dtype(self.dtype)
|
||||
|
||||
prod_r = torch.prod(torch.abs(self), *args, **kwargs)
|
||||
sum_phi = torch.sum(torch.angle(self), *args, **kwargs)
|
||||
u = prod_r * torch.cos(sum_phi)
|
||||
v = prod_r * torch.sin(sum_phi)
|
||||
return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.pow)
|
||||
def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self, exponent) = promote_tensors(self, exponent)
|
||||
return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.cumprod)
|
||||
def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
kwargs["dtype"] = complex_to_real_dtype(dtype)
|
||||
|
||||
prod_r = torch.cumprod(torch.abs(self), *args, **kwargs)
|
||||
sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs)
|
||||
u = prod_r * torch.cos(sum_phi)
|
||||
v = prod_r * torch.sin(sum_phi)
|
||||
return ComplexTensor(u, v)
|
||||
|
||||
|
||||
# unary funcs,
|
||||
# most of these are simple or require some kind of identity
|
||||
@register_complex(aten.abs)
|
||||
def abs_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
result = torch.hypot(x, y)
|
||||
return result.to(out_dt)
|
||||
|
||||
|
||||
@register_complex(aten.angle)
|
||||
def angle_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.atan2(y, x)
|
||||
|
||||
|
||||
@register_complex(aten.acos)
|
||||
def acos_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
_, y = split_complex_tensor(self)
|
||||
acosh_z = torch.acosh(self)
|
||||
assert isinstance(acosh_z, ComplexTensor)
|
||||
acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z)
|
||||
sign_im = 2 * torch.signbit(y) - 1
|
||||
return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re))
|
||||
|
||||
|
||||
@register_complex(aten.asin)
|
||||
def asin_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
asinh_iz = torch.asinh(ComplexTensor(-y, x))
|
||||
assert isinstance(asinh_iz, ComplexTensor)
|
||||
asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz)
|
||||
return ComplexTensor(asinh_iz_im, -asinh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.atan)
|
||||
def atan_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
tanh_iz = torch.atanh(ComplexTensor(-y, x))
|
||||
assert isinstance(tanh_iz, ComplexTensor)
|
||||
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
|
||||
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.asinh)
|
||||
def asinh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.acosh)
|
||||
def acosh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.atanh)
|
||||
def atanh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
|
||||
ret = 0.5 * (
|
||||
torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y))
|
||||
)
|
||||
assert isinstance(ret, ComplexTensor)
|
||||
ret_re, ret_im = split_complex_tensor(ret)
|
||||
|
||||
return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.cos)
|
||||
def cos_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.cosh)
|
||||
def cosh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
u = torch.cosh(x) * torch.cos(y)
|
||||
v = torch.sinh(x) * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.sin)
|
||||
def sin_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
sinh_iz = torch.sinh(ComplexTensor(-y, x))
|
||||
assert isinstance(sinh_iz, ComplexTensor)
|
||||
sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz)
|
||||
return ComplexTensor(sinh_iz_im, -sinh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.sinh)
|
||||
def sinh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
u = torch.sinh(x) * torch.cos(y)
|
||||
v = torch.cosh(x) * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.tan)
|
||||
def tan_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
tanh_iz = torch.tanh(ComplexTensor(-y, x))
|
||||
assert isinstance(tanh_iz, ComplexTensor)
|
||||
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
|
||||
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.tanh)
|
||||
def tanh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
|
||||
_2x = 2 * x
|
||||
_2y = 2 * y
|
||||
_d = torch.cosh(_2x) + torch.cos(_2y)
|
||||
_2xsh = torch.sinh(_2x)
|
||||
|
||||
out_re = _2xsh / _d
|
||||
out_im = torch.sin(_2y) / _d
|
||||
|
||||
return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.exp)
|
||||
def exp_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
ex = torch.exp(x)
|
||||
u = ex * torch.cos(y)
|
||||
v = ex * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.expm1)
|
||||
def expm1_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
# TODO (hameerabbasi): The two lines below may have numerical issues
|
||||
ex = torch.exp(x)
|
||||
u = ex * torch.cos(y) - 1
|
||||
v = ex * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.log)
|
||||
def log_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
re = torch.log(torch.abs(self))
|
||||
im = torch.angle(self)
|
||||
return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.log1p)
|
||||
def log1p_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
# TODO (hameerabbasi): The line below may have numerical issues
|
||||
return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.any)
|
||||
def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.all)
|
||||
def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.eq)
|
||||
def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
a_r, a_i = split_complex_arg(self)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.ne)
|
||||
def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
a_r, a_i = split_complex_tensor(self)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.isnan)
|
||||
def isnan_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.isnan(re) | torch.isnan(im)
|
||||
|
||||
|
||||
@register_complex(aten.isinf)
|
||||
def isinf_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.isinf(re) | torch.isinf(im)
|
||||
|
||||
|
||||
@register_complex(aten.isfinite)
|
||||
def isfinite_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.isfinite(re) & torch.isfinite(im)
|
||||
|
||||
|
||||
@register_complex(aten.isclose)
|
||||
def isclose_impl(
|
||||
self: ComplexTensor,
|
||||
rhs: ComplexTensor,
|
||||
rtol=1e-5,
|
||||
atol=1e-8,
|
||||
equal_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
abs_diff = torch.abs(self - rhs)
|
||||
abs_other = torch.abs(rhs)
|
||||
basic_condition = abs_diff <= (rtol * abs_other + atol)
|
||||
|
||||
# This is the nontrivial part
|
||||
if equal_nan:
|
||||
a_r, a_i = split_complex_tensor(self)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
|
||||
a_r_nan = torch.isnan(a_r)
|
||||
b_r_nan = torch.isnan(b_r)
|
||||
a_i_nan = torch.isnan(a_i)
|
||||
b_i_nan = torch.isnan(b_i)
|
||||
a_nan = a_r_nan | a_i_nan
|
||||
|
||||
# This logical expression makes sure that the isnan of both the real and imaginary parts
|
||||
# matches (so 1 + nan*i doesn't equal nan + 1*i)
|
||||
equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan
|
||||
return basic_condition | equal_nan_condition
|
||||
|
||||
return basic_condition
|
||||
|
||||
|
||||
ERROR_OPS_LIST = [
|
||||
aten.lt,
|
||||
aten.le,
|
||||
aten.gt,
|
||||
aten.ge,
|
||||
aten.amin,
|
||||
aten.amax,
|
||||
aten.clamp,
|
||||
aten.ceil,
|
||||
aten.floor,
|
||||
aten.minimum,
|
||||
aten.maximum,
|
||||
aten.trunc,
|
||||
aten.sign,
|
||||
aten.argmax,
|
||||
aten.argmin,
|
||||
aten.sort,
|
||||
aten.topk,
|
||||
aten.round,
|
||||
aten.fmod,
|
||||
]
|
||||
|
||||
|
||||
ERROR_TYPES = {
|
||||
aten.minimum: RuntimeError,
|
||||
aten.maximum: RuntimeError,
|
||||
aten.argmax: RuntimeError,
|
||||
aten.argmin: RuntimeError,
|
||||
aten.sort: RuntimeError,
|
||||
aten.topk: RuntimeError,
|
||||
}
|
||||
|
||||
|
||||
for err_op in ERROR_OPS_LIST:
|
||||
globals()[_get_func_name(err_op)] = register_error(
|
||||
err_op, ERROR_TYPES.get(err_op, NotImplementedError)
|
||||
)
|
||||
|
||||
del err_op
|
||||
|
||||
|
||||
@register_complex(aten.masked_scatter)
|
||||
def masked_scatter_impl(
|
||||
self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
source_r, source_i = split_complex_arg(source)
|
||||
ret_r = torch.masked_scatter(self_r, mask, source_r)
|
||||
ret_i = torch.masked_scatter(self_i, mask, source_i)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
@register_complex(aten.where)
|
||||
def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor:
|
||||
x_r, x_i = split_complex_arg(x)
|
||||
y_r, y_i = split_complex_arg(y)
|
||||
|
||||
ret_r = torch.where(mask, x_r, y_r)
|
||||
ret_i = torch.where(mask, x_i, y_i)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
@register_complex(aten.full_like)
|
||||
def full_like_impl(
|
||||
input: ComplexTensor,
|
||||
fill_value: complex,
|
||||
*args,
|
||||
dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | ComplexTensor:
|
||||
# Note: Cannot be merged with the cases below due to the `fill_value` argument
|
||||
input_r, input_i = split_complex_tensor(input)
|
||||
if dtype is not None and dtype not in COMPLEX_TO_REAL:
|
||||
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs)
|
||||
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
|
||||
|
||||
fv_r, fv_i = split_complex_arg(fill_value)
|
||||
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
|
||||
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]:
|
||||
def impl(
|
||||
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
|
||||
) -> torch.Tensor | ComplexTensor:
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
|
||||
if dtype is not None and dtype not in COMPLEX_TO_REAL:
|
||||
return op(self_re, *args, dtype=dtype, **kwargs)
|
||||
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
|
||||
|
||||
ret_re = op(self_re, *args, **kwargs)
|
||||
ret_im = op(self_im, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
LIKE_OPS_LIST = [
|
||||
aten.empty_like,
|
||||
aten.zeros_like,
|
||||
aten.randn_like,
|
||||
aten.new_zeros,
|
||||
]
|
||||
|
||||
for like_op in LIKE_OPS_LIST:
|
||||
globals()[_get_func_name(like_op)] = register_like(like_op)
|
||||
|
||||
del like_op
|
||||
|
||||
|
||||
@register_complex(aten.cat)
|
||||
def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor:
|
||||
tensors_r = []
|
||||
tensors_i = []
|
||||
|
||||
for t in tensors:
|
||||
t_r, t_i = split_complex_arg(t)
|
||||
tensors_r.append(t_r)
|
||||
tensors_i.append(t_i)
|
||||
|
||||
ret_r = torch.cat(tensors_r, dim=dim)
|
||||
ret_i = torch.cat(tensors_i, dim=dim)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
@register_complex(aten.sgn)
|
||||
def sgn_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
abs_self = torch.abs(ComplexTensor(self_r, self_i))
|
||||
mask = (self_r != 0) | (self_i != 0)
|
||||
masked_sgn = ComplexTensor(
|
||||
(self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt)
|
||||
)
|
||||
return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.sqrt)
|
||||
def sqrt_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
self = ComplexTensor(self_r, self_i)
|
||||
self_abs_sqrt = torch.sqrt(torch.abs(self))
|
||||
self_half_angle = 0.5 * torch.angle(self)
|
||||
|
||||
ret_r = self_abs_sqrt * torch.cos(self_half_angle)
|
||||
ret_i = self_abs_sqrt * torch.sin(self_half_angle)
|
||||
|
||||
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.rsqrt)
|
||||
def rsqrt_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
self = ComplexTensor(self_r, self_i)
|
||||
self_abs_rsqrt = torch.rsqrt(torch.abs(self))
|
||||
self_neg_half_angle = -0.5 * torch.angle(self)
|
||||
|
||||
ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle)
|
||||
ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle)
|
||||
|
||||
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.addmm)
|
||||
def addmm_impl(
|
||||
input: ComplexTensor,
|
||||
mat1: ComplexTensor,
|
||||
mat2: ComplexTensor,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
beta: complex = 1,
|
||||
alpha: complex = 1,
|
||||
) -> ComplexTensor:
|
||||
ret = beta * input + alpha * torch.mm(mat1, mat2)
|
||||
assert isinstance(ret, ComplexTensor)
|
||||
ret_r, ret_i = split_complex_tensor(ret)
|
||||
if out_dtype is not None:
|
||||
out_dtype = COMPLEX_TO_REAL[out_dtype]
|
||||
ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype)
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return (re != 0) | (im != 0)
|
||||
|
||||
|
||||
def register_nonzero_impl(op: OpType):
|
||||
def nonzero_impl(
|
||||
self: ComplexTensor, other: ComplexTensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
nonzero_impl.__name__ = func_name
|
||||
nonzero_impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, nonzero_impl)
|
||||
|
||||
|
||||
logical_and_impl = register_nonzero_impl(aten.logical_and)
|
||||
logical_or_impl = register_nonzero_impl(aten.logical_or)
|
||||
logical_xor_impl = register_nonzero_impl(aten.logical_xor)
|
||||
|
||||
|
||||
@register_complex(aten.logical_not)
|
||||
def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
return torch.logical_not(elemwise_nonzero(self), *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.view_as_real)
|
||||
def view_as_real_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.stack([re, im], dim=-1)
|
||||
|
||||
|
||||
@register_complex(aten.linalg_vector_norm)
|
||||
def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs)
|
||||
|
||||
|
||||
@register_force_test(aten.copy_)
|
||||
def copy__impl(self: ComplexTensor, src, *args, **kwargs):
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
src_re, src_im = split_complex_arg(src)
|
||||
|
||||
ret_re = self_re.copy_(src_re, *args, **kwargs)
|
||||
ret_im = self_im.copy_(src_im, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten._local_scalar_dense)
|
||||
def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex:
|
||||
x, y = split_complex_tensor(self)
|
||||
u = aten._local_scalar_dense(x, *args, **kwargs)
|
||||
v = aten._local_scalar_dense(y, *args, **kwargs)
|
||||
return complex(u, v)
|
||||
|
||||
|
||||
@register_complex(aten.allclose)
|
||||
def allclose_impl(
|
||||
input: torch.Tensor,
|
||||
other: torch.Tensor,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
) -> bool:
|
||||
return torch.all(
|
||||
torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
).item() # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.stack)
|
||||
def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
|
||||
re_im_tuples = [split_complex_arg(self_i) for self_i in self]
|
||||
u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs)
|
||||
v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs)
|
||||
return ComplexTensor(u, v)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten._conj_physical)
|
||||
@register_complex(aten.conj_physical)
|
||||
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return ComplexTensor(re, -im)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten._conj)
|
||||
def _conj_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return ComplexTensor(re, torch._neg_view(im))
|
||||
|
||||
|
||||
@register_complex(aten.index_add)
|
||||
def index_add_impl(
|
||||
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
|
||||
) -> ComplexTensor:
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
if alpha is not None:
|
||||
source = source * alpha
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
source_re, source_im = split_complex_arg(source)
|
||||
|
||||
ret_re = self_re.index_add(dim, index, source_re)
|
||||
ret_im = self_im.index_add(dim, index, source_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten.index_add_)
|
||||
def index_add__impl(
|
||||
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
|
||||
) -> ComplexTensor:
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
if alpha is not None:
|
||||
source = source * alpha
|
||||
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
source_re, source_im = split_complex_arg(source)
|
||||
|
||||
ret_re = self_re.index_add_(dim, index, source_re)
|
||||
ret_im = self_im.index_add_(dim, index, source_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.masked_fill)
|
||||
def masked_fill_impl(
|
||||
self: ComplexTensor, mask: torch.Tensor, value: complex
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
value_re, value_im = split_complex_arg(value)
|
||||
|
||||
ret_re = self_re.masked_fill(mask, value_re)
|
||||
ret_im = self_im.masked_fill(mask, value_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten.masked_fill_)
|
||||
def masked_fill__impl(
|
||||
self: ComplexTensor, mask: torch.Tensor, value: complex
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
value_re, value_im = split_complex_arg(value)
|
||||
|
||||
ret_re = self_re.masked_fill_(mask, value_re)
|
||||
ret_im = self_im.masked_fill_(mask, value_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.constant_pad_nd)
|
||||
def constant_pad_nd_impl(
|
||||
self: ComplexTensor, pad, value: complex | None = None
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
if value is None:
|
||||
ret_re = aten.constant_pad_nd(self_re, pad)
|
||||
ret_im = aten.constant_pad_nd(self_im, pad)
|
||||
else:
|
||||
value_re, value_im = split_complex_arg(value)
|
||||
ret_re = aten.constant_pad_nd(self_re, pad, value_re)
|
||||
ret_im = aten.constant_pad_nd(self_im, pad, value_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.var)
|
||||
def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.scatter_add)
|
||||
def scatter_add_impl(
|
||||
self: ComplexTensor, dim, index, src: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
src_re, src_im = split_complex_arg(src)
|
||||
|
||||
ret_re = torch.scatter_add(self_re, dim, index, src_re)
|
||||
ret_im = torch.scatter_add(self_im, dim, index, src_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.scatter_add_)
|
||||
def scatter_add__impl(
|
||||
self: ComplexTensor, dim, index, src: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
src_re, src_im = split_complex_arg(src)
|
||||
|
||||
out_re = self_re.scatter_add_(dim, index, src_re)
|
||||
out_im = self_im.scatter_add_(dim, index, src_im)
|
||||
|
||||
return ComplexTensor(out_re, out_im)
|
||||
|
||||
|
||||
@register_complex(aten.index_put_)
|
||||
def index_put__impl(
|
||||
self: ComplexTensor,
|
||||
indices: tuple[torch.Tensor, ...],
|
||||
values: ComplexTensor,
|
||||
accumulate: bool = False,
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
values_re, values_im = split_complex_arg(values)
|
||||
|
||||
out_re = self_re.index_put_(indices, values_re, accumulate=accumulate)
|
||||
out_im = self_im.index_put_(indices, values_im, accumulate=accumulate)
|
||||
|
||||
return ComplexTensor(out_re, out_im)
|
||||
|
||||
|
||||
@register_complex(aten.tanh_backward)
|
||||
def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor):
|
||||
return out_grad * (1.0 - y * y).conj_physical()
|
||||
|
||||
|
||||
@register_complex(aten.diagonal_backward)
|
||||
def diagonal_backward(
|
||||
grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
|
||||
):
|
||||
grad_input = grad_output.new_zeros(input_sizes)
|
||||
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
|
||||
|
||||
|
||||
def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any:
|
||||
if not isinstance(dt, torch.dtype):
|
||||
return dt
|
||||
|
||||
return COMPLEX_TO_REAL[dt]
|
||||
|
||||
|
||||
def register_to_impl(op: OpType):
|
||||
"""Register an op similar to `aten.to`, but may have different signatures."""
|
||||
|
||||
def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
try:
|
||||
args = tuple(_dt_to_real(a) for a in args)
|
||||
kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()}
|
||||
except KeyError:
|
||||
return op(x, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs))
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
to_impl = register_to_impl(aten.to)
|
||||
_to_copy_impl = register_to_impl(aten._to_copy)
|
||||
317
torch/_subclasses/complex_tensor/_ops/common.py
Normal file
317
torch/_subclasses/complex_tensor/_ops/common.py
Normal file
@ -0,0 +1,317 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, overload, TypeAlias
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch._refs import is_complex as _is_complex
|
||||
from torch.types import Number
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from .._core import ComplexTensor
|
||||
|
||||
|
||||
OpType: TypeAlias = OpOverloadPacket | OpOverload
|
||||
|
||||
TableType: TypeAlias = dict[OpType, Callable]
|
||||
|
||||
# Mapping from ops to implementations
|
||||
COMPLEX_OPS_TABLE: TableType = {}
|
||||
|
||||
COMPLEX_TO_REAL = {
|
||||
torch.complex128: torch.float64,
|
||||
torch.complex64: torch.float32,
|
||||
torch.complex32: torch.float16,
|
||||
}
|
||||
|
||||
REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()}
|
||||
|
||||
# Used to promote dtypes in `promote_real_cpu_tensors`
|
||||
PROMOTE_TYPES = {
|
||||
torch.float16: torch.float32,
|
||||
torch.bfloat16: torch.float32,
|
||||
torch.complex32: torch.complex64,
|
||||
}
|
||||
|
||||
|
||||
def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]:
|
||||
r"""Returns True if the input is a ComplexTensor, else False
|
||||
|
||||
Args:
|
||||
a: any input
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torch.complex import ComplexTensor
|
||||
>>> data = torch.zeros((3, 2), dtype=torch.complex64)
|
||||
>>> ct = ComplexTensor.from_interleaved(data)
|
||||
>>> is_complex_tensor(ct)
|
||||
True
|
||||
"""
|
||||
return isinstance(obj, ComplexTensor)
|
||||
|
||||
|
||||
@overload
|
||||
def promote_tensors(
|
||||
*tensors: ComplexTensor,
|
||||
) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def promote_tensors(
|
||||
*tensors: Tensor,
|
||||
) -> tuple[torch.dtype, tuple[Tensor, ...]]: ...
|
||||
|
||||
|
||||
def promote_tensors(
|
||||
*tensors: Tensor | ComplexTensor,
|
||||
) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]:
|
||||
"""
|
||||
Promotes all tensors to a common dtype.
|
||||
Additionally promotes CPU tensors to at least `float32`.
|
||||
"""
|
||||
tensor = next(t for t in tensors if isinstance(t, Tensor))
|
||||
out_dt = tensor.dtype
|
||||
for t in tensors:
|
||||
if isinstance(t, Tensor):
|
||||
out_dt = torch.promote_types(out_dt, t.dtype)
|
||||
|
||||
prom_dt = PROMOTE_TYPES.get(out_dt, out_dt)
|
||||
return out_dt, tuple(
|
||||
t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt)
|
||||
for t in tensors
|
||||
)
|
||||
|
||||
|
||||
def register_complex(
|
||||
op: OpType,
|
||||
func_impl: Callable | None = None,
|
||||
):
|
||||
"""Decorator to register an implementation for some ops in some dispatch tables"""
|
||||
|
||||
def inner(func):
|
||||
if COMPLEX_OPS_TABLE.get(op, func) is not func:
|
||||
raise RuntimeError(f"Attempted to register multiple functions for {op}")
|
||||
COMPLEX_OPS_TABLE[op] = func
|
||||
return func
|
||||
|
||||
if func_impl is None:
|
||||
return inner
|
||||
|
||||
return inner(func_impl)
|
||||
|
||||
|
||||
FORCE_TEST_LIST: list[OpType] = []
|
||||
|
||||
|
||||
def register_force_test(op: OpType, *args, **kwargs):
|
||||
"""Will attempt to test these ops even if they err on "normal" inputs"""
|
||||
FORCE_TEST_LIST.append(op)
|
||||
return register_complex(op, *args, **kwargs)
|
||||
|
||||
|
||||
DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload]
|
||||
|
||||
|
||||
def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None:
|
||||
"""
|
||||
Lookup an impl from the table.
|
||||
|
||||
Try the particular overload first, then the overload packet.
|
||||
|
||||
If nothing is found, try the decompositions with both.
|
||||
"""
|
||||
return COMPLEX_OPS_TABLE.get(
|
||||
func,
|
||||
COMPLEX_OPS_TABLE.get(
|
||||
func.overloadpacket,
|
||||
DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def is_complex(x: Any, /) -> bool:
|
||||
"""Utility to detect if a given object is (known) to be complex."""
|
||||
return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex)
|
||||
|
||||
|
||||
@overload
|
||||
def split_complex_arg(
|
||||
arg: Tensor | ComplexTensor,
|
||||
) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def split_complex_arg(
|
||||
arg: complex | Number,
|
||||
) -> tuple[Number, Number]: ...
|
||||
|
||||
|
||||
def split_complex_arg(
|
||||
arg: Tensor | ComplexTensor | complex | Number,
|
||||
) -> tuple[Tensor, Tensor] | tuple[Number, Number]:
|
||||
"""
|
||||
Split a complex argument into a real/imaginary component.
|
||||
|
||||
If real, use zero for the imaginary part.
|
||||
"""
|
||||
if isinstance(arg, ComplexTensor):
|
||||
return split_complex_tensor(arg)
|
||||
if isinstance(arg, Tensor):
|
||||
if is_complex(arg):
|
||||
return arg.real, arg.imag
|
||||
return arg, torch.zeros_like(arg)
|
||||
# TODO (hameerabbasi): Should there be a `torch.SymComplex`?
|
||||
if isinstance(arg, complex):
|
||||
return arg.real, arg.imag
|
||||
if isinstance(arg, float | torch.SymFloat):
|
||||
return arg, 0.0
|
||||
if isinstance(arg, int | torch.SymInt):
|
||||
return arg, 0
|
||||
if isinstance(arg, bool | torch.SymBool):
|
||||
return arg, False
|
||||
raise TypeError(f"Expected tensor or number got, {type(arg)}")
|
||||
|
||||
|
||||
def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]:
|
||||
"""Split a ComplexTensor into its real and imaginary parts."""
|
||||
return complex_tensor.re, complex_tensor.im
|
||||
|
||||
|
||||
def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype:
|
||||
"""Convert a complex dtype to the dtype of its real part. Return other dtypes as-is."""
|
||||
return COMPLEX_TO_REAL.get(dtype, dtype)
|
||||
|
||||
|
||||
def _get_op_name(op: OpType) -> str:
|
||||
"""Get the op name from the op."""
|
||||
if isinstance(op, OpOverload):
|
||||
op = op.overloadpacket
|
||||
return str(op).split(".", 1)[1]
|
||||
|
||||
|
||||
def _get_func_name(op: OpType) -> str:
|
||||
"""Get the name of the implementation function from the op."""
|
||||
return f"{_get_op_name(op)}_impl"
|
||||
|
||||
|
||||
def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError):
|
||||
msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`."
|
||||
|
||||
def ordered_impl(*args, **kwargs):
|
||||
raise exc_type(msg)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
ordered_impl.__name__ = func_name
|
||||
ordered_impl.__qualname__ = func_name
|
||||
|
||||
return register_force_test(op, ordered_impl)
|
||||
|
||||
|
||||
def register_binary_nonlinear(op: OpType) -> Callable:
|
||||
"""Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ..."""
|
||||
|
||||
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
a_r, a_i = split_complex_arg(lhs)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
|
||||
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
|
||||
imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs)
|
||||
return ComplexTensor(real.to(out_dt), imag.to(out_dt))
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
def register_simple(op: OpType):
|
||||
"""Register an op which can be applied independently to the real and complex parts to get the result."""
|
||||
|
||||
def impl(
|
||||
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
|
||||
) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
if dtype is not None and dtype not in COMPLEX_TO_REAL:
|
||||
raise RuntimeError(
|
||||
"Non-complex `dtype` specified, please write custom impl."
|
||||
)
|
||||
|
||||
if dtype in COMPLEX_TO_REAL:
|
||||
assert dtype is not None
|
||||
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
|
||||
|
||||
u = op(x, *args, **kwargs)
|
||||
v = op(y, *args, **kwargs)
|
||||
|
||||
u_flat, u_spec = tree_flatten(u)
|
||||
v_flat, v_spec = tree_flatten(v)
|
||||
assert u_spec == v_spec
|
||||
out_flat = [
|
||||
ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)
|
||||
]
|
||||
return tree_unflatten(out_flat, u_spec)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any:
|
||||
"""Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is."""
|
||||
if (
|
||||
not isinstance(arg, ComplexTensor)
|
||||
and isinstance(arg, Tensor)
|
||||
and arg.dtype in COMPLEX_TO_REAL
|
||||
):
|
||||
return ComplexTensor.from_interleaved(arg)
|
||||
return arg
|
||||
|
||||
|
||||
def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any:
|
||||
"""Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is."""
|
||||
if isinstance(arg, ComplexTensor):
|
||||
return arg.as_interleaved()
|
||||
return arg
|
||||
|
||||
|
||||
class ComplexTensorMode(TorchDispatchMode):
|
||||
_compile: bool
|
||||
|
||||
""" A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """
|
||||
|
||||
def __init__(self, _dispatch_key=None, *, _compile: bool = False):
|
||||
"""Initialize a ComplexTensorMode.
|
||||
|
||||
Args:
|
||||
_dispatch_key: passed on to TorchDispatchMode
|
||||
_compile: Compile the op before the computation
|
||||
"""
|
||||
super().__init__(_dispatch_key)
|
||||
self._compile = _compile
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: tuple[type],
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# TODO (hameerabbasi): Test perf with `_compile` set to `True`
|
||||
if self._compile:
|
||||
func = torch.compile(func) # type: ignore[bad-assignment]
|
||||
|
||||
args = tree_map(_as_complex_tensor, args)
|
||||
kwargs = tree_map(_as_complex_tensor, kwargs)
|
||||
|
||||
return tree_map(_as_interleaved, func(*args, **kwargs))
|
||||
34
torch/_subclasses/complex_tensor/_ops/prims.py
Normal file
34
torch/_subclasses/complex_tensor/_ops/prims.py
Normal file
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
from .._core import ComplexTensor
|
||||
from .common import (
|
||||
complex_to_real_dtype,
|
||||
register_complex,
|
||||
register_force_test,
|
||||
split_complex_tensor,
|
||||
)
|
||||
|
||||
|
||||
prims = torch.ops.prims
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_force_test(prims.convert_element_type)
|
||||
def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor:
|
||||
dtype = complex_to_real_dtype(dtype)
|
||||
u, v = split_complex_tensor(x)
|
||||
u_out = prims.convert_element_type(u, dtype)
|
||||
v_out = prims.convert_element_type(v, dtype)
|
||||
|
||||
return ComplexTensor(u_out, v_out)
|
||||
|
||||
|
||||
@register_complex(prims.conj_physical)
|
||||
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
return aten._conj_physical(self)
|
||||
|
||||
|
||||
@register_complex(prims.conj)
|
||||
def conj_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
return aten._conj(self)
|
||||
@ -138,7 +138,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
||||
throw; \
|
||||
} \
|
||||
} \
|
||||
catch (const std::exception& e) { \
|
||||
catch (const std::exception&) { \
|
||||
torch::translate_exception_to_python(std::current_exception()); \
|
||||
return retval; \
|
||||
}
|
||||
|
||||
@ -81,7 +81,7 @@ c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
|
||||
ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED};
|
||||
try {
|
||||
backendType = deviceTypeToBackendType_.at(deviceType);
|
||||
} catch (const std::out_of_range& e) {
|
||||
} catch (const std::out_of_range&) {
|
||||
TORCH_CHECK(
|
||||
false, "No backend type associated with device type ", deviceType);
|
||||
}
|
||||
|
||||
@ -246,7 +246,7 @@ class UvTcpServer : public UvTcpSocket {
|
||||
uv_err_name(uv_res),
|
||||
uv_strerror(uv_res)));
|
||||
res->cacheSocketPort();
|
||||
} catch (std::exception& ex) {
|
||||
} catch (std::exception&) {
|
||||
res->close();
|
||||
throw;
|
||||
}
|
||||
@ -322,7 +322,7 @@ class UvTcpServer : public UvTcpSocket {
|
||||
uv_err_name(uv_res),
|
||||
uv_strerror(uv_res)));
|
||||
res->cacheSocketPort();
|
||||
} catch (std::exception& ex) {
|
||||
} catch (std::exception&) {
|
||||
res->close();
|
||||
throw;
|
||||
}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
||||
res.setStatus(200);
|
||||
}};
|
||||
|
||||
RegisterHandler frTracehandler(
|
||||
"fr_trace_json",
|
||||
[](const Request&, Response& res) {
|
||||
auto trace = ::c10d::dump_fr_trace_json(true, true);
|
||||
res.setContent(std::move(trace), "application/json");
|
||||
res.setStatus(200);
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
|
||||
@ -18,6 +18,14 @@ class TORCH_API Request {
|
||||
virtual const std::string& body() const = 0;
|
||||
|
||||
virtual const std::multimap<std::string, std::string>& params() const = 0;
|
||||
|
||||
std::string getParam(const std::string& key) const {
|
||||
auto it = params().find(key);
|
||||
if (it != params().end()) {
|
||||
return it->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
// Response represents a response to the handler. This conceptually maps to an
|
||||
|
||||
@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, 80),
|
||||
fmt::format("Error binding to {}", hostOrFile));
|
||||
} else if (port == 0) {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
port_ = server_.bind_to_any_port(hostOrFile);
|
||||
TORCH_CHECK(
|
||||
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
} else {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, port),
|
||||
fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
port_ = port;
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
|
||||
@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
|
||||
void shutdown();
|
||||
|
||||
int port() {
|
||||
return port_;
|
||||
}
|
||||
|
||||
private:
|
||||
httplib::Server server_;
|
||||
std::thread serverThread_;
|
||||
int port_;
|
||||
};
|
||||
|
||||
} // namespace c10d::control_plane
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
@ -4209,7 +4210,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
}),
|
||||
py::arg("host_or_file"),
|
||||
py::arg("port") = -1)
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
|
||||
.def_property_readonly(
|
||||
"port", &::c10d::control_plane::WorkerServer::port);
|
||||
|
||||
module.def(
|
||||
"_get_handler",
|
||||
@ -4225,6 +4228,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
Returns the handler with the specified name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_register_handler",
|
||||
[](const std::string& name, const py::function& handler) {
|
||||
::c10d::control_plane::registerHandler(
|
||||
name,
|
||||
[handler](
|
||||
const ::c10d::control_plane::Request& req,
|
||||
::c10d::control_plane::Response& res) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
handler(std::ref(req), std::ref(res));
|
||||
});
|
||||
},
|
||||
|
||||
py::arg("name"),
|
||||
py::arg("handler"),
|
||||
R"(
|
||||
Registers a handler by name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_get_handler_names",
|
||||
&::c10d::control_plane::getHandlerNames,
|
||||
@ -4242,12 +4264,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
// Default constructor.
|
||||
.def(py::init<>())
|
||||
.def("body", &::c10d::control_plane::Request::body)
|
||||
.def("params", &::c10d::control_plane::Request::params);
|
||||
.def("get_param", &::c10d::control_plane::Request::getParam);
|
||||
|
||||
py::class_<
|
||||
::c10d::control_plane::Response,
|
||||
std::shared_ptr<::c10d::control_plane::Response>,
|
||||
PythonResponse>(
|
||||
py::class_<::c10d::control_plane::Response, PythonResponse>(
|
||||
module,
|
||||
"_Response",
|
||||
R"(
|
||||
|
||||
@ -353,7 +353,7 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||
Py_CLEAR(node->_kwargs);
|
||||
node->_kwargs = map_aggregate(args[1], visit_fn);
|
||||
Py_RETURN_NONE;
|
||||
} catch (const PythonError& e) {
|
||||
} catch (const PythonError&) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@ -397,7 +397,7 @@ static PyObject* NodeBase__replace_input_with(
|
||||
|
||||
PyObject* update_args[2] = {new_args.get(), new_kwargs.get()};
|
||||
return NodeBase__update_args_kwargs(self, update_args, 2);
|
||||
} catch (const PythonError& e) {
|
||||
} catch (const PythonError&) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@ -802,7 +802,7 @@ static PyObject* py_map_aggregate(
|
||||
// args[0]: aggregate, args[1]: callable fn
|
||||
return map_aggregate(
|
||||
args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); });
|
||||
} catch (const PythonError& e) {
|
||||
} catch (const PythonError&) {
|
||||
return nullptr; // error should already be set
|
||||
}
|
||||
}
|
||||
@ -824,7 +824,7 @@ static PyObject* py_map_arg(
|
||||
}
|
||||
return Py_NewRef(a);
|
||||
});
|
||||
} catch (const PythonError& e) {
|
||||
} catch (const PythonError&) {
|
||||
return nullptr; // error should already be set
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,7 +117,7 @@ struct type_caster<torch::jit::IValue> {
|
||||
try {
|
||||
value = torch::jit::toTypeInferredIValue(src);
|
||||
return true;
|
||||
} catch (std::exception& e) {
|
||||
} catch (std::exception&) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -142,7 +142,7 @@ struct type_caster<torch::jit::Symbol> {
|
||||
std::string src_str;
|
||||
try {
|
||||
src_str = py::cast<std::string>(src);
|
||||
} catch (std::exception& e) {
|
||||
} catch (std::exception&) {
|
||||
return false;
|
||||
}
|
||||
value = torch::jit::Symbol::fromQualString(src_str);
|
||||
|
||||
@ -285,7 +285,7 @@ struct FromImpl<torch::headeronly::HeaderOnlyArrayRef<T>> {
|
||||
torch_list_push_back(new_list_handle, from(elem)));
|
||||
}
|
||||
return from(new_list_handle);
|
||||
} catch (const std::runtime_error& e) {
|
||||
} catch (const std::runtime_error&) {
|
||||
if (new_list_handle != nullptr) {
|
||||
// clean up memory if an error was thrown
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle));
|
||||
@ -553,7 +553,7 @@ struct ToImpl<std::vector<T>> {
|
||||
}
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
|
||||
return result;
|
||||
} catch (const std::runtime_error& e) {
|
||||
} catch (const std::runtime_error&) {
|
||||
// clean up memory if an exception is thrown, and rethrow
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
|
||||
throw;
|
||||
|
||||
82
torch/distributed/debug/__init__.py
Normal file
82
torch/distributed/debug/__init__.py
Normal file
@ -0,0 +1,82 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import socket
|
||||
|
||||
# import for registration side effect
|
||||
import torch.distributed.debug._handlers # noqa: F401
|
||||
from torch._C._distributed_c10d import _WorkerServer
|
||||
from torch.distributed.debug._store import get_rank, tcpstore_client
|
||||
|
||||
|
||||
__all__ = [
|
||||
"start_debug_server",
|
||||
"stop_debug_server",
|
||||
]
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
_WORKER_SERVER: _WorkerServer | None = None
|
||||
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
|
||||
|
||||
|
||||
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
|
||||
"""
|
||||
Start the debug server stack on all workers. The frontend debug server is
|
||||
only started on rank0 while the per rank worker servers are started on all
|
||||
ranks.
|
||||
|
||||
This server provides an HTTP frontend that allows for debugging slow and
|
||||
deadlocked distributed jobs across all ranks simultaneously. This collects
|
||||
data such as stack traces, FlightRecorder events, and performance profiles.
|
||||
|
||||
WARNING: This is intended to only be used in trusted network environments.
|
||||
The debug server is not designed to be secure and should not be exposed to
|
||||
the public internet. See SECURITY.md for more details.
|
||||
|
||||
WARNING: This is an experimental feature and may change at any time.
|
||||
|
||||
Args:
|
||||
port (int): The port to start the frontend debug server on.
|
||||
worker_port (int): The port to start the worker server on. Defaults to 0, which
|
||||
will cause the worker server to bind to an ephemeral port.
|
||||
"""
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _WORKER_SERVER is None, "debug server already started"
|
||||
assert _DEBUG_SERVER_PROC is None, "debug server already started"
|
||||
|
||||
logger.info("Starting debug server on port %d", port)
|
||||
|
||||
store = tcpstore_client()
|
||||
|
||||
_WORKER_SERVER = _WorkerServer("::", worker_port)
|
||||
|
||||
RANK = get_rank()
|
||||
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
|
||||
|
||||
from torch.distributed.debug._frontend import main
|
||||
|
||||
if RANK == 0:
|
||||
_DEBUG_SERVER_PROC = multiprocessing.Process(
|
||||
target=main, args=(port,), daemon=True
|
||||
)
|
||||
_DEBUG_SERVER_PROC.start()
|
||||
|
||||
|
||||
def stop_debug_server() -> None:
|
||||
"""
|
||||
Shutdown the debug server and stop the frontend debug server process.
|
||||
"""
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _DEBUG_SERVER_PROC is not None
|
||||
assert _WORKER_SERVER is not None
|
||||
|
||||
logger.info("Stopping debug server")
|
||||
|
||||
_DEBUG_SERVER_PROC.terminate()
|
||||
_WORKER_SERVER.shutdown()
|
||||
_DEBUG_SERVER_PROC.join()
|
||||
|
||||
_WORKER_SERVER = None
|
||||
_DEBUG_SERVER_PROC = None
|
||||
353
torch/distributed/debug/_frontend.py
Normal file
353
torch/distributed/debug/_frontend.py
Normal file
@ -0,0 +1,353 @@
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
from jinja2 import DictLoader, Environment
|
||||
|
||||
from torch.distributed.debug._store import get_world_size, tcpstore_client
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fetch_all(
|
||||
endpoint: str, args: str = ""
|
||||
) -> tuple[list[str], Iterator[requests.Response]]:
|
||||
store = tcpstore_client()
|
||||
keys = [f"rank{r}" for r in range(get_world_size())]
|
||||
addrs = store.multi_get(keys)
|
||||
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
resps = executor.map(requests.post, addrs)
|
||||
|
||||
return addrs, resps
|
||||
|
||||
|
||||
def format_json(blob: str):
|
||||
parsed = json.loads(blob)
|
||||
return json.dumps(parsed, indent=2)
|
||||
|
||||
|
||||
templates = {
|
||||
"base.html": """
|
||||
<!doctype html>
|
||||
<head>
|
||||
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
|
||||
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
|
||||
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font-family:
|
||||
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
|
||||
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
|
||||
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
line-height: 1.5;
|
||||
color: #212529;
|
||||
text-align: left;
|
||||
background-color: #fff;
|
||||
}
|
||||
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
|
||||
margin-bottom: .5rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.2;
|
||||
}
|
||||
nav {
|
||||
background-color: rgba(0, 0, 0, 0.17);
|
||||
padding: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
nav h1 {
|
||||
display: inline-block;
|
||||
margin: 0;
|
||||
}
|
||||
nav a {
|
||||
margin: 0 8px;
|
||||
}
|
||||
section {
|
||||
max-width: 1280px;
|
||||
padding: 16px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
max-width: 100%;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<nav>
|
||||
<h1>Torch Distributed Debug Server</h1>
|
||||
|
||||
<a href="/">Home</a> <!--@lint-ignore-->
|
||||
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
|
||||
<a href="/fr_trace">FlightRecorder</a> <!--@lint-ignore-->
|
||||
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
|
||||
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
|
||||
</nav>
|
||||
|
||||
<section class="content">
|
||||
{% block header %}{% endblock %}
|
||||
{% block content %}{% endblock %}
|
||||
</section>
|
||||
""",
|
||||
"index.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}Index{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
Hi
|
||||
{% endblock %}
|
||||
""",
|
||||
"raw_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{title}}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"json_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{ title }}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ format_json(resp.text) }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"profile.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}torch.profiler{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<form action="/profile" method="get">
|
||||
<label for="duration">Duration (seconds):</label>
|
||||
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
|
||||
<input type="submit" value="Submit">
|
||||
</form>
|
||||
|
||||
<script>
|
||||
function stringToArrayBuffer(str) {
|
||||
const encoder = new TextEncoder();
|
||||
return encoder.encode(str).buffer;
|
||||
}
|
||||
async function openPerfetto(data) {
|
||||
const ui = window.open('https://ui.perfetto.dev/#!/');
|
||||
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
|
||||
|
||||
// Perfetto readiness handshake: PING until we receive PONG
|
||||
await new Promise((resolve, reject) => {
|
||||
const onMsg = (e) => {
|
||||
if (e.source === ui && e.data === 'PONG') {
|
||||
window.removeEventListener('message', onMsg);
|
||||
clearInterval(pinger);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
window.addEventListener('message', onMsg);
|
||||
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
|
||||
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
|
||||
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
|
||||
|
||||
ui.postMessage({
|
||||
perfetto: {
|
||||
buffer: stringToArrayBuffer(JSON.stringify(data)),
|
||||
title: "torch profiler",
|
||||
fileName: "trace.json",
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
</script>
|
||||
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<script>
|
||||
function run{{ i }}() {
|
||||
var data = {{ resp.text | safe }};
|
||||
openPerfetto(data);
|
||||
}
|
||||
</script>
|
||||
|
||||
<button onclick="run{{ i }}()">View {{ i }}</button>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
class _IPv6HTTPServer(ThreadingHTTPServer):
|
||||
address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore
|
||||
request_queue_size: int = 1024
|
||||
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
frontend: "FrontendServer"
|
||||
|
||||
def do_GET(self):
|
||||
self.frontend._handle_request(self)
|
||||
|
||||
def get_path(self) -> str:
|
||||
return urlparse(self.path).path
|
||||
|
||||
def get_query(self) -> dict[str, list[str]]:
|
||||
return parse_qs(urlparse(self.path).query)
|
||||
|
||||
def get_query_arg(
|
||||
self, name: str, default: object = None, type: type = str
|
||||
) -> object:
|
||||
query = self.get_query()
|
||||
if name not in query:
|
||||
return default
|
||||
return type(query[name][0])
|
||||
|
||||
|
||||
class FrontendServer:
|
||||
def __init__(self, port: int):
|
||||
# Setup templates
|
||||
loader = DictLoader(templates)
|
||||
self._jinja_env = Environment(loader=loader, enable_async=True)
|
||||
self._jinja_env.globals.update(
|
||||
zip=zip,
|
||||
format_json=format_json,
|
||||
enumerate=enumerate,
|
||||
)
|
||||
|
||||
# Create routes
|
||||
self._routes = {
|
||||
"/": self._handle_index,
|
||||
"/stacks": self._handle_stacks,
|
||||
"/fr_trace": self._handle_fr_trace,
|
||||
"/fr_trace_nccl": self._handle_fr_trace_nccl,
|
||||
"/profile": self._handle_profiler,
|
||||
}
|
||||
|
||||
# Create HTTP server
|
||||
RequestHandlerClass = type(
|
||||
"HTTPRequestHandler",
|
||||
(HTTPRequestHandler,),
|
||||
{"frontend": self},
|
||||
)
|
||||
|
||||
server_address = ("", port)
|
||||
self._server = _IPv6HTTPServer(server_address, RequestHandlerClass)
|
||||
|
||||
self._thread = threading.Thread(
|
||||
target=self._serve,
|
||||
args=(),
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _serve(self) -> None:
|
||||
try:
|
||||
self._server.serve_forever()
|
||||
except Exception:
|
||||
logger.exception("got exception in checkpoint server")
|
||||
|
||||
def join(self) -> None:
|
||||
self._thread.join()
|
||||
|
||||
def _handle_request(self, req: HTTPRequestHandler) -> None:
|
||||
path = req.get_path()
|
||||
if path not in self._routes:
|
||||
req.send_error(404, f"Handler not found: {path}")
|
||||
return
|
||||
|
||||
handler = self._routes[path]
|
||||
try:
|
||||
resp = handler(req)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Exception in checkpoint server when handling %s",
|
||||
path,
|
||||
)
|
||||
req.send_error(500, str(e))
|
||||
return
|
||||
|
||||
req.send_response(200)
|
||||
req.send_header("Content-type", "text/html")
|
||||
req.end_headers()
|
||||
req.wfile.write(resp)
|
||||
|
||||
def _render_template(self, template: str, **kwargs: object) -> bytes:
|
||||
return self._jinja_env.get_template(template).render(**kwargs).encode()
|
||||
|
||||
def _handle_index(self, req: HTTPRequestHandler) -> bytes:
|
||||
return self._render_template("index.html")
|
||||
|
||||
def _handle_stacks(self, req: HTTPRequestHandler) -> bytes:
|
||||
addrs, resps = fetch_all("dump_traceback")
|
||||
return self._render_template(
|
||||
"raw_resp.html", title="Stacks", addrs=addrs, resps=resps
|
||||
)
|
||||
|
||||
def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes:
|
||||
addrs, resps = fetch_all("fr_trace_json")
|
||||
|
||||
return self._render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes:
|
||||
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
|
||||
|
||||
return self._render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder NCCL",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
def _handle_profiler(self, req: HTTPRequestHandler) -> bytes:
|
||||
duration = req.get_query_arg("duration", default=1.0, type=float)
|
||||
|
||||
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
|
||||
|
||||
return self._render_template("profile.html", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
def main(port: int) -> None:
|
||||
server = FrontendServer(port=port)
|
||||
logger.info("Frontend server started on port %d", server._server.server_port)
|
||||
server.join()
|
||||
22
torch/distributed/debug/_handlers.py
Normal file
22
torch/distributed/debug/_handlers.py
Normal file
@ -0,0 +1,22 @@
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from torch._C._distributed_c10d import _register_handler, _Request, _Response
|
||||
from torch.profiler import _ExperimentalConfig, profile
|
||||
|
||||
|
||||
def _torch_profile(req: _Request, resp: _Response) -> None:
|
||||
experimental_config = _ExperimentalConfig(
|
||||
profile_all_threads=True,
|
||||
)
|
||||
duration = float(req.get_param("duration"))
|
||||
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
|
||||
time.sleep(duration)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
|
||||
prof.export_chrome_trace(f.name)
|
||||
resp.set_content(open(f.name, "rb").read(), "application/json")
|
||||
resp.set_status(200)
|
||||
|
||||
|
||||
_register_handler("torch_profile", _torch_profile)
|
||||
24
torch/distributed/debug/_store.py
Normal file
24
torch/distributed/debug/_store.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return int(os.environ["RANK"])
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
|
||||
|
||||
def tcpstore_client() -> dist.Store:
|
||||
MASTER_ADDR = os.environ["MASTER_ADDR"]
|
||||
MASTER_PORT = int(os.environ["MASTER_PORT"])
|
||||
|
||||
store = dist.TCPStore(
|
||||
host_name=MASTER_ADDR,
|
||||
port=MASTER_PORT,
|
||||
is_master=False,
|
||||
)
|
||||
store = dist.PrefixStore("debug_server", store)
|
||||
return store
|
||||
@ -9,7 +9,7 @@ from collections.abc import Callable, Iterable
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import deprecated, Self
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
@ -408,6 +408,11 @@ class _KinetoProfile:
|
||||
)
|
||||
return MemoryProfile(self.profiler.kineto_results)
|
||||
|
||||
@deprecated(
|
||||
"`export_memory_timeline` is deprecated and will be removed in a future version. "
|
||||
"Please use `torch.cuda.memory._record_memory_history` and `torch.cuda.memory._export_memory_snapshot` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
|
||||
"""Export memory event information from the profiler collected
|
||||
tree for a given device, and export a timeline plot. There are 3
|
||||
@ -429,6 +434,11 @@ class _KinetoProfile:
|
||||
``torch.profiler._memory_profiler.Category``.
|
||||
|
||||
Output: Memory timeline written as gzipped JSON, JSON, or HTML.
|
||||
|
||||
.. deprecated::
|
||||
``export_memory_timeline`` is deprecated and will be removed in a future version.
|
||||
Please use ``torch.cuda.memory._record_memory_history`` and
|
||||
``torch.cuda.memory._export_memory_snapshot`` instead.
|
||||
"""
|
||||
# Default to device 0, if unset. Fallback on cpu.
|
||||
if device is None:
|
||||
|
||||
@ -386,7 +386,7 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def backend(self) -> str:
|
||||
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
|
||||
backend = dist.get_default_backend_for_device(self.device_type)
|
||||
return backend
|
||||
|
||||
def init_manual_seed_for_rank(self) -> None:
|
||||
|
||||
@ -249,6 +249,8 @@ def object_annotation(obj):
|
||||
if len(filename) > FRAME_FILENAME_LIMIT:
|
||||
filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
|
||||
return f"frame\n{filename}:{obj.f_lineno}"
|
||||
elif is_cuda_tensor(obj):
|
||||
return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})"
|
||||
else:
|
||||
return f"object\n{type(obj).__module__}.{type(obj).__name__}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user