mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157632 Approved by: https://github.com/ezyang, https://github.com/Skylion007
430 lines
13 KiB
Python
430 lines
13 KiB
Python
from typing import Union
|
|
from typing_extensions import assert_type, TypeAlias
|
|
|
|
from torch import randn, Tensor
|
|
|
|
|
|
# Test deduced types of arithmetic operations between tensors, ints, floats and bools
|
|
# The expected type should always be `Tensor`, but isn't.
|
|
# See https://github.com/pytorch/pytorch/issues/145838
|
|
|
|
TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True
|
|
|
|
#
|
|
# Unary ops
|
|
#
|
|
|
|
assert_type(+TENSOR, Tensor)
|
|
assert_type(-TENSOR, Tensor)
|
|
assert_type(~TENSOR, Tensor)
|
|
|
|
#
|
|
# Binary ops that return a bolean
|
|
#
|
|
|
|
# Operator ==
|
|
assert_type(TENSOR == TENSOR, Tensor)
|
|
assert_type(TENSOR == BOOL, Tensor)
|
|
assert_type(BOOL == TENSOR, bool) # Should be Tensor
|
|
assert_type(TENSOR == INT, Tensor)
|
|
assert_type(INT == TENSOR, bool) # Should be Tensor
|
|
assert_type(TENSOR == FLOAT, Tensor)
|
|
assert_type(FLOAT == TENSOR, bool) # Should be Tensor
|
|
|
|
# Operator !=
|
|
assert_type(TENSOR != TENSOR, Tensor)
|
|
assert_type(TENSOR != BOOL, Tensor)
|
|
assert_type(BOOL != TENSOR, bool) # Should be Tensor
|
|
assert_type(TENSOR != INT, Tensor)
|
|
assert_type(INT != TENSOR, bool) # Should be Tensor
|
|
assert_type(TENSOR != FLOAT, Tensor)
|
|
assert_type(FLOAT != TENSOR, bool) # Should be Tensor
|
|
|
|
# Operator <
|
|
assert_type(TENSOR < TENSOR, Tensor)
|
|
assert_type(TENSOR < BOOL, Tensor)
|
|
assert_type(BOOL < TENSOR, Tensor)
|
|
assert_type(TENSOR < INT, Tensor)
|
|
assert_type(INT < TENSOR, Tensor)
|
|
assert_type(TENSOR < FLOAT, Tensor)
|
|
assert_type(FLOAT < TENSOR, Tensor)
|
|
|
|
# Operator >
|
|
assert_type(TENSOR > TENSOR, Tensor)
|
|
assert_type(TENSOR > BOOL, Tensor)
|
|
assert_type(BOOL > TENSOR, Tensor)
|
|
assert_type(TENSOR > INT, Tensor)
|
|
assert_type(INT > TENSOR, Tensor)
|
|
assert_type(TENSOR > FLOAT, Tensor)
|
|
assert_type(FLOAT > TENSOR, Tensor)
|
|
|
|
# Operator <=
|
|
assert_type(TENSOR <= TENSOR, Tensor)
|
|
assert_type(TENSOR <= BOOL, Tensor)
|
|
assert_type(BOOL <= TENSOR, Tensor)
|
|
assert_type(TENSOR <= INT, Tensor)
|
|
assert_type(INT <= TENSOR, Tensor)
|
|
assert_type(TENSOR <= FLOAT, Tensor)
|
|
assert_type(FLOAT <= TENSOR, Tensor)
|
|
|
|
# Operator >=
|
|
assert_type(TENSOR >= TENSOR, Tensor)
|
|
assert_type(TENSOR >= BOOL, Tensor)
|
|
assert_type(BOOL >= TENSOR, Tensor)
|
|
assert_type(TENSOR >= INT, Tensor)
|
|
assert_type(INT >= TENSOR, Tensor)
|
|
assert_type(TENSOR >= FLOAT, Tensor)
|
|
assert_type(FLOAT >= TENSOR, Tensor)
|
|
|
|
#
|
|
# Binary ops that take and return ints or floats
|
|
#
|
|
|
|
# Operator +
|
|
assert_type(TENSOR + TENSOR, Tensor)
|
|
assert_type(TENSOR + BOOL, Tensor)
|
|
assert_type(BOOL + TENSOR, Tensor)
|
|
assert_type(TENSOR + INT, Tensor)
|
|
assert_type(INT + TENSOR, Tensor)
|
|
assert_type(TENSOR + FLOAT, Tensor)
|
|
assert_type(FLOAT + TENSOR, Tensor)
|
|
|
|
# Operator -
|
|
assert_type(TENSOR - TENSOR, Tensor)
|
|
assert_type(TENSOR - BOOL, Tensor)
|
|
assert_type(BOOL - TENSOR, Tensor)
|
|
assert_type(TENSOR - INT, Tensor)
|
|
assert_type(INT - TENSOR, Tensor)
|
|
assert_type(TENSOR - FLOAT, Tensor)
|
|
assert_type(FLOAT - TENSOR, Tensor)
|
|
|
|
# Operator *
|
|
assert_type(TENSOR * TENSOR, Tensor)
|
|
assert_type(TENSOR * BOOL, Tensor)
|
|
assert_type(BOOL * TENSOR, Tensor)
|
|
assert_type(TENSOR * INT, Tensor)
|
|
assert_type(INT * TENSOR, Tensor)
|
|
assert_type(TENSOR * FLOAT, Tensor)
|
|
assert_type(FLOAT * TENSOR, Tensor)
|
|
|
|
# Operator //
|
|
assert_type(TENSOR // TENSOR, Tensor)
|
|
assert_type(TENSOR // BOOL, Tensor)
|
|
assert_type(BOOL // TENSOR, Tensor)
|
|
assert_type(TENSOR // INT, Tensor)
|
|
assert_type(INT // TENSOR, Tensor)
|
|
assert_type(TENSOR // FLOAT, Tensor)
|
|
assert_type(FLOAT // TENSOR, Tensor)
|
|
|
|
# Operator /
|
|
assert_type(TENSOR / TENSOR, Tensor)
|
|
assert_type(TENSOR / BOOL, Tensor)
|
|
assert_type(BOOL / TENSOR, Tensor)
|
|
assert_type(TENSOR / INT, Tensor)
|
|
assert_type(INT / TENSOR, Tensor)
|
|
assert_type(TENSOR / FLOAT, Tensor)
|
|
assert_type(FLOAT / TENSOR, Tensor)
|
|
|
|
# Operator %
|
|
assert_type(TENSOR % TENSOR, Tensor)
|
|
assert_type(TENSOR % BOOL, Tensor)
|
|
assert_type(BOOL % TENSOR, Tensor)
|
|
assert_type(TENSOR % INT, Tensor)
|
|
assert_type(INT % TENSOR, Tensor)
|
|
assert_type(TENSOR % FLOAT, Tensor)
|
|
assert_type(FLOAT % TENSOR, Tensor)
|
|
|
|
# Operator **
|
|
assert_type(TENSOR**TENSOR, Tensor)
|
|
assert_type(TENSOR**BOOL, Tensor)
|
|
assert_type(BOOL**TENSOR, Tensor)
|
|
assert_type(TENSOR**INT, Tensor)
|
|
assert_type(INT**TENSOR, Tensor)
|
|
assert_type(TENSOR**FLOAT, Tensor)
|
|
assert_type(FLOAT**TENSOR, Tensor)
|
|
|
|
#
|
|
# Matrix multiplication
|
|
#
|
|
|
|
# Operator @
|
|
assert_type(TENSOR @ TENSOR, Tensor)
|
|
assert_type(TENSOR @ BOOL, Tensor) # Should fail type checking
|
|
assert_type(BOOL @ TENSOR, Tensor) # type: ignore[operator]
|
|
assert_type(TENSOR @ INT, Tensor) # Should fail type checking
|
|
assert_type(INT @ TENSOR, Tensor) # type: ignore[operator]
|
|
assert_type(TENSOR @ FLOAT, Tensor) # Should fail type checking
|
|
assert_type(FLOAT @ TENSOR, Tensor) # type: ignore[operator]
|
|
|
|
#
|
|
# Binary ops that take and return ints only
|
|
#
|
|
|
|
# Operator <<
|
|
assert_type(TENSOR << TENSOR, Tensor)
|
|
assert_type(TENSOR << BOOL, Tensor)
|
|
assert_type(BOOL << TENSOR, Tensor)
|
|
assert_type(TENSOR << INT, Tensor)
|
|
assert_type(INT << TENSOR, Tensor)
|
|
assert_type(TENSOR << FLOAT, Tensor) # Should fail type checking
|
|
assert_type(FLOAT << TENSOR, Tensor) # Should fail type checking
|
|
|
|
# Operator >>
|
|
assert_type(TENSOR >> TENSOR, Tensor)
|
|
assert_type(TENSOR >> BOOL, Tensor)
|
|
assert_type(BOOL >> TENSOR, Tensor)
|
|
assert_type(TENSOR >> INT, Tensor)
|
|
assert_type(INT >> TENSOR, Tensor)
|
|
assert_type(TENSOR >> FLOAT, Tensor) # Should fail type checking
|
|
assert_type(FLOAT >> TENSOR, Tensor) # Should fail type checking
|
|
|
|
# Operator &
|
|
assert_type(TENSOR & TENSOR, Tensor)
|
|
assert_type(TENSOR & BOOL, Tensor)
|
|
assert_type(BOOL & TENSOR, Tensor)
|
|
assert_type(TENSOR & INT, Tensor)
|
|
assert_type(INT & TENSOR, Tensor)
|
|
assert_type(TENSOR & FLOAT, Tensor) # Should fail type checking
|
|
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
|
|
|
|
# Operator |
|
|
assert_type(TENSOR | TENSOR, Tensor)
|
|
assert_type(TENSOR | BOOL, Tensor)
|
|
assert_type(BOOL | TENSOR, Tensor)
|
|
assert_type(TENSOR | INT, Tensor)
|
|
assert_type(INT | TENSOR, Tensor)
|
|
assert_type(TENSOR | FLOAT, Tensor) # Should fail type checking
|
|
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
|
|
|
|
# Operator ^
|
|
assert_type(TENSOR ^ TENSOR, Tensor)
|
|
assert_type(TENSOR ^ BOOL, Tensor)
|
|
assert_type(BOOL ^ TENSOR, Tensor)
|
|
assert_type(TENSOR ^ INT, Tensor)
|
|
assert_type(INT ^ TENSOR, Tensor)
|
|
assert_type(TENSOR ^ FLOAT, Tensor) # Should fail type checking
|
|
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]
|
|
|
|
|
|
NUMBER: TypeAlias = Union[int, float, bool]
|
|
|
|
|
|
class Binary:
|
|
"""
|
|
This class demonstrates what is possible by overriding every magic method
|
|
relating to binary operations.
|
|
"""
|
|
|
|
def __add__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __and__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __div__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __eq__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __floordiv__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __ge__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __gt__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __le__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __lshift__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __lt__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __mod__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __mul__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __ne__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __or__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __pow__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __radd__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rand__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rdiv__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rfloordiv__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rlshift__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rmod__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rmul__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __ror__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rpow__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rrshift__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rshift__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rsub__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rtruediv__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __rxor__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __sub__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __truediv__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
def __xor__(self, other: NUMBER) -> "Binary": # type: ignore[override]
|
|
return self
|
|
|
|
|
|
BINARY = Binary()
|
|
|
|
assert_type(BINARY + INT, Binary)
|
|
assert_type(BINARY & INT, Binary)
|
|
assert_type(BINARY / INT, Binary)
|
|
assert_type(BINARY == INT, Binary)
|
|
assert_type(BINARY // INT, Binary)
|
|
assert_type(BINARY >= INT, Binary)
|
|
assert_type(BINARY > INT, Binary)
|
|
assert_type(BINARY <= INT, Binary)
|
|
assert_type(BINARY << INT, Binary)
|
|
assert_type(BINARY < INT, Binary)
|
|
assert_type(BINARY % INT, Binary)
|
|
assert_type(BINARY * INT, Binary)
|
|
assert_type(BINARY != INT, Binary)
|
|
assert_type(BINARY | INT, Binary)
|
|
assert_type(BINARY**INT, Binary)
|
|
assert_type(BINARY >> INT, Binary)
|
|
assert_type(BINARY - INT, Binary)
|
|
assert_type(BINARY ^ INT, Binary)
|
|
|
|
assert_type(INT + BINARY, Binary)
|
|
assert_type(INT & BINARY, Binary)
|
|
assert_type(INT / BINARY, Binary)
|
|
assert_type(INT == BINARY, bool)
|
|
assert_type(INT // BINARY, Binary)
|
|
assert_type(INT >= BINARY, Binary)
|
|
assert_type(INT > BINARY, Binary)
|
|
assert_type(INT <= BINARY, Binary)
|
|
assert_type(INT << BINARY, Binary)
|
|
assert_type(INT < BINARY, Binary)
|
|
assert_type(INT % BINARY, Binary)
|
|
assert_type(INT * BINARY, Binary)
|
|
assert_type(INT != BINARY, bool)
|
|
assert_type(INT | BINARY, Binary)
|
|
assert_type(INT**BINARY, Binary)
|
|
assert_type(INT >> BINARY, Binary)
|
|
assert_type(INT - BINARY, Binary)
|
|
assert_type(INT ^ BINARY, Binary)
|
|
|
|
assert_type(BINARY + FLOAT, Binary)
|
|
assert_type(BINARY & FLOAT, Binary)
|
|
assert_type(BINARY / FLOAT, Binary)
|
|
assert_type(BINARY == FLOAT, Binary)
|
|
assert_type(BINARY // FLOAT, Binary)
|
|
assert_type(BINARY >= FLOAT, Binary)
|
|
assert_type(BINARY > FLOAT, Binary)
|
|
assert_type(BINARY <= FLOAT, Binary)
|
|
assert_type(BINARY << FLOAT, Binary)
|
|
assert_type(BINARY < FLOAT, Binary)
|
|
assert_type(BINARY % FLOAT, Binary)
|
|
assert_type(BINARY * FLOAT, Binary)
|
|
assert_type(BINARY != FLOAT, Binary)
|
|
assert_type(BINARY | FLOAT, Binary)
|
|
assert_type(BINARY**FLOAT, Binary)
|
|
assert_type(BINARY >> FLOAT, Binary)
|
|
assert_type(BINARY - FLOAT, Binary)
|
|
assert_type(BINARY ^ FLOAT, Binary)
|
|
|
|
assert_type(FLOAT + BINARY, Binary)
|
|
assert_type(FLOAT & BINARY, Binary)
|
|
assert_type(FLOAT / BINARY, Binary)
|
|
assert_type(FLOAT == BINARY, bool)
|
|
assert_type(FLOAT // BINARY, Binary)
|
|
assert_type(FLOAT >= BINARY, Binary)
|
|
assert_type(FLOAT > BINARY, Binary)
|
|
assert_type(FLOAT <= BINARY, Binary)
|
|
assert_type(FLOAT << BINARY, Binary)
|
|
assert_type(FLOAT < BINARY, Binary)
|
|
assert_type(FLOAT % BINARY, Binary)
|
|
assert_type(FLOAT * BINARY, Binary)
|
|
assert_type(FLOAT != BINARY, bool)
|
|
assert_type(FLOAT | BINARY, Binary)
|
|
assert_type(FLOAT**BINARY, Binary)
|
|
assert_type(FLOAT >> BINARY, Binary)
|
|
assert_type(FLOAT - BINARY, Binary)
|
|
assert_type(FLOAT ^ BINARY, Binary)
|
|
|
|
assert_type(BINARY + BOOL, Binary)
|
|
assert_type(BINARY & BOOL, Binary)
|
|
assert_type(BINARY / BOOL, Binary)
|
|
assert_type(BINARY == BOOL, Binary)
|
|
assert_type(BINARY // BOOL, Binary)
|
|
assert_type(BINARY >= BOOL, Binary)
|
|
assert_type(BINARY > BOOL, Binary)
|
|
assert_type(BINARY <= BOOL, Binary)
|
|
assert_type(BINARY << BOOL, Binary)
|
|
assert_type(BINARY < BOOL, Binary)
|
|
assert_type(BINARY % BOOL, Binary)
|
|
assert_type(BINARY * BOOL, Binary)
|
|
assert_type(BINARY != BOOL, Binary)
|
|
assert_type(BINARY | BOOL, Binary)
|
|
assert_type(BINARY**BOOL, Binary)
|
|
assert_type(BINARY >> BOOL, Binary)
|
|
assert_type(BINARY - BOOL, Binary)
|
|
assert_type(BINARY ^ BOOL, Binary)
|
|
|
|
assert_type(BOOL + BINARY, Binary)
|
|
assert_type(BOOL & BINARY, Binary)
|
|
assert_type(BOOL / BINARY, Binary)
|
|
assert_type(BOOL == BINARY, bool)
|
|
assert_type(BOOL // BINARY, Binary)
|
|
assert_type(BOOL >= BINARY, Binary)
|
|
assert_type(BOOL > BINARY, Binary)
|
|
assert_type(BOOL <= BINARY, Binary)
|
|
assert_type(BOOL << BINARY, Binary)
|
|
assert_type(BOOL < BINARY, Binary)
|
|
assert_type(BOOL % BINARY, Binary)
|
|
assert_type(BOOL * BINARY, Binary)
|
|
assert_type(BOOL != BINARY, bool)
|
|
assert_type(BOOL | BINARY, Binary)
|
|
assert_type(BOOL**BINARY, Binary)
|
|
assert_type(BOOL >> BINARY, Binary)
|
|
assert_type(BOOL - BINARY, Binary)
|
|
assert_type(BOOL ^ BINARY, Binary)
|