mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard (#117625)
Make variables in dict lazy and remove DICT_KEYS guard. We build the keys of a dict depth-first and we rely on the guards of each element in the dict to create the correct guards. This allows us to remove the rather buggy DICT_KEYS guard and make the guard lazy. The guards are not completely lazy yet, as we instantiate them in `_HashableTracker._eq_impl` but it should be possible to make them truly lazy. Also, adding new types to the supported types within keys should be less error prone. This is marginally less efficient when we graph break, but in turn we should graph break much less. It also makes the dicts code easier to maintain (removes `is_hashable_python_var`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/117625 Approved by: https://github.com/jansel, https://github.com/peterbell10, https://github.com/anijain2305 ghstack dependencies: #117982, #118098, #117983
This commit is contained in:
committed by
PyTorch MergeBot
parent
75a5c41921
commit
eb2bdfae88
@ -322,22 +322,22 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
actual_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
|
||||
l_d_x_ = L_d_x_
|
||||
l_d_y_0_ = L_d_y_0_
|
||||
l_d_y_1_2_ = L_d_y_1_2_
|
||||
def forward(self, L_d_dict_keys_getitem_L_d_0_ : torch.Tensor, L_d_dict_keys_getitem_L_d_1_0_ : torch.Tensor, L_d_dict_keys_getitem_L_d_1_1_2_ : torch.Tensor):
|
||||
l_d_dict_keys_getitem_l_d_0_ = L_d_dict_keys_getitem_L_d_0_
|
||||
l_d_dict_keys_getitem_l_d_1_0_ = L_d_dict_keys_getitem_L_d_1_0_
|
||||
l_d_dict_keys_getitem_l_d_1_1_2_ = L_d_dict_keys_getitem_L_d_1_1_2_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
|
||||
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_dict_keys_getitem_l_d_0_, l_d_dict_keys_getitem_l_d_1_0_, l_d_dict_keys_getitem_l_d_1_1_2_); wrap_body_0 = l_d_dict_keys_getitem_l_d_0_ = l_d_dict_keys_getitem_l_d_1_0_ = l_d_dict_keys_getitem_l_d_1_1_2_ = None
|
||||
getitem = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
|
||||
sin = l_d_x_.sin(); l_d_x_ = None
|
||||
cos = l_d_y_0_.cos(); l_d_y_0_ = None
|
||||
def forward(self, l_d_dict_keys_getitem_l_d_0_, l_d_dict_keys_getitem_l_d_1_0_, l_d_dict_keys_getitem_l_d_1_1_2_):
|
||||
sin = l_d_dict_keys_getitem_l_d_0_.sin(); l_d_dict_keys_getitem_l_d_0_ = None
|
||||
cos = l_d_dict_keys_getitem_l_d_1_0_.cos(); l_d_dict_keys_getitem_l_d_1_0_ = None
|
||||
add = sin + cos; sin = cos = None
|
||||
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
|
||||
sin_1 = l_d_dict_keys_getitem_l_d_1_1_2_.sin(); l_d_dict_keys_getitem_l_d_1_1_2_ = None
|
||||
sub = add - sin_1; add = sin_1 = None
|
||||
return (sub,)
|
||||
""", # NOQA: B950
|
||||
@ -2338,7 +2338,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertExpectedInline(
|
||||
graph.code.strip(),
|
||||
"""\
|
||||
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
|
||||
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_dict_keys_getitem_L_pytree_in_4_0_ : torch.Tensor):
|
||||
l_pred_ = L_pred_
|
||||
l_pytree_in_0_ = L_pytree_in_0_
|
||||
l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
|
||||
@ -2346,10 +2346,10 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
||||
l_pytree_in_3_0_ = L_pytree_in_3_0_
|
||||
l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
|
||||
l_pytree_in_3_2_ = L_pytree_in_3_2_
|
||||
l_pytree_in_4_g_ = L_pytree_in_4_g_
|
||||
l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_ = L_pytree_in_4_dict_keys_getitem_L_pytree_in_4_0_
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
|
||||
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_ = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
@ -2516,6 +2516,20 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
# Extra calls don't recompile
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_dict_namedtuple(self):
|
||||
def fn(d):
|
||||
return d[3] * 2
|
||||
|
||||
args1 = {collections.namedtuple: None, 3: torch.randn(3)}
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
self.assertEqual(fn(args1), opt_fn(args1))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
# Test a failing namedtuple guard
|
||||
args2 = {2: None, 3: torch.randn(3)}
|
||||
self.assertEqual(fn(args2), opt_fn(args2))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_dict_order_keys_tensors(self):
|
||||
def fn(d, x):
|
||||
return d[x] + 3
|
||||
|
@ -1843,7 +1843,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
handle.remove()
|
||||
self.assertEqual(compiled_func(inp), outer_func(inp))
|
||||
self.assertEqual(compiled_func(inp).item(), 7)
|
||||
self.assertTrue("forward_hooks.keys" in failure_reason)
|
||||
self.assertTrue("forward_hooks" in failure_reason)
|
||||
self.assertEqual(cc.frame_count, 1 + 1)
|
||||
self.assertEqual(cc.op_count, 6 + 4)
|
||||
|
||||
|
@ -688,7 +688,7 @@ for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
name = test_param.pop('module_name')
|
||||
test_param['constructor'] = getattr(nn, name)
|
||||
decorator = test_param.pop('decorator', None)
|
||||
decorator = test_param.pop('decorator', lambda test: test)
|
||||
test = ContextManagerTests(**test_param)
|
||||
test_name = test.get_name()
|
||||
if hasattr(TestExpandedWeightModule, test_name):
|
||||
@ -696,16 +696,14 @@ for test_param in supported_tests:
|
||||
test_name_multi_input = test.get_name() + "_multiple_inputs"
|
||||
if hasattr(TestExpandedWeightModule, test_name_multi_input):
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
if decorator is not None:
|
||||
fn = decorator(fn) # noqa: F821
|
||||
if test.test_cpu:
|
||||
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self, 'cpu'))
|
||||
setattr(TestExpandedWeightModule, test_name, decorator(lambda self, test=test: test.test_context_manager(self, 'cpu')))
|
||||
setattr(TestExpandedWeightModule, test_name_multi_input,
|
||||
lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu'))
|
||||
decorator(lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu')))
|
||||
if TEST_CUDA and test.test_cuda:
|
||||
# since this checks derivatives, only use double for precision
|
||||
setattr(TestExpandedWeightModule, test_name + '_cuda_double',
|
||||
lambda self, test=test: test.test_context_manager(self, 'cuda'))
|
||||
decorator(lambda self, test=test: test.test_context_manager(self, 'cuda')))
|
||||
|
||||
# ------------- HELPER FUNCTIONS -----------------
|
||||
|
||||
|
@ -668,7 +668,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self._produce_guard_code(guard, [shape_guard], shape_env=True)
|
||||
|
||||
def TENSOR_MATCH(self, guard: Guard, value=None):
|
||||
if guard.is_nn_module():
|
||||
if guard.is_nn_module() or guard.originating_source.is_dict_key():
|
||||
self.ID_MATCH(guard)
|
||||
else:
|
||||
if isinstance(value, TensorWeakRef):
|
||||
|
@ -346,6 +346,9 @@ class GetItemSource(ChainedSource):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConstDictKeySource(GetItemSource):
|
||||
def is_dict_key(self):
|
||||
return True
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
return [
|
||||
*codegen.create_load_import_from(utils.__name__, "dict_keys_getitem"),
|
||||
|
@ -96,7 +96,6 @@ from .dicts import (
|
||||
DataClassVariable,
|
||||
DefaultDictVariable,
|
||||
HFPretrainedConfigVariable,
|
||||
is_hashable_python_var,
|
||||
PythonSysModulesVariable,
|
||||
SetVariable,
|
||||
)
|
||||
@ -413,9 +412,7 @@ class VariableBuilder:
|
||||
return ConstDictVariable(result, type(value))
|
||||
elif value is sys.modules:
|
||||
return PythonSysModulesVariable(source=self.source)
|
||||
elif istype(
|
||||
value, (dict, collections.defaultdict, collections.OrderedDict)
|
||||
) and all(is_hashable_python_var(k) for k in value.keys()):
|
||||
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
|
||||
if not value and self.get_source().is_nn_module():
|
||||
# It is faster to guard on 'false' property than to guard
|
||||
# on actual dict keys, but we can't do this fast guard in general because
|
||||
@ -426,26 +423,22 @@ class VariableBuilder:
|
||||
# but not completely secure job ensuring a property wasn't changed.
|
||||
self.install_guards(GuardBuilder.BOOL_FALSE)
|
||||
else:
|
||||
self.install_guards(GuardBuilder.DICT_KEYS)
|
||||
self.install_guards(GuardBuilder.LIST_LENGTH)
|
||||
|
||||
idx = 0
|
||||
|
||||
def build_key_value(k, v):
|
||||
nonlocal idx
|
||||
if ConstantVariable.is_literal(k):
|
||||
key = ConstantVariable.create(k)
|
||||
source_key = k
|
||||
else:
|
||||
source_key = ConstDictKeySource(self.get_source(), idx)
|
||||
key = VariableBuilder(self.tx, source_key)(k)
|
||||
# We need all the keys to be hashable. We do this within the
|
||||
# _HashableTracker class in dicts.py
|
||||
def build_key_value(i, k, v):
|
||||
source_key = ConstDictKeySource(self.get_source(), i)
|
||||
key = LazyVariableTracker.create(k, source_key)
|
||||
|
||||
source_value = GetItemSource(self.get_source(), source_key)
|
||||
value = LazyVariableTracker.create(v, source_value)
|
||||
|
||||
idx += 1
|
||||
return key, value
|
||||
|
||||
result = dict(build_key_value(k, v) for k, v in value.items())
|
||||
result = dict(
|
||||
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
|
||||
)
|
||||
|
||||
if istype(value, collections.defaultdict):
|
||||
result = DefaultDictVariable(
|
||||
|
@ -2,15 +2,11 @@
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
from types import MethodWrapperType
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
from .. import variables
|
||||
@ -28,32 +24,12 @@ from ..utils import dict_keys, dict_values, istype, specialize_symnode
|
||||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
# Note: [Adding a new supported class the keys of ConstDictVarialble]
|
||||
# You'll need to add it to:
|
||||
# - `is_hashable_python_var` in this file
|
||||
# - `is_hashable` in this file
|
||||
# - `const_repr` in util.py, and perhaps modify DICT_KEYS in guards.py
|
||||
|
||||
|
||||
def is_hashable_python_var(x):
|
||||
# IMPORTANT: Keep me in sync with is_hashable!
|
||||
# Even better, we should have a map of functions connecting the two
|
||||
from torch import Tensor
|
||||
from ..trace_rules import is_builtin_callable, is_numpy
|
||||
|
||||
return (
|
||||
ConstantVariable.is_literal(x)
|
||||
or isinstance(x, (Tensor, enum.Enum, type, torch.nn.Module, MethodWrapperType))
|
||||
or is_builtin_callable(x)
|
||||
or (isinstance(x, tuple) and all(is_hashable_python_var(e) for e in x))
|
||||
or is_numpy(x)
|
||||
)
|
||||
# [Adding a new supported class within the keys of ConstDictVarialble]
|
||||
# - Add its tracker type to is_hashable
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def is_hashable(x):
|
||||
# IMPORTANT: Keep me in sync with is_hashable_python_var!
|
||||
# Even better, we should have a map of functions connecting the two
|
||||
|
||||
if isinstance(x, variables.TensorVariable):
|
||||
# Tensors are hashable if they have an example_value (a fake tensor)
|
||||
# Most VT's should have one.
|
||||
@ -201,16 +177,6 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return [create_instruction("BUILD_MAP", arg=len(self.items))]
|
||||
|
||||
@staticmethod
|
||||
def _wrap_keys_python_var(d):
|
||||
"""Wrap the keys of a dictionary with python objs as keys into Hashable objects"""
|
||||
assert all(is_hashable_python_var(k) for k in d.keys())
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
build = SourcelessBuilder()
|
||||
return {Hashable(build(k)): v for k, v in d.items()}
|
||||
|
||||
def getitem_const(self, arg: VariableTracker):
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
@ -290,8 +256,10 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
# all keys in kwargs are valid (`str`s)
|
||||
kwargs = ConstDictVariable._wrap_keys_python_var(kwargs)
|
||||
# Wrap strings
|
||||
kwargs = {
|
||||
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs)
|
||||
return ConstantVariable.create(None)
|
||||
elif name in ("get", "__getattr__") and args[0] in self:
|
||||
|
@ -765,6 +765,9 @@ def tracing(context: Optional[TracingContext]):
|
||||
# TODO(voz): Consider a toplevel torch/_source.py
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Source:
|
||||
def is_dict_key(self):
|
||||
return False
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -788,6 +791,10 @@ class Source:
|
||||
class ChainedSource(Source):
|
||||
base: Source
|
||||
|
||||
def is_dict_key(self):
|
||||
# Recurse until you either hit a ConstDictKey or a Source
|
||||
return self.base.is_dict_key()
|
||||
|
||||
|
||||
def detect_fake_mode(inputs: Any = None):
|
||||
"""
|
||||
|
@ -16,7 +16,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import _reduction as _Reduction
|
||||
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||
gradcheck, gradgradcheck, set_default_dtype
|
||||
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
|
||||
from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
|
||||
from torch.autograd import Variable
|
||||
@ -1709,6 +1709,7 @@ new_module_tests = [
|
||||
input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
|
||||
check_gradgrad=False,
|
||||
default_dtype=torch.double,
|
||||
decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
|
||||
),
|
||||
dict(
|
||||
module_name='Embedding',
|
||||
@ -1718,6 +1719,7 @@ new_module_tests = [
|
||||
check_gradgrad=False,
|
||||
desc='discontiguous',
|
||||
default_dtype=torch.double,
|
||||
decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
|
||||
),
|
||||
dict(
|
||||
module_name='EmbeddingBag',
|
||||
|
@ -1611,7 +1611,6 @@ dynamo_expected_failures = {
|
||||
"TestTorchTidyProfiler.test_tensorimpl_invalidation_full", # profiler/test_profiler
|
||||
"TestProfiler.test_profiler_tracing", # profiler/test_profiler
|
||||
"TestProfiler.test_is_profiler_enabled", # profiler/test_profiler
|
||||
"TestExperimentalUtils.test_utils_compute_idle_time", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_optimizer_parameters_sgd", # profiler/test_profiler
|
||||
"TestExperimentalUtils.test_profiler_name_pattern", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_extra_fields", # profiler/test_profiler
|
||||
@ -1639,13 +1638,11 @@ dynamo_expected_failures = {
|
||||
"TestTorchTidyProfiler.test_sparse_tensors", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_optimizer", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_tensorimpl_invalidation_keep_alive", # profiler/test_profiler
|
||||
"TestExperimentalUtils.test_utils_compute_queue_depth", # profiler/test_profiler
|
||||
"TestExperimentalUtils.test_profiler_pattern_match_helper", # profiler/test_profiler
|
||||
"TestProfiler.test_export_stacks", # profiler/test_profiler
|
||||
"TestProfiler.test_source_multithreaded_basic_work_in_main_thread_True", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_mkldnn_tensors", # profiler/test_profiler
|
||||
"TestRecordFunction.test_datapipe_with_record_function", # profiler/test_profiler
|
||||
"TestProfiler.test_memory_profiler", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_tensor_lists", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_pointers_and_ids", # profiler/test_profiler
|
||||
"TestTorchTidyProfiler.test_nnmodule_params", # profiler/test_profiler
|
||||
|
Reference in New Issue
Block a user