Compare commits

..

1 Commits

Author SHA1 Message Date
33f776b894 [export] Add pytree input check for dynamo_graph_capture_for_export
Summary:
as title.

Test Plan:
pytest test/export/test_export.py -k test_invalid_pytree_dynamo_graph_capture
2025-11-13 10:59:31 -08:00
31 changed files with 350 additions and 957 deletions

View File

@ -1,7 +1,6 @@
#pragma once
#include <ATen/native/CompositeRandomAccessorCommon.h>
#include <thrust/swap.h>
#include <thrust/tuple.h>
namespace at { namespace native {

View File

@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<typename scalar_t, typename index_t, class BinaryOp>
template<typename scalar_t, class BinaryOp>
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
const scalar_t init, BinaryOp binary_op)
{
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
scalar_t acc = init;
for (uint32_t col = 0; col < row_size; ++col) {
@ -409,15 +409,10 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
check_fits_in_unsigned(num_irows, "num_irows");
check_fits_in_unsigned(num_orows, "num_orows");
check_fits_in_unsigned(row_size, "row_size");
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
} else {
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View File

@ -7518,7 +7518,7 @@
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
SparseCPU, SparseCUDA: sparse_mask_projection
autogen: _sparse_mask_projection.out
- func: _to_cpu(Tensor[] tensors) -> Tensor[]

View File

@ -30,12 +30,10 @@
#include <thrust/binary_search.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <cuda_runtime_api.h>
#include <cusparse.h>

View File

@ -445,33 +445,6 @@ static SparseTensor& mul_out_dense_sparse_mps(
return out;
}
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
const Tensor& A_keys,
const Tensor& B_keys,
int64_t lenA,
int64_t lenB,
bool boolean_flag) {
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), boolean_flag);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
}
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
@ -550,10 +523,22 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_lhs);
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_lhs);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const uint32_t M = counter.item<int32_t>(); // number of structural matches
r_.resize_as_(lhs);
@ -777,14 +762,6 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
using OptTensor = std::optional<Tensor>;
static Tensor create_sparse_output_values(
const Tensor& template_values,
int64_t output_nnz,
ScalarType dtype) {
auto out_val_sizes = template_values.sizes().vec();
out_val_sizes[0] = output_nnz;
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
}
static void sparse_mask_apply_out_mps_kernel(
Tensor& result,
@ -806,9 +783,9 @@ static void sparse_mask_apply_out_mps_kernel(
auto src = src_in.coalesce();
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
const auto src_nnz = src._nnz();
const auto mask_nnz = mask._nnz();
const auto sd = src.sparse_dim();
const int64_t src_nnz = src._nnz();
const int64_t mask_nnz = mask._nnz();
const int64_t sd = src.sparse_dim();
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
auto commonDtype = at::result_type(src, mask);
@ -837,27 +814,53 @@ static void sparse_mask_apply_out_mps_kernel(
return;
}
auto mask_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
if (src_nnz == 0) {
alias_into_sparse(result, mask_indices, out_values);
auto out_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype);
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
alias_into_sparse(result, out_indices, out_values);
result._coalesced_(mask.is_coalesced());
return;
}
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
auto mask_indices = mask._indices().contiguous();
auto src_indices = src._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
const auto A_is_src = (src_nnz <= mask_nnz);
const auto lenA = A_is_src ? src_nnz : mask_nnz;
const auto lenB = A_is_src ? mask_nnz : src_nnz;
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
const bool A_is_src = (src_nnz <= mask_nnz);
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
auto A_keys = A_is_src ? src_keys : mask_keys;
auto B_keys = A_is_src ? mask_keys : src_keys;
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_src);
const auto device = result.device();
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_src);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
if (M > 0) {
auto src_match = outA_idx.narrow(0, 0, M);
@ -875,70 +878,6 @@ static void sparse_mask_apply_out_mps_kernel(
result._coalesced_(mask.is_coalesced());
}
static void sparse_mask_projection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
const Tensor& rhs,
const OptTensor& /*x_hash_opt*/,
bool accumulate_matches) {
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
auto lhs_c = lhs.coalesce();
auto rhs_c = rhs.coalesce();
const auto sd = lhs_c.sparse_dim();
const auto lhs_nnz = lhs_c._nnz();
const auto rhs_nnz = rhs_c._nnz();
auto commonDtype = at::result_type(lhs_c, rhs_c);
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
"Can't convert ", commonDtype, " to output ", result.scalar_type());
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
auto lhs_indices = lhs_c._indices().contiguous();
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
if (lhs_nnz > 0 && rhs_nnz > 0) {
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_lhs);
if (M > 0) {
auto idx_in_A = outA_idx.narrow(0, 0, M);
auto idx_in_B = outB_idx.narrow(0, 0, M);
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
auto out_2d = out_values.view({lhs_nnz, view_cols});
if (accumulate_matches) {
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
} else {
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
}
}
}
alias_into_sparse(result, lhs._indices(), out_values);
result._coalesced_(lhs.is_coalesced());
}
static void sparse_mask_intersection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
@ -1063,5 +1002,4 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
} // namespace at::native

View File

@ -31,8 +31,6 @@ from torch.utils._debug_mode import (
_RedistributeCall,
_TritonKernelCall,
DebugMode,
hash_tensor_fn,
norm_hash_fn,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._triton import has_triton_package
@ -117,28 +115,6 @@ class TestDTensorDebugMode(TestCase):
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
)
# check tuple hash functions
with (
DebugMode() as debug_mode,
DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]),
):
mm(x_dtensor, y_dtensor)
output_hash = debug_mode.operators[-1].log["hash"]
norm_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731
hash_ = lambda x: hash_tensor_fn(x, use_scalar=True) # noqa: E731
self.assertEqual(output_hash[0], norm_(eager_out))
self.assertEqual(output_hash[1], hash_(eager_out))
# some edge cases
self.assertEqual(norm_(torch.tensor(torch.nan)), torch.nan)
self.assertEqual(norm_(torch.tensor(torch.inf)), torch.inf)
self.assertEqual(norm_(torch.complex(torch.ones(4), torch.zeros(4))), 4)
self.assertEqual(hash_(torch.ones(4, dtype=torch.float8_e5m2)), 0)
self.assertEqual(hash_(torch.ones(4, dtype=torch.int8)), 0)
self.assertEqual(hash_(torch.ones(5, dtype=torch.int8)), 1)
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

View File

@ -664,101 +664,6 @@ class TestViewOps(DTensorTestBase):
)
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
@with_comms
def test_storage_offset_slice(self):
"""
Test that storage_offset is properly tracked on DTensor when slicing
a replicated tensor.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a replicated DTensor
tensor = torch.randn(10, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
# Perform a slice operation [1:]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[1:]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
# Verify that the DTensor's storage_offset matches the expected value
self.assertEqual(sliced_dtensor.storage_offset(), 1)
# Verify that the local tensor also has the correct storage_offset
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
# Verify the shape is correct
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
# Verify the values are correct
expected = tensor[1:]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
@with_comms
def test_storage_offset_shard_dim0_slice_dim1(self):
"""
Test that storage_offset is properly tracked when tensor is sharded on dim 0
and sliced on dim 1.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a 2D tensor and shard on dim 0
tensor = torch.randn(12, 8, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
# Perform a slice operation [:, 2:]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[:, 2:]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
# The storage_offset should be 2 (skipping 2 elements in each row)
self.assertEqual(sliced_dtensor.storage_offset(), 2)
# Verify that the local tensor also has the correct storage_offset
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
# Verify the shape is correct
expected_shape = torch.Size([12, 6])
self.assertEqual(sliced_dtensor.shape, expected_shape)
# Verify the values are correct
expected = tensor[:, 2:]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
@with_comms
def test_storage_offset_shard_dim1_slice_dim0(self):
"""
Test that storage_offset is properly tracked when tensor is sharded on dim 1
and sliced on dim 0.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a 2D tensor and shard on dim 1
tensor = torch.randn(10, 12, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
# Perform a slice operation [2:, :]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[2:, :]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
local_dim1_size = 12 // self.world_size
expected_offset = 2 * local_dim1_size
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
# Verify the shape is correct
expected_shape = torch.Size([8, 12])
self.assertEqual(sliced_dtensor.shape, expected_shape)
# Verify the values are correct
expected = tensor[2:, :]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
TestViewOps,

View File

@ -1,9 +1,11 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -13,13 +15,16 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -599,6 +604,92 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
)
@requires_cuda
def test_stream_backward_simple(self) -> None:
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
@ -524,68 +524,7 @@ class GraphModule(torch.nn.Module):
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_stream_backward_sync(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
torch.ones(2, 2, device="cuda:0", requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",

View File

@ -15295,12 +15295,12 @@ graph():
def forward(self, block):
return block.a + block.b
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
with self.assertRaisesRegex(
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
):
_dynamo_graph_capture_for_export(Foo())(
dynamo_graph_capture_for_export(Foo())(
Block(torch.randn(4, 4), torch.randn(4, 4))
)

View File

@ -1,71 +0,0 @@
# Owner(s): ["oncall: export"]
from torch._dynamo import config as dynamo_config
from torch._dynamo.testing import make_test_cls_with_patches
from torch._export import config as export_config
try:
from . import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def mocked_strict_export(*args, **kwargs):
# If user already specified strict, don't make it strict
if "strict" in kwargs:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True)
def make_dynamic_cls(cls):
# Some test check for ending in suffix; need to make
# the `_strict` for end of string as a result
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
cls_prefix = "InlineAndInstall"
cls_a = testing.make_test_cls_with_mocked_export(
cls,
"StrictExport",
suffix,
mocked_strict_export,
xfail_prop="_expected_failure_strict",
)
test_class = make_test_cls_with_patches(
cls_a,
cls_prefix,
"",
(export_config, "use_new_tracer_experimental", True),
(dynamo_config, "install_free_tensors", True),
(dynamo_config, "inline_inbuilt_nn_modules", True),
xfail_prop="_expected_failure_inline_and_install",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
test_export.TestDynamismExpression,
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1394,357 +1394,6 @@ class HasDecompTest(TestCase):
check_case(groups=1, C_in=8, C_out=12) # groups=1 bigger
check_case(groups=2, C_in=8, C_out=12) # grouped conv
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_mm_decompose_mm_dde(self):
def fuzzed_program(
arg_0,
arg_1,
arg_2,
arg_3,
arg_4,
arg_5,
arg_6,
arg_7,
arg_8,
arg_9,
arg_10,
arg_11,
arg_12,
arg_13,
arg_14,
arg_15,
arg_16,
arg_17,
arg_18,
sentinel,
):
var_node_6 = (
arg_0 # size=(9, 9, 9), stride=(81, 9, 1), dtype=float64, device=cuda
)
var_node_7 = (
arg_1 # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
)
var_node_5 = torch.matmul(
var_node_6.to(torch.float64), var_node_7.to(torch.float64)
) # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
var_node_9 = torch.full(
(9, 11, 12), 1.5758497316910556, dtype=torch.float64
) # size=(9, 11, 12), stride=(132, 12, 1), dtype=float64, device=cuda
var_node_10 = (
arg_2 # size=(9, 12, 8), stride=(96, 8, 1), dtype=float64, device=cuda
)
var_node_8 = torch.matmul(
var_node_9.to(torch.float64), var_node_10.to(torch.float64)
) # size=(9, 11, 8), stride=(88, 8, 1), dtype=float64, device=cuda
var_node_4 = torch.matmul(
var_node_5.to(torch.float64), var_node_8.to(torch.float64)
) # size=(9, 9, 8), stride=(72, 8, 1), dtype=float64, device=cuda
var_node_13 = arg_3 # size=(9, 8, 13), stride=(104, 13, 1), dtype=float64, device=cuda
var_node_14 = (
arg_4 # size=(9, 13, 7), stride=(91, 7, 1), dtype=float64, device=cuda
)
var_node_12 = torch.matmul(
var_node_13.to(torch.float64), var_node_14.to(torch.float64)
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
var_node_15 = arg_5 # size=(9, 7, 16), stride=(112, 16, 1), dtype=float64, device=cuda
var_node_11 = torch.matmul(
var_node_12.to(torch.float64), var_node_15.to(torch.float64)
) # size=(9, 8, 16), stride=(128, 16, 1), dtype=float64, device=cuda
var_node_3 = torch.matmul(
var_node_4.to(torch.float64), var_node_11.to(torch.float64)
) # size=(9, 9, 16), stride=(144, 16, 1), dtype=float64, device=cuda
var_node_17 = arg_6 # size=(9, 16, 12), stride=(192, 12, 1), dtype=float64, device=cuda
var_node_18 = arg_7 # size=(9, 12, 11), stride=(132, 11, 1), dtype=float64, device=cuda
var_node_16 = torch.matmul(
var_node_17.to(torch.float64), var_node_18.to(torch.float64)
) # size=(9, 16, 11), stride=(176, 11, 1), dtype=float64, device=cuda
var_node_2 = torch.matmul(
var_node_3.to(torch.float64), var_node_16.to(torch.float64)
) # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
var_node_23 = torch.full(
(156, 8), -0.5249394453404403, dtype=torch.float64
) # size=(156, 8), stride=(8, 1), dtype=float64, device=cuda
var_node_24 = torch.full(
(8, 9), 0.9331226188585692, dtype=torch.float64
) # size=(8, 9), stride=(9, 1), dtype=float64, device=cuda
var_node_22 = torch.matmul(
var_node_23.to(torch.float64), var_node_24.to(torch.float64)
) # size=(156, 9), stride=(9, 1), dtype=float64, device=cuda
var_node_26 = torch.full(
(9, 13), -0.9276381954691514, dtype=torch.float64
) # size=(9, 13), stride=(13, 1), dtype=float64, device=cuda
var_node_27 = torch.full(
(13, 16), 0.024752238943232543, dtype=torch.float64
) # size=(13, 16), stride=(16, 1), dtype=float64, device=cuda
var_node_25 = torch.matmul(
var_node_26.to(torch.float64), var_node_27.to(torch.float64)
) # size=(9, 16), stride=(16, 1), dtype=float64, device=cuda
var_node_21 = torch.matmul(
var_node_22.to(torch.float64), var_node_25.to(torch.float64)
) # size=(156, 16), stride=(16, 1), dtype=float64, device=cuda
var_node_29 = arg_8
_x_nz = torch.zeros(
(9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
dtype=torch.bool,
device=var_node_29.device,
)
_x_nz_flat = _x_nz.reshape(-1)
_x_nz_flat[:9] = True
var_node_28 = torch.nonzero(
_x_nz
) # size=(9, 11), stride=(11, 1), dtype=int64, device=cuda
var_node_20 = torch.nn.functional.embedding(
torch.clamp(var_node_28.to(torch.int64), 0, var_node_21.size(0) - 1),
var_node_21,
) # size=(9, 11, 16), stride=(176, 16, 1), dtype=float64, device=cuda
var_node_33 = torch.full(
(9, 16, 5), 1.0707914920634904, dtype=torch.float64
) # size=(9, 16, 5), stride=(80, 5, 1), dtype=float64, device=cuda
var_node_34 = torch.full(
(9, 5, 10), -0.44934093079047227, dtype=torch.float64
) # size=(9, 5, 10), stride=(50, 10, 1), dtype=float64, device=cuda
var_node_32 = torch.matmul(
var_node_33.to(torch.float64), var_node_34.to(torch.float64)
) # size=(9, 16, 10), stride=(160, 10, 1), dtype=float64, device=cuda
var_node_36 = (
arg_9 # size=(9, 10, 1), stride=(10, 1, 1), dtype=float64, device=cuda
)
var_node_37 = torch.full(
(9, 1, 11), -1.874293687140311, dtype=torch.float64
) # size=(9, 1, 11), stride=(11, 11, 1), dtype=float64, device=cuda
var_node_35 = torch.matmul(
var_node_36.to(torch.float64), var_node_37.to(torch.float64)
) # size=(9, 10, 11), stride=(110, 11, 1), dtype=float64, device=cuda
var_node_31 = torch.matmul(
var_node_32.to(torch.float64), var_node_35.to(torch.float64)
) # size=(9, 16, 11), stride=(176, 11, 1), dtype=float64, device=cuda
var_node_40 = torch.full(
(990, 2), 0.4084376380351558, dtype=torch.float64
) # size=(990, 2), stride=(2, 1), dtype=float64, device=cuda
var_node_41 = torch.full(
(2,), 0.982671965550022, dtype=torch.float64
) # size=(2,), stride=(1,), dtype=float64, device=cuda
var_node_39 = torch.matmul(
var_node_40.to(torch.float64), var_node_41.to(torch.float64)
) # size=(990,), stride=(1,), dtype=float64, device=cuda
var_node_38 = torch.reshape(
var_node_39, [9, 11, 10]
) # size=(9, 11, 10), stride=(110, 10, 1), dtype=float64, device=cuda
var_node_30 = torch.matmul(
var_node_31.to(torch.float64), var_node_38.to(torch.float64)
) # size=(9, 16, 10), stride=(160, 10, 1), dtype=float64, device=cuda
var_node_19 = torch.matmul(
var_node_20.to(torch.float64), var_node_30.to(torch.float64)
) # size=(9, 11, 10), stride=(110, 10, 1), dtype=float64, device=cuda
var_node_1 = torch.matmul(
var_node_2.to(torch.float64), var_node_19.to(torch.float64)
) # size=(9, 9, 10), stride=(90, 10, 1), dtype=float64, device=cuda
var_node_47 = arg_10 # size=(9, 10, 15), stride=(150, 15, 1), dtype=float64, device=cuda
var_node_48 = torch.full(
(9, 15, 2), -0.3349339402390618, dtype=torch.float64
) # size=(9, 15, 2), stride=(30, 2, 1), dtype=float64, device=cuda
var_node_46 = torch.matmul(
var_node_47.to(torch.float64), var_node_48.to(torch.float64)
) # size=(9, 10, 2), stride=(20, 2, 1), dtype=float64, device=cuda
var_node_50 = (
arg_11 # size=(9, 2, 7), stride=(14, 7, 1), dtype=float64, device=cuda
)
var_node_51 = (
arg_12 # size=(9, 7, 2), stride=(14, 2, 1), dtype=float64, device=cuda
)
var_node_49 = torch.matmul(
var_node_50.to(torch.float64), var_node_51.to(torch.float64)
) # size=(9, 2, 2), stride=(4, 2, 1), dtype=float64, device=cuda
var_node_45 = torch.matmul(
var_node_46.to(torch.float64), var_node_49.to(torch.float64)
) # size=(9, 10, 2), stride=(20, 2, 1), dtype=float64, device=cuda
var_node_52 = torch.full(
(9, 2, 1), -0.4046675639434615, dtype=torch.float64
) # size=(9, 2, 1), stride=(2, 1, 1), dtype=float64, device=cuda
var_node_44 = torch.matmul(
var_node_45.to(torch.float64), var_node_52.to(torch.float64)
) # size=(9, 10, 1), stride=(10, 1, 1), dtype=float64, device=cuda
var_node_56 = (
arg_13 # size=(9, 1, 1), stride=(1, 1, 1), dtype=float64, device=cuda
)
var_node_55 = torch.nn.functional.rms_norm(
var_node_56.to(torch.float64), (1,)
) # size=(9, 1, 1), stride=(1, 1, 1), dtype=float64, device=cuda
var_node_57 = torch.full(
(9, 1, 8), 0.17877664640931384, dtype=torch.float64
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
var_node_54 = torch.matmul(
var_node_55.to(torch.float64), var_node_57.to(torch.float64)
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
var_node_60 = arg_14 # size=(9, 8, 10), stride=(80, 10, 1), dtype=float64, device=cuda
var_node_61 = torch.full(
(9, 10, 6), 0.43614806380221494, dtype=torch.float64
) # size=(9, 10, 6), stride=(60, 6, 1), dtype=float64, device=cuda
var_node_59 = torch.matmul(
var_node_60.to(torch.float64), var_node_61.to(torch.float64)
) # size=(9, 8, 6), stride=(48, 6, 1), dtype=float64, device=cuda
var_node_63 = (
arg_15 # size=(9, 6, 3), stride=(18, 3, 1), dtype=float64, device=cuda
)
var_node_64 = torch.full(
(9, 3, 8), -0.042774422041922854, dtype=torch.float64
) # size=(9, 3, 8), stride=(24, 8, 1), dtype=float64, device=cuda
var_node_62 = torch.matmul(
var_node_63.to(torch.float64), var_node_64.to(torch.float64)
) # size=(9, 6, 8), stride=(48, 8, 1), dtype=float64, device=cuda
var_node_58 = torch.matmul(
var_node_59.to(torch.float64), var_node_62.to(torch.float64)
) # size=(9, 8, 8), stride=(64, 8, 1), dtype=float64, device=cuda
var_node_53 = torch.matmul(
var_node_54.to(torch.float64), var_node_58.to(torch.float64)
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
var_node_43 = torch.matmul(
var_node_44.to(torch.float64), var_node_53.to(torch.float64)
) # size=(9, 10, 8), stride=(80, 8, 1), dtype=float64, device=cuda
var_node_68 = arg_16 # size=(9, 8, 16), stride=(128, 16, 1), dtype=float64, device=cuda
var_node_70 = torch.full(
(9, 16, 15), 0.24947808634496438, dtype=torch.float64
) # size=(9, 16, 15), stride=(240, 15, 1), dtype=float64, device=cuda
var_node_71 = torch.full(
(9, 15, 7), -0.09035245509773453, dtype=torch.float64
) # size=(9, 15, 7), stride=(105, 7, 1), dtype=float64, device=cuda
var_node_69 = torch.matmul(
var_node_70.to(torch.float64), var_node_71.to(torch.float64)
) # size=(9, 16, 7), stride=(112, 7, 1), dtype=float64, device=cuda
var_node_67 = torch.matmul(
var_node_68.to(torch.float64), var_node_69.to(torch.float64)
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
var_node_74 = torch.full(
(9, 7, 1), 0.05671950481832341, dtype=torch.float64
) # size=(9, 7, 1), stride=(7, 1, 1), dtype=float64, device=cuda
var_node_73 = torch.nn.functional.gelu(
var_node_74
) # size=(9, 7, 1), stride=(7, 1, 1), dtype=float64, device=cuda
var_node_76 = torch.full(
(9, 1, 2), -0.019912810353597852, dtype=torch.float64
) # size=(9, 1, 2), stride=(2, 2, 1), dtype=float64, device=cuda
var_node_77 = (
arg_17 # size=(9, 2, 7), stride=(14, 7, 1), dtype=float64, device=cuda
)
var_node_75 = torch.matmul(
var_node_76.to(torch.float64), var_node_77.to(torch.float64)
) # size=(9, 1, 7), stride=(7, 7, 1), dtype=float64, device=cuda
var_node_72 = torch.matmul(
var_node_73.to(torch.float64), var_node_75.to(torch.float64)
) # size=(9, 7, 7), stride=(49, 7, 1), dtype=float64, device=cuda
var_node_66 = torch.matmul(
var_node_67.to(torch.float64), var_node_72.to(torch.float64)
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
var_node_78 = arg_18 # size=(9, 7, 13), stride=(91, 13, 1), dtype=float64, device=cuda
var_node_65 = torch.matmul(
var_node_66.to(torch.float64), var_node_78.to(torch.float64)
) # size=(9, 8, 13), stride=(104, 13, 1), dtype=float64, device=cuda
var_node_42 = torch.matmul(
var_node_43.to(torch.float64), var_node_65.to(torch.float64)
) # size=(9, 10, 13), stride=(130, 13, 1), dtype=float64, device=cuda
var_node_0 = torch.matmul(
var_node_1.to(torch.float64), var_node_42.to(torch.float64)
) # size=(9, 9, 13), stride=(117, 13, 1), dtype=float64, device=cuda
# Ensure gradient computation by multiplying with sentinel and taking real part
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
# Sentinel tensor to ensure gradient computation
sentinel = torch.tensor(1.0, requires_grad=True)
arg_0 = torch.as_strided(
torch.randn(729).to(torch.float64), (9, 9, 9), (81, 9, 1)
)
arg_1 = torch.as_strided(
torch.randn(891).to(torch.float64), (9, 9, 11), (99, 11, 1)
)
arg_2 = torch.as_strided(
torch.randn(864).to(torch.float64), (9, 12, 8), (96, 8, 1)
)
arg_3 = torch.as_strided(
torch.randn(936).to(torch.float64), (9, 8, 13), (104, 13, 1)
)
arg_4 = torch.as_strided(
torch.randn(819).to(torch.float64), (9, 13, 7), (91, 7, 1)
)
arg_5 = torch.as_strided(
torch.randn(1008).to(torch.float64), (9, 7, 16), (112, 16, 1)
)
arg_6 = torch.as_strided(
torch.randn(1728).to(torch.float64), (9, 16, 12), (192, 12, 1)
)
arg_7 = torch.as_strided(
torch.randn(1188).to(torch.float64), (9, 12, 11), (132, 11, 1)
)
arg_8 = torch.as_strided(
torch.randint(0, 2, (1,), dtype=torch.int8).bool(),
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
)
arg_9 = torch.as_strided(
torch.randn(90).to(torch.float64), (9, 10, 1), (10, 1, 1)
)
arg_10 = torch.as_strided(
torch.randn(1350).to(torch.float64), (9, 10, 15), (150, 15, 1)
)
arg_11 = torch.as_strided(
torch.randn(126).to(torch.float64), (9, 2, 7), (14, 7, 1)
)
arg_12 = torch.as_strided(
torch.randn(126).to(torch.float64), (9, 7, 2), (14, 2, 1)
)
arg_13 = torch.as_strided(
torch.randn(9).to(torch.float64), (9, 1, 1), (1, 1, 1)
)
arg_14 = torch.as_strided(
torch.randn(720).to(torch.float64), (9, 8, 10), (80, 10, 1)
)
arg_15 = torch.as_strided(
torch.randn(162).to(torch.float64), (9, 6, 3), (18, 3, 1)
)
arg_16 = torch.as_strided(
torch.randn(1152).to(torch.float64), (9, 8, 16), (128, 16, 1)
)
arg_17 = torch.as_strided(
torch.randn(126).to(torch.float64), (9, 2, 7), (14, 7, 1)
)
arg_18 = torch.as_strided(
torch.randn(819).to(torch.float64), (9, 7, 13), (91, 13, 1)
)
args = (
arg_0,
arg_1,
arg_2,
arg_3,
arg_4,
arg_5,
arg_6,
arg_7,
arg_8,
arg_9,
arg_10,
arg_11,
arg_12,
arg_13,
arg_14,
arg_15,
arg_16,
arg_17,
arg_18,
) + (sentinel,)
result_original = fuzzed_program(*args)
compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
result_compiled = compiled_program(*args)
# Both should succeed without NameError
self.assertTrue(
torch.allclose(result_original, result_compiled, rtol=1e-5, atol=1e-5)
)
if __name__ == "__main__":
run_tests()

View File

@ -4524,17 +4524,6 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
run(torch.rand(2, 10), torch.rand(2, 10))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_unbacked_view_extra(self):
def fn(x):
i0 = x.nonzero().size(0)
y = torch.zeros((i0, 192))
return y.view([12, -1, 192])
res1 = torch.compile(fn, fullgraph=True)(torch.ones((12,)))
res2 = fn(torch.ones((12,)))
self.assertEqual(res1, res2)
instantiate_parametrized_tests(TestUnbacked)

View File

@ -3755,44 +3755,6 @@ as the input tensor excluding its innermost dimension'):
with ctx:
self.assertEqual(torch.mean(t), expected)
def test_scalar_tensor_as_dim_argument(self):
"""Tests that scalar tensors work correctly as dimension arguments.
This tests the fix for the PythonArgParser bug where scalar Tensors
passed to IntList/SymIntList parameters would be incorrectly handled.
"""
x = torch.ones(1, 2, 3, 4, 5)
# Scalar tensors should work correctly (same as passing an int)
result_tensor = x.sum(dim=torch.tensor(3))
result_int = x.sum(dim=3)
self.assertEqual(result_tensor.shape, result_int.shape)
self.assertEqual(result_tensor.shape, torch.Size([1, 2, 3, 5]))
# Test with different integer dtypes
for dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
dim_tensor = torch.tensor(1, dtype=dtype)
result = x.sum(dim=dim_tensor)
expected = x.sum(dim=1)
self.assertEqual(result.shape, expected.shape)
@skipIfTorchDynamo("Test uses random.randint which creates FakeTensors")
def test_scalar_tensor_dim_compiled_mode(self):
"""Tests that scalar FakeTensors from random.randint work correctly in compiled mode."""
def foo():
x = torch.ones(2, 2, 2)
return x.sum(dim=random.randint(0, 0))
@torch.compile
def foo_compile():
x = torch.ones(2, 2, 2)
return x.sum(dim=random.randint(0, 0))
result_eager = foo()
result_compiled = foo_compile()
self.assertEqual(result_eager.shape, result_compiled.shape)
self.assertEqual(result_eager.shape, torch.Size([2, 2]))
instantiate_device_type_tests(TestReductions, globals())
if __name__ == '__main__':

View File

@ -2236,6 +2236,7 @@ class TestSparse(TestSparseBase):
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
@expectedFailureMPS
@skipIfCrossRef
def test_sparse_mask_backward(self, device, dtype):
from itertools import product, repeat
@ -2245,6 +2246,7 @@ class TestSparse(TestSparseBase):
nnzs = (0, 5, 15, 25)
lhs_data = torch.arange(1, 26, device=device).reshape(shape).to(dtype).to_sparse(sparse_dims)
rhs_data = lhs_data.clone()
for nnz in nnzs:
for lhs_is_coalesced, rhs_is_coalesced in product(*repeat((True, False), 2)):
@ -2264,9 +2266,8 @@ class TestSparse(TestSparseBase):
# sparsity_pattern(lhs) == sparsity_pattern(lhs.grad).
# lhs.sparse_mask(lhs_mask) accomplishes that.
lhs_mask = lhs.detach().clone()
gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,),
masked=True, eps=3e-4, atol=5e-5)
gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False, eps=3e-4, atol=5e-5)
gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,), masked=True)
gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False)
@coalescedonoff
@dtypes(torch.double, torch.cdouble)

View File

@ -1781,14 +1781,6 @@ class TestTorchDeviceType(TestCase):
self.assertEqual(b[0, :], d[0, :], atol=3e-5, rtol=3e-5)
self.assertEqual(b[-1, :], d[-1, :], atol=3e-5, rtol=3e-5)
@onlyCUDA
@largeTensorTest('48GB')
def test_cumsum_outer_dim_64bit_indexing(self, device):
x = torch.zeros(309504, 1, 16384, device=device)
torch.exp(x)
cumsum = torch.cumsum(x, dim=1)
self.assertEqual(cumsum.max().item(), 0., atol=0., rtol=0.)
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
@onlyNativeDeviceTypes
def test_nondeterministic_alert_put(self, device):

View File

@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2672,8 +2701,10 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2700,7 +2731,10 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -4562,8 +4562,6 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
@out_wrapper(pass_is_out=True)
def matmul(tensor1, tensor2, *, is_out=False):
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
dim_tensor1 = tensor1.dim()
dim_tensor2 = tensor2.dim()
assert dim_tensor1 != 0 and dim_tensor2 != 0
@ -4632,11 +4630,11 @@ def matmul(tensor1, tensor2, *, is_out=False):
if (
dim_tensor1 == 3
and dim_tensor2 == 3
and guard_or_true(batch_tensor1[0] != batch_tensor2[0])
and batch_tensor1[0] != batch_tensor2[0]
):
if guard_or_false(batch_tensor1[0] == 1) and tensor1.requires_grad:
if batch_tensor1[0] == 1 and tensor1.requires_grad:
return matmul(tensor1.squeeze(0), tensor2)
if guard_or_false(batch_tensor2[0] == 1) and tensor2.requires_grad:
if batch_tensor2[0] == 1 and tensor2.requires_grad:
return matmul(tensor1, tensor2.squeeze(0))
# expand the batch portion (i.e. cut off matrix dimensions and expand rest)

View File

@ -53,6 +53,7 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -273,6 +274,7 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -499,6 +499,9 @@ def pytreeify(
root = mod.__self__
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
torch._dynamo.eval_frame.check_user_input_output(
flat_real_args[1 if root else 0 :], UserErrorType.INVALID_INPUT
)
class Yield(Exception):
pass

View File

@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -33,7 +33,6 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -474,9 +473,6 @@ def aot_dispatch_autograd_graph(
# fw node match might be erased
copy_fwd_metadata_to_bw_nodes(fx_g)
# After copying metadata, assign streams to gradient accumulation nodes
assign_backward_streams(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -1,53 +0,0 @@
from typing import Optional, TypeAlias
import torch.fx
import torch.fx.traceback
from torch._dynamo.graph_utils import _get_flat_args
Node: TypeAlias = torch.fx.Node
def is_gradient_acc(node: Node) -> bool:
return node.meta.get("is_gradient_acc", False)
def get_stream(node: Node) -> Optional[int]:
maybe_annotation = node.meta.get("custom", None)
if maybe_annotation is not None:
return node.meta["custom"].get("stream", None)
else:
return None
def set_stream(node: Node, ind: int) -> None:
if "custom" in node.meta:
node.meta["custom"].update({"stream": ind})
else:
node.meta["custom"] = {"stream": ind}
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
# NB: iterate in reverse order to more closely match eager
# the user node stream will be populated first
for node in reversed(list(gm.graph.nodes)):
if is_gradient_acc(node):
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
# 1. Match first stream assignment of the first user with a stream
# 2. Match first stream assignment encountered in the args from left to right
# This differs from eager in some cases:
# Specifically the eager code uses the autograd node to determine the stream,
# crucially this does not necessarily correspond to the FX graph node. For example,
# in the backward for an add node with a constant we will passthrough and during backward tracing,
# no op will be added to the FX graph, so our stream assignment will differ in this case.
gradients = _get_flat_args(node, {})
users = list(node.users.keys())
# All gradients will be on same device, they will be coerced if they were not with a .to() node
for neighbor in users + gradients:
ind = get_stream(neighbor)
if ind is not None:
set_stream(node, ind)
break

View File

@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2713,7 +2715,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation:
if V.aot_compilation and not config.enable_autograd_for_aot:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()

View File

@ -1193,6 +1193,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -773,9 +773,83 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
def post_compile(
self,
@ -783,10 +857,8 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
pass
def prepare_for_serialization(self) -> None:
pass
if self.current_callable is None:
self.__post_init__()
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -1751,7 +1751,7 @@ static PyObject* THPVariable_dtensor_new(
Tensor tensor = make_tensor_for_subclass_helper(
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
/*sym_storage_offset=*/local_tensor.sym_storage_offset(),
/*sym_storage_offset=*/std::nullopt,
options,
/*storage_size=*/std::nullopt,
extra_dispatch_keys);

View File

@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -565,16 +565,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
return std::vector<c10::SymInt>(size1, si);
}
if (size1 > 0 && THPVariable_Check(args[i])) {
return std::vector<c10::SymInt>(
size1, THPVariable_Unpack(args[i]).item().toSymInt());
}
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<c10::SymInt> res;
@ -653,13 +645,7 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
}
if (size1 > 0 && THPVariable_Check(arg)) {
return std::vector<int64_t>(size1, THPVariable_Unpack(arg).item<int64_t>());
}
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<int64_t> res(size2);
@ -730,9 +716,6 @@ inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
inline std::vector<double> PythonArgs::getDoublelist(int i) {
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<double> res(size);
@ -906,9 +889,6 @@ inline at::Dimname PythonArgs::dimname(int i) {
inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<at::Dimname> res;

View File

@ -7037,16 +7037,52 @@ class ShapeEnv:
ok = len(free_unbacked_symbols(new_var)) == 0
if ok:
self._set_replacement(free[0], new_var, "solve")
except NotImplementedError:
pass
else:
# expression has mod.
if expr.has(Mod):
mod_expr = next(iter(expr.atoms(Mod)))
try:
r = try_solve(expr, mod_expr, floordiv_inequality=False)
if r is not None and r[1] == 0:
self._add_divisible(mod_expr)
# This is a little bit of extra logic to make things like
# torch.empty(i0, q).view(c, -1, q) work out
p, q = mod_expr.args
if (
isinstance(q, sympy.Number)
and isinstance(p, sympy.Mul)
and len(p.args) == 2
):
c, i0 = p.args
# Given Mod(c * i0, q) == 0
if (
isinstance(c, sympy.Number)
and isinstance(i0, sympy.Symbol)
and self.is_unbacked_symint(i0)
):
# We have Mod(i0, q / c) == 0, which means we can
# rewrite i0 as (q / gcd(q, c)) * i1
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
i1 = self.create_unbacked_symint().node.expr
# Propagate the value ranges. It doesn't really
# matter if we use truediv or floordiv, because we
# have established divisibility.
self._update_var_to_range(
i1,
SymPyValueRangeAnalysis.floordiv(
self.var_to_range[i0], ValueRanges.wrap(d)
),
)
# Propagate hints (real tensor tracing)
if i0 in self.unbacked_var_to_val:
self.set_unbacked_var_to_val(
i1, self.unbacked_var_to_val[i0] // d
)
# Propagate size-like-ness
if i0 in self.size_like:
self.size_like.add(i1)
self._set_replacement(i0, d * i1, "divisibility")
except NotImplementedError:
pass
return

View File

@ -39,7 +39,7 @@ import os
import traceback
import weakref
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING, Union # noqa: F401
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -157,25 +157,21 @@ def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
return str(arg)
def norm_hash_fn(
t: torch.Tensor, use_scalar: bool = False
) -> Union[torch.Tensor, float]:
def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor:
"""
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
This is used to generate a deterministic summary value for tensor comparison.
"""
with torch._C._DisablePythonDispatcher():
with torch._C._DisablePythonDispatcher(), torch._C._DisableTorchDispatch():
if not (t.is_floating_point() or t.is_complex()):
t = t.float()
t = t.contiguous()
# Clean the tensor to handle NaN/inf values, then compute norm
t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0)
if t.is_complex():
t_float = t.to(dtype=torch.complex128)
else:
t_float = t.to(dtype=torch.float64)
out = t_float.norm(p=1)
dtype = torch.complex128 if t.is_complex() else torch.float64
out = t_clean.norm(p=1, dtype=dtype)
if use_scalar:
return out.item()
return out
@ -188,28 +184,6 @@ def _compute_rel_diff(hash1, hash2):
return numerator / denominator
def hash_tensor_fn(
t: torch.Tensor, use_scalar: bool = False
) -> Union[torch.Tensor, int]:
"""
wrapper over torch.hash_tensor
"""
if isinstance(t, torch.distributed.tensor.DTensor):
t = t.to_local()
if t.is_floating_point():
t_clean = t.to(dtype=torch.float64)
elif t.is_complex():
t_clean = t.to(dtype=torch.complex128).view(torch.float64)
else:
t_clean = t.to(dtype=torch.int64)
out = torch.hash_tensor(t_clean)
if use_scalar:
return out.item() # type: ignore[attribute]
return out
def _get_stack_trace() -> str:
from torch.fx.experimental.symbolic_shapes import uninteresting_files
@ -923,43 +897,20 @@ class DebugMode(TorchDispatchMode):
@staticmethod
@contextlib.contextmanager
def log_tensor_hashes(
hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False
):
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
"""
Installs hook for tensor hash logging.
hash_fn: One of:
- Custom-defined hash function
- String: one of ("norm", "hash_tensor")
- "norm": uses norm_hash_fn; basically tensor's L1 norm
- "hash_tensor": uses torch.hash_tensor (XOR sum reduction)
- List of strings: returns tuple of hashes from above options
hash_fn: optional function for custom hashing
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
"""
def hash_fn_option(hash_type):
assert isinstance(hash_type, str) and hash_type in ["norm", "hash_tensor"]
return functools.partial(
norm_hash_fn if hash_type == "norm" else hash_tensor_fn, use_scalar=True
)
if callable(hash_fn):
fn = hash_fn
elif isinstance(hash_fn, str):
fn = hash_fn_option(hash_fn)
elif isinstance(hash_fn, list):
fns = [hash_fn_option(fn) for fn in hash_fn]
fn = lambda x: tuple(fn(x) for fn in fns) # noqa: E731
else:
raise NotImplementedError(
f"log_tensor_hashes() expected hash_fn to be callable, str, or list[str], but found {type(hash_fn)}"
)
if hash_fn is None:
hash_fn = functools.partial(default_hash_fn, use_scalar=True)
def _tree_hash(obj):
return tree_map(
lambda x: fn(x) if isinstance(x, torch.Tensor) else None, obj
lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj
)
def _dispatch_hash_hook(func, types, args, kwargs, result):
@ -979,9 +930,9 @@ class DebugMode(TorchDispatchMode):
try:
if hash_inputs:
_old_input_hfn = _TRITON_INPUT_HASH_FN
_TRITON_INPUT_HASH_FN = fn
_TRITON_INPUT_HASH_FN = hash_fn
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
_TRITON_OUTPUT_HASH_FN = fn
_TRITON_OUTPUT_HASH_FN = hash_fn
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
yield
finally: