Fix convit_base (#95174)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95174
Approved by: https://github.com/ngimel, https://github.com/jansel, https://github.com/atalman
This commit is contained in:
Edward Z. Yang
2023-02-20 16:30:14 -08:00
committed by PyTorch MergeBot
parent 92e03cd583
commit 7ca623c2e1
2 changed files with 9 additions and 1 deletions

View File

@ -107,6 +107,9 @@ class TestValueRanges(TestCase):
self.assertEqual(r.lower, r.upper)
self.assertEqual(ref_r, r.lower)
def test_pow_half(self):
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
@parametrize("fn", BINARY_OPS)
def test_binary_ref(self, fn):
for a, b in itertools.product(CONSTANTS, repeat=2):

View File

@ -360,6 +360,11 @@ class ValueRangeAnalysis:
@classmethod
def pow(cls, a, b):
def is_integer(val):
return isinstance(val, int) or (
hasattr(val, "is_integer") and val.is_integer
)
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if a.is_singleton() and b.is_singleton():
@ -367,7 +372,7 @@ class ValueRangeAnalysis:
if r == sympy.zoo:
return ValueRanges.unknown()
return ValueRanges.wrap(r)
elif b.is_singleton() and b.lower >= 0 and isinstance(b.lower, int):
elif b.is_singleton() and is_integer(b.lower) and b.lower >= 0:
i = ValueRanges.wrap(1)
for _ in range(b.lower):
i = cls.mul(i, a)