diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 4a9ad6e0a081..3419834d8567 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1497,7 +1497,7 @@ class TestMaxAutotune(TestCase): ).run(code[0]) else: FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_mm_0.run" + "triton_.*_fused_0.run" ).check("decompose_k").run(code[0]) check_divisors(code) torch.testing.assert_close( @@ -1518,7 +1518,7 @@ class TestMaxAutotune(TestCase): ).run(code[0]) else: FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_.*_0.run" + "triton_.*_fused_0.run" ).check("decompose_k").run(code[0]) check_divisors(code) torch.testing.assert_close( diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index c898e6276537..16fefea92efb 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -13,7 +13,6 @@ import zipfile from pathlib import Path import torch -from torch._C import FileCheck from torch._dynamo.utils import detect_fake_mode from torch._inductor import config from torch._inductor.debug import ( @@ -24,7 +23,6 @@ from torch._inductor.debug import ( ) from torch._inductor.fx_passes.post_grad import post_grad_passes from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -592,47 +590,6 @@ class TestProvenanceTracingStackTraces(TestCase): f"Mismatch for key: {key}", ) - @torch._inductor.config.patch( - {"trace.provenance_tracking_level": 2, "max_autotune_gemm_backends": "ATEN"} - ) - @requires_cuda_and_triton - def test_deferred_triton_kernels(self): - def foo(m, inp): - a = m(inp) - return a - - foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) - - m = torch.nn.Linear(512, 512, bias=True).half().cuda() - inp = torch.rand([1, 512]).half().cuda() - - expected = { - "extern_kernels.bias_addmm:1": [ - "a = m(inp)", - ], - } - - with self._setup_provenance_capture() as payload_buffer: - with torch.no_grad(): - _, out_code = run_and_get_code(foo_c, m, inp) - payload_content = payload_buffer.getvalue().strip() - data = json.loads(payload_content) - self.assertEqual(set(data.keys()), set(expected.keys())) - for key, expected_lines in expected.items(): - actual_lines = [self.extract_code_line(s) for s in data[key]] - self.assertEqual( - sorted(actual_lines), - sorted(expected_lines), - f"Mismatch for key: {key}", - ) - - # Check that debug handle is in the output code - FileCheck().check( - "Topologically Sorted Source Nodes: [a], Original ATen: [aten.t, aten.addmm]" - ).check("[Provenance debug handles] extern_kernels.bias_addmm:1").check( - "extern_kernels.bias_addmm" - ).run(out_code[0]) - def _check_kernel_information_json(self, kernel_info, expected_kernels): """Validate kernel information JSON structure and content.""" self.assertIsInstance(kernel_info, dict) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 5e5a0e631d9c..1f48b61c2a10 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -68,7 +68,6 @@ from .exc import ( ) from .fx_utils import count_flops_fx from .ir import ( - assign_origin_node, Constant, DonatedBuffer, FixedLayout, @@ -1835,7 +1834,31 @@ class GraphLowering(torch.fx.Interpreter): if curr.has_large_inner_fn(threshold=100): result.realize() - assign_origin_node(result, n) + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): + if isinstance(result.data.data, ir.Loops): + result.data.data._post_init_setattr("origin_node", n) + elif isinstance(result.data.data, ir.Buffer): + result.data.data._post_init_setattr("origin_node", n) + if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( + result.data.data.data, ir.Loops + ): + result.data.data.data._post_init_setattr("origin_node", n) + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, ir.MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], ir.Buffer): + result.data.data.inputs[0]._post_init_setattr("origin_node", n) + self.register_users_of(result) new_unbacked_defs = OrderedSet[sympy.Symbol]() diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a874d1931a7d..10d596c99bf1 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -9399,30 +9399,3 @@ def maybe_free_symbols(s: object) -> OrderedSet[Symbol]: return free_symbols(s) else: return OrderedSet() - - -def assign_origin_node(result: Any, n: torch.fx.Node) -> None: - # This is not complete, but it doesn't have to be: origin_node - # tracking is best effort. The logic here critically relies on direct - # TensorBox -> StorageBox denoting a non-view; we don't bother trying - # to get views to work. Feel free to add any extra cases as needed. - # - # Note: we can't YOLO tree_map over this result, because if there are - # buffers or a view involved, we might not be able to validly assign - # the origin_node here. - if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): - if isinstance(result.data.data, Loops): - result.data.data._post_init_setattr("origin_node", n) - elif isinstance(result.data.data, Buffer): - result.data.data._post_init_setattr("origin_node", n) - if isinstance(result.data.data, ComputedBuffer) and isinstance( - result.data.data.data, Loops - ): - result.data.data.data._post_init_setattr("origin_node", n) - # Not really multi-output, can straightforwardly recurse in - elif ( - isinstance(result.data.data, MultiOutput) - and not result.data.data.indices - ): - if isinstance(result.data.data.inputs[0], Buffer): - result.data.data.inputs[0]._post_init_setattr("origin_node", n) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index adf203c9bca8..878c5992c3f9 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -50,7 +50,6 @@ from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .exc import GPUTooOldForTriton, TritonMissing from .fx_utils import count_flops_fx from .ir import ( - assign_origin_node, get_device_type, GraphPartitionSignature, MultiOutput, @@ -3164,16 +3163,12 @@ class Scheduler: node.node.finalize_as_triton_caller(min_node_unfused) continue - with ir.IRNode.current_origins(multi_node.origins): - out_tensorbox = min_node_unfused.output_node() + out_tensorbox = min_node_unfused.output_node() out_storage = out_tensorbox.data # type: ignore[union-attr] assert isinstance(out_storage, ir.StorageBox) out_buffer = out_storage.data assert isinstance(out_buffer, ir.OperationBuffer) - if multi_node.origin_node: - assign_origin_node(out_tensorbox, multi_node.origin_node) - out_buffer.layout = multi_node.layout replace_operation_buffer(multi_node, out_buffer) new_scheduler_node = self.create_scheduler_node(out_buffer)