Add torch dispatch mode to ProxyTensor tracing (#77174)

Uses a mode for ProxyTensor tracing so that it traces factory functions as well

cc @dhruvbird
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77174
Approved by: https://github.com/ezyang
This commit is contained in:
samdow
2022-05-19 19:53:57 +00:00
committed by PyTorch MergeBot
parent 327d313705
commit ba0ca0f591
2 changed files with 107 additions and 47 deletions

View File

@ -710,6 +710,30 @@ class TestFXExperimental(JitTestCase):
inp = torch.randn(3, requires_grad=True)
torch.testing.assert_close(traced_graph(inp), f(inp))
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
traced = make_fx(f, trace_factory_functions=True)(torch.randn(3))
self.assertTrue(
any(
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
for node in traced.graph.nodes
)
)
def test_mode_tracing_factory_function_default_behavior(self):
def f(x):
return x + torch.randn(x.shape)
traced = make_fx(f)(torch.randn(3)) # default behavior should not trace factory functions
self.assertFalse(
any(
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
for node in traced.graph.nodes
)
)
def test_call_to_assert_with_msg(self):
class M(torch.nn.Module):
def forward(self, a, b):