diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 7a2ceb0f26b9..3bebcfc345b7 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -49,6 +49,12 @@ def make_dynamic_cls(cls): suffix, (config, "assume_static_by_default", False), (config, "specialize_int", False), + # When we unspecialize float, we wobble tests by changing + # the op count since previously we would just specialize and constant + # fold floats into the graph, whereas when we unspecialize we will have + # ops for item, add, and all other tensorified operations. Since these + # tests really aren't testing that, we purposely specialize floats here. + (config, "specialize_float", True), (fx_config, "translation_validation", TEST_Z3), (fx_config, "check_shape_env_recorded_events", True), (fx_config, "validate_shape_env_version_key", True), diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e788f687142d..a17dbc313b6d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1271,7 +1271,12 @@ utils_device.CURRENT_DEVICE == None""".split( torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1) - @torch._dynamo.config.patch(specialize_float=False) + # When we unspecialize float, we wobble this test by changing + # the op count since previously we would just specialize and constant + # fold floats into the graph, whereas when we unspecialize we will have + # ops for item, add, and all other tensorified operations. Since this + # test really isn't testing that, we purposely specialize floats here. + @torch._dynamo.config.patch(specialize_float=True) def test_config_obj(self): class Cfg: def __init__(self) -> None: @@ -1296,7 +1301,7 @@ utils_device.CURRENT_DEVICE == None""".split( cfg2.val = 2.0 v = opt_fn(v, cfg2) # 7 self.assertEqual(v[0], 7) - self.assertEqual(cnts.op_count, 9) + self.assertEqual(cnts.op_count, 8) def test_config_getattr_default(self): class Cfg: @@ -3620,7 +3625,12 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertTrue(same(out[0], out[1])) - @torch._dynamo.config.patch(specialize_float=False) + # When we unspecialize float, we wobble this test by changing + # the op count since previously we would just specialize and constant + # fold floats into the graph, whereas when we unspecialize we will have + # ops for item, add, and all other tensorified operations. Since this + # test really isn't testing that, we purposely specialize floats here. + @torch._dynamo.config.patch(specialize_float=True) def test_closure_out_of_scope_cell(self): cell1 = torch.rand(1).item() cell2 = torch.rand(3, 3) @@ -3642,7 +3652,12 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 1) - @torch._dynamo.config.patch(specialize_float=False) + # When we unspecialize float, we wobble this test by changing + # the op count since previously we would just specialize and constant + # fold floats into the graph, whereas when we unspecialize we will have + # ops for item, add, and all other tensorified operations. Since this + # test really isn't testing that, we purposely specialize floats here. + @torch._dynamo.config.patch(specialize_float=True) def test_closure_out_of_scope_cell_with_mutation(self): cell1 = torch.rand(1).item() orig1 = cell1 @@ -3669,18 +3684,8 @@ utils_device.CURRENT_DEVICE == None""".split( result1, result2, _ = opt_fn() self.assertAlmostEqual(orig1 + 1 * i, result1) self.assertTrue(torch.allclose(orig2 + 10 * i, result2)) - if i == 1: - # No automatic dynamic - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) - elif i == 2: - # Automatic dynamic float arguments kicked in - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 6) - else: - # No more recompiles - self.assertEqual(cnts.frame_count, 0) - self.assertEqual(cnts.op_count, 0) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 3) cnts.clear() def test_closure_with_mutation_and_graph_break(self): diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 79fdb0a37add..6bac5cd10e35 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -653,14 +653,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase): self.assertEqual(fn_opt(x, y1), fn(x, y1)) self.assertEqual(fn_opt(x, y2), fn(x, y2)) self.assertEqual(fn_opt(x, y3), fn(x, y3)) - if i == 0: - # This is kind of quirky part of automatic dynamic, - # since it just uses source name + tx.f_code as the key - # subsequent recompilations will actually reuse the automatic - # dynamic choices. - self.assertEqual(cnt.frame_count, 2) - else: - self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.frame_count, 1) @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False) def test_unspec_float_input_f64(self): diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index e4ccc12201c1..ce1bba53f8b3 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2536,9 +2536,9 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float_$is_traceable, m) { self.assertEqual(counters["stats"]["unique_graphs"], 3) else: self.check_output_and_recompiles( - fn, count=[1, 4], compiler_fn=make_compiler_fn(fullgraph=False) + fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False) ) - self.assertEqual(counters["stats"]["unique_graphs"], 3) + self.assertEqual(counters["stats"]["unique_graphs"], 2) @parametrize("is_traceable", (True, False)) @scoped_load_inline diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index cc6f152fb42f..e4cb30bd1791 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -265,6 +265,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): ) @dtypes(torch.float, torch.bfloat16, torch.half) @torch.fx.experimental._config.patch(use_duck_shape=False) + @torch._dynamo.config.patch(specialize_float=True) def test_linear_with_pointwise( self, batch_size, in_features, out_features, bias, epilogue, dtype ): diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index fbc11009f567..52a705911166 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -4119,7 +4119,17 @@ class TestPatternMatcher(TestPatternMatcherBase): self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) -@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) +# When testing kernel counts, unspecializing float causes wobbling of our tests because +# we end up reusing the same compiled region across tests. Thus we purposely specialize floats +# here since we primarily care about number of kernels generated in the absence of compile +# caching. +@dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } +) class TestDynamicPatternMatcher(TestPatternMatcherBase): _test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 332ec3c86ac5..bacecfcc5f95 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -1043,6 +1043,19 @@ class TestInductorDynamic(TestCase): self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0)) self.assertEqual(cnt.frame_count, 2) + def test_unspecialized_float_dynamic(self): + def fn(x, y): + return x * y + + cnt = CompileCounterWithBackend("inductor") + fn_opt = torch.compile(fn, dynamic=True, backend=cnt) + x = torch.randn(5, 5) + + self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0)) + self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0)) + self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0)) + self.assertEqual(cnt.frame_count, 1) + @torch._dynamo.config.patch(specialize_float=False) def test_unspecialized_float_fallback_symint_specialization(self): def fn(x, y): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 87ebe99a987f..0e900c094d21 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2089,7 +2089,10 @@ class VariableBuilder: # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950 or torch._inductor.config.triton.cudagraphs or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) - or frame_state_entry.scalar is not auto_dynamic + or ( + config.assume_static_by_default + and frame_state_entry.scalar is not auto_dynamic + ) ): self.install_guards(GuardBuilder.CONSTANT_MATCH) return ConstantVariable.create(value=value, source=self.source)