mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Refactor Symint Deduping to separate pass (#118938)
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward. potential fix for : https://github.com/pytorch/pytorch/issues/118224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/118938 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
dea15c9fdc
commit
d0ca849fdf
@ -1037,6 +1037,25 @@ def forward(self, s0_1, s1_1, x_1, y_1):
|
||||
empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
|
||||
return ((s0_1, s1_1), empty)""")
|
||||
|
||||
def test_non_deduped_shape(self):
|
||||
def f(x, y):
|
||||
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
|
||||
|
||||
x = torch.empty(3, 1)
|
||||
y = torch.empty(5)
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
|
||||
x = fake_mode.from_tensor(x)
|
||||
y = fake_mode.from_tensor(y)
|
||||
r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1, y_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
|
||||
return ((sym_size_int, sym_size_int_1), empty)""")
|
||||
|
||||
def test_unary(self):
|
||||
def f(x):
|
||||
|
||||
Reference in New Issue
Block a user