mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add provenance to inductor IR nodes created after graph.run (#164255)"
This reverts commit b9e73e639e36f3aa628752161711e68878231b30.
Reverted https://github.com/pytorch/pytorch/pull/164255 on behalf of https://github.com/jeffdaily due to broke rocm; inductor/test_provenance_tracing.py::TestProvenanceTracingStackTraces::test_deferred_triton_kernels [GH job link](https://github.com/pytorch/pytorch/actions/runs/18200790301/job/51821738132) [HUD commit link](b9e73e639e
) ([comment](https://github.com/pytorch/pytorch/pull/164255#issuecomment-3363360088))
This commit is contained in:
@ -1497,7 +1497,7 @@ class TestMaxAutotune(TestCase):
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_mm_0.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
@ -1518,7 +1518,7 @@ class TestMaxAutotune(TestCase):
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*_0.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
|
@ -13,7 +13,6 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch._inductor import config
|
||||
from torch._inductor.debug import (
|
||||
@ -24,7 +23,6 @@ from torch._inductor.debug import (
|
||||
)
|
||||
from torch._inductor.fx_passes.post_grad import post_grad_passes
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
@ -592,47 +590,6 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
f"Mismatch for key: {key}",
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch(
|
||||
{"trace.provenance_tracking_level": 2, "max_autotune_gemm_backends": "ATEN"}
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_deferred_triton_kernels(self):
|
||||
def foo(m, inp):
|
||||
a = m(inp)
|
||||
return a
|
||||
|
||||
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
|
||||
|
||||
m = torch.nn.Linear(512, 512, bias=True).half().cuda()
|
||||
inp = torch.rand([1, 512]).half().cuda()
|
||||
|
||||
expected = {
|
||||
"extern_kernels.bias_addmm:1": [
|
||||
"a = m(inp)",
|
||||
],
|
||||
}
|
||||
|
||||
with self._setup_provenance_capture() as payload_buffer:
|
||||
with torch.no_grad():
|
||||
_, out_code = run_and_get_code(foo_c, m, inp)
|
||||
payload_content = payload_buffer.getvalue().strip()
|
||||
data = json.loads(payload_content)
|
||||
self.assertEqual(set(data.keys()), set(expected.keys()))
|
||||
for key, expected_lines in expected.items():
|
||||
actual_lines = [self.extract_code_line(s) for s in data[key]]
|
||||
self.assertEqual(
|
||||
sorted(actual_lines),
|
||||
sorted(expected_lines),
|
||||
f"Mismatch for key: {key}",
|
||||
)
|
||||
|
||||
# Check that debug handle is in the output code
|
||||
FileCheck().check(
|
||||
"Topologically Sorted Source Nodes: [a], Original ATen: [aten.t, aten.addmm]"
|
||||
).check("[Provenance debug handles] extern_kernels.bias_addmm:1").check(
|
||||
"extern_kernels.bias_addmm"
|
||||
).run(out_code[0])
|
||||
|
||||
def _check_kernel_information_json(self, kernel_info, expected_kernels):
|
||||
"""Validate kernel information JSON structure and content."""
|
||||
self.assertIsInstance(kernel_info, dict)
|
||||
|
@ -68,7 +68,6 @@ from .exc import (
|
||||
)
|
||||
from .fx_utils import count_flops_fx
|
||||
from .ir import (
|
||||
assign_origin_node,
|
||||
Constant,
|
||||
DonatedBuffer,
|
||||
FixedLayout,
|
||||
@ -1835,7 +1834,31 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if curr.has_large_inner_fn(threshold=100):
|
||||
result.realize()
|
||||
|
||||
assign_origin_node(result, n)
|
||||
# This is not complete, but it doesn't have to be: origin_node
|
||||
# tracking is best effort. The logic here critically relies on direct
|
||||
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
|
||||
# to get views to work. Feel free to add any extra cases as needed.
|
||||
#
|
||||
# Note: we can't YOLO tree_map over this result, because if there are
|
||||
# buffers or a view involved, we might not be able to validly assign
|
||||
# the origin_node here.
|
||||
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
|
||||
if isinstance(result.data.data, ir.Loops):
|
||||
result.data.data._post_init_setattr("origin_node", n)
|
||||
elif isinstance(result.data.data, ir.Buffer):
|
||||
result.data.data._post_init_setattr("origin_node", n)
|
||||
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
|
||||
result.data.data.data, ir.Loops
|
||||
):
|
||||
result.data.data.data._post_init_setattr("origin_node", n)
|
||||
# Not really multi-output, can straightforwardly recurse in
|
||||
elif (
|
||||
isinstance(result.data.data, ir.MultiOutput)
|
||||
and not result.data.data.indices
|
||||
):
|
||||
if isinstance(result.data.data.inputs[0], ir.Buffer):
|
||||
result.data.data.inputs[0]._post_init_setattr("origin_node", n)
|
||||
|
||||
self.register_users_of(result)
|
||||
|
||||
new_unbacked_defs = OrderedSet[sympy.Symbol]()
|
||||
|
@ -9399,30 +9399,3 @@ def maybe_free_symbols(s: object) -> OrderedSet[Symbol]:
|
||||
return free_symbols(s)
|
||||
else:
|
||||
return OrderedSet()
|
||||
|
||||
|
||||
def assign_origin_node(result: Any, n: torch.fx.Node) -> None:
|
||||
# This is not complete, but it doesn't have to be: origin_node
|
||||
# tracking is best effort. The logic here critically relies on direct
|
||||
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
|
||||
# to get views to work. Feel free to add any extra cases as needed.
|
||||
#
|
||||
# Note: we can't YOLO tree_map over this result, because if there are
|
||||
# buffers or a view involved, we might not be able to validly assign
|
||||
# the origin_node here.
|
||||
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
|
||||
if isinstance(result.data.data, Loops):
|
||||
result.data.data._post_init_setattr("origin_node", n)
|
||||
elif isinstance(result.data.data, Buffer):
|
||||
result.data.data._post_init_setattr("origin_node", n)
|
||||
if isinstance(result.data.data, ComputedBuffer) and isinstance(
|
||||
result.data.data.data, Loops
|
||||
):
|
||||
result.data.data.data._post_init_setattr("origin_node", n)
|
||||
# Not really multi-output, can straightforwardly recurse in
|
||||
elif (
|
||||
isinstance(result.data.data, MultiOutput)
|
||||
and not result.data.data.indices
|
||||
):
|
||||
if isinstance(result.data.data.inputs[0], Buffer):
|
||||
result.data.data.inputs[0]._post_init_setattr("origin_node", n)
|
||||
|
@ -50,7 +50,6 @@ from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from .exc import GPUTooOldForTriton, TritonMissing
|
||||
from .fx_utils import count_flops_fx
|
||||
from .ir import (
|
||||
assign_origin_node,
|
||||
get_device_type,
|
||||
GraphPartitionSignature,
|
||||
MultiOutput,
|
||||
@ -3164,16 +3163,12 @@ class Scheduler:
|
||||
node.node.finalize_as_triton_caller(min_node_unfused)
|
||||
continue
|
||||
|
||||
with ir.IRNode.current_origins(multi_node.origins):
|
||||
out_tensorbox = min_node_unfused.output_node()
|
||||
out_tensorbox = min_node_unfused.output_node()
|
||||
out_storage = out_tensorbox.data # type: ignore[union-attr]
|
||||
assert isinstance(out_storage, ir.StorageBox)
|
||||
out_buffer = out_storage.data
|
||||
assert isinstance(out_buffer, ir.OperationBuffer)
|
||||
|
||||
if multi_node.origin_node:
|
||||
assign_origin_node(out_tensorbox, multi_node.origin_node)
|
||||
|
||||
out_buffer.layout = multi_node.layout
|
||||
replace_operation_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
|
Reference in New Issue
Block a user