Compare commits

..

1 Commits

48 changed files with 776 additions and 1023 deletions

View File

@ -143,8 +143,7 @@ init_command = [
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
'numpy==2.3.4 ; python_version >= "3.14"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',

View File

@ -174,12 +174,6 @@ class TORCH_API Context {
static long versionCuDNN() {
return detail::getCUDAHooks().versionCuDNN();
}
static long versionRuntimeCuDNN() {
return detail::getCUDAHooks().versionRuntimeCuDNN();
}
static long versionCuDNNFrontend() {
return detail::getCUDAHooks().versionCuDNNFrontend();
}
static bool hasCuSOLVER() {
return detail::getCUDAHooks().hasCuSOLVER();
}

View File

@ -6,7 +6,6 @@
#include <c10/util/Half.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/complex.h>
#include <torch/headeronly/core/Dispatch.h>
#ifdef __CUDACC__
#include <cuda.h> // For CUDA_VERSION
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
} \
} while (0)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
return __VA_ARGS__(); \
}
#define AT_DISPATCH_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
return __VA_ARGS__(); \
}
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
} // namespace detail
// The AT_DISPATCH_* family of macros provides the ability to
// conveniently generate specializations of a kernel over all of the
// dtypes we care about in PyTorch. We call it "dispatch" because
@ -180,13 +190,27 @@ TORCH_API void record_kernel_function_dtype(std::string name);
// but we're just being safe (and it doesn't hurt.) Note we must
// use it to shut up warnings about unused store.
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH_TMPL( \
RECORD_KERNEL_FUNCTION_DTYPE, \
TORCH_CHECK_NOT_IMPLEMENTED, \
TYPE, \
NAME, \
__VA_ARGS__)
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
constexpr const char* at_dispatch_name = NAME; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
toString(_st), \
"'"); \
} \
C10_DIAGNOSTIC_POP() \
}()
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \

View File

@ -1,8 +1,3 @@
#pragma once
#include <torch/headeronly/core/Dispatch_v2.h>
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
#include <ATen/Dispatch.h>
// This is a new implementation of the AT_DISPATCH macro family from
@ -79,19 +74,41 @@
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
// relied on GPT4 to help me get it right.
// Public API macros
// See documentation above
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
AT_DISPATCH_SWITCH, \
AT_DISPATCH_CASE, \
TYPE, \
NAME, \
AT_WRAP(BODY), \
__VA_ARGS__)
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
#define AT_WRAP(...) __VA_ARGS__
#define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
#define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
#define AT_INTEGRAL_TYPES_V2 \
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
// NB: not *actually* all types
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
#define AT_ALL_TYPES_AND_COMPLEX \
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
// Helper macros
// Unused helper macros, kept for BC:
#define AT_AP_VAR(N, T, ...) \
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
#define AT_CONCAT_AUX(a, b) a##b
#define AT_EXPAND(X) X
// Ensure we never have too many scalar types for the expansion here to
// support. To bump this, you must regenerate the macros below.
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
num_args = 60
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
for i in range(1, num_args+1):
args = ', '.join(f'_{i}' for i in range(1, i+1))
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
// Begin generated code
// clang-format off
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)

View File

@ -21,7 +21,6 @@
#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#endif
#if AT_MAGMA_ENABLED()
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
#endif
}
long CUDAHooks::versionRuntimeCuDNN() const {
#if AT_CUDNN_ENABLED()
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
#else
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionCuDNNFrontend() const {
#if AT_CUDNN_ENABLED()
return CUDNN_FRONTEND_VERSION;
#else
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionMIOpen() const {
#if AT_ROCM_ENABLED()
return MIOPEN_VERSION_MAJOR * 10000 +

View File

@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;
long versionRuntimeCuDNN() const override;
long versionCuDNNFrontend() const override;
long versionMIOpen() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;

View File

@ -174,14 +174,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionRuntimeCuDNN() const {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionCuDNNFrontend() const {
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionMIOpen() const {
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -409,7 +409,7 @@ struct ConvParams {
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
@ -453,7 +453,7 @@ struct ConvParams {
}
// native kernel doesn't support 64-bit non-splittable case
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
if (cudnn_version < 0 || cudnn_version > 91000) {

View File

@ -478,7 +478,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
const auto s_k = params.key.sym_size(2);
const auto d_qk = params.query.sym_size(3);
const auto d_v = params.value.sym_size(3);
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
if (cudnn_version < 8903) {
if (debug) {
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
@ -709,7 +709,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
return false;
#endif
#if defined(CUDNN_VERSION)
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
static auto cudnn_version = cudnnGetVersion();
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
if (debug) {
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");

View File

@ -10,8 +10,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch_v2.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp

View File

@ -1,82 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/Dispatch_v2.h>
// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For
// testing, we do nothing:
#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */
#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__)
#define MY_DISPATCH_CASE(...) \
THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__)
// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch
// statement. For testing, we just avoid unused variable warning:
#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \
(void)DISPATCHNAME
// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For
// testing, we count case mismatches:
#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++
#define MY_DISPATCH_SWITCH(...) \
THO_DISPATCH_SWITCH_TMPL( \
MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__)
// MY_CASE_FUNCTION is called in a case block. For testing, we count
// case matches and ensure that scalar_t/index_t type is defined:
#define MY_CASE_FUNCTION \
[&] { \
count++; \
scalar_t tmp; \
(void)tmp; \
}
#define MY_INDEX_CASE_FUNCTION \
[&] { \
count++; \
index_t tmp; \
(void)tmp; \
}
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
MY_DISPATCH_SWITCH, \
MY_DISPATCH_CASE, \
TYPE, \
NAME, \
AT_WRAP(BODY), \
__VA_ARGS__)
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
TEST(TestDispatchV2, NAME) { \
using torch::headeronly::ScalarType; \
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
int8_t total_count = 0; \
int8_t count = 0; \
int8_t default_count = 0; \
for (ScalarType t : \
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
total_count++; \
MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \
} \
EXPECT_EQ(count, EXPECTEDCOUNT); \
EXPECT_EQ(default_count + count, total_count); \
}
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
#undef DEFINE_ITEM

View File

@ -1,45 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/util/Exception.h>
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
TEST(TestThoDispatchV2, NAME) { \
using torch::headeronly::ScalarType; \
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
int8_t total_count = 0; \
int8_t count = 0; \
int8_t default_count = 0; \
for (ScalarType t : \
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
total_count++; \
try { \
THO_DISPATCH_V2( \
t, \
"test_tho_dispatch_v2", \
[&] { \
count++; \
scalar_t tmp; \
(void)tmp; \
}, \
__VA_ARGS__); \
} catch (...) { \
default_count++; /* counts mismatches */ \
} \
} \
EXPECT_EQ(count, EXPECTEDCOUNT); \
EXPECT_EQ(default_count + count, total_count); \
}
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
#undef DEFINE_ITEM

View File

@ -1,18 +1,11 @@
# Owner(s): ["oncall: distributed"]
import itertools
from contextlib import nullcontext
from typing import Any
import torch
import torch.distributed as dist
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
@ -21,7 +14,6 @@ from torch.distributed.tensor._utils import (
compute_global_tensor_shape,
compute_local_shape_and_global_offset,
compute_local_tensor_info,
ExplicitRedistributionContext,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import (
@ -859,93 +851,5 @@ class Test2DStridedLocalShard(DTensorTestBase):
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
class LocalTensorTestBase(TestCase):
def assertEqual(self, lhs, rhs, **kwargs):
mode = local_tensor_mode()
with nullcontext() if mode is None else mode.disable():
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
super().assertEqual(lhs._ranks, rhs._ranks)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r],
rhs._local_tensors[r],
lambda m: f"rank {r}: {m}",
)
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
)
else:
return super().assertEqual(lhs, rhs, **kwargs)
@property
def world_size(self):
raise NotImplementedError("override world-size in your subclass")
def build_device_mesh(self) -> DeviceMesh:
return init_device_mesh("cpu", (self.world_size,))
def setUp(self):
super().setUp()
torch.distributed.init_process_group(
# TODO: test other ranks too
"fake",
rank=0,
world_size=self.world_size,
)
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
class TestExplicitRedistribute(LocalTensorTestBase):
@property
def world_size(self):
return 4
def test_explicit_matmul(self):
with LocalTensorMode(self.world_size):
device_mesh = self.build_device_mesh()
dim = 128
x = torch.randn(8, dim, requires_grad=True)
A = torch.randn(dim, dim, requires_grad=True)
# Prepare DTensors
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
# implicit redistribute works as usual by default
with CommDebugMode() as comm_mode:
torch.matmul(dx, dA)
self.assertEqual(comm_mode.get_total_counts(), 1)
# explicit redistribute works too
with ExplicitRedistributionContext():
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
torch.matmul(dx, dA)
# explicit redistribute allows manual redistribute
with ExplicitRedistributionContext():
dA_repl = dA.redistribute(device_mesh, [Replicate()])
torch.matmul(dx, dA_repl)
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Replicate()])
with ExplicitRedistributionContext():
dY = torch.matmul(dx, dA_repl)
loss = dY.sum()
# we now see the error during backwards
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
loss.backward()
if __name__ == "__main__":
run_tests()

View File

@ -475,14 +475,17 @@ class TestFxGraphCache(TestCase):
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
if (
device == "cuda"
and torch.version.hip is None
and dtype == torch.bfloat16
and not SM80OrLater
):
raise unittest.SkipTest("requires SM80 or later")
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
raise unittest.SkipTest(
"Static cuda launcher requires cuda and triton bundling"
)
if use_static_cuda_launcher and TEST_WITH_ROCM:
raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM")
def fn(x, y):
return (x * 2, y @ y)

View File

@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.profiler import profile, ProfilerActivity
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
@ -941,5 +941,63 @@ copy_tests(
)
from torch.profiler._utils import _enrich_profiler_traces
class TestProfilerStackTraceAugmentation(TestCase):
"""
Test that profiler events are correctly augmented with stack traces
from both FX metadata and inductor kernel stack traces.
"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
@torch.compiler.config.patch("force_disable_caches", True)
def test_profiler_inductor_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from inductor kernel metadata.
"""
# Test model similar to test.py
class TestModel(torch.nn.Module):
def forward(self, c):
d = c * 2
d = d + 1
return d
device = "cuda"
model = TestModel().to(device)
c = torch.randn((64, 32), device=device)
# Force disable caches to ensure fresh compilation
torch.compiler.config.force_disable_caches = True
# Compile the model
compiled_model = torch.compile(model, fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(c)
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
compiled_model(c)
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
# TODO: add test that when enrich is not turned on there is no recordfast generated.
if __name__ == "__main__":
run_tests()

View File

@ -12,7 +12,6 @@ from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaK
from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton
from torch._inductor.runtime.triton_helpers import libdevice
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton
@ -39,8 +38,9 @@ class TestStaticCudaLauncher(TestCase):
# Just used by tests for now.
# TODO: derive cubin_path from wherever triton stores the cubin file on disk.
tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False)
binary_key = "hsaco" if torch.version.hip else "cubin"
with tmp_file:
tmp_file.write(kernel.asm["cubin"])
tmp_file.write(kernel.asm[binary_key])
self.tmp_files.append(tmp_file)
return tmp_file.name
@ -64,7 +64,6 @@ class TestStaticCudaLauncher(TestCase):
result.load_kernel(device_interface.current_device())
return result
@skipIfRocm
def test_basic(self):
@triton.jit
def simple_kernel(arg0, arg1):
@ -91,7 +90,6 @@ class TestStaticCudaLauncher(TestCase):
# 2. triton relies on inspect.get_source to get the type annotations
# so I can't even use exec() to generate the test cases.
# So we'll just make a few kernels by hand
@skipIfRocm
def test_unsigned_integers(self):
@triton.jit
def unsigned_integers(
@ -115,7 +113,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_signed_integers(self):
@triton.jit
def signed_integers(
@ -139,7 +136,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_basic_1arg(self):
@triton.jit
def simple_kernel_1_arg(arg0):
@ -164,7 +160,6 @@ class TestStaticCudaLauncher(TestCase):
)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_constexpr(self):
# Constexprs are compiled directly into the cubin file,
# so we never need to pass it to StaticCudaLauncher.
@ -193,7 +188,6 @@ class TestStaticCudaLauncher(TestCase):
)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_implied_constant(self):
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
@ -246,7 +240,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, arg0, arg2, 128)
self.assertEqual(arg1, arg2)
@skipIfRocm
def test_kernel_no_args(self):
# Just an easy way to test incompatible number of arguments
@triton.jit
@ -259,7 +252,6 @@ class TestStaticCudaLauncher(TestCase):
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run(1, 1, 1, stream)
@skipIfRocm
def test_high_shared_mem(self):
@triton.jit
def simple_kernel(arg0, arg1):
@ -283,7 +275,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, new_arg0, arg1)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_too_high_shared_mem(self):
@triton.jit
def simple_kernel(arg0, arg1):
@ -303,7 +294,6 @@ class TestStaticCudaLauncher(TestCase):
lambda: self._make_launcher(compiled_kernel),
)
@skipIfRocm
def test_kernel_empty_tensor(self):
# Triton kernel generated by torch.compile of the following:
# @torch.compile()
@ -364,7 +354,6 @@ class TestStaticCudaLauncher(TestCase):
launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel)
self.assertEqual(buf0, buf1)
@skipIfRocm
def test_kernel_many_args(self):
N = 200
# Make 200 arguments
@ -405,7 +394,6 @@ class TestStaticTritonCompileResult(TestCase):
Tests static cuda launcher with torch.compile()
"""
@skipIfRocm
def test_basic_compile(self):
@torch.compile
def foo(x, y):
@ -415,7 +403,6 @@ class TestStaticTritonCompileResult(TestCase):
y = torch.randn(10, device="cuda")
self.assertEqual(foo(x, y), x + y)
@skipIfRocm
# The error gets raised on a worker, so we want to not use a separate process
@torch._inductor.config.patch("compile_threads", 1)
def test_incompatible_code(self):
@ -438,7 +425,6 @@ class TestStaticTritonCompileResult(TestCase):
lambda: foo(x),
)
@skipIfRocm
# The error gets raised on a worker, so we want to not use a separate process
@torch._inductor.config.patch(
{"compile_threads": 1, "static_launch_user_defined_triton_kernels": True}
@ -460,7 +446,6 @@ class TestStaticTritonCompileResult(TestCase):
x2 = x.clone().detach_()
self.assertEqual(foo(x), x2 + 5)
@skipIfRocm
def test_empty_tensor(self):
@torch.compile()
def foo(x, y):
@ -472,7 +457,6 @@ class TestStaticTritonCompileResult(TestCase):
result = foo(x, y)
self.assertEqual(result, torch.cat(((x * 4), y + 10)))
@skipIfRocm
def test_any(self):
def fn(x):
return (
@ -492,7 +476,6 @@ class TestStaticTritonCompileResult(TestCase):
compiled_result = compiled_fn(arg)
self.assertEqual(eager_result, compiled_result)
@skipIfRocm
def test_disable_static_cuda_launcher(self):
@torch.compile
def fn(x, y):

View File

@ -76,11 +76,8 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
from torch.profiler._utils import _enrich_profiler_traces
try:
from torchvision import models as torchvision_models
@ -208,36 +205,6 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4251,7 +4218,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
@ -4307,7 +4274,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
@ -4351,7 +4318,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)

View File

@ -1762,13 +1762,13 @@ class TestMPS(TestCaseMPS):
continue
# Running stats must be tracked in eval mode
if (track_running_stats):
helper(shape, eps=1e-5, momentum=1, channels_last=channels_last,
helper(shape, eps=0, momentum=1, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-5, momentum=1.0, wts=False, training=False, channels_last=channels_last,
helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
@ -1776,7 +1776,7 @@ class TestMPS(TestCaseMPS):
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-5, momentum=1.0, wts=False, training=True, channels_last=channels_last,
helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)

View File

@ -739,11 +739,8 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
enrich_profiler_metadata: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -3228,7 +3228,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
self.push(SliceVariable(items, tx=self))
def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
@ -3607,7 +3607,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.is_mutable()
obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "extend", [v], {})
def LIST_TO_TUPLE(self, inst: Instruction) -> None:
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
@ -3673,7 +3673,7 @@ class InstructionTranslatorBase(
def MATCH_KEYS(self, inst: Instruction) -> None:
tos = self.stack[-1]
assert isinstance(tos, TupleVariable)
keys = tos.unpack_var_sequence(self) # type: ignore[arg-type]
keys = tos.unpack_var_sequence(self)
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)

View File

@ -1513,7 +1513,7 @@ class WithExitFunctionVariable(VariableTracker):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.
self.ctx.reconstruct_type(codegen) # type: ignore[union-attr]
self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined]
if codegen.tx.output.partial_convert:
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
@ -1522,10 +1522,10 @@ class WithExitFunctionVariable(VariableTracker):
# We rely on classes subtyping `GenericContextWrappingVariable`
# to implement these fns and have these attributes
codegen.extend_output(
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr]
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type]
)
codegen.extend_output(
create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr]
create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type]
)
codegen.append_output(create_setup_with(self.target))
codegen.append_output(create_instruction("POP_TOP"))

View File

@ -82,8 +82,7 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.product(*seqs, repeat=r)
]
return variables.ListIteratorVariable(
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
items, mutation_type=ValueMutationNew()
)
elif (
self.value is itertools.combinations
@ -99,8 +98,7 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.combinations(iterable, r):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
items, mutation_type=ValueMutationNew()
)
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
@ -183,8 +181,7 @@ class ItertoolsVariable(VariableTracker):
from_exc=e,
)
return variables.ListIteratorVariable(
result, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
result, mutation_type=ValueMutationNew()
)
elif self.value is itertools.repeat:
if len(args) < 2:
@ -215,8 +212,7 @@ class ItertoolsVariable(VariableTracker):
)
]
return variables.ListIteratorVariable(
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
items, mutation_type=ValueMutationNew()
)
else:
return super().call_function(tx, args, kwargs)

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
"""
Variable tracking implementations for list-like data structures in Dynamo.
@ -18,7 +20,7 @@ import collections
import inspect
import operator
import sys
from typing import Any, Optional, Sequence, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
import torch
import torch.fx
@ -58,11 +60,11 @@ if TYPE_CHECKING:
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for_instance(obj: Any) -> type["BaseListVariable"]:
def cls_for_instance(obj):
return BaseListVariable.cls_for(type(obj))
@staticmethod
def cls_for(obj: Any) -> type:
def cls_for(obj):
return {
iter: ListIteratorVariable,
list: ListVariable,
@ -78,38 +80,34 @@ class BaseListVariable(VariableTracker):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items: list[VariableTracker] = items
def _as_proxy(self) -> list[Any]:
def _as_proxy(self):
return [x.as_proxy() for x in self.items]
def modified(
self, items: list[VariableTracker], **kwargs: Any
) -> "BaseListVariable":
def modified(self, items, **kwargs):
return type(self)(items, **kwargs)
@property
def value(self) -> Any:
def value(self):
return self.as_python_constant()
def debug_repr_helper(self, prefix: str, suffix: str) -> str:
def debug_repr_helper(self, prefix, suffix):
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
def as_python_constant(self) -> Any:
def as_python_constant(self):
return self.python_type()([x.as_python_constant() for x in self.items])
def as_proxy(self) -> Any:
def as_proxy(self):
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -136,16 +134,16 @@ class BaseListVariable(VariableTracker):
IndexError, tx, args=["list index out of range"]
)
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return list(self.items)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__getitem__":
from .tensor import TensorVariable
@ -226,15 +224,15 @@ class BaseListVariable(VariableTracker):
if type(self) is not type(args[0]):
tp_name = self.python_type_name()
other = args[0].python_type_name()
msg_vt = ConstantVariable.create(
msg = ConstantVariable.create(
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
)
raise_observed_exception(TypeError, tx, args=[msg_vt])
raise_observed_exception(TypeError, tx, args=[msg])
if name == "__add__":
return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined]
return type(self)(self.items + args[0].items, source=self.source)
else:
self.items += args[0].items # type: ignore[attr-defined]
self.items += args[0].items
return self
elif name in ("__mul__", "__imul__"):
if kwargs or len(args) != 1:
@ -246,10 +244,10 @@ class BaseListVariable(VariableTracker):
)
if not (args[0].is_python_constant() and args[0].python_type() is int):
msg_vt = ConstantVariable.create(
msg = ConstantVariable.create(
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg_vt])
raise_observed_exception(TypeError, tx, args=[msg])
val = args[0].as_python_constant()
@ -303,7 +301,7 @@ class BaseListVariable(VariableTracker):
class RangeVariable(BaseListVariable):
def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None:
def __init__(self, items, **kwargs) -> None:
items_to_map = items
start = variables.ConstantVariable.create(0)
stop = None
@ -318,7 +316,7 @@ class RangeVariable(BaseListVariable):
else:
raise AssertionError
def maybe_as_int(x: VariableTracker) -> VariableTracker:
def maybe_as_int(x):
return (
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
)
@ -331,22 +329,22 @@ class RangeVariable(BaseListVariable):
assert stop is not None
super().__init__([start, stop, step], **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
return self.debug_repr_helper("range(", ")")
def python_type(self) -> type:
def python_type(self):
return range
def start(self) -> Any:
def start(self):
return self.items[0].as_python_constant()
def stop(self) -> Any:
def stop(self):
return self.items[1].as_python_constant()
def step(self) -> Any:
def step(self):
return self.items[2].as_python_constant()
def range_length(self) -> int:
def range_length(self):
lo = self.start()
hi = self.stop()
step = self.step()
@ -359,7 +357,7 @@ class RangeVariable(BaseListVariable):
else:
return 0
def _get_slice_indices(self, length: int, slice: slice) -> list[int]:
def _get_slice_indices(self, length, slice):
step_is_negative = 0
if slice.step is None:
@ -408,7 +406,7 @@ class RangeVariable(BaseListVariable):
return [start, stop, step]
def apply_index(self, index: int) -> VariableTracker:
def apply_index(self, index):
length = self.range_length()
if index < 0:
index = length + index
@ -423,12 +421,12 @@ class RangeVariable(BaseListVariable):
return variables.ConstantVariable.create(self.start() + (index * self.step()))
def apply_slice(self, slice: slice) -> "RangeVariable":
def apply_slice(self, slice):
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
self.range_length(), slice
)
def compute_item(index: int) -> int:
def compute_item(index):
return self.start() + (index * self.step())
sub_step = self.step() * slice_step
@ -444,12 +442,10 @@ class RangeVariable(BaseListVariable):
)
return result
def as_python_constant(self) -> range:
def as_python_constant(self):
return range(*[x.as_python_constant() for x in self.items])
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()
@ -461,30 +457,28 @@ class RangeVariable(BaseListVariable):
msg = ConstantVariable("range indices must be integers or slices")
raise_observed_exception(TypeError, tx, args=[msg])
def as_proxy(self) -> range:
def as_proxy(self):
return self.python_type()(*self._as_proxy())
def unpack_var_sequence(
self, tx: Optional["InstructionTranslator"] = None
) -> list[VariableTracker]:
def unpack_var_sequence(self, tx=None):
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "range" not in codegen.tx.f_globals
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
lambda: codegen.append_output(codegen.create_load_python_module(range))
)
codegen.foreach(self.items)
codegen.extend_output(create_call_function(3, False))
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is range:
return variables.ConstantVariable.create(name in range.__dict__)
return super().call_obj_hasattr(tx, name)
def range_equals(self, other: "RangeVariable") -> bool:
def range_equals(self, other: "RangeVariable"):
r0, r1 = self, other
if (
self.range_length() != r1.range_length()
@ -493,12 +487,12 @@ class RangeVariable(BaseListVariable):
):
return False
if self.range_length() == 1:
if len(r0) == 1:
return True
return r0.step() == r1.step()
def range_count(self, x: VariableTracker) -> int:
def range_count(self, x: VariableTracker):
# Based on CPython
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
x = x.as_python_constant()
@ -517,13 +511,7 @@ class RangeVariable(BaseListVariable):
return int(re)
return 0
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def call_method(self, tx, name, args, kwargs):
if name == "__iter__":
if not all(var.is_python_constant() for var in self.items):
# Can't represent a `range_iterator` without well defined bounds
@ -557,10 +545,7 @@ class RangeVariable(BaseListVariable):
if pt is not range:
return ConstantVariable.create(NotImplemented)
if isinstance(other, RangeVariable):
cmp = self.range_equals(other)
else:
cmp = False
cmp = self.range_equals(other)
# Two ranges are equal if they produce the same sequence of values
if name == "__eq__":
@ -569,7 +554,7 @@ class RangeVariable(BaseListVariable):
return ConstantVariable(not cmp)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def var_getattr(self, tx: "InstructionTranslator", name):
fields = ["start", "stop", "step"]
if name in fields:
return self.items[fields.index(name)]
@ -583,11 +568,11 @@ class CommonListMethodsVariable(BaseListVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
from .tensor import SymNodeVariable
if name == "append" and self.is_mutable():
@ -691,9 +676,9 @@ class CommonListMethodsVariable(BaseListVariable):
self.items[key.evaluate_expr()] = value
elif isinstance(key, SliceVariable):
if key.is_python_constant():
self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined]
self.items[key.as_python_constant()] = list(value.items)
else:
items_slice = slice(
items = slice(
*[
(
s.evaluate_expr()
@ -703,7 +688,7 @@ class CommonListMethodsVariable(BaseListVariable):
for s in key.items
]
)
self.items[items_slice] = list(value.items) # type: ignore[attr-defined]
self.items[items] = list(value.items)
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
@ -748,8 +733,8 @@ class CommonListMethodsVariable(BaseListVariable):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
items_lst: list[VariableTracker] = list(self.items)
return self.modified(items_lst, mutation_type=ValueMutationNew())
items = list(self.items)
return self.modified(items, mutation_type=ValueMutationNew())
elif name == "reverse" and self.is_mutable():
if args or kwargs:
raise_args_mismatch(
@ -778,13 +763,13 @@ class CommonListMethodsVariable(BaseListVariable):
class ListVariable(CommonListMethodsVariable):
def python_type(self) -> type:
def python_type(self):
return list
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self) -> str:
def debug_repr(self):
return self.debug_repr_helper("[", "]")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -793,11 +778,11 @@ class ListVariable(CommonListMethodsVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
from .tensor import SymNodeVariable
if name == "__setitem__" and self.is_mutable():
@ -820,14 +805,14 @@ class ListVariable(CommonListMethodsVariable):
msg = ConstantVariable.create("can only assign an iterable")
raise_observed_exception(TypeError, tx, args=[msg])
key_as_const = key.as_python_constant()
if key_as_const.step == 0:
key = key.as_python_constant()
if key.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
value_unpack = value.force_unpack_var_sequence(tx)
value = value.force_unpack_var_sequence(tx)
try:
self.items[key_as_const] = value_unpack
self.items[key] = value
except Exception as exc:
raise_observed_exception(
type(exc),
@ -874,7 +859,7 @@ class ListVariable(CommonListMethodsVariable):
assert first_non_constant_key is not None
try:
python_type = str(first_non_constant_key.python_type())
python_type = first_non_constant_key.python_type()
except NotImplementedError:
python_type = "unknown"
@ -919,7 +904,7 @@ class ListVariable(CommonListMethodsVariable):
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def var_getattr(self, tx, name):
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -931,19 +916,14 @@ class ListVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is not list:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr([], name))
class DequeVariable(CommonListMethodsVariable):
def __init__(
self,
items: list[VariableTracker],
maxlen: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
def __init__(self, items, maxlen=None, **kwargs) -> None:
if maxlen is None:
maxlen = ConstantVariable.create(None)
assert maxlen.is_python_constant(), (
@ -955,17 +935,17 @@ class DequeVariable(CommonListMethodsVariable):
items = items[-maxlen.as_python_constant() :]
super().__init__(items, **kwargs)
def python_type(self) -> type:
def python_type(self):
return collections.deque
def debug_repr(self) -> str:
def debug_repr(self):
if self.maxlen.as_python_constant() is None:
return self.debug_repr_helper(
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
)
return self.debug_repr_helper("deque([", "])")
def as_python_constant(self) -> collections.deque[Any]:
def as_python_constant(self):
return self.python_type()(
[x.as_python_constant() for x in self.items],
maxlen=self.maxlen.as_python_constant(),
@ -974,7 +954,7 @@ class DequeVariable(CommonListMethodsVariable):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_python_module(collections.deque) # type: ignore[arg-type]
codegen.create_load_python_module(collections.deque)
)
)
codegen.foreach(self.items)
@ -982,18 +962,18 @@ class DequeVariable(CommonListMethodsVariable):
codegen(self.maxlen)
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def var_getattr(self, tx: "InstructionTranslator", name):
if name == "maxlen":
return self.maxlen
return super().var_getattr(tx, name)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if (
name == "__setitem__"
and self.is_mutable()
@ -1088,20 +1068,20 @@ class DequeVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is collections.deque:
return variables.ConstantVariable.create(name in collections.deque.__dict__)
return super().call_obj_hasattr(tx, name)
class TupleVariable(BaseListVariable):
def python_type(self) -> type[tuple]: # type: ignore[type-arg]
def python_type(self):
return tuple
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self) -> str:
def debug_repr(self):
return self.debug_repr_helper("(", ")")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1110,14 +1090,14 @@ class TupleVariable(BaseListVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def var_getattr(self, tx, name):
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -1129,7 +1109,7 @@ class TupleVariable(BaseListVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is not tuple:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr((), name))
@ -1147,18 +1127,18 @@ class SizeVariable(TupleVariable):
self,
items: list[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs: Any,
**kwargs,
) -> None:
self.proxy = proxy
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
return self.debug_repr_helper("torch.Size([", "])")
def python_type(self) -> type:
def python_type(self):
return torch.Size
def as_proxy(self) -> Any:
def as_proxy(self):
if self.proxy is not None:
return self.proxy
@ -1213,10 +1193,10 @@ class SizeVariable(TupleVariable):
] + create_call_function(1, False)
codegen.extend_output(build_torch_size)
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return list(self.items)
def numel(self, tx: "InstructionTranslator") -> VariableTracker:
def numel(self, tx):
from .builtin import BuiltinVariable
from .tensor import SymNodeVariable
@ -1246,11 +1226,11 @@ class SizeVariable(TupleVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__getitem__":
if kwargs or len(args) != 1:
raise_args_mismatch(
@ -1273,9 +1253,7 @@ class SizeVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def get_item_dyn(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -1291,7 +1269,7 @@ class SizeVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(torch.Size, name))
@ -1302,39 +1280,33 @@ class NamedTupleVariable(TupleVariable):
*TupleVariable._nonvar_fields,
}
def __init__(
self,
items: list[VariableTracker],
tuple_cls: type,
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
**kwargs: Any,
) -> None:
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
def is_namedtuple(self) -> bool:
def is_namedtuple(self):
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
getattr(self.tuple_cls, "_make", None)
)
def is_structseq(self) -> bool:
def is_structseq(self):
return not self.is_namedtuple()
def fields(self) -> tuple[str, ...]:
def fields(self):
return namedtuple_fields(self.tuple_cls)
def debug_repr(self) -> str:
def debug_repr(self):
if self.is_structseq():
# StructSequenceType(iterable)
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
# NamedTupleType(*iterable)
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
def python_type(self) -> type:
def python_type(self):
return self.tuple_cls
def as_python_constant(self) -> Any:
def as_python_constant(self):
if self.is_structseq():
# StructSequenceType(iterable)
result = self.python_type()([x.as_python_constant() for x in self.items])
@ -1356,7 +1328,7 @@ class NamedTupleVariable(TupleVariable):
return result
def as_proxy(self) -> Any:
def as_proxy(self):
assert self.python_type() is not SizeVariable
if self.is_structseq():
# StructSequenceType(iterable)
@ -1370,10 +1342,7 @@ class NamedTupleVariable(TupleVariable):
# StructSequenceType(iterable)
# NamedTupleType(*iterable)
# NamedTupleType._make(iterable)
if self.is_structseq():
create_fn = self.tuple_cls
else:
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_const_unchecked(create_fn)
@ -1415,8 +1384,8 @@ class NamedTupleVariable(TupleVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
@ -1477,9 +1446,7 @@ class NamedTupleVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
if isinstance(arg, SliceVariable):
# slicing a namedtuple produces a tuple
return TupleVariable(
@ -1488,8 +1455,8 @@ class NamedTupleVariable(TupleVariable):
)
return super().getitem_const(tx, arg)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def check_and_create_method() -> Optional[VariableTracker]:
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
method = inspect.getattr_static(self.tuple_cls, name, None)
if isinstance(method, classmethod):
# We need the unbounded cls method to avoid the inline __self__
@ -1522,8 +1489,8 @@ class NamedTupleVariable(TupleVariable):
return super().var_getattr(tx, name)
if name == "_fields":
result_source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=result_source)
source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=source)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]
@ -1538,19 +1505,14 @@ class NamedTupleVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
return variables.ConstantVariable.create(
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
)
class SliceVariable(VariableTracker):
def __init__(
self,
items: Sequence[VariableTracker],
tx: Optional["InstructionTranslator"] = None,
**kwargs: Any,
) -> None:
def __init__(self, items, tx=None, **kwargs) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1585,23 +1547,23 @@ class SliceVariable(VariableTracker):
super().__init__(**kwargs)
def debug_repr(self) -> str:
return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")"
def debug_repr(self):
return self.debug_repr_helper("slice(", ")")
def as_proxy(self) -> slice:
def as_proxy(self):
return slice(*[x.as_proxy() for x in self.items])
def python_type(self) -> type:
def python_type(self):
return slice
def as_python_constant(self) -> slice:
def as_python_constant(self):
return slice(*[guard_if_dyn(x) for x in self.items])
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def var_getattr(self, tx: "InstructionTranslator", name):
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
fields = ["start", "stop", "step"]
@ -1622,9 +1584,7 @@ class ListIteratorVariable(IteratorVariable):
*IteratorVariable._nonvar_fields,
}
def __init__(
self, items: list[VariableTracker], index: int = 0, **kwargs: Any
) -> None:
def __init__(self, items, index: int = 0, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
# Removing this check as it slows things down too much
@ -1638,7 +1598,7 @@ class ListIteratorVariable(IteratorVariable):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
def next_variable(self, tx):
assert self.is_mutable()
old_index = self.index
if old_index >= len(self.items) or self.is_exhausted:
@ -1649,31 +1609,27 @@ class ListIteratorVariable(IteratorVariable):
self.index += 1
return self.items[old_index]
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
def call_obj_hasattr(self, tx, name):
return variables.ConstantVariable.create(hasattr(iter([]), name))
def python_type(self) -> type:
def python_type(self):
return type(iter([]))
def as_python_constant(self) -> Any:
def as_python_constant(self):
if self.index > 0:
raise NotImplementedError
return iter([x.as_python_constant() for x in self.items])
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
def has_unpack_var_sequence(self, tx):
return True
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
if self.is_exhausted:
return []
self.is_exhausted = True
return list(self.items[self.index :])
def force_unpack_var_sequence(
self, tx: "InstructionTranslator"
) -> list[VariableTracker]:
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1700,37 +1656,27 @@ class RangeIteratorVariable(IteratorVariable):
"iter_obj",
}
def __init__(
self, start: int, stop: int, step: int, len_: int, **kwargs: Any
) -> None:
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
super().__init__(**kwargs)
self.start = start
self.stop = stop
self.step = step
self.len = len_
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def call_method(self, tx, name, args, kwargs):
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
return self
return super().call_method(tx, name, args, kwargs)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
def call_obj_hasattr(self, tx, name):
if self.python_type() is range_iterator:
ri = iter(range(0))
return ConstantVariable(hasattr(ri, name))
return super().call_obj_hasattr(tx, name)
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
def next_variable(self, tx):
if self.len <= 0:
raise_observed_exception(StopIteration, tx)
@ -1739,12 +1685,12 @@ class RangeIteratorVariable(IteratorVariable):
self.start += self.step
return ConstantVariable.create(current)
def python_type(self) -> type:
def python_type(self):
return range_iterator
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
lambda: codegen.append_output(codegen.create_load_python_module(range))
)
codegen.append_output(codegen.create_load_const(self.start))
codegen.append_output(codegen.create_load_const(self.stop))

View File

@ -70,7 +70,7 @@ from .common import (
)
from .cpp_utils import cexpr
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
from torch.fx.experimental import _config as fx_experimental_config
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
@ -1120,6 +1120,37 @@ class PythonWrapperCodegen(CodeGen):
# Additional files that are dependent to the wrapper (ex. cubin files)
self.additional_files = []
# This is used to emit RecordFunctionFast markers that can be matched
# with profiler traces for provenance tracking.
#
# Stores the (kernel_name, debug_handle) tuple
# for the currently being generated kernel.
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
# set_current_kernel_debug_handle: Flag that controls whether
# write_provenance_debug_handle() should update current_kernel_debug_handle.
# This flag is automatically managed by kernel_debug_handle_context().
self.set_current_kernel_debug_handle: bool = False
@contextlib.contextmanager
def kernel_debug_handle_context(self):
"""
Context manager for kernel debug handle tracking.
self.current_kernel_debug_handle can be updated within the context
with wrapper.write_provenance_debug_handle
and it will be reset after the context
"""
old_flag_value = self.set_current_kernel_debug_handle
old_handle_value = self.current_kernel_debug_handle
self.set_current_kernel_debug_handle = True
try:
yield
finally:
self.set_current_kernel_debug_handle = old_flag_value
self.current_kernel_debug_handle = old_handle_value
@staticmethod
def create(
is_subgraph: bool,
@ -1510,8 +1541,27 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None:
return
def generate_record_function_start(self) -> Optional[str]:
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
if record_func:
assert self.current_kernel_debug_handle
kernel_name, debug_handle = self.current_kernel_debug_handle
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
self.writeline(
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
)
return "_rf_enter"
else:
return None
def generate_record_function_end(self, record_func_var: Optional[str]):
if record_func_var:
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(ExternKernelAllocLine(self, node))
self.generate_record_function_end(record_func_var)
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
node.codegen_comment(self)
@ -1671,7 +1721,9 @@ class PythonWrapperCodegen(CodeGen):
raw_args: Sequence[Any],
outputs: Sequence[ir.Buffer],
) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
self.generate_record_function_end(record_func_var)
def generate(self, is_inference):
with dynamo_timed("PythonWrapperCodegen.generate"):
@ -3142,6 +3194,8 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
)
if self.set_current_kernel_debug_handle:
self.current_kernel_debug_handle = (kernel_name, debug_handle)
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
assert old.get_dtype() == new.get_dtype()

View File

@ -98,6 +98,8 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.monitor import _WaitCounter
from torch.utils._ordered_set import OrderedSet
from torch.fx.experimental import _config as fx_experimental_config
from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
@ -1530,7 +1532,10 @@ class _InProcessFxCompile(FxCompile):
# Dump provenance artifacts for debugging trace
inductor_provenance_tracking_node_mappings = None
inductor_kernel_stack_trace_str = None
if config.trace.provenance_tracking_level != 0:
if (
config.trace.provenance_tracking_level != 0
or fx_experimental_config.enrich_profiler_metadata
):
inductor_provenance_tracking_node_mappings = json.dumps(
torch._inductor.debug.dump_inductor_provenance_info()
)

View File

@ -1179,6 +1179,8 @@ torchinductor_worker_logpath: str = Config(
default="",
)
fallback_by_default: bool = False
# config specific to codegen/cpp.py
class cpp:

View File

@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
Returns a unique int debug handler for each call to this function.
"""
if config.trace.provenance_tracking_level == 0:
if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
return None
try:

View File

@ -1628,6 +1628,7 @@ class GraphLowering(torch.fx.Interpreter):
"inductor", "lowerings", lambda: repr(n)
)
)
or (n.op == "call_function" and config.fallback_by_default)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(

View File

@ -8079,27 +8079,28 @@ class FallbackKernel(ExternKernelAlloc):
for v, a in zip(args_iter, kernel._schema.arguments)
)
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
with wrapper.kernel_debug_handle_context():
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
self.codegen_unbacked_symbol_defs(wrapper)

View File

@ -25,6 +25,8 @@ from __future__ import annotations
import dataclasses
import logging
import os
import random
import string
from functools import partial
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -70,6 +72,10 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# Used for profiler post-processing to match
# for the same compiled run
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
@dataclasses.dataclass
class OutputCode:
@ -612,9 +618,18 @@ class CompiledFxGraph(OutputCode):
try:
# Checking the profiler directly is faster than nullcontext
if torch.autograd.profiler._is_profiler_enabled:
with record_function(
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
):
# generate a random string to represent this unique run if no cache key
run_key = (
self._fx_graph_cache_key
if self._fx_graph_cache_key
else "".join(random.choices(string.ascii_lowercase, k=51))
)
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
if self.inductor_provenance_stack_traces_str:
torch.fx.traceback._register_fx_metadata(
run_name, self.inductor_provenance_stack_traces_str
)
with record_function(f"## {run_name} ##"):
return self.current_callable(inputs)
else:
return self.current_callable(inputs)

View File

@ -38,7 +38,20 @@ class StaticallyLaunchedCudaKernel:
# pyrefly: ignore [missing-attribute]
self.name = kernel.src.fn.__name__
# pyrefly: ignore [missing-attribute]
self.cubin_raw = kernel.asm.get("cubin", None)
if "hsaco" in kernel.asm:
# pyrefly: ignore [missing-attribute]
self.cubin_raw = kernel.asm["hsaco"]
self.is_rocm = True
# pyrefly: ignore [missing-attribute]
elif "cubin" in kernel.asm:
# pyrefly: ignore [missing-attribute]
self.cubin_raw = kernel.asm["cubin"]
self.is_rocm = False
else:
raise RuntimeError(
"Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm"
)
# pyrefly: ignore [missing-attribute]
self.cubin_path = kernel._cubin_path
@ -245,12 +258,42 @@ class StaticallyLaunchedCudaKernel:
# thing, it should always match.
# Get rid of constants before passing to cubin launcher
# Add a None if triton wants extra parameters for scratch spaces
arg_tys = self.arg_tys
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
if has_scratch:
arg_tys = arg_tys + "O"
args = (*args, None)
if self.is_rocm:
# ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both
# global_scratch and profile_scratch parameters in the kernel signature,
# even when the kernel doesn't use them (i.e., when has_*_scratch is False).
#
# This differs fundamentally from CUDA, where these parameters are only
# present in the signature if the corresponding has_*_scratch flag is True.
#
# The flags indicate whether memory will be allocated/used:
# - has_global_scratch: Whether global scratch workspace is needed
# - has_profile_scratch: Whether profiling instrumentation is enabled
#
# However, regardless of flag values, we MUST always pass both parameters
# to match the HIP kernel ABI. Passing None is safe:
#
# - If scratch is not needed (has_*_scratch=False or scratch_size=0):
# The None becomes nullptr, which the kernel never dereferences
#
# - If scratch is needed (has_*_scratch=True and scratch_size>0):
# The None becomes nullptr initially, but the HIP runtime intercepts
# the kernel launch, allocates the required scratch memory based on
# kernel metadata, and replaces the nullptr with a valid pointer before
# the kernel actually executes
#
# Not passing both parameters causes segmentation faults because the kernel
# expects them at specific positions in the argument array.
arg_tys = arg_tys + "OO"
args = (*args, None, None)
else:
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
if has_scratch:
arg_tys = arg_tys + "O"
args = (*args, None)
# pyrefly: ignore [bad-argument-type]
assert len(args) == len(arg_tys)

View File

@ -1599,9 +1599,8 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
return None
def check_can_launch() -> StaticallyLaunchedCudaKernel:
if triton_meta.get("device_type") != "cuda":
# Only cuda kernels
raise CannotStaticallyLaunchKernel("Non-cuda device")
if triton_meta.get("device_type") not in ("cuda", "hip"):
raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device")
if torch._inductor.config.cpp_wrapper:
# If we're running with cpp wrapper, it doesn't
@ -1627,10 +1626,11 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
"static launch does not support launch attributes"
)
binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin"
cubin_location = os.path.join(
triton_cache_dir(triton_meta.get("device", 0)),
triton_hash_to_path_key(kernel.hash),
f"{kernel.src.fn.__name__}.cubin",
f"{kernel.src.fn.__name__}.{binary_ext}",
)
if not os.path.exists(cubin_location):
@ -1662,10 +1662,11 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
When loading from cache on disk, we want to reload cubin
files from their appropriate location on disc.
"""
binary_ext = "hsaco" if torch.version.hip else "cubin"
cubin_location = os.path.join(
triton_cache_dir(self.compile_meta.get("device", 0)),
triton_hash_to_path_key(self.kernel.hash),
f"{self.kernel.name}.cubin",
f"{self.kernel.name}.{binary_ext}",
)
if not os.path.exists(cubin_location):
if self.kernel.cubin_raw is not None:

View File

@ -1224,43 +1224,3 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -2159,7 +2159,7 @@ PyObject* initModule() {
#ifdef USE_CUDA
torch::cuda::initModule(module);
#endif
#if defined(USE_CUDA) && !defined(USE_ROCM)
#if defined(USE_CUDA)
ASSERT_TRUE(StaticCudaLauncher_init(module));
#endif
#ifdef USE_MPS

View File

@ -2,7 +2,6 @@
// This file should only be compiled if this condition holds, so it should be
// safe.
#if defined(USE_CUDNN) || defined(USE_ROCM)
#include <ATen/detail/CUDAHooksInterface.h>
#include <torch/csrc/utils/pybind.h>
#include <tuple>
@ -33,7 +32,11 @@ version_tuple getRuntimeVersion() {
}
size_t getVersionInt() {
return at::detail::getCUDAHooks().versionRuntimeCuDNN();
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
}
} // namespace

View File

@ -1,7 +1,4 @@
#if defined(USE_CUDA) && !defined(USE_ROCM)
// We disable this file from being hipified because there are CUDA drivers hip
// has not implemented yet. Also, we're passing in a cubin file directly, so it
// would take more work to support ROCM anyway.
#if defined(USE_CUDA) || defined(USE_ROCM)
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <ATen/Context.h>
@ -16,6 +13,11 @@
#include <torch/csrc/utils/python_numbers.h>
#include <filesystem>
#include <optional>
#if defined(USE_ROCM)
#include <hip/hip_runtime_api.h>
#endif
/**
Implements a static launcher for triton compiled CUDA kernels.
Given a path to a cubin file, a function name, and some metadata,
@ -56,8 +58,14 @@ const at::cuda::NVRTC& nvrtc() {
CUdeviceptr getPointer(PyObject* obj) {
CUdeviceptr data_ptr = 0;
if (THPUtils_checkLong(obj)) {
#if defined(USE_ROCM)
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(obj));
#else
data_ptr = THPUtils_unpackUInt64(obj);
#endif
return data_ptr;
}
if (obj == Py_None) {
@ -73,13 +81,25 @@ CUdeviceptr getPointer(PyObject* obj) {
TORCH_CHECK(
THPUtils_checkLong(ret),
"data_ptr method of Pointer object must return 64-bit int");
#if defined(USE_ROCM)
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(ret));
#else
data_ptr = THPUtils_unpackUInt64(ret);
#endif
if (!data_ptr)
return data_ptr;
CUdeviceptr dev_ptr = 0;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute(
&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute(
&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
#endif
return dev_ptr;
}
@ -98,6 +118,15 @@ CUfunction loadKernel(
}
CUmodule mod = nullptr;
CUfunction func = nullptr;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str()));
int shared_optin = 0;
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
@ -106,6 +135,9 @@ CUfunction loadKernel(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
#endif
// Shared memory logic from triton/third-party/nvidia/backend/driver.c
// If we're using more than 48 KB of shared memory, and we have
// access to more than 48 KB of shared memory on the device,
@ -124,6 +156,21 @@ CUfunction loadKernel(
" Reducing block sizes or `num_stages` may help.");
if (sharedMemBytes > SHARED_MEM_STATIC_MAX &&
shared_optin > SHARED_MEM_STATIC_MAX) {
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared));
int shared_total = 0, shared_static = 0;
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
&shared_total,
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor,
device));
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
&shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
#else
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total = 0, shared_static = 0;
@ -137,6 +184,7 @@ CUfunction loadKernel(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
#endif
}
return func;
}
@ -152,6 +200,27 @@ inline void launchKernel(
cudaStream_t stream) {
// cta_args is always 1 for inductor generated triton kernels,
// so we don't need to figure out grid dimension here
#if defined(USE_ROCM)
int device = 0;
AT_CUDA_DRIVER_CHECK(hipGetDevice(&device));
int warp_size = 0;
AT_CUDA_DRIVER_CHECK(
hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device));
AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel(
func,
gridX,
gridY,
gridZ,
warp_size * numWarps, // blockDim.x
1, // blockDim.y
1, // blockDim.z
sharedMemBytes,
stream,
args,
nullptr));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
func,
gridX,
@ -164,6 +233,7 @@ inline void launchKernel(
stream,
args,
nullptr));
#endif
}
template <typename FINAL, typename F>
@ -269,11 +339,20 @@ PyObject* load_kernel(PyObject* self, PyObject* args) {
CUdevice device = static_cast<CUdevice>(device_ptr); // NOLINT
CUfunction func = nullptr;
func = loadKernel(filePath, funcName, sharedMemBytes, device);
// Taken from triton/nvidia/backend/driver.c
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(
hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func));
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
&n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
#else
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
#endif
n_spills /= 4;
// Return a tuple of CUFunction, n_regs, n_spills
return Py_BuildValue(
@ -299,7 +378,6 @@ PyObject* launch_kernel_inner(
std::array<uint64_t, MAX_ARGS> argStorage = {};
std::array<void*, MAX_ARGS> kernelArgs = {};
parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data());
launchKernel(
func,
gridX,
@ -386,13 +464,25 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
CUcontext pctx = nullptr;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
#endif
if (!pctx) {
// Ensure device context exists
CUdevice device = 0;
#if defined(USE_ROCM)
AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0));
AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device));
AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx));
#else
AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0));
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device));
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx));
#endif
}
CUfunction func = reinterpret_cast<CUfunction>(func_ptr); // NOLINT
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(stream); // NOLINT

View File

@ -1,5 +1,5 @@
#pragma once
#if defined(USE_CUDA) && !defined(USE_ROCM)
#if defined(USE_CUDA)
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>
#include <torch/csrc/python_headers.h>

View File

@ -20,10 +20,7 @@ from torch.distributed.tensor._tp_conv import (
convolution_backward_handler,
convolution_handler,
)
from torch.distributed.tensor._utils import (
ExplicitRedistributionContext,
try_find_mesh_from_args,
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._debug_mode import get_active_debug_mode
from torch.utils._python_dispatch import return_and_correct_aliasing
@ -202,10 +199,6 @@ class OpDispatcher:
if participating:
# computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute:
if ExplicitRedistributionContext.is_active():
raise RuntimeError(
f"Implicit redistribution occurred while ExplicitRedistributionContext was active for {op_info.schema}"
)
# If sharding propagation decision needs redistribute, perform redistribute
# on args first, which could potentially modify args (i.e. allgather certain arg)
assert output_sharding.redistribute_schema is not None

View File

@ -18,33 +18,6 @@ from torch.distributed.tensor.placement_types import (
from torch.utils._typing_utils import not_none
class ExplicitRedistributionContext:
"""
Within this context manager, DTensor will refuse to perform implicit redistribution,
instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution
must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution.
Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass
may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual
calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op
during forward and perform a manual redistribution during backwards.
"""
_explicit_redistribute_mode = False
@classmethod
def is_active(cls) -> bool:
return cls._explicit_redistribute_mode
def __enter__(self):
self.prev = ExplicitRedistributionContext._explicit_redistribute_mode
ExplicitRedistributionContext._explicit_redistribute_mode = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
ExplicitRedistributionContext._explicit_redistribute_mode = self.prev
def _explicit_order_placements(
mesh_shape: ShapeType, placements: Sequence[Placement]
) -> Sequence[tuple[int, Placement]]:

View File

@ -2,6 +2,8 @@ import os
import sys
from typing import Optional
from torch.utils._config_module import Config, install_config_module
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
no_data_dependent_graph_break = (
@ -100,7 +102,11 @@ backed_size_oblivious = False
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
skip_dtype_check_in_meta_registrations = False
from torch.utils._config_module import install_config_module
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
install_config_module(sys.modules[__name__])

View File

@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
from ._compatibility import compatibility
from .experimental import _config as fx_experimental_config
from .graph import (
_BoxedCodeGen,
_custom_builtins,
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
# Do not import anything inside recompile, it might slow down the
# function and cause perf regression. Import outside of the method instead.
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
root_module="self",
record_func=fx_experimental_config.enrich_profiler_metadata,
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
if dynamo_config.enrich_profiler_metadata:
if fx_experimental_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):

View File

@ -43,10 +43,10 @@ should_preserve_node_meta = False
# =============================================================================
# Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
_FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
"""
Register FX metadata in the global in-memory registry.
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
Args:
module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
"""
# TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata

View File

@ -139,22 +139,3 @@ AT_FORALL_COMPLEX_TYPES
toString
<<
toUnderlying
# torch/headeronly/core/Dispatch_v2.h
THO_DISPATCH_V2_TMPL
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL
THO_DISPATCH_CASE_TMPL
THO_DISPATCH_SWITCH_TMPL
# AT_WRAP, THO_AP_VAR_TMPL, AT_CONCAT, AT_CONCAT_AUX, AT_EXPAND are tested through THO_DISPATCH_V2_TMPL
# scalar_type is tested through THO_DISPATCH_SWITCH_TMPL
AT_FLOAT8_TYPES
AT_INTEGRAL_TYPES
AT_FLOATING_TYPES
AT_BAREBONES_UNSIGNED_TYPES
AT_INTEGRAL_TYPES_V2
AT_COMPLEX_TYPES
AT_QINT_TYPES
AT_ALL_TYPES
AT_ALL_TYPES_AND_COMPLEX
THO_DISPATCH_V2
# THO_EMPTY, THO_DISPATCH_CASE, THO_DISPATCH_SWITCH, THO_PRIVATE_CASE_TYPE_USING_HINT are tested through THO_DISPATCH_V2

View File

@ -1,73 +0,0 @@
#pragma once
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
// THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL is same as
// AT_PRIVATE_CASE_TYPE_USING_HINT but with a custom PRELUDE macro:
#define THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL(PRELUDE, enum_type, HINT, ...) \
case enum_type: { \
PRELUDE(enum_type); \
using HINT [[maybe_unused]] = \
torch::headeronly::impl::ScalarTypeToCPPTypeT<enum_type>; \
return __VA_ARGS__(); \
}
// THO_DISPATCH_CASE_TMPL is same as AT_DISPATCH_CASE but with a
// custom CASE_TYPE_USING_HINT macro:
#define THO_DISPATCH_CASE_TMPL(CASE_TYPE_USING_HINT, enum_type, ...) \
CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
namespace detail {
inline torch::headeronly::ScalarType scalar_type(
torch::headeronly::ScalarType s) {
return s;
}
} // namespace detail
// THO_DISPATCH_SWITCH_TMPL is same as AT_DISPATCH_SWITCH but with
// custom PRELUDE and CHECK_NOT_IMPLEMENTED macros:
#define THO_DISPATCH_SWITCH_TMPL( \
PRELUDE, CHECK_NOT_IMPLEMENTED, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
constexpr const char* at_dispatch_name = NAME; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
torch::headeronly::ScalarType _st = ::detail::scalar_type(the_type); \
PRELUDE(at_dispatch_name, _st); \
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
switch (_st) { \
__VA_ARGS__ \
default: \
CHECK_NOT_IMPLEMENTED( \
false, \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
torch::headeronly::toString(_st), \
"'"); \
} \
C10_DIAGNOSTIC_POP() \
}()
// THO_EMPTY is a helper macro that discards its arguments.
#define THO_EMPTY(...)
// THO_PRIVATE_CASE_TYPE_USING_HINT is same as
// AT_PRIVATE_CASE_TYPE_USING_HINT with call to macro
// AT_PRIVATE_CHECK_SELECTIVE_BUILD removed.
#define THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL(THO_EMPTY, enum_type, HINT, __VA_ARGS__)
// THO_DISPATCH_SWITCH is same as AT_DISPATCH_SWITCH with call to
// macro RECORD_KERNEL_FUNCTION_DTYPE removed and using
// STD_TORCH_CHECK instead of TORCH_CHECK_NOT_IMPLEMENTED.
#define THO_DISPATCH_SWITCH(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH_TMPL(THO_EMPTY, STD_TORCH_CHECK, TYPE, NAME, __VA_ARGS__)
// THO_DISPATCH_CASE is same as AT_DISPATCH_CASE but using
// THO_PRIVATE_CASE_TYPE_USING_HINT instead of
// AT_PRIVATE_CASE_TYPE_USING_HINT.
#define THO_DISPATCH_CASE(enum_type, ...) \
THO_DISPATCH_CASE_TMPL( \
THO_PRIVATE_CASE_TYPE_USING_HINT, enum_type, __VA_ARGS__)

View File

@ -1,170 +0,0 @@
#pragma once
#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/ScalarType.h>
// This file provides THO_DISPATCH_V2_TMPL macro that is a generalized
// version of the original AT_DISPATCH_V2 (see ATen/Dispatch_v2.h for
// documentation): THO_DISPATCH_V2_TMPL extends AT_DISPATCH_V2 with
// extra DISPATCH_SWITCH and DISPATCH_CASE arguments for specifying
// custom implementations of the original AT_DISPATCH_SWITCH and
// AT_DISPATCH_CASE macros. Use the provided macros
// THO_DISPATCH_SWITCH_TMPL and THO_DISPATCH_CASE_TMPL to define the
// custom implementations of the switch and case macros, respectively.
// Public API macros
// THO_DISPATCH_V2_TMPL is same as AT_DISPATCH_V2 but with custom
// DISPATCH_SWITCH and DISPATCH_CASE macro arguments:
#define THO_DISPATCH_V2_TMPL( \
DISPATCH_SWITCH, DISPATCH_CASE, TYPE, NAME, BODY, ...) \
DISPATCH_SWITCH( \
TYPE, \
NAME, \
THO_AP_VAR_TMPL(DISPATCH_CASE, AT_WRAP(BODY), TYPE, __VA_ARGS__))
// THO_DISPATCH_V2 is same as AT_DISPATCH_V2 but using
// THO_DISPATCH_SWITCH and THO_DISPATCH_CASE instead of
// AT_DISPATCH_SWITCH and AT_DISPATCH_CASE, respectively.
#define THO_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
THO_DISPATCH_SWITCH, THO_DISPATCH_CASE, TYPE, NAME, BODY, __VA_ARGS__)
// Type collection macros
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
#define AT_WRAP(...) __VA_ARGS__
#define AT_FLOAT8_TYPES \
torch::headeronly::ScalarType::Float8_e5m2, \
torch::headeronly::ScalarType::Float8_e5m2fnuz, \
torch::headeronly::ScalarType::Float8_e4m3fn, \
torch::headeronly::ScalarType::Float8_e4m3fnuz, \
torch::headeronly::ScalarType::Float8_e8m0fnu
#define AT_INTEGRAL_TYPES \
torch::headeronly::ScalarType::Byte, torch::headeronly::ScalarType::Char, \
torch::headeronly::ScalarType::Int, torch::headeronly::ScalarType::Long, \
torch::headeronly::ScalarType::Short
#define AT_FLOATING_TYPES \
torch::headeronly::ScalarType::Double, torch::headeronly::ScalarType::Float
#define AT_BAREBONES_UNSIGNED_TYPES \
torch::headeronly::ScalarType::UInt16, \
torch::headeronly::ScalarType::UInt32, \
torch::headeronly::ScalarType::UInt64
#define AT_INTEGRAL_TYPES_V2 \
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
#define AT_COMPLEX_TYPES \
torch::headeronly::ScalarType::ComplexDouble, \
torch::headeronly::ScalarType::ComplexFloat
#define AT_QINT_TYPES \
torch::headeronly::ScalarType::QInt8, torch::headeronly::ScalarType::QUInt8, \
torch::headeronly::ScalarType::QInt32
// NB: not *actually* all types
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
#define AT_ALL_TYPES_AND_COMPLEX \
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
// Helper macros
// THO_AP_VAR_TMPL is same as AT_AP_VAR but with a custom
// DISPATCH_CASE macro argument:
#define THO_AP_VAR_TMPL(C, N, T, ...) \
AT_EXPAND( \
AT_CONCAT(THO_AP, AT_NUM_ARGS(__VA_ARGS__))(C, AT_WRAP(N), __VA_ARGS__))
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
#define AT_CONCAT_AUX(a, b) a##b
#define AT_EXPAND(X) X
// Ensure we never have too many scalar types for the expansion here to
// support. To bump this, you must regenerate the macros below.
static_assert(static_cast<int>(torch::headeronly::ScalarType::NumOptions) < 60);
// Python code to regenerate generate code below:
#if 0
num_args = 60
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
for i in range(1, num_args+1):
args = ', '.join(f'_{i}' for i in range(1, i+1))
cases = ' '.join([f'C(_{j}, N)' for j in range(1, i+1)])
print(f'#define THO_AP{i}(C, N, {args}) {cases}')
#endif
// Begin generated code
// clang-format off
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
#define THO_AP1(C, N, _1) C(_1, N)
#define THO_AP2(C, N, _1, _2) C(_1, N) C(_2, N)
#define THO_AP3(C, N, _1, _2, _3) C(_1, N) C(_2, N) C(_3, N)
#define THO_AP4(C, N, _1, _2, _3, _4) C(_1, N) C(_2, N) C(_3, N) C(_4, N)
#define THO_AP5(C, N, _1, _2, _3, _4, _5) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N)
#define THO_AP6(C, N, _1, _2, _3, _4, _5, _6) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N)
#define THO_AP7(C, N, _1, _2, _3, _4, _5, _6, _7) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N)
#define THO_AP8(C, N, _1, _2, _3, _4, _5, _6, _7, _8) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N)
#define THO_AP9(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N)
#define THO_AP10(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N)
#define THO_AP11(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N)
#define THO_AP12(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N)
#define THO_AP13(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N)
#define THO_AP14(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N)
#define THO_AP15(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N)
#define THO_AP16(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N)
#define THO_AP17(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N)
#define THO_AP18(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N)
#define THO_AP19(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N)
#define THO_AP20(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N)
#define THO_AP21(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N)
#define THO_AP22(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N)
#define THO_AP23(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N)
#define THO_AP24(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N)
#define THO_AP25(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N)
#define THO_AP26(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N)
#define THO_AP27(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N)
#define THO_AP28(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N)
#define THO_AP29(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N)
#define THO_AP30(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N)
#define THO_AP31(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N)
#define THO_AP32(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N)
#define THO_AP33(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N)
#define THO_AP34(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N)
#define THO_AP35(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N)
#define THO_AP36(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N)
#define THO_AP37(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N)
#define THO_AP38(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N)
#define THO_AP39(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N)
#define THO_AP40(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N)
#define THO_AP41(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N)
#define THO_AP42(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N)
#define THO_AP43(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N)
#define THO_AP44(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N)
#define THO_AP45(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N)
#define THO_AP46(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N)
#define THO_AP47(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N)
#define THO_AP48(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N)
#define THO_AP49(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N)
#define THO_AP50(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N)
#define THO_AP51(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N)
#define THO_AP52(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N)
#define THO_AP53(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N)
#define THO_AP54(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N)
#define THO_AP55(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N)
#define THO_AP56(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N)
#define THO_AP57(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N)
#define THO_AP58(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N)
#define THO_AP59(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) C(_59, N)
#define THO_AP60(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) C(_59, N) C(_60, N)
// End generated code
// clang-format on

View File

@ -2838,9 +2838,6 @@ def batch_norm(
# pyrefly: ignore [bad-argument-type]
_verify_batch_size(input.size())
if eps <= 0.0:
raise ValueError(f"batch_norm eps must be positive, but got {eps}")
return torch.batch_norm(
input,
weight,

View File

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs
import functools
import json
import operator
import re
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
@ -402,13 +404,31 @@ def _init_for_cuda_graphs() -> None:
pass
class ContextType(Enum):
"""Types of contexts in the profiler stack."""
FX_GRAPH = "filename"
FX_NODE = "node"
COMPILED_GRAPH = "compiled_graph"
INDUCTOR_NODE = "inductor_node"
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
if context_type == ContextType.FX_NODE:
return ContextType.FX_GRAPH
elif context_type == ContextType.INDUCTOR_NODE:
return ContextType.COMPILED_GRAPH
else:
return None
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
marker_type: Optional[ContextType]
identifier: Optional[str | int]
event: dict[str, Any]
@ -417,7 +437,7 @@ class TimelineEvent:
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
context_type: ContextType
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
@ -438,6 +458,8 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch._inductor.output_code import CALL_COMPILED_PREFIX
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
@ -447,7 +469,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
event.get("cat") in ("cpu_op", "user_annotation")
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
@ -469,14 +491,27 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
# Try different event types
if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
".py"
):
# FX graph event
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
elif content.startswith(CALL_COMPILED_PREFIX):
# Inductor compiled graph event
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
elif content.startswith("inductor_kernel:"):
append_fx_marker_event(
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
)
else:
# Try to parse as node index for FX graph
# TODO: change to start with fx_node
try:
node_index = int(content)
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
@ -495,23 +530,37 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
if timeline_event.marker_type in (
ContextType.FX_GRAPH,
ContextType.COMPILED_GRAPH,
):
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
# TODO: add get method in traceback to try - catch and get
if isinstance(metadata, str):
metadata = json.loads(metadata)
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
timeline_event.marker_type,
timeline_event.identifier,
metadata,
tid,
)
)
elif timeline_event.marker_type == "node":
elif timeline_event.marker_type in (
ContextType.FX_NODE,
ContextType.INDUCTOR_NODE,
):
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
parent_type = get_parent_context_type(timeline_event.marker_type)
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
ctx_entry.context_type == parent_type
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
@ -520,14 +569,39 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
if ctx_entry.context_type == ContextType.FX_NODE:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
ContextType.FX_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
# Look up stack traces for this kernel
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
stack_traces = current_file_metadata.get(
timeline_event.identifier, []
)
if stack_traces:
# Store all stack traces as metadata
node_meta: Optional[dict] = {
"stack_trace": stack_traces,
"name": timeline_event.identifier,
}
context_stack.append(
ContextStackEntry(
ContextType.INDUCTOR_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
case "end":
# Pop from stack - search backwards to find matching context
@ -551,7 +625,10 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
if (
ctx_entry.context_type == ContextType.FX_NODE
and ctx_entry.metadata
):
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
@ -559,6 +636,19 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
elif (
ctx_entry.context_type == ContextType.INDUCTOR_NODE
and ctx_entry.metadata
):
# For inductor nodes, stack_trace is a list of traces
stack_traces_list = ctx_entry.metadata.get(
"stack_trace", []
)
if stack_traces_list:
# Store as a list - each trace gets its own entry
current_stack_trace = stack_traces_list
current_node_name = ctx_entry.metadata.get("name", "")
break
# Augment the event
if current_stack_trace or current_node_name:
@ -567,3 +657,81 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name
import tempfile
import os
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
if isinstance(stack_trace, list):
stack_trace = "\n".join(stack_trace)
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
trace_file = f.name
try:
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
finally:
if os.path.exists(trace_file):
os.remove(trace_file)

View File

@ -961,38 +961,6 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
desc='zero_batch')]
def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
if module_info.module_cls == torch.nn.BatchNorm1d:
input_shape = (2, 10)
elif module_info.module_cls == torch.nn.BatchNorm2d:
input_shape = (2, 10, 5, 5)
else:
input_shape = (2, 10, 4, 4, 4)
return [
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, eps=-1.0),
forward_input=FunctionInput(make_input(input_shape)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=ValueError,
error_regex="eps must be positive"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, eps=0.0),
forward_input=FunctionInput(make_input(input_shape)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=ValueError,
error_regex="eps must be positive"
),
]
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
N = kwargs['N']
lazy = kwargs.get('lazy', False)
@ -3462,7 +3430,6 @@ module_db: list[ModuleInfo] = [
ModuleInfo(torch.nn.BatchNorm1d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
skips=(
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
# RuntimeError: tried to get Double out of SymInt
@ -3481,7 +3448,6 @@ module_db: list[ModuleInfo] = [
ModuleInfo(torch.nn.BatchNorm2d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
skips=(
# See https://github.com/pytorch/pytorch/issues/134580
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')),
@ -3502,7 +3468,6 @@ module_db: list[ModuleInfo] = [
ModuleInfo(torch.nn.BatchNorm3d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
skips=(
# not supported on MPS backend
DecorateInfo(skipMPS),