mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
basic SymInt test for functionalization (#80418)
`expand` is one of a handful of ops with SymInt support today, so this PR gives a basic test that shows functionalization properly mapping `expand.SymInt` -> `expand_copy.SymInt`. I added the logic to handle this properly in https://github.com/pytorch/pytorch/pull/80251, but didn't add a test for it. (see the [code](https://github.com/pytorch/pytorch/pull/80251/files#diff-da7d91d9e59774e3ee8d120a0f97e52058b73125fd7edd55b5c2e71d4ce5629dR330)) I want to add a more comprehensive test that also shows something more E2E (using `PySymInt`'s to avoid baking in shapes, running functionalization, and fx-tracing the output to show that functionalization ran properly), but I think it's currently blocked on some other work. At least today, `FakeSymbolicTensor` doesn't play well with `make_fx` (but @Chillee mentioned - should it?) Pull Request resolved: https://github.com/pytorch/pytorch/pull/80418 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
f84b30f790
commit
f2dcb11bac
@ -658,6 +658,23 @@ def forward(self, a_1):
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
def test_expand_symint(self):
|
||||
# Once some existing SymInt bugs are ironed out, we should update
|
||||
# this test to plumb FakeSymbolicTensors through it
|
||||
def f(x):
|
||||
return x.expand(x.size(0), x.size(1))
|
||||
|
||||
self.assert_functionalization(f, torch.ones(2, 2))
|
||||
logs = self.get_logs(f, torch.ones(2, 2))
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
expand_copy_sym_int = torch.ops.aten.expand_copy.SymInt(a_1, [2, 2]); a_1 = None
|
||||
return expand_copy_sym_int
|
||||
""")
|
||||
|
||||
def test_fill_(self):
|
||||
def f(x):
|
||||
y = x + x
|
||||
|
Reference in New Issue
Block a user