mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch dispatch mode to ProxyTensor tracing (#77174)
Uses a mode for ProxyTensor tracing so that it traces factory functions as well cc @dhruvbird Pull Request resolved: https://github.com/pytorch/pytorch/pull/77174 Approved by: https://github.com/ezyang
This commit is contained in:
@ -710,6 +710,30 @@ class TestFXExperimental(JitTestCase):
|
||||
inp = torch.randn(3, requires_grad=True)
|
||||
torch.testing.assert_close(traced_graph(inp), f(inp))
|
||||
|
||||
def test_mode_tracing_factory_function(self):
|
||||
def f(x):
|
||||
return x + torch.randn(x.shape)
|
||||
|
||||
traced = make_fx(f, trace_factory_functions=True)(torch.randn(3))
|
||||
self.assertTrue(
|
||||
any(
|
||||
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
|
||||
for node in traced.graph.nodes
|
||||
)
|
||||
)
|
||||
|
||||
def test_mode_tracing_factory_function_default_behavior(self):
|
||||
def f(x):
|
||||
return x + torch.randn(x.shape)
|
||||
|
||||
traced = make_fx(f)(torch.randn(3)) # default behavior should not trace factory functions
|
||||
self.assertFalse(
|
||||
any(
|
||||
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
|
||||
for node in traced.graph.nodes
|
||||
)
|
||||
)
|
||||
|
||||
def test_call_to_assert_with_msg(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
||||
Reference in New Issue
Block a user