Files
pytorch/torch/utils/_sympy/singleton_int.py
Xuehai Pan 30293319a8 [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2024-08-01 17:07:14 +00:00

97 lines
2.9 KiB
Python

# mypy: allow-untyped-defs
import sympy
from sympy.multipledispatch import dispatch
__all__ = ["SingletonInt"]
class SingletonInt(sympy.AtomicExpr):
# This is probably not super important unless we are in multiple dispatch
# situations with other more exotic Expr types.
_op_priority = 99999
def __new__(cls, *args, coeff=None, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
return instance
# The semantics of this class should match that of NestedIntSymNodeImpl in
# c10/core/NestedIntSymNodeImpl.h
def __init__(self, val, *, coeff=1):
self._val = val
self._coeff = coeff
super().__init__()
# See NOTE [ Inequalities with nested int ]
def _eval_Eq(self, other):
if (
isinstance(other, SingletonInt)
and other._val == self._val
and self._coeff == other._coeff
):
return sympy.true
else:
return sympy.false
# This is necessary so that calling expr.free_symbols on exprs that contain
# this Singleton does not error
@property
def free_symbols(self):
return set()
def __mul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
def __rmul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
# Make sure we promptly raise an error instead of falling back to building
# an expression tree. There are probably more ops, how can we be exhaustive?
def __add__(self, other):
raise NotImplementedError("NYI")
def __sub__(self, other):
raise NotImplementedError("NYI")
def __truediv__(self, other):
raise NotImplementedError("NYI")
def __floordiv__(self, other):
raise NotImplementedError("NYI")
def __mod__(self, other):
raise NotImplementedError("NYI")
# See NOTE [ Inequalities with nested int ]
@dispatch(sympy.Integer, SingletonInt)
def _eval_is_ge(a, b):
if a < 2:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if b <= 2:
return sympy.true
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if a._val == b._val:
if a._coeff >= b._coeff:
return sympy.true
else:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")