mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Relax equality check (#165460)
When an object is inherited from multiple types, the previous check would fail. So we should relax it to respect eager semantic Differential Revision: [D84635322](https://our.internmc.facebook.com/intern/diff/D84635322) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165460 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
0aa7ebaf03
commit
2395d7d7da
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
# ruff: noqa: TRY002
|
# ruff: noqa: TRY002
|
||||||
|
|
||||||
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
import types
|
import types
|
||||||
@ -56,6 +57,30 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
|||||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
self.assertEqual(fn(x), opt_fn(x))
|
self.assertEqual(fn(x), opt_fn(x))
|
||||||
|
|
||||||
|
def test_dict_contains_enum(self):
|
||||||
|
class TensorDim(str, enum.Enum):
|
||||||
|
DDP = "ddp"
|
||||||
|
FSDP = "fsdp"
|
||||||
|
CP = "cp"
|
||||||
|
TP = "tp"
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
val = x.sin()
|
||||||
|
if TensorDim.DDP in {"ddp"}:
|
||||||
|
val += x.cos()
|
||||||
|
if "ddp" in {TensorDim.DDP}:
|
||||||
|
val += x.cos()
|
||||||
|
return val
|
||||||
|
|
||||||
|
inp = torch.randn(4, 4)
|
||||||
|
mod = Foo()
|
||||||
|
opt_f = torch.compile(mod)
|
||||||
|
self.assertEqual(mod(inp), opt_f(inp))
|
||||||
|
|
||||||
def test_dict_subclass_local_with_non_dict_method(self):
|
def test_dict_subclass_local_with_non_dict_method(self):
|
||||||
# Checks that add_1 method is inlined
|
# Checks that add_1 method is inlined
|
||||||
class MethodDict(dict):
|
class MethodDict(dict):
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -15191,6 +15192,45 @@ graph():
|
|||||||
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_enum_str(self):
|
||||||
|
class TensorDim(str, enum.Enum):
|
||||||
|
DDP = "ddp"
|
||||||
|
FSDP = "fsdp"
|
||||||
|
CP = "cp"
|
||||||
|
TP = "tp"
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
val = x.sin()
|
||||||
|
if TensorDim.DDP in {"ddp"}:
|
||||||
|
val += x.cos()
|
||||||
|
if "ddp" in {TensorDim.DDP}:
|
||||||
|
val += x.cos()
|
||||||
|
return val
|
||||||
|
|
||||||
|
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||||
|
|
||||||
|
inp = torch.randn(4, 4)
|
||||||
|
gm = export(Foo(), (inp,)).run_decompositions().module()
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(gm.graph).strip(),
|
||||||
|
"""\
|
||||||
|
graph():
|
||||||
|
%x : [num_users=4] = placeholder[target=x]
|
||||||
|
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
|
||||||
|
%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%x,), kwargs = {})
|
||||||
|
%cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%x,), kwargs = {})
|
||||||
|
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%sin, %cos), kwargs = {})
|
||||||
|
%cos_1 : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%x,), kwargs = {})
|
||||||
|
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %cos_1), kwargs = {})
|
||||||
|
return (add_1,)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(gm(inp), Foo()(inp))
|
||||||
|
|
||||||
def test_split_const_gm_with_lifted_constants(self):
|
def test_split_const_gm_with_lifted_constants(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -197,9 +197,11 @@ class ConstDictVariable(VariableTracker):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _eq_impl(a, b):
|
def _eq_impl(a, b):
|
||||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||||
if type(a) is not type(b):
|
type_a, type_b = type(a), type(b)
|
||||||
|
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||||
return False
|
return False
|
||||||
elif isinstance(a, tuple):
|
|
||||||
|
if isinstance(a, tuple):
|
||||||
Hashable = ConstDictVariable._HashableTracker
|
Hashable = ConstDictVariable._HashableTracker
|
||||||
return len(a) == len(b) and all(
|
return len(a) == len(b) and all(
|
||||||
Hashable._eq_impl(u, v) for u, v in zip(a, b)
|
Hashable._eq_impl(u, v) for u, v in zip(a, b)
|
||||||
|
Reference in New Issue
Block a user