Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard (#117625)

Make variables in dict lazy and remove DICT_KEYS guard.

We build the keys of a dict depth-first and we rely on the guards of
each element in the dict to create the correct guards. This allows us to
remove the rather buggy DICT_KEYS guard and make the guard lazy.
The guards are not completely lazy yet, as we instantiate them in
`_HashableTracker._eq_impl` but it should be possible to make them
truly lazy.

Also, adding new types to the supported types within keys should be less
error prone.

This is marginally less efficient when we graph break, but in turn we
should graph break much less. It also  makes the dicts code easier to maintain
(removes `is_hashable_python_var`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117625
Approved by: https://github.com/jansel, https://github.com/peterbell10, https://github.com/anijain2305
ghstack dependencies: #117982, #118098, #117983
This commit is contained in:
lezcano
2024-02-01 10:35:50 +00:00
committed by PyTorch MergeBot
parent 75a5c41921
commit eb2bdfae88
11 changed files with 62 additions and 80 deletions

View File

@ -322,22 +322,22 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
l_d_x_ = L_d_x_
l_d_y_0_ = L_d_y_0_
l_d_y_1_2_ = L_d_y_1_2_
def forward(self, L_d_dict_keys_getitem_L_d_0_ : torch.Tensor, L_d_dict_keys_getitem_L_d_1_0_ : torch.Tensor, L_d_dict_keys_getitem_L_d_1_1_2_ : torch.Tensor):
l_d_dict_keys_getitem_l_d_0_ = L_d_dict_keys_getitem_L_d_0_
l_d_dict_keys_getitem_l_d_1_0_ = L_d_dict_keys_getitem_L_d_1_0_
l_d_dict_keys_getitem_l_d_1_1_2_ = L_d_dict_keys_getitem_L_d_1_1_2_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_dict_keys_getitem_l_d_0_, l_d_dict_keys_getitem_l_d_1_0_, l_d_dict_keys_getitem_l_d_1_1_2_); wrap_body_0 = l_d_dict_keys_getitem_l_d_0_ = l_d_dict_keys_getitem_l_d_1_0_ = l_d_dict_keys_getitem_l_d_1_1_2_ = None
getitem = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
sin = l_d_x_.sin(); l_d_x_ = None
cos = l_d_y_0_.cos(); l_d_y_0_ = None
def forward(self, l_d_dict_keys_getitem_l_d_0_, l_d_dict_keys_getitem_l_d_1_0_, l_d_dict_keys_getitem_l_d_1_1_2_):
sin = l_d_dict_keys_getitem_l_d_0_.sin(); l_d_dict_keys_getitem_l_d_0_ = None
cos = l_d_dict_keys_getitem_l_d_1_0_.cos(); l_d_dict_keys_getitem_l_d_1_0_ = None
add = sin + cos; sin = cos = None
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
sin_1 = l_d_dict_keys_getitem_l_d_1_1_2_.sin(); l_d_dict_keys_getitem_l_d_1_1_2_ = None
sub = add - sin_1; add = sin_1 = None
return (sub,)
""", # NOQA: B950
@ -2338,7 +2338,7 @@ class GraphModule(torch.nn.Module):
self.assertExpectedInline(
graph.code.strip(),
"""\
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_dict_keys_getitem_L_pytree_in_4_0_ : torch.Tensor):
l_pred_ = L_pred_
l_pytree_in_0_ = L_pytree_in_0_
l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
@ -2346,10 +2346,10 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
l_pytree_in_3_0_ = L_pytree_in_3_0_
l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
l_pytree_in_3_2_ = L_pytree_in_3_2_
l_pytree_in_4_g_ = L_pytree_in_4_g_
l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_ = L_pytree_in_4_dict_keys_getitem_L_pytree_in_4_0_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_ = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)

View File

@ -2516,6 +2516,20 @@ utils_device.CURRENT_DEVICE == None""".split(
# Extra calls don't recompile
self.assertEqual(cnts.frame_count, 2)
def test_dict_namedtuple(self):
def fn(d):
return d[3] * 2
args1 = {collections.namedtuple: None, 3: torch.randn(3)}
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(fn(args1), opt_fn(args1))
self.assertEqual(cnts.frame_count, 1)
# Test a failing namedtuple guard
args2 = {2: None, 3: torch.randn(3)}
self.assertEqual(fn(args2), opt_fn(args2))
self.assertEqual(cnts.frame_count, 2)
def test_dict_order_keys_tensors(self):
def fn(d, x):
return d[x] + 3

View File

@ -1843,7 +1843,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
handle.remove()
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 7)
self.assertTrue("forward_hooks.keys" in failure_reason)
self.assertTrue("forward_hooks" in failure_reason)
self.assertEqual(cc.frame_count, 1 + 1)
self.assertEqual(cc.op_count, 6 + 4)

View File

@ -688,7 +688,7 @@ for test_param in supported_tests:
if 'constructor' not in test_param:
name = test_param.pop('module_name')
test_param['constructor'] = getattr(nn, name)
decorator = test_param.pop('decorator', None)
decorator = test_param.pop('decorator', lambda test: test)
test = ContextManagerTests(**test_param)
test_name = test.get_name()
if hasattr(TestExpandedWeightModule, test_name):
@ -696,16 +696,14 @@ for test_param in supported_tests:
test_name_multi_input = test.get_name() + "_multiple_inputs"
if hasattr(TestExpandedWeightModule, test_name_multi_input):
raise RuntimeError('Found two tests with the same name: ' + test_name)
if decorator is not None:
fn = decorator(fn) # noqa: F821
if test.test_cpu:
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self, 'cpu'))
setattr(TestExpandedWeightModule, test_name, decorator(lambda self, test=test: test.test_context_manager(self, 'cpu')))
setattr(TestExpandedWeightModule, test_name_multi_input,
lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu'))
decorator(lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu')))
if TEST_CUDA and test.test_cuda:
# since this checks derivatives, only use double for precision
setattr(TestExpandedWeightModule, test_name + '_cuda_double',
lambda self, test=test: test.test_context_manager(self, 'cuda'))
decorator(lambda self, test=test: test.test_context_manager(self, 'cuda')))
# ------------- HELPER FUNCTIONS -----------------

View File

@ -668,7 +668,7 @@ class GuardBuilder(GuardBuilderBase):
self._produce_guard_code(guard, [shape_guard], shape_env=True)
def TENSOR_MATCH(self, guard: Guard, value=None):
if guard.is_nn_module():
if guard.is_nn_module() or guard.originating_source.is_dict_key():
self.ID_MATCH(guard)
else:
if isinstance(value, TensorWeakRef):

View File

@ -346,6 +346,9 @@ class GetItemSource(ChainedSource):
@dataclasses.dataclass(frozen=True)
class ConstDictKeySource(GetItemSource):
def is_dict_key(self):
return True
def reconstruct(self, codegen):
return [
*codegen.create_load_import_from(utils.__name__, "dict_keys_getitem"),

View File

@ -96,7 +96,6 @@ from .dicts import (
DataClassVariable,
DefaultDictVariable,
HFPretrainedConfigVariable,
is_hashable_python_var,
PythonSysModulesVariable,
SetVariable,
)
@ -413,9 +412,7 @@ class VariableBuilder:
return ConstDictVariable(result, type(value))
elif value is sys.modules:
return PythonSysModulesVariable(source=self.source)
elif istype(
value, (dict, collections.defaultdict, collections.OrderedDict)
) and all(is_hashable_python_var(k) for k in value.keys()):
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
if not value and self.get_source().is_nn_module():
# It is faster to guard on 'false' property than to guard
# on actual dict keys, but we can't do this fast guard in general because
@ -426,26 +423,22 @@ class VariableBuilder:
# but not completely secure job ensuring a property wasn't changed.
self.install_guards(GuardBuilder.BOOL_FALSE)
else:
self.install_guards(GuardBuilder.DICT_KEYS)
self.install_guards(GuardBuilder.LIST_LENGTH)
idx = 0
def build_key_value(k, v):
nonlocal idx
if ConstantVariable.is_literal(k):
key = ConstantVariable.create(k)
source_key = k
else:
source_key = ConstDictKeySource(self.get_source(), idx)
key = VariableBuilder(self.tx, source_key)(k)
# We need all the keys to be hashable. We do this within the
# _HashableTracker class in dicts.py
def build_key_value(i, k, v):
source_key = ConstDictKeySource(self.get_source(), i)
key = LazyVariableTracker.create(k, source_key)
source_value = GetItemSource(self.get_source(), source_key)
value = LazyVariableTracker.create(v, source_value)
idx += 1
return key, value
result = dict(build_key_value(k, v) for k, v in value.items())
result = dict(
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
)
if istype(value, collections.defaultdict):
result = DefaultDictVariable(

View File

@ -2,15 +2,11 @@
import collections
import dataclasses
import enum
import functools
import inspect
import sys
from types import MethodWrapperType
from typing import Dict, List, Optional
import torch
from torch._subclasses.fake_tensor import is_fake
from .. import variables
@ -28,32 +24,12 @@ from ..utils import dict_keys, dict_values, istype, specialize_symnode
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
# Note: [Adding a new supported class the keys of ConstDictVarialble]
# You'll need to add it to:
# - `is_hashable_python_var` in this file
# - `is_hashable` in this file
# - `const_repr` in util.py, and perhaps modify DICT_KEYS in guards.py
def is_hashable_python_var(x):
# IMPORTANT: Keep me in sync with is_hashable!
# Even better, we should have a map of functions connecting the two
from torch import Tensor
from ..trace_rules import is_builtin_callable, is_numpy
return (
ConstantVariable.is_literal(x)
or isinstance(x, (Tensor, enum.Enum, type, torch.nn.Module, MethodWrapperType))
or is_builtin_callable(x)
or (isinstance(x, tuple) and all(is_hashable_python_var(e) for e in x))
or is_numpy(x)
)
# [Adding a new supported class within the keys of ConstDictVarialble]
# - Add its tracker type to is_hashable
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def is_hashable(x):
# IMPORTANT: Keep me in sync with is_hashable_python_var!
# Even better, we should have a map of functions connecting the two
if isinstance(x, variables.TensorVariable):
# Tensors are hashable if they have an example_value (a fake tensor)
# Most VT's should have one.
@ -201,16 +177,6 @@ class ConstDictVariable(VariableTracker):
else:
return [create_instruction("BUILD_MAP", arg=len(self.items))]
@staticmethod
def _wrap_keys_python_var(d):
"""Wrap the keys of a dictionary with python objs as keys into Hashable objects"""
assert all(is_hashable_python_var(k) for k in d.keys())
Hashable = ConstDictVariable._HashableTracker
from .builder import SourcelessBuilder
build = SourcelessBuilder()
return {Hashable(build(k)): v for k, v in d.items()}
def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
@ -290,8 +256,10 @@ class ConstDictVariable(VariableTracker):
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
# all keys in kwargs are valid (`str`s)
kwargs = ConstDictVariable._wrap_keys_python_var(kwargs)
# Wrap strings
kwargs = {
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
}
self.items.update(kwargs)
return ConstantVariable.create(None)
elif name in ("get", "__getattr__") and args[0] in self:

View File

@ -765,6 +765,9 @@ def tracing(context: Optional[TracingContext]):
# TODO(voz): Consider a toplevel torch/_source.py
@dataclasses.dataclass(frozen=True)
class Source:
def is_dict_key(self):
return False
def reconstruct(self, codegen):
raise NotImplementedError()
@ -788,6 +791,10 @@ class Source:
class ChainedSource(Source):
base: Source
def is_dict_key(self):
# Recurse until you either hit a ConstDictKey or a Source
return self.base.is_dict_key()
def detect_fake_mode(inputs: Any = None):
"""

View File

@ -16,7 +16,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import _reduction as _Reduction
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
gradcheck, gradgradcheck, set_default_dtype
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
from torch.autograd import Variable
@ -1709,6 +1709,7 @@ new_module_tests = [
input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
check_gradgrad=False,
default_dtype=torch.double,
decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
),
dict(
module_name='Embedding',
@ -1718,6 +1719,7 @@ new_module_tests = [
check_gradgrad=False,
desc='discontiguous',
default_dtype=torch.double,
decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
),
dict(
module_name='EmbeddingBag',

View File

@ -1611,7 +1611,6 @@ dynamo_expected_failures = {
"TestTorchTidyProfiler.test_tensorimpl_invalidation_full", # profiler/test_profiler
"TestProfiler.test_profiler_tracing", # profiler/test_profiler
"TestProfiler.test_is_profiler_enabled", # profiler/test_profiler
"TestExperimentalUtils.test_utils_compute_idle_time", # profiler/test_profiler
"TestTorchTidyProfiler.test_optimizer_parameters_sgd", # profiler/test_profiler
"TestExperimentalUtils.test_profiler_name_pattern", # profiler/test_profiler
"TestTorchTidyProfiler.test_extra_fields", # profiler/test_profiler
@ -1639,13 +1638,11 @@ dynamo_expected_failures = {
"TestTorchTidyProfiler.test_sparse_tensors", # profiler/test_profiler
"TestTorchTidyProfiler.test_optimizer", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensorimpl_invalidation_keep_alive", # profiler/test_profiler
"TestExperimentalUtils.test_utils_compute_queue_depth", # profiler/test_profiler
"TestExperimentalUtils.test_profiler_pattern_match_helper", # profiler/test_profiler
"TestProfiler.test_export_stacks", # profiler/test_profiler
"TestProfiler.test_source_multithreaded_basic_work_in_main_thread_True", # profiler/test_profiler
"TestTorchTidyProfiler.test_mkldnn_tensors", # profiler/test_profiler
"TestRecordFunction.test_datapipe_with_record_function", # profiler/test_profiler
"TestProfiler.test_memory_profiler", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensor_lists", # profiler/test_profiler
"TestTorchTidyProfiler.test_pointers_and_ids", # profiler/test_profiler
"TestTorchTidyProfiler.test_nnmodule_params", # profiler/test_profiler