Make hashing a SymInt raise an error again (#130548)

See https://github.com/pytorch/pytorch/issues/130547

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130548
Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/lezcano
This commit is contained in:
Edward Z. Yang
2024-07-16 06:52:16 -07:00
committed by PyTorch MergeBot
parent 1d8baa4df2
commit 408c921d96
5 changed files with 59 additions and 34 deletions

View File

@ -7,6 +7,7 @@ import itertools
import math
import operator
import re
import unittest
import numpy as np
@ -1262,11 +1263,15 @@ class TestSymNumberMagicMethods(TestCase):
def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))
@unittest.expectedFailure
def test_symint_hashing(self):
shape_env = ShapeEnv()
hash(create_symint(shape_env, 3))
def test_symnode_hashing(self):
shape_env = ShapeEnv()
# These all trigger specialization when hashed
hash(create_symint(shape_env, 3))
hash(create_symbool(shape_env, True))
# We should be passing in float here, but create_symbol currently
# only supports int

View File

@ -520,24 +520,31 @@ class SymInt:
return self.node.expr
def __hash__(self) -> builtins.int:
return hash(self._get_int())
if self.node.is_nested_int():
return hash(self.node.nested_int())
else:
# We could support constant SymInts as well, but not doing it for now
raise TypeError("unhashable type: non-nested SymInt")
# TODO: Force specialization
# This can't be done because the TypeError here is load bearing
# for einops
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
# return hash(builtins.int(self))
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
"""Represent this int as an exact integer ratio"""
return self._get_int(), 1
return self, 1
def bit_length(self) -> "SymInt":
return SymInt(self.node.wrap_int(self._get_int().bit_length()))
def bit_length(self) -> builtins.int:
# TODO: A more relaxed guard is possible here, where you guard to
# allow all integer quantities which would result in the same bit
# length. We can also just make a dedicated Sympy function for
# computing this quantity and represent it symbolically.
return builtins.int(self).bit_length()
def conjugate(self) -> "SymInt":
return self
def _get_int(self) -> builtins.int:
if self.node.is_nested_int():
return self.node.nested_int()
else:
return builtins.int(self)
class SymFloat:
"""
@ -638,7 +645,7 @@ class SymFloat:
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
"""Represent this float as an exact integer ratio"""
return self._get_float().as_integer_ratio()
return builtins.float(self).as_integer_ratio()
def __repr__(self):
return self.node._graph_repr()
@ -647,10 +654,7 @@ class SymFloat:
return self.node.expr
def __hash__(self):
return hash(self._get_float())
def _get_float(self) -> builtins.float:
return self.node.float_() if self.node.is_constant() else builtins.float(self)
return hash(builtins.float(self))
class SymBool:

View File

@ -2122,21 +2122,29 @@ def wrap_fx_proxy_cls(
):
set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, **options)
elif isinstance(example_value, int) and proxy.node.target in [
torch.sym_int,
getattr,
operator.getitem,
torch._utils._element_size,
torch.seed,
operator.mod,
torch._functorch.vmap._validate_and_get_batch_size,
# some mac builds are missing torch.distributed.get_rank()
getattr(torch.distributed, "get_rank", _missing),
getattr(torch.distributed, "get_world_size", _missing),
# This always wants to be in the graph, even if the constraint
# results in a constant int
torch._constrain_as_size,
]:
elif isinstance(example_value, int) and (
proxy.node.target
in [
torch.sym_int,
getattr,
operator.getitem,
torch._utils._element_size,
torch.seed,
operator.mod,
torch._functorch.vmap._validate_and_get_batch_size,
# some mac builds are missing torch.distributed.get_rank()
getattr(torch.distributed, "get_rank", _missing),
getattr(torch.distributed, "get_world_size", _missing),
# This always wants to be in the graph, even if the constraint
# results in a constant int
torch._constrain_as_size,
]
or (
# TODO: this is a little sus, because we didn't check what the self is
proxy.node.op == "call_method"
and proxy.node.target in ["bit_length"]
)
):
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
elif isinstance(example_value, torch.backends.cuda.SDPAParams):

View File

@ -9,6 +9,7 @@ import typing
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import torch
import torch.fx as fx
import torch.utils._pytree as pytree
@ -932,7 +933,14 @@ class TritonHOPifier:
self.raise_unsupported("Grid can have at most rank 3")
assert len(grids) != 0
if len(set(grids)) == 1:
def intify(x):
if isinstance(x, torch.SymInt):
return int(x)
else:
return x
if len(set(pytree.tree_map(intify, grids))) == 1:
# If there's only one unique grid, lets simplify
grids = [grids[0]]

View File

@ -90,7 +90,7 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
pass
# Going via an iterator directly is slower than via list comprehension.
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
if not allow_duplicate and len(set(axis)) != len(axis):
if not allow_duplicate and len(set(map(int, axis))) != len(axis):
if argname:
raise ValueError(f"repeated axis in `{argname}` argument")
else: