From 4627cfd1f99b40d2415529aaecfcf1e378932745 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 20 Dec 2024 08:19:50 -0800 Subject: [PATCH] [dynamo] Support user defined dicts (#143548) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143548 Approved by: https://github.com/yanboliang, https://github.com/jansel, https://github.com/williamwen42 --- test/dynamo/test_dicts.py | 759 ++++++++++++++++++ test/dynamo/test_misc.py | 11 +- .../TestSerialization.test_serialization_dill | 0 test/test_serialization.py | 1 + torch/_dynamo/side_effects.py | 83 +- torch/_dynamo/utils.py | 8 + torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/builder.py | 41 + torch/_dynamo/variables/dicts.py | 9 +- torch/_dynamo/variables/misc.py | 5 + torch/_dynamo/variables/user_defined.py | 39 +- 11 files changed, 944 insertions(+), 13 deletions(-) create mode 100644 test/dynamo/test_dicts.py delete mode 100644 test/dynamo_expected_failures/TestSerialization.test_serialization_dill diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py new file mode 100644 index 000000000000..33fd38cabbf4 --- /dev/null +++ b/test/dynamo/test_dicts.py @@ -0,0 +1,759 @@ +# Owner(s): ["module: dynamo"] + +# ruff: noqa: TRY002 +# flake8: noqa + +import dataclasses +from collections import OrderedDict +from dataclasses import dataclass, fields, is_dataclass +from typing import Any, Optional, Tuple + +import torch +import torch._dynamo.config +import torch._dynamo.test_case +import torch._dynamo.testing +import torch._functorch.config +import torch.nn +import torch.utils.checkpoint +from torch._dynamo.testing import same +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import TestCase + + +class SimpleDict(dict): + pass + + +class DictTests(torch._dynamo.test_case.TestCase): + def test_dict_subclass_instantiation(self): + def fn(x): + sd = SimpleDict(x=5) + return sd["x"] * x + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_dict_subclass_local_mutation(self): + def fn(x): + sd = SimpleDict(x=5) + z = sd["x"] * x + sd["x"] = 10 + return z * sd["x"] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_dict_subclass_local_with_non_dict_method(self): + # Checks that add_1 method is inlined + class MethodDict(dict): + def add_1(self, x): + return x + 1 + + def fn(x): + sd = MethodDict(x=5) + z = sd["x"] * x + sd["x"] = 10 + return sd.add_1(z * sd["x"]) + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_dict_subclass_methods_fallback_readonly(self): + sd = SimpleDict() + sd[2] = 5 + sd[4] = 10 + # check that regular attr accesses work well + sd.attr = 4 + + def fn(x): + for value in sd.values(): + x = x * value + for key in sd.keys(): + x = x * key + for k, v in sd.items(): + x = x * k + x = x * v + # for k in sd: + # x = x * k + + if 1 in sd: + x = x * 2 + else: + x = x * 3 + + x = x * sd.get(2, 0) + x = x * sd.get(3, 4) + x = len(sd) * x + x = x * sd.attr + return x + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + # Ensure a recompilation + sd[6] = 15 + self.assertEqual(fn(x), opt_fn(x)) + + def test_dict_subclass_instantiation_return(self): + def fn(x): + sd = SimpleDict(x=5 * x) + sd["y"] = 10 + return sd + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(type(ref), type(res)) + self.assertEqual(ref["x"], res["x"]) + self.assertEqual(ref["y"], res["y"]) + + def test_dict_subclass_methods_fallback_mutation(self): + def fn(sd, x): + for value in sd.values(): + x = x * value + sd[6] = 14 + for key in sd.keys(): + x = x * key + for k, v in sd.items(): + x = x * k + x = x * v + # for k in sd: + # x = x * k + + if 1 in sd: + x = x * 2 + else: + x = x * 3 + + x = x * sd.get(2, 0) + x = x * sd.get(3, 4) + x = len(sd) * x + x = x * sd.attr + sd.attr = 10 + x = x * sd.attr + return x + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + sd1 = SimpleDict() + sd1[2] = 5 + sd1[4] = 10 + sd1.attr = 4 + + sd2 = SimpleDict() + sd2[2] = 5 + sd2[4] = 10 + sd2.attr = 4 + self.assertTrue(sd1 == sd2) + + self.assertEqual(fn(sd1, x), opt_fn(sd2, x)) + self.assertTrue(sd1 == sd2) + self.assertTrue(sd1.attr == sd2.attr) + + def test_dict_subclass_setitem(self): + class SetItemDict(dict): + def __setitem__(self, key, value): + super().__setitem__(key, value + 1) + + def fn(x): + sd = SetItemDict(x=5 * x) + sd["y"] = 10 + return sd + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(type(ref), type(res)) + self.assertEqual(ref["x"], res["x"]) + self.assertEqual(ref["y"], res["y"]) + + +def is_tensor(x): + import torch + + return isinstance(x, torch.Tensor) + + +class ModelOutput(OrderedDict): + """ + Copied from transformers. + """ + + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Subclasses of ModelOutput must use the @dataclass decorator + # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__ + # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed + # Just need to check that the current class is not ModelOutput + is_modeloutput_subclass = self.__class__ != ModelOutput + + if is_modeloutput_subclass and not is_dataclass(self): + raise TypeError( + f"{self.__module__}.{self.__class__.__name__} is not a dataclasss." + " This is a subclass of ModelOutput and so must use the @dataclass decorator." + ) + + def __post_init__(self): + """Check the ModelOutput dataclass. + + Only occurs if @dataclass decorator has been used. + """ + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + if not all(field.default is None for field in class_fields[1:]): + raise ValueError( + f"{self.__class__.__name__} should not have more than one required field." + ) + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all( + getattr(self, field.name) is None for field in class_fields[1:] + ) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for idx, element in enumerate(iterator): + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + if idx == 0: + # If we do not have an iterator of key/values, set it as attribute + self[class_fields[0].name] = first_field + else: + # If we have a mixed iterator, raise an error + raise ValueError( + f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." + ) + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + +@dataclass +class BaseModelOutput(ModelOutput): + """ + Copied from transformers + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutput): + """ + Copied from transformers + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class TestModelOutput(torch._dynamo.test_case.TestCase): + def test_mo_create(self): + def fn(a, b): + tmp = BaseModelOutput(a + 1, attentions=b + 3) + return tmp + + torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2) + + def test_mo_assign(self): + def fn(a, b): + tmp = BaseModelOutput(last_hidden_state=b + 3) + tmp.hidden_states = a + 7 + tmp["attentions"] = a + b + 6 + return tmp + + args = [torch.randn(10), torch.randn(10)] + obj1 = fn(*args) + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize_assert(cnts)(fn) + obj2 = opt_fn(*args) + self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state)) + self.assertTrue(same(obj1.hidden_states, obj2.hidden_states)) + self.assertTrue(same(obj1.attentions, obj2.attentions)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 4) + + def _common(self, fn, op_count): + args = [ + BaseModelOutput( + last_hidden_state=torch.randn(10), attentions=torch.randn(10) + ) + ] + obj1 = fn(*args) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize_assert(cnts)(fn) + obj2 = opt_fn(*args) + self.assertTrue(same(obj1, obj2)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, op_count) + + def test_mo_getattr(self): + def fn(obj: BaseModelOutput): + x = obj.last_hidden_state * 10 + if obj.hidden_states is not None: + x += obj.hidden_states + if obj.attentions is not None: + x += obj.attentions + return x + + self._common(fn, 2) + + def test_mo_getattr_missing(self): + def fn(obj: BaseModelOutput): + if getattr(obj, "asdf", None) is not None: + obj.asdf += 1 + return obj.attentions + 1 + + self._common(fn, 1) + + def test_mo_getitem(self): + def fn(obj: BaseModelOutput): + x = obj["last_hidden_state"] * 10 + if "hidden_stats" in obj: + x += obj["hidden_states"] + if "attentions" in obj: + x += obj["attentions"] + return x + + self._common(fn, 2) + + def test_mo_tuple(self): + def fn(obj: BaseModelOutput): + a, b = obj.to_tuple() + return a + b * 10 + + self._common(fn, 2) + + def test_mo_index(self): + def fn(obj: BaseModelOutput): + return obj[0] * 10 + obj[1] + + self._common(fn, 2) + + def test_mo_init(self): + @dataclasses.dataclass + class MyDataClass(ModelOutput): + a: torch.Tensor + b: torch.Tensor = None + c: torch.Tensor = None + d: torch.Tensor = None + e: torch.Tensor = None + + def fn(obj): + class_fields = dataclasses.fields(obj) + assert len(class_fields) + assert all(field.default is None for field in class_fields[1:]) + other_fields_are_none = all( + getattr(obj, field.name) is None for field in class_fields[1:] + ) + assert not other_fields_are_none + + total = getattr(obj, class_fields[0].name) + for field in class_fields[1:]: + v = getattr(obj, field.name) + if v is not None: + total += v + + return total + + tensors = [torch.randn(10), torch.randn(10), torch.randn(10)] + obj1 = MyDataClass(*tensors) + correct1 = fn(obj1) + + obj2 = MyDataClass(*tensors) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fn, backend=cnts) + self.assertTrue(same(opt_fn(obj2), correct1)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 2) + + def test_mo_init2(self): + # this ModelOutput subclass runs a different __post_init__ codepath + @dataclasses.dataclass + class MyDataClass(ModelOutput): + x: torch.FloatTensor = None + + def fn(x): + obj = MyDataClass(x=x * 5) + return obj + + inp = torch.randn(3, 3) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(inp).x, opt_fn(inp).x) + + def test_mo_init_with_disable(self): + # Can result in "non-function or method super: " + # graph breaks (although it may not be the first) + # Minimal repro for https://github.com/pytorch/pytorch/issues/126028 + @dataclasses.dataclass + class MyDataClass(ModelOutput): + x: torch.FloatTensor = None + + @torch._dynamo.disable(recursive=False) + def fn(x): + return MyDataClass(x=x) + + inp = torch.randn(3, 3) + opt_fn = torch.compile(fn, backend="eager") + self.assertEqual(fn(inp).x, opt_fn(inp).x) + + def test_mo_newkey(self): + obj = BaseModelOutput() + + def fn(obj): + return obj["wwww"] + 1 + + inp = torch.randn(3, 3) + obj["wwww"] = inp + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(obj), opt_fn(obj)) + + def test_mo_from_outside(self): + def fn(obj): + return obj.attentions + 1 + + obj = BaseModelOutput(attentions=torch.randn(3, 3)) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(obj), opt_fn(obj)) + + def test_mo_reconstruct_bytecode(self): + def fn(inp): + return BaseModelOutput(attentions=inp + 1) + + inp = torch.randn(3, 3) + opt_fn = torch.compile(fn, backend="eager") + self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions) + + def test_none(self): + class Model(torch.nn.Module): + def forward(self, x): + x = x + 1 + return CausalLMOutputWithPast(loss=None, logits=x)[0] + + model = Model() + opt_model = torch.compile(model, backend="eager", fullgraph=True) + x = torch.randn(1, 1, 1, 1) + + self.assertTrue(same(model(x), opt_model(x))) + + def test_reconstruction(self): + torch._export.utils.register_dataclass_as_pytree_node( + CausalLMOutputWithPast, + serialized_type_name="test_reconstruction_CausalLMOutputWithPast", + ) + + class Model(torch.nn.Module): + def forward(self, x): + x = x + 1 + return CausalLMOutputWithPast(loss=x, logits=None) + + model = Model() + x = torch.randn(1, 1, 1, 1) + eo = torch._dynamo.export(Model(), aten_graph=True)(x) + self.assertTrue(same(model(x), eo.graph_module(x))) + + +class TestModelOutputBert(TestCase): + def test_HF_bert_model_output(self, device): + class BertPooler(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.dense = torch.nn.Linear(768, 768).to(device) + self.activation = torch.nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + class BertEncoder(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, + hidden_states: torch.Tensor, + ) -> BaseModelOutputWithPastAndCrossAttentions: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + cross_attentions=None, + ) + + class BertModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.encoder = BertEncoder() + self.pooler = BertPooler() + + def forward( + self, + sequence_output: torch.Tensor, + ) -> BaseModelOutputWithPoolingAndCrossAttentions: + encoder_outputs = self.encoder(sequence_output) + # test __getitem__ and to_tuple + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + # test CustomDictVariable.create + result = BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + # test __setattr__ + result.pooler_output = pooled_output + # test __setitem__ + result["pooler_output"] = pooled_output + return result + + sequence_output = torch.rand(1, 12, 768).to(device) + model = BertModel() + orig_result = model(sequence_output) + compiled_model = torch.compile(model, backend="eager") + compiled_result = compiled_model(sequence_output) + self.assertTrue( + torch.allclose( + orig_result.last_hidden_state, compiled_result.last_hidden_state + ) + ) + self.assertTrue( + torch.allclose(orig_result.pooler_output, compiled_result.pooler_output) + ) + + +devices = ["cpu"] + +instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b7c828f58db6..5aa2d9e8802e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10599,7 +10599,7 @@ ShapeEnv not equal: field values don't match: foo() - def test_dict_subclass_cannot_be_initialized_in_graph(self): + def test_dict_subclass_initialization_in_graph(self): for super_class in ( collections.OrderedDict, dict, @@ -10615,11 +10615,10 @@ ShapeEnv not equal: field values don't match: assert "key" in c return c["key"] + 1 - fn_opt = torch.compile(fn, backend="eager", fullgraph=True) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable" - ): - print(fn_opt(torch.zeros(1))) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = torch.rand(4) + self.assertEqual(fn(x), opt_fn(x)) @wrapDeterministicFlagAPITest def test_backward_deterministic_mode_mismatch_warning(self): diff --git a/test/dynamo_expected_failures/TestSerialization.test_serialization_dill b/test/dynamo_expected_failures/TestSerialization.test_serialization_dill deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/test_serialization.py b/test/test_serialization.py index 302788f7c6a0..aea2cf1a6f05 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -313,6 +313,7 @@ class SerializationMixin: not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1, '"dill" not found or not correct version' ) + @skipIfTorchDynamo("Different behavior between 3.11 and 3.13, causing CI issues") def test_serialization_dill(self): x = torch.randn(5, 5) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 6f5beb4dc2b7..e758f825b39a 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import collections import contextlib import functools import inspect @@ -20,7 +21,7 @@ from .bytecode_transformation import ( from .codegen import PyCodegen from .exc import unimplemented from .source import GlobalSource, LocalCellSource, LocalSource, Source -from .utils import is_frozen_dataclass, nn_module_new, object_new +from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( AttributeMutation, AttributeMutationExisting, @@ -37,6 +38,17 @@ def _manual_update_dict(dict_from, dict_to): dict_to[k] = v +def _manual_dict_setitem(dict_from, dict_to, mro_index): + # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have + # to be careful because we don't want to trigger the user defined object + # setitem or clear. The mro_index is used to find the dict/OrderedDict from + # the class mro. + dict_class = type(dict_to).__mro__[mro_index] + dict_class.clear(dict_to) + for k, v in dict_from.items(): + dict_class.__setitem__(dict_to, k, v) + + class SideEffects: """ Track side effects (list mutation, setattr, etc) that need to be @@ -181,9 +193,9 @@ class SideEffects: @staticmethod def cls_supports_mutation_side_effects(cls): - return ( - inspect.getattr_static(cls, "__getattribute__", None) - is object.__getattribute__ + return inspect.getattr_static(cls, "__getattribute__", None) in ( + object.__getattribute__, + dict.__getattribute__, ) def is_attribute_mutation(self, item): @@ -254,6 +266,8 @@ class SideEffects: obj = torch.autograd.Function() elif issubclass(user_cls, torch.nn.Module): obj = nn_module_new(user_cls) + elif issubclass(user_cls, (dict, collections.OrderedDict)): + obj = dict_new(user_cls) else: try: obj = object_new(user_cls) @@ -284,6 +298,8 @@ class SideEffects: ] = variables.UserDefinedObjectVariable if issubclass(user_cls, torch.nn.Module): variable_cls = variables.UnspecializedNNModuleVariable + elif issubclass(user_cls, (dict, collections.OrderedDict)): + variable_cls = variables.UserDefinedDictVariable elif issubclass(user_cls, MutableMapping): variable_cls = variables.MutableMappingVariable elif is_frozen_dataclass(user_cls): @@ -442,7 +458,12 @@ class SideEffects: if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") cg.add_push_null( - lambda: cg.load_import_from(utils.__name__, "object_new") + lambda: cg.load_import_from( + utils.__name__, + "dict_new" + if isinstance(var, variables.UserDefinedDictVariable) + else "object_new", + ) ) cg(var.mutation_type.cls_source) cg.extend_output(create_call_function(1, False)) @@ -695,6 +716,58 @@ class SideEffects: suffixes.append([cg.create_store_deref(var.local_name)]) elif self.is_attribute_mutation(var): + if isinstance(var, variables.UserDefinedDictVariable): + # Do dict related update manually here. The store_attr + # mutations will be applied later. + varname_map = {} + for name in _manual_dict_setitem.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + try: + mro_index = type(var.value).__mro__.index( + collections.OrderedDict + ) + except ValueError: + mro_index = type(var.value).__mro__.index(dict) + + cg.extend_output( + [ + create_instruction("LOAD_CONST", argval=mro_index), + create_instruction( + "STORE_FAST", argval=varname_map["mro_index"] + ), + ] + ) + + cg(var.source) # type: ignore[attr-defined] + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["dict_to"] + ) + ] + ) + + cg(var._dict_vt, allow_cache=False) # Don't codegen via source + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["dict_from"] + ) + ] + ) + + dict_update_insts = bytecode_from_template( + _manual_dict_setitem, varname_map=varname_map + ) + + suffixes.append( + [ + *dict_update_insts, + create_instruction("POP_TOP"), + ] + ) + # Applying mutations involves two steps: 1) Push all # reconstructed objects onto the stack. 2) Call STORE_ATTR to # apply the mutations. diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c4e30ead1198..742e1893d817 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1896,6 +1896,14 @@ tuple_iterator: Type[Iterator[Any]] = type(iter(())) range_iterator: Type[Iterator[Any]] = type(iter(range(0))) tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] object_new = object.__new__ +dict_new = dict.__new__ +dict_methods = { + method + for method in itertools.chain( + dict.__dict__.values(), collections.OrderedDict.__dict__.values() + ) + if callable(method) +} def nn_module_new(cls): diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 8cbd7291b67c..548941d86e39 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -113,6 +113,7 @@ from .user_defined import ( MutableMappingVariable, RemovableHandleVariable, UserDefinedClassVariable, + UserDefinedDictVariable, UserDefinedObjectVariable, ) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5993bce53113..1b1a14851cb0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -227,6 +227,7 @@ from .user_defined import ( MutableMappingVariable, SourcelessGraphModuleVariable, UserDefinedClassVariable, + UserDefinedDictVariable, UserDefinedObjectVariable, ) @@ -1232,6 +1233,46 @@ class VariableBuilder: fake_script_obj, source=self.source, ) + elif ( + isinstance(value, (dict, collections.OrderedDict)) + and type(value).__new__ is dict.__new__ + ): + # Construct a dict_vt that will reside inside the UserDefinedDictVariable + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Guard on the key order + self.tx.output.guard_on_key_order.add(self.source.name()) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + source_key = ConstDictKeySource(self.get_source(), i) + key = LazyVariableTracker.create(k, source_key) + + source_value = GetItemSource(self.get_source(), source_key) + value = LazyVariableTracker.create(v, source_value) + + return key, value + + result = dict( + build_key_value(i, k, v) for i, (k, v) in enumerate(value.items()) + ) + + # NB: This is deliberately kept ValueMutationNew because dict_vt is + # an internal representation. dict_vt tracks the mutation on the + # dict side. side_effects infra uses the UserDefinedDictVariable to + # apply side-effects of this dict_vt. + dict_vt = ConstDictVariable( + result, + user_cls=collections.OrderedDict + if isinstance(value, collections.OrderedDict) + else dict, + mutation_type=ValueMutationNew(), + ) + + result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) elif issubclass(type(value), MutableMapping): self.install_guards(GuardBuilder.TYPE_MATCH) return MutableMappingVariable(value, source=self.source) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 30b4fceb2477..cf88edb9938f 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -283,7 +283,14 @@ class ConstDictVariable(VariableTracker): arg_hashable = args and is_hashable(args[0]) - if name == "__getitem__": + if name == "__init__": + temp_dict_vt = variables.BuiltinVariable(dict).call_dict( + tx, *args, **kwargs + ) + tx.output.side_effects.mutation(self) + self.items.update(temp_dict_vt.items) + return ConstantVariable.create(None) + elif name == "__getitem__": assert len(args) == 1 return self.getitem_const_raise_exception_if_absent(tx, args[0]) elif name == "items": diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 8e2bc0b9a958..3f5f6e4bb2ab 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -236,6 +236,11 @@ class SuperVariable(VariableTracker): self.objvar, attr, variables.DeletedVariable() ) return variables.ConstantVariable(None) + elif ( + isinstance(self.objvar, variables.UserDefinedDictVariable) + and inner_fn in self.objvar._dict_methods + ): + return self.objvar._dict_vt.call_method(tx, name, args, kwargs) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 56deb6306a4d..361a0f5d6cba 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -42,6 +42,7 @@ from ..utils import ( build_checkpoint_variable, build_invoke_subgraph_variable, check_constant_args, + dict_methods, get_custom_getattr, has_torch_function, is_frozen_dataclass, @@ -637,7 +638,7 @@ class UserDefinedClassVariable(UserDefinedVariable): new_fn = inspect.getattr_static(self.value, "__new__", None) if isinstance(new_fn, staticmethod): new_fn = new_fn.__func__ - return new_fn in (object.__new__, Generic.__new__) + return new_fn in (object.__new__, Generic.__new__, dict.__new__) def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.source: @@ -1441,6 +1442,42 @@ class RemovableHandleVariable(VariableTracker): return RemovableHandleClass +class UserDefinedDictVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of dict/OrderedDict. + + Internally, it uses a ConstDictVariable to represent the dict part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, dict_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._dict_vt = dict_vt + if self._dict_vt is None: + assert ( + self.source is None + ), "dict_vt must be constructed by builder.py when source is present" + self._dict_vt = variables.ConstDictVariable( + {}, mutation_type=ValueMutationNew() + ) + self._dict_methods = dict_methods + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._dict_methods: + return self._dict_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + class MutableMappingVariable(UserDefinedObjectVariable): _nonvar_fields = UserDefinedObjectVariable._nonvar_fields