Relax equality check (#165460)

When an object is inherited from multiple types, the previous check would fail. So we should relax it to respect eager semantic

Differential Revision: [D84635322](https://our.internmc.facebook.com/intern/diff/D84635322)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165460
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-15 08:00:22 -07:00
committed by PyTorch MergeBot
parent 0aa7ebaf03
commit 2395d7d7da
3 changed files with 69 additions and 2 deletions

View File

@ -2,6 +2,7 @@
# ruff: noqa: TRY002
import enum
import itertools
import operator
import types
@ -56,6 +57,30 @@ class DictTests(torch._dynamo.test_case.TestCase):
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_contains_enum(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"
FSDP = "fsdp"
CP = "cp"
TP = "tp"
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
val = x.sin()
if TensorDim.DDP in {"ddp"}:
val += x.cos()
if "ddp" in {TensorDim.DDP}:
val += x.cos()
return val
inp = torch.randn(4, 4)
mod = Foo()
opt_f = torch.compile(mod)
self.assertEqual(mod(inp), opt_f(inp))
def test_dict_subclass_local_with_non_dict_method(self):
# Checks that add_1 method is inlined
class MethodDict(dict):

View File

@ -4,6 +4,7 @@
import contextlib
import copy
import dataclasses
import enum
import functools
import logging
import math
@ -15191,6 +15192,45 @@ graph():
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
)
def test_enum_str(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"
FSDP = "fsdp"
CP = "cp"
TP = "tp"
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
val = x.sin()
if TensorDim.DDP in {"ddp"}:
val += x.cos()
if "ddp" in {TensorDim.DDP}:
val += x.cos()
return val
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
inp = torch.randn(4, 4)
gm = export(Foo(), (inp,)).run_decompositions().module()
self.assertExpectedInline(
str(gm.graph).strip(),
"""\
graph():
%x : [num_users=4] = placeholder[target=x]
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%x,), kwargs = {})
%cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%x,), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sin, %cos), kwargs = {})
%cos_1 : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%x,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %cos_1), kwargs = {})
return (add_1,)""",
)
self.assertEqual(gm(inp), Foo()(inp))
def test_split_const_gm_with_lifted_constants(self):
class Model(torch.nn.Module):
def __init__(self) -> None:

View File

@ -197,9 +197,11 @@ class ConstDictVariable(VariableTracker):
@staticmethod
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
if type(a) is not type(b):
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
return False
elif isinstance(a, tuple):
if isinstance(a, tuple):
Hashable = ConstDictVariable._HashableTracker
return len(a) == len(b) and all(
Hashable._eq_impl(u, v) for u, v in zip(a, b)