mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Natively support int truncation, don't guard on positive/negative (#122827)
This doesn't entirely fix the original problem that prompted this, but it seems to just be getting stuck in export constraint formatting now which seems like progress to me. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/122827 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
c83900887f
commit
efa36ef092
@ -1086,6 +1086,36 @@ class TestExport(TestCase):
|
||||
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
|
||||
self._test_export_same_as_eager(list_tensor_map, inps)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_crop_like(self):
|
||||
# https://fb.workplace.com/groups/1405155842844877/posts/8195050017188725/
|
||||
|
||||
# Minimal crop code copied from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional
|
||||
class CropLike(torch.nn.Module):
|
||||
def forward(self, image, crop_height, crop_width):
|
||||
c, image_height, image_width = image.shape
|
||||
crop_top = int(round((image_height - crop_height) / 2.0))
|
||||
crop_left = int(round((image_width - crop_width) / 2.0))
|
||||
return image[
|
||||
...,
|
||||
crop_top : crop_top + crop_height,
|
||||
crop_left : crop_left + crop_width,
|
||||
]
|
||||
|
||||
crop = CropLike()
|
||||
imagew = Dim("width")
|
||||
imageh = Dim("height")
|
||||
dynamic_dims = {
|
||||
"image": {0: None, 1: imageh, 2: imagew},
|
||||
"crop_height": None,
|
||||
"crop_width": None,
|
||||
}
|
||||
args = (torch.rand(3, 512, 512), 150, 150)
|
||||
ecrop = export(crop, args=args, dynamic_shapes=dynamic_dims)
|
||||
|
||||
args = (torch.rand(3, 700, 700), 150, 150)
|
||||
self.assertEqual(ecrop.module()(*args), ecrop(*args))
|
||||
|
||||
def test_export_func_with_kwargs(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, arg1, arg2, kw1, kw2):
|
||||
|
@ -404,13 +404,13 @@ class TestPySymInt(TestCase):
|
||||
r = sym_int(a1 / 2)
|
||||
self.assertEqual(guard_int(r), 3)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""")
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""")
|
||||
|
||||
a3 = create_symint(shape_env, 3)
|
||||
r = sym_int(2.0 * torch.sym_float(a3))
|
||||
self.assertEqual(guard_int(r), 6)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""")
|
||||
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""")
|
||||
|
||||
def test_sym_sqrt(self):
|
||||
shape_env = ShapeEnv()
|
||||
@ -432,6 +432,18 @@ class TestPySymInt(TestCase):
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
|
||||
|
||||
def test_sym_trunc(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 5)
|
||||
r = math.trunc(a0 / 2)
|
||||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""")
|
||||
r = torch.sym_int(torch.sym_sqrt(a0))
|
||||
self.assertEqual(r, 2)
|
||||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""")
|
||||
|
||||
def test_sym_ceil(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 5)
|
||||
|
@ -1898,7 +1898,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
|
||||
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
|
||||
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
|
||||
|
@ -339,6 +339,9 @@ class SymFloat:
|
||||
def __ge__(self, other) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __trunc__(self):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_max__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
@ -465,7 +468,7 @@ def sym_int(a):
|
||||
if isinstance(a, SymInt):
|
||||
return a
|
||||
elif isinstance(a, SymFloat):
|
||||
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
|
||||
return math.trunc(a)
|
||||
return py_int(a) # type: ignore[operator]
|
||||
|
||||
def sym_max(a, b):
|
||||
|
@ -449,6 +449,10 @@ class PythonPrinter(ExprPrinter):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
@ -562,6 +562,11 @@ class CppPrinter(ExprPrinter):
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::trunc({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Pow(self, expr):
|
||||
# Uses float constants to perform FP div
|
||||
base, exp = expr.args
|
||||
|
@ -305,6 +305,12 @@ class TritonPrinter(PythonPrinter):
|
||||
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
|
@ -233,6 +233,9 @@ class SymNode:
|
||||
def round(self, ndigits=None) -> "SymNode":
|
||||
return self._round(ndigits) # type: ignore[attr-defined]
|
||||
|
||||
def trunc(self) -> "SymNode":
|
||||
return self._trunc() # type: ignore[attr-defined]
|
||||
|
||||
def add(self, other) -> "SymNode":
|
||||
return self._add(other) # type: ignore[attr-defined]
|
||||
|
||||
@ -454,6 +457,7 @@ METHOD_TO_OPERATOR = {
|
||||
"ceil": math.ceil,
|
||||
"eq": operator.eq,
|
||||
"floor": math.floor,
|
||||
"trunc": math.trunc,
|
||||
"floordiv": operator.floordiv,
|
||||
"ge": operator.ge,
|
||||
"gt": operator.gt,
|
||||
@ -486,6 +490,7 @@ unary_magic_methods = {
|
||||
"neg",
|
||||
"sym_not",
|
||||
"pos",
|
||||
"trunc",
|
||||
}
|
||||
|
||||
|
||||
@ -548,7 +553,7 @@ for name in math_op_names:
|
||||
always_float_magic_methods.add(sym_name)
|
||||
|
||||
|
||||
always_int_magic_methods = {"ceil", "floor"}
|
||||
always_int_magic_methods = {"ceil", "floor", "trunc"}
|
||||
always_bool_magic_methods = {
|
||||
"eq",
|
||||
"ne",
|
||||
@ -653,6 +658,12 @@ def _sympy_floor(a):
|
||||
return _floor_ceil_helper(a, sympy.floor)
|
||||
|
||||
|
||||
def _sympy_trunc(a):
|
||||
from torch.utils._sympy.functions import Trunc
|
||||
|
||||
return Trunc(a)
|
||||
|
||||
|
||||
def _sympy_ceil(a):
|
||||
import sympy
|
||||
|
||||
@ -774,6 +785,7 @@ magic_methods = {
|
||||
"le": _sympy_le,
|
||||
"ge": _sympy_ge,
|
||||
"floor": _sympy_floor,
|
||||
"trunc": _sympy_trunc,
|
||||
"sym_float": _sympy_sym_float,
|
||||
"ceil": _sympy_ceil,
|
||||
"neg": operator.neg,
|
||||
|
@ -1717,7 +1717,7 @@ class DimConstraints:
|
||||
elif left.isdigit():
|
||||
relation_with_digit(right, flip(op), int(left))
|
||||
else:
|
||||
assert op == "=="
|
||||
assert op == "==", t
|
||||
results[left]["eq"] = sympy.sympify(right)
|
||||
|
||||
buf = ""
|
||||
|
@ -328,6 +328,17 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function):
|
||||
return None
|
||||
|
||||
|
||||
class Trunc(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, number):
|
||||
if number.is_integer:
|
||||
return number
|
||||
elif isinstance(number, sympy.Number):
|
||||
return sympy.Integer(math.trunc(float(number)))
|
||||
|
||||
|
||||
class Round(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
|
@ -24,6 +24,7 @@ from .functions import (
|
||||
Round,
|
||||
RoundDecimal,
|
||||
TrueDiv,
|
||||
Trunc,
|
||||
Where,
|
||||
)
|
||||
|
||||
@ -51,6 +52,7 @@ def handlers():
|
||||
TrueDiv: "truediv",
|
||||
FloorDiv: "floordiv",
|
||||
CleanDiv: "div",
|
||||
Trunc: "trunc",
|
||||
Where: "where",
|
||||
sympy.Add: "add",
|
||||
sympy.Mul: "mul",
|
||||
|
@ -745,6 +745,13 @@ class SymPyValueRangeAnalysis:
|
||||
def atan(x):
|
||||
return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
|
||||
|
||||
@staticmethod
|
||||
def trunc(x):
|
||||
def trunc(x):
|
||||
return sympy.Integer(x) if x.is_finite else x
|
||||
|
||||
return ValueRanges.increasing_map(x, trunc)
|
||||
|
||||
|
||||
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
||||
def __init__(self):
|
||||
@ -829,10 +836,7 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
||||
if x == ValueRanges.unknown():
|
||||
return x
|
||||
|
||||
def trunc(x):
|
||||
return sympy.Integer(x) if x.is_finite else x
|
||||
|
||||
return ValueRanges.increasing_map(x, trunc)
|
||||
return cls.trunc(x)
|
||||
|
||||
@classmethod
|
||||
def sub(cls, a, b):
|
||||
|
Reference in New Issue
Block a user