basic SymInt test for functionalization (#80418)

`expand` is one of a handful of ops with SymInt support today, so this PR gives a basic test that shows functionalization properly mapping `expand.SymInt` -> `expand_copy.SymInt`. I added the logic to handle this properly in https://github.com/pytorch/pytorch/pull/80251, but didn't add a test for it. (see the [code](https://github.com/pytorch/pytorch/pull/80251/files#diff-da7d91d9e59774e3ee8d120a0f97e52058b73125fd7edd55b5c2e71d4ce5629dR330))

I want to add a more comprehensive test that also shows something more E2E (using `PySymInt`'s to avoid baking in shapes, running functionalization, and fx-tracing the output to show that functionalization ran properly), but I think it's currently blocked on some other work.

At least today, `FakeSymbolicTensor` doesn't play well with `make_fx` (but @Chillee mentioned - should it?)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80418
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Brian Hirsh
2022-07-11 15:11:58 -07:00
committed by PyTorch MergeBot
parent f84b30f790
commit f2dcb11bac

View File

@ -658,6 +658,23 @@ def forward(self, a_1):
return add_tensor
""")
def test_expand_symint(self):
# Once some existing SymInt bugs are ironed out, we should update
# this test to plumb FakeSymbolicTensors through it
def f(x):
return x.expand(x.size(0), x.size(1))
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
expand_copy_sym_int = torch.ops.aten.expand_copy.SymInt(a_1, [2, 2]); a_1 = None
return expand_copy_sym_int
""")
def test_fill_(self):
def f(x):
y = x + x