feat(fx): make_fx should be aware of functions wrapped with @fx.wrap (#93273)

Fixes https://github.com/pytorch/pytorch/issues/89421

The strategy is to patch the given function wrapped with `@torch.fx.wrap` so that if a tensor tracer is active, we will `proxy_call` the function.

`proxy_call` will also skip certain checks if the function to proxy call is not a torch op (checked with `isinstance(.., OpOverload)`.

@IvanYashchuk @ezyang @Chillee
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93273
Approved by: https://github.com/ezyang
This commit is contained in:
jon-chuang
2023-02-02 01:57:49 +00:00
committed by PyTorch MergeBot
parent dd8662d5c8
commit 6a4bf3b71b
4 changed files with 99 additions and 26 deletions

View File

@ -31,6 +31,7 @@ from torch.fx.node import Target, Argument, _format_arg
from torch.fx.passes import shape_prop
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.operator_schemas import get_signature_for_torch_op
from copy import deepcopy
from collections import namedtuple
@ -477,6 +478,45 @@ class TestFX(JitTestCase):
self.assertIn('wrapped_decorated_fn', m.code)
self.assertEqual(m(1), 1)
@unittest.skipIf(sys.version_info >= (3, 11, 0), "FX currently does not have 3.11 support")
def test_wrap_with_make_fx(self):
def to_trace(y):
return a_lifted_leaf((4, y), 3) * a_lifted_leaf((3, 4), 5) * a_lifted_leaf((y, y), y)
expected_code = """def forward(self, y_1):
a_lifted_leaf = __main___a_lifted_leaf((4, y_1), 3)
a_lifted_leaf_1 = __main___a_lifted_leaf((3, 4), 5)
mul = torch.ops.aten.mul.Tensor(a_lifted_leaf, 12); a_lifted_leaf = None
a_lifted_leaf_2 = __main___a_lifted_leaf((y_1, y_1), y_1); y_1 = None
mul_1 = torch.ops.aten.mul.Tensor(mul, a_lifted_leaf_2); mul = a_lifted_leaf_2 = None
return mul_1"""
m = make_fx(to_trace, tracing_mode="real")(torch.tensor([10]))
self.assertIn('a_lifted_leaf', m.code)
# aten.add.Tensor should be internal to `a_lifted_leaf` when some of the parameters are tensors.
# However, it should not be traced as the function is marked as opaque.
self.assertNotIn('aten.add.Tensor', m.code)
self.assertExpectedInline(
m.code.strip(),
expected_code
)
m = make_fx(to_trace, tracing_mode="fake")(torch.tensor([10]))
self.assertIn('a_lifted_leaf', m.code)
self.assertNotIn('aten.add.Tensor', m.code)
self.assertExpectedInline(
m.code.strip(),
expected_code
)
m = make_fx(to_trace, tracing_mode="symbolic")(torch.tensor([10]))
self.assertIn('a_lifted_leaf', m.code)
self.assertNotIn('aten.add.Tensor', m.code)
self.assertExpectedInline(
m.code.strip(),
expected_code
)
def test_graph_edit_with_proxy(self):
class M(torch.nn.Module):
def forward(self, a, b):