mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland] Transfer "stack_trace" in post_grad passes (#158752)
Summary: We transfer stack trace in post_grad passes. We shouldn't add "stack_trace" to _COPY_META_FIELDS because _COPY_META_FIELDS is used in proxy.py where stack_trace is explicitly set. Since the stack_trace is being used by more and more debugging tools, we should also start testing it more rigorously. This PR start by adding a first test for testing that stack trace is preserved through post_grad_passes. Test Plan: ``` buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r test_pattern_matcher_transfer_meta buck run mode/dev-nosan fbcode//caffe2/test/inductor:auto_functionalize -- --rcaffe2/test/inductor:auto_functionalize_old ``` Rollback Plan: Differential Revision: D78669729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158752 Approved by: https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
a155f742ad
commit
21c97bd565
@ -21,6 +21,7 @@ from torch._dynamo.testing import (
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._higher_order_ops.schema import find_hop_schema
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
PatternMatcherPass,
|
||||
@ -619,6 +620,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(ref, res)
|
||||
res.sum().backward()
|
||||
|
||||
@inductor_config.patch("fx_graph_cache", False)
|
||||
def test_dropout_checks_joint_graph(self):
|
||||
# `dropout` tests that joint graph passes (not just partitioner) is ran
|
||||
# on the hop graphs. Inductor rng functionalization happens in the joint
|
||||
@ -675,9 +677,9 @@ class GraphModule(torch.nn.Module):
|
||||
sin: "f32[8]" = torch.ops.aten.sin.default(primals_0)
|
||||
|
||||
inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu'))
|
||||
|
||||
inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None
|
||||
inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None
|
||||
|
||||
gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None
|
||||
mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); sin = None
|
||||
mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None
|
||||
@ -690,6 +692,7 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@inductor_config.patch("fx_graph_cache", False)
|
||||
def test_dropout_checks_joint_graph_inference(self):
|
||||
# Checks that joint graph results in inductor seeds for just the inference graph
|
||||
@nested_compile_region
|
||||
@ -719,9 +722,9 @@ class <lambda>(torch.nn.Module):
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[8]"):
|
||||
inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu'))
|
||||
|
||||
inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None
|
||||
inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None
|
||||
|
||||
gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None
|
||||
sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); gt = sin = None
|
||||
@ -917,6 +920,7 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@inductor_config.patch("fx_graph_cache", False)
|
||||
def test_view_to_reshape(self):
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
|
@ -185,9 +185,15 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
# No stacktrace found for following nodes
|
||||
# Custom comment for test
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
)
|
||||
|
||||
# stack trace should be in post_grad_graph
|
||||
self.assertTrue(
|
||||
"code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs,
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
@ -328,10 +334,16 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
|
||||
# No stacktrace found for following nodes
|
||||
# Custom comment for test
|
||||
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \
|
||||
arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
return ()""",
|
||||
ignore_comments=True,
|
||||
)
|
||||
|
||||
# stack trace should be in post_grad_graph
|
||||
self.assertTrue(
|
||||
"code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs,
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
|
@ -9,12 +9,15 @@ import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch._inductor import config
|
||||
from torch._inductor.debug import (
|
||||
create_mapping_pre_post_grad_nodes,
|
||||
create_node_mapping_kernel_to_post_grad,
|
||||
)
|
||||
from torch._inductor.fx_passes.post_grad import post_grad_passes
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
@ -427,5 +430,58 @@ class TestProvenanceTracingNodeMapping(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestProvenanceTracingNodeMeta(TestCase):
|
||||
def get_node_with_target(self, gm, target):
|
||||
"""
|
||||
Return first node in gm with target
|
||||
"""
|
||||
return next(iter([node for node in gm.graph.nodes if node.target == target]))
|
||||
|
||||
@requires_cuda # test only works for cuda pattern matcher
|
||||
def test_pattern_matcher_transfer_meta(self):
|
||||
"""
|
||||
Test that stack trace is transfered when node is decomposed in post_grad_passes
|
||||
"""
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.sigmoid(x)
|
||||
return x * 3
|
||||
|
||||
x = torch.randn(8, 10).to("cuda")
|
||||
example_inputs = (x,)
|
||||
model = Model().to("cuda")
|
||||
|
||||
# mimic the before_post_grad graph
|
||||
ep = torch.export.export(model, example_inputs).run_decompositions()
|
||||
gm = ep.module()
|
||||
|
||||
# Set fake mode for V
|
||||
fake_inputs = [
|
||||
node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
fake_mode = detect_fake_mode(fake_inputs)
|
||||
V.set_fake_mode(fake_mode)
|
||||
|
||||
addmm_node = self.get_node_with_target(gm, torch.ops.aten.addmm.default)
|
||||
stack_trace = addmm_node.meta["stack_trace"]
|
||||
|
||||
post_grad_passes(gm, True) # for this test is_inference doesn't matter
|
||||
|
||||
mm_node = self.get_node_with_target(gm, torch.ops.aten.mm.default)
|
||||
add_node = self.get_node_with_target(gm, torch.ops.aten.add.Tensor)
|
||||
|
||||
self.assertEqual(add_node.meta["stack_trace"], stack_trace)
|
||||
self.assertEqual(mm_node.meta["stack_trace"], stack_trace)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -143,6 +143,8 @@ def _transfer_meta(
|
||||
for k, v in old_node.meta.items()
|
||||
if k in torch.fx.proxy._COPY_META_FIELDS
|
||||
)
|
||||
if "stack_trace" in old_node.meta:
|
||||
new_meta["stack_trace"] = old_node.meta["stack_trace"]
|
||||
|
||||
|
||||
class Match:
|
||||
|
Reference in New Issue
Block a user