End TritonBundle on non-cache write codepaths (#139698)

Summary:
When we bypass cache write on inductor, we were also forgetting to reset the bundle, this moves resetting the bundle into post_compile step so it gets uniformly reset.

This diff also turns on the cache for internal so that we can do a code rollout.

Test Plan: updated tests

Differential Revision: D65457224

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139698
Approved by: https://github.com/ezyang
This commit is contained in:
Oguz Ulgen
2024-11-05 17:00:40 +00:00
committed by PyTorch MergeBot
parent 4d5cc1b4ef
commit c0d21b6581
4 changed files with 112 additions and 42 deletions

View File

@ -123,7 +123,8 @@ class TestFxGraphCache(TestCase):
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("dynamic", (False, True))
@parametrize("bundle_triton", (False, True))
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton):
@parametrize("grad", (False, True))
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton, grad):
"""
Verify that we can populate and load functions from the cache.
"""
@ -132,23 +133,43 @@ class TestFxGraphCache(TestCase):
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
def fn(x, y):
return (x * 2, y @ y)
grad_multiplier = 2 if grad else 1
a = torch.rand(25, dtype=dtype, device=device)
b = torch.rand(5, 5, dtype=dtype, device=device)
def fn(x, y):
yy = y @ y
return x * 2 + yy.view(25)
a_orig = torch.rand(25, dtype=dtype, device=device)
b_orig = torch.rand(5, 5, dtype=dtype, device=device)
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(fn, dynamic=dynamic)
a1 = a_orig.clone().requires_grad_(grad)
b1 = b_orig.clone().requires_grad_(grad)
a2 = a_orig.clone().requires_grad_(grad)
b2 = b_orig.clone().requires_grad_(grad)
# A first call should miss in the cache.
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
eager_result = fn(a1, b1)
compiled_result = compiled_fn(a2, b2)
self.assertEqual(eager_result, compiled_result)
if grad:
eager_result.sum().backward()
compiled_result.sum().backward()
self.assertEqual(a1.grad, a2.grad)
self.assertEqual(b1.grad, b2.grad)
self.assertEqual(
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1
)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
if bundle_triton and device != "cpu":
self.assertEqual(counters["inductor"]["triton_bundler_save_kernel"], 7)
self.assertEqual(
counters["inductor"]["triton_bundler_save_kernel"],
grad_multiplier * 7,
)
self.assertEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0
)
@ -161,15 +182,37 @@ class TestFxGraphCache(TestCase):
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
a1 = a_orig.clone().requires_grad_(grad)
b1 = b_orig.clone().requires_grad_(grad)
a2 = a_orig.clone().requires_grad_(grad)
b2 = b_orig.clone().requires_grad_(grad)
eager_result = fn(a1, b1)
compiled_result = compiled_fn(a2, b2)
self.assertEqual(eager_result, compiled_result)
if grad:
eager_result.sum().backward()
compiled_result.sum().backward()
self.assertEqual(a1.grad, a2.grad)
self.assertEqual(b1.grad, b2.grad)
self.assertEqual(
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1
)
self.assertEqual(
counters["inductor"]["fxgraph_cache_hit"], grad_multiplier * 1
)
self.assertEqual(
counters["inductor"]["fxgraph_lookup_write_file"], grad_multiplier * 1
)
if bundle_triton and device != "cpu":
self.assertEqual(counters["inductor"]["triton_bundler_save_kernel"], 7)
self.assertEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 7
counters["inductor"]["triton_bundler_save_kernel"],
grad_multiplier * 7,
)
self.assertEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
grad_multiplier * 7,
)
@requires_triton()
@ -448,29 +491,34 @@ class TestFxGraphCache(TestCase):
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_triton_higher_order_op_bypass(self):
@parametrize("bundle_triton", (False, True))
@parametrize("grad", (False, True))
def test_triton_higher_order_op_bypass(self, bundle_triton, grad):
"""
Verify that we bypass the cache when we have a triton higher order ops.
Verify that we bypass the cache when we have a triton higher order ops
and that bundler start/end works with a cache bypass.
"""
def fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output
add_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
return x
compiled_fn = torch.compile(fn, fullgraph=True)
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(fn, fullgraph=True)
x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)
compiled_fn(x, y)
x = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
y = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
result = compiled_fn(x, y)
if grad:
result.sum().backward()
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})

View File

@ -1145,9 +1145,12 @@ class FxGraphCache:
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
if (meta := triton_bundler_meta) is not None:
cache_info["triton_bundler_meta"] = str(meta)
get_chromium_event_logger().add_event_data(
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
logger = get_chromium_event_logger()
if "inductor_compile" in logger.get_stack():
# TODO: Clean up autograd cache integration
logger.add_event_data(
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
inductor_meta = autotune_cache.inductor_meta_from_config()
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
@ -1504,13 +1507,20 @@ class FxGraphCache:
assert compiled_graph is None
assert key_info is not None
start_time = cache_info["cache_event_time"]
compiled_graph = compile_fx_fn(
gm, example_inputs, inputs_to_check, fx_kwargs
)
compiled_graph._time_taken_ns = time_ns() - start_time
cache_key = key_info[0]
compiled_graph._fx_graph_cache_key = cache_key
compiled_graph._triton_bundle, triton_bundler_meta = TritonBundler.collect()
TritonBundler.begin_compile()
try:
compiled_graph = compile_fx_fn(
gm, example_inputs, inputs_to_check, fx_kwargs
)
compiled_graph._time_taken_ns = time_ns() - start_time
cache_key = key_info[0]
compiled_graph._fx_graph_cache_key = cache_key
(
compiled_graph._triton_bundle,
triton_bundler_meta,
) = TritonBundler.collect()
finally:
TritonBundler.end_compile()
if triton_bundler_meta is not None:
cache_info["triton_bundler_meta"] = str(triton_bundler_meta)
cache_info["time_taken_ns"] = compiled_graph._time_taken_ns

View File

@ -97,7 +97,6 @@ from .runtime import autotune_cache
from .runtime.autotune_cache import AutotuneCacheBundler
from .scheduler import BaseSchedulerNode
from .sizevars import SizeVarAllocator
from .triton_bundler import TritonBundler
from .utils import (
convert_shape_to_inductor,
gather_origins,
@ -1966,7 +1965,6 @@ class GraphLowering(torch.fx.Interpreter):
inductor_meta = autotune_cache.inductor_meta_from_config()
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
TritonBundler.begin_compile()
try:
linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc]

View File

@ -70,6 +70,8 @@ class TritonBundler:
- TritonBundler.begin_compile is called when we start compiling in Inductor
- TritonBundler.put is called each time a Triton Kernel is compiled
- TritonBundler.collect is called when a cache entry is being generated
- TritonBundler.end_compile is called to indicate bundling is completed,
collect will execute this function as well.
- TritonBundler.read_and_emit is called when a cache entry is read
"""
@ -92,7 +94,9 @@ class TritonBundler:
if not config.is_fbcode():
return False
return justknobs_check("pytorch/remote_cache:bundle_triton_into_fx_graph_cache")
return justknobs_check(
"pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2"
)
@classmethod
def begin_compile(cls) -> None:
@ -102,9 +106,19 @@ class TritonBundler:
"""
if not TritonBundler.is_enabled():
return
log.debug("TritonBundler.begin_compile is called")
assert cls._entries is None
cls._entries = []
@classmethod
def end_compile(cls) -> None:
"""
Finalizes the TritonBundler. If collect is not yet called, it
discards the current bundle.
"""
log.debug("TritonBundler.end_compile is called")
cls._entries = None
@classmethod
def put(cls, kernel_hash: str, device: int) -> None:
"""
@ -127,7 +141,7 @@ class TritonBundler:
This function also finalizes the current bundle.
"""
if not TritonBundler.is_enabled():
cls._entries = None
cls.end_compile()
return [], None
with dynamo_timed(
@ -171,7 +185,7 @@ class TritonBundler:
artifacts,
)
)
cls._entries = None
cls.end_compile()
return result, TritonBundlerMetadata(kernel_names)
return [], None