mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Compare commits
10 Commits
quote-pyte
...
bf/bug-sta
| Author | SHA1 | Date | |
|---|---|---|---|
| 8cef91fb74 | |||
| 527b1109a8 | |||
| 3144713325 | |||
| eefa16342c | |||
| d02f68f484 | |||
| 68eb55c4b2 | |||
| 8d4b8ab430 | |||
| afd50bdd29 | |||
| 56dfd4c74b | |||
| 24db5c4451 |
12
.github/actions/pytest-cache-download/action.yml
vendored
12
.github/actions/pytest-cache-download/action.yml
vendored
@ -38,9 +38,9 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--download \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
--repo "$REPO" \
|
||||
--bucket "$BUCKET" \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
--repo $REPO \
|
||||
--bucket $BUCKET \
|
||||
|
||||
16
.github/actions/pytest-cache-upload/action.yml
vendored
16
.github/actions/pytest-cache-upload/action.yml
vendored
@ -47,11 +47,11 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--upload \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--sha "$SHA" \
|
||||
--test_config "$TEST_CONFIG" \
|
||||
--shard "$SHARD" \
|
||||
--repo "$REPO" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--sha $SHA \
|
||||
--test_config $TEST_CONFIG \
|
||||
--shard $SHARD \
|
||||
--repo $REPO \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -41,16 +39,6 @@ namespace {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -59,6 +59,24 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -1087,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1121,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1250,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
mat_a.sizes()[1],
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import tempfile
|
||||
from threading import Event
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
raise_testexc,
|
||||
SubprocException,
|
||||
@ -16,9 +17,12 @@ from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class TestCompileWorker(TestCase):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size)
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_basic_jobs(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
@ -29,7 +33,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_exception(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(raise_testexc)
|
||||
with self.assertRaisesRegex(
|
||||
@ -42,7 +46,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_crash(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
a = pool.submit(os._exit, 1)
|
||||
@ -58,7 +62,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
@ -75,7 +79,7 @@ class TestCompileWorker(TestCase):
|
||||
os.environ["ROLE_RANK"] = "0"
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
|
||||
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
pool.submit(operator.add, 100, 1)
|
||||
self.assertEqual(os.path.exists(temp_log.name), True)
|
||||
@ -83,6 +87,12 @@ class TestCompileWorker(TestCase):
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
@config.patch("quiesce_async_compile_time", 0.1)
|
||||
class TestCompileWorkerWithTimer(TestCompileWorker):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size, quiesce=True)
|
||||
|
||||
|
||||
class TestTimer(TestCase):
|
||||
def test_basics(self):
|
||||
done = Event()
|
||||
|
||||
@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
softmax_load_str in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -15280,7 +15280,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_native_layer_norm_relu",
|
||||
"triton_poi_fused_addmm_native_layer_norm",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
@ -15293,7 +15293,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_LayerNorm_ReLU",
|
||||
"triton_poi_fused_LayerNorm_Linear_ReLU",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1826,9 +1826,14 @@ def run_test_module(
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
start = time.perf_counter()
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
end = time.perf_counter()
|
||||
print_to_stderr(
|
||||
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
|
||||
)
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
|
||||
@ -7413,6 +7413,140 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
"""Collects all frames that has node metadata from a memory snapshot."""
|
||||
# Collect all frames with FX metadata
|
||||
fx_frames = []
|
||||
|
||||
# Check device traces for FX debug fields
|
||||
if collect_device_traces and "device_traces" in augmented_snapshot:
|
||||
for trace_list in augmented_snapshot["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
for frame in trace_entry["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
# Check for FX debug fields
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
|
||||
# Check segments/blocks for FX debug fields
|
||||
if collect_segments and "segments" in augmented_snapshot:
|
||||
for segment in augmented_snapshot["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
for frame in block["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_parametrized_tests(TestCompileKernel)
|
||||
|
||||
@ -771,6 +771,7 @@ class TestFX(JitTestCase):
|
||||
gm = GraphModule(tracer.root, graph)
|
||||
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
# test custom codegen
|
||||
def transform_code(code):
|
||||
@ -780,6 +781,7 @@ class TestFX(JitTestCase):
|
||||
gm.recompile()
|
||||
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
def test_graph_unique_names_manual(self):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
|
||||
@ -209,42 +209,36 @@ def infer_scale_swizzle(mat, scale):
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# MX formats
|
||||
if not torch.version.hip:
|
||||
# MXFP8 w/ swizzle
|
||||
# MX w/swizzle (NVIDIA)
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MXFP8 w/o swizzle
|
||||
# MX w/o swizzle (AMD)
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
@ -1868,7 +1862,7 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
@ -1882,8 +1876,12 @@ class TestFP8Matmul(TestCase):
|
||||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
@ -2061,7 +2059,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
||||
@ -739,6 +739,12 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -179,6 +179,9 @@ def aot_stage1_graph_capture(
|
||||
)
|
||||
)
|
||||
|
||||
print(f"in aot_stage1_graph_capture. maybe_subclass_meta.fw_metadata.static_input_indices:{maybe_subclass_meta.fw_metadata.static_input_indices if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None else None}")
|
||||
print(f"in aot_stage1_graph_capture. aot_state.fw_metadata.static_input_indices:{aot_state.fw_metadata.static_input_indices}")
|
||||
|
||||
return AOTGraphCapture(
|
||||
wrappers=wrappers,
|
||||
graph_module=graph,
|
||||
|
||||
@ -423,6 +423,10 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
pg = _resolve_process_group(group_name)
|
||||
if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
|
||||
# nccl estimator requires real process group
|
||||
return None
|
||||
|
||||
fn = fx_node.target
|
||||
assert isinstance(fn, torch._ops.OpOverload)
|
||||
with torch.distributed._time_estimator(group=pg) as time_estimator:
|
||||
|
||||
@ -2318,7 +2318,7 @@ def compile_fx_forward(
|
||||
# force the outputs of invoke_subgraph subgraph to follow the
|
||||
# original strides
|
||||
_recursive_record_user_visible_output_idxs(gm)
|
||||
|
||||
print(f"in compile_fx_foward. static_input_idxs:{get_static_input_idxs(fixed)}")
|
||||
return inner_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
|
||||
@ -24,6 +24,7 @@ from typing_extensions import Never, ParamSpec
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch._inductor.compile_worker.timer import Timer
|
||||
from torch._inductor.compile_worker.tracked_process_pool import (
|
||||
TrackedProcessPoolExecutor,
|
||||
)
|
||||
@ -132,6 +133,7 @@ class SubprocPool:
|
||||
nprocs: int,
|
||||
pickler: Optional[SubprocPickler] = None,
|
||||
kind: SubprocKind = SubprocKind.FORK,
|
||||
quiesce: bool = False,
|
||||
) -> None:
|
||||
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
|
||||
self.pickler = pickler or SubprocPickler()
|
||||
@ -216,6 +218,13 @@ class SubprocPool:
|
||||
"pytorch.wait_counter.subproc_pool.first_job"
|
||||
).guard()
|
||||
|
||||
if quiesce:
|
||||
self.timer: Optional[Timer] = Timer(
|
||||
config.quiesce_async_compile_time, self.quiesce
|
||||
)
|
||||
else:
|
||||
self.timer = None
|
||||
|
||||
# Start thread last to ensure all member variables are initialized
|
||||
# before any access.
|
||||
self.read_thread.start()
|
||||
@ -288,6 +297,8 @@ class SubprocPool:
|
||||
with self.futures_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.record_call()
|
||||
if isinstance(result, _SubprocExceptionInfo):
|
||||
# An exception occurred in the submitted job
|
||||
self.pending_futures[job_id].set_exception(
|
||||
@ -322,6 +333,8 @@ class SubprocPool:
|
||||
with self.write_lock:
|
||||
if not self.running:
|
||||
return
|
||||
if self.timer:
|
||||
self.timer.quit()
|
||||
self.running = False
|
||||
self.running_waitcounter.__exit__()
|
||||
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
|
||||
|
||||
@ -17,7 +17,7 @@ class Timer:
|
||||
self.background_thread: Optional[Thread] = None
|
||||
self.last_called: Optional[float] = None
|
||||
self.duration = duration
|
||||
self.sleep_time = 60
|
||||
self.sleep_time = duration / 2
|
||||
self.call = call
|
||||
self.exit = False
|
||||
|
||||
|
||||
@ -964,6 +964,11 @@ quiesce_async_compile_pool: bool = Config(
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Time in seconds to wait before quiescing
|
||||
quiesce_async_compile_time: int = Config(
|
||||
default=60,
|
||||
)
|
||||
|
||||
# Whether or not to enable statically launching CUDA kernels
|
||||
# compiled by triton (instead of using triton's own launcher)
|
||||
use_static_cuda_launcher: bool = static_cuda_launcher_default()
|
||||
|
||||
@ -51,8 +51,8 @@ from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
has_uses_tagged_as,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
|
||||
if not is_gpu(inp.meta["val"].device.type):
|
||||
return False
|
||||
|
||||
output = match.output_node()
|
||||
return all(is_pointwise_use(use) for use in output.users)
|
||||
return has_uses_tagged_as(
|
||||
match.output_node(),
|
||||
(torch.Tag.pointwise, torch.Tag.reduction),
|
||||
)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
|
||||
@ -549,6 +549,70 @@ def is_pointwise_use(
|
||||
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
|
||||
|
||||
|
||||
class LogicalConnective(enum.Enum):
|
||||
OR = enum.auto()
|
||||
AND = enum.auto()
|
||||
|
||||
|
||||
def has_uses(
|
||||
target: Node,
|
||||
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a target, explore the uses of `target` by applying `use_selector_fn`
|
||||
on them, and then aggregate these booleans with the `use_aggregate_type`
|
||||
logical connective.
|
||||
|
||||
Uses in view ops will follow the views uses.
|
||||
"""
|
||||
|
||||
def get_use_aggregate_fn(
|
||||
use_aggregate_type: LogicalConnective,
|
||||
) -> Callable[[Iterator[Any]], bool]:
|
||||
match use_aggregate_type:
|
||||
case LogicalConnective.AND:
|
||||
return all
|
||||
case LogicalConnective.OR:
|
||||
return any
|
||||
case _:
|
||||
return any
|
||||
|
||||
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
|
||||
|
||||
def has_uses_impl(use: Node) -> bool:
|
||||
if use.op != "call_function":
|
||||
return False
|
||||
if not (
|
||||
isinstance(use.target, torch._ops.OpOverload)
|
||||
or use.target is operator.getitem
|
||||
):
|
||||
return False
|
||||
|
||||
target = cast(torch._ops.OpOverload, use.target)
|
||||
# Process getitem and view
|
||||
if target is operator.getitem or is_view(target):
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
|
||||
|
||||
return use_selector_fn(target)
|
||||
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
|
||||
|
||||
|
||||
def has_uses_tagged_as(
|
||||
target: Node,
|
||||
use_tags: Collection[torch.Tag],
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Is there a use with given tags?
|
||||
"""
|
||||
|
||||
return has_uses(
|
||||
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(
|
||||
target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[GraphModule, list[torch.Tensor]]:
|
||||
|
||||
@ -31,10 +31,8 @@ template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(
|
||||
T val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
@ -75,10 +73,8 @@ template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(
|
||||
ScalarType val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
@ -133,10 +129,8 @@ template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(
|
||||
std::nullopt_t val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
@ -190,10 +184,8 @@ template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(
|
||||
const torch::stable::Tensor& val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
@ -209,10 +201,8 @@ template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
@ -249,10 +239,8 @@ template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
@ -309,10 +297,8 @@ template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -350,10 +336,8 @@ template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
||||
@ -4,12 +4,14 @@ r"""This package adds support for device memory management implemented in CUDA."
|
||||
import collections
|
||||
import contextlib
|
||||
import ctypes
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from inspect import signature
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing_extensions import deprecated, NotRequired
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -29,6 +31,60 @@ if TYPE_CHECKING:
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
# Type definitions for memory profiler
|
||||
class _Frame(TypedDict):
|
||||
"""Frame information from memory profiler snapshots."""
|
||||
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Fields added by FX augmentation (optional)
|
||||
fx_node_op: NotRequired[str]
|
||||
fx_node_name: NotRequired[str]
|
||||
fx_node_target: NotRequired[str]
|
||||
fx_original_trace: NotRequired[str]
|
||||
|
||||
|
||||
class _Block(TypedDict):
|
||||
"""Memory block information."""
|
||||
|
||||
size: int
|
||||
requested_size: int
|
||||
address: int
|
||||
state: str
|
||||
frames: list[_Frame]
|
||||
|
||||
|
||||
class _Segment(TypedDict):
|
||||
"""Memory segment information."""
|
||||
|
||||
address: int
|
||||
total_size: int
|
||||
stream: int
|
||||
segment_type: str
|
||||
allocated_size: int
|
||||
active_size: int
|
||||
blocks: list[_Block]
|
||||
|
||||
|
||||
class _TraceEntry(TypedDict):
|
||||
"""Memory trace entry information."""
|
||||
|
||||
action: str
|
||||
addr: NotRequired[int]
|
||||
frames: list[_Frame]
|
||||
size: int
|
||||
stream: int
|
||||
device_free: NotRequired[int]
|
||||
|
||||
|
||||
class _Snapshot(TypedDict):
|
||||
"""Memory snapshot structure."""
|
||||
|
||||
segments: list[_Segment]
|
||||
device_traces: NotRequired[list[list[_TraceEntry]]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
@ -964,7 +1020,120 @@ def _record_memory_history_impl(
|
||||
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _snapshot(device: "Device" = None):
|
||||
def _augment_frames(frames: list[_Frame]) -> int:
|
||||
"""
|
||||
Augment a list of frames with FX debug information.
|
||||
|
||||
Args:
|
||||
frames: List of frame dictionaries to augment
|
||||
|
||||
Returns:
|
||||
The count of frames that were augmented.
|
||||
"""
|
||||
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
|
||||
|
||||
# Regex pattern to match FX generated files
|
||||
_FX_GENERATED_PATTERN = re.compile(
|
||||
rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
|
||||
)
|
||||
|
||||
count = 0
|
||||
if not frames:
|
||||
return count
|
||||
|
||||
for frame in frames:
|
||||
if "filename" in frame and "line" in frame:
|
||||
filename = frame["filename"]
|
||||
lineno = frame["line"]
|
||||
|
||||
# Check if this looks like an FX generated file
|
||||
if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
|
||||
continue
|
||||
|
||||
# Look up metadata from the global registry
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
metadata = _FX_METADATA_REGISTRY.get(filename)
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
lineno_map = metadata.get("lineno_map", {})
|
||||
node_metadata = metadata.get("node_metadata", {})
|
||||
prologue_start = metadata.get("prologue_start", 0)
|
||||
|
||||
# Get the node index for this line
|
||||
node_idx = lineno_map.get(lineno - prologue_start)
|
||||
|
||||
if node_idx is not None and node_idx in node_metadata:
|
||||
node_info = node_metadata[node_idx]
|
||||
original_trace = node_info.get("stack_trace")
|
||||
node_op = node_info.get("op")
|
||||
node_name = node_info.get("name")
|
||||
node_target = node_info.get("target")
|
||||
|
||||
# Always add node metadata
|
||||
frame["fx_node_op"] = node_op
|
||||
frame["fx_node_name"] = node_name
|
||||
frame["fx_node_target"] = str(node_target)
|
||||
|
||||
# Add original trace if available
|
||||
if original_trace:
|
||||
frame["fx_original_trace"] = original_trace
|
||||
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def _augment_memory_snapshot_stack_traces(
|
||||
snapshot: str | _Snapshot,
|
||||
) -> _Snapshot:
|
||||
"""
|
||||
Augment a memory snapshot with original source stack traces from FX metadata.
|
||||
|
||||
IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
|
||||
that is populated during graph module compilation. It must be called in the same
|
||||
Python process where the FX graphs were compiled. It cannot be used to augment
|
||||
snapshots loaded from disk in a different process.
|
||||
|
||||
Args:
|
||||
snapshot: Either a memory snapshot dict or path to a snapshot pickle file
|
||||
|
||||
Returns:
|
||||
The augmented snapshot dictionary with fx_node_op, fx_node_name,
|
||||
fx_original_trace, and fx_node_info fields added to frames
|
||||
"""
|
||||
|
||||
snapshot_dict: _Snapshot
|
||||
if isinstance(snapshot, str):
|
||||
# Load the memory snapshot
|
||||
with open(snapshot, "rb") as f:
|
||||
snapshot_dict = cast(_Snapshot, pickle.load(f))
|
||||
else:
|
||||
snapshot_dict = snapshot
|
||||
|
||||
# Process stack traces in the snapshot
|
||||
augmented_count = 0
|
||||
|
||||
# Process blocks in segments (for regular allocations)
|
||||
if "segments" in snapshot_dict:
|
||||
for segment in snapshot_dict["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
augmented_count += _augment_frames(block["frames"])
|
||||
|
||||
# Process device traces (for memory history)
|
||||
if "device_traces" in snapshot_dict:
|
||||
for trace_list in snapshot_dict["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
augmented_count += _augment_frames(trace_entry["frames"])
|
||||
|
||||
return snapshot_dict
|
||||
|
||||
|
||||
def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
|
||||
"""Save a snapshot of CUDA memory state at the time it was called.
|
||||
|
||||
The state is represented as a dictionary with the following structure.
|
||||
@ -1012,6 +1181,11 @@ def _snapshot(device: "Device" = None):
|
||||
filename: str
|
||||
line: int
|
||||
name: str
|
||||
# Optional FX debug fields (present when augment_with_fx_traces=True
|
||||
# and the frame corresponds to FX-generated code)
|
||||
fx_node_op: str # FX node operation type (e.g., 'call_function', 'output')
|
||||
fx_node_name: str # FX node name (e.g., 'linear', 'relu_1')
|
||||
fx_original_trace: str # Original model source code stack trace
|
||||
|
||||
|
||||
class TraceEntry(TypedDict):
|
||||
@ -1041,13 +1215,23 @@ def _snapshot(device: "Device" = None):
|
||||
device_free: int # only present for OOM, the amount of
|
||||
# memory cuda still reports to be free
|
||||
|
||||
Args:
|
||||
device: Device to capture snapshot for. If None, captures for current device.
|
||||
augment_with_fx_traces: If True, augment stack trace frames with FX debug information
|
||||
that maps generated FX code back to original model source code.
|
||||
This adds fx_node_op, fx_node_name, fx_original_trace, and
|
||||
fx_node_info fields to Frame objects. Default: False.
|
||||
|
||||
Returns:
|
||||
The Snapshot dictionary object
|
||||
"""
|
||||
return _C._cuda_memorySnapshot(None)
|
||||
s = _C._cuda_memorySnapshot(None)
|
||||
if augment_with_fx_traces:
|
||||
s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type]
|
||||
return s
|
||||
|
||||
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||
def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False):
|
||||
"""
|
||||
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
|
||||
|
||||
@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||
|
||||
Args:
|
||||
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
|
||||
augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information
|
||||
before dumping. This maps generated FX code stack traces
|
||||
back to original model source code. Defaults to False.
|
||||
verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output
|
||||
during augmentation. Defaults to False.
|
||||
"""
|
||||
s = _snapshot()
|
||||
s = _snapshot(augment_with_fx_traces=augment_with_fx_traces)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(s, f)
|
||||
|
||||
|
||||
@ -226,8 +226,10 @@ class PythonCode:
|
||||
# Values in global scope during execution of `src_def`.
|
||||
globals: dict[str, Any]
|
||||
# Optional mapping from the forward function's line number to
|
||||
# node index.
|
||||
# node index. Line number starts at the prologue (i.e. forward()).
|
||||
_lineno_map: Optional[dict[int, Optional[int]]]
|
||||
# The line number of prologue in fn_code
|
||||
_prologue_start: int = 0
|
||||
|
||||
|
||||
def _format_target(base: str, target: str) -> str:
|
||||
@ -854,7 +856,14 @@ class CodeGen:
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
|
||||
# The +4 accounts for the empty lines before prologue in fn_code
|
||||
prologue_start = wrap_stmts.count("\n") + 4
|
||||
return PythonCode(
|
||||
fn_code,
|
||||
globals_,
|
||||
_lineno_map=lineno_map,
|
||||
_prologue_start=prologue_start,
|
||||
)
|
||||
|
||||
|
||||
# Ideally, we'd like to refactor all of the pytree logic into this codegen
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import itertools
|
||||
import linecache
|
||||
import os
|
||||
@ -36,6 +38,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
|
||||
FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
|
||||
|
||||
|
||||
# Normal exec loses the source code, however we can work with
|
||||
@ -61,7 +64,13 @@ class _EvalCacheLoader:
|
||||
|
||||
key = self._get_key()
|
||||
if co_fields:
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
if "co_filename" in co_fields:
|
||||
# If only co_filename is provided, use it directly as the key
|
||||
if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
|
||||
key = co_fields["co_filename"]
|
||||
else:
|
||||
# Full co_fields with all three components
|
||||
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
|
||||
self.eval_cache[key] = src
|
||||
|
||||
# Don't mutate globals so that this loader is only used
|
||||
@ -353,6 +362,36 @@ def _print_readable(
|
||||
return output
|
||||
|
||||
|
||||
def _metadata_hash(code: str, node_metadata: dict) -> str:
|
||||
"""
|
||||
Create a content-addressed hash from code and metadata.
|
||||
|
||||
Args:
|
||||
code: The source code string
|
||||
lineno_map: Mapping from line numbers to node indices
|
||||
node_metadata: Metadata for each node
|
||||
|
||||
Returns:
|
||||
A 51-character base32-encoded hash
|
||||
"""
|
||||
import json
|
||||
|
||||
# Create a deterministic string representation of all components
|
||||
# We use JSON to ensure consistent serialization
|
||||
hash_data = {
|
||||
"code": code,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
hashing_str = json.dumps(hash_data).encode("utf-8")
|
||||
|
||||
# [:51] to strip off the "Q====" suffix common to every hash value.
|
||||
return (
|
||||
base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
|
||||
.decode("utf-8")
|
||||
.lower()
|
||||
)
|
||||
|
||||
|
||||
class _WrappedCall:
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
@ -825,9 +864,47 @@ class {module_name}(torch.nn.Module):
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
node_metadata[i] = {
|
||||
"name": node.name,
|
||||
"op": node.op,
|
||||
"target": str(node.target),
|
||||
"stack_trace": node.meta.get("stack_trace", None),
|
||||
}
|
||||
|
||||
# Generate a content-addressed filename based on hash of code and metadata
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
co_fields = {
|
||||
"co_filename": filename,
|
||||
}
|
||||
|
||||
# Store metadata in global in-memory registry
|
||||
metadata = {
|
||||
"lineno_map": python_code._lineno_map,
|
||||
"prologue_start": python_code._prologue_start,
|
||||
"node_metadata": node_metadata,
|
||||
}
|
||||
|
||||
# Register metadata in the global registry
|
||||
from torch.fx.traceback import _register_fx_metadata
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -38,6 +38,28 @@ current_meta: dict[str, Any] = {}
|
||||
current_replay_node: Optional[Node] = None
|
||||
should_preserve_node_meta = False
|
||||
|
||||
# =============================================================================
|
||||
# FX Metadata Registry for Memory Profiler
|
||||
# =============================================================================
|
||||
# Global in-memory registry for FX metadata
|
||||
# Maps module_name -> metadata dict containing lineno_map and node_metadata
|
||||
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
"""
|
||||
Register FX metadata in the global in-memory registry.
|
||||
|
||||
This is called automatically during graph module compilation to store metadata
|
||||
for later use by memory profiler augmentation.
|
||||
|
||||
Args:
|
||||
module_name: The module identifier (content-addressed filename)
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
|
||||
"""
|
||||
# TODO: add logging to tlparse
|
||||
_FX_METADATA_REGISTRY[module_name] = metadata
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class NodeSourceAction(Enum):
|
||||
|
||||
@ -806,7 +806,29 @@ function format_frames(frames) {
|
||||
}
|
||||
const frame_strings = frames
|
||||
.filter(frameFilter)
|
||||
.map(f => `${f.filename}:${f.line}:${f.name}`);
|
||||
.map(f => {
|
||||
let frame_str = `${f.filename}:${f.line}:${f.name}`;
|
||||
|
||||
// Add FX debug information if available
|
||||
if (f.fx_node_op || f.fx_node_name || f.fx_node_target) {
|
||||
const fx_parts = [];
|
||||
if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`);
|
||||
if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`);
|
||||
if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`);
|
||||
frame_str += `\n >> FX: ${fx_parts.join(', ')}`;
|
||||
}
|
||||
|
||||
if (f.fx_original_trace) {
|
||||
frame_str += `\n >> Original Model Code:`;
|
||||
const original_lines = f.fx_original_trace.trim().split('\n');
|
||||
// Show all lines of the original trace
|
||||
for (const line of original_lines) {
|
||||
frame_str += `\n ${line}`;
|
||||
}
|
||||
}
|
||||
|
||||
return frame_str;
|
||||
});
|
||||
return elideRepeats(frame_strings).join('\n');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user