mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use maxint to bound integers. (#96121)
We don't actually support arbitrary precision integers. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/96121 Approved by: https://github.com/tugsbayasgalan, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6e3e7905e
commit
98ff841a75
@ -970,6 +970,16 @@ def forward(self, crop_camera_1, mask_1):
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
|
||||
return None""")
|
||||
|
||||
def test_unbacked_slice(self):
|
||||
def f(x, m):
|
||||
x = x[m]
|
||||
return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
|
||||
|
||||
make_fx(f, tracing_mode="symbolic")(
|
||||
torch.randn((12, 3, 3)),
|
||||
torch.randint(0, 2, (12,), dtype=torch.bool)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
||||
def test_unbacked_batch_resnet(self):
|
||||
mod = torchvision.models.resnet18()
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
|
||||
import itertools
|
||||
import sys
|
||||
|
||||
import sympy
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -50,6 +51,8 @@ CONSTANTS = [
|
||||
2**24,
|
||||
2**32,
|
||||
2**37 - 1,
|
||||
sys.maxsize - 1,
|
||||
sys.maxsize,
|
||||
]
|
||||
# less constants for N^2 situations
|
||||
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
|
||||
|
@ -426,6 +426,8 @@ def nonzero(fake_mode, func, arg):
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
if arg.nonzero_memo is None:
|
||||
import sys
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import constrain_range
|
||||
|
||||
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||
@ -438,9 +440,7 @@ def nonzero(fake_mode, func, arg):
|
||||
# disjoint with what can actually occur. But this is fine:
|
||||
# remember, the hypothesis is that if your later code works
|
||||
# with N >= 2, it will work with N = 1 and N = 0.
|
||||
lower = 2
|
||||
upper = None
|
||||
constrain_range(nnz, min=lower, max=upper)
|
||||
constrain_range(nnz, min=2, max=sys.maxsize - 1)
|
||||
|
||||
arg._nonzero_memo = nnz
|
||||
arg._nonzero_memo_vc = arg._version
|
||||
|
@ -1334,7 +1334,7 @@ class ShapeEnv:
|
||||
def create_unbacked_symint(self):
|
||||
symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
||||
self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||
self.var_to_range[symbol] = ValueRanges.unknown()
|
||||
self.var_to_range[symbol] = ValueRanges(-sys.maxsize - 1, sys.maxsize)
|
||||
return SymInt(SymNode(symbol, self, int, None))
|
||||
|
||||
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
|
||||
@ -1361,7 +1361,10 @@ class ShapeEnv:
|
||||
|
||||
# We also infer that it must be not 0/1
|
||||
lower = 2 if self.specialize_zero_one else 0
|
||||
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
|
||||
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
|
||||
# as a sentinel sometimes. Your sizevar isn't going to be
|
||||
# anywhere near the max 64-bit integer anyway.
|
||||
self.var_to_range[sympy_expr] = ValueRanges(lower, sys.maxsize - 1)
|
||||
|
||||
if not dyn and self.duck_shape:
|
||||
# This implements duck-shaping: input sizes that match are assigned
|
||||
@ -1577,12 +1580,20 @@ class ShapeEnv:
|
||||
if not _simplified:
|
||||
for symbol, sources in symbol_to_source.items():
|
||||
assert sources
|
||||
assert symbol.is_integer
|
||||
r = self.var_to_range[symbol]
|
||||
bounds = []
|
||||
if r.lower != -sympy.oo:
|
||||
bounds.append(str(r.lower))
|
||||
bounds.append(source_ref(sources[0]))
|
||||
if r.upper != sympy.oo:
|
||||
# NB: This looks like an off-by-one error but it's not: the
|
||||
# upper bound may be sys.maxsize - 1 because we intentionally
|
||||
# exclude sys.maxsize from our bounds to deal with direct
|
||||
# == INT_MAX guards, but it's still dumb to actually test it.
|
||||
# Note that you can be off by a pretty large constant and it
|
||||
# won't matter because sizes in practice will be no where near
|
||||
# the 64-bit limit.
|
||||
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
|
||||
bounds.append(str(r.upper))
|
||||
if len(bounds) > 1:
|
||||
exprs.append(" <= ".join(bounds))
|
||||
|
Reference in New Issue
Block a user