mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Support fx.traceback.annotate as decorator (#165805)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165805 Approved by: https://github.com/Lucaskabela, https://github.com/SherlockNoMad, https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
23417ae50f
commit
d9f94e0d7d
@ -238,6 +238,56 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_as_decorator(self):
|
||||
class Mod(torch.nn.Module):
|
||||
@fx_traceback.annotate({"fdsp_bucket": 0})
|
||||
def sin(self, x):
|
||||
return torch.sin(x)
|
||||
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
sin = self.sin(x)
|
||||
sub = sin - 2
|
||||
mul = sub * 2
|
||||
div = mul / 3
|
||||
return div
|
||||
|
||||
m = Mod()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_m = torch.compile(m, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
m(x)
|
||||
opt_m(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'mul_1', {'pp_stage': 0})
|
||||
('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -126,6 +126,7 @@ supported_ctx_manager_classes = dict.fromkeys(
|
||||
torch.cpu.amp.autocast_mode.autocast,
|
||||
torch.cuda.amp.autocast_mode.autocast,
|
||||
torch.fx.traceback.annotate,
|
||||
torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined]
|
||||
# We'll let Dynamo inline into the contextlib part of these context
|
||||
# manager instances, all the way till it invokes the wrapped function
|
||||
# itself (at which point we wrap it back to special context manager
|
||||
@ -364,7 +365,10 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
|
||||
return InferenceModeVariable.create(tx, inf_mode)
|
||||
elif self.value is torch.fx.traceback.annotate:
|
||||
elif self.value in (
|
||||
torch.fx.traceback.annotate,
|
||||
torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined]
|
||||
):
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
return FxTracebackAnnotateVariable(
|
||||
args[0].as_python_constant(), source=self.source
|
||||
|
Reference in New Issue
Block a user