Compare commits

..

10 Commits

Author SHA1 Message Date
8cef91fb74 prints for static inputs indices 2025-11-04 13:19:14 -08:00
527b1109a8 Delete deprecated fp32 precision warnings (#166956)
The deprecation warning led to warning spamming in PyTorch APIs, like
torch.compile. This is not how a deprecation warning should go: if we
add a deprecation warning, we'd better update our built-in APIs to
prevent warning spam.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166956
Approved by: https://github.com/albanD
2025-11-04 17:50:04 +00:00
clr
3144713325 subproc_pool: Add support for enabling quiesce via a timer (#166467)
This adds the capability to subproc pool to enable quiesce via a timer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166467
Approved by: https://github.com/masnesral
2025-11-04 17:37:41 +00:00
eefa16342c [Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)
Prefer unfused addmm when there is at least a single elemwise/reduction consumer..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165
Approved by: https://github.com/eellison
2025-11-04 17:23:04 +00:00
d02f68f484 [BE] Use [[maybe_unused]] (#166865)
Instead of `(void) foo; // Unused parameter` trick, as this is a C++17 standard feature

Will replace further repetitions of the same pattern soon after
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166865
Approved by: https://github.com/mikaylagawarecki, https://github.com/Skylion007, https://github.com/janeyx99
2025-11-04 17:08:28 +00:00
68eb55c4b2 Add model code stack trace to cuda.memory._snapshot (#166676)
We store a mapping between generated fx graph code and original model code stack trace in `fx.traceback._FX_METADATA_REGISTRY`. And we do a post-processing on the memory snapshot to append the original model stack trace information.

To achieve this, the biggest change we had to do in `aot_eager` mode is to give each generated fx graph a unique stack trace, i.e. it cannot just be `<eval_with_key>`. We set co_filename to **pretend** that the code is from `co_filename` file. Now instead of `<eval_with_key>` in stack trace, we get something like `fx_generated_3a4b5c6d7e8f9a0.py`.

`augment_with_fx_traces` arg is added to `torch.cuda.memory._snapshot` and `_dump_snapshot`. When the arg is set to True, a post-processing will run to populate the original model stack trace to the snapshot frames.

The new behavior of GraphModule can be controlled by `TORCH_ENRICH_RPOFILER_STACK_TRACE` or `_dynamo.config.enrich_profiler_metadata=True`.

Alternative:

Instead of setting co_filename, we can also do it like below:
Note that if we do it this way, we will need to dump the file to make the graph module torch-scriptable. TorchScript requires source access in order to carry out compilation, so we need to make sure original .py files are available.
```
        key = filename
        globals_copy = globals.copy()
        globals_copy["__file__"] = key
        globals_copy["__name__"] = key
        linecache.lazycache(key, globals_copy)
        exec(compile(src, key, "exec"), globals)
````

Other changes:

- Update `MemoryViz.js` to display fx node information and original model code if exist

```
python test/test_fx.py -k test_lineno_map
python test/test_fx.py -k test_custom_traceback_raised
python test/test_public_bindings.py
python test/test_cuda.py -k test_fx_memory
python test/test_fx.py -k test_informative_co_filename
python test/test_fx.py -k test_autowrap_functions
python test/dynamo/test_utils.py -k test_inductor_provenance
```

```python
# Profile with memory snapshot
torch.cuda.memory._record_memory_history()

with  torch._dynamo.config.patch("enrich_profiler_stack_trace", True):
    compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
    result = compiled(torch.randn(10, 10, device="cuda:0"))

torch.cuda.memory._dump_snapshot("memory_snapshot.pickle", augment_with_fx_traces=True)
torch.cuda.memory._record_memory_history(enabled=None)
```

<img width="913" height="711" alt="Screenshot 2025-10-30 at 10 40 44 AM" src="https://github.com/user-attachments/assets/8d7a1833-f98d-4756-b666-1d63ab57b27b" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166676
Approved by: https://github.com/albanD, https://github.com/ezyang
2025-11-04 17:01:02 +00:00
8d4b8ab430 [ez] Print some more test timing info in the logs (#166447)
You can just subtract timestamps, but this makes it easier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166447
Approved by: https://github.com/Skylion007
2025-11-04 16:45:22 +00:00
afd50bdd29 [CI] Use smaller amx + avx2 runners for inductor test? (#164989)
Results from CI:
No failures but generally takes longer, maybe ~20% increase in time?
But the smaller runner is ~25% of the cost of the current runner, so in terms of cost this is a decrease

If the 20% is too much, we can try the 4x larger runners, which are about half the cost of the current runner, so it would probably still result in cost savings with hopefully less impact to time

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164989
Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn
2025-11-04 16:43:06 +00:00
56dfd4c74b Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526)
Summary:

* Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA
* Add testing

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526
Approved by: https://github.com/drisspg, https://github.com/ngimel
2025-11-04 15:53:16 +00:00
24db5c4451 [inductor] do not hard fail on FakePG with nccl estimator (#166869)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166869
Approved by: https://github.com/eellison
ghstack dependencies: #166521
2025-11-04 15:22:38 +00:00
29 changed files with 729 additions and 130 deletions

View File

@ -38,9 +38,9 @@ runs:
run: |
python3 .github/scripts/pytest_cache.py \
--download \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
--pr_identifier "$GITHUB_REF" \
--job_identifier "$JOB_IDENTIFIER" \
--temp_dir "$RUNNER_TEMP" \
--repo "$REPO" \
--bucket "$BUCKET" \
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier $GITHUB_REF \
--job_identifier $JOB_IDENTIFIER \
--temp_dir $RUNNER_TEMP \
--repo $REPO \
--bucket $BUCKET \

View File

@ -47,11 +47,11 @@ runs:
run: |
python3 .github/scripts/pytest_cache.py \
--upload \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
--pr_identifier "$GITHUB_REF" \
--job_identifier "$JOB_IDENTIFIER" \
--sha "$SHA" \
--test_config "$TEST_CONFIG" \
--shard "$SHARD" \
--repo "$REPO" \
--temp_dir "$RUNNER_TEMP" \
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier $GITHUB_REF \
--job_identifier $JOB_IDENTIFIER \
--sha $SHA \
--test_config $TEST_CONFIG \
--shard $SHARD \
--repo $REPO \
--temp_dir $RUNNER_TEMP \

View File

@ -115,10 +115,10 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
]}
secrets: inherit

View File

@ -84,13 +84,13 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
]}
build-additional-packages: "vision audio torchao"

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
#endif
namespace at {
namespace {
/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
@ -41,16 +39,6 @@ namespace {
->rnn
*/
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Float32Backend str2backend(const std::string& name) {
if (name == "generic")
return Float32Backend::GENERIC;
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
} else {
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
}
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
}
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;

View File

@ -59,6 +59,24 @@
// forward declare
class cublasCommonArgs;
#ifndef _WIN32
namespace fbgemm_gpu {
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
// To update supported ops means a submodule bump, which is.. painful. Instead, we
// can simply forward-declare the methods we want to use.. Works at least as a short-term
// thing, but should still be fixed somewhere/somehow.
at::Tensor f4f4bf16(
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
bool use_mx);
} // namespace fbgemm_gpu
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
@ -1087,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
auto K_multiplier = 2;
#ifdef USE_ROCM
// AMD
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
#else
// NVIDIA
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
#endif
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifdef USE_ROCM
// AMD
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
#else
// NVIDIA
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
#ifdef USE_ROCM
// AMD
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
@ -1121,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
#else
// NVIDIA
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
// but we have one we need to use. Two clear options are to copy into
// our output (slow), or use a move-assignment-operator (faster).
// However, the compiler can complain about the explicit move preventing
// copy elision because the return from f4f4bf16 is a temporary object.
// So we don't explicitly move, and trust the compiler here...
// In the longer term this should be fixed on the FBGemm side.
out = fbgemm_gpu::f4f4bf16(
mat_a,
mat_b.transpose(-2, -1),
scale_a,
scale_b,
std::nullopt, /* global_scale */
true /* use_mx */
);
return out;
#endif
#endif
}
Tensor&
@ -1250,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
// Handle fp4 packed-K dimension
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
mat_a.sizes()[1] % 16 == 0,
K_multiplier * mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
mat_a.sizes()[1],
K_multiplier * mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.

View File

@ -4,6 +4,7 @@ import os
import tempfile
from threading import Event
import torch._inductor.config as config
from torch._inductor.compile_worker.subproc_pool import (
raise_testexc,
SubprocException,
@ -16,9 +17,12 @@ from torch.testing._internal.inductor_utils import HAS_CPU
class TestCompileWorker(TestCase):
def make_pool(self, size):
return SubprocPool(size)
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_basic_jobs(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(operator.add, 100, 1)
b = pool.submit(operator.sub, 100, 1)
@ -29,7 +33,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_exception(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(raise_testexc)
with self.assertRaisesRegex(
@ -42,7 +46,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_crash(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
with self.assertRaises(Exception):
a = pool.submit(os._exit, 1)
@ -58,7 +62,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(operator.add, 100, 1)
pool.quiesce()
@ -75,7 +79,7 @@ class TestCompileWorker(TestCase):
os.environ["ROLE_RANK"] = "0"
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
pool.submit(operator.add, 100, 1)
self.assertEqual(os.path.exists(temp_log.name), True)
@ -83,6 +87,12 @@ class TestCompileWorker(TestCase):
pool.shutdown()
@config.patch("quiesce_async_compile_time", 0.1)
class TestCompileWorkerWithTimer(TestCompileWorker):
def make_pool(self, size):
return SubprocPool(size, quiesce=True)
class TestTimer(TestCase):
def test_basics(self):
done = Event()

View File

@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
forward_wrapper = wrapper_codes[0]
# make sure the load for softmax is aligned
if bias:
# addmm -> mm + bias and bias is fused with softmax
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
else:
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
self.assertTrue(
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
softmax_load_str in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -15280,7 +15280,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_native_layer_norm_relu",
"triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15293,7 +15293,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_ReLU",
"triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -1826,9 +1826,14 @@ def run_test_module(
test_name = test.name
# Printing the date here can help diagnose which tests are slow
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
start = time.perf_counter()
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
handler = CUSTOM_HANDLERS.get(test_name, run_test)
return_code = handler(test, test_directory, options)
end = time.perf_counter()
print_to_stderr(
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
)
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
f"While running {str(test)} got non integer return code {return_code}"
)

View File

@ -7413,6 +7413,140 @@ class TestCudaDeviceParametrized(TestCase):
)
class TestFXMemoryProfiler(TestCase):
"""Tests for memory profiler augmentation with original stack traces."""
def collect_frames(
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
):
"""Collects all frames that has node metadata from a memory snapshot."""
# Collect all frames with FX metadata
fx_frames = []
# Check device traces for FX debug fields
if collect_device_traces and "device_traces" in augmented_snapshot:
for trace_list in augmented_snapshot["device_traces"]:
for trace_entry in trace_list:
if isinstance(trace_entry, dict) and "frames" in trace_entry:
for frame in trace_entry["frames"]:
if isinstance(frame, dict):
# Check for FX debug fields
if "fx_node_op" in frame or "fx_node_name" in frame:
fx_frames.append(frame)
# Check segments/blocks for FX debug fields
if collect_segments and "segments" in augmented_snapshot:
for segment in augmented_snapshot["segments"]:
if "blocks" in segment:
for block in segment["blocks"]:
if "frames" in block:
for frame in block["frames"]:
if isinstance(frame, dict):
if "fx_node_op" in frame or "fx_node_name" in frame:
fx_frames.append(frame)
return fx_frames
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_fx_memory_profiler_augmentation(self):
"""Test that memory snapshots are augmented with FX debug information."""
# Create a simple model
class MLPModule(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
a = self.net1(x)
b = self.relu(a)
c = self.net2(b)
return c
device = "cuda"
mod = MLPModule(device)
with tempfile.TemporaryDirectory() as tmpdir:
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(
augment_with_fx_traces=True
)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
torch.cuda.empty_cache()
fx_frames = self.collect_frames(augmented_snapshot)
if TEST_WITH_ROCM:
self.assertGreater(len(fx_frames), 0)
else:
self.assertEqual(len(fx_frames), 12)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
# Test that when we have two graphs with the same src_code, they're not hashed
# to the same metadata
class MLPModule2(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
d = self.net1(x)
e = self.relu(d)
f = self.net2(e)
return f
mod = MLPModule2(device)
with tempfile.TemporaryDirectory() as tmpdir:
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(
augment_with_fx_traces=True
)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
# avoid collecting segments from previous run for unit test purpose
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
self.assertGreater(len(fx_frames), 0)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
instantiate_parametrized_tests(TestCuda)
instantiate_parametrized_tests(TestCudaMallocAsync)
instantiate_parametrized_tests(TestCompileKernel)

View File

@ -771,6 +771,7 @@ class TestFX(JitTestCase):
gm = GraphModule(tracer.root, graph)
expected = {1: 2, 2: 3, 3: 4, 4: 5}
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
self.assertEqual(gm._prologue_start, 4)
# test custom codegen
def transform_code(code):
@ -780,6 +781,7 @@ class TestFX(JitTestCase):
gm.recompile()
expected = {2: 2, 3: 3, 4: 4, 5: 5}
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
self.assertEqual(gm._prologue_start, 4)
def test_graph_unique_names_manual(self):
graph: torch.fx.Graph = torch.fx.Graph()

View File

@ -209,42 +209,36 @@ def infer_scale_swizzle(mat, scale):
] == math.ceil(mat.shape[1] // 128):
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
# if we're checking for nvfp4, need to adjust for packed-K
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
# NVFP4
if (
(scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e4m3fn
):
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
# MXFP4 w/o swizzle
if (
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
# MX formats
if not torch.version.hip:
# MXFP8 w/ swizzle
# MX w/swizzle (NVIDIA)
if (
(scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
else:
# MXFP8 w/o swizzle
# MX w/o swizzle (AMD)
if (
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
@ -1868,7 +1862,7 @@ class TestFP8Matmul(TestCase):
(127, 96, 1024),
(1025, 128, 96)
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
@ -1882,8 +1876,12 @@ class TestFP8Matmul(TestCase):
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
if K % BLOCK_SIZE != 0:
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
require_exact_match = True
approx_match_sqnr_target = 22.0
@ -2061,7 +2059,7 @@ class TestFP8Matmul(TestCase):
B = B.clamp(min=min_val, max=max_val)
B = _bfloat16_to_float4_e2m1fn_x2(B)
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
C_ref = A_ref @ B_ref.t()

View File

@ -739,6 +739,12 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -179,6 +179,9 @@ def aot_stage1_graph_capture(
)
)
print(f"in aot_stage1_graph_capture. maybe_subclass_meta.fw_metadata.static_input_indices:{maybe_subclass_meta.fw_metadata.static_input_indices if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None else None}")
print(f"in aot_stage1_graph_capture. aot_state.fw_metadata.static_input_indices:{aot_state.fw_metadata.static_input_indices}")
return AOTGraphCapture(
wrappers=wrappers,
graph_module=graph,

View File

@ -423,6 +423,10 @@ def estimate_nccl_collective_runtime_from_fx_node(
from torch.distributed.distributed_c10d import _resolve_process_group
pg = _resolve_process_group(group_name)
if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
# nccl estimator requires real process group
return None
fn = fx_node.target
assert isinstance(fn, torch._ops.OpOverload)
with torch.distributed._time_estimator(group=pg) as time_estimator:

View File

@ -2318,7 +2318,7 @@ def compile_fx_forward(
# force the outputs of invoke_subgraph subgraph to follow the
# original strides
_recursive_record_user_visible_output_idxs(gm)
print(f"in compile_fx_foward. static_input_idxs:{get_static_input_idxs(fixed)}")
return inner_compile(
gm,
example_inputs,

View File

@ -24,6 +24,7 @@ from typing_extensions import Never, ParamSpec
import torch._thread_safe_fork # noqa: F401
from torch._inductor import config
from torch._inductor.codecache import torch_key
from torch._inductor.compile_worker.timer import Timer
from torch._inductor.compile_worker.tracked_process_pool import (
TrackedProcessPoolExecutor,
)
@ -132,6 +133,7 @@ class SubprocPool:
nprocs: int,
pickler: Optional[SubprocPickler] = None,
kind: SubprocKind = SubprocKind.FORK,
quiesce: bool = False,
) -> None:
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
self.pickler = pickler or SubprocPickler()
@ -216,6 +218,13 @@ class SubprocPool:
"pytorch.wait_counter.subproc_pool.first_job"
).guard()
if quiesce:
self.timer: Optional[Timer] = Timer(
config.quiesce_async_compile_time, self.quiesce
)
else:
self.timer = None
# Start thread last to ensure all member variables are initialized
# before any access.
self.read_thread.start()
@ -288,6 +297,8 @@ class SubprocPool:
with self.futures_lock:
if not self.running:
return
if self.timer:
self.timer.record_call()
if isinstance(result, _SubprocExceptionInfo):
# An exception occurred in the submitted job
self.pending_futures[job_id].set_exception(
@ -322,6 +333,8 @@ class SubprocPool:
with self.write_lock:
if not self.running:
return
if self.timer:
self.timer.quit()
self.running = False
self.running_waitcounter.__exit__()
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)

View File

@ -17,7 +17,7 @@ class Timer:
self.background_thread: Optional[Thread] = None
self.last_called: Optional[float] = None
self.duration = duration
self.sleep_time = 60
self.sleep_time = duration / 2
self.call = call
self.exit = False

View File

@ -964,6 +964,11 @@ quiesce_async_compile_pool: bool = Config(
default=False,
)
# Time in seconds to wait before quiescing
quiesce_async_compile_time: int = Config(
default=60,
)
# Whether or not to enable statically launching CUDA kernels
# compiled by triton (instead of using triton's own launcher)
use_static_cuda_launcher: bool = static_cuda_launcher_default()

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
@register_graph_pattern(

View File

@ -549,6 +549,70 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]:

View File

@ -31,10 +31,8 @@ template <typename T>
struct FromImpl {
static StableIValue call(
T val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
@ -75,10 +73,8 @@ template <>
struct FromImpl<ScalarType> {
static StableIValue call(
ScalarType val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
switch (val) {
case ScalarType::Byte:
return from(aoti_torch_dtype_uint8());
@ -133,10 +129,8 @@ template <>
struct FromImpl<std::nullopt_t> {
static StableIValue call(
std::nullopt_t val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
return from(nullptr);
}
};
@ -190,10 +184,8 @@ template <>
struct FromImpl<torch::stable::Tensor> {
static StableIValue call(
const torch::stable::Tensor& val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
AtenTensorHandle new_ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
return from(new_ath);
@ -209,10 +201,8 @@ template <typename T>
struct ToImpl {
static T call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the
@ -249,10 +239,8 @@ template <>
struct ToImpl<ScalarType> {
static ScalarType call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
int32_t shim_scalartype = to<int32_t>(val);
if (shim_scalartype == aoti_torch_dtype_uint8()) {
return ScalarType::Byte;
@ -309,10 +297,8 @@ template <>
struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
// val should be equivalent to from(nullptr)
return std::nullopt;
}
@ -350,10 +336,8 @@ template <>
struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
return torch::stable::Tensor(to<AtenTensorHandle>(val));
}
};

View File

@ -4,12 +4,14 @@ r"""This package adds support for device memory management implemented in CUDA."
import collections
import contextlib
import ctypes
import os
import pickle
import re
import sys
import warnings
from inspect import signature
from typing import Any, Literal, Optional, TYPE_CHECKING
from typing_extensions import deprecated
from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict
from typing_extensions import deprecated, NotRequired
import torch
from torch import _C
@ -29,6 +31,60 @@ if TYPE_CHECKING:
from torch.types import Device
# Type definitions for memory profiler
class _Frame(TypedDict):
"""Frame information from memory profiler snapshots."""
filename: str
line: int
name: str
# Fields added by FX augmentation (optional)
fx_node_op: NotRequired[str]
fx_node_name: NotRequired[str]
fx_node_target: NotRequired[str]
fx_original_trace: NotRequired[str]
class _Block(TypedDict):
"""Memory block information."""
size: int
requested_size: int
address: int
state: str
frames: list[_Frame]
class _Segment(TypedDict):
"""Memory segment information."""
address: int
total_size: int
stream: int
segment_type: str
allocated_size: int
active_size: int
blocks: list[_Block]
class _TraceEntry(TypedDict):
"""Memory trace entry information."""
action: str
addr: NotRequired[int]
frames: list[_Frame]
size: int
stream: int
device_free: NotRequired[int]
class _Snapshot(TypedDict):
"""Memory snapshot structure."""
segments: list[_Segment]
device_traces: NotRequired[list[list[_TraceEntry]]]
__all__ = [
"caching_allocator_alloc",
"caching_allocator_delete",
@ -964,7 +1020,120 @@ def _record_memory_history_impl(
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
def _snapshot(device: "Device" = None):
def _augment_frames(frames: list[_Frame]) -> int:
"""
Augment a list of frames with FX debug information.
Args:
frames: List of frame dictionaries to augment
Returns:
The count of frames that were augmented.
"""
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
# Regex pattern to match FX generated files
_FX_GENERATED_PATTERN = re.compile(
rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
)
count = 0
if not frames:
return count
for frame in frames:
if "filename" in frame and "line" in frame:
filename = frame["filename"]
lineno = frame["line"]
# Check if this looks like an FX generated file
if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
continue
# Look up metadata from the global registry
from torch.fx.traceback import _FX_METADATA_REGISTRY
metadata = _FX_METADATA_REGISTRY.get(filename)
if metadata is None:
continue
lineno_map = metadata.get("lineno_map", {})
node_metadata = metadata.get("node_metadata", {})
prologue_start = metadata.get("prologue_start", 0)
# Get the node index for this line
node_idx = lineno_map.get(lineno - prologue_start)
if node_idx is not None and node_idx in node_metadata:
node_info = node_metadata[node_idx]
original_trace = node_info.get("stack_trace")
node_op = node_info.get("op")
node_name = node_info.get("name")
node_target = node_info.get("target")
# Always add node metadata
frame["fx_node_op"] = node_op
frame["fx_node_name"] = node_name
frame["fx_node_target"] = str(node_target)
# Add original trace if available
if original_trace:
frame["fx_original_trace"] = original_trace
count += 1
return count
def _augment_memory_snapshot_stack_traces(
snapshot: str | _Snapshot,
) -> _Snapshot:
"""
Augment a memory snapshot with original source stack traces from FX metadata.
IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
that is populated during graph module compilation. It must be called in the same
Python process where the FX graphs were compiled. It cannot be used to augment
snapshots loaded from disk in a different process.
Args:
snapshot: Either a memory snapshot dict or path to a snapshot pickle file
Returns:
The augmented snapshot dictionary with fx_node_op, fx_node_name,
fx_original_trace, and fx_node_info fields added to frames
"""
snapshot_dict: _Snapshot
if isinstance(snapshot, str):
# Load the memory snapshot
with open(snapshot, "rb") as f:
snapshot_dict = cast(_Snapshot, pickle.load(f))
else:
snapshot_dict = snapshot
# Process stack traces in the snapshot
augmented_count = 0
# Process blocks in segments (for regular allocations)
if "segments" in snapshot_dict:
for segment in snapshot_dict["segments"]:
if "blocks" in segment:
for block in segment["blocks"]:
if "frames" in block:
augmented_count += _augment_frames(block["frames"])
# Process device traces (for memory history)
if "device_traces" in snapshot_dict:
for trace_list in snapshot_dict["device_traces"]:
for trace_entry in trace_list:
if isinstance(trace_entry, dict) and "frames" in trace_entry:
augmented_count += _augment_frames(trace_entry["frames"])
return snapshot_dict
def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
"""Save a snapshot of CUDA memory state at the time it was called.
The state is represented as a dictionary with the following structure.
@ -1012,6 +1181,11 @@ def _snapshot(device: "Device" = None):
filename: str
line: int
name: str
# Optional FX debug fields (present when augment_with_fx_traces=True
# and the frame corresponds to FX-generated code)
fx_node_op: str # FX node operation type (e.g., 'call_function', 'output')
fx_node_name: str # FX node name (e.g., 'linear', 'relu_1')
fx_original_trace: str # Original model source code stack trace
class TraceEntry(TypedDict):
@ -1041,13 +1215,23 @@ def _snapshot(device: "Device" = None):
device_free: int # only present for OOM, the amount of
# memory cuda still reports to be free
Args:
device: Device to capture snapshot for. If None, captures for current device.
augment_with_fx_traces: If True, augment stack trace frames with FX debug information
that maps generated FX code back to original model source code.
This adds fx_node_op, fx_node_name, fx_original_trace, and
fx_node_info fields to Frame objects. Default: False.
Returns:
The Snapshot dictionary object
"""
return _C._cuda_memorySnapshot(None)
s = _C._cuda_memorySnapshot(None)
if augment_with_fx_traces:
s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type]
return s
def _dump_snapshot(filename="dump_snapshot.pickle"):
def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False):
"""
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"):
Args:
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information
before dumping. This maps generated FX code stack traces
back to original model source code. Defaults to False.
verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output
during augmentation. Defaults to False.
"""
s = _snapshot()
s = _snapshot(augment_with_fx_traces=augment_with_fx_traces)
with open(filename, "wb") as f:
pickle.dump(s, f)

View File

@ -226,8 +226,10 @@ class PythonCode:
# Values in global scope during execution of `src_def`.
globals: dict[str, Any]
# Optional mapping from the forward function's line number to
# node index.
# node index. Line number starts at the prologue (i.e. forward()).
_lineno_map: Optional[dict[int, Optional[int]]]
# The line number of prologue in fn_code
_prologue_start: int = 0
def _format_target(base: str, target: str) -> str:
@ -854,7 +856,14 @@ class CodeGen:
{prologue}
{code}"""
return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
# The +4 accounts for the empty lines before prologue in fn_code
prologue_start = wrap_stmts.count("\n") + 4
return PythonCode(
fn_code,
globals_,
_lineno_map=lineno_map,
_prologue_start=prologue_start,
)
# Ideally, we'd like to refactor all of the pytree logic into this codegen

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import base64
import contextlib
import copy
import hashlib
import itertools
import linecache
import os
@ -36,6 +38,7 @@ __all__ = [
]
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
# Normal exec loses the source code, however we can work with
@ -61,7 +64,13 @@ class _EvalCacheLoader:
key = self._get_key()
if co_fields:
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
if "co_filename" in co_fields:
# If only co_filename is provided, use it directly as the key
if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
key = co_fields["co_filename"]
else:
# Full co_fields with all three components
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
self.eval_cache[key] = src
# Don't mutate globals so that this loader is only used
@ -353,6 +362,36 @@ def _print_readable(
return output
def _metadata_hash(code: str, node_metadata: dict) -> str:
"""
Create a content-addressed hash from code and metadata.
Args:
code: The source code string
lineno_map: Mapping from line numbers to node indices
node_metadata: Metadata for each node
Returns:
A 51-character base32-encoded hash
"""
import json
# Create a deterministic string representation of all components
# We use JSON to ensure consistent serialization
hash_data = {
"code": code,
"node_metadata": node_metadata,
}
hashing_str = json.dumps(hash_data).encode("utf-8")
# [:51] to strip off the "Q====" suffix common to every hash value.
return (
base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
.decode("utf-8")
.lower()
)
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
@ -825,9 +864,47 @@ class {module_name}(torch.nn.Module):
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):
node_metadata[i] = {
"name": node.name,
"op": node.op,
"target": str(node.target),
"stack_trace": node.meta.get("stack_trace", None),
}
# Generate a content-addressed filename based on hash of code and metadata
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
co_fields = {
"co_filename": filename,
}
# Store metadata in global in-memory registry
metadata = {
"lineno_map": python_code._lineno_map,
"prologue_start": python_code._prologue_start,
"node_metadata": node_metadata,
}
# Register metadata in the global registry
from torch.fx.traceback import _register_fx_metadata
_register_fx_metadata(filename, metadata)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -38,6 +38,28 @@ current_meta: dict[str, Any] = {}
current_replay_node: Optional[Node] = None
should_preserve_node_meta = False
# =============================================================================
# FX Metadata Registry for Memory Profiler
# =============================================================================
# Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
"""
Register FX metadata in the global in-memory registry.
This is called automatically during graph module compilation to store metadata
for later use by memory profiler augmentation.
Args:
module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
"""
# TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata
@compatibility(is_backward_compatible=False)
class NodeSourceAction(Enum):

View File

@ -806,7 +806,29 @@ function format_frames(frames) {
}
const frame_strings = frames
.filter(frameFilter)
.map(f => `${f.filename}:${f.line}:${f.name}`);
.map(f => {
let frame_str = `${f.filename}:${f.line}:${f.name}`;
// Add FX debug information if available
if (f.fx_node_op || f.fx_node_name || f.fx_node_target) {
const fx_parts = [];
if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`);
if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`);
if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`);
frame_str += `\n >> FX: ${fx_parts.join(', ')}`;
}
if (f.fx_original_trace) {
frame_str += `\n >> Original Model Code:`;
const original_lines = f.fx_original_trace.trim().split('\n');
// Show all lines of the original trace
for (const line of original_lines) {
frame_str += `\n ${line}`;
}
}
return frame_str;
});
return elideRepeats(frame_strings).join('\n');
}