Compare commits

...

2 Commits

Author SHA1 Message Date
517e8a5c80 Disable inlining on cudagraph fallback tests
ghstack-source-id: de457d9cf9374657da7b65ba12e5a86d7b69b02e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131557
2024-07-23 16:14:29 -07:00
2ff5c7fc1a Ensure tensor dict is populated with compiled autograd
ghstack-source-id: 950085ab5cba038e7725a66739ced96dbcddd340
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131556
2024-07-23 16:14:23 -07:00
2 changed files with 7 additions and 0 deletions

View File

@ -1973,6 +1973,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.run_static_input_param_test(fn, 6) self.run_static_input_param_test(fn, 6)
@torch._dynamo.config.patch("error_on_recompile", True) @torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
def test_fallback_to_eager_if_recompiling_too_many_times(self): def test_fallback_to_eager_if_recompiling_too_many_times(self):
@ -2008,6 +2009,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@torch._dynamo.config.patch("error_on_recompile", True) @torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self): def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self):
@ -2052,6 +2054,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
) )
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor( def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor(
@ -2096,6 +2099,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
).run(captured_output[0]) ).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._dynamo.config.patch("error_on_recompile", True) @torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1) @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1)

View File

@ -1133,6 +1133,9 @@ class VariableBuilder:
source_i = GetItemSource(base=source, index=i, index_is_slice=False) source_i = GetItemSource(base=source, index=i, index_is_slice=False)
# access unpacked tensor from this list instead of from a lifted arg # access unpacked tensor from this list instead of from a lifted arg
self.tx.output.input_source_to_var[source_i] = tensor_variable self.tx.output.input_source_to_var[source_i] = tensor_variable
tensor_variable.proxy.node.meta["tensor_dict"] = value[
i
].__dict__.copy()
guard = functools.partial( guard = functools.partial(
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])