mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Relax equality check (#165460)
When an object is inherited from multiple types, the previous check would fail. So we should relax it to respect eager semantic Differential Revision: [D84635322](https://our.internmc.facebook.com/intern/diff/D84635322) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165460 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
0aa7ebaf03
commit
2395d7d7da
@ -2,6 +2,7 @@
|
||||
|
||||
# ruff: noqa: TRY002
|
||||
|
||||
import enum
|
||||
import itertools
|
||||
import operator
|
||||
import types
|
||||
@ -56,6 +57,30 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_contains_enum(self):
|
||||
class TensorDim(str, enum.Enum):
|
||||
DDP = "ddp"
|
||||
FSDP = "fsdp"
|
||||
CP = "cp"
|
||||
TP = "tp"
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
val = x.sin()
|
||||
if TensorDim.DDP in {"ddp"}:
|
||||
val += x.cos()
|
||||
if "ddp" in {TensorDim.DDP}:
|
||||
val += x.cos()
|
||||
return val
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
mod = Foo()
|
||||
opt_f = torch.compile(mod)
|
||||
self.assertEqual(mod(inp), opt_f(inp))
|
||||
|
||||
def test_dict_subclass_local_with_non_dict_method(self):
|
||||
# Checks that add_1 method is inlined
|
||||
class MethodDict(dict):
|
||||
|
@ -4,6 +4,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
@ -15191,6 +15192,45 @@ graph():
|
||||
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
||||
)
|
||||
|
||||
def test_enum_str(self):
|
||||
class TensorDim(str, enum.Enum):
|
||||
DDP = "ddp"
|
||||
FSDP = "fsdp"
|
||||
CP = "cp"
|
||||
TP = "tp"
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
val = x.sin()
|
||||
if TensorDim.DDP in {"ddp"}:
|
||||
val += x.cos()
|
||||
if "ddp" in {TensorDim.DDP}:
|
||||
val += x.cos()
|
||||
return val
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
gm = export(Foo(), (inp,)).run_decompositions().module()
|
||||
self.assertExpectedInline(
|
||||
str(gm.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=4] = placeholder[target=x]
|
||||
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
|
||||
%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%x,), kwargs = {})
|
||||
%cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%x,), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sin, %cos), kwargs = {})
|
||||
%cos_1 : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%x,), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %cos_1), kwargs = {})
|
||||
return (add_1,)""",
|
||||
)
|
||||
|
||||
self.assertEqual(gm(inp), Foo()(inp))
|
||||
|
||||
def test_split_const_gm_with_lifted_constants(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -197,9 +197,11 @@ class ConstDictVariable(VariableTracker):
|
||||
@staticmethod
|
||||
def _eq_impl(a, b):
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
if type(a) is not type(b):
|
||||
type_a, type_b = type(a), type(b)
|
||||
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||
return False
|
||||
elif isinstance(a, tuple):
|
||||
|
||||
if isinstance(a, tuple):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return len(a) == len(b) and all(
|
||||
Hashable._eq_impl(u, v) for u, v in zip(a, b)
|
||||
|
Reference in New Issue
Block a user