mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/112119 Approved by: https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab20bab729
commit
dd24e92949
@ -1,6 +1,6 @@
|
||||
# Owner(s): ["module: ProxyTensor"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import torch
|
||||
import unittest
|
||||
import warnings
|
||||
@ -747,9 +747,6 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor):
|
||||
tracing_mode = "fake"
|
||||
|
||||
|
||||
@xfail_inherited_tests([
|
||||
"test_make_fx_overloads",
|
||||
])
|
||||
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
|
||||
tracing_mode = "symbolic"
|
||||
|
||||
@ -933,8 +930,8 @@ class TestSymbolicTracing(TestCase):
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1, y_1):
|
||||
sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None
|
||||
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
|
||||
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = None
|
||||
return None""")
|
||||
|
||||
|
||||
@ -1059,8 +1056,8 @@ def forward(self, a_1, b_1):
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
|
||||
mul = sym_size * 2; sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
|
||||
mul = sym_size_int * 2; sym_size_int = None
|
||||
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
return empty""")
|
||||
|
||||
@ -1119,8 +1116,8 @@ def forward(self, a_1):
|
||||
self.assertExpectedInline(
|
||||
r, """\
|
||||
def forward(self, x_1):
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None
|
||||
select = torch.ops.aten.select.int(x_1, 0, 0)
|
||||
copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = None
|
||||
return x_1""" # noqa: B950
|
||||
@ -1165,21 +1162,21 @@ def forward(self, crop_camera_1, mask_1):
|
||||
select = torch.ops.aten.select.int(eye, 0, 0)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
|
||||
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
|
||||
sym_size = torch.ops.aten.sym_size(index, 0)
|
||||
expand = torch.ops.aten.expand.default(eye, [sym_size, 3, 3])
|
||||
view = torch.ops.aten.view.default(expand, [sym_size, 3, 3]); expand = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 1)
|
||||
sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 2)
|
||||
expand_1 = torch.ops.aten.expand.default(index, [sym_size, sym_size_1, sym_size_2]); index = None
|
||||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_1, sym_size_2]); expand_1 = sym_size_1 = sym_size_2 = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(index, 0)
|
||||
expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
|
||||
view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
|
||||
expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None
|
||||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
|
||||
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size, 3, 3]); bmm = None
|
||||
mul = sym_size * 3
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
|
||||
mul = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
|
||||
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
|
||||
view_4 = torch.ops.aten.view.default(mm, [sym_size, 3, 3]); mm = sym_size = None
|
||||
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = None
|
||||
return None""")
|
||||
return None""") # noqa: B950
|
||||
|
||||
def test_unbacked_slice(self):
|
||||
def f(x, m):
|
||||
@ -1241,8 +1238,8 @@ def forward(self, images_1, handedness_1, valid_1):
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
|
||||
neg = -sym_size; sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
|
||||
neg = -sym_size_int; sym_size_int = None
|
||||
add = neg + 10; neg = None
|
||||
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
|
||||
return empty""")
|
||||
@ -1304,8 +1301,8 @@ def forward(self, lengths_1, values_1):
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
pow_1 = sym_size ** 0.5; sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
pow_1 = sym_size_int ** 0.5; sym_size_int = None
|
||||
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
|
||||
return div""")
|
||||
|
||||
@ -1317,15 +1314,15 @@ def forward(self, a_1):
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None
|
||||
return div""")
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_float = torch.sym_float(sym_size); sym_size = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
|
||||
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
|
||||
return div""")
|
||||
|
||||
|
Reference in New Issue
Block a user