mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Compare commits
1 Commits
ciflow/tru
...
zhxchen17/
| Author | SHA1 | Date | |
|---|---|---|---|
| 33f776b894 |
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
||||
#include <thrust/swap.h>
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
|
||||
* outer dimensions, which contains several "inner rows").
|
||||
* Each thread processes a single inner row at a time.
|
||||
*/
|
||||
template<typename scalar_t, typename index_t, class BinaryOp>
|
||||
template<typename scalar_t, class BinaryOp>
|
||||
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
|
||||
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
|
||||
const scalar_t init, BinaryOp binary_op)
|
||||
{
|
||||
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
|
||||
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
|
||||
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
|
||||
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
|
||||
scalar_t acc = init;
|
||||
|
||||
for (uint32_t col = 0; col < row_size; ++col) {
|
||||
@ -409,15 +409,10 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
|
||||
check_fits_in_unsigned(num_irows, "num_irows");
|
||||
check_fits_in_unsigned(num_orows, "num_orows");
|
||||
check_fits_in_unsigned(row_size, "row_size");
|
||||
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
|
||||
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
} else {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
|
||||
@ -7518,7 +7518,7 @@
|
||||
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
|
||||
SparseCPU, SparseCUDA: sparse_mask_projection
|
||||
autogen: _sparse_mask_projection.out
|
||||
|
||||
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
|
||||
|
||||
@ -30,12 +30,10 @@
|
||||
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
|
||||
@ -445,33 +445,6 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
|
||||
const Tensor& A_keys,
|
||||
const Tensor& B_keys,
|
||||
int64_t lenA,
|
||||
int64_t lenB,
|
||||
bool boolean_flag) {
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), boolean_flag);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
|
||||
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
|
||||
}
|
||||
|
||||
|
||||
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
|
||||
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
|
||||
@ -550,10 +523,22 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
|
||||
|
||||
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_lhs);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const uint32_t M = counter.item<int32_t>(); // number of structural matches
|
||||
|
||||
r_.resize_as_(lhs);
|
||||
|
||||
@ -777,14 +762,6 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||
|
||||
using OptTensor = std::optional<Tensor>;
|
||||
|
||||
static Tensor create_sparse_output_values(
|
||||
const Tensor& template_values,
|
||||
int64_t output_nnz,
|
||||
ScalarType dtype) {
|
||||
auto out_val_sizes = template_values.sizes().vec();
|
||||
out_val_sizes[0] = output_nnz;
|
||||
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
|
||||
}
|
||||
|
||||
static void sparse_mask_apply_out_mps_kernel(
|
||||
Tensor& result,
|
||||
@ -806,9 +783,9 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
auto src = src_in.coalesce();
|
||||
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
|
||||
|
||||
const auto src_nnz = src._nnz();
|
||||
const auto mask_nnz = mask._nnz();
|
||||
const auto sd = src.sparse_dim();
|
||||
const int64_t src_nnz = src._nnz();
|
||||
const int64_t mask_nnz = mask._nnz();
|
||||
const int64_t sd = src.sparse_dim();
|
||||
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
|
||||
|
||||
auto commonDtype = at::result_type(src, mask);
|
||||
@ -837,27 +814,53 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
|
||||
|
||||
if (src_nnz == 0) {
|
||||
alias_into_sparse(result, mask_indices, out_values);
|
||||
auto out_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype);
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_indices = src._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
|
||||
const auto A_is_src = (src_nnz <= mask_nnz);
|
||||
const auto lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const auto lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const bool A_is_src = (src_nnz <= mask_nnz);
|
||||
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto A_keys = A_is_src ? src_keys : mask_keys;
|
||||
auto B_keys = A_is_src ? mask_keys : src_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_src);
|
||||
const auto device = result.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_src);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
|
||||
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
|
||||
if (M > 0) {
|
||||
auto src_match = outA_idx.narrow(0, 0, M);
|
||||
@ -875,70 +878,6 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_projection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
const Tensor& rhs,
|
||||
const OptTensor& /*x_hash_opt*/,
|
||||
bool accumulate_matches) {
|
||||
|
||||
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
|
||||
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
|
||||
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
|
||||
|
||||
auto lhs_c = lhs.coalesce();
|
||||
auto rhs_c = rhs.coalesce();
|
||||
|
||||
const auto sd = lhs_c.sparse_dim();
|
||||
const auto lhs_nnz = lhs_c._nnz();
|
||||
const auto rhs_nnz = rhs_c._nnz();
|
||||
|
||||
auto commonDtype = at::result_type(lhs_c, rhs_c);
|
||||
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
|
||||
"Can't convert ", commonDtype, " to output ", result.scalar_type());
|
||||
|
||||
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
|
||||
|
||||
auto lhs_indices = lhs_c._indices().contiguous();
|
||||
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
|
||||
|
||||
if (lhs_nnz > 0 && rhs_nnz > 0) {
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
|
||||
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
|
||||
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
|
||||
if (M > 0) {
|
||||
auto idx_in_A = outA_idx.narrow(0, 0, M);
|
||||
auto idx_in_B = outB_idx.narrow(0, 0, M);
|
||||
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
|
||||
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
|
||||
|
||||
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
|
||||
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
|
||||
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
|
||||
auto out_2d = out_values.view({lhs_nnz, view_cols});
|
||||
|
||||
if (accumulate_matches) {
|
||||
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
|
||||
} else {
|
||||
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alias_into_sparse(result, lhs._indices(), out_values);
|
||||
result._coalesced_(lhs.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_intersection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
@ -1063,5 +1002,4 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -31,8 +31,6 @@ from torch.utils._debug_mode import (
|
||||
_RedistributeCall,
|
||||
_TritonKernelCall,
|
||||
DebugMode,
|
||||
hash_tensor_fn,
|
||||
norm_hash_fn,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._triton import has_triton_package
|
||||
@ -117,28 +115,6 @@ class TestDTensorDebugMode(TestCase):
|
||||
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
|
||||
)
|
||||
|
||||
# check tuple hash functions
|
||||
with (
|
||||
DebugMode() as debug_mode,
|
||||
DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]),
|
||||
):
|
||||
mm(x_dtensor, y_dtensor)
|
||||
|
||||
output_hash = debug_mode.operators[-1].log["hash"]
|
||||
norm_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731
|
||||
hash_ = lambda x: hash_tensor_fn(x, use_scalar=True) # noqa: E731
|
||||
|
||||
self.assertEqual(output_hash[0], norm_(eager_out))
|
||||
self.assertEqual(output_hash[1], hash_(eager_out))
|
||||
|
||||
# some edge cases
|
||||
self.assertEqual(norm_(torch.tensor(torch.nan)), torch.nan)
|
||||
self.assertEqual(norm_(torch.tensor(torch.inf)), torch.inf)
|
||||
self.assertEqual(norm_(torch.complex(torch.ones(4), torch.zeros(4))), 4)
|
||||
self.assertEqual(hash_(torch.ones(4, dtype=torch.float8_e5m2)), 0)
|
||||
self.assertEqual(hash_(torch.ones(4, dtype=torch.int8)), 0)
|
||||
self.assertEqual(hash_(torch.ones(5, dtype=torch.int8)), 1)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
|
||||
@ -664,101 +664,6 @@ class TestViewOps(DTensorTestBase):
|
||||
)
|
||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_slice(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked on DTensor when slicing
|
||||
a replicated tensor.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a replicated DTensor
|
||||
tensor = torch.randn(10, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
|
||||
|
||||
# Perform a slice operation [1:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[1:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# Verify that the DTensor's storage_offset matches the expected value
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 1)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
|
||||
|
||||
# Verify the shape is correct
|
||||
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[1:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim0_slice_dim1(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 0
|
||||
and sliced on dim 1.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 0
|
||||
tensor = torch.randn(12, 8, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
|
||||
# Perform a slice operation [:, 2:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[:, 2:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# The storage_offset should be 2 (skipping 2 elements in each row)
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 2)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([12, 6])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[:, 2:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim1_slice_dim0(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 1
|
||||
and sliced on dim 0.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 1
|
||||
tensor = torch.randn(10, 12, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
|
||||
|
||||
# Perform a slice operation [2:, :]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[2:, :]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
local_dim1_size = 12 // self.world_size
|
||||
expected_offset = 2 * local_dim1_size
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
|
||||
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([8, 12])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[2:, :]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
|
||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestViewOps,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,13 +15,16 @@ import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
|
||||
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
|
||||
from torch._dynamo.exc import PackageError, Unsupported
|
||||
from torch._dynamo.package import DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TEST_CUDA,
|
||||
)
|
||||
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
@ -599,6 +604,92 @@ from user code:
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
|
||||
fn,
|
||||
(make_inputs(), {}),
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_module(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
mod = SimpleLinearModule()
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(4, 3),)
|
||||
|
||||
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
|
||||
mod,
|
||||
[ModelInput(make_inputs(), {}, [])],
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
def get_grads(m: torch.nn.Module):
|
||||
return {name: p.grad for name, p in m.named_parameters()}
|
||||
|
||||
original_mod = copy.deepcopy(mod)
|
||||
test_inputs = make_inputs()
|
||||
expected = mod(*test_inputs)
|
||||
expected.sum().backward()
|
||||
expected_grads = get_grads(mod)
|
||||
|
||||
actual = compiled_mod(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
serialized = compiled_mod.serialize()
|
||||
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
actual.sum().backward()
|
||||
self.assertEqual(get_grads(original_mod), expected_grads)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_torch_compile(self):
|
||||
with torch.device("cuda"):
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn, fullgraph=True, options={"use_aoti": True}
|
||||
).aot_compile((make_inputs(), {}))
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward_simple(self) -> None:
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
@ -524,68 +524,7 @@ class GraphModule(torch.nn.Module):
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward_sync(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
|
||||
torch.ones(2, 2, device="cuda:0", requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
|
||||
@ -15295,12 +15295,12 @@ graph():
|
||||
def forward(self, block):
|
||||
return block.a + block.b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
|
||||
):
|
||||
_dynamo_graph_capture_for_export(Foo())(
|
||||
dynamo_graph_capture_for_export(Foo())(
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
from torch._export import config as export_config
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
# Some test check for ending in suffix; need to make
|
||||
# the `_strict` for end of string as a result
|
||||
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
|
||||
|
||||
cls_prefix = "InlineAndInstall"
|
||||
|
||||
cls_a = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
"StrictExport",
|
||||
suffix,
|
||||
mocked_strict_export,
|
||||
xfail_prop="_expected_failure_strict",
|
||||
)
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls_a,
|
||||
cls_prefix,
|
||||
"",
|
||||
(export_config, "use_new_tracer_experimental", True),
|
||||
(dynamo_config, "install_free_tensors", True),
|
||||
(dynamo_config, "inline_inbuilt_nn_modules", True),
|
||||
xfail_prop="_expected_failure_inline_and_install",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestDynamismExpression,
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -1394,357 +1394,6 @@ class HasDecompTest(TestCase):
|
||||
check_case(groups=1, C_in=8, C_out=12) # groups=1 bigger
|
||||
check_case(groups=2, C_in=8, C_out=12) # grouped conv
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
||||
def test_mm_decompose_mm_dde(self):
|
||||
def fuzzed_program(
|
||||
arg_0,
|
||||
arg_1,
|
||||
arg_2,
|
||||
arg_3,
|
||||
arg_4,
|
||||
arg_5,
|
||||
arg_6,
|
||||
arg_7,
|
||||
arg_8,
|
||||
arg_9,
|
||||
arg_10,
|
||||
arg_11,
|
||||
arg_12,
|
||||
arg_13,
|
||||
arg_14,
|
||||
arg_15,
|
||||
arg_16,
|
||||
arg_17,
|
||||
arg_18,
|
||||
sentinel,
|
||||
):
|
||||
var_node_6 = (
|
||||
arg_0 # size=(9, 9, 9), stride=(81, 9, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_7 = (
|
||||
arg_1 # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_5 = torch.matmul(
|
||||
var_node_6.to(torch.float64), var_node_7.to(torch.float64)
|
||||
) # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
|
||||
var_node_9 = torch.full(
|
||||
(9, 11, 12), 1.5758497316910556, dtype=torch.float64
|
||||
) # size=(9, 11, 12), stride=(132, 12, 1), dtype=float64, device=cuda
|
||||
var_node_10 = (
|
||||
arg_2 # size=(9, 12, 8), stride=(96, 8, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_8 = torch.matmul(
|
||||
var_node_9.to(torch.float64), var_node_10.to(torch.float64)
|
||||
) # size=(9, 11, 8), stride=(88, 8, 1), dtype=float64, device=cuda
|
||||
var_node_4 = torch.matmul(
|
||||
var_node_5.to(torch.float64), var_node_8.to(torch.float64)
|
||||
) # size=(9, 9, 8), stride=(72, 8, 1), dtype=float64, device=cuda
|
||||
var_node_13 = arg_3 # size=(9, 8, 13), stride=(104, 13, 1), dtype=float64, device=cuda
|
||||
var_node_14 = (
|
||||
arg_4 # size=(9, 13, 7), stride=(91, 7, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_12 = torch.matmul(
|
||||
var_node_13.to(torch.float64), var_node_14.to(torch.float64)
|
||||
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
|
||||
var_node_15 = arg_5 # size=(9, 7, 16), stride=(112, 16, 1), dtype=float64, device=cuda
|
||||
var_node_11 = torch.matmul(
|
||||
var_node_12.to(torch.float64), var_node_15.to(torch.float64)
|
||||
) # size=(9, 8, 16), stride=(128, 16, 1), dtype=float64, device=cuda
|
||||
var_node_3 = torch.matmul(
|
||||
var_node_4.to(torch.float64), var_node_11.to(torch.float64)
|
||||
) # size=(9, 9, 16), stride=(144, 16, 1), dtype=float64, device=cuda
|
||||
var_node_17 = arg_6 # size=(9, 16, 12), stride=(192, 12, 1), dtype=float64, device=cuda
|
||||
var_node_18 = arg_7 # size=(9, 12, 11), stride=(132, 11, 1), dtype=float64, device=cuda
|
||||
var_node_16 = torch.matmul(
|
||||
var_node_17.to(torch.float64), var_node_18.to(torch.float64)
|
||||
) # size=(9, 16, 11), stride=(176, 11, 1), dtype=float64, device=cuda
|
||||
var_node_2 = torch.matmul(
|
||||
var_node_3.to(torch.float64), var_node_16.to(torch.float64)
|
||||
) # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
|
||||
var_node_23 = torch.full(
|
||||
(156, 8), -0.5249394453404403, dtype=torch.float64
|
||||
) # size=(156, 8), stride=(8, 1), dtype=float64, device=cuda
|
||||
var_node_24 = torch.full(
|
||||
(8, 9), 0.9331226188585692, dtype=torch.float64
|
||||
) # size=(8, 9), stride=(9, 1), dtype=float64, device=cuda
|
||||
var_node_22 = torch.matmul(
|
||||
var_node_23.to(torch.float64), var_node_24.to(torch.float64)
|
||||
) # size=(156, 9), stride=(9, 1), dtype=float64, device=cuda
|
||||
var_node_26 = torch.full(
|
||||
(9, 13), -0.9276381954691514, dtype=torch.float64
|
||||
) # size=(9, 13), stride=(13, 1), dtype=float64, device=cuda
|
||||
var_node_27 = torch.full(
|
||||
(13, 16), 0.024752238943232543, dtype=torch.float64
|
||||
) # size=(13, 16), stride=(16, 1), dtype=float64, device=cuda
|
||||
var_node_25 = torch.matmul(
|
||||
var_node_26.to(torch.float64), var_node_27.to(torch.float64)
|
||||
) # size=(9, 16), stride=(16, 1), dtype=float64, device=cuda
|
||||
var_node_21 = torch.matmul(
|
||||
var_node_22.to(torch.float64), var_node_25.to(torch.float64)
|
||||
) # size=(156, 16), stride=(16, 1), dtype=float64, device=cuda
|
||||
var_node_29 = arg_8
|
||||
_x_nz = torch.zeros(
|
||||
(9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
|
||||
dtype=torch.bool,
|
||||
device=var_node_29.device,
|
||||
)
|
||||
_x_nz_flat = _x_nz.reshape(-1)
|
||||
_x_nz_flat[:9] = True
|
||||
var_node_28 = torch.nonzero(
|
||||
_x_nz
|
||||
) # size=(9, 11), stride=(11, 1), dtype=int64, device=cuda
|
||||
var_node_20 = torch.nn.functional.embedding(
|
||||
torch.clamp(var_node_28.to(torch.int64), 0, var_node_21.size(0) - 1),
|
||||
var_node_21,
|
||||
) # size=(9, 11, 16), stride=(176, 16, 1), dtype=float64, device=cuda
|
||||
var_node_33 = torch.full(
|
||||
(9, 16, 5), 1.0707914920634904, dtype=torch.float64
|
||||
) # size=(9, 16, 5), stride=(80, 5, 1), dtype=float64, device=cuda
|
||||
var_node_34 = torch.full(
|
||||
(9, 5, 10), -0.44934093079047227, dtype=torch.float64
|
||||
) # size=(9, 5, 10), stride=(50, 10, 1), dtype=float64, device=cuda
|
||||
var_node_32 = torch.matmul(
|
||||
var_node_33.to(torch.float64), var_node_34.to(torch.float64)
|
||||
) # size=(9, 16, 10), stride=(160, 10, 1), dtype=float64, device=cuda
|
||||
var_node_36 = (
|
||||
arg_9 # size=(9, 10, 1), stride=(10, 1, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_37 = torch.full(
|
||||
(9, 1, 11), -1.874293687140311, dtype=torch.float64
|
||||
) # size=(9, 1, 11), stride=(11, 11, 1), dtype=float64, device=cuda
|
||||
var_node_35 = torch.matmul(
|
||||
var_node_36.to(torch.float64), var_node_37.to(torch.float64)
|
||||
) # size=(9, 10, 11), stride=(110, 11, 1), dtype=float64, device=cuda
|
||||
var_node_31 = torch.matmul(
|
||||
var_node_32.to(torch.float64), var_node_35.to(torch.float64)
|
||||
) # size=(9, 16, 11), stride=(176, 11, 1), dtype=float64, device=cuda
|
||||
var_node_40 = torch.full(
|
||||
(990, 2), 0.4084376380351558, dtype=torch.float64
|
||||
) # size=(990, 2), stride=(2, 1), dtype=float64, device=cuda
|
||||
var_node_41 = torch.full(
|
||||
(2,), 0.982671965550022, dtype=torch.float64
|
||||
) # size=(2,), stride=(1,), dtype=float64, device=cuda
|
||||
var_node_39 = torch.matmul(
|
||||
var_node_40.to(torch.float64), var_node_41.to(torch.float64)
|
||||
) # size=(990,), stride=(1,), dtype=float64, device=cuda
|
||||
var_node_38 = torch.reshape(
|
||||
var_node_39, [9, 11, 10]
|
||||
) # size=(9, 11, 10), stride=(110, 10, 1), dtype=float64, device=cuda
|
||||
var_node_30 = torch.matmul(
|
||||
var_node_31.to(torch.float64), var_node_38.to(torch.float64)
|
||||
) # size=(9, 16, 10), stride=(160, 10, 1), dtype=float64, device=cuda
|
||||
var_node_19 = torch.matmul(
|
||||
var_node_20.to(torch.float64), var_node_30.to(torch.float64)
|
||||
) # size=(9, 11, 10), stride=(110, 10, 1), dtype=float64, device=cuda
|
||||
var_node_1 = torch.matmul(
|
||||
var_node_2.to(torch.float64), var_node_19.to(torch.float64)
|
||||
) # size=(9, 9, 10), stride=(90, 10, 1), dtype=float64, device=cuda
|
||||
var_node_47 = arg_10 # size=(9, 10, 15), stride=(150, 15, 1), dtype=float64, device=cuda
|
||||
var_node_48 = torch.full(
|
||||
(9, 15, 2), -0.3349339402390618, dtype=torch.float64
|
||||
) # size=(9, 15, 2), stride=(30, 2, 1), dtype=float64, device=cuda
|
||||
var_node_46 = torch.matmul(
|
||||
var_node_47.to(torch.float64), var_node_48.to(torch.float64)
|
||||
) # size=(9, 10, 2), stride=(20, 2, 1), dtype=float64, device=cuda
|
||||
var_node_50 = (
|
||||
arg_11 # size=(9, 2, 7), stride=(14, 7, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_51 = (
|
||||
arg_12 # size=(9, 7, 2), stride=(14, 2, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_49 = torch.matmul(
|
||||
var_node_50.to(torch.float64), var_node_51.to(torch.float64)
|
||||
) # size=(9, 2, 2), stride=(4, 2, 1), dtype=float64, device=cuda
|
||||
var_node_45 = torch.matmul(
|
||||
var_node_46.to(torch.float64), var_node_49.to(torch.float64)
|
||||
) # size=(9, 10, 2), stride=(20, 2, 1), dtype=float64, device=cuda
|
||||
var_node_52 = torch.full(
|
||||
(9, 2, 1), -0.4046675639434615, dtype=torch.float64
|
||||
) # size=(9, 2, 1), stride=(2, 1, 1), dtype=float64, device=cuda
|
||||
var_node_44 = torch.matmul(
|
||||
var_node_45.to(torch.float64), var_node_52.to(torch.float64)
|
||||
) # size=(9, 10, 1), stride=(10, 1, 1), dtype=float64, device=cuda
|
||||
var_node_56 = (
|
||||
arg_13 # size=(9, 1, 1), stride=(1, 1, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_55 = torch.nn.functional.rms_norm(
|
||||
var_node_56.to(torch.float64), (1,)
|
||||
) # size=(9, 1, 1), stride=(1, 1, 1), dtype=float64, device=cuda
|
||||
var_node_57 = torch.full(
|
||||
(9, 1, 8), 0.17877664640931384, dtype=torch.float64
|
||||
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
|
||||
var_node_54 = torch.matmul(
|
||||
var_node_55.to(torch.float64), var_node_57.to(torch.float64)
|
||||
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
|
||||
var_node_60 = arg_14 # size=(9, 8, 10), stride=(80, 10, 1), dtype=float64, device=cuda
|
||||
var_node_61 = torch.full(
|
||||
(9, 10, 6), 0.43614806380221494, dtype=torch.float64
|
||||
) # size=(9, 10, 6), stride=(60, 6, 1), dtype=float64, device=cuda
|
||||
var_node_59 = torch.matmul(
|
||||
var_node_60.to(torch.float64), var_node_61.to(torch.float64)
|
||||
) # size=(9, 8, 6), stride=(48, 6, 1), dtype=float64, device=cuda
|
||||
var_node_63 = (
|
||||
arg_15 # size=(9, 6, 3), stride=(18, 3, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_64 = torch.full(
|
||||
(9, 3, 8), -0.042774422041922854, dtype=torch.float64
|
||||
) # size=(9, 3, 8), stride=(24, 8, 1), dtype=float64, device=cuda
|
||||
var_node_62 = torch.matmul(
|
||||
var_node_63.to(torch.float64), var_node_64.to(torch.float64)
|
||||
) # size=(9, 6, 8), stride=(48, 8, 1), dtype=float64, device=cuda
|
||||
var_node_58 = torch.matmul(
|
||||
var_node_59.to(torch.float64), var_node_62.to(torch.float64)
|
||||
) # size=(9, 8, 8), stride=(64, 8, 1), dtype=float64, device=cuda
|
||||
var_node_53 = torch.matmul(
|
||||
var_node_54.to(torch.float64), var_node_58.to(torch.float64)
|
||||
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
|
||||
var_node_43 = torch.matmul(
|
||||
var_node_44.to(torch.float64), var_node_53.to(torch.float64)
|
||||
) # size=(9, 10, 8), stride=(80, 8, 1), dtype=float64, device=cuda
|
||||
var_node_68 = arg_16 # size=(9, 8, 16), stride=(128, 16, 1), dtype=float64, device=cuda
|
||||
var_node_70 = torch.full(
|
||||
(9, 16, 15), 0.24947808634496438, dtype=torch.float64
|
||||
) # size=(9, 16, 15), stride=(240, 15, 1), dtype=float64, device=cuda
|
||||
var_node_71 = torch.full(
|
||||
(9, 15, 7), -0.09035245509773453, dtype=torch.float64
|
||||
) # size=(9, 15, 7), stride=(105, 7, 1), dtype=float64, device=cuda
|
||||
var_node_69 = torch.matmul(
|
||||
var_node_70.to(torch.float64), var_node_71.to(torch.float64)
|
||||
) # size=(9, 16, 7), stride=(112, 7, 1), dtype=float64, device=cuda
|
||||
var_node_67 = torch.matmul(
|
||||
var_node_68.to(torch.float64), var_node_69.to(torch.float64)
|
||||
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
|
||||
var_node_74 = torch.full(
|
||||
(9, 7, 1), 0.05671950481832341, dtype=torch.float64
|
||||
) # size=(9, 7, 1), stride=(7, 1, 1), dtype=float64, device=cuda
|
||||
var_node_73 = torch.nn.functional.gelu(
|
||||
var_node_74
|
||||
) # size=(9, 7, 1), stride=(7, 1, 1), dtype=float64, device=cuda
|
||||
var_node_76 = torch.full(
|
||||
(9, 1, 2), -0.019912810353597852, dtype=torch.float64
|
||||
) # size=(9, 1, 2), stride=(2, 2, 1), dtype=float64, device=cuda
|
||||
var_node_77 = (
|
||||
arg_17 # size=(9, 2, 7), stride=(14, 7, 1), dtype=float64, device=cuda
|
||||
)
|
||||
var_node_75 = torch.matmul(
|
||||
var_node_76.to(torch.float64), var_node_77.to(torch.float64)
|
||||
) # size=(9, 1, 7), stride=(7, 7, 1), dtype=float64, device=cuda
|
||||
var_node_72 = torch.matmul(
|
||||
var_node_73.to(torch.float64), var_node_75.to(torch.float64)
|
||||
) # size=(9, 7, 7), stride=(49, 7, 1), dtype=float64, device=cuda
|
||||
var_node_66 = torch.matmul(
|
||||
var_node_67.to(torch.float64), var_node_72.to(torch.float64)
|
||||
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
|
||||
var_node_78 = arg_18 # size=(9, 7, 13), stride=(91, 13, 1), dtype=float64, device=cuda
|
||||
var_node_65 = torch.matmul(
|
||||
var_node_66.to(torch.float64), var_node_78.to(torch.float64)
|
||||
) # size=(9, 8, 13), stride=(104, 13, 1), dtype=float64, device=cuda
|
||||
var_node_42 = torch.matmul(
|
||||
var_node_43.to(torch.float64), var_node_65.to(torch.float64)
|
||||
) # size=(9, 10, 13), stride=(130, 13, 1), dtype=float64, device=cuda
|
||||
var_node_0 = torch.matmul(
|
||||
var_node_1.to(torch.float64), var_node_42.to(torch.float64)
|
||||
) # size=(9, 9, 13), stride=(117, 13, 1), dtype=float64, device=cuda
|
||||
# Ensure gradient computation by multiplying with sentinel and taking real part
|
||||
result = var_node_0 * sentinel
|
||||
if result.is_complex():
|
||||
result = result.real
|
||||
return result
|
||||
|
||||
# Sentinel tensor to ensure gradient computation
|
||||
sentinel = torch.tensor(1.0, requires_grad=True)
|
||||
|
||||
arg_0 = torch.as_strided(
|
||||
torch.randn(729).to(torch.float64), (9, 9, 9), (81, 9, 1)
|
||||
)
|
||||
arg_1 = torch.as_strided(
|
||||
torch.randn(891).to(torch.float64), (9, 9, 11), (99, 11, 1)
|
||||
)
|
||||
arg_2 = torch.as_strided(
|
||||
torch.randn(864).to(torch.float64), (9, 12, 8), (96, 8, 1)
|
||||
)
|
||||
arg_3 = torch.as_strided(
|
||||
torch.randn(936).to(torch.float64), (9, 8, 13), (104, 13, 1)
|
||||
)
|
||||
arg_4 = torch.as_strided(
|
||||
torch.randn(819).to(torch.float64), (9, 13, 7), (91, 7, 1)
|
||||
)
|
||||
arg_5 = torch.as_strided(
|
||||
torch.randn(1008).to(torch.float64), (9, 7, 16), (112, 16, 1)
|
||||
)
|
||||
arg_6 = torch.as_strided(
|
||||
torch.randn(1728).to(torch.float64), (9, 16, 12), (192, 12, 1)
|
||||
)
|
||||
arg_7 = torch.as_strided(
|
||||
torch.randn(1188).to(torch.float64), (9, 12, 11), (132, 11, 1)
|
||||
)
|
||||
arg_8 = torch.as_strided(
|
||||
torch.randint(0, 2, (1,), dtype=torch.int8).bool(),
|
||||
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
|
||||
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
|
||||
)
|
||||
arg_9 = torch.as_strided(
|
||||
torch.randn(90).to(torch.float64), (9, 10, 1), (10, 1, 1)
|
||||
)
|
||||
arg_10 = torch.as_strided(
|
||||
torch.randn(1350).to(torch.float64), (9, 10, 15), (150, 15, 1)
|
||||
)
|
||||
arg_11 = torch.as_strided(
|
||||
torch.randn(126).to(torch.float64), (9, 2, 7), (14, 7, 1)
|
||||
)
|
||||
arg_12 = torch.as_strided(
|
||||
torch.randn(126).to(torch.float64), (9, 7, 2), (14, 2, 1)
|
||||
)
|
||||
arg_13 = torch.as_strided(
|
||||
torch.randn(9).to(torch.float64), (9, 1, 1), (1, 1, 1)
|
||||
)
|
||||
arg_14 = torch.as_strided(
|
||||
torch.randn(720).to(torch.float64), (9, 8, 10), (80, 10, 1)
|
||||
)
|
||||
arg_15 = torch.as_strided(
|
||||
torch.randn(162).to(torch.float64), (9, 6, 3), (18, 3, 1)
|
||||
)
|
||||
arg_16 = torch.as_strided(
|
||||
torch.randn(1152).to(torch.float64), (9, 8, 16), (128, 16, 1)
|
||||
)
|
||||
arg_17 = torch.as_strided(
|
||||
torch.randn(126).to(torch.float64), (9, 2, 7), (14, 7, 1)
|
||||
)
|
||||
arg_18 = torch.as_strided(
|
||||
torch.randn(819).to(torch.float64), (9, 7, 13), (91, 13, 1)
|
||||
)
|
||||
|
||||
args = (
|
||||
arg_0,
|
||||
arg_1,
|
||||
arg_2,
|
||||
arg_3,
|
||||
arg_4,
|
||||
arg_5,
|
||||
arg_6,
|
||||
arg_7,
|
||||
arg_8,
|
||||
arg_9,
|
||||
arg_10,
|
||||
arg_11,
|
||||
arg_12,
|
||||
arg_13,
|
||||
arg_14,
|
||||
arg_15,
|
||||
arg_16,
|
||||
arg_17,
|
||||
arg_18,
|
||||
) + (sentinel,)
|
||||
result_original = fuzzed_program(*args)
|
||||
compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
|
||||
result_compiled = compiled_program(*args)
|
||||
|
||||
# Both should succeed without NameError
|
||||
self.assertTrue(
|
||||
torch.allclose(result_original, result_compiled, rtol=1e-5, atol=1e-5)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -4524,17 +4524,6 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
run(torch.rand(2, 10), torch.rand(2, 10))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
||||
def test_unbacked_view_extra(self):
|
||||
def fn(x):
|
||||
i0 = x.nonzero().size(0)
|
||||
y = torch.zeros((i0, 192))
|
||||
return y.view([12, -1, 192])
|
||||
|
||||
res1 = torch.compile(fn, fullgraph=True)(torch.ones((12,)))
|
||||
res2 = fn(torch.ones((12,)))
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
||||
@ -3755,44 +3755,6 @@ as the input tensor excluding its innermost dimension'):
|
||||
with ctx:
|
||||
self.assertEqual(torch.mean(t), expected)
|
||||
|
||||
def test_scalar_tensor_as_dim_argument(self):
|
||||
"""Tests that scalar tensors work correctly as dimension arguments.
|
||||
|
||||
This tests the fix for the PythonArgParser bug where scalar Tensors
|
||||
passed to IntList/SymIntList parameters would be incorrectly handled.
|
||||
"""
|
||||
x = torch.ones(1, 2, 3, 4, 5)
|
||||
|
||||
# Scalar tensors should work correctly (same as passing an int)
|
||||
result_tensor = x.sum(dim=torch.tensor(3))
|
||||
result_int = x.sum(dim=3)
|
||||
self.assertEqual(result_tensor.shape, result_int.shape)
|
||||
self.assertEqual(result_tensor.shape, torch.Size([1, 2, 3, 5]))
|
||||
|
||||
# Test with different integer dtypes
|
||||
for dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
|
||||
dim_tensor = torch.tensor(1, dtype=dtype)
|
||||
result = x.sum(dim=dim_tensor)
|
||||
expected = x.sum(dim=1)
|
||||
self.assertEqual(result.shape, expected.shape)
|
||||
|
||||
@skipIfTorchDynamo("Test uses random.randint which creates FakeTensors")
|
||||
def test_scalar_tensor_dim_compiled_mode(self):
|
||||
"""Tests that scalar FakeTensors from random.randint work correctly in compiled mode."""
|
||||
def foo():
|
||||
x = torch.ones(2, 2, 2)
|
||||
return x.sum(dim=random.randint(0, 0))
|
||||
|
||||
@torch.compile
|
||||
def foo_compile():
|
||||
x = torch.ones(2, 2, 2)
|
||||
return x.sum(dim=random.randint(0, 0))
|
||||
|
||||
result_eager = foo()
|
||||
result_compiled = foo_compile()
|
||||
self.assertEqual(result_eager.shape, result_compiled.shape)
|
||||
self.assertEqual(result_eager.shape, torch.Size([2, 2]))
|
||||
|
||||
instantiate_device_type_tests(TestReductions, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -2236,6 +2236,7 @@ class TestSparse(TestSparseBase):
|
||||
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
@expectedFailureMPS
|
||||
@skipIfCrossRef
|
||||
def test_sparse_mask_backward(self, device, dtype):
|
||||
from itertools import product, repeat
|
||||
@ -2245,6 +2246,7 @@ class TestSparse(TestSparseBase):
|
||||
nnzs = (0, 5, 15, 25)
|
||||
|
||||
lhs_data = torch.arange(1, 26, device=device).reshape(shape).to(dtype).to_sparse(sparse_dims)
|
||||
rhs_data = lhs_data.clone()
|
||||
|
||||
for nnz in nnzs:
|
||||
for lhs_is_coalesced, rhs_is_coalesced in product(*repeat((True, False), 2)):
|
||||
@ -2264,9 +2266,8 @@ class TestSparse(TestSparseBase):
|
||||
# sparsity_pattern(lhs) == sparsity_pattern(lhs.grad).
|
||||
# lhs.sparse_mask(lhs_mask) accomplishes that.
|
||||
lhs_mask = lhs.detach().clone()
|
||||
gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,),
|
||||
masked=True, eps=3e-4, atol=5e-5)
|
||||
gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False, eps=3e-4, atol=5e-5)
|
||||
gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,), masked=True)
|
||||
gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False)
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
|
||||
@ -1781,14 +1781,6 @@ class TestTorchDeviceType(TestCase):
|
||||
self.assertEqual(b[0, :], d[0, :], atol=3e-5, rtol=3e-5)
|
||||
self.assertEqual(b[-1, :], d[-1, :], atol=3e-5, rtol=3e-5)
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest('48GB')
|
||||
def test_cumsum_outer_dim_64bit_indexing(self, device):
|
||||
x = torch.zeros(309504, 1, 16384, device=device)
|
||||
torch.exp(x)
|
||||
cumsum = torch.cumsum(x, dim=1)
|
||||
self.assertEqual(cumsum.max().item(), 0., atol=0., rtol=0.)
|
||||
|
||||
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
|
||||
@onlyNativeDeviceTypes
|
||||
def test_nondeterministic_alert_put(self, device):
|
||||
|
||||
@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
|
||||
reset_cudagraph_trees()
|
||||
|
||||
|
||||
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
|
||||
compiler_name = "aotinductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
super().__init__(mode, options, dynamic)
|
||||
self.apply_options({"cpp_wrapper": True})
|
||||
self.apply_options({"aot_inductor.package": True})
|
||||
|
||||
def __call__(self, model_, inputs_):
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
fake_mode = detect_fake_mode(inputs_)
|
||||
ctx = (
|
||||
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
||||
if fake_mode
|
||||
else nullcontext()
|
||||
)
|
||||
with (
|
||||
V.set_aot_compilation(True),
|
||||
ctx,
|
||||
torch._inductor.config.patch("enable_autograd_for_aot", True),
|
||||
):
|
||||
return super().__call__(model_, inputs_)
|
||||
|
||||
|
||||
class _TorchCompileWrapper:
|
||||
def __init__(self, backend, mode, options, dynamic):
|
||||
from torch._dynamo.backends.registry import lookup_backend
|
||||
@ -2672,8 +2701,10 @@ def compile(
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
use_aoti = False
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
use_aoti = options.pop("use_aoti", False)
|
||||
|
||||
if torch.compiler.is_exporting():
|
||||
warnings.warn(
|
||||
@ -2700,7 +2731,10 @@ def compile(
|
||||
return export_wrapped_fn
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
if use_aoti:
|
||||
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||
|
||||
|
||||
@ -4562,8 +4562,6 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@out_wrapper(pass_is_out=True)
|
||||
def matmul(tensor1, tensor2, *, is_out=False):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
|
||||
dim_tensor1 = tensor1.dim()
|
||||
dim_tensor2 = tensor2.dim()
|
||||
assert dim_tensor1 != 0 and dim_tensor2 != 0
|
||||
@ -4632,11 +4630,11 @@ def matmul(tensor1, tensor2, *, is_out=False):
|
||||
if (
|
||||
dim_tensor1 == 3
|
||||
and dim_tensor2 == 3
|
||||
and guard_or_true(batch_tensor1[0] != batch_tensor2[0])
|
||||
and batch_tensor1[0] != batch_tensor2[0]
|
||||
):
|
||||
if guard_or_false(batch_tensor1[0] == 1) and tensor1.requires_grad:
|
||||
if batch_tensor1[0] == 1 and tensor1.requires_grad:
|
||||
return matmul(tensor1.squeeze(0), tensor2)
|
||||
if guard_or_false(batch_tensor2[0] == 1) and tensor2.requires_grad:
|
||||
if batch_tensor2[0] == 1 and tensor2.requires_grad:
|
||||
return matmul(tensor1, tensor2.squeeze(0))
|
||||
|
||||
# expand the batch portion (i.e. cut off matrix dimensions and expand rest)
|
||||
|
||||
@ -53,6 +53,7 @@ class CompileArtifacts:
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
source_info: "SourceInfo"
|
||||
device_type: str
|
||||
backend_name: str
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
||||
def check_compatibility(self) -> None:
|
||||
@ -273,6 +274,7 @@ def aot_compile_fullgraph(
|
||||
argdefs=fn.__defaults__,
|
||||
source_info=source_info,
|
||||
device_type=device_type,
|
||||
backend_name=getattr(backend, "compiler_name", "unknown"),
|
||||
)
|
||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||
|
||||
|
||||
@ -499,6 +499,9 @@ def pytreeify(
|
||||
root = mod.__self__
|
||||
|
||||
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
torch._dynamo.eval_frame.check_user_input_output(
|
||||
flat_real_args[1 if root else 0 :], UserErrorType.INVALID_INPUT
|
||||
)
|
||||
|
||||
class Yield(Exception):
|
||||
pass
|
||||
|
||||
@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
|
||||
).post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||
)
|
||||
compiled_fw_func._boxed_call = True
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if needs_autograd:
|
||||
|
||||
@ -33,7 +33,6 @@ from .graph_capture_wrappers import (
|
||||
handle_effect_tokens_fn,
|
||||
)
|
||||
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
|
||||
from .streams import assign_backward_streams
|
||||
from .utils import (
|
||||
call_and_expect_output_descs,
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
@ -474,9 +473,6 @@ def aot_dispatch_autograd_graph(
|
||||
# fw node match might be erased
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
|
||||
# After copying metadata, assign streams to gradient accumulation nodes
|
||||
assign_backward_streams(fx_g)
|
||||
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
from typing import Optional, TypeAlias
|
||||
|
||||
import torch.fx
|
||||
import torch.fx.traceback
|
||||
from torch._dynamo.graph_utils import _get_flat_args
|
||||
|
||||
|
||||
Node: TypeAlias = torch.fx.Node
|
||||
|
||||
|
||||
def is_gradient_acc(node: Node) -> bool:
|
||||
return node.meta.get("is_gradient_acc", False)
|
||||
|
||||
|
||||
def get_stream(node: Node) -> Optional[int]:
|
||||
maybe_annotation = node.meta.get("custom", None)
|
||||
if maybe_annotation is not None:
|
||||
return node.meta["custom"].get("stream", None)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def set_stream(node: Node, ind: int) -> None:
|
||||
if "custom" in node.meta:
|
||||
node.meta["custom"].update({"stream": ind})
|
||||
else:
|
||||
node.meta["custom"] = {"stream": ind}
|
||||
|
||||
|
||||
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
|
||||
"""Assigns backward streams to gradient accumulation nodes"""
|
||||
|
||||
# NB: iterate in reverse order to more closely match eager
|
||||
# the user node stream will be populated first
|
||||
for node in reversed(list(gm.graph.nodes)):
|
||||
if is_gradient_acc(node):
|
||||
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
|
||||
# 1. Match first stream assignment of the first user with a stream
|
||||
# 2. Match first stream assignment encountered in the args from left to right
|
||||
# This differs from eager in some cases:
|
||||
# Specifically the eager code uses the autograd node to determine the stream,
|
||||
# crucially this does not necessarily correspond to the FX graph node. For example,
|
||||
# in the backward for an add node with a constant we will passthrough and during backward tracing,
|
||||
# no op will be added to the FX graph, so our stream assignment will differ in this case.
|
||||
gradients = _get_flat_args(node, {})
|
||||
users = list(node.users.keys())
|
||||
|
||||
# All gradients will be on same device, they will be coerced if they were not with a .to() node
|
||||
for neighbor in users + gradients:
|
||||
ind = get_stream(neighbor)
|
||||
if ind is not None:
|
||||
set_stream(node, ind)
|
||||
break
|
||||
@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(str, list, torch.fx.GraphModule),
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
return CompiledAOTI(
|
||||
filename=compiled_fn, device_type=graph.device_type
|
||||
)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
# pyrefly: ignore [unbound-name]
|
||||
@ -2713,7 +2715,7 @@ def _compile_fx_main(
|
||||
or torch._guards.TracingContext(fake_mode)
|
||||
)
|
||||
|
||||
if V.aot_compilation:
|
||||
if V.aot_compilation and not config.enable_autograd_for_aot:
|
||||
from .utils import is_valid_aoti_model_name
|
||||
|
||||
is_valid_aoti_model_name()
|
||||
|
||||
@ -1193,6 +1193,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
|
||||
|
||||
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
||||
|
||||
enable_autograd_for_aot: bool = False
|
||||
|
||||
|
||||
def get_worker_log_path() -> Optional[str]:
|
||||
log_loc = None
|
||||
|
||||
@ -773,9 +773,83 @@ class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
|
||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||
device_type: str
|
||||
current_callable: Optional[Callable[..., Any]] = None
|
||||
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not config.aot_inductor.link_libtorch:
|
||||
return
|
||||
|
||||
if (
|
||||
torch._inductor.cpp_builder._IS_MACOS
|
||||
or torch._inductor.cpp_builder._IS_WINDOWS
|
||||
):
|
||||
return
|
||||
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
return
|
||||
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
return
|
||||
|
||||
if isinstance(self.filename, list):
|
||||
current_callable = next(
|
||||
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
|
||||
)
|
||||
else:
|
||||
current_callable = self.filename
|
||||
|
||||
if isinstance(current_callable, torch.fx.GraphModule):
|
||||
self.current_callable = current_callable
|
||||
return
|
||||
|
||||
if self.device_type.startswith("cuda"):
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
|
||||
current_callable,
|
||||
1,
|
||||
self.device_type,
|
||||
"",
|
||||
True,
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
elif self.device_type == "cpu":
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
|
||||
current_callable, 1
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise RuntimeError(f"unsupported device type {self.device_type}")
|
||||
self.current_callable = current_callable
|
||||
self._boxed_call = True
|
||||
for file in self._cached_files:
|
||||
if not os.path.exists(file):
|
||||
with open(file, "wb") as f:
|
||||
f.write(self._cached_files[file])
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
if self.current_callable is None:
|
||||
raise RuntimeError("AOTInductor compiled so is not loaded")
|
||||
return self.current_callable(inputs)
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
self.current_callable = None
|
||||
self._cached_files = {}
|
||||
filenames: list[str] = []
|
||||
if isinstance(self.filename, list):
|
||||
filenames = self.filename # type: ignore[assignment]
|
||||
elif isinstance(self.filename, str):
|
||||
filenames = [self.filename]
|
||||
for name in filenames:
|
||||
with open(name, "rb") as f:
|
||||
self._cached_files[name] = f.read()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["current_callable"] = None
|
||||
return state
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
@ -783,10 +857,8 @@ class CompiledAOTI(OutputCode):
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
pass
|
||||
if self.current_callable is None:
|
||||
self.__post_init__()
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
@ -1751,7 +1751,7 @@ static PyObject* THPVariable_dtensor_new(
|
||||
Tensor tensor = make_tensor_for_subclass_helper(
|
||||
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
|
||||
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
|
||||
/*sym_storage_offset=*/local_tensor.sym_storage_offset(),
|
||||
/*sym_storage_offset=*/std::nullopt,
|
||||
options,
|
||||
/*storage_size=*/std::nullopt,
|
||||
extra_dispatch_keys);
|
||||
|
||||
@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&>())
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool>())
|
||||
.def(
|
||||
"run",
|
||||
&AOTIModelContainerRunnerCuda::run,
|
||||
|
||||
@ -565,16 +565,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
||||
return std::vector<c10::SymInt>(size1, si);
|
||||
}
|
||||
|
||||
if (size1 > 0 && THPVariable_Check(args[i])) {
|
||||
return std::vector<c10::SymInt>(
|
||||
size1, THPVariable_Unpack(args[i]).item().toSymInt());
|
||||
}
|
||||
|
||||
PyObject* arg = args[i];
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
if (!tuple) {
|
||||
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
std::vector<c10::SymInt> res;
|
||||
@ -653,13 +645,7 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
|
||||
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
|
||||
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
|
||||
}
|
||||
if (size1 > 0 && THPVariable_Check(arg)) {
|
||||
return std::vector<int64_t>(size1, THPVariable_Unpack(arg).item<int64_t>());
|
||||
}
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
if (!tuple) {
|
||||
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
std::vector<int64_t> res(size2);
|
||||
@ -730,9 +716,6 @@ inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
|
||||
inline std::vector<double> PythonArgs::getDoublelist(int i) {
|
||||
PyObject* arg = args[i];
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
if (!tuple) {
|
||||
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
std::vector<double> res(size);
|
||||
@ -906,9 +889,6 @@ inline at::Dimname PythonArgs::dimname(int i) {
|
||||
|
||||
inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
if (!tuple) {
|
||||
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
std::vector<at::Dimname> res;
|
||||
|
||||
@ -7037,16 +7037,52 @@ class ShapeEnv:
|
||||
ok = len(free_unbacked_symbols(new_var)) == 0
|
||||
if ok:
|
||||
self._set_replacement(free[0], new_var, "solve")
|
||||
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
# expression has mod.
|
||||
if expr.has(Mod):
|
||||
mod_expr = next(iter(expr.atoms(Mod)))
|
||||
try:
|
||||
r = try_solve(expr, mod_expr, floordiv_inequality=False)
|
||||
if r is not None and r[1] == 0:
|
||||
self._add_divisible(mod_expr)
|
||||
# This is a little bit of extra logic to make things like
|
||||
# torch.empty(i0, q).view(c, -1, q) work out
|
||||
p, q = mod_expr.args
|
||||
if (
|
||||
isinstance(q, sympy.Number)
|
||||
and isinstance(p, sympy.Mul)
|
||||
and len(p.args) == 2
|
||||
):
|
||||
c, i0 = p.args
|
||||
# Given Mod(c * i0, q) == 0
|
||||
if (
|
||||
isinstance(c, sympy.Number)
|
||||
and isinstance(i0, sympy.Symbol)
|
||||
and self.is_unbacked_symint(i0)
|
||||
):
|
||||
# We have Mod(i0, q / c) == 0, which means we can
|
||||
# rewrite i0 as (q / gcd(q, c)) * i1
|
||||
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
|
||||
i1 = self.create_unbacked_symint().node.expr
|
||||
# Propagate the value ranges. It doesn't really
|
||||
# matter if we use truediv or floordiv, because we
|
||||
# have established divisibility.
|
||||
self._update_var_to_range(
|
||||
i1,
|
||||
SymPyValueRangeAnalysis.floordiv(
|
||||
self.var_to_range[i0], ValueRanges.wrap(d)
|
||||
),
|
||||
)
|
||||
# Propagate hints (real tensor tracing)
|
||||
if i0 in self.unbacked_var_to_val:
|
||||
self.set_unbacked_var_to_val(
|
||||
i1, self.unbacked_var_to_val[i0] // d
|
||||
)
|
||||
# Propagate size-like-ness
|
||||
if i0 in self.size_like:
|
||||
self.size_like.add(i1)
|
||||
self._set_replacement(i0, d * i1, "divisibility")
|
||||
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return
|
||||
|
||||
@ -39,7 +39,7 @@ import os
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union # noqa: F401
|
||||
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
@ -157,25 +157,21 @@ def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
|
||||
return str(arg)
|
||||
|
||||
|
||||
def norm_hash_fn(
|
||||
t: torch.Tensor, use_scalar: bool = False
|
||||
) -> Union[torch.Tensor, float]:
|
||||
def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
|
||||
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
|
||||
This is used to generate a deterministic summary value for tensor comparison.
|
||||
"""
|
||||
with torch._C._DisablePythonDispatcher():
|
||||
with torch._C._DisablePythonDispatcher(), torch._C._DisableTorchDispatch():
|
||||
if not (t.is_floating_point() or t.is_complex()):
|
||||
t = t.float()
|
||||
t = t.contiguous()
|
||||
# Clean the tensor to handle NaN/inf values, then compute norm
|
||||
t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
if t.is_complex():
|
||||
t_float = t.to(dtype=torch.complex128)
|
||||
else:
|
||||
t_float = t.to(dtype=torch.float64)
|
||||
|
||||
out = t_float.norm(p=1)
|
||||
dtype = torch.complex128 if t.is_complex() else torch.float64
|
||||
out = t_clean.norm(p=1, dtype=dtype)
|
||||
if use_scalar:
|
||||
return out.item()
|
||||
return out
|
||||
@ -188,28 +184,6 @@ def _compute_rel_diff(hash1, hash2):
|
||||
return numerator / denominator
|
||||
|
||||
|
||||
def hash_tensor_fn(
|
||||
t: torch.Tensor, use_scalar: bool = False
|
||||
) -> Union[torch.Tensor, int]:
|
||||
"""
|
||||
wrapper over torch.hash_tensor
|
||||
"""
|
||||
if isinstance(t, torch.distributed.tensor.DTensor):
|
||||
t = t.to_local()
|
||||
|
||||
if t.is_floating_point():
|
||||
t_clean = t.to(dtype=torch.float64)
|
||||
elif t.is_complex():
|
||||
t_clean = t.to(dtype=torch.complex128).view(torch.float64)
|
||||
else:
|
||||
t_clean = t.to(dtype=torch.int64)
|
||||
|
||||
out = torch.hash_tensor(t_clean)
|
||||
if use_scalar:
|
||||
return out.item() # type: ignore[attribute]
|
||||
return out
|
||||
|
||||
|
||||
def _get_stack_trace() -> str:
|
||||
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
||||
|
||||
@ -923,43 +897,20 @@ class DebugMode(TorchDispatchMode):
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def log_tensor_hashes(
|
||||
hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False
|
||||
):
|
||||
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
|
||||
"""
|
||||
Installs hook for tensor hash logging.
|
||||
|
||||
hash_fn: One of:
|
||||
- Custom-defined hash function
|
||||
- String: one of ("norm", "hash_tensor")
|
||||
- "norm": uses norm_hash_fn; basically tensor's L1 norm
|
||||
- "hash_tensor": uses torch.hash_tensor (XOR sum reduction)
|
||||
- List of strings: returns tuple of hashes from above options
|
||||
hash_fn: optional function for custom hashing
|
||||
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
|
||||
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
|
||||
"""
|
||||
|
||||
def hash_fn_option(hash_type):
|
||||
assert isinstance(hash_type, str) and hash_type in ["norm", "hash_tensor"]
|
||||
return functools.partial(
|
||||
norm_hash_fn if hash_type == "norm" else hash_tensor_fn, use_scalar=True
|
||||
)
|
||||
|
||||
if callable(hash_fn):
|
||||
fn = hash_fn
|
||||
elif isinstance(hash_fn, str):
|
||||
fn = hash_fn_option(hash_fn)
|
||||
elif isinstance(hash_fn, list):
|
||||
fns = [hash_fn_option(fn) for fn in hash_fn]
|
||||
fn = lambda x: tuple(fn(x) for fn in fns) # noqa: E731
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"log_tensor_hashes() expected hash_fn to be callable, str, or list[str], but found {type(hash_fn)}"
|
||||
)
|
||||
if hash_fn is None:
|
||||
hash_fn = functools.partial(default_hash_fn, use_scalar=True)
|
||||
|
||||
def _tree_hash(obj):
|
||||
return tree_map(
|
||||
lambda x: fn(x) if isinstance(x, torch.Tensor) else None, obj
|
||||
lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj
|
||||
)
|
||||
|
||||
def _dispatch_hash_hook(func, types, args, kwargs, result):
|
||||
@ -979,9 +930,9 @@ class DebugMode(TorchDispatchMode):
|
||||
try:
|
||||
if hash_inputs:
|
||||
_old_input_hfn = _TRITON_INPUT_HASH_FN
|
||||
_TRITON_INPUT_HASH_FN = fn
|
||||
_TRITON_INPUT_HASH_FN = hash_fn
|
||||
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
|
||||
_TRITON_OUTPUT_HASH_FN = fn
|
||||
_TRITON_OUTPUT_HASH_FN = hash_fn
|
||||
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
|
||||
yield
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user