mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Support user defined dicts (#143548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143548 Approved by: https://github.com/yanboliang, https://github.com/jansel, https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
9cb743d1f9
commit
4627cfd1f9
759
test/dynamo/test_dicts.py
Normal file
759
test/dynamo/test_dicts.py
Normal file
@ -0,0 +1,759 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# ruff: noqa: TRY002
|
||||
# flake8: noqa
|
||||
|
||||
import dataclasses
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch._functorch.config
|
||||
import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.testing import same
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class SimpleDict(dict):
|
||||
pass
|
||||
|
||||
|
||||
class DictTests(torch._dynamo.test_case.TestCase):
|
||||
def test_dict_subclass_instantiation(self):
|
||||
def fn(x):
|
||||
sd = SimpleDict(x=5)
|
||||
return sd["x"] * x
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_subclass_local_mutation(self):
|
||||
def fn(x):
|
||||
sd = SimpleDict(x=5)
|
||||
z = sd["x"] * x
|
||||
sd["x"] = 10
|
||||
return z * sd["x"]
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_subclass_local_with_non_dict_method(self):
|
||||
# Checks that add_1 method is inlined
|
||||
class MethodDict(dict):
|
||||
def add_1(self, x):
|
||||
return x + 1
|
||||
|
||||
def fn(x):
|
||||
sd = MethodDict(x=5)
|
||||
z = sd["x"] * x
|
||||
sd["x"] = 10
|
||||
return sd.add_1(z * sd["x"])
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_subclass_methods_fallback_readonly(self):
|
||||
sd = SimpleDict()
|
||||
sd[2] = 5
|
||||
sd[4] = 10
|
||||
# check that regular attr accesses work well
|
||||
sd.attr = 4
|
||||
|
||||
def fn(x):
|
||||
for value in sd.values():
|
||||
x = x * value
|
||||
for key in sd.keys():
|
||||
x = x * key
|
||||
for k, v in sd.items():
|
||||
x = x * k
|
||||
x = x * v
|
||||
# for k in sd:
|
||||
# x = x * k
|
||||
|
||||
if 1 in sd:
|
||||
x = x * 2
|
||||
else:
|
||||
x = x * 3
|
||||
|
||||
x = x * sd.get(2, 0)
|
||||
x = x * sd.get(3, 4)
|
||||
x = len(sd) * x
|
||||
x = x * sd.attr
|
||||
return x
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
# Ensure a recompilation
|
||||
sd[6] = 15
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_subclass_instantiation_return(self):
|
||||
def fn(x):
|
||||
sd = SimpleDict(x=5 * x)
|
||||
sd["y"] = 10
|
||||
return sd
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(type(ref), type(res))
|
||||
self.assertEqual(ref["x"], res["x"])
|
||||
self.assertEqual(ref["y"], res["y"])
|
||||
|
||||
def test_dict_subclass_methods_fallback_mutation(self):
|
||||
def fn(sd, x):
|
||||
for value in sd.values():
|
||||
x = x * value
|
||||
sd[6] = 14
|
||||
for key in sd.keys():
|
||||
x = x * key
|
||||
for k, v in sd.items():
|
||||
x = x * k
|
||||
x = x * v
|
||||
# for k in sd:
|
||||
# x = x * k
|
||||
|
||||
if 1 in sd:
|
||||
x = x * 2
|
||||
else:
|
||||
x = x * 3
|
||||
|
||||
x = x * sd.get(2, 0)
|
||||
x = x * sd.get(3, 4)
|
||||
x = len(sd) * x
|
||||
x = x * sd.attr
|
||||
sd.attr = 10
|
||||
x = x * sd.attr
|
||||
return x
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
|
||||
sd1 = SimpleDict()
|
||||
sd1[2] = 5
|
||||
sd1[4] = 10
|
||||
sd1.attr = 4
|
||||
|
||||
sd2 = SimpleDict()
|
||||
sd2[2] = 5
|
||||
sd2[4] = 10
|
||||
sd2.attr = 4
|
||||
self.assertTrue(sd1 == sd2)
|
||||
|
||||
self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
|
||||
self.assertTrue(sd1 == sd2)
|
||||
self.assertTrue(sd1.attr == sd2.attr)
|
||||
|
||||
def test_dict_subclass_setitem(self):
|
||||
class SetItemDict(dict):
|
||||
def __setitem__(self, key, value):
|
||||
super().__setitem__(key, value + 1)
|
||||
|
||||
def fn(x):
|
||||
sd = SetItemDict(x=5 * x)
|
||||
sd["y"] = 10
|
||||
return sd
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(type(ref), type(res))
|
||||
self.assertEqual(ref["x"], res["x"])
|
||||
self.assertEqual(ref["y"], res["y"])
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
import torch
|
||||
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
|
||||
class ModelOutput(OrderedDict):
|
||||
"""
|
||||
Copied from transformers.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Register subclasses as pytree nodes.
|
||||
|
||||
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
|
||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Subclasses of ModelOutput must use the @dataclass decorator
|
||||
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
|
||||
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
|
||||
# Just need to check that the current class is not ModelOutput
|
||||
is_modeloutput_subclass = self.__class__ != ModelOutput
|
||||
|
||||
if is_modeloutput_subclass and not is_dataclass(self):
|
||||
raise TypeError(
|
||||
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
|
||||
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Check the ModelOutput dataclass.
|
||||
|
||||
Only occurs if @dataclass decorator has been used.
|
||||
"""
|
||||
class_fields = fields(self)
|
||||
|
||||
# Safety and consistency checks
|
||||
if not len(class_fields):
|
||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||||
if not all(field.default is None for field in class_fields[1:]):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} should not have more than one required field."
|
||||
)
|
||||
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(
|
||||
getattr(self, field.name) is None for field in class_fields[1:]
|
||||
)
|
||||
|
||||
if other_fields_are_none and not is_tensor(first_field):
|
||||
if isinstance(first_field, dict):
|
||||
iterator = first_field.items()
|
||||
first_field_iterator = True
|
||||
else:
|
||||
try:
|
||||
iterator = iter(first_field)
|
||||
first_field_iterator = True
|
||||
except TypeError:
|
||||
first_field_iterator = False
|
||||
|
||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
||||
# set the associated fields
|
||||
if first_field_iterator:
|
||||
for idx, element in enumerate(iterator):
|
||||
if (
|
||||
not isinstance(element, (list, tuple))
|
||||
or not len(element) == 2
|
||||
or not isinstance(element[0], str)
|
||||
):
|
||||
if idx == 0:
|
||||
# If we do not have an iterator of key/values, set it as attribute
|
||||
self[class_fields[0].name] = first_field
|
||||
else:
|
||||
# If we have a mixed iterator, raise an error
|
||||
raise ValueError(
|
||||
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
|
||||
)
|
||||
break
|
||||
setattr(self, element[0], element[1])
|
||||
if element[1] is not None:
|
||||
self[element[0]] = element[1]
|
||||
elif first_field is not None:
|
||||
self[class_fields[0].name] = first_field
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(
|
||||
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
|
||||
)
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(
|
||||
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
|
||||
)
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(
|
||||
f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
|
||||
)
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(
|
||||
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
|
||||
)
|
||||
|
||||
def __getitem__(self, k):
|
||||
if isinstance(k, str):
|
||||
inner_dict = dict(self.items())
|
||||
return inner_dict[k]
|
||||
else:
|
||||
return self.to_tuple()[k]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self.keys() and value is not None:
|
||||
# Don't call self.__setitem__ to avoid recursion errors
|
||||
super().__setitem__(name, value)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Will raise a KeyException if needed
|
||||
super().__setitem__(key, value)
|
||||
# Don't call self.__setattr__ to avoid recursion errors
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __reduce__(self):
|
||||
if not is_dataclass(self):
|
||||
return super().__reduce__()
|
||||
callable, _args, *remaining = super().__reduce__()
|
||||
args = tuple(getattr(self, field.name) for field in fields(self))
|
||||
return callable, args, *remaining
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
"""
|
||||
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||||
"""
|
||||
return tuple(self[k] for k in self.keys())
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutput(ModelOutput):
|
||||
"""
|
||||
Copied from transformers
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Copied from transformers
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
||||
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
||||
the classification token after processing through a linear layer and a tanh activation function. The linear
|
||||
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
input) to speed up sequential decoding.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
class TestModelOutput(torch._dynamo.test_case.TestCase):
|
||||
def test_mo_create(self):
|
||||
def fn(a, b):
|
||||
tmp = BaseModelOutput(a + 1, attentions=b + 3)
|
||||
return tmp
|
||||
|
||||
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
|
||||
|
||||
def test_mo_assign(self):
|
||||
def fn(a, b):
|
||||
tmp = BaseModelOutput(last_hidden_state=b + 3)
|
||||
tmp.hidden_states = a + 7
|
||||
tmp["attentions"] = a + b + 6
|
||||
return tmp
|
||||
|
||||
args = [torch.randn(10), torch.randn(10)]
|
||||
obj1 = fn(*args)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
||||
obj2 = opt_fn(*args)
|
||||
self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
|
||||
self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
|
||||
self.assertTrue(same(obj1.attentions, obj2.attentions))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
def _common(self, fn, op_count):
|
||||
args = [
|
||||
BaseModelOutput(
|
||||
last_hidden_state=torch.randn(10), attentions=torch.randn(10)
|
||||
)
|
||||
]
|
||||
obj1 = fn(*args)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
||||
obj2 = opt_fn(*args)
|
||||
self.assertTrue(same(obj1, obj2))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, op_count)
|
||||
|
||||
def test_mo_getattr(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
x = obj.last_hidden_state * 10
|
||||
if obj.hidden_states is not None:
|
||||
x += obj.hidden_states
|
||||
if obj.attentions is not None:
|
||||
x += obj.attentions
|
||||
return x
|
||||
|
||||
self._common(fn, 2)
|
||||
|
||||
def test_mo_getattr_missing(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
if getattr(obj, "asdf", None) is not None:
|
||||
obj.asdf += 1
|
||||
return obj.attentions + 1
|
||||
|
||||
self._common(fn, 1)
|
||||
|
||||
def test_mo_getitem(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
x = obj["last_hidden_state"] * 10
|
||||
if "hidden_stats" in obj:
|
||||
x += obj["hidden_states"]
|
||||
if "attentions" in obj:
|
||||
x += obj["attentions"]
|
||||
return x
|
||||
|
||||
self._common(fn, 2)
|
||||
|
||||
def test_mo_tuple(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
a, b = obj.to_tuple()
|
||||
return a + b * 10
|
||||
|
||||
self._common(fn, 2)
|
||||
|
||||
def test_mo_index(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
return obj[0] * 10 + obj[1]
|
||||
|
||||
self._common(fn, 2)
|
||||
|
||||
def test_mo_init(self):
|
||||
@dataclasses.dataclass
|
||||
class MyDataClass(ModelOutput):
|
||||
a: torch.Tensor
|
||||
b: torch.Tensor = None
|
||||
c: torch.Tensor = None
|
||||
d: torch.Tensor = None
|
||||
e: torch.Tensor = None
|
||||
|
||||
def fn(obj):
|
||||
class_fields = dataclasses.fields(obj)
|
||||
assert len(class_fields)
|
||||
assert all(field.default is None for field in class_fields[1:])
|
||||
other_fields_are_none = all(
|
||||
getattr(obj, field.name) is None for field in class_fields[1:]
|
||||
)
|
||||
assert not other_fields_are_none
|
||||
|
||||
total = getattr(obj, class_fields[0].name)
|
||||
for field in class_fields[1:]:
|
||||
v = getattr(obj, field.name)
|
||||
if v is not None:
|
||||
total += v
|
||||
|
||||
return total
|
||||
|
||||
tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
||||
obj1 = MyDataClass(*tensors)
|
||||
correct1 = fn(obj1)
|
||||
|
||||
obj2 = MyDataClass(*tensors)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
self.assertTrue(same(opt_fn(obj2), correct1))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
def test_mo_init2(self):
|
||||
# this ModelOutput subclass runs a different __post_init__ codepath
|
||||
@dataclasses.dataclass
|
||||
class MyDataClass(ModelOutput):
|
||||
x: torch.FloatTensor = None
|
||||
|
||||
def fn(x):
|
||||
obj = MyDataClass(x=x * 5)
|
||||
return obj
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(inp).x, opt_fn(inp).x)
|
||||
|
||||
def test_mo_init_with_disable(self):
|
||||
# Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>"
|
||||
# graph breaks (although it may not be the first)
|
||||
# Minimal repro for https://github.com/pytorch/pytorch/issues/126028
|
||||
@dataclasses.dataclass
|
||||
class MyDataClass(ModelOutput):
|
||||
x: torch.FloatTensor = None
|
||||
|
||||
@torch._dynamo.disable(recursive=False)
|
||||
def fn(x):
|
||||
return MyDataClass(x=x)
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
self.assertEqual(fn(inp).x, opt_fn(inp).x)
|
||||
|
||||
def test_mo_newkey(self):
|
||||
obj = BaseModelOutput()
|
||||
|
||||
def fn(obj):
|
||||
return obj["wwww"] + 1
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
obj["wwww"] = inp
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(obj), opt_fn(obj))
|
||||
|
||||
def test_mo_from_outside(self):
|
||||
def fn(obj):
|
||||
return obj.attentions + 1
|
||||
|
||||
obj = BaseModelOutput(attentions=torch.randn(3, 3))
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(obj), opt_fn(obj))
|
||||
|
||||
def test_mo_reconstruct_bytecode(self):
|
||||
def fn(inp):
|
||||
return BaseModelOutput(attentions=inp + 1)
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions)
|
||||
|
||||
def test_none(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 1
|
||||
return CausalLMOutputWithPast(loss=None, logits=x)[0]
|
||||
|
||||
model = Model()
|
||||
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||
x = torch.randn(1, 1, 1, 1)
|
||||
|
||||
self.assertTrue(same(model(x), opt_model(x)))
|
||||
|
||||
def test_reconstruction(self):
|
||||
torch._export.utils.register_dataclass_as_pytree_node(
|
||||
CausalLMOutputWithPast,
|
||||
serialized_type_name="test_reconstruction_CausalLMOutputWithPast",
|
||||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 1
|
||||
return CausalLMOutputWithPast(loss=x, logits=None)
|
||||
|
||||
model = Model()
|
||||
x = torch.randn(1, 1, 1, 1)
|
||||
eo = torch._dynamo.export(Model(), aten_graph=True)(x)
|
||||
self.assertTrue(same(model(x), eo.graph_module(x)))
|
||||
|
||||
|
||||
class TestModelOutputBert(TestCase):
|
||||
def test_HF_bert_model_output(self, device):
|
||||
class BertPooler(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dense = torch.nn.Linear(768, 768).to(device)
|
||||
self.activation = torch.nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
class BertEncoder(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> BaseModelOutputWithPastAndCrossAttentions:
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
cross_attentions=None,
|
||||
)
|
||||
|
||||
class BertModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.encoder = BertEncoder()
|
||||
self.pooler = BertPooler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sequence_output: torch.Tensor,
|
||||
) -> BaseModelOutputWithPoolingAndCrossAttentions:
|
||||
encoder_outputs = self.encoder(sequence_output)
|
||||
# test __getitem__ and to_tuple
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = (
|
||||
self.pooler(sequence_output) if self.pooler is not None else None
|
||||
)
|
||||
# test CustomDictVariable.create
|
||||
result = BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
# test __setattr__
|
||||
result.pooler_output = pooled_output
|
||||
# test __setitem__
|
||||
result["pooler_output"] = pooled_output
|
||||
return result
|
||||
|
||||
sequence_output = torch.rand(1, 12, 768).to(device)
|
||||
model = BertModel()
|
||||
orig_result = model(sequence_output)
|
||||
compiled_model = torch.compile(model, backend="eager")
|
||||
compiled_result = compiled_model(sequence_output)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
orig_result.last_hidden_state, compiled_result.last_hidden_state
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(orig_result.pooler_output, compiled_result.pooler_output)
|
||||
)
|
||||
|
||||
|
||||
devices = ["cpu"]
|
||||
|
||||
instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -10599,7 +10599,7 @@ ShapeEnv not equal: field values don't match:
|
||||
|
||||
foo()
|
||||
|
||||
def test_dict_subclass_cannot_be_initialized_in_graph(self):
|
||||
def test_dict_subclass_initialization_in_graph(self):
|
||||
for super_class in (
|
||||
collections.OrderedDict,
|
||||
dict,
|
||||
@ -10615,11 +10615,10 @@ ShapeEnv not equal: field values don't match:
|
||||
assert "key" in c
|
||||
return c["key"] + 1
|
||||
|
||||
fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable"
|
||||
):
|
||||
print(fn_opt(torch.zeros(1)))
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
|
||||
x = torch.rand(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
@wrapDeterministicFlagAPITest
|
||||
def test_backward_deterministic_mode_mismatch_warning(self):
|
||||
|
@ -313,6 +313,7 @@ class SerializationMixin:
|
||||
not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1,
|
||||
'"dill" not found or not correct version'
|
||||
)
|
||||
@skipIfTorchDynamo("Different behavior between 3.11 and 3.13, causing CI issues")
|
||||
def test_serialization_dill(self):
|
||||
x = torch.randn(5, 5)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
@ -20,7 +21,7 @@ from .bytecode_transformation import (
|
||||
from .codegen import PyCodegen
|
||||
from .exc import unimplemented
|
||||
from .source import GlobalSource, LocalCellSource, LocalSource, Source
|
||||
from .utils import is_frozen_dataclass, nn_module_new, object_new
|
||||
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new
|
||||
from .variables.base import (
|
||||
AttributeMutation,
|
||||
AttributeMutationExisting,
|
||||
@ -37,6 +38,17 @@ def _manual_update_dict(dict_from, dict_to):
|
||||
dict_to[k] = v
|
||||
|
||||
|
||||
def _manual_dict_setitem(dict_from, dict_to, mro_index):
|
||||
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
|
||||
# to be careful because we don't want to trigger the user defined object
|
||||
# setitem or clear. The mro_index is used to find the dict/OrderedDict from
|
||||
# the class mro.
|
||||
dict_class = type(dict_to).__mro__[mro_index]
|
||||
dict_class.clear(dict_to)
|
||||
for k, v in dict_from.items():
|
||||
dict_class.__setitem__(dict_to, k, v)
|
||||
|
||||
|
||||
class SideEffects:
|
||||
"""
|
||||
Track side effects (list mutation, setattr, etc) that need to be
|
||||
@ -181,9 +193,9 @@ class SideEffects:
|
||||
|
||||
@staticmethod
|
||||
def cls_supports_mutation_side_effects(cls):
|
||||
return (
|
||||
inspect.getattr_static(cls, "__getattribute__", None)
|
||||
is object.__getattribute__
|
||||
return inspect.getattr_static(cls, "__getattribute__", None) in (
|
||||
object.__getattribute__,
|
||||
dict.__getattribute__,
|
||||
)
|
||||
|
||||
def is_attribute_mutation(self, item):
|
||||
@ -254,6 +266,8 @@ class SideEffects:
|
||||
obj = torch.autograd.Function()
|
||||
elif issubclass(user_cls, torch.nn.Module):
|
||||
obj = nn_module_new(user_cls)
|
||||
elif issubclass(user_cls, (dict, collections.OrderedDict)):
|
||||
obj = dict_new(user_cls)
|
||||
else:
|
||||
try:
|
||||
obj = object_new(user_cls)
|
||||
@ -284,6 +298,8 @@ class SideEffects:
|
||||
] = variables.UserDefinedObjectVariable
|
||||
if issubclass(user_cls, torch.nn.Module):
|
||||
variable_cls = variables.UnspecializedNNModuleVariable
|
||||
elif issubclass(user_cls, (dict, collections.OrderedDict)):
|
||||
variable_cls = variables.UserDefinedDictVariable
|
||||
elif issubclass(user_cls, MutableMapping):
|
||||
variable_cls = variables.MutableMappingVariable
|
||||
elif is_frozen_dataclass(user_cls):
|
||||
@ -442,7 +458,12 @@ class SideEffects:
|
||||
if isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented("AutogradFunctionContextVariable escaped")
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(utils.__name__, "object_new")
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__,
|
||||
"dict_new"
|
||||
if isinstance(var, variables.UserDefinedDictVariable)
|
||||
else "object_new",
|
||||
)
|
||||
)
|
||||
cg(var.mutation_type.cls_source)
|
||||
cg.extend_output(create_call_function(1, False))
|
||||
@ -695,6 +716,58 @@ class SideEffects:
|
||||
suffixes.append([cg.create_store_deref(var.local_name)])
|
||||
|
||||
elif self.is_attribute_mutation(var):
|
||||
if isinstance(var, variables.UserDefinedDictVariable):
|
||||
# Do dict related update manually here. The store_attr
|
||||
# mutations will be applied later.
|
||||
varname_map = {}
|
||||
for name in _manual_dict_setitem.__code__.co_varnames:
|
||||
varname_map[name] = cg.tx.output.new_var()
|
||||
|
||||
try:
|
||||
mro_index = type(var.value).__mro__.index(
|
||||
collections.OrderedDict
|
||||
)
|
||||
except ValueError:
|
||||
mro_index = type(var.value).__mro__.index(dict)
|
||||
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("LOAD_CONST", argval=mro_index),
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=varname_map["mro_index"]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
cg(var.source) # type: ignore[attr-defined]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=varname_map["dict_to"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=varname_map["dict_from"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
dict_update_insts = bytecode_from_template(
|
||||
_manual_dict_setitem, varname_map=varname_map
|
||||
)
|
||||
|
||||
suffixes.append(
|
||||
[
|
||||
*dict_update_insts,
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
)
|
||||
|
||||
# Applying mutations involves two steps: 1) Push all
|
||||
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
|
||||
# apply the mutations.
|
||||
|
@ -1896,6 +1896,14 @@ tuple_iterator: Type[Iterator[Any]] = type(iter(()))
|
||||
range_iterator: Type[Iterator[Any]] = type(iter(range(0)))
|
||||
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
|
||||
object_new = object.__new__
|
||||
dict_new = dict.__new__
|
||||
dict_methods = {
|
||||
method
|
||||
for method in itertools.chain(
|
||||
dict.__dict__.values(), collections.OrderedDict.__dict__.values()
|
||||
)
|
||||
if callable(method)
|
||||
}
|
||||
|
||||
|
||||
def nn_module_new(cls):
|
||||
|
@ -113,6 +113,7 @@ from .user_defined import (
|
||||
MutableMappingVariable,
|
||||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedObjectVariable,
|
||||
)
|
||||
|
||||
|
@ -227,6 +227,7 @@ from .user_defined import (
|
||||
MutableMappingVariable,
|
||||
SourcelessGraphModuleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedObjectVariable,
|
||||
)
|
||||
|
||||
@ -1232,6 +1233,46 @@ class VariableBuilder:
|
||||
fake_script_obj,
|
||||
source=self.source,
|
||||
)
|
||||
elif (
|
||||
isinstance(value, (dict, collections.OrderedDict))
|
||||
and type(value).__new__ is dict.__new__
|
||||
):
|
||||
# Construct a dict_vt that will reside inside the UserDefinedDictVariable
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
|
||||
|
||||
# Guard on the key order
|
||||
self.tx.output.guard_on_key_order.add(self.source.name())
|
||||
|
||||
# We need all the keys to be hashable. We do this within the
|
||||
# _HashableTracker class in dicts.py
|
||||
def build_key_value(i, k, v):
|
||||
source_key = ConstDictKeySource(self.get_source(), i)
|
||||
key = LazyVariableTracker.create(k, source_key)
|
||||
|
||||
source_value = GetItemSource(self.get_source(), source_key)
|
||||
value = LazyVariableTracker.create(v, source_value)
|
||||
|
||||
return key, value
|
||||
|
||||
result = dict(
|
||||
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
|
||||
)
|
||||
|
||||
# NB: This is deliberately kept ValueMutationNew because dict_vt is
|
||||
# an internal representation. dict_vt tracks the mutation on the
|
||||
# dict side. side_effects infra uses the UserDefinedDictVariable to
|
||||
# apply side-effects of this dict_vt.
|
||||
dict_vt = ConstDictVariable(
|
||||
result,
|
||||
user_cls=collections.OrderedDict
|
||||
if isinstance(value, collections.OrderedDict)
|
||||
else dict,
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
|
||||
result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif issubclass(type(value), MutableMapping):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return MutableMappingVariable(value, source=self.source)
|
||||
|
@ -283,7 +283,14 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
arg_hashable = args and is_hashable(args[0])
|
||||
|
||||
if name == "__getitem__":
|
||||
if name == "__init__":
|
||||
temp_dict_vt = variables.BuiltinVariable(dict).call_dict(
|
||||
tx, *args, **kwargs
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.update(temp_dict_vt.items)
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__getitem__":
|
||||
assert len(args) == 1
|
||||
return self.getitem_const_raise_exception_if_absent(tx, args[0])
|
||||
elif name == "items":
|
||||
|
@ -236,6 +236,11 @@ class SuperVariable(VariableTracker):
|
||||
self.objvar, attr, variables.DeletedVariable()
|
||||
)
|
||||
return variables.ConstantVariable(None)
|
||||
elif (
|
||||
isinstance(self.objvar, variables.UserDefinedDictVariable)
|
||||
and inner_fn in self.objvar._dict_methods
|
||||
):
|
||||
return self.objvar._dict_vt.call_method(tx, name, args, kwargs)
|
||||
|
||||
unimplemented(f"non-function or method super: {inner_fn}")
|
||||
|
||||
|
@ -42,6 +42,7 @@ from ..utils import (
|
||||
build_checkpoint_variable,
|
||||
build_invoke_subgraph_variable,
|
||||
check_constant_args,
|
||||
dict_methods,
|
||||
get_custom_getattr,
|
||||
has_torch_function,
|
||||
is_frozen_dataclass,
|
||||
@ -637,7 +638,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
new_fn = inspect.getattr_static(self.value, "__new__", None)
|
||||
if isinstance(new_fn, staticmethod):
|
||||
new_fn = new_fn.__func__
|
||||
return new_fn in (object.__new__, Generic.__new__)
|
||||
return new_fn in (object.__new__, Generic.__new__, dict.__new__)
|
||||
|
||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
if self.source:
|
||||
@ -1441,6 +1442,42 @@ class RemovableHandleVariable(VariableTracker):
|
||||
return RemovableHandleClass
|
||||
|
||||
|
||||
class UserDefinedDictVariable(UserDefinedObjectVariable):
|
||||
"""
|
||||
Represents user defined objects that are subclasses of dict/OrderedDict.
|
||||
|
||||
Internally, it uses a ConstDictVariable to represent the dict part of the
|
||||
variable tracker. For everything else, it falls back to
|
||||
UserDefinedObjectVariable.
|
||||
"""
|
||||
|
||||
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
|
||||
|
||||
def __init__(self, value, dict_vt=None, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self._dict_vt = dict_vt
|
||||
if self._dict_vt is None:
|
||||
assert (
|
||||
self.source is None
|
||||
), "dict_vt must be constructed by builder.py when source is present"
|
||||
self._dict_vt = variables.ConstDictVariable(
|
||||
{}, mutation_type=ValueMutationNew()
|
||||
)
|
||||
self._dict_methods = dict_methods
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
method = self._maybe_get_baseclass_method(name)
|
||||
if method in self._dict_methods:
|
||||
return self._dict_vt.call_method(tx, name, args, kwargs)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
class MutableMappingVariable(UserDefinedObjectVariable):
|
||||
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
|
||||
|
||||
|
Reference in New Issue
Block a user