mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 17:34:42 +08:00
Compare commits
1 Commits
gh/guangye
...
annotate_f
| Author | SHA1 | Date | |
|---|---|---|---|
| 2056d7fa22 |
@ -143,8 +143,7 @@ init_command = [
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
|
||||
'numpy==2.3.4 ; python_version >= "3.14"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
'sympy==1.13.3',
|
||||
|
||||
@ -174,12 +174,6 @@ class TORCH_API Context {
|
||||
static long versionCuDNN() {
|
||||
return detail::getCUDAHooks().versionCuDNN();
|
||||
}
|
||||
static long versionRuntimeCuDNN() {
|
||||
return detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
}
|
||||
static long versionCuDNNFrontend() {
|
||||
return detail::getCUDAHooks().versionCuDNNFrontend();
|
||||
}
|
||||
static bool hasCuSOLVER() {
|
||||
return detail::getCUDAHooks().hasCuSOLVER();
|
||||
}
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h> // For CUDA_VERSION
|
||||
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
case enum_type: { \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// The AT_DISPATCH_* family of macros provides the ability to
|
||||
// conveniently generate specializations of a kernel over all of the
|
||||
// dtypes we care about in PyTorch. We call it "dispatch" because
|
||||
@ -180,13 +190,27 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
// but we're just being safe (and it doesn't hurt.) Note we must
|
||||
// use it to shut up warnings about unused store.
|
||||
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE, \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
constexpr const char* at_dispatch_name = NAME; \
|
||||
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||
false, \
|
||||
'"', \
|
||||
at_dispatch_name, \
|
||||
"\" not implemented for '", \
|
||||
toString(_st), \
|
||||
"'"); \
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP() \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
|
||||
@ -1,8 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
// This is a new implementation of the AT_DISPATCH macro family from
|
||||
@ -79,19 +74,41 @@
|
||||
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
||||
// relied on GPT4 to help me get it right.
|
||||
|
||||
// Public API macros
|
||||
|
||||
// See documentation above
|
||||
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
AT_DISPATCH_SWITCH, \
|
||||
AT_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||
|
||||
// This macro lets you pass an arbitrary expression that may contain internal
|
||||
// commas to another macro without having the commas causing the expression
|
||||
// to be interpreted as being multiple arguments
|
||||
#define AT_WRAP(...) __VA_ARGS__
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||
#define AT_INTEGRAL_TYPES_V2 \
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||
// NB: not *actually* all types
|
||||
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||
|
||||
// Helper macros
|
||||
|
||||
// Unused helper macros, kept for BC:
|
||||
#define AT_AP_VAR(N, T, ...) \
|
||||
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
||||
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||
#define AT_CONCAT_AUX(a, b) a##b
|
||||
#define AT_EXPAND(X) X
|
||||
|
||||
// Ensure we never have too many scalar types for the expansion here to
|
||||
// support. To bump this, you must regenerate the macros below.
|
||||
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
||||
|
||||
num_args = 60
|
||||
|
||||
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||
|
||||
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||
|
||||
for i in range(1, num_args+1):
|
||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
||||
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
|
||||
// Begin generated code
|
||||
// clang-format off
|
||||
|
||||
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
||||
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
||||
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
||||
|
||||
@ -21,7 +21,6 @@
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionRuntimeCuDNN() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#ifndef USE_STATIC_CUDNN
|
||||
return cudnnGetVersion();
|
||||
#else
|
||||
return CUDNN_VERSION;
|
||||
#endif
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionCuDNNFrontend() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
return CUDNN_FRONTEND_VERSION;
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionMIOpen() const {
|
||||
#if AT_ROCM_ENABLED()
|
||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||
|
||||
@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCUDART() const override;
|
||||
long versionCUDART() const override;
|
||||
long versionCuDNN() const override;
|
||||
long versionRuntimeCuDNN() const override;
|
||||
long versionCuDNNFrontend() const override;
|
||||
long versionMIOpen() const override;
|
||||
std::string showConfig() const override;
|
||||
double batchnormMinEpsilonCuDNN() const override;
|
||||
|
||||
@ -174,14 +174,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionRuntimeCuDNN() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionCuDNNFrontend() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionMIOpen() const {
|
||||
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
@ -409,7 +409,7 @@ struct ConvParams {
|
||||
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
||||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
@ -453,7 +453,7 @@ struct ConvParams {
|
||||
}
|
||||
// native kernel doesn't support 64-bit non-splittable case
|
||||
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
|
||||
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
||||
if (cudnn_version < 0 || cudnn_version > 91000) {
|
||||
|
||||
@ -478,7 +478,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
|
||||
const auto s_k = params.key.sym_size(2);
|
||||
const auto d_qk = params.query.sym_size(3);
|
||||
const auto d_v = params.value.sym_size(3);
|
||||
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
|
||||
if (cudnn_version < 8903) {
|
||||
if (debug) {
|
||||
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
|
||||
@ -709,7 +709,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
||||
return false;
|
||||
#endif
|
||||
#if defined(CUDNN_VERSION)
|
||||
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
static auto cudnn_version = cudnnGetVersion();
|
||||
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
|
||||
if (debug) {
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");
|
||||
|
||||
@ -10,8 +10,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch_v2.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
|
||||
@ -1,82 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For
|
||||
// testing, we do nothing:
|
||||
#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */
|
||||
|
||||
#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__)
|
||||
|
||||
#define MY_DISPATCH_CASE(...) \
|
||||
THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__)
|
||||
|
||||
// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch
|
||||
// statement. For testing, we just avoid unused variable warning:
|
||||
#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \
|
||||
(void)DISPATCHNAME
|
||||
|
||||
// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For
|
||||
// testing, we count case mismatches:
|
||||
#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++
|
||||
|
||||
#define MY_DISPATCH_SWITCH(...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__)
|
||||
|
||||
// MY_CASE_FUNCTION is called in a case block. For testing, we count
|
||||
// case matches and ensure that scalar_t/index_t type is defined:
|
||||
#define MY_CASE_FUNCTION \
|
||||
[&] { \
|
||||
count++; \
|
||||
scalar_t tmp; \
|
||||
(void)tmp; \
|
||||
}
|
||||
#define MY_INDEX_CASE_FUNCTION \
|
||||
[&] { \
|
||||
count++; \
|
||||
index_t tmp; \
|
||||
(void)tmp; \
|
||||
}
|
||||
|
||||
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
||||
|
||||
#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
MY_DISPATCH_SWITCH, \
|
||||
MY_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestDispatchV2, NAME) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t total_count = 0; \
|
||||
int8_t count = 0; \
|
||||
int8_t default_count = 0; \
|
||||
for (ScalarType t : \
|
||||
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
||||
total_count++; \
|
||||
MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \
|
||||
} \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
EXPECT_EQ(default_count + count, total_count); \
|
||||
}
|
||||
|
||||
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
||||
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
||||
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
||||
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
||||
|
||||
#undef DEFINE_ITEM
|
||||
@ -1,45 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
||||
|
||||
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestThoDispatchV2, NAME) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t total_count = 0; \
|
||||
int8_t count = 0; \
|
||||
int8_t default_count = 0; \
|
||||
for (ScalarType t : \
|
||||
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
||||
total_count++; \
|
||||
try { \
|
||||
THO_DISPATCH_V2( \
|
||||
t, \
|
||||
"test_tho_dispatch_v2", \
|
||||
[&] { \
|
||||
count++; \
|
||||
scalar_t tmp; \
|
||||
(void)tmp; \
|
||||
}, \
|
||||
__VA_ARGS__); \
|
||||
} catch (...) { \
|
||||
default_count++; /* counts mismatches */ \
|
||||
} \
|
||||
} \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
EXPECT_EQ(default_count + count, total_count); \
|
||||
}
|
||||
|
||||
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
||||
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
||||
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
||||
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
||||
|
||||
#undef DEFINE_ITEM
|
||||
@ -1,18 +1,11 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import itertools
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._local_tensor import (
|
||||
local_tensor_mode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._utils import (
|
||||
_compute_local_shape_and_global_offset,
|
||||
@ -21,7 +14,6 @@ from torch.distributed.tensor._utils import (
|
||||
compute_global_tensor_shape,
|
||||
compute_local_shape_and_global_offset,
|
||||
compute_local_tensor_info,
|
||||
ExplicitRedistributionContext,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
@ -859,93 +851,5 @@ class Test2DStridedLocalShard(DTensorTestBase):
|
||||
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
|
||||
|
||||
|
||||
class LocalTensorTestBase(TestCase):
|
||||
def assertEqual(self, lhs, rhs, **kwargs):
|
||||
mode = local_tensor_mode()
|
||||
with nullcontext() if mode is None else mode.disable():
|
||||
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
|
||||
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
|
||||
super().assertEqual(lhs._ranks, rhs._ranks)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r],
|
||||
rhs._local_tensors[r],
|
||||
lambda m: f"rank {r}: {m}",
|
||||
)
|
||||
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
|
||||
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
|
||||
)
|
||||
else:
|
||||
return super().assertEqual(lhs, rhs, **kwargs)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
raise NotImplementedError("override world-size in your subclass")
|
||||
|
||||
def build_device_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh("cpu", (self.world_size,))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.distributed.init_process_group(
|
||||
# TODO: test other ranks too
|
||||
"fake",
|
||||
rank=0,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
class TestExplicitRedistribute(LocalTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
def test_explicit_matmul(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
device_mesh = self.build_device_mesh()
|
||||
dim = 128
|
||||
x = torch.randn(8, dim, requires_grad=True)
|
||||
A = torch.randn(dim, dim, requires_grad=True)
|
||||
|
||||
# Prepare DTensors
|
||||
dx = distribute_tensor(x, device_mesh, [Shard(0)])
|
||||
dA = distribute_tensor(A, device_mesh, [Shard(0)])
|
||||
|
||||
# implicit redistribute works as usual by default
|
||||
with CommDebugMode() as comm_mode:
|
||||
torch.matmul(dx, dA)
|
||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||
|
||||
# explicit redistribute works too
|
||||
with ExplicitRedistributionContext():
|
||||
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
|
||||
torch.matmul(dx, dA)
|
||||
|
||||
# explicit redistribute allows manual redistribute
|
||||
with ExplicitRedistributionContext():
|
||||
dA_repl = dA.redistribute(device_mesh, [Replicate()])
|
||||
torch.matmul(dx, dA_repl)
|
||||
|
||||
dx = distribute_tensor(x, device_mesh, [Shard(0)])
|
||||
dA = distribute_tensor(A, device_mesh, [Replicate()])
|
||||
with ExplicitRedistributionContext():
|
||||
dY = torch.matmul(dx, dA_repl)
|
||||
loss = dY.sum()
|
||||
|
||||
# we now see the error during backwards
|
||||
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -475,14 +475,17 @@ class TestFxGraphCache(TestCase):
|
||||
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
if (
|
||||
device == "cuda"
|
||||
and torch.version.hip is None
|
||||
and dtype == torch.bfloat16
|
||||
and not SM80OrLater
|
||||
):
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
|
||||
raise unittest.SkipTest(
|
||||
"Static cuda launcher requires cuda and triton bundling"
|
||||
)
|
||||
if use_static_cuda_launcher and TEST_WITH_ROCM:
|
||||
raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM")
|
||||
|
||||
def fn(x, y):
|
||||
return (x * 2, y @ y)
|
||||
|
||||
@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
try:
|
||||
from .test_aot_inductor_utils import AOTIRunnerUtil
|
||||
@ -941,5 +941,63 @@ copy_tests(
|
||||
)
|
||||
|
||||
|
||||
from torch.profiler._utils import _enrich_profiler_traces
|
||||
|
||||
|
||||
class TestProfilerStackTraceAugmentation(TestCase):
|
||||
"""
|
||||
Test that profiler events are correctly augmented with stack traces
|
||||
from both FX metadata and inductor kernel stack traces.
|
||||
"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
|
||||
@torch.compiler.config.patch("force_disable_caches", True)
|
||||
def test_profiler_inductor_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
augments profiler events with stack traces from inductor kernel metadata.
|
||||
"""
|
||||
|
||||
# Test model similar to test.py
|
||||
class TestModel(torch.nn.Module):
|
||||
def forward(self, c):
|
||||
d = c * 2
|
||||
d = d + 1
|
||||
return d
|
||||
|
||||
device = "cuda"
|
||||
model = TestModel().to(device)
|
||||
c = torch.randn((64, 32), device=device)
|
||||
|
||||
# Force disable caches to ensure fresh compilation
|
||||
torch.compiler.config.force_disable_caches = True
|
||||
|
||||
# Compile the model
|
||||
compiled_model = torch.compile(model, fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(c)
|
||||
|
||||
# Profile with the compiled model
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
compiled_model(c)
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
|
||||
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
|
||||
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
|
||||
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
|
||||
|
||||
# TODO: add test that when enrich is not turned on there is no recordfast generated.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -12,7 +12,6 @@ from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaK
|
||||
from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
@ -39,8 +38,9 @@ class TestStaticCudaLauncher(TestCase):
|
||||
# Just used by tests for now.
|
||||
# TODO: derive cubin_path from wherever triton stores the cubin file on disk.
|
||||
tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False)
|
||||
binary_key = "hsaco" if torch.version.hip else "cubin"
|
||||
with tmp_file:
|
||||
tmp_file.write(kernel.asm["cubin"])
|
||||
tmp_file.write(kernel.asm[binary_key])
|
||||
self.tmp_files.append(tmp_file)
|
||||
return tmp_file.name
|
||||
|
||||
@ -64,7 +64,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
result.load_kernel(device_interface.current_device())
|
||||
return result
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic(self):
|
||||
@triton.jit
|
||||
def simple_kernel(arg0, arg1):
|
||||
@ -91,7 +90,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
# 2. triton relies on inspect.get_source to get the type annotations
|
||||
# so I can't even use exec() to generate the test cases.
|
||||
# So we'll just make a few kernels by hand
|
||||
@skipIfRocm
|
||||
def test_unsigned_integers(self):
|
||||
@triton.jit
|
||||
def unsigned_integers(
|
||||
@ -115,7 +113,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_signed_integers(self):
|
||||
@triton.jit
|
||||
def signed_integers(
|
||||
@ -139,7 +136,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic_1arg(self):
|
||||
@triton.jit
|
||||
def simple_kernel_1_arg(arg0):
|
||||
@ -164,7 +160,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_constexpr(self):
|
||||
# Constexprs are compiled directly into the cubin file,
|
||||
# so we never need to pass it to StaticCudaLauncher.
|
||||
@ -193,7 +188,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_implied_constant(self):
|
||||
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
|
||||
|
||||
@ -246,7 +240,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, arg0, arg2, 128)
|
||||
self.assertEqual(arg1, arg2)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_no_args(self):
|
||||
# Just an easy way to test incompatible number of arguments
|
||||
@triton.jit
|
||||
@ -259,7 +252,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run(1, 1, 1, stream)
|
||||
|
||||
@skipIfRocm
|
||||
def test_high_shared_mem(self):
|
||||
@triton.jit
|
||||
def simple_kernel(arg0, arg1):
|
||||
@ -283,7 +275,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, arg1)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_too_high_shared_mem(self):
|
||||
@triton.jit
|
||||
def simple_kernel(arg0, arg1):
|
||||
@ -303,7 +294,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
lambda: self._make_launcher(compiled_kernel),
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_empty_tensor(self):
|
||||
# Triton kernel generated by torch.compile of the following:
|
||||
# @torch.compile()
|
||||
@ -364,7 +354,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel)
|
||||
self.assertEqual(buf0, buf1)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_many_args(self):
|
||||
N = 200
|
||||
# Make 200 arguments
|
||||
@ -405,7 +394,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
Tests static cuda launcher with torch.compile()
|
||||
"""
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic_compile(self):
|
||||
@torch.compile
|
||||
def foo(x, y):
|
||||
@ -415,7 +403,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
y = torch.randn(10, device="cuda")
|
||||
self.assertEqual(foo(x, y), x + y)
|
||||
|
||||
@skipIfRocm
|
||||
# The error gets raised on a worker, so we want to not use a separate process
|
||||
@torch._inductor.config.patch("compile_threads", 1)
|
||||
def test_incompatible_code(self):
|
||||
@ -438,7 +425,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
lambda: foo(x),
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
# The error gets raised on a worker, so we want to not use a separate process
|
||||
@torch._inductor.config.patch(
|
||||
{"compile_threads": 1, "static_launch_user_defined_triton_kernels": True}
|
||||
@ -460,7 +446,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
x2 = x.clone().detach_()
|
||||
self.assertEqual(foo(x), x2 + 5)
|
||||
|
||||
@skipIfRocm
|
||||
def test_empty_tensor(self):
|
||||
@torch.compile()
|
||||
def foo(x, y):
|
||||
@ -472,7 +457,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
result = foo(x, y)
|
||||
self.assertEqual(result, torch.cat(((x * 4), y + 10)))
|
||||
|
||||
@skipIfRocm
|
||||
def test_any(self):
|
||||
def fn(x):
|
||||
return (
|
||||
@ -492,7 +476,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
compiled_result = compiled_fn(arg)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
|
||||
@skipIfRocm
|
||||
def test_disable_static_cuda_launcher(self):
|
||||
@torch.compile
|
||||
def fn(x, y):
|
||||
|
||||
@ -76,11 +76,8 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
|
||||
from torch.autograd.profiler_util import _canonicalize_profiler_events
|
||||
from torch.profiler._utils import _enrich_profiler_traces
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
@ -208,36 +205,6 @@ def side_effect_func(x: torch.Tensor):
|
||||
print(x)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
|
||||
trace_file = f.name
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(
|
||||
trace_data
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
|
||||
|
||||
class TestFX(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -4251,7 +4218,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
@ -4307,7 +4274,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
@ -4351,7 +4318,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
|
||||
@ -1762,13 +1762,13 @@ class TestMPS(TestCaseMPS):
|
||||
continue
|
||||
# Running stats must be tracked in eval mode
|
||||
if (track_running_stats):
|
||||
helper(shape, eps=1e-5, momentum=1, channels_last=channels_last,
|
||||
helper(shape, eps=0, momentum=1, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, eps=1e-5, momentum=1.0, wts=False, training=False, channels_last=channels_last,
|
||||
helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
@ -1776,7 +1776,7 @@ class TestMPS(TestCaseMPS):
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, eps=1e-5, momentum=1.0, wts=False, training=True, channels_last=channels_last,
|
||||
helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
|
||||
track_running_stats=track_running_stats, test_module=test_module)
|
||||
|
||||
@ -739,11 +739,8 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
|
||||
enrich_profiler_metadata: bool = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
@ -3228,7 +3228,7 @@ class InstructionTranslatorBase(
|
||||
|
||||
def BUILD_SLICE(self, inst: Instruction) -> None:
|
||||
items = self.popn(inst.argval)
|
||||
self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
|
||||
self.push(SliceVariable(items, tx=self))
|
||||
|
||||
def BUILD_LIST(self, inst: Instruction) -> None:
|
||||
items = self.popn(inst.argval)
|
||||
@ -3607,7 +3607,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, ListVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
|
||||
obj.call_method(self, "extend", [v], {})
|
||||
|
||||
def LIST_TO_TUPLE(self, inst: Instruction) -> None:
|
||||
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
|
||||
@ -3673,7 +3673,7 @@ class InstructionTranslatorBase(
|
||||
def MATCH_KEYS(self, inst: Instruction) -> None:
|
||||
tos = self.stack[-1]
|
||||
assert isinstance(tos, TupleVariable)
|
||||
keys = tos.unpack_var_sequence(self) # type: ignore[arg-type]
|
||||
keys = tos.unpack_var_sequence(self)
|
||||
tos1 = self.stack[-2]
|
||||
assert isinstance(tos1, ConstDictVariable)
|
||||
|
||||
|
||||
@ -1513,7 +1513,7 @@ class WithExitFunctionVariable(VariableTracker):
|
||||
# Note here we reconstruct the context manager rather than the
|
||||
# exit function. The handler generated by BlockStackEntry
|
||||
# will re-enter the context in the resume function.
|
||||
self.ctx.reconstruct_type(codegen) # type: ignore[union-attr]
|
||||
self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined]
|
||||
if codegen.tx.output.partial_convert:
|
||||
if sys.version_info >= (3, 11):
|
||||
codegen.append_output(create_instruction("PUSH_NULL"))
|
||||
@ -1522,10 +1522,10 @@ class WithExitFunctionVariable(VariableTracker):
|
||||
# We rely on classes subtyping `GenericContextWrappingVariable`
|
||||
# to implement these fns and have these attributes
|
||||
codegen.extend_output(
|
||||
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr]
|
||||
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type]
|
||||
)
|
||||
codegen.extend_output(
|
||||
create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr]
|
||||
create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type]
|
||||
)
|
||||
codegen.append_output(create_setup_with(self.target))
|
||||
codegen.append_output(create_instruction("POP_TOP"))
|
||||
|
||||
@ -82,8 +82,7 @@ class ItertoolsVariable(VariableTracker):
|
||||
for item in itertools.product(*seqs, repeat=r)
|
||||
]
|
||||
return variables.ListIteratorVariable(
|
||||
items, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
items, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif (
|
||||
self.value is itertools.combinations
|
||||
@ -99,8 +98,7 @@ class ItertoolsVariable(VariableTracker):
|
||||
for item in itertools.combinations(iterable, r):
|
||||
items.append(variables.TupleVariable(list(item)))
|
||||
return variables.ListIteratorVariable(
|
||||
items, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
items, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif self.value is itertools.groupby:
|
||||
if any(kw != "key" for kw in kwargs.keys()):
|
||||
@ -183,8 +181,7 @@ class ItertoolsVariable(VariableTracker):
|
||||
from_exc=e,
|
||||
)
|
||||
return variables.ListIteratorVariable(
|
||||
result, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
result, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif self.value is itertools.repeat:
|
||||
if len(args) < 2:
|
||||
@ -215,8 +212,7 @@ class ItertoolsVariable(VariableTracker):
|
||||
)
|
||||
]
|
||||
return variables.ListIteratorVariable(
|
||||
items, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
items, mutation_type=ValueMutationNew()
|
||||
)
|
||||
else:
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Variable tracking implementations for list-like data structures in Dynamo.
|
||||
|
||||
@ -18,7 +20,7 @@ import collections
|
||||
import inspect
|
||||
import operator
|
||||
import sys
|
||||
from typing import Any, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@ -58,11 +60,11 @@ if TYPE_CHECKING:
|
||||
|
||||
class BaseListVariable(VariableTracker):
|
||||
@staticmethod
|
||||
def cls_for_instance(obj: Any) -> type["BaseListVariable"]:
|
||||
def cls_for_instance(obj):
|
||||
return BaseListVariable.cls_for(type(obj))
|
||||
|
||||
@staticmethod
|
||||
def cls_for(obj: Any) -> type:
|
||||
def cls_for(obj):
|
||||
return {
|
||||
iter: ListIteratorVariable,
|
||||
list: ListVariable,
|
||||
@ -78,38 +80,34 @@ class BaseListVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(items, list)
|
||||
assert all(isinstance(x, VariableTracker) for x in items)
|
||||
self.items: list[VariableTracker] = items
|
||||
|
||||
def _as_proxy(self) -> list[Any]:
|
||||
def _as_proxy(self):
|
||||
return [x.as_proxy() for x in self.items]
|
||||
|
||||
def modified(
|
||||
self, items: list[VariableTracker], **kwargs: Any
|
||||
) -> "BaseListVariable":
|
||||
def modified(self, items, **kwargs):
|
||||
return type(self)(items, **kwargs)
|
||||
|
||||
@property
|
||||
def value(self) -> Any:
|
||||
def value(self):
|
||||
return self.as_python_constant()
|
||||
|
||||
def debug_repr_helper(self, prefix: str, suffix: str) -> str:
|
||||
def debug_repr_helper(self, prefix, suffix):
|
||||
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
|
||||
|
||||
def as_python_constant(self) -> Any:
|
||||
def as_python_constant(self):
|
||||
return self.python_type()([x.as_python_constant() for x in self.items])
|
||||
|
||||
def as_proxy(self) -> Any:
|
||||
def as_proxy(self):
|
||||
assert self.python_type() is not SizeVariable
|
||||
return self.python_type()(self._as_proxy())
|
||||
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if isinstance(arg, SymNodeVariable):
|
||||
@ -136,16 +134,16 @@ class BaseListVariable(VariableTracker):
|
||||
IndexError, tx, args=["list index out of range"]
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx):
|
||||
return list(self.items)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name == "__getitem__":
|
||||
from .tensor import TensorVariable
|
||||
|
||||
@ -226,15 +224,15 @@ class BaseListVariable(VariableTracker):
|
||||
if type(self) is not type(args[0]):
|
||||
tp_name = self.python_type_name()
|
||||
other = args[0].python_type_name()
|
||||
msg_vt = ConstantVariable.create(
|
||||
msg = ConstantVariable.create(
|
||||
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg_vt])
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
if name == "__add__":
|
||||
return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined]
|
||||
return type(self)(self.items + args[0].items, source=self.source)
|
||||
else:
|
||||
self.items += args[0].items # type: ignore[attr-defined]
|
||||
self.items += args[0].items
|
||||
return self
|
||||
elif name in ("__mul__", "__imul__"):
|
||||
if kwargs or len(args) != 1:
|
||||
@ -246,10 +244,10 @@ class BaseListVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not (args[0].is_python_constant() and args[0].python_type() is int):
|
||||
msg_vt = ConstantVariable.create(
|
||||
msg = ConstantVariable.create(
|
||||
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg_vt])
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
val = args[0].as_python_constant()
|
||||
|
||||
@ -303,7 +301,7 @@ class BaseListVariable(VariableTracker):
|
||||
|
||||
|
||||
class RangeVariable(BaseListVariable):
|
||||
def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None:
|
||||
def __init__(self, items, **kwargs) -> None:
|
||||
items_to_map = items
|
||||
start = variables.ConstantVariable.create(0)
|
||||
stop = None
|
||||
@ -318,7 +316,7 @@ class RangeVariable(BaseListVariable):
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
def maybe_as_int(x: VariableTracker) -> VariableTracker:
|
||||
def maybe_as_int(x):
|
||||
return (
|
||||
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
|
||||
)
|
||||
@ -331,22 +329,22 @@ class RangeVariable(BaseListVariable):
|
||||
assert stop is not None
|
||||
super().__init__([start, stop, step], **kwargs)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
return self.debug_repr_helper("range(", ")")
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return range
|
||||
|
||||
def start(self) -> Any:
|
||||
def start(self):
|
||||
return self.items[0].as_python_constant()
|
||||
|
||||
def stop(self) -> Any:
|
||||
def stop(self):
|
||||
return self.items[1].as_python_constant()
|
||||
|
||||
def step(self) -> Any:
|
||||
def step(self):
|
||||
return self.items[2].as_python_constant()
|
||||
|
||||
def range_length(self) -> int:
|
||||
def range_length(self):
|
||||
lo = self.start()
|
||||
hi = self.stop()
|
||||
step = self.step()
|
||||
@ -359,7 +357,7 @@ class RangeVariable(BaseListVariable):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _get_slice_indices(self, length: int, slice: slice) -> list[int]:
|
||||
def _get_slice_indices(self, length, slice):
|
||||
step_is_negative = 0
|
||||
|
||||
if slice.step is None:
|
||||
@ -408,7 +406,7 @@ class RangeVariable(BaseListVariable):
|
||||
|
||||
return [start, stop, step]
|
||||
|
||||
def apply_index(self, index: int) -> VariableTracker:
|
||||
def apply_index(self, index):
|
||||
length = self.range_length()
|
||||
if index < 0:
|
||||
index = length + index
|
||||
@ -423,12 +421,12 @@ class RangeVariable(BaseListVariable):
|
||||
|
||||
return variables.ConstantVariable.create(self.start() + (index * self.step()))
|
||||
|
||||
def apply_slice(self, slice: slice) -> "RangeVariable":
|
||||
def apply_slice(self, slice):
|
||||
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
|
||||
self.range_length(), slice
|
||||
)
|
||||
|
||||
def compute_item(index: int) -> int:
|
||||
def compute_item(index):
|
||||
return self.start() + (index * self.step())
|
||||
|
||||
sub_step = self.step() * slice_step
|
||||
@ -444,12 +442,10 @@ class RangeVariable(BaseListVariable):
|
||||
)
|
||||
return result
|
||||
|
||||
def as_python_constant(self) -> range:
|
||||
def as_python_constant(self):
|
||||
return range(*[x.as_python_constant() for x in self.items])
|
||||
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
|
||||
index = arg.as_python_constant()
|
||||
|
||||
@ -461,30 +457,28 @@ class RangeVariable(BaseListVariable):
|
||||
msg = ConstantVariable("range indices must be integers or slices")
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
def as_proxy(self) -> range:
|
||||
def as_proxy(self):
|
||||
return self.python_type()(*self._as_proxy())
|
||||
|
||||
def unpack_var_sequence(
|
||||
self, tx: Optional["InstructionTranslator"] = None
|
||||
) -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx=None):
|
||||
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert "range" not in codegen.tx.f_globals
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
||||
)
|
||||
codegen.foreach(self.items)
|
||||
codegen.extend_output(create_call_function(3, False))
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is range:
|
||||
return variables.ConstantVariable.create(name in range.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
def range_equals(self, other: "RangeVariable") -> bool:
|
||||
def range_equals(self, other: "RangeVariable"):
|
||||
r0, r1 = self, other
|
||||
if (
|
||||
self.range_length() != r1.range_length()
|
||||
@ -493,12 +487,12 @@ class RangeVariable(BaseListVariable):
|
||||
):
|
||||
return False
|
||||
|
||||
if self.range_length() == 1:
|
||||
if len(r0) == 1:
|
||||
return True
|
||||
|
||||
return r0.step() == r1.step()
|
||||
|
||||
def range_count(self, x: VariableTracker) -> int:
|
||||
def range_count(self, x: VariableTracker):
|
||||
# Based on CPython
|
||||
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
|
||||
x = x.as_python_constant()
|
||||
@ -517,13 +511,7 @@ class RangeVariable(BaseListVariable):
|
||||
return int(re)
|
||||
return 0
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
if name == "__iter__":
|
||||
if not all(var.is_python_constant() for var in self.items):
|
||||
# Can't represent a `range_iterator` without well defined bounds
|
||||
@ -557,10 +545,7 @@ class RangeVariable(BaseListVariable):
|
||||
if pt is not range:
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
|
||||
if isinstance(other, RangeVariable):
|
||||
cmp = self.range_equals(other)
|
||||
else:
|
||||
cmp = False
|
||||
cmp = self.range_equals(other)
|
||||
|
||||
# Two ranges are equal if they produce the same sequence of values
|
||||
if name == "__eq__":
|
||||
@ -569,7 +554,7 @@ class RangeVariable(BaseListVariable):
|
||||
return ConstantVariable(not cmp)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
fields = ["start", "stop", "step"]
|
||||
if name in fields:
|
||||
return self.items[fields.index(name)]
|
||||
@ -583,11 +568,11 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if name == "append" and self.is_mutable():
|
||||
@ -691,9 +676,9 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
self.items[key.evaluate_expr()] = value
|
||||
elif isinstance(key, SliceVariable):
|
||||
if key.is_python_constant():
|
||||
self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined]
|
||||
self.items[key.as_python_constant()] = list(value.items)
|
||||
else:
|
||||
items_slice = slice(
|
||||
items = slice(
|
||||
*[
|
||||
(
|
||||
s.evaluate_expr()
|
||||
@ -703,7 +688,7 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
for s in key.items
|
||||
]
|
||||
)
|
||||
self.items[items_slice] = list(value.items) # type: ignore[attr-defined]
|
||||
self.items[items] = list(value.items)
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
@ -748,8 +733,8 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
"0 args and 0 kwargs",
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
items_lst: list[VariableTracker] = list(self.items)
|
||||
return self.modified(items_lst, mutation_type=ValueMutationNew())
|
||||
items = list(self.items)
|
||||
return self.modified(items, mutation_type=ValueMutationNew())
|
||||
elif name == "reverse" and self.is_mutable():
|
||||
if args or kwargs:
|
||||
raise_args_mismatch(
|
||||
@ -778,13 +763,13 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
|
||||
|
||||
class ListVariable(CommonListMethodsVariable):
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)})"
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
return self.debug_repr_helper("[", "]")
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -793,11 +778,11 @@ class ListVariable(CommonListMethodsVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if name == "__setitem__" and self.is_mutable():
|
||||
@ -820,14 +805,14 @@ class ListVariable(CommonListMethodsVariable):
|
||||
msg = ConstantVariable.create("can only assign an iterable")
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
key_as_const = key.as_python_constant()
|
||||
if key_as_const.step == 0:
|
||||
key = key.as_python_constant()
|
||||
if key.step == 0:
|
||||
msg = ConstantVariable.create("slice step cannot be zero")
|
||||
raise_observed_exception(ValueError, tx, args=[msg])
|
||||
|
||||
value_unpack = value.force_unpack_var_sequence(tx)
|
||||
value = value.force_unpack_var_sequence(tx)
|
||||
try:
|
||||
self.items[key_as_const] = value_unpack
|
||||
self.items[key] = value
|
||||
except Exception as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
@ -874,7 +859,7 @@ class ListVariable(CommonListMethodsVariable):
|
||||
assert first_non_constant_key is not None
|
||||
|
||||
try:
|
||||
python_type = str(first_non_constant_key.python_type())
|
||||
python_type = first_non_constant_key.python_type()
|
||||
except NotImplementedError:
|
||||
python_type = "unknown"
|
||||
|
||||
@ -919,7 +904,7 @@ class ListVariable(CommonListMethodsVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def var_getattr(self, tx, name):
|
||||
if name == "__class__":
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
class_type = self.python_type()
|
||||
@ -931,19 +916,14 @@ class ListVariable(CommonListMethodsVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is not list:
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
return variables.ConstantVariable.create(hasattr([], name))
|
||||
|
||||
|
||||
class DequeVariable(CommonListMethodsVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
maxlen: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
def __init__(self, items, maxlen=None, **kwargs) -> None:
|
||||
if maxlen is None:
|
||||
maxlen = ConstantVariable.create(None)
|
||||
assert maxlen.is_python_constant(), (
|
||||
@ -955,17 +935,17 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
items = items[-maxlen.as_python_constant() :]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return collections.deque
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
if self.maxlen.as_python_constant() is None:
|
||||
return self.debug_repr_helper(
|
||||
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
|
||||
)
|
||||
return self.debug_repr_helper("deque([", "])")
|
||||
|
||||
def as_python_constant(self) -> collections.deque[Any]:
|
||||
def as_python_constant(self):
|
||||
return self.python_type()(
|
||||
[x.as_python_constant() for x in self.items],
|
||||
maxlen=self.maxlen.as_python_constant(),
|
||||
@ -974,7 +954,7 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_python_module(collections.deque) # type: ignore[arg-type]
|
||||
codegen.create_load_python_module(collections.deque)
|
||||
)
|
||||
)
|
||||
codegen.foreach(self.items)
|
||||
@ -982,18 +962,18 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
codegen(self.maxlen)
|
||||
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
if name == "maxlen":
|
||||
return self.maxlen
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if (
|
||||
name == "__setitem__"
|
||||
and self.is_mutable()
|
||||
@ -1088,20 +1068,20 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is collections.deque:
|
||||
return variables.ConstantVariable.create(name in collections.deque.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
|
||||
class TupleVariable(BaseListVariable):
|
||||
def python_type(self) -> type[tuple]: # type: ignore[type-arg]
|
||||
def python_type(self):
|
||||
return tuple
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)})"
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
return self.debug_repr_helper("(", ")")
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -1110,14 +1090,14 @@ class TupleVariable(BaseListVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def var_getattr(self, tx, name):
|
||||
if name == "__class__":
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
class_type = self.python_type()
|
||||
@ -1129,7 +1109,7 @@ class TupleVariable(BaseListVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is not tuple:
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
return variables.ConstantVariable.create(hasattr((), name))
|
||||
@ -1147,18 +1127,18 @@ class SizeVariable(TupleVariable):
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
proxy: Optional[torch.fx.Proxy] = None,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.proxy = proxy
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
return self.debug_repr_helper("torch.Size([", "])")
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return torch.Size
|
||||
|
||||
def as_proxy(self) -> Any:
|
||||
def as_proxy(self):
|
||||
if self.proxy is not None:
|
||||
return self.proxy
|
||||
|
||||
@ -1213,10 +1193,10 @@ class SizeVariable(TupleVariable):
|
||||
] + create_call_function(1, False)
|
||||
codegen.extend_output(build_torch_size)
|
||||
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx):
|
||||
return list(self.items)
|
||||
|
||||
def numel(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
def numel(self, tx):
|
||||
from .builtin import BuiltinVariable
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
@ -1246,11 +1226,11 @@ class SizeVariable(TupleVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name == "__getitem__":
|
||||
if kwargs or len(args) != 1:
|
||||
raise_args_mismatch(
|
||||
@ -1273,9 +1253,7 @@ class SizeVariable(TupleVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def get_item_dyn(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if isinstance(arg, SymNodeVariable):
|
||||
@ -1291,7 +1269,7 @@ class SizeVariable(TupleVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
return variables.ConstantVariable.create(hasattr(torch.Size, name))
|
||||
|
||||
|
||||
@ -1302,39 +1280,33 @@ class NamedTupleVariable(TupleVariable):
|
||||
*TupleVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
tuple_cls: type,
|
||||
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
self.tuple_cls = tuple_cls
|
||||
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
|
||||
|
||||
def is_namedtuple(self) -> bool:
|
||||
def is_namedtuple(self):
|
||||
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
||||
getattr(self.tuple_cls, "_make", None)
|
||||
)
|
||||
|
||||
def is_structseq(self) -> bool:
|
||||
def is_structseq(self):
|
||||
return not self.is_namedtuple()
|
||||
|
||||
def fields(self) -> tuple[str, ...]:
|
||||
def fields(self):
|
||||
return namedtuple_fields(self.tuple_cls)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
|
||||
# NamedTupleType(*iterable)
|
||||
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return self.tuple_cls
|
||||
|
||||
def as_python_constant(self) -> Any:
|
||||
def as_python_constant(self):
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
result = self.python_type()([x.as_python_constant() for x in self.items])
|
||||
@ -1356,7 +1328,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
return result
|
||||
|
||||
def as_proxy(self) -> Any:
|
||||
def as_proxy(self):
|
||||
assert self.python_type() is not SizeVariable
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
@ -1370,10 +1342,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
# StructSequenceType(iterable)
|
||||
# NamedTupleType(*iterable)
|
||||
# NamedTupleType._make(iterable)
|
||||
if self.is_structseq():
|
||||
create_fn = self.tuple_cls
|
||||
else:
|
||||
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
|
||||
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_const_unchecked(create_fn)
|
||||
@ -1415,8 +1384,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
tx,
|
||||
name,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
@ -1477,9 +1446,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
if isinstance(arg, SliceVariable):
|
||||
# slicing a namedtuple produces a tuple
|
||||
return TupleVariable(
|
||||
@ -1488,8 +1455,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
)
|
||||
return super().getitem_const(tx, arg)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def check_and_create_method() -> Optional[VariableTracker]:
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def check_and_create_method():
|
||||
method = inspect.getattr_static(self.tuple_cls, name, None)
|
||||
if isinstance(method, classmethod):
|
||||
# We need the unbounded cls method to avoid the inline __self__
|
||||
@ -1522,8 +1489,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
if name == "_fields":
|
||||
result_source = NamedTupleFieldsSource(self.source) if self.source else None
|
||||
return VariableTracker.build(tx, self.fields(), source=result_source)
|
||||
source = NamedTupleFieldsSource(self.source) if self.source else None
|
||||
return VariableTracker.build(tx, self.fields(), source=source)
|
||||
|
||||
if name in self.dynamic_attributes:
|
||||
return self.dynamic_attributes[name]
|
||||
@ -1538,19 +1505,14 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
return variables.ConstantVariable.create(
|
||||
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
|
||||
)
|
||||
|
||||
|
||||
class SliceVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: Sequence[VariableTracker],
|
||||
tx: Optional["InstructionTranslator"] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
def __init__(self, items, tx=None, **kwargs) -> None:
|
||||
items_to_map = items
|
||||
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
||||
|
||||
@ -1585,23 +1547,23 @@ class SliceVariable(VariableTracker):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")"
|
||||
def debug_repr(self):
|
||||
return self.debug_repr_helper("slice(", ")")
|
||||
|
||||
def as_proxy(self) -> slice:
|
||||
def as_proxy(self):
|
||||
return slice(*[x.as_proxy() for x in self.items])
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return slice
|
||||
|
||||
def as_python_constant(self) -> slice:
|
||||
def as_python_constant(self):
|
||||
return slice(*[guard_if_dyn(x) for x in self.items])
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach(self.items)
|
||||
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
if name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(self, name)
|
||||
fields = ["start", "stop", "step"]
|
||||
@ -1622,9 +1584,7 @@ class ListIteratorVariable(IteratorVariable):
|
||||
*IteratorVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, items: list[VariableTracker], index: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(items, list)
|
||||
# Removing this check as it slows things down too much
|
||||
@ -1638,7 +1598,7 @@ class ListIteratorVariable(IteratorVariable):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
||||
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
def next_variable(self, tx):
|
||||
assert self.is_mutable()
|
||||
old_index = self.index
|
||||
if old_index >= len(self.items) or self.is_exhausted:
|
||||
@ -1649,31 +1609,27 @@ class ListIteratorVariable(IteratorVariable):
|
||||
self.index += 1
|
||||
return self.items[old_index]
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
return variables.ConstantVariable.create(hasattr(iter([]), name))
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return type(iter([]))
|
||||
|
||||
def as_python_constant(self) -> Any:
|
||||
def as_python_constant(self):
|
||||
if self.index > 0:
|
||||
raise NotImplementedError
|
||||
return iter([x.as_python_constant() for x in self.items])
|
||||
|
||||
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
|
||||
def has_unpack_var_sequence(self, tx):
|
||||
return True
|
||||
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx):
|
||||
if self.is_exhausted:
|
||||
return []
|
||||
self.is_exhausted = True
|
||||
return list(self.items[self.index :])
|
||||
|
||||
def force_unpack_var_sequence(
|
||||
self, tx: "InstructionTranslator"
|
||||
) -> list[VariableTracker]:
|
||||
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
||||
return self.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -1700,37 +1656,27 @@ class RangeIteratorVariable(IteratorVariable):
|
||||
"iter_obj",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, start: int, stop: int, step: int, len_: int, **kwargs: Any
|
||||
) -> None:
|
||||
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
self.len = len_
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
if name == "__next__":
|
||||
return self.next_variable(tx)
|
||||
elif name == "__iter__":
|
||||
return self
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
if self.python_type() is range_iterator:
|
||||
ri = iter(range(0))
|
||||
return ConstantVariable(hasattr(ri, name))
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
def next_variable(self, tx):
|
||||
if self.len <= 0:
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
@ -1739,12 +1685,12 @@ class RangeIteratorVariable(IteratorVariable):
|
||||
self.start += self.step
|
||||
return ConstantVariable.create(current)
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return range_iterator
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
||||
)
|
||||
codegen.append_output(codegen.create_load_const(self.start))
|
||||
codegen.append_output(codegen.create_load_const(self.stop))
|
||||
|
||||
@ -70,7 +70,7 @@ from .common import (
|
||||
)
|
||||
from .cpp_utils import cexpr
|
||||
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
|
||||
|
||||
from torch.fx.experimental import _config as fx_experimental_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
@ -1120,6 +1120,37 @@ class PythonWrapperCodegen(CodeGen):
|
||||
# Additional files that are dependent to the wrapper (ex. cubin files)
|
||||
self.additional_files = []
|
||||
|
||||
# This is used to emit RecordFunctionFast markers that can be matched
|
||||
# with profiler traces for provenance tracking.
|
||||
#
|
||||
# Stores the (kernel_name, debug_handle) tuple
|
||||
# for the currently being generated kernel.
|
||||
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
|
||||
|
||||
# set_current_kernel_debug_handle: Flag that controls whether
|
||||
# write_provenance_debug_handle() should update current_kernel_debug_handle.
|
||||
# This flag is automatically managed by kernel_debug_handle_context().
|
||||
self.set_current_kernel_debug_handle: bool = False
|
||||
|
||||
@contextlib.contextmanager
|
||||
def kernel_debug_handle_context(self):
|
||||
"""
|
||||
Context manager for kernel debug handle tracking.
|
||||
|
||||
self.current_kernel_debug_handle can be updated within the context
|
||||
with wrapper.write_provenance_debug_handle
|
||||
and it will be reset after the context
|
||||
"""
|
||||
old_flag_value = self.set_current_kernel_debug_handle
|
||||
old_handle_value = self.current_kernel_debug_handle
|
||||
self.set_current_kernel_debug_handle = True
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.set_current_kernel_debug_handle = old_flag_value
|
||||
self.current_kernel_debug_handle = old_handle_value
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
is_subgraph: bool,
|
||||
@ -1510,8 +1541,27 @@ class PythonWrapperCodegen(CodeGen):
|
||||
def generate_end(self, result: IndentedBuffer) -> None:
|
||||
return
|
||||
|
||||
def generate_record_function_start(self) -> Optional[str]:
|
||||
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
|
||||
if record_func:
|
||||
assert self.current_kernel_debug_handle
|
||||
kernel_name, debug_handle = self.current_kernel_debug_handle
|
||||
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
|
||||
self.writeline(
|
||||
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
|
||||
)
|
||||
return "_rf_enter"
|
||||
else:
|
||||
return None
|
||||
|
||||
def generate_record_function_end(self, record_func_var: Optional[str]):
|
||||
if record_func_var:
|
||||
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
|
||||
|
||||
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
|
||||
record_func_var = self.generate_record_function_start()
|
||||
self.writeline(ExternKernelAllocLine(self, node))
|
||||
self.generate_record_function_end(record_func_var)
|
||||
|
||||
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
|
||||
node.codegen_comment(self)
|
||||
@ -1671,7 +1721,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
raw_args: Sequence[Any],
|
||||
outputs: Sequence[ir.Buffer],
|
||||
) -> None:
|
||||
record_func_var = self.generate_record_function_start()
|
||||
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
|
||||
self.generate_record_function_end(record_func_var)
|
||||
|
||||
def generate(self, is_inference):
|
||||
with dynamo_timed("PythonWrapperCodegen.generate"):
|
||||
@ -3142,6 +3194,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.writeline(
|
||||
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
|
||||
)
|
||||
if self.set_current_kernel_debug_handle:
|
||||
self.current_kernel_debug_handle = (kernel_name, debug_handle)
|
||||
|
||||
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
|
||||
assert old.get_dtype() == new.get_dtype()
|
||||
|
||||
@ -98,6 +98,8 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.monitor import _WaitCounter
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.fx.experimental import _config as fx_experimental_config
|
||||
|
||||
|
||||
from .._dynamo.backends.common import aot_autograd
|
||||
from .._dynamo.exc import ShortenTraceback, SkipFrame
|
||||
@ -1530,7 +1532,10 @@ class _InProcessFxCompile(FxCompile):
|
||||
# Dump provenance artifacts for debugging trace
|
||||
inductor_provenance_tracking_node_mappings = None
|
||||
inductor_kernel_stack_trace_str = None
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
if (
|
||||
config.trace.provenance_tracking_level != 0
|
||||
or fx_experimental_config.enrich_profiler_metadata
|
||||
):
|
||||
inductor_provenance_tracking_node_mappings = json.dumps(
|
||||
torch._inductor.debug.dump_inductor_provenance_info()
|
||||
)
|
||||
|
||||
@ -1179,6 +1179,8 @@ torchinductor_worker_logpath: str = Config(
|
||||
default="",
|
||||
)
|
||||
|
||||
fallback_by_default: bool = False
|
||||
|
||||
|
||||
# config specific to codegen/cpp.py
|
||||
class cpp:
|
||||
|
||||
@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
|
||||
Returns a unique int debug handler for each call to this function.
|
||||
"""
|
||||
|
||||
if config.trace.provenance_tracking_level == 0:
|
||||
if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
@ -1628,6 +1628,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
"inductor", "lowerings", lambda: repr(n)
|
||||
)
|
||||
)
|
||||
or (n.op == "call_function" and config.fallback_by_default)
|
||||
):
|
||||
debug("fallback_handler")
|
||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||
|
||||
@ -8079,27 +8079,28 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
for v, a in zip(args_iter, kernel._schema.arguments)
|
||||
)
|
||||
|
||||
self.codegen_comment(wrapper)
|
||||
if self.use_runtime_dispatch:
|
||||
exported_args = self.export_extern_kernel_node()
|
||||
assert self.python_kernel_name is not None
|
||||
assert self.op_overload is not None
|
||||
with wrapper.kernel_debug_handle_context():
|
||||
self.codegen_comment(wrapper)
|
||||
if self.use_runtime_dispatch:
|
||||
exported_args = self.export_extern_kernel_node()
|
||||
assert self.python_kernel_name is not None
|
||||
assert self.op_overload is not None
|
||||
|
||||
wrapper.generate_fallback_kernel_with_runtime_lookup(
|
||||
self.get_name(),
|
||||
self.python_kernel_name,
|
||||
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
|
||||
self.op_overload,
|
||||
exported_args,
|
||||
# NOTE: [special handling of all_reduce_coalesced_'s return value]
|
||||
self.outputs if self.outputs else self.mutation_outputs,
|
||||
)
|
||||
else:
|
||||
wrapper.generate_fallback_kernel(self)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
self.codegen_memory_tracking(wrapper)
|
||||
wrapper.generate_fallback_kernel_with_runtime_lookup(
|
||||
self.get_name(),
|
||||
self.python_kernel_name,
|
||||
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
|
||||
self.op_overload,
|
||||
exported_args,
|
||||
# NOTE: [special handling of all_reduce_coalesced_'s return value]
|
||||
self.outputs if self.outputs else self.mutation_outputs,
|
||||
)
|
||||
else:
|
||||
wrapper.generate_fallback_kernel(self)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
self.codegen_memory_tracking(wrapper)
|
||||
|
||||
self.codegen_unbacked_symbol_defs(wrapper)
|
||||
|
||||
|
||||
@ -25,6 +25,8 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from functools import partial
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
@ -70,6 +72,10 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Used for profiler post-processing to match
|
||||
# for the same compiled run
|
||||
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OutputCode:
|
||||
@ -612,9 +618,18 @@ class CompiledFxGraph(OutputCode):
|
||||
try:
|
||||
# Checking the profiler directly is faster than nullcontext
|
||||
if torch.autograd.profiler._is_profiler_enabled:
|
||||
with record_function(
|
||||
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
|
||||
):
|
||||
# generate a random string to represent this unique run if no cache key
|
||||
run_key = (
|
||||
self._fx_graph_cache_key
|
||||
if self._fx_graph_cache_key
|
||||
else "".join(random.choices(string.ascii_lowercase, k=51))
|
||||
)
|
||||
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
|
||||
if self.inductor_provenance_stack_traces_str:
|
||||
torch.fx.traceback._register_fx_metadata(
|
||||
run_name, self.inductor_provenance_stack_traces_str
|
||||
)
|
||||
with record_function(f"## {run_name} ##"):
|
||||
return self.current_callable(inputs)
|
||||
else:
|
||||
return self.current_callable(inputs)
|
||||
|
||||
@ -38,7 +38,20 @@ class StaticallyLaunchedCudaKernel:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.name = kernel.src.fn.__name__
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.cubin_raw = kernel.asm.get("cubin", None)
|
||||
if "hsaco" in kernel.asm:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.cubin_raw = kernel.asm["hsaco"]
|
||||
self.is_rocm = True
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
elif "cubin" in kernel.asm:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.cubin_raw = kernel.asm["cubin"]
|
||||
self.is_rocm = False
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm"
|
||||
)
|
||||
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.cubin_path = kernel._cubin_path
|
||||
|
||||
@ -245,12 +258,42 @@ class StaticallyLaunchedCudaKernel:
|
||||
# thing, it should always match.
|
||||
# Get rid of constants before passing to cubin launcher
|
||||
|
||||
# Add a None if triton wants extra parameters for scratch spaces
|
||||
arg_tys = self.arg_tys
|
||||
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
|
||||
if has_scratch:
|
||||
arg_tys = arg_tys + "O"
|
||||
args = (*args, None)
|
||||
|
||||
if self.is_rocm:
|
||||
# ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both
|
||||
# global_scratch and profile_scratch parameters in the kernel signature,
|
||||
# even when the kernel doesn't use them (i.e., when has_*_scratch is False).
|
||||
#
|
||||
# This differs fundamentally from CUDA, where these parameters are only
|
||||
# present in the signature if the corresponding has_*_scratch flag is True.
|
||||
#
|
||||
# The flags indicate whether memory will be allocated/used:
|
||||
# - has_global_scratch: Whether global scratch workspace is needed
|
||||
# - has_profile_scratch: Whether profiling instrumentation is enabled
|
||||
#
|
||||
# However, regardless of flag values, we MUST always pass both parameters
|
||||
# to match the HIP kernel ABI. Passing None is safe:
|
||||
#
|
||||
# - If scratch is not needed (has_*_scratch=False or scratch_size=0):
|
||||
# The None becomes nullptr, which the kernel never dereferences
|
||||
#
|
||||
# - If scratch is needed (has_*_scratch=True and scratch_size>0):
|
||||
# The None becomes nullptr initially, but the HIP runtime intercepts
|
||||
# the kernel launch, allocates the required scratch memory based on
|
||||
# kernel metadata, and replaces the nullptr with a valid pointer before
|
||||
# the kernel actually executes
|
||||
#
|
||||
# Not passing both parameters causes segmentation faults because the kernel
|
||||
# expects them at specific positions in the argument array.
|
||||
arg_tys = arg_tys + "OO"
|
||||
args = (*args, None, None)
|
||||
|
||||
else:
|
||||
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
|
||||
if has_scratch:
|
||||
arg_tys = arg_tys + "O"
|
||||
args = (*args, None)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert len(args) == len(arg_tys)
|
||||
|
||||
|
||||
@ -1599,9 +1599,8 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
return None
|
||||
|
||||
def check_can_launch() -> StaticallyLaunchedCudaKernel:
|
||||
if triton_meta.get("device_type") != "cuda":
|
||||
# Only cuda kernels
|
||||
raise CannotStaticallyLaunchKernel("Non-cuda device")
|
||||
if triton_meta.get("device_type") not in ("cuda", "hip"):
|
||||
raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device")
|
||||
|
||||
if torch._inductor.config.cpp_wrapper:
|
||||
# If we're running with cpp wrapper, it doesn't
|
||||
@ -1627,10 +1626,11 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
"static launch does not support launch attributes"
|
||||
)
|
||||
|
||||
binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin"
|
||||
cubin_location = os.path.join(
|
||||
triton_cache_dir(triton_meta.get("device", 0)),
|
||||
triton_hash_to_path_key(kernel.hash),
|
||||
f"{kernel.src.fn.__name__}.cubin",
|
||||
f"{kernel.src.fn.__name__}.{binary_ext}",
|
||||
)
|
||||
|
||||
if not os.path.exists(cubin_location):
|
||||
@ -1662,10 +1662,11 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
When loading from cache on disk, we want to reload cubin
|
||||
files from their appropriate location on disc.
|
||||
"""
|
||||
binary_ext = "hsaco" if torch.version.hip else "cubin"
|
||||
cubin_location = os.path.join(
|
||||
triton_cache_dir(self.compile_meta.get("device", 0)),
|
||||
triton_hash_to_path_key(self.kernel.hash),
|
||||
f"{self.kernel.name}.cubin",
|
||||
f"{self.kernel.name}.{binary_ext}",
|
||||
)
|
||||
if not os.path.exists(cubin_location):
|
||||
if self.kernel.cubin_raw is not None:
|
||||
|
||||
@ -1224,43 +1224,3 @@ def _build_table(
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -2159,7 +2159,7 @@ PyObject* initModule() {
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
#if defined(USE_CUDA)
|
||||
ASSERT_TRUE(StaticCudaLauncher_init(module));
|
||||
#endif
|
||||
#ifdef USE_MPS
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
// This file should only be compiled if this condition holds, so it should be
|
||||
// safe.
|
||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <tuple>
|
||||
@ -33,7 +32,11 @@ version_tuple getRuntimeVersion() {
|
||||
}
|
||||
|
||||
size_t getVersionInt() {
|
||||
return at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
#ifndef USE_STATIC_CUDNN
|
||||
return cudnnGetVersion();
|
||||
#else
|
||||
return CUDNN_VERSION;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// We disable this file from being hipified because there are CUDA drivers hip
|
||||
// has not implemented yet. Also, we're passing in a cubin file directly, so it
|
||||
// would take more work to support ROCM anyway.
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
#include <ATen/Context.h>
|
||||
@ -16,6 +13,11 @@
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <filesystem>
|
||||
#include <optional>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
|
||||
/**
|
||||
Implements a static launcher for triton compiled CUDA kernels.
|
||||
Given a path to a cubin file, a function name, and some metadata,
|
||||
@ -56,8 +58,14 @@ const at::cuda::NVRTC& nvrtc() {
|
||||
|
||||
CUdeviceptr getPointer(PyObject* obj) {
|
||||
CUdeviceptr data_ptr = 0;
|
||||
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
#if defined(USE_ROCM)
|
||||
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(obj));
|
||||
#else
|
||||
data_ptr = THPUtils_unpackUInt64(obj);
|
||||
#endif
|
||||
|
||||
return data_ptr;
|
||||
}
|
||||
if (obj == Py_None) {
|
||||
@ -73,13 +81,25 @@ CUdeviceptr getPointer(PyObject* obj) {
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(ret),
|
||||
"data_ptr method of Pointer object must return 64-bit int");
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(ret));
|
||||
#else
|
||||
data_ptr = THPUtils_unpackUInt64(ret);
|
||||
#endif
|
||||
|
||||
if (!data_ptr)
|
||||
return data_ptr;
|
||||
|
||||
CUdeviceptr dev_ptr = 0;
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute(
|
||||
&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute(
|
||||
&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
|
||||
#endif
|
||||
|
||||
return dev_ptr;
|
||||
}
|
||||
|
||||
@ -98,6 +118,15 @@ CUfunction loadKernel(
|
||||
}
|
||||
CUmodule mod = nullptr;
|
||||
CUfunction func = nullptr;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
int shared_optin = 0;
|
||||
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
|
||||
&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
@ -106,6 +135,9 @@ CUfunction loadKernel(
|
||||
&shared_optin,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
|
||||
#endif
|
||||
|
||||
// Shared memory logic from triton/third-party/nvidia/backend/driver.c
|
||||
// If we're using more than 48 KB of shared memory, and we have
|
||||
// access to more than 48 KB of shared memory on the device,
|
||||
@ -124,6 +156,21 @@ CUfunction loadKernel(
|
||||
" Reducing block sizes or `num_stages` may help.");
|
||||
if (sharedMemBytes > SHARED_MEM_STATIC_MAX &&
|
||||
shared_optin > SHARED_MEM_STATIC_MAX) {
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared));
|
||||
int shared_total = 0, shared_static = 0;
|
||||
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
|
||||
&shared_total,
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor,
|
||||
device));
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
|
||||
&shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total = 0, shared_static = 0;
|
||||
@ -137,6 +184,7 @@ CUfunction loadKernel(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
#endif
|
||||
}
|
||||
return func;
|
||||
}
|
||||
@ -152,6 +200,27 @@ inline void launchKernel(
|
||||
cudaStream_t stream) {
|
||||
// cta_args is always 1 for inductor generated triton kernels,
|
||||
// so we don't need to figure out grid dimension here
|
||||
#if defined(USE_ROCM)
|
||||
int device = 0;
|
||||
AT_CUDA_DRIVER_CHECK(hipGetDevice(&device));
|
||||
int warp_size = 0;
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device));
|
||||
|
||||
AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel(
|
||||
func,
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
warp_size * numWarps, // blockDim.x
|
||||
1, // blockDim.y
|
||||
1, // blockDim.z
|
||||
sharedMemBytes,
|
||||
stream,
|
||||
args,
|
||||
nullptr));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
||||
func,
|
||||
gridX,
|
||||
@ -164,6 +233,7 @@ inline void launchKernel(
|
||||
stream,
|
||||
args,
|
||||
nullptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename FINAL, typename F>
|
||||
@ -269,11 +339,20 @@ PyObject* load_kernel(PyObject* self, PyObject* args) {
|
||||
CUdevice device = static_cast<CUdevice>(device_ptr); // NOLINT
|
||||
CUfunction func = nullptr;
|
||||
func = loadKernel(filePath, funcName, sharedMemBytes, device);
|
||||
// Taken from triton/nvidia/backend/driver.c
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func));
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
|
||||
&n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
|
||||
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
|
||||
|
||||
#endif
|
||||
n_spills /= 4;
|
||||
// Return a tuple of CUFunction, n_regs, n_spills
|
||||
return Py_BuildValue(
|
||||
@ -299,7 +378,6 @@ PyObject* launch_kernel_inner(
|
||||
std::array<uint64_t, MAX_ARGS> argStorage = {};
|
||||
std::array<void*, MAX_ARGS> kernelArgs = {};
|
||||
parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data());
|
||||
|
||||
launchKernel(
|
||||
func,
|
||||
gridX,
|
||||
@ -386,13 +464,25 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
CUcontext pctx = nullptr;
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx));
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
#endif
|
||||
|
||||
if (!pctx) {
|
||||
// Ensure device context exists
|
||||
CUdevice device = 0;
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0));
|
||||
AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device));
|
||||
AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx));
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx));
|
||||
|
||||
#endif
|
||||
}
|
||||
CUfunction func = reinterpret_cast<CUfunction>(func_ptr); // NOLINT
|
||||
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(stream); // NOLINT
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
#if defined(USE_CUDA)
|
||||
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
|
||||
@ -20,10 +20,7 @@ from torch.distributed.tensor._tp_conv import (
|
||||
convolution_backward_handler,
|
||||
convolution_handler,
|
||||
)
|
||||
from torch.distributed.tensor._utils import (
|
||||
ExplicitRedistributionContext,
|
||||
try_find_mesh_from_args,
|
||||
)
|
||||
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
||||
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
||||
from torch.utils._debug_mode import get_active_debug_mode
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
@ -202,10 +199,6 @@ class OpDispatcher:
|
||||
if participating:
|
||||
# computation that happens in the current rank of the mesh, normal case
|
||||
if output_sharding.needs_redistribute:
|
||||
if ExplicitRedistributionContext.is_active():
|
||||
raise RuntimeError(
|
||||
f"Implicit redistribution occurred while ExplicitRedistributionContext was active for {op_info.schema}"
|
||||
)
|
||||
# If sharding propagation decision needs redistribute, perform redistribute
|
||||
# on args first, which could potentially modify args (i.e. allgather certain arg)
|
||||
assert output_sharding.redistribute_schema is not None
|
||||
|
||||
@ -18,33 +18,6 @@ from torch.distributed.tensor.placement_types import (
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
class ExplicitRedistributionContext:
|
||||
"""
|
||||
Within this context manager, DTensor will refuse to perform implicit redistribution,
|
||||
instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution
|
||||
must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution.
|
||||
|
||||
Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass
|
||||
may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual
|
||||
calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op
|
||||
during forward and perform a manual redistribution during backwards.
|
||||
"""
|
||||
|
||||
_explicit_redistribute_mode = False
|
||||
|
||||
@classmethod
|
||||
def is_active(cls) -> bool:
|
||||
return cls._explicit_redistribute_mode
|
||||
|
||||
def __enter__(self):
|
||||
self.prev = ExplicitRedistributionContext._explicit_redistribute_mode
|
||||
ExplicitRedistributionContext._explicit_redistribute_mode = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
ExplicitRedistributionContext._explicit_redistribute_mode = self.prev
|
||||
|
||||
|
||||
def _explicit_order_placements(
|
||||
mesh_shape: ShapeType, placements: Sequence[Placement]
|
||||
) -> Sequence[tuple[int, Placement]]:
|
||||
|
||||
@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils._config_module import Config, install_config_module
|
||||
|
||||
|
||||
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
|
||||
no_data_dependent_graph_break = (
|
||||
@ -100,7 +102,11 @@ backed_size_oblivious = False
|
||||
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||
skip_dtype_check_in_meta_registrations = False
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
||||
@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
|
||||
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .experimental import _config as fx_experimental_config
|
||||
from .graph import (
|
||||
_BoxedCodeGen,
|
||||
_custom_builtins,
|
||||
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
# Do not import anything inside recompile, it might slow down the
|
||||
# function and cause perf regression. Import outside of the method instead.
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
root_module="self",
|
||||
record_func=fx_experimental_config.enrich_profiler_metadata,
|
||||
)
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
if fx_experimental_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
|
||||
@ -43,10 +43,10 @@ should_preserve_node_meta = False
|
||||
# =============================================================================
|
||||
# Global in-memory registry for FX metadata
|
||||
# Maps module_name -> metadata dict containing lineno_map and node_metadata
|
||||
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
_FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
|
||||
"""
|
||||
Register FX metadata in the global in-memory registry.
|
||||
|
||||
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
|
||||
Args:
|
||||
module_name: The module identifier (content-addressed filename)
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
|
||||
"""
|
||||
# TODO: add logging to tlparse
|
||||
_FX_METADATA_REGISTRY[module_name] = metadata
|
||||
|
||||
@ -139,22 +139,3 @@ AT_FORALL_COMPLEX_TYPES
|
||||
toString
|
||||
<<
|
||||
toUnderlying
|
||||
|
||||
# torch/headeronly/core/Dispatch_v2.h
|
||||
THO_DISPATCH_V2_TMPL
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL
|
||||
THO_DISPATCH_CASE_TMPL
|
||||
THO_DISPATCH_SWITCH_TMPL
|
||||
# AT_WRAP, THO_AP_VAR_TMPL, AT_CONCAT, AT_CONCAT_AUX, AT_EXPAND are tested through THO_DISPATCH_V2_TMPL
|
||||
# scalar_type is tested through THO_DISPATCH_SWITCH_TMPL
|
||||
AT_FLOAT8_TYPES
|
||||
AT_INTEGRAL_TYPES
|
||||
AT_FLOATING_TYPES
|
||||
AT_BAREBONES_UNSIGNED_TYPES
|
||||
AT_INTEGRAL_TYPES_V2
|
||||
AT_COMPLEX_TYPES
|
||||
AT_QINT_TYPES
|
||||
AT_ALL_TYPES
|
||||
AT_ALL_TYPES_AND_COMPLEX
|
||||
THO_DISPATCH_V2
|
||||
# THO_EMPTY, THO_DISPATCH_CASE, THO_DISPATCH_SWITCH, THO_PRIVATE_CASE_TYPE_USING_HINT are tested through THO_DISPATCH_V2
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
// THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL is same as
|
||||
// AT_PRIVATE_CASE_TYPE_USING_HINT but with a custom PRELUDE macro:
|
||||
#define THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL(PRELUDE, enum_type, HINT, ...) \
|
||||
case enum_type: { \
|
||||
PRELUDE(enum_type); \
|
||||
using HINT [[maybe_unused]] = \
|
||||
torch::headeronly::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
// THO_DISPATCH_CASE_TMPL is same as AT_DISPATCH_CASE but with a
|
||||
// custom CASE_TYPE_USING_HINT macro:
|
||||
#define THO_DISPATCH_CASE_TMPL(CASE_TYPE_USING_HINT, enum_type, ...) \
|
||||
CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||
|
||||
namespace detail {
|
||||
inline torch::headeronly::ScalarType scalar_type(
|
||||
torch::headeronly::ScalarType s) {
|
||||
return s;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// THO_DISPATCH_SWITCH_TMPL is same as AT_DISPATCH_SWITCH but with
|
||||
// custom PRELUDE and CHECK_NOT_IMPLEMENTED macros:
|
||||
#define THO_DISPATCH_SWITCH_TMPL( \
|
||||
PRELUDE, CHECK_NOT_IMPLEMENTED, TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
constexpr const char* at_dispatch_name = NAME; \
|
||||
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||
torch::headeronly::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||
PRELUDE(at_dispatch_name, _st); \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
CHECK_NOT_IMPLEMENTED( \
|
||||
false, \
|
||||
'"', \
|
||||
at_dispatch_name, \
|
||||
"\" not implemented for '", \
|
||||
torch::headeronly::toString(_st), \
|
||||
"'"); \
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP() \
|
||||
}()
|
||||
|
||||
// THO_EMPTY is a helper macro that discards its arguments.
|
||||
#define THO_EMPTY(...)
|
||||
|
||||
// THO_PRIVATE_CASE_TYPE_USING_HINT is same as
|
||||
// AT_PRIVATE_CASE_TYPE_USING_HINT with call to macro
|
||||
// AT_PRIVATE_CHECK_SELECTIVE_BUILD removed.
|
||||
#define THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL(THO_EMPTY, enum_type, HINT, __VA_ARGS__)
|
||||
|
||||
// THO_DISPATCH_SWITCH is same as AT_DISPATCH_SWITCH with call to
|
||||
// macro RECORD_KERNEL_FUNCTION_DTYPE removed and using
|
||||
// STD_TORCH_CHECK instead of TORCH_CHECK_NOT_IMPLEMENTED.
|
||||
#define THO_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH_TMPL(THO_EMPTY, STD_TORCH_CHECK, TYPE, NAME, __VA_ARGS__)
|
||||
|
||||
// THO_DISPATCH_CASE is same as AT_DISPATCH_CASE but using
|
||||
// THO_PRIVATE_CASE_TYPE_USING_HINT instead of
|
||||
// AT_PRIVATE_CASE_TYPE_USING_HINT.
|
||||
#define THO_DISPATCH_CASE(enum_type, ...) \
|
||||
THO_DISPATCH_CASE_TMPL( \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT, enum_type, __VA_ARGS__)
|
||||
@ -1,170 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
// This file provides THO_DISPATCH_V2_TMPL macro that is a generalized
|
||||
// version of the original AT_DISPATCH_V2 (see ATen/Dispatch_v2.h for
|
||||
// documentation): THO_DISPATCH_V2_TMPL extends AT_DISPATCH_V2 with
|
||||
// extra DISPATCH_SWITCH and DISPATCH_CASE arguments for specifying
|
||||
// custom implementations of the original AT_DISPATCH_SWITCH and
|
||||
// AT_DISPATCH_CASE macros. Use the provided macros
|
||||
// THO_DISPATCH_SWITCH_TMPL and THO_DISPATCH_CASE_TMPL to define the
|
||||
// custom implementations of the switch and case macros, respectively.
|
||||
|
||||
// Public API macros
|
||||
|
||||
// THO_DISPATCH_V2_TMPL is same as AT_DISPATCH_V2 but with custom
|
||||
// DISPATCH_SWITCH and DISPATCH_CASE macro arguments:
|
||||
#define THO_DISPATCH_V2_TMPL( \
|
||||
DISPATCH_SWITCH, DISPATCH_CASE, TYPE, NAME, BODY, ...) \
|
||||
DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
THO_AP_VAR_TMPL(DISPATCH_CASE, AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||
|
||||
// THO_DISPATCH_V2 is same as AT_DISPATCH_V2 but using
|
||||
// THO_DISPATCH_SWITCH and THO_DISPATCH_CASE instead of
|
||||
// AT_DISPATCH_SWITCH and AT_DISPATCH_CASE, respectively.
|
||||
#define THO_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
THO_DISPATCH_SWITCH, THO_DISPATCH_CASE, TYPE, NAME, BODY, __VA_ARGS__)
|
||||
|
||||
// Type collection macros
|
||||
|
||||
// This macro lets you pass an arbitrary expression that may contain internal
|
||||
// commas to another macro without having the commas causing the expression
|
||||
// to be interpreted as being multiple arguments
|
||||
#define AT_WRAP(...) __VA_ARGS__
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
torch::headeronly::ScalarType::Float8_e5m2, \
|
||||
torch::headeronly::ScalarType::Float8_e5m2fnuz, \
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn, \
|
||||
torch::headeronly::ScalarType::Float8_e4m3fnuz, \
|
||||
torch::headeronly::ScalarType::Float8_e8m0fnu
|
||||
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
torch::headeronly::ScalarType::Byte, torch::headeronly::ScalarType::Char, \
|
||||
torch::headeronly::ScalarType::Int, torch::headeronly::ScalarType::Long, \
|
||||
torch::headeronly::ScalarType::Short
|
||||
#define AT_FLOATING_TYPES \
|
||||
torch::headeronly::ScalarType::Double, torch::headeronly::ScalarType::Float
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES \
|
||||
torch::headeronly::ScalarType::UInt16, \
|
||||
torch::headeronly::ScalarType::UInt32, \
|
||||
torch::headeronly::ScalarType::UInt64
|
||||
#define AT_INTEGRAL_TYPES_V2 \
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#define AT_COMPLEX_TYPES \
|
||||
torch::headeronly::ScalarType::ComplexDouble, \
|
||||
torch::headeronly::ScalarType::ComplexFloat
|
||||
#define AT_QINT_TYPES \
|
||||
torch::headeronly::ScalarType::QInt8, torch::headeronly::ScalarType::QUInt8, \
|
||||
torch::headeronly::ScalarType::QInt32
|
||||
// NB: not *actually* all types
|
||||
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||
|
||||
// Helper macros
|
||||
|
||||
// THO_AP_VAR_TMPL is same as AT_AP_VAR but with a custom
|
||||
// DISPATCH_CASE macro argument:
|
||||
#define THO_AP_VAR_TMPL(C, N, T, ...) \
|
||||
AT_EXPAND( \
|
||||
AT_CONCAT(THO_AP, AT_NUM_ARGS(__VA_ARGS__))(C, AT_WRAP(N), __VA_ARGS__))
|
||||
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||
#define AT_CONCAT_AUX(a, b) a##b
|
||||
#define AT_EXPAND(X) X
|
||||
|
||||
// Ensure we never have too many scalar types for the expansion here to
|
||||
// support. To bump this, you must regenerate the macros below.
|
||||
static_assert(static_cast<int>(torch::headeronly::ScalarType::NumOptions) < 60);
|
||||
|
||||
// Python code to regenerate generate code below:
|
||||
#if 0
|
||||
|
||||
num_args = 60
|
||||
|
||||
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||
|
||||
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||
|
||||
for i in range(1, num_args+1):
|
||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||
cases = ' '.join([f'C(_{j}, N)' for j in range(1, i+1)])
|
||||
print(f'#define THO_AP{i}(C, N, {args}) {cases}')
|
||||
|
||||
#endif
|
||||
|
||||
// Begin generated code
|
||||
// clang-format off
|
||||
|
||||
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||
#define THO_AP1(C, N, _1) C(_1, N)
|
||||
#define THO_AP2(C, N, _1, _2) C(_1, N) C(_2, N)
|
||||
#define THO_AP3(C, N, _1, _2, _3) C(_1, N) C(_2, N) C(_3, N)
|
||||
#define THO_AP4(C, N, _1, _2, _3, _4) C(_1, N) C(_2, N) C(_3, N) C(_4, N)
|
||||
#define THO_AP5(C, N, _1, _2, _3, _4, _5) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N)
|
||||
#define THO_AP6(C, N, _1, _2, _3, _4, _5, _6) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N)
|
||||
#define THO_AP7(C, N, _1, _2, _3, _4, _5, _6, _7) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N)
|
||||
#define THO_AP8(C, N, _1, _2, _3, _4, _5, _6, _7, _8) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N)
|
||||
#define THO_AP9(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N)
|
||||
#define THO_AP10(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N)
|
||||
#define THO_AP11(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N)
|
||||
#define THO_AP12(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N)
|
||||
#define THO_AP13(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N)
|
||||
#define THO_AP14(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N)
|
||||
#define THO_AP15(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N)
|
||||
#define THO_AP16(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N)
|
||||
#define THO_AP17(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N)
|
||||
#define THO_AP18(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N)
|
||||
#define THO_AP19(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N)
|
||||
#define THO_AP20(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N)
|
||||
#define THO_AP21(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N)
|
||||
#define THO_AP22(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N)
|
||||
#define THO_AP23(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N)
|
||||
#define THO_AP24(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N)
|
||||
#define THO_AP25(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N)
|
||||
#define THO_AP26(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N)
|
||||
#define THO_AP27(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N)
|
||||
#define THO_AP28(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N)
|
||||
#define THO_AP29(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N)
|
||||
#define THO_AP30(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N)
|
||||
#define THO_AP31(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N)
|
||||
#define THO_AP32(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N)
|
||||
#define THO_AP33(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N)
|
||||
#define THO_AP34(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N)
|
||||
#define THO_AP35(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N)
|
||||
#define THO_AP36(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N)
|
||||
#define THO_AP37(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N)
|
||||
#define THO_AP38(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N)
|
||||
#define THO_AP39(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N)
|
||||
#define THO_AP40(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N)
|
||||
#define THO_AP41(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N)
|
||||
#define THO_AP42(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N)
|
||||
#define THO_AP43(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N)
|
||||
#define THO_AP44(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N)
|
||||
#define THO_AP45(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N)
|
||||
#define THO_AP46(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N)
|
||||
#define THO_AP47(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N)
|
||||
#define THO_AP48(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N)
|
||||
#define THO_AP49(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N)
|
||||
#define THO_AP50(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N)
|
||||
#define THO_AP51(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N)
|
||||
#define THO_AP52(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N)
|
||||
#define THO_AP53(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N)
|
||||
#define THO_AP54(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N)
|
||||
#define THO_AP55(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N)
|
||||
#define THO_AP56(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N)
|
||||
#define THO_AP57(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N)
|
||||
#define THO_AP58(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N)
|
||||
#define THO_AP59(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) C(_59, N)
|
||||
#define THO_AP60(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) C(_59, N) C(_60, N)
|
||||
|
||||
// End generated code
|
||||
// clang-format on
|
||||
@ -2838,9 +2838,6 @@ def batch_norm(
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_verify_batch_size(input.size())
|
||||
|
||||
if eps <= 0.0:
|
||||
raise ValueError(f"batch_norm eps must be positive, but got {eps}")
|
||||
|
||||
return torch.batch_norm(
|
||||
input,
|
||||
weight,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import json
|
||||
import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
@ -402,13 +404,31 @@ def _init_for_cuda_graphs() -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ContextType(Enum):
|
||||
"""Types of contexts in the profiler stack."""
|
||||
|
||||
FX_GRAPH = "filename"
|
||||
FX_NODE = "node"
|
||||
COMPILED_GRAPH = "compiled_graph"
|
||||
INDUCTOR_NODE = "inductor_node"
|
||||
|
||||
|
||||
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
|
||||
if context_type == ContextType.FX_NODE:
|
||||
return ContextType.FX_GRAPH
|
||||
elif context_type == ContextType.INDUCTOR_NODE:
|
||||
return ContextType.COMPILED_GRAPH
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimelineEvent:
|
||||
"""Represents an event in the profiler timeline."""
|
||||
|
||||
timestamp: int
|
||||
event_type: Literal["start", "end", "regular"]
|
||||
marker_type: Optional[Literal["filename", "node"]]
|
||||
marker_type: Optional[ContextType]
|
||||
identifier: Optional[str | int]
|
||||
event: dict[str, Any]
|
||||
|
||||
@ -417,7 +437,7 @@ class TimelineEvent:
|
||||
class ContextStackEntry:
|
||||
"""Represents a context (filename or node) in the stack."""
|
||||
|
||||
context_type: Literal["filename", "node"]
|
||||
context_type: ContextType
|
||||
identifier: str | int
|
||||
metadata: Optional[dict]
|
||||
tid: Optional[int] = None # Thread ID associated with this context
|
||||
@ -438,6 +458,8 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
from torch._inductor.output_code import CALL_COMPILED_PREFIX
|
||||
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
@ -447,7 +469,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
|
||||
def is_fx_marker_event(event):
|
||||
return (
|
||||
event.get("cat") == "cpu_op"
|
||||
event.get("cat") in ("cpu_op", "user_annotation")
|
||||
and event.get("name", "").startswith("## ")
|
||||
and event.get("name", "").endswith(" ##")
|
||||
)
|
||||
@ -469,14 +491,27 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
if is_fx_marker_event(event):
|
||||
content = event["name"][3:-3]
|
||||
|
||||
if content.endswith(".py"):
|
||||
append_fx_marker_event("filename", content, event)
|
||||
# Try different event types
|
||||
if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
|
||||
".py"
|
||||
):
|
||||
# FX graph event
|
||||
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
|
||||
elif content.startswith(CALL_COMPILED_PREFIX):
|
||||
# Inductor compiled graph event
|
||||
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
|
||||
elif content.startswith("inductor_kernel:"):
|
||||
append_fx_marker_event(
|
||||
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
|
||||
)
|
||||
else:
|
||||
# Try to parse as node index for FX graph
|
||||
# TODO: change to start with fx_node
|
||||
try:
|
||||
node_index = int(content)
|
||||
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
|
||||
except ValueError:
|
||||
pass
|
||||
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
|
||||
|
||||
else:
|
||||
# Regular event that needs augmentation
|
||||
@ -495,23 +530,37 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
case "start":
|
||||
assert timeline_event.identifier is not None
|
||||
|
||||
if timeline_event.marker_type == "filename":
|
||||
if timeline_event.marker_type in (
|
||||
ContextType.FX_GRAPH,
|
||||
ContextType.COMPILED_GRAPH,
|
||||
):
|
||||
assert isinstance(timeline_event.identifier, str)
|
||||
# Push filename context - query metadata registry on-demand
|
||||
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
|
||||
tid = timeline_event.event.get("tid")
|
||||
|
||||
# TODO: add get method in traceback to try - catch and get
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"filename", timeline_event.identifier, metadata, tid
|
||||
timeline_event.marker_type,
|
||||
timeline_event.identifier,
|
||||
metadata,
|
||||
tid,
|
||||
)
|
||||
)
|
||||
elif timeline_event.marker_type == "node":
|
||||
elif timeline_event.marker_type in (
|
||||
ContextType.FX_NODE,
|
||||
ContextType.INDUCTOR_NODE,
|
||||
):
|
||||
# Find the current filename from stack
|
||||
current_file_metadata = None
|
||||
tid = timeline_event.event.get("tid")
|
||||
parent_type = get_parent_context_type(timeline_event.marker_type)
|
||||
for ctx_entry in reversed(context_stack):
|
||||
if (
|
||||
ctx_entry.context_type == "filename"
|
||||
ctx_entry.context_type == parent_type
|
||||
and ctx_entry.tid == tid
|
||||
):
|
||||
current_file_metadata = ctx_entry.metadata
|
||||
@ -520,14 +569,39 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
if current_file_metadata:
|
||||
node_metadata = current_file_metadata.get("node_metadata", {})
|
||||
if timeline_event.identifier in node_metadata:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"node", timeline_event.identifier, node_meta, tid
|
||||
if ctx_entry.context_type == ContextType.FX_NODE:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
ContextType.FX_NODE,
|
||||
timeline_event.identifier,
|
||||
node_meta,
|
||||
tid,
|
||||
)
|
||||
)
|
||||
|
||||
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
|
||||
# Look up stack traces for this kernel
|
||||
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
|
||||
stack_traces = current_file_metadata.get(
|
||||
timeline_event.identifier, []
|
||||
)
|
||||
if stack_traces:
|
||||
# Store all stack traces as metadata
|
||||
node_meta: Optional[dict] = {
|
||||
"stack_trace": stack_traces,
|
||||
"name": timeline_event.identifier,
|
||||
}
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
ContextType.INDUCTOR_NODE,
|
||||
timeline_event.identifier,
|
||||
node_meta,
|
||||
tid,
|
||||
)
|
||||
)
|
||||
|
||||
case "end":
|
||||
# Pop from stack - search backwards to find matching context
|
||||
@ -551,7 +625,10 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
for ctx_entry in reversed(context_stack):
|
||||
# Only apply metadata from contexts with matching tid
|
||||
if ctx_entry.tid == event_tid:
|
||||
if ctx_entry.context_type == "node" and ctx_entry.metadata:
|
||||
if (
|
||||
ctx_entry.context_type == ContextType.FX_NODE
|
||||
and ctx_entry.metadata
|
||||
):
|
||||
current_stack_trace = ctx_entry.metadata.get(
|
||||
"stack_trace", "No model stack trace available"
|
||||
)
|
||||
@ -559,6 +636,19 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
|
||||
# if nodes are nested, e.g. in nested graph modules
|
||||
break
|
||||
elif (
|
||||
ctx_entry.context_type == ContextType.INDUCTOR_NODE
|
||||
and ctx_entry.metadata
|
||||
):
|
||||
# For inductor nodes, stack_trace is a list of traces
|
||||
stack_traces_list = ctx_entry.metadata.get(
|
||||
"stack_trace", []
|
||||
)
|
||||
if stack_traces_list:
|
||||
# Store as a list - each trace gets its own entry
|
||||
current_stack_trace = stack_traces_list
|
||||
current_node_name = ctx_entry.metadata.get("name", "")
|
||||
break
|
||||
|
||||
# Augment the event
|
||||
if current_stack_trace or current_node_name:
|
||||
@ -567,3 +657,81 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
args["stack_trace"] = current_stack_trace
|
||||
if current_node_name:
|
||||
args["node_name"] = current_node_name
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
if isinstance(stack_trace, list):
|
||||
stack_trace = "\n".join(stack_trace)
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
trace_file = f.name
|
||||
|
||||
try:
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
finally:
|
||||
if os.path.exists(trace_file):
|
||||
os.remove(trace_file)
|
||||
|
||||
@ -961,38 +961,6 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad
|
||||
desc='zero_batch')]
|
||||
|
||||
|
||||
def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
if module_info.module_cls == torch.nn.BatchNorm1d:
|
||||
input_shape = (2, 10)
|
||||
elif module_info.module_cls == torch.nn.BatchNorm2d:
|
||||
input_shape = (2, 10, 5, 5)
|
||||
else:
|
||||
input_shape = (2, 10, 4, 4, 4)
|
||||
|
||||
return [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(10, eps=-1.0),
|
||||
forward_input=FunctionInput(make_input(input_shape)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.FORWARD_ERROR,
|
||||
error_type=ValueError,
|
||||
error_regex="eps must be positive"
|
||||
),
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(10, eps=0.0),
|
||||
forward_input=FunctionInput(make_input(input_shape)),
|
||||
),
|
||||
error_on=ModuleErrorEnum.FORWARD_ERROR,
|
||||
error_type=ValueError,
|
||||
error_regex="eps must be positive"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
N = kwargs['N']
|
||||
lazy = kwargs.get('lazy', False)
|
||||
@ -3462,7 +3430,6 @@ module_db: list[ModuleInfo] = [
|
||||
ModuleInfo(torch.nn.BatchNorm1d,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
|
||||
skips=(
|
||||
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
|
||||
# RuntimeError: tried to get Double out of SymInt
|
||||
@ -3481,7 +3448,6 @@ module_db: list[ModuleInfo] = [
|
||||
ModuleInfo(torch.nn.BatchNorm2d,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
|
||||
skips=(
|
||||
# See https://github.com/pytorch/pytorch/issues/134580
|
||||
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')),
|
||||
@ -3502,7 +3468,6 @@ module_db: list[ModuleInfo] = [
|
||||
ModuleInfo(torch.nn.BatchNorm3d,
|
||||
train_and_eval_differ=True,
|
||||
module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
|
||||
skips=(
|
||||
# not supported on MPS backend
|
||||
DecorateInfo(skipMPS),
|
||||
|
||||
Reference in New Issue
Block a user