More testing of Python arithmetic operators between tensors and scalars (see 157266) (#157632)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157632
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
Tom Ritchford
2025-07-05 14:31:46 +00:00
committed by PyTorch MergeBot
parent ee9ac36c23
commit 3e56a9cdfb
3 changed files with 350 additions and 118 deletions

View File

@ -7,12 +7,18 @@ from torch import randn, Tensor
# See ../pass/arithmetic_ops.py for more information
TENSOR, INT, FLOAT = randn(3), 2, 1.5
TENSOR, FLOAT = randn(3), 1.5
FLOAT & TENSOR # E: Unsupported operand types for & ("float" and "Tensor")
FLOAT | TENSOR # E: Unsupported operand types for | ("float" and "Tensor")
FLOAT ^ TENSOR # E: Unsupported operand types for ^ ("float" and "Tensor")
# FIXME: false negatives (https://github.com/pytorch/pytorch/issues/155701)
#
# FLOAT << TENSOR # E: Unsupported operand types for & ("float" and "Tensor")
# FLOAT >> TENSOR # E: Unsupported operand types for & ("float" and "Tensor")
#
# TENSOR & FLOAT # E: Unsupported operand types for & ("Tensor" and "float" )
# TENSOR | FLOAT # E: Unsupported operand types for | ("Tensor" and "float" )
# TENSOR ^ FLOAT # E: Unsupported operand types for ^ ("Tensor" and "float" )
# TENSOR << FLOAT # E: Unsupported operand types for & ("Tensor" and "float")
# TENSOR >> FLOAT # E: Unsupported operand types for & ("Tensor" and "float")

View File

@ -4,151 +4,205 @@ from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor
TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True
# Test deduced types of arithmetic operations between tensors, ints, floats and bools
# The expected type should always be `Tensor`: `Any` and `bool` below are wrong.
# The expected type should always be `Tensor`, but isn't.
# See https://github.com/pytorch/pytorch/issues/145838
TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True
#
# Unary ops
#
assert_type(+TENSOR, Tensor)
assert_type(-TENSOR, Tensor)
assert_type(~TENSOR, Tensor)
# Binary ops
#
# Binary ops that return a bolean
#
# Operator ==
assert_type(TENSOR == TENSOR, Tensor)
assert_type(TENSOR != TENSOR, Tensor)
assert_type(TENSOR < TENSOR, Tensor)
assert_type(TENSOR > TENSOR, Tensor)
assert_type(TENSOR <= TENSOR, Tensor)
assert_type(TENSOR >= TENSOR, Tensor)
assert_type(TENSOR + TENSOR, Tensor)
assert_type(TENSOR - TENSOR, Tensor)
assert_type(TENSOR * TENSOR, Tensor)
assert_type(TENSOR // TENSOR, Tensor)
assert_type(TENSOR / TENSOR, Tensor)
assert_type(TENSOR % TENSOR, Tensor)
assert_type(TENSOR**TENSOR, Tensor)
assert_type(TENSOR << TENSOR, Tensor)
assert_type(TENSOR >> TENSOR, Tensor)
assert_type(TENSOR & TENSOR, Tensor)
assert_type(TENSOR | TENSOR, Tensor)
assert_type(TENSOR ^ TENSOR, Tensor)
assert_type(TENSOR == BOOL, Tensor)
assert_type(TENSOR != BOOL, Tensor)
assert_type(TENSOR < BOOL, Tensor)
assert_type(TENSOR > BOOL, Tensor)
assert_type(TENSOR <= BOOL, Tensor)
assert_type(TENSOR >= BOOL, Tensor)
assert_type(TENSOR + BOOL, Tensor)
assert_type(TENSOR - BOOL, Tensor)
assert_type(TENSOR * BOOL, Tensor)
assert_type(TENSOR // BOOL, Tensor)
assert_type(TENSOR / BOOL, Tensor)
assert_type(TENSOR % BOOL, Tensor)
assert_type(TENSOR**BOOL, Tensor)
assert_type(TENSOR << BOOL, Tensor)
assert_type(TENSOR >> BOOL, Tensor)
assert_type(TENSOR & BOOL, Tensor)
assert_type(TENSOR | BOOL, Tensor)
assert_type(TENSOR ^ BOOL, Tensor)
assert_type(BOOL == TENSOR, bool)
assert_type(BOOL != TENSOR, bool)
assert_type(BOOL < TENSOR, Tensor)
assert_type(BOOL > TENSOR, Tensor)
assert_type(BOOL <= TENSOR, Tensor)
assert_type(BOOL >= TENSOR, Tensor)
assert_type(BOOL + TENSOR, Tensor)
assert_type(BOOL - TENSOR, Tensor)
assert_type(BOOL * TENSOR, Tensor)
assert_type(BOOL // TENSOR, Tensor)
assert_type(BOOL / TENSOR, Tensor)
assert_type(BOOL % TENSOR, Tensor)
assert_type(BOOL**TENSOR, Tensor)
assert_type(BOOL << TENSOR, Tensor)
assert_type(BOOL >> TENSOR, Tensor)
assert_type(BOOL & TENSOR, Tensor)
assert_type(BOOL | TENSOR, Tensor)
assert_type(BOOL ^ TENSOR, Tensor)
assert_type(BOOL == TENSOR, bool) # Should be Tensor
assert_type(TENSOR == INT, Tensor)
assert_type(TENSOR != INT, Tensor)
assert_type(TENSOR < INT, Tensor)
assert_type(TENSOR > INT, Tensor)
assert_type(TENSOR <= INT, Tensor)
assert_type(TENSOR >= INT, Tensor)
assert_type(TENSOR + INT, Tensor)
assert_type(TENSOR - INT, Tensor)
assert_type(TENSOR * INT, Tensor)
assert_type(TENSOR // INT, Tensor)
assert_type(TENSOR / INT, Tensor)
assert_type(TENSOR % INT, Tensor)
assert_type(TENSOR**INT, Tensor)
assert_type(TENSOR << INT, Tensor)
assert_type(TENSOR >> INT, Tensor)
assert_type(TENSOR & INT, Tensor)
assert_type(TENSOR | INT, Tensor)
assert_type(TENSOR ^ INT, Tensor)
assert_type(INT == TENSOR, bool)
assert_type(INT != TENSOR, bool)
assert_type(INT < TENSOR, Tensor)
assert_type(INT > TENSOR, Tensor)
assert_type(INT <= TENSOR, Tensor)
assert_type(INT >= TENSOR, Tensor)
assert_type(INT + TENSOR, Tensor)
assert_type(INT - TENSOR, Tensor)
assert_type(INT * TENSOR, Tensor)
assert_type(INT // TENSOR, Tensor)
assert_type(INT / TENSOR, Tensor)
assert_type(INT % TENSOR, Tensor)
assert_type(INT**TENSOR, Tensor)
assert_type(INT << TENSOR, Tensor)
assert_type(INT >> TENSOR, Tensor)
assert_type(INT & TENSOR, Tensor)
assert_type(INT | TENSOR, Tensor)
assert_type(INT ^ TENSOR, Tensor)
assert_type(INT == TENSOR, bool) # Should be Tensor
assert_type(TENSOR == FLOAT, Tensor)
assert_type(TENSOR != FLOAT, Tensor)
assert_type(TENSOR < FLOAT, Tensor)
assert_type(TENSOR > FLOAT, Tensor)
assert_type(TENSOR <= FLOAT, Tensor)
assert_type(TENSOR >= FLOAT, Tensor)
assert_type(TENSOR + FLOAT, Tensor)
assert_type(TENSOR - FLOAT, Tensor)
assert_type(TENSOR * FLOAT, Tensor)
assert_type(TENSOR // FLOAT, Tensor)
assert_type(TENSOR / FLOAT, Tensor)
assert_type(TENSOR % FLOAT, Tensor)
assert_type(TENSOR**FLOAT, Tensor)
assert_type(TENSOR << FLOAT, Tensor)
assert_type(TENSOR >> FLOAT, Tensor)
assert_type(TENSOR & FLOAT, Tensor)
assert_type(TENSOR | FLOAT, Tensor)
assert_type(TENSOR ^ FLOAT, Tensor)
assert_type(FLOAT == TENSOR, bool) # Should be Tensor
assert_type(FLOAT == TENSOR, bool)
assert_type(FLOAT != TENSOR, bool)
# Operator !=
assert_type(TENSOR != TENSOR, Tensor)
assert_type(TENSOR != BOOL, Tensor)
assert_type(BOOL != TENSOR, bool) # Should be Tensor
assert_type(TENSOR != INT, Tensor)
assert_type(INT != TENSOR, bool) # Should be Tensor
assert_type(TENSOR != FLOAT, Tensor)
assert_type(FLOAT != TENSOR, bool) # Should be Tensor
# Operator <
assert_type(TENSOR < TENSOR, Tensor)
assert_type(TENSOR < BOOL, Tensor)
assert_type(BOOL < TENSOR, Tensor)
assert_type(TENSOR < INT, Tensor)
assert_type(INT < TENSOR, Tensor)
assert_type(TENSOR < FLOAT, Tensor)
assert_type(FLOAT < TENSOR, Tensor)
# Operator >
assert_type(TENSOR > TENSOR, Tensor)
assert_type(TENSOR > BOOL, Tensor)
assert_type(BOOL > TENSOR, Tensor)
assert_type(TENSOR > INT, Tensor)
assert_type(INT > TENSOR, Tensor)
assert_type(TENSOR > FLOAT, Tensor)
assert_type(FLOAT > TENSOR, Tensor)
# Operator <=
assert_type(TENSOR <= TENSOR, Tensor)
assert_type(TENSOR <= BOOL, Tensor)
assert_type(BOOL <= TENSOR, Tensor)
assert_type(TENSOR <= INT, Tensor)
assert_type(INT <= TENSOR, Tensor)
assert_type(TENSOR <= FLOAT, Tensor)
assert_type(FLOAT <= TENSOR, Tensor)
# Operator >=
assert_type(TENSOR >= TENSOR, Tensor)
assert_type(TENSOR >= BOOL, Tensor)
assert_type(BOOL >= TENSOR, Tensor)
assert_type(TENSOR >= INT, Tensor)
assert_type(INT >= TENSOR, Tensor)
assert_type(TENSOR >= FLOAT, Tensor)
assert_type(FLOAT >= TENSOR, Tensor)
#
# Binary ops that take and return ints or floats
#
# Operator +
assert_type(TENSOR + TENSOR, Tensor)
assert_type(TENSOR + BOOL, Tensor)
assert_type(BOOL + TENSOR, Tensor)
assert_type(TENSOR + INT, Tensor)
assert_type(INT + TENSOR, Tensor)
assert_type(TENSOR + FLOAT, Tensor)
assert_type(FLOAT + TENSOR, Tensor)
# Operator -
assert_type(TENSOR - TENSOR, Tensor)
assert_type(TENSOR - BOOL, Tensor)
assert_type(BOOL - TENSOR, Tensor)
assert_type(TENSOR - INT, Tensor)
assert_type(INT - TENSOR, Tensor)
assert_type(TENSOR - FLOAT, Tensor)
assert_type(FLOAT - TENSOR, Tensor)
# Operator *
assert_type(TENSOR * TENSOR, Tensor)
assert_type(TENSOR * BOOL, Tensor)
assert_type(BOOL * TENSOR, Tensor)
assert_type(TENSOR * INT, Tensor)
assert_type(INT * TENSOR, Tensor)
assert_type(TENSOR * FLOAT, Tensor)
assert_type(FLOAT * TENSOR, Tensor)
# Operator //
assert_type(TENSOR // TENSOR, Tensor)
assert_type(TENSOR // BOOL, Tensor)
assert_type(BOOL // TENSOR, Tensor)
assert_type(TENSOR // INT, Tensor)
assert_type(INT // TENSOR, Tensor)
assert_type(TENSOR // FLOAT, Tensor)
assert_type(FLOAT // TENSOR, Tensor)
# Operator /
assert_type(TENSOR / TENSOR, Tensor)
assert_type(TENSOR / BOOL, Tensor)
assert_type(BOOL / TENSOR, Tensor)
assert_type(TENSOR / INT, Tensor)
assert_type(INT / TENSOR, Tensor)
assert_type(TENSOR / FLOAT, Tensor)
assert_type(FLOAT / TENSOR, Tensor)
# Operator %
assert_type(TENSOR % TENSOR, Tensor)
assert_type(TENSOR % BOOL, Tensor)
assert_type(BOOL % TENSOR, Tensor)
assert_type(TENSOR % INT, Tensor)
assert_type(INT % TENSOR, Tensor)
assert_type(TENSOR % FLOAT, Tensor)
assert_type(FLOAT % TENSOR, Tensor)
# Operator **
assert_type(TENSOR**TENSOR, Tensor)
assert_type(TENSOR**BOOL, Tensor)
assert_type(BOOL**TENSOR, Tensor)
assert_type(TENSOR**INT, Tensor)
assert_type(INT**TENSOR, Tensor)
assert_type(TENSOR**FLOAT, Tensor)
assert_type(FLOAT**TENSOR, Tensor)
assert_type(FLOAT << TENSOR, Tensor)
assert_type(FLOAT >> TENSOR, Tensor)
#
# Matrix multiplication
#
# Operator @
assert_type(TENSOR @ TENSOR, Tensor)
assert_type(TENSOR @ BOOL, Tensor) # Should fail type checking
assert_type(BOOL @ TENSOR, Tensor) # type: ignore[operator]
assert_type(TENSOR @ INT, Tensor) # Should fail type checking
assert_type(INT @ TENSOR, Tensor) # type: ignore[operator]
assert_type(TENSOR @ FLOAT, Tensor) # Should fail type checking
assert_type(FLOAT @ TENSOR, Tensor) # type: ignore[operator]
#
# Binary ops that take and return ints only
#
# Operator <<
assert_type(TENSOR << TENSOR, Tensor)
assert_type(TENSOR << BOOL, Tensor)
assert_type(BOOL << TENSOR, Tensor)
assert_type(TENSOR << INT, Tensor)
assert_type(INT << TENSOR, Tensor)
assert_type(TENSOR << FLOAT, Tensor) # Should fail type checking
assert_type(FLOAT << TENSOR, Tensor) # Should fail type checking
# Operator >>
assert_type(TENSOR >> TENSOR, Tensor)
assert_type(TENSOR >> BOOL, Tensor)
assert_type(BOOL >> TENSOR, Tensor)
assert_type(TENSOR >> INT, Tensor)
assert_type(INT >> TENSOR, Tensor)
assert_type(TENSOR >> FLOAT, Tensor) # Should fail type checking
assert_type(FLOAT >> TENSOR, Tensor) # Should fail type checking
# Operator &
assert_type(TENSOR & TENSOR, Tensor)
assert_type(TENSOR & BOOL, Tensor)
assert_type(BOOL & TENSOR, Tensor)
assert_type(TENSOR & INT, Tensor)
assert_type(INT & TENSOR, Tensor)
assert_type(TENSOR & FLOAT, Tensor) # Should fail type checking
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
# Operator |
assert_type(TENSOR | TENSOR, Tensor)
assert_type(TENSOR | BOOL, Tensor)
assert_type(BOOL | TENSOR, Tensor)
assert_type(TENSOR | INT, Tensor)
assert_type(INT | TENSOR, Tensor)
assert_type(TENSOR | FLOAT, Tensor) # Should fail type checking
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
# Operator ^
assert_type(TENSOR ^ TENSOR, Tensor)
assert_type(TENSOR ^ BOOL, Tensor)
assert_type(BOOL ^ TENSOR, Tensor)
assert_type(TENSOR ^ INT, Tensor)
assert_type(INT ^ TENSOR, Tensor)
assert_type(TENSOR ^ FLOAT, Tensor) # Should fail type checking
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]

View File

@ -0,0 +1,172 @@
# mypy: ignore-errors
# Owner(s): ["module: unknown"]
import token
from itertools import product
from pathlib import Path
import torch
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
MM = "@"
BINARY_RETURNS_BOOL = "!=", "<", "<=", "==", ">", ">="
BINARY_ACCEPTS_FLOAT_OR_INT = "%", "*", "**", "+", "-", "/", "//"
BINARY_ACCEPTS_INT_ONLY = "&", "<<", ">>", "^", "|"
BINARY_OPS = (
*BINARY_RETURNS_BOOL,
*BINARY_ACCEPTS_FLOAT_OR_INT,
*BINARY_ACCEPTS_INT_ONLY,
MM,
)
BINARY_RETURNS_FLOAT = ("/",)
UNARY_ACCEPTS_FLOAT_OR_INT = "+", "-"
UNARY_ACCEPTS_INT_ONLY = ("~",)
UNARY_OPS = *UNARY_ACCEPTS_FLOAT_OR_INT, *UNARY_ACCEPTS_INT_ONLY
PUNCTUATION = ",", ";"
OPERATORS = *UNARY_OPS, *BINARY_OPS, *PUNCTUATION
FLOATS = 1.5, torch.tensor((2.5, 3.5))
INTS = 3, torch.tensor((1, 2))
ALL = *FLOATS, *INTS
TYPE_TEST_FILE = Path(__file__).parent / "pass/arithmetic_ops.py"
class TestPythonOperators(TestCase):
# Prove that UNARY_OPS, BINARY_OPS, and OPERATORS are correct and complete
def test_operators_are_correct_and_complete(self):
self.assertFalse(set(OPERATORS).difference(token.EXACT_TOKEN_TYPES))
unary, binary, punctuation = {}, {}, {}
for op in token.EXACT_TOKEN_TYPES:
if op in PUNCTUATION:
punctuation[op] = True
else:
try:
unary[op] = compile(f"{op}1 ; {op}a", op, "single")
except SyntaxError:
pass
try:
binary[op] = compile(f"2 {op} 3 ; a {op} b", op, "single")
except SyntaxError:
pass
self.assertEqual(sorted(unary), sorted(UNARY_OPS))
self.assertEqual(sorted(binary), sorted(BINARY_OPS))
self.assertEqual(sorted(punctuation), sorted(PUNCTUATION))
def test_type_tests_are_complete(self):
binary, unary = {}, []
with TYPE_TEST_FILE.open() as fp:
# Looking for lines like: assert_type(TENSOR ^ BOOL, Tensor)
# But not: assert_type(BOOL ^ BINARY, Binary)
lines = (i for i in fp if "TENSOR" in i)
for line in lines:
if expr := line.partition("assert_type(")[2].partition(",")[0]:
if expr[0].isalpha():
# ** formats differently from all other operators
a, op, b = expr.replace("**", " ** ").split()
binary.setdefault(op, []).append((a, b))
else:
unary.append(expr[0])
self.assertEqual(sorted(unary), sorted(UNARY_OPS))
self.assertEqual(sorted(binary), sorted(BINARY_OPS))
value, *values = binary.values()
self.assertEqual(values, [value] * len(values))
@parametrize("a, op, b", product(ALL, BINARY_OPS, ALL))
def test_binary(self, a, op, b):
try:
r = eval(f"a {op} b")
except Exception as e:
r = e
any_tensor = isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor)
any_float = _any_float(a, b)
returns_float = any_float or op in BINARY_RETURNS_FLOAT
if op == MM:
if not (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)):
self.assertIsInstance(r, TypeError)
elif a is b:
self.assertIsInstance(r, torch.Tensor)
else:
self.assertIsInstance(r, RuntimeError)
elif any_tensor:
if op in BINARY_ACCEPTS_INT_ONLY and any_float:
# See https://github.com/pytorch/pytorch/issues/15754
self.assertIsInstance(r, NotImplementedError)
else:
self.assertIsInstance(r, torch.Tensor)
if op in BINARY_RETURNS_BOOL:
self.assertEqual(r.dtype, torch.bool)
elif op in BINARY_ACCEPTS_INT_ONLY:
self.assertFalse(r.dtype.is_floating_point)
elif op in BINARY_ACCEPTS_FLOAT_OR_INT:
self.assertEqual(r.dtype.is_floating_point, returns_float)
else:
self.assertFalse("Logic error")
elif op in BINARY_RETURNS_BOOL:
self.assertIsInstance(r, bool)
elif op in BINARY_ACCEPTS_INT_ONLY:
if any_float:
self.assertIsInstance(r, TypeError)
else:
self.assertIsInstance(r, int)
elif returns_float:
self.assertIsInstance(r, float)
else:
self.assertIsInstance(r, int)
@parametrize("op, a", product(UNARY_OPS, ALL))
def test_unary(self, op, a):
try:
r = eval(f"{op} a")
except Exception as e:
r = e
if op in UNARY_ACCEPTS_INT_ONLY and _any_float(a):
self.assertIsInstance(r, TypeError)
elif isinstance(a, torch.Tensor):
self.assertIsInstance(r, torch.Tensor)
elif op in UNARY_ACCEPTS_INT_ONLY:
self.assertIsInstance(r, int)
elif isinstance(a, float):
self.assertIsInstance(r, float)
else:
self.assertIsInstance(r, int)
def _any_float(*x):
for i in x:
if isinstance(i, float) or (
isinstance(i, torch.Tensor) and i.dtype.is_floating_point
):
return True
return False
instantiate_parametrized_tests(TestPythonOperators)
if __name__ == "__main__":
run_tests()