mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix dynamic float when dynamic=True (#149564)
Fixes https://github.com/pytorch/pytorch/issues/149406#issuecomment-2738111733. Basically previously we would only make floats dynamic via automatic dynamic, now if you set dynamic=True, we will make the floats dynamic on the first compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149564 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f7fbe3d7d
commit
621c801f78
@ -49,6 +49,12 @@ def make_dynamic_cls(cls):
|
||||
suffix,
|
||||
(config, "assume_static_by_default", False),
|
||||
(config, "specialize_int", False),
|
||||
# When we unspecialize float, we wobble tests by changing
|
||||
# the op count since previously we would just specialize and constant
|
||||
# fold floats into the graph, whereas when we unspecialize we will have
|
||||
# ops for item, add, and all other tensorified operations. Since these
|
||||
# tests really aren't testing that, we purposely specialize floats here.
|
||||
(config, "specialize_float", True),
|
||||
(fx_config, "translation_validation", TEST_Z3),
|
||||
(fx_config, "check_shape_env_recorded_events", True),
|
||||
(fx_config, "validate_shape_env_version_key", True),
|
||||
|
@ -1271,7 +1271,12 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
|
||||
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False)
|
||||
# When we unspecialize float, we wobble this test by changing
|
||||
# the op count since previously we would just specialize and constant
|
||||
# fold floats into the graph, whereas when we unspecialize we will have
|
||||
# ops for item, add, and all other tensorified operations. Since this
|
||||
# test really isn't testing that, we purposely specialize floats here.
|
||||
@torch._dynamo.config.patch(specialize_float=True)
|
||||
def test_config_obj(self):
|
||||
class Cfg:
|
||||
def __init__(self) -> None:
|
||||
@ -1296,7 +1301,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
cfg2.val = 2.0
|
||||
v = opt_fn(v, cfg2) # 7
|
||||
self.assertEqual(v[0], 7)
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
self.assertEqual(cnts.op_count, 8)
|
||||
|
||||
def test_config_getattr_default(self):
|
||||
class Cfg:
|
||||
@ -3620,7 +3625,12 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
|
||||
self.assertTrue(same(out[0], out[1]))
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False)
|
||||
# When we unspecialize float, we wobble this test by changing
|
||||
# the op count since previously we would just specialize and constant
|
||||
# fold floats into the graph, whereas when we unspecialize we will have
|
||||
# ops for item, add, and all other tensorified operations. Since this
|
||||
# test really isn't testing that, we purposely specialize floats here.
|
||||
@torch._dynamo.config.patch(specialize_float=True)
|
||||
def test_closure_out_of_scope_cell(self):
|
||||
cell1 = torch.rand(1).item()
|
||||
cell2 = torch.rand(3, 3)
|
||||
@ -3642,7 +3652,12 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False)
|
||||
# When we unspecialize float, we wobble this test by changing
|
||||
# the op count since previously we would just specialize and constant
|
||||
# fold floats into the graph, whereas when we unspecialize we will have
|
||||
# ops for item, add, and all other tensorified operations. Since this
|
||||
# test really isn't testing that, we purposely specialize floats here.
|
||||
@torch._dynamo.config.patch(specialize_float=True)
|
||||
def test_closure_out_of_scope_cell_with_mutation(self):
|
||||
cell1 = torch.rand(1).item()
|
||||
orig1 = cell1
|
||||
@ -3669,18 +3684,8 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
result1, result2, _ = opt_fn()
|
||||
self.assertAlmostEqual(orig1 + 1 * i, result1)
|
||||
self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
|
||||
if i == 1:
|
||||
# No automatic dynamic
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
elif i == 2:
|
||||
# Automatic dynamic float arguments kicked in
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
else:
|
||||
# No more recompiles
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
self.assertEqual(cnts.op_count, 0)
|
||||
cnts.clear()
|
||||
|
||||
def test_closure_with_mutation_and_graph_break(self):
|
||||
|
@ -653,13 +653,6 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(fn_opt(x, y1), fn(x, y1))
|
||||
self.assertEqual(fn_opt(x, y2), fn(x, y2))
|
||||
self.assertEqual(fn_opt(x, y3), fn(x, y3))
|
||||
if i == 0:
|
||||
# This is kind of quirky part of automatic dynamic,
|
||||
# since it just uses source name + tx.f_code as the key
|
||||
# subsequent recompilations will actually reuse the automatic
|
||||
# dynamic choices.
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
else:
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False)
|
||||
|
@ -2536,9 +2536,9 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float_$is_traceable, m) {
|
||||
self.assertEqual(counters["stats"]["unique_graphs"], 3)
|
||||
else:
|
||||
self.check_output_and_recompiles(
|
||||
fn, count=[1, 4], compiler_fn=make_compiler_fn(fullgraph=False)
|
||||
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
|
||||
)
|
||||
self.assertEqual(counters["stats"]["unique_graphs"], 3)
|
||||
self.assertEqual(counters["stats"]["unique_graphs"], 2)
|
||||
|
||||
@parametrize("is_traceable", (True, False))
|
||||
@scoped_load_inline
|
||||
|
@ -265,6 +265,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
||||
)
|
||||
@dtypes(torch.float, torch.bfloat16, torch.half)
|
||||
@torch.fx.experimental._config.patch(use_duck_shape=False)
|
||||
@torch._dynamo.config.patch(specialize_float=True)
|
||||
def test_linear_with_pointwise(
|
||||
self, batch_size, in_features, out_features, bias, epilogue, dtype
|
||||
):
|
||||
|
@ -4119,7 +4119,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)
|
||||
|
||||
|
||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
||||
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
||||
# here since we primarily care about number of kernels generated in the absence of compile
|
||||
# caching.
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"dynamic_shapes": True,
|
||||
"assume_static_by_default": False,
|
||||
"specialize_float": True,
|
||||
}
|
||||
)
|
||||
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
_test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base
|
||||
test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu
|
||||
|
@ -1043,6 +1043,19 @@ class TestInductorDynamic(TestCase):
|
||||
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_unspecialized_float_dynamic(self):
|
||||
def fn(x, y):
|
||||
return x * y
|
||||
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
fn_opt = torch.compile(fn, dynamic=True, backend=cnt)
|
||||
x = torch.randn(5, 5)
|
||||
|
||||
self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0))
|
||||
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
|
||||
self.assertEqual(fn(x, 4.0), fn_opt(x, 4.0))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False)
|
||||
def test_unspecialized_float_fallback_symint_specialization(self):
|
||||
def fn(x, y):
|
||||
|
@ -2089,7 +2089,10 @@ class VariableBuilder:
|
||||
# python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950
|
||||
or torch._inductor.config.triton.cudagraphs
|
||||
or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
|
||||
or frame_state_entry.scalar is not auto_dynamic
|
||||
or (
|
||||
config.assume_static_by_default
|
||||
and frame_state_entry.scalar is not auto_dynamic
|
||||
)
|
||||
):
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value, source=self.source)
|
||||
|
Reference in New Issue
Block a user