[dynamo] Extend LazyVariableTracker to tuples (#117426)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117426
Approved by: https://github.com/lezcano, https://github.com/jansel
This commit is contained in:
Animesh Jain
2024-01-17 20:53:23 -08:00
committed by PyTorch MergeBot
parent 26956980c6
commit 6e4e81a9ef
19 changed files with 61 additions and 63 deletions

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,41
detectron2_fcos_r_50_fpn,pass,35
@ -354,7 +354,7 @@ vgg16,pass,0
vision_maskrcnn,pass,17
vision_maskrcnn,pass,16

1 name accuracy graph_breaks
86 timm_vovnet pass 0
87 torch_multimodal_clip pass 0
88 tts_angular pass 2
89 vgg16 pass 0
90 vision_maskrcnn pass 17 16
91 yolov3 pass 2
92
354
355
356
357
358
359
360

View File

@ -294,7 +294,7 @@ vgg16,pass,7
vision_maskrcnn,pass,35
vision_maskrcnn,pass,34

1 name accuracy graph_breaks
294
295
296
297
298
299
300

View File

@ -54,47 +54,47 @@ densenet121,pass,0
detectron2_fasterrcnn_r_101_c4,pass,52
detectron2_fasterrcnn_r_101_c4,pass,51
detectron2_fasterrcnn_r_101_dc5,pass,52
detectron2_fasterrcnn_r_101_dc5,pass,51
detectron2_fasterrcnn_r_101_fpn,pass,56
detectron2_fasterrcnn_r_101_fpn,pass,55
detectron2_fasterrcnn_r_50_c4,pass,52
detectron2_fasterrcnn_r_50_c4,pass,51
detectron2_fasterrcnn_r_50_dc5,pass,52
detectron2_fasterrcnn_r_50_dc5,pass,51
detectron2_fasterrcnn_r_50_fpn,pass,56
detectron2_fasterrcnn_r_50_fpn,pass,55
detectron2_fcos_r_50_fpn,pass,44
detectron2_fcos_r_50_fpn,pass,38
detectron2_maskrcnn_r_101_c4,fail_accuracy,67
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
detectron2_maskrcnn_r_101_fpn,pass,74
detectron2_maskrcnn_r_101_fpn,pass,73
detectron2_maskrcnn_r_50_c4,pass,67
detectron2_maskrcnn_r_50_c4,pass,66
detectron2_maskrcnn_r_50_fpn,pass,74
detectron2_maskrcnn_r_50_fpn,pass,73
@ -322,7 +322,7 @@ vgg16,pass,0
vision_maskrcnn,pass,29
vision_maskrcnn,pass,28

1 name accuracy graph_breaks
54 phlippe_resnet pass 0
55 pyhpc_equation_of_state pass 0
56 pyhpc_isoneutral_mixing pass 0
57 pyhpc_turbulent_kinetic_energy pass 0
58 pytorch_CycleGAN_and_pix2pix pass 0
59 pytorch_stargan pass 0
60 pytorch_unet pass 0
61 resnet152 pass 0
62 resnet18 pass 0
63 resnet50 pass 0
64 resnet50_quantized_qat eager_fail_to_run 0
65 resnext50_32x4d pass 0
66 shufflenet_v2_x1_0 pass 0
67 soft_actor_critic pass 0
68 speech_transformer pass 10
69 squeezenet1_1 pass 0
70 stable_diffusion_unet pass_due_to_skip 0
71 timm_efficientdet model_fail_to_load 0
72 timm_efficientnet pass 0
73 timm_nfnet pass 0
74 timm_regnet pass 0
75 timm_resnest pass 0
76 timm_vision_transformer pass 0
77 timm_vision_transformer_large pass_due_to_skip 0
78 timm_vovnet pass 0
79 torch_multimodal_clip pass 0
80 tts_angular pass 2
81 vgg16 pass 0
82 vision_maskrcnn pass 29 28
83 yolov3 pass 2
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
322
323
324
325
326
327
328

View File

@ -346,7 +346,7 @@ vgg16,pass,0
vision_maskrcnn,pass,17
vision_maskrcnn,pass,16

1 name accuracy graph_breaks
346
347
348
349
350
351
352

View File

@ -286,7 +286,7 @@ vgg16,pass,7
vision_maskrcnn,pass,35
vision_maskrcnn,pass,34

1 name accuracy graph_breaks
286
287
288
289
290
291
292

View File

@ -274,7 +274,7 @@ vgg16,pass,0
vision_maskrcnn,pass,29
vision_maskrcnn,pass,28

1 name accuracy graph_breaks
274
275
276
277
278
279
280

View File

@ -346,7 +346,7 @@ vgg16,pass,0
vision_maskrcnn,pass,17
vision_maskrcnn,pass,16

1 name accuracy graph_breaks
346
347
348
349
350
351
352

View File

@ -286,7 +286,7 @@ vgg16,pass,7
vision_maskrcnn,pass,35
vision_maskrcnn,pass,34

1 name accuracy graph_breaks
286
287
288
289
290
291
292

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,41
detectron2_fcos_r_50_fpn,pass,35
@ -354,7 +354,7 @@ vgg16,pass,0
vision_maskrcnn,pass,17
vision_maskrcnn,pass,16

1 name accuracy graph_breaks
86 timm_vovnet pass 0
87 torch_multimodal_clip pass 0
88 tts_angular pass 2
89 vgg16 pass 0
90 vision_maskrcnn pass 17 16
91 yolov3 pass 2
92
354
355
356
357
358
359
360

View File

@ -294,7 +294,7 @@ vgg16,pass,7
vision_maskrcnn,pass,35
vision_maskrcnn,pass,34

1 name accuracy graph_breaks
294
295
296
297
298
299
300

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,42
detectron2_fcos_r_50_fpn,pass,36
@ -354,7 +354,7 @@ vgg16,pass,0
vision_maskrcnn,pass,17
vision_maskrcnn,pass,16

1 name accuracy graph_breaks
86 timm_vovnet pass 0
87 torch_multimodal_clip pass 0
88 tts_angular pass 2
89 vgg16 pass 0
90 vision_maskrcnn pass 17 16
91 yolov3 pass 2
92
354
355
356
357
358
359
360

View File

@ -294,7 +294,7 @@ vgg16,pass,7
vision_maskrcnn,pass,35
vision_maskrcnn,pass,34

1 name accuracy graph_breaks
294
295
296
297
298
299
300

View File

@ -785,6 +785,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref[1]["e"], res[1]["e"]))
self.assertTrue(same(ref[1][param], res[1][param]))
def test_dict_tuple_lazy_guard(self):
@torch.compile(backend="eager")
def fn(x, y):
return torch.sin(x) * y[1]
fn(torch.randn(3), {1: 1, 2: 2})
# Changing the value of other key should not causing recompilation
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(3), {1: 1, 2: 3})
fn(torch.randn(3), (1, 2, 3))
# Changing the value of index 0, 2 (not 1) should not cause recompilation
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(3), (11, 2, 13))
@make_test
def test_call_dict1(x):
d1 = dict()

View File

@ -266,9 +266,7 @@ tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
opt_f([7, 8])
for line in """\
len(L['x']) == 3
L['x'][0] == 4
L['x'][1] == 5""".split(
len(L['x']) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())
@ -278,9 +276,7 @@ L['x'][1] == 5""".split(
for line in """\
len(L['x']) == 2
L['x'][0] == 7
len(L['x']) == 3
L['x'][0] == 4""".split(
len(L['x']) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())

View File

@ -419,7 +419,6 @@ class TestGradTransform(TestCase):
expected = -y * x.sin()
self.assertEqual(result, expected)
@xfailIfTorchDynamo
def test_grad_of_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)

View File

@ -4934,6 +4934,7 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
@skipCUDAIfNoMagmaAndNoCusolver
@skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor")
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_linalg_lu_family(self, device, dtype):

View File

@ -801,9 +801,10 @@ class VariableBuilder:
unimplemented("list elements are pointing to the list itself")
output = [
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(item)
LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i))
for i, item in enumerate(value)
]
result = BaseListVariable.cls_for_instance(value)(
output, mutable_local=MutableLocal()
)

View File

@ -493,6 +493,13 @@ class BuiltinVariable(VariableTracker):
k: v.as_python_constant() for k, v in kwargs.items()
}
def has_constant_handler(self, args, kwargs):
constant_args = check_constant_args(args, kwargs)
unspec_python_args = self.unspec_python_args(*args, **kwargs)
return self.can_constant_fold_through() and (
constant_args or unspec_python_args
)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
@ -501,14 +508,9 @@ class BuiltinVariable(VariableTracker):
args = [v.realize() for v in args]
kwargs = {k: v.realize() for k, v in kwargs.items()}
constant_args = check_constant_args(args, kwargs)
tensor_args = self.tensor_args(*args, **kwargs)
unspec_python_args = self.unspec_python_args(*args, **kwargs)
has_constant_handler = self.can_constant_fold_through() and (
constant_args or unspec_python_args
)
assert isinstance(args, (list, tuple))
assert isinstance(kwargs, dict)
tensor_args = self.tensor_args(*args, **kwargs)
# args[0] is list and args[1] is unspec
if self.fn is operator.getitem and not isinstance(
@ -646,6 +648,7 @@ class BuiltinVariable(VariableTracker):
try:
inspect.signature(handler).bind(tx, *args, **kwargs)
except TypeError as exc:
has_constant_handler = self.has_constant_handler(args, kwargs)
if not has_constant_handler:
log.warning(
"incorrect arg count %s %s and no constant handler",
@ -660,11 +663,17 @@ class BuiltinVariable(VariableTracker):
if result is not None:
return result
except Unsupported as exc:
has_constant_handler = self.has_constant_handler(args, kwargs)
if not has_constant_handler:
raise
# Actually, we will handle this just fine
exc.remove_from_stats()
# NB: call to has_constant_handler is deliberately delayed post generic
# handler because has_constant_handler calls as_python_constant
# internally which realizes LazyVariableTracker for ConstantVariables,
# unnecessarily putting guards on objects which might not actually be used.
has_constant_handler = self.has_constant_handler(args, kwargs)
if has_constant_handler:
# constant fold
return variables.ConstantVariable.create(

View File

@ -126,8 +126,6 @@ dynamo_expected_failures = {
"TestLinalgCPU.test_inverse_cpu_complex128",
"TestLinalgCPU.test_norm_dtype_cpu_complex128",
"TestLinalgCPU.test_householder_product_cpu_float64",
"TestLinalgCPU.test_linalg_lu_family_cpu_float32",
"TestLinalgCPU.test_linalg_lu_family_cpu_float64",
"TestLinalgCPU.test_addr_integral_cpu_int64",
"TestLinalgCPU.test_norm_vector_cpu_float32",
"TestLinalgCPU.test_solve_cpu_complex128",
@ -152,7 +150,6 @@ dynamo_expected_failures = {
"TestLinalgCPU.test_addmm_sizes_cpu_float32",
"TestLinalgCPU.test_norm_bfloat16_and_half_cpu_float16",
"TestLinalgCPU.test_householder_product_cpu_complex64",
"TestLinalgCPU.test_linalg_lu_family_cpu_complex128",
"TestLinalgCPU.test_inverse_cpu_float64",
"TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64",
"TestLinalgCPU.test_pinv_cpu_complex64",
@ -161,7 +158,6 @@ dynamo_expected_failures = {
"TestLinalgCPU.test_einsum_sublist_format_cpu_complex128",
"TestLinalgCPU.test_geqrf_cpu_complex64",
"TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64",
"TestLinalgCPU.test_linalg_lu_family_cpu_complex64",
"TestLinalgCPU.test_geqrf_cpu_float64",
"TestLinalgCPU.test_householder_product_cpu_complex128",
"TestLinalgCPU.test_geqrf_cpu_float32",
@ -821,10 +817,8 @@ dynamo_expected_failures = {
"TestIndexing.test_index_no_floats", # torch_np/numpy_tests/core/test_indexing
"TestBooleanIndexing.test_boolean_indexing_weirdness", # torch_np/numpy_tests/core/test_indexing
"TestBooleanIndexing.test_bool_as_int_argument_errors", # torch_np/numpy_tests/core/test_indexing
"TestBroadcastedAssignments.test_simple_broadcasting_errors", # torch_np/numpy_tests/core/test_indexing
"TestFloatNonIntegerArgument.test_non_integer_argument_errors", # torch_np/numpy_tests/core/test_indexing
"TestIndexing.test_slicing_no_floats", # torch_np/numpy_tests/core/test_indexing
"TestBroadcastedAssignments.test_prepend_not_one", # torch_np/numpy_tests/core/test_indexing
"TestFloatNonIntegerArgument.test_reduce_axis_float_index", # torch_np/numpy_tests/core/test_indexing
"TestEinsum.test_different_paths_dtype_e", # torch_np/numpy_tests/core/test_einsum
"TestEinsum.test_different_paths_dtype_B", # torch_np/numpy_tests/core/test_einsum
@ -2073,7 +2067,6 @@ dynamo_expected_failures = {
"TestMkldnnCPU.test_tanh_cpu", # test_mkldnn
"TestMkldnnCPU.test_conv2d_cpu", # test_mkldnn
"TestMkldnnCPU.test_batch_norm_3d_cpu", # test_mkldnn
"TestFunctionSchema.test_serialize_and_deserialize", # test_function_schema
"FakeTensorOperatorInvariants.test_like_ops", # test_fake_tensor
"FakeTensorConverterTest.test_memoized_conversion_from_meta", # test_fake_tensor
"FakeTensorOperatorInvariants.test_non_kwarg_only_device", # test_fake_tensor
@ -2794,7 +2787,6 @@ dynamo_expected_failures = {
"TestVmapOperatorsLegacy.test_contiguous", # test_legacy_vmap
"TestVmapAPILegacy.test_accepts_nested_inputs", # test_legacy_vmap
"TestVmapAPILegacy.test_nested_out_dims", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_add_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_inplace_manyview_cpu", # test_legacy_vmap
"TestVmapAPILegacy.test_functools_partial", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_unrelated_output_cpu", # test_legacy_vmap
@ -2803,21 +2795,16 @@ dynamo_expected_failures = {
"TestVmapAPILegacy.test_single_input", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_chunk", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_mul_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_reshape_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_unrelated_output_multiple_grad_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_stack", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_select", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_binary_pointwise_ops", # test_legacy_vmap
"TestVmapAPILegacy.test_non_tensor_output_raises", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_max_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_binary_cross_entropy_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_diagonal", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_select_cpu", # test_legacy_vmap
"TestVmapAPILegacy.test_nonzero_out_dims", # test_legacy_vmap
"TestVmapAPILegacy.test_unsupported_op_err_msg", # test_legacy_vmap
"TestVmapAPILegacy.test_batched_gradient_basic", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_slice", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_min_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_expand_as", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_unfold", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_sigmoid_cpu", # test_legacy_vmap
@ -2827,16 +2814,11 @@ dynamo_expected_failures = {
"TestVmapOperatorsLegacy.test_new_empty_strided", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_is_floating_point", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_split", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_stack_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_fill_and_zero_inplace", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_is_complex", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_expand_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_as_strided", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_slice_cpu", # test_legacy_vmap
"TestVmapAPILegacy.test_nested_with_different_map_dim", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_new_zeros", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_trace_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_permute_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_view_as", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_logsumexp_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_log1p_cpu", # test_legacy_vmap
@ -2850,17 +2832,13 @@ dynamo_expected_failures = {
"TestVmapBatchedGradientLegacyCPU.test_inplace_on_view_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_new_empty", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_lgamma_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_threshold_cpu", # test_legacy_vmap
"TestVmapAPILegacy.test_multiple_out_dims", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_result_type", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_sum_dim", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_to", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_diagonal_cpu", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_sub_cpu", # test_legacy_vmap
"TestVmapAPILegacy.test_backward_unsupported_interaction", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_comparison_ops", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_is_contiguous", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_cat", # test_legacy_vmap
"TestVmapAPILegacy.test_multiple_outputs", # test_legacy_vmap
"TestVmapAPILegacy.test_inplace_fallback_unary", # test_legacy_vmap
"TestVmapAPILegacy.test_out_dim_out_of_bounds_err_msg", # test_legacy_vmap
@ -2875,7 +2853,6 @@ dynamo_expected_failures = {
"TestVmapOperatorsLegacy.test_no_random_op_support", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_unbind", # test_legacy_vmap
"TestVmapAPILegacy.test_non_default_in_dims_out_dims", # test_legacy_vmap
"TestVmapBatchedGradientLegacyCPU.test_median_cpu", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_T_numpy", # test_legacy_vmap
"TestNamedTensor.test_addmv", # test_namedtensor
"TestNamedTensor.test_cummax_cummin", # test_namedtensor