Revert "Add provenance to inductor IR nodes created after graph.run (#164255)"

This reverts commit b9e73e639e36f3aa628752161711e68878231b30.

Reverted https://github.com/pytorch/pytorch/pull/164255 on behalf of https://github.com/jeffdaily due to broke rocm; inductor/test_provenance_tracing.py::TestProvenanceTracingStackTraces::test_deferred_triton_kernels [GH job link](https://github.com/pytorch/pytorch/actions/runs/18200790301/job/51821738132) [HUD commit link](b9e73e639e) ([comment](https://github.com/pytorch/pytorch/pull/164255#issuecomment-3363360088))
This commit is contained in:
PyTorch MergeBot
2025-10-02 22:01:41 +00:00
parent f465ea6752
commit a34797e031
5 changed files with 28 additions and 80 deletions

View File

@ -1497,7 +1497,7 @@ class TestMaxAutotune(TestCase):
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_mm_0.run"
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
@ -1518,7 +1518,7 @@ class TestMaxAutotune(TestCase):
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_.*_0.run"
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(

View File

@ -13,7 +13,6 @@ import zipfile
from pathlib import Path
import torch
from torch._C import FileCheck
from torch._dynamo.utils import detect_fake_mode
from torch._inductor import config
from torch._inductor.debug import (
@ -24,7 +23,6 @@ from torch._inductor.debug import (
)
from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.triton_utils import requires_cuda_and_triton
@ -592,47 +590,6 @@ class TestProvenanceTracingStackTraces(TestCase):
f"Mismatch for key: {key}",
)
@torch._inductor.config.patch(
{"trace.provenance_tracking_level": 2, "max_autotune_gemm_backends": "ATEN"}
)
@requires_cuda_and_triton
def test_deferred_triton_kernels(self):
def foo(m, inp):
a = m(inp)
return a
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
m = torch.nn.Linear(512, 512, bias=True).half().cuda()
inp = torch.rand([1, 512]).half().cuda()
expected = {
"extern_kernels.bias_addmm:1": [
"a = m(inp)",
],
}
with self._setup_provenance_capture() as payload_buffer:
with torch.no_grad():
_, out_code = run_and_get_code(foo_c, m, inp)
payload_content = payload_buffer.getvalue().strip()
data = json.loads(payload_content)
self.assertEqual(set(data.keys()), set(expected.keys()))
for key, expected_lines in expected.items():
actual_lines = [self.extract_code_line(s) for s in data[key]]
self.assertEqual(
sorted(actual_lines),
sorted(expected_lines),
f"Mismatch for key: {key}",
)
# Check that debug handle is in the output code
FileCheck().check(
"Topologically Sorted Source Nodes: [a], Original ATen: [aten.t, aten.addmm]"
).check("[Provenance debug handles] extern_kernels.bias_addmm:1").check(
"extern_kernels.bias_addmm"
).run(out_code[0])
def _check_kernel_information_json(self, kernel_info, expected_kernels):
"""Validate kernel information JSON structure and content."""
self.assertIsInstance(kernel_info, dict)

View File

@ -68,7 +68,6 @@ from .exc import (
)
from .fx_utils import count_flops_fx
from .ir import (
assign_origin_node,
Constant,
DonatedBuffer,
FixedLayout,
@ -1835,7 +1834,31 @@ class GraphLowering(torch.fx.Interpreter):
if curr.has_large_inner_fn(threshold=100):
result.realize()
assign_origin_node(result, n)
# This is not complete, but it doesn't have to be: origin_node
# tracking is best effort. The logic here critically relies on direct
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
# to get views to work. Feel free to add any extra cases as needed.
#
# Note: we can't YOLO tree_map over this result, because if there are
# buffers or a view involved, we might not be able to validly assign
# the origin_node here.
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
if isinstance(result.data.data, ir.Loops):
result.data.data._post_init_setattr("origin_node", n)
elif isinstance(result.data.data, ir.Buffer):
result.data.data._post_init_setattr("origin_node", n)
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
result.data.data.data, ir.Loops
):
result.data.data.data._post_init_setattr("origin_node", n)
# Not really multi-output, can straightforwardly recurse in
elif (
isinstance(result.data.data, ir.MultiOutput)
and not result.data.data.indices
):
if isinstance(result.data.data.inputs[0], ir.Buffer):
result.data.data.inputs[0]._post_init_setattr("origin_node", n)
self.register_users_of(result)
new_unbacked_defs = OrderedSet[sympy.Symbol]()

View File

@ -9399,30 +9399,3 @@ def maybe_free_symbols(s: object) -> OrderedSet[Symbol]:
return free_symbols(s)
else:
return OrderedSet()
def assign_origin_node(result: Any, n: torch.fx.Node) -> None:
# This is not complete, but it doesn't have to be: origin_node
# tracking is best effort. The logic here critically relies on direct
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
# to get views to work. Feel free to add any extra cases as needed.
#
# Note: we can't YOLO tree_map over this result, because if there are
# buffers or a view involved, we might not be able to validly assign
# the origin_node here.
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
if isinstance(result.data.data, Loops):
result.data.data._post_init_setattr("origin_node", n)
elif isinstance(result.data.data, Buffer):
result.data.data._post_init_setattr("origin_node", n)
if isinstance(result.data.data, ComputedBuffer) and isinstance(
result.data.data.data, Loops
):
result.data.data.data._post_init_setattr("origin_node", n)
# Not really multi-output, can straightforwardly recurse in
elif (
isinstance(result.data.data, MultiOutput)
and not result.data.data.indices
):
if isinstance(result.data.data.inputs[0], Buffer):
result.data.data.inputs[0]._post_init_setattr("origin_node", n)

View File

@ -50,7 +50,6 @@ from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .exc import GPUTooOldForTriton, TritonMissing
from .fx_utils import count_flops_fx
from .ir import (
assign_origin_node,
get_device_type,
GraphPartitionSignature,
MultiOutput,
@ -3164,16 +3163,12 @@ class Scheduler:
node.node.finalize_as_triton_caller(min_node_unfused)
continue
with ir.IRNode.current_origins(multi_node.origins):
out_tensorbox = min_node_unfused.output_node()
out_tensorbox = min_node_unfused.output_node()
out_storage = out_tensorbox.data # type: ignore[union-attr]
assert isinstance(out_storage, ir.StorageBox)
out_buffer = out_storage.data
assert isinstance(out_buffer, ir.OperationBuffer)
if multi_node.origin_node:
assign_origin_node(out_tensorbox, multi_node.origin_node)
out_buffer.layout = multi_node.layout
replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)