Refactor Symint Deduping to separate pass (#118938)

Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass.

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : https://github.com/pytorch/pytorch/issues/118224

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118938
Approved by: https://github.com/ezyang
This commit is contained in:
Elias Ellison
2024-02-06 01:50:16 +00:00
committed by PyTorch MergeBot
parent dea15c9fdc
commit d0ca849fdf
3 changed files with 121 additions and 112 deletions

View File

@ -1037,6 +1037,25 @@ def forward(self, s0_1, s1_1, x_1, y_1):
empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
return ((s0_1, s1_1), empty)""")
def test_non_deduped_shape(self):
def f(x, y):
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
x = torch.empty(3, 1)
y = torch.empty(5)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
shape_env = ShapeEnv()
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
x = fake_mode.from_tensor(x)
y = fake_mode.from_tensor(y)
r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
return ((sym_size_int, sym_size_int_1), empty)""")
def test_unary(self):
def f(x):