diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index a83322f6154f..758225534d0e 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -123,7 +123,8 @@ class TestFxGraphCache(TestCase): @parametrize("dtype", (torch.float32, torch.bfloat16)) @parametrize("dynamic", (False, True)) @parametrize("bundle_triton", (False, True)) - def test_cache_load_function(self, device, dtype, dynamic, bundle_triton): + @parametrize("grad", (False, True)) + def test_cache_load_function(self, device, dtype, dynamic, bundle_triton, grad): """ Verify that we can populate and load functions from the cache. """ @@ -132,23 +133,43 @@ class TestFxGraphCache(TestCase): if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: raise unittest.SkipTest("requires SM80 or later") - def fn(x, y): - return (x * 2, y @ y) + grad_multiplier = 2 if grad else 1 - a = torch.rand(25, dtype=dtype, device=device) - b = torch.rand(5, 5, dtype=dtype, device=device) + def fn(x, y): + yy = y @ y + return x * 2 + yy.view(25) + + a_orig = torch.rand(25, dtype=dtype, device=device) + b_orig = torch.rand(5, 5, dtype=dtype, device=device) with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton): compiled_fn = torch.compile(fn, dynamic=dynamic) + a1 = a_orig.clone().requires_grad_(grad) + b1 = b_orig.clone().requires_grad_(grad) + a2 = a_orig.clone().requires_grad_(grad) + b2 = b_orig.clone().requires_grad_(grad) + # A first call should miss in the cache. - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + eager_result = fn(a1, b1) + compiled_result = compiled_fn(a2, b2) + self.assertEqual(eager_result, compiled_result) + if grad: + eager_result.sum().backward() + compiled_result.sum().backward() + self.assertEqual(a1.grad, a2.grad) + self.assertEqual(b1.grad, b2.grad) + self.assertEqual( + counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1 + ) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) if bundle_triton and device != "cpu": - self.assertEqual(counters["inductor"]["triton_bundler_save_kernel"], 7) + self.assertEqual( + counters["inductor"]["triton_bundler_save_kernel"], + grad_multiplier * 7, + ) self.assertEqual( counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0 ) @@ -161,15 +182,37 @@ class TestFxGraphCache(TestCase): PyCodeCache.cache_clear() shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) - self.assertEqual(fn(a, b), compiled_fn(a, b)) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) - self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) + a1 = a_orig.clone().requires_grad_(grad) + b1 = b_orig.clone().requires_grad_(grad) + a2 = a_orig.clone().requires_grad_(grad) + b2 = b_orig.clone().requires_grad_(grad) + + eager_result = fn(a1, b1) + compiled_result = compiled_fn(a2, b2) + self.assertEqual(eager_result, compiled_result) + if grad: + eager_result.sum().backward() + compiled_result.sum().backward() + self.assertEqual(a1.grad, a2.grad) + self.assertEqual(b1.grad, b2.grad) + self.assertEqual( + counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1 + ) + self.assertEqual( + counters["inductor"]["fxgraph_cache_hit"], grad_multiplier * 1 + ) + self.assertEqual( + counters["inductor"]["fxgraph_lookup_write_file"], grad_multiplier * 1 + ) if bundle_triton and device != "cpu": - self.assertEqual(counters["inductor"]["triton_bundler_save_kernel"], 7) self.assertEqual( - counters["inductor"]["triton_bundler_read_and_emit_kernel"], 7 + counters["inductor"]["triton_bundler_save_kernel"], + grad_multiplier * 7, + ) + self.assertEqual( + counters["inductor"]["triton_bundler_read_and_emit_kernel"], + grad_multiplier * 7, ) @requires_triton() @@ -448,29 +491,34 @@ class TestFxGraphCache(TestCase): @requires_triton() @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) - def test_triton_higher_order_op_bypass(self): + @parametrize("bundle_triton", (False, True)) + @parametrize("grad", (False, True)) + def test_triton_higher_order_op_bypass(self, bundle_triton, grad): """ - Verify that we bypass the cache when we have a triton higher order ops. + Verify that we bypass the cache when we have a triton higher order ops + and that bundler start/end works with a cache bypass. """ def fn(x, y): - output = torch.zeros_like(x) - n_elements = output.numel() + n_elements = x.numel() grid = lambda meta: ( # noqa: E731 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) - return output + add_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4) + return x - compiled_fn = torch.compile(fn, fullgraph=True) + with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton): + compiled_fn = torch.compile(fn, fullgraph=True) - x = torch.randn(4, device=GPU_TYPE) - y = torch.randn(4, device=GPU_TYPE) - compiled_fn(x, y) + x = torch.randn(4, device=GPU_TYPE, requires_grad=grad) + y = torch.randn(4, device=GPU_TYPE, requires_grad=grad) + result = compiled_fn(x, y) + if grad: + result.sum().backward() - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0) @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index fc4cbba2a702..8519a93d2f7c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1145,9 +1145,12 @@ class FxGraphCache: triton_bundler_meta = TritonBundler.read_and_emit(bundle) if (meta := triton_bundler_meta) is not None: cache_info["triton_bundler_meta"] = str(meta) - get_chromium_event_logger().add_event_data( - "inductor_compile", cached_kernel_names=meta.cached_kernel_names - ) + logger = get_chromium_event_logger() + if "inductor_compile" in logger.get_stack(): + # TODO: Clean up autograd cache integration + logger.add_event_data( + "inductor_compile", cached_kernel_names=meta.cached_kernel_names + ) inductor_meta = autotune_cache.inductor_meta_from_config() AutotuneCacheBundler.begin_compile(inductor_meta, code=code) @@ -1504,13 +1507,20 @@ class FxGraphCache: assert compiled_graph is None assert key_info is not None start_time = cache_info["cache_event_time"] - compiled_graph = compile_fx_fn( - gm, example_inputs, inputs_to_check, fx_kwargs - ) - compiled_graph._time_taken_ns = time_ns() - start_time - cache_key = key_info[0] - compiled_graph._fx_graph_cache_key = cache_key - compiled_graph._triton_bundle, triton_bundler_meta = TritonBundler.collect() + TritonBundler.begin_compile() + try: + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) + compiled_graph._time_taken_ns = time_ns() - start_time + cache_key = key_info[0] + compiled_graph._fx_graph_cache_key = cache_key + ( + compiled_graph._triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + finally: + TritonBundler.end_compile() if triton_bundler_meta is not None: cache_info["triton_bundler_meta"] = str(triton_bundler_meta) cache_info["time_taken_ns"] = compiled_graph._time_taken_ns diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 67c72fc2f304..86b580739294 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -97,7 +97,6 @@ from .runtime import autotune_cache from .runtime.autotune_cache import AutotuneCacheBundler from .scheduler import BaseSchedulerNode from .sizevars import SizeVarAllocator -from .triton_bundler import TritonBundler from .utils import ( convert_shape_to_inductor, gather_origins, @@ -1966,7 +1965,6 @@ class GraphLowering(torch.fx.Interpreter): inductor_meta = autotune_cache.inductor_meta_from_config() AutotuneCacheBundler.begin_compile(inductor_meta, code=code) - TritonBundler.begin_compile() try: linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc] diff --git a/torch/_inductor/triton_bundler.py b/torch/_inductor/triton_bundler.py index 7835f168b376..ff33730079c4 100644 --- a/torch/_inductor/triton_bundler.py +++ b/torch/_inductor/triton_bundler.py @@ -70,6 +70,8 @@ class TritonBundler: - TritonBundler.begin_compile is called when we start compiling in Inductor - TritonBundler.put is called each time a Triton Kernel is compiled - TritonBundler.collect is called when a cache entry is being generated + - TritonBundler.end_compile is called to indicate bundling is completed, + collect will execute this function as well. - TritonBundler.read_and_emit is called when a cache entry is read """ @@ -92,7 +94,9 @@ class TritonBundler: if not config.is_fbcode(): return False - return justknobs_check("pytorch/remote_cache:bundle_triton_into_fx_graph_cache") + return justknobs_check( + "pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2" + ) @classmethod def begin_compile(cls) -> None: @@ -102,9 +106,19 @@ class TritonBundler: """ if not TritonBundler.is_enabled(): return + log.debug("TritonBundler.begin_compile is called") assert cls._entries is None cls._entries = [] + @classmethod + def end_compile(cls) -> None: + """ + Finalizes the TritonBundler. If collect is not yet called, it + discards the current bundle. + """ + log.debug("TritonBundler.end_compile is called") + cls._entries = None + @classmethod def put(cls, kernel_hash: str, device: int) -> None: """ @@ -127,7 +141,7 @@ class TritonBundler: This function also finalizes the current bundle. """ if not TritonBundler.is_enabled(): - cls._entries = None + cls.end_compile() return [], None with dynamo_timed( @@ -171,7 +185,7 @@ class TritonBundler: artifacts, ) ) - cls._entries = None + cls.end_compile() return result, TritonBundlerMetadata(kernel_names) return [], None