mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
feat(fx): make_fx
should be aware of functions wrapped with @fx.wrap
(#93273)
Fixes https://github.com/pytorch/pytorch/issues/89421 The strategy is to patch the given function wrapped with `@torch.fx.wrap` so that if a tensor tracer is active, we will `proxy_call` the function. `proxy_call` will also skip certain checks if the function to proxy call is not a torch op (checked with `isinstance(.., OpOverload)`. @IvanYashchuk @ezyang @Chillee Pull Request resolved: https://github.com/pytorch/pytorch/pull/93273 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd8662d5c8
commit
6a4bf3b71b
@ -31,6 +31,7 @@ from torch.fx.node import Target, Argument, _format_arg
|
||||
from torch.fx.passes import shape_prop
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.operator_schemas import get_signature_for_torch_op
|
||||
from copy import deepcopy
|
||||
from collections import namedtuple
|
||||
@ -477,6 +478,45 @@ class TestFX(JitTestCase):
|
||||
self.assertIn('wrapped_decorated_fn', m.code)
|
||||
self.assertEqual(m(1), 1)
|
||||
|
||||
@unittest.skipIf(sys.version_info >= (3, 11, 0), "FX currently does not have 3.11 support")
|
||||
def test_wrap_with_make_fx(self):
|
||||
def to_trace(y):
|
||||
return a_lifted_leaf((4, y), 3) * a_lifted_leaf((3, 4), 5) * a_lifted_leaf((y, y), y)
|
||||
|
||||
expected_code = """def forward(self, y_1):
|
||||
a_lifted_leaf = __main___a_lifted_leaf((4, y_1), 3)
|
||||
a_lifted_leaf_1 = __main___a_lifted_leaf((3, 4), 5)
|
||||
mul = torch.ops.aten.mul.Tensor(a_lifted_leaf, 12); a_lifted_leaf = None
|
||||
a_lifted_leaf_2 = __main___a_lifted_leaf((y_1, y_1), y_1); y_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(mul, a_lifted_leaf_2); mul = a_lifted_leaf_2 = None
|
||||
return mul_1"""
|
||||
|
||||
m = make_fx(to_trace, tracing_mode="real")(torch.tensor([10]))
|
||||
self.assertIn('a_lifted_leaf', m.code)
|
||||
# aten.add.Tensor should be internal to `a_lifted_leaf` when some of the parameters are tensors.
|
||||
# However, it should not be traced as the function is marked as opaque.
|
||||
self.assertNotIn('aten.add.Tensor', m.code)
|
||||
self.assertExpectedInline(
|
||||
m.code.strip(),
|
||||
expected_code
|
||||
)
|
||||
|
||||
m = make_fx(to_trace, tracing_mode="fake")(torch.tensor([10]))
|
||||
self.assertIn('a_lifted_leaf', m.code)
|
||||
self.assertNotIn('aten.add.Tensor', m.code)
|
||||
self.assertExpectedInline(
|
||||
m.code.strip(),
|
||||
expected_code
|
||||
)
|
||||
|
||||
m = make_fx(to_trace, tracing_mode="symbolic")(torch.tensor([10]))
|
||||
self.assertIn('a_lifted_leaf', m.code)
|
||||
self.assertNotIn('aten.add.Tensor', m.code)
|
||||
self.assertExpectedInline(
|
||||
m.code.strip(),
|
||||
expected_code
|
||||
)
|
||||
|
||||
def test_graph_edit_with_proxy(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
Reference in New Issue
Block a user