Natively support int truncation, don't guard on positive/negative (#122827)

This doesn't entirely fix the original problem that prompted this, but
it seems to just be getting stuck in export constraint formatting now
which seems like progress to me.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122827
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Edward Z. Yang
2024-04-10 19:31:01 -04:00
committed by PyTorch MergeBot
parent c83900887f
commit efa36ef092
12 changed files with 98 additions and 10 deletions

View File

@ -1086,6 +1086,36 @@ class TestExport(TestCase):
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
self._test_export_same_as_eager(list_tensor_map, inps)
@unittest.expectedFailure
def test_crop_like(self):
# https://fb.workplace.com/groups/1405155842844877/posts/8195050017188725/
# Minimal crop code copied from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional
class CropLike(torch.nn.Module):
def forward(self, image, crop_height, crop_width):
c, image_height, image_width = image.shape
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return image[
...,
crop_top : crop_top + crop_height,
crop_left : crop_left + crop_width,
]
crop = CropLike()
imagew = Dim("width")
imageh = Dim("height")
dynamic_dims = {
"image": {0: None, 1: imageh, 2: imagew},
"crop_height": None,
"crop_width": None,
}
args = (torch.rand(3, 512, 512), 150, 150)
ecrop = export(crop, args=args, dynamic_shapes=dynamic_dims)
args = (torch.rand(3, 700, 700), 150, 150)
self.assertEqual(ecrop.module()(*args), ecrop(*args))
def test_export_func_with_kwargs(self):
class Module(torch.nn.Module):
def forward(self, arg1, arg2, kw1, kw2):

View File

@ -404,13 +404,13 @@ class TestPySymInt(TestCase):
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""")
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""")
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""")
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""")
def test_sym_sqrt(self):
shape_env = ShapeEnv()
@ -432,6 +432,18 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_sym_trunc(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.trunc(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""")
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""")
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)

View File

@ -1898,7 +1898,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition

View File

@ -339,6 +339,9 @@ class SymFloat:
def __ge__(self, other) -> builtins.bool:
raise AssertionError("type stub not overridden")
def __trunc__(self):
raise AssertionError("type stub not overridden")
def __sym_max__(self, other):
raise AssertionError("type stub not overridden")
@ -465,7 +468,7 @@ def sym_int(a):
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
return math.trunc(a)
return py_int(a) # type: ignore[operator]
def sym_max(a, b):

View File

@ -449,6 +449,10 @@ class PythonPrinter(ExprPrinter):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_Trunc(self, expr):
assert len(expr.args) == 1
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"

View File

@ -562,6 +562,11 @@ class CppPrinter(ExprPrinter):
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Trunc(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Pow(self, expr):
# Uses float constants to perform FP div
base, exp = expr.args

View File

@ -305,6 +305,12 @@ class TritonPrinter(PythonPrinter):
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_Trunc(self, expr):
assert len(expr.args) == 1
return (
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"

View File

@ -233,6 +233,9 @@ class SymNode:
def round(self, ndigits=None) -> "SymNode":
return self._round(ndigits) # type: ignore[attr-defined]
def trunc(self) -> "SymNode":
return self._trunc() # type: ignore[attr-defined]
def add(self, other) -> "SymNode":
return self._add(other) # type: ignore[attr-defined]
@ -454,6 +457,7 @@ METHOD_TO_OPERATOR = {
"ceil": math.ceil,
"eq": operator.eq,
"floor": math.floor,
"trunc": math.trunc,
"floordiv": operator.floordiv,
"ge": operator.ge,
"gt": operator.gt,
@ -486,6 +490,7 @@ unary_magic_methods = {
"neg",
"sym_not",
"pos",
"trunc",
}
@ -548,7 +553,7 @@ for name in math_op_names:
always_float_magic_methods.add(sym_name)
always_int_magic_methods = {"ceil", "floor"}
always_int_magic_methods = {"ceil", "floor", "trunc"}
always_bool_magic_methods = {
"eq",
"ne",
@ -653,6 +658,12 @@ def _sympy_floor(a):
return _floor_ceil_helper(a, sympy.floor)
def _sympy_trunc(a):
from torch.utils._sympy.functions import Trunc
return Trunc(a)
def _sympy_ceil(a):
import sympy
@ -774,6 +785,7 @@ magic_methods = {
"le": _sympy_le,
"ge": _sympy_ge,
"floor": _sympy_floor,
"trunc": _sympy_trunc,
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
"neg": operator.neg,

View File

@ -1717,7 +1717,7 @@ class DimConstraints:
elif left.isdigit():
relation_with_digit(right, flip(op), int(left))
else:
assert op == "=="
assert op == "==", t
results[left]["eq"] = sympy.sympify(right)
buf = ""

View File

@ -328,6 +328,17 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function):
return None
class Trunc(sympy.Function):
is_integer = True
@classmethod
def eval(cls, number):
if number.is_integer:
return number
elif isinstance(number, sympy.Number):
return sympy.Integer(math.trunc(float(number)))
class Round(sympy.Function):
is_integer = True

View File

@ -24,6 +24,7 @@ from .functions import (
Round,
RoundDecimal,
TrueDiv,
Trunc,
Where,
)
@ -51,6 +52,7 @@ def handlers():
TrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "div",
Trunc: "trunc",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",

View File

@ -745,6 +745,13 @@ class SymPyValueRangeAnalysis:
def atan(x):
return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
@staticmethod
def trunc(x):
def trunc(x):
return sympy.Integer(x) if x.is_finite else x
return ValueRanges.increasing_map(x, trunc)
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def __init__(self):
@ -829,10 +836,7 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
if x == ValueRanges.unknown():
return x
def trunc(x):
return sympy.Integer(x) if x.is_finite else x
return ValueRanges.increasing_map(x, trunc)
return cls.trunc(x)
@classmethod
def sub(cls, a, b):