[dynamo] Support user defined dicts (#143548)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143548
Approved by: https://github.com/yanboliang, https://github.com/jansel, https://github.com/williamwen42
This commit is contained in:
Animesh Jain
2024-12-20 08:19:50 -08:00
committed by PyTorch MergeBot
parent 9cb743d1f9
commit 4627cfd1f9
11 changed files with 944 additions and 13 deletions

759
test/dynamo/test_dicts.py Normal file
View File

@ -0,0 +1,759 @@
# Owner(s): ["module: dynamo"]
# ruff: noqa: TRY002
# flake8: noqa
import dataclasses
from collections import OrderedDict
from dataclasses import dataclass, fields, is_dataclass
from typing import Any, Optional, Tuple
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch._dynamo.testing import same
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import TestCase
class SimpleDict(dict):
pass
class DictTests(torch._dynamo.test_case.TestCase):
def test_dict_subclass_instantiation(self):
def fn(x):
sd = SimpleDict(x=5)
return sd["x"] * x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_local_mutation(self):
def fn(x):
sd = SimpleDict(x=5)
z = sd["x"] * x
sd["x"] = 10
return z * sd["x"]
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_local_with_non_dict_method(self):
# Checks that add_1 method is inlined
class MethodDict(dict):
def add_1(self, x):
return x + 1
def fn(x):
sd = MethodDict(x=5)
z = sd["x"] * x
sd["x"] = 10
return sd.add_1(z * sd["x"])
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_methods_fallback_readonly(self):
sd = SimpleDict()
sd[2] = 5
sd[4] = 10
# check that regular attr accesses work well
sd.attr = 4
def fn(x):
for value in sd.values():
x = x * value
for key in sd.keys():
x = x * key
for k, v in sd.items():
x = x * k
x = x * v
# for k in sd:
# x = x * k
if 1 in sd:
x = x * 2
else:
x = x * 3
x = x * sd.get(2, 0)
x = x * sd.get(3, 4)
x = len(sd) * x
x = x * sd.attr
return x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
# Ensure a recompilation
sd[6] = 15
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_instantiation_return(self):
def fn(x):
sd = SimpleDict(x=5 * x)
sd["y"] = 10
return sd
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(type(ref), type(res))
self.assertEqual(ref["x"], res["x"])
self.assertEqual(ref["y"], res["y"])
def test_dict_subclass_methods_fallback_mutation(self):
def fn(sd, x):
for value in sd.values():
x = x * value
sd[6] = 14
for key in sd.keys():
x = x * key
for k, v in sd.items():
x = x * k
x = x * v
# for k in sd:
# x = x * k
if 1 in sd:
x = x * 2
else:
x = x * 3
x = x * sd.get(2, 0)
x = x * sd.get(3, 4)
x = len(sd) * x
x = x * sd.attr
sd.attr = 10
x = x * sd.attr
return x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
sd1 = SimpleDict()
sd1[2] = 5
sd1[4] = 10
sd1.attr = 4
sd2 = SimpleDict()
sd2[2] = 5
sd2[4] = 10
sd2.attr = 4
self.assertTrue(sd1 == sd2)
self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
self.assertTrue(sd1 == sd2)
self.assertTrue(sd1.attr == sd2.attr)
def test_dict_subclass_setitem(self):
class SetItemDict(dict):
def __setitem__(self, key, value):
super().__setitem__(key, value + 1)
def fn(x):
sd = SetItemDict(x=5 * x)
sd["y"] = 10
return sd
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(type(ref), type(res))
self.assertEqual(ref["x"], res["x"])
self.assertEqual(ref["y"], res["y"])
def is_tensor(x):
import torch
return isinstance(x, torch.Tensor)
class ModelOutput(OrderedDict):
"""
Copied from transformers.
"""
def __init_subclass__(cls) -> None:
"""Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Subclasses of ModelOutput must use the @dataclass decorator
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
# Just need to check that the current class is not ModelOutput
is_modeloutput_subclass = self.__class__ != ModelOutput
if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
)
def __post_init__(self):
"""Check the ModelOutput dataclass.
Only occurs if @dataclass decorator has been used.
"""
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(
f"{self.__class__.__name__} should not have more than one required field."
)
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(
getattr(self, field.name) is None for field in class_fields[1:]
)
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for idx, element in enumerate(iterator):
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
if idx == 0:
# If we do not have an iterator of key/values, set it as attribute
self[class_fields[0].name] = first_field
else:
# If we have a mixed iterator, raise an error
raise ValueError(
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
)
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def setdefault(self, *args, **kwargs):
raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def pop(self, *args, **kwargs):
raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
)
def update(self, *args, **kwargs):
raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
return super().__reduce__()
callable, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
@dataclass
class BaseModelOutput(ModelOutput):
"""
Copied from transformers
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class CausalLMOutputWithPast(ModelOutput):
"""
Copied from transformers
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class TestModelOutput(torch._dynamo.test_case.TestCase):
def test_mo_create(self):
def fn(a, b):
tmp = BaseModelOutput(a + 1, attentions=b + 3)
return tmp
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
def test_mo_assign(self):
def fn(a, b):
tmp = BaseModelOutput(last_hidden_state=b + 3)
tmp.hidden_states = a + 7
tmp["attentions"] = a + b + 6
return tmp
args = [torch.randn(10), torch.randn(10)]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
self.assertTrue(same(obj1.attentions, obj2.attentions))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 4)
def _common(self, fn, op_count):
args = [
BaseModelOutput(
last_hidden_state=torch.randn(10), attentions=torch.randn(10)
)
]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1, obj2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, op_count)
def test_mo_getattr(self):
def fn(obj: BaseModelOutput):
x = obj.last_hidden_state * 10
if obj.hidden_states is not None:
x += obj.hidden_states
if obj.attentions is not None:
x += obj.attentions
return x
self._common(fn, 2)
def test_mo_getattr_missing(self):
def fn(obj: BaseModelOutput):
if getattr(obj, "asdf", None) is not None:
obj.asdf += 1
return obj.attentions + 1
self._common(fn, 1)
def test_mo_getitem(self):
def fn(obj: BaseModelOutput):
x = obj["last_hidden_state"] * 10
if "hidden_stats" in obj:
x += obj["hidden_states"]
if "attentions" in obj:
x += obj["attentions"]
return x
self._common(fn, 2)
def test_mo_tuple(self):
def fn(obj: BaseModelOutput):
a, b = obj.to_tuple()
return a + b * 10
self._common(fn, 2)
def test_mo_index(self):
def fn(obj: BaseModelOutput):
return obj[0] * 10 + obj[1]
self._common(fn, 2)
def test_mo_init(self):
@dataclasses.dataclass
class MyDataClass(ModelOutput):
a: torch.Tensor
b: torch.Tensor = None
c: torch.Tensor = None
d: torch.Tensor = None
e: torch.Tensor = None
def fn(obj):
class_fields = dataclasses.fields(obj)
assert len(class_fields)
assert all(field.default is None for field in class_fields[1:])
other_fields_are_none = all(
getattr(obj, field.name) is None for field in class_fields[1:]
)
assert not other_fields_are_none
total = getattr(obj, class_fields[0].name)
for field in class_fields[1:]:
v = getattr(obj, field.name)
if v is not None:
total += v
return total
tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
obj1 = MyDataClass(*tensors)
correct1 = fn(obj1)
obj2 = MyDataClass(*tensors)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertTrue(same(opt_fn(obj2), correct1))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
def test_mo_init2(self):
# this ModelOutput subclass runs a different __post_init__ codepath
@dataclasses.dataclass
class MyDataClass(ModelOutput):
x: torch.FloatTensor = None
def fn(x):
obj = MyDataClass(x=x * 5)
return obj
inp = torch.randn(3, 3)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(inp).x, opt_fn(inp).x)
def test_mo_init_with_disable(self):
# Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>"
# graph breaks (although it may not be the first)
# Minimal repro for https://github.com/pytorch/pytorch/issues/126028
@dataclasses.dataclass
class MyDataClass(ModelOutput):
x: torch.FloatTensor = None
@torch._dynamo.disable(recursive=False)
def fn(x):
return MyDataClass(x=x)
inp = torch.randn(3, 3)
opt_fn = torch.compile(fn, backend="eager")
self.assertEqual(fn(inp).x, opt_fn(inp).x)
def test_mo_newkey(self):
obj = BaseModelOutput()
def fn(obj):
return obj["wwww"] + 1
inp = torch.randn(3, 3)
obj["wwww"] = inp
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(obj), opt_fn(obj))
def test_mo_from_outside(self):
def fn(obj):
return obj.attentions + 1
obj = BaseModelOutput(attentions=torch.randn(3, 3))
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(obj), opt_fn(obj))
def test_mo_reconstruct_bytecode(self):
def fn(inp):
return BaseModelOutput(attentions=inp + 1)
inp = torch.randn(3, 3)
opt_fn = torch.compile(fn, backend="eager")
self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions)
def test_none(self):
class Model(torch.nn.Module):
def forward(self, x):
x = x + 1
return CausalLMOutputWithPast(loss=None, logits=x)[0]
model = Model()
opt_model = torch.compile(model, backend="eager", fullgraph=True)
x = torch.randn(1, 1, 1, 1)
self.assertTrue(same(model(x), opt_model(x)))
def test_reconstruction(self):
torch._export.utils.register_dataclass_as_pytree_node(
CausalLMOutputWithPast,
serialized_type_name="test_reconstruction_CausalLMOutputWithPast",
)
class Model(torch.nn.Module):
def forward(self, x):
x = x + 1
return CausalLMOutputWithPast(loss=x, logits=None)
model = Model()
x = torch.randn(1, 1, 1, 1)
eo = torch._dynamo.export(Model(), aten_graph=True)(x)
self.assertTrue(same(model(x), eo.graph_module(x)))
class TestModelOutputBert(TestCase):
def test_HF_bert_model_output(self, device):
class BertPooler(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dense = torch.nn.Linear(768, 768).to(device)
self.activation = torch.nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertEncoder(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
hidden_states: torch.Tensor,
) -> BaseModelOutputWithPastAndCrossAttentions:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
cross_attentions=None,
)
class BertModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.encoder = BertEncoder()
self.pooler = BertPooler()
def forward(
self,
sequence_output: torch.Tensor,
) -> BaseModelOutputWithPoolingAndCrossAttentions:
encoder_outputs = self.encoder(sequence_output)
# test __getitem__ and to_tuple
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
# test CustomDictVariable.create
result = BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# test __setattr__
result.pooler_output = pooled_output
# test __setitem__
result["pooler_output"] = pooled_output
return result
sequence_output = torch.rand(1, 12, 768).to(device)
model = BertModel()
orig_result = model(sequence_output)
compiled_model = torch.compile(model, backend="eager")
compiled_result = compiled_model(sequence_output)
self.assertTrue(
torch.allclose(
orig_result.last_hidden_state, compiled_result.last_hidden_state
)
)
self.assertTrue(
torch.allclose(orig_result.pooler_output, compiled_result.pooler_output)
)
devices = ["cpu"]
instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -10599,7 +10599,7 @@ ShapeEnv not equal: field values don't match:
foo()
def test_dict_subclass_cannot_be_initialized_in_graph(self):
def test_dict_subclass_initialization_in_graph(self):
for super_class in (
collections.OrderedDict,
dict,
@ -10615,11 +10615,10 @@ ShapeEnv not equal: field values don't match:
assert "key" in c
return c["key"] + 1
fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable"
):
print(fn_opt(torch.zeros(1)))
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.rand(4)
self.assertEqual(fn(x), opt_fn(x))
@wrapDeterministicFlagAPITest
def test_backward_deterministic_mode_mismatch_warning(self):

View File

@ -313,6 +313,7 @@ class SerializationMixin:
not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1,
'"dill" not found or not correct version'
)
@skipIfTorchDynamo("Different behavior between 3.11 and 3.13, causing CI issues")
def test_serialization_dill(self):
x = torch.randn(5, 5)

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import collections
import contextlib
import functools
import inspect
@ -20,7 +21,7 @@ from .bytecode_transformation import (
from .codegen import PyCodegen
from .exc import unimplemented
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import is_frozen_dataclass, nn_module_new, object_new
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new
from .variables.base import (
AttributeMutation,
AttributeMutationExisting,
@ -37,6 +38,17 @@ def _manual_update_dict(dict_from, dict_to):
dict_to[k] = v
def _manual_dict_setitem(dict_from, dict_to, mro_index):
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
# to be careful because we don't want to trigger the user defined object
# setitem or clear. The mro_index is used to find the dict/OrderedDict from
# the class mro.
dict_class = type(dict_to).__mro__[mro_index]
dict_class.clear(dict_to)
for k, v in dict_from.items():
dict_class.__setitem__(dict_to, k, v)
class SideEffects:
"""
Track side effects (list mutation, setattr, etc) that need to be
@ -181,9 +193,9 @@ class SideEffects:
@staticmethod
def cls_supports_mutation_side_effects(cls):
return (
inspect.getattr_static(cls, "__getattribute__", None)
is object.__getattribute__
return inspect.getattr_static(cls, "__getattribute__", None) in (
object.__getattribute__,
dict.__getattribute__,
)
def is_attribute_mutation(self, item):
@ -254,6 +266,8 @@ class SideEffects:
obj = torch.autograd.Function()
elif issubclass(user_cls, torch.nn.Module):
obj = nn_module_new(user_cls)
elif issubclass(user_cls, (dict, collections.OrderedDict)):
obj = dict_new(user_cls)
else:
try:
obj = object_new(user_cls)
@ -284,6 +298,8 @@ class SideEffects:
] = variables.UserDefinedObjectVariable
if issubclass(user_cls, torch.nn.Module):
variable_cls = variables.UnspecializedNNModuleVariable
elif issubclass(user_cls, (dict, collections.OrderedDict)):
variable_cls = variables.UserDefinedDictVariable
elif issubclass(user_cls, MutableMapping):
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
@ -442,7 +458,12 @@ class SideEffects:
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "object_new")
lambda: cg.load_import_from(
utils.__name__,
"dict_new"
if isinstance(var, variables.UserDefinedDictVariable)
else "object_new",
)
)
cg(var.mutation_type.cls_source)
cg.extend_output(create_call_function(1, False))
@ -695,6 +716,58 @@ class SideEffects:
suffixes.append([cg.create_store_deref(var.local_name)])
elif self.is_attribute_mutation(var):
if isinstance(var, variables.UserDefinedDictVariable):
# Do dict related update manually here. The store_attr
# mutations will be applied later.
varname_map = {}
for name in _manual_dict_setitem.__code__.co_varnames:
varname_map[name] = cg.tx.output.new_var()
try:
mro_index = type(var.value).__mro__.index(
collections.OrderedDict
)
except ValueError:
mro_index = type(var.value).__mro__.index(dict)
cg.extend_output(
[
create_instruction("LOAD_CONST", argval=mro_index),
create_instruction(
"STORE_FAST", argval=varname_map["mro_index"]
),
]
)
cg(var.source) # type: ignore[attr-defined]
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["dict_to"]
)
]
)
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["dict_from"]
)
]
)
dict_update_insts = bytecode_from_template(
_manual_dict_setitem, varname_map=varname_map
)
suffixes.append(
[
*dict_update_insts,
create_instruction("POP_TOP"),
]
)
# Applying mutations involves two steps: 1) Push all
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
# apply the mutations.

View File

@ -1896,6 +1896,14 @@ tuple_iterator: Type[Iterator[Any]] = type(iter(()))
range_iterator: Type[Iterator[Any]] = type(iter(range(0)))
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
object_new = object.__new__
dict_new = dict.__new__
dict_methods = {
method
for method in itertools.chain(
dict.__dict__.values(), collections.OrderedDict.__dict__.values()
)
if callable(method)
}
def nn_module_new(cls):

View File

@ -113,6 +113,7 @@ from .user_defined import (
MutableMappingVariable,
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedObjectVariable,
)

View File

@ -227,6 +227,7 @@ from .user_defined import (
MutableMappingVariable,
SourcelessGraphModuleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedObjectVariable,
)
@ -1232,6 +1233,46 @@ class VariableBuilder:
fake_script_obj,
source=self.source,
)
elif (
isinstance(value, (dict, collections.OrderedDict))
and type(value).__new__ is dict.__new__
):
# Construct a dict_vt that will reside inside the UserDefinedDictVariable
self.install_guards(GuardBuilder.TYPE_MATCH)
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
# Guard on the key order
self.tx.output.guard_on_key_order.add(self.source.name())
# We need all the keys to be hashable. We do this within the
# _HashableTracker class in dicts.py
def build_key_value(i, k, v):
source_key = ConstDictKeySource(self.get_source(), i)
key = LazyVariableTracker.create(k, source_key)
source_value = GetItemSource(self.get_source(), source_key)
value = LazyVariableTracker.create(v, source_value)
return key, value
result = dict(
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
)
# NB: This is deliberately kept ValueMutationNew because dict_vt is
# an internal representation. dict_vt tracks the mutation on the
# dict side. side_effects infra uses the UserDefinedDictVariable to
# apply side-effects of this dict_vt.
dict_vt = ConstDictVariable(
result,
user_cls=collections.OrderedDict
if isinstance(value, collections.OrderedDict)
else dict,
mutation_type=ValueMutationNew(),
)
result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
elif issubclass(type(value), MutableMapping):
self.install_guards(GuardBuilder.TYPE_MATCH)
return MutableMappingVariable(value, source=self.source)

View File

@ -283,7 +283,14 @@ class ConstDictVariable(VariableTracker):
arg_hashable = args and is_hashable(args[0])
if name == "__getitem__":
if name == "__init__":
temp_dict_vt = variables.BuiltinVariable(dict).call_dict(
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
return ConstantVariable.create(None)
elif name == "__getitem__":
assert len(args) == 1
return self.getitem_const_raise_exception_if_absent(tx, args[0])
elif name == "items":

View File

@ -236,6 +236,11 @@ class SuperVariable(VariableTracker):
self.objvar, attr, variables.DeletedVariable()
)
return variables.ConstantVariable(None)
elif (
isinstance(self.objvar, variables.UserDefinedDictVariable)
and inner_fn in self.objvar._dict_methods
):
return self.objvar._dict_vt.call_method(tx, name, args, kwargs)
unimplemented(f"non-function or method super: {inner_fn}")

View File

@ -42,6 +42,7 @@ from ..utils import (
build_checkpoint_variable,
build_invoke_subgraph_variable,
check_constant_args,
dict_methods,
get_custom_getattr,
has_torch_function,
is_frozen_dataclass,
@ -637,7 +638,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
new_fn = inspect.getattr_static(self.value, "__new__", None)
if isinstance(new_fn, staticmethod):
new_fn = new_fn.__func__
return new_fn in (object.__new__, Generic.__new__)
return new_fn in (object.__new__, Generic.__new__, dict.__new__)
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
if self.source:
@ -1441,6 +1442,42 @@ class RemovableHandleVariable(VariableTracker):
return RemovableHandleClass
class UserDefinedDictVariable(UserDefinedObjectVariable):
"""
Represents user defined objects that are subclasses of dict/OrderedDict.
Internally, it uses a ConstDictVariable to represent the dict part of the
variable tracker. For everything else, it falls back to
UserDefinedObjectVariable.
"""
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
def __init__(self, value, dict_vt=None, **kwargs):
super().__init__(value, **kwargs)
self._dict_vt = dict_vt
if self._dict_vt is None:
assert (
self.source is None
), "dict_vt must be constructed by builder.py when source is present"
self._dict_vt = variables.ConstDictVariable(
{}, mutation_type=ValueMutationNew()
)
self._dict_methods = dict_methods
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
method = self._maybe_get_baseclass_method(name)
if method in self._dict_methods:
return self._dict_vt.call_method(tx, name, args, kwargs)
return super().call_method(tx, name, args, kwargs)
class MutableMappingVariable(UserDefinedObjectVariable):
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields