mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make hashing a SymInt raise an error again (#130548)
See https://github.com/pytorch/pytorch/issues/130547 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130548 Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d8baa4df2
commit
408c921d96
@ -7,6 +7,7 @@ import itertools
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1262,11 +1263,15 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
def get_constant_bool(self, val):
|
||||
return SymBool(torch._C._get_constant_bool_symnode(val))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_symint_hashing(self):
|
||||
shape_env = ShapeEnv()
|
||||
hash(create_symint(shape_env, 3))
|
||||
|
||||
def test_symnode_hashing(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
# These all trigger specialization when hashed
|
||||
hash(create_symint(shape_env, 3))
|
||||
hash(create_symbool(shape_env, True))
|
||||
# We should be passing in float here, but create_symbol currently
|
||||
# only supports int
|
||||
|
@ -520,24 +520,31 @@ class SymInt:
|
||||
return self.node.expr
|
||||
|
||||
def __hash__(self) -> builtins.int:
|
||||
return hash(self._get_int())
|
||||
if self.node.is_nested_int():
|
||||
return hash(self.node.nested_int())
|
||||
else:
|
||||
# We could support constant SymInts as well, but not doing it for now
|
||||
raise TypeError("unhashable type: non-nested SymInt")
|
||||
# TODO: Force specialization
|
||||
# This can't be done because the TypeError here is load bearing
|
||||
# for einops
|
||||
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
|
||||
# return hash(builtins.int(self))
|
||||
|
||||
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
||||
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
|
||||
"""Represent this int as an exact integer ratio"""
|
||||
return self._get_int(), 1
|
||||
return self, 1
|
||||
|
||||
def bit_length(self) -> "SymInt":
|
||||
return SymInt(self.node.wrap_int(self._get_int().bit_length()))
|
||||
def bit_length(self) -> builtins.int:
|
||||
# TODO: A more relaxed guard is possible here, where you guard to
|
||||
# allow all integer quantities which would result in the same bit
|
||||
# length. We can also just make a dedicated Sympy function for
|
||||
# computing this quantity and represent it symbolically.
|
||||
return builtins.int(self).bit_length()
|
||||
|
||||
def conjugate(self) -> "SymInt":
|
||||
return self
|
||||
|
||||
def _get_int(self) -> builtins.int:
|
||||
if self.node.is_nested_int():
|
||||
return self.node.nested_int()
|
||||
else:
|
||||
return builtins.int(self)
|
||||
|
||||
|
||||
class SymFloat:
|
||||
"""
|
||||
@ -638,7 +645,7 @@ class SymFloat:
|
||||
|
||||
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
||||
"""Represent this float as an exact integer ratio"""
|
||||
return self._get_float().as_integer_ratio()
|
||||
return builtins.float(self).as_integer_ratio()
|
||||
|
||||
def __repr__(self):
|
||||
return self.node._graph_repr()
|
||||
@ -647,10 +654,7 @@ class SymFloat:
|
||||
return self.node.expr
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._get_float())
|
||||
|
||||
def _get_float(self) -> builtins.float:
|
||||
return self.node.float_() if self.node.is_constant() else builtins.float(self)
|
||||
return hash(builtins.float(self))
|
||||
|
||||
|
||||
class SymBool:
|
||||
|
@ -2122,21 +2122,29 @@ def wrap_fx_proxy_cls(
|
||||
):
|
||||
set_example_value(proxy.node, example_value)
|
||||
return EventVariable(proxy, example_value, **options)
|
||||
elif isinstance(example_value, int) and proxy.node.target in [
|
||||
torch.sym_int,
|
||||
getattr,
|
||||
operator.getitem,
|
||||
torch._utils._element_size,
|
||||
torch.seed,
|
||||
operator.mod,
|
||||
torch._functorch.vmap._validate_and_get_batch_size,
|
||||
# some mac builds are missing torch.distributed.get_rank()
|
||||
getattr(torch.distributed, "get_rank", _missing),
|
||||
getattr(torch.distributed, "get_world_size", _missing),
|
||||
# This always wants to be in the graph, even if the constraint
|
||||
# results in a constant int
|
||||
torch._constrain_as_size,
|
||||
]:
|
||||
elif isinstance(example_value, int) and (
|
||||
proxy.node.target
|
||||
in [
|
||||
torch.sym_int,
|
||||
getattr,
|
||||
operator.getitem,
|
||||
torch._utils._element_size,
|
||||
torch.seed,
|
||||
operator.mod,
|
||||
torch._functorch.vmap._validate_and_get_batch_size,
|
||||
# some mac builds are missing torch.distributed.get_rank()
|
||||
getattr(torch.distributed, "get_rank", _missing),
|
||||
getattr(torch.distributed, "get_world_size", _missing),
|
||||
# This always wants to be in the graph, even if the constraint
|
||||
# results in a constant int
|
||||
torch._constrain_as_size,
|
||||
]
|
||||
or (
|
||||
# TODO: this is a little sus, because we didn't check what the self is
|
||||
proxy.node.op == "call_method"
|
||||
and proxy.node.target in ["bit_length"]
|
||||
)
|
||||
):
|
||||
set_example_value(proxy.node, example_value)
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
elif isinstance(example_value, torch.backends.cuda.SDPAParams):
|
||||
|
@ -9,6 +9,7 @@ import typing
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
@ -932,7 +933,14 @@ class TritonHOPifier:
|
||||
self.raise_unsupported("Grid can have at most rank 3")
|
||||
|
||||
assert len(grids) != 0
|
||||
if len(set(grids)) == 1:
|
||||
|
||||
def intify(x):
|
||||
if isinstance(x, torch.SymInt):
|
||||
return int(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
if len(set(pytree.tree_map(intify, grids))) == 1:
|
||||
# If there's only one unique grid, lets simplify
|
||||
grids = [grids[0]]
|
||||
|
||||
|
@ -90,7 +90,7 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
|
||||
pass
|
||||
# Going via an iterator directly is slower than via list comprehension.
|
||||
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
|
||||
if not allow_duplicate and len(set(axis)) != len(axis):
|
||||
if not allow_duplicate and len(set(map(int, axis))) != len(axis):
|
||||
if argname:
|
||||
raise ValueError(f"repeated axis in `{argname}` argument")
|
||||
else:
|
||||
|
Reference in New Issue
Block a user