mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
End TritonBundle on non-cache write codepaths (#139698)
Summary: When we bypass cache write on inductor, we were also forgetting to reset the bundle, this moves resetting the bundle into post_compile step so it gets uniformly reset. This diff also turns on the cache for internal so that we can do a code rollout. Test Plan: updated tests Differential Revision: D65457224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139698 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d5cc1b4ef
commit
c0d21b6581
@ -123,7 +123,8 @@ class TestFxGraphCache(TestCase):
|
||||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton):
|
||||
@parametrize("grad", (False, True))
|
||||
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton, grad):
|
||||
"""
|
||||
Verify that we can populate and load functions from the cache.
|
||||
"""
|
||||
@ -132,23 +133,43 @@ class TestFxGraphCache(TestCase):
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
|
||||
def fn(x, y):
|
||||
return (x * 2, y @ y)
|
||||
grad_multiplier = 2 if grad else 1
|
||||
|
||||
a = torch.rand(25, dtype=dtype, device=device)
|
||||
b = torch.rand(5, 5, dtype=dtype, device=device)
|
||||
def fn(x, y):
|
||||
yy = y @ y
|
||||
return x * 2 + yy.view(25)
|
||||
|
||||
a_orig = torch.rand(25, dtype=dtype, device=device)
|
||||
b_orig = torch.rand(5, 5, dtype=dtype, device=device)
|
||||
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
a1 = a_orig.clone().requires_grad_(grad)
|
||||
b1 = b_orig.clone().requires_grad_(grad)
|
||||
a2 = a_orig.clone().requires_grad_(grad)
|
||||
b2 = b_orig.clone().requires_grad_(grad)
|
||||
|
||||
# A first call should miss in the cache.
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
eager_result = fn(a1, b1)
|
||||
compiled_result = compiled_fn(a2, b2)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
if grad:
|
||||
eager_result.sum().backward()
|
||||
compiled_result.sum().backward()
|
||||
self.assertEqual(a1.grad, a2.grad)
|
||||
self.assertEqual(b1.grad, b2.grad)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
|
||||
|
||||
if bundle_triton and device != "cpu":
|
||||
self.assertEqual(counters["inductor"]["triton_bundler_save_kernel"], 7)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_kernel"],
|
||||
grad_multiplier * 7,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0
|
||||
)
|
||||
@ -161,15 +182,37 @@ class TestFxGraphCache(TestCase):
|
||||
PyCodeCache.cache_clear()
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||
a1 = a_orig.clone().requires_grad_(grad)
|
||||
b1 = b_orig.clone().requires_grad_(grad)
|
||||
a2 = a_orig.clone().requires_grad_(grad)
|
||||
b2 = b_orig.clone().requires_grad_(grad)
|
||||
|
||||
eager_result = fn(a1, b1)
|
||||
compiled_result = compiled_fn(a2, b2)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
if grad:
|
||||
eager_result.sum().backward()
|
||||
compiled_result.sum().backward()
|
||||
self.assertEqual(a1.grad, a2.grad)
|
||||
self.assertEqual(b1.grad, b2.grad)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["fxgraph_cache_hit"], grad_multiplier * 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["fxgraph_lookup_write_file"], grad_multiplier * 1
|
||||
)
|
||||
|
||||
if bundle_triton and device != "cpu":
|
||||
self.assertEqual(counters["inductor"]["triton_bundler_save_kernel"], 7)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 7
|
||||
counters["inductor"]["triton_bundler_save_kernel"],
|
||||
grad_multiplier * 7,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
|
||||
grad_multiplier * 7,
|
||||
)
|
||||
|
||||
@requires_triton()
|
||||
@ -448,29 +491,34 @@ class TestFxGraphCache(TestCase):
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
def test_triton_higher_order_op_bypass(self):
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
@parametrize("grad", (False, True))
|
||||
def test_triton_higher_order_op_bypass(self, bundle_triton, grad):
|
||||
"""
|
||||
Verify that we bypass the cache when we have a triton higher order ops.
|
||||
Verify that we bypass the cache when we have a triton higher order ops
|
||||
and that bundler start/end works with a cache bypass.
|
||||
"""
|
||||
|
||||
def fn(x, y):
|
||||
output = torch.zeros_like(x)
|
||||
n_elements = output.numel()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: ( # noqa: E731
|
||||
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
||||
)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
|
||||
return output
|
||||
add_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
|
||||
return x
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True)
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
compiled_fn = torch.compile(fn, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
y = torch.randn(4, device=GPU_TYPE)
|
||||
compiled_fn(x, y)
|
||||
x = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
|
||||
y = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
|
||||
result = compiled_fn(x, y)
|
||||
if grad:
|
||||
result.sum().backward()
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
|
@ -1145,9 +1145,12 @@ class FxGraphCache:
|
||||
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
|
||||
if (meta := triton_bundler_meta) is not None:
|
||||
cache_info["triton_bundler_meta"] = str(meta)
|
||||
get_chromium_event_logger().add_event_data(
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
logger = get_chromium_event_logger()
|
||||
if "inductor_compile" in logger.get_stack():
|
||||
# TODO: Clean up autograd cache integration
|
||||
logger.add_event_data(
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
@ -1504,13 +1507,20 @@ class FxGraphCache:
|
||||
assert compiled_graph is None
|
||||
assert key_info is not None
|
||||
start_time = cache_info["cache_event_time"]
|
||||
compiled_graph = compile_fx_fn(
|
||||
gm, example_inputs, inputs_to_check, fx_kwargs
|
||||
)
|
||||
compiled_graph._time_taken_ns = time_ns() - start_time
|
||||
cache_key = key_info[0]
|
||||
compiled_graph._fx_graph_cache_key = cache_key
|
||||
compiled_graph._triton_bundle, triton_bundler_meta = TritonBundler.collect()
|
||||
TritonBundler.begin_compile()
|
||||
try:
|
||||
compiled_graph = compile_fx_fn(
|
||||
gm, example_inputs, inputs_to_check, fx_kwargs
|
||||
)
|
||||
compiled_graph._time_taken_ns = time_ns() - start_time
|
||||
cache_key = key_info[0]
|
||||
compiled_graph._fx_graph_cache_key = cache_key
|
||||
(
|
||||
compiled_graph._triton_bundle,
|
||||
triton_bundler_meta,
|
||||
) = TritonBundler.collect()
|
||||
finally:
|
||||
TritonBundler.end_compile()
|
||||
if triton_bundler_meta is not None:
|
||||
cache_info["triton_bundler_meta"] = str(triton_bundler_meta)
|
||||
cache_info["time_taken_ns"] = compiled_graph._time_taken_ns
|
||||
|
@ -97,7 +97,6 @@ from .runtime import autotune_cache
|
||||
from .runtime.autotune_cache import AutotuneCacheBundler
|
||||
from .scheduler import BaseSchedulerNode
|
||||
from .sizevars import SizeVarAllocator
|
||||
from .triton_bundler import TritonBundler
|
||||
from .utils import (
|
||||
convert_shape_to_inductor,
|
||||
gather_origins,
|
||||
@ -1966,7 +1965,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
TritonBundler.begin_compile()
|
||||
|
||||
try:
|
||||
linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc]
|
||||
|
@ -70,6 +70,8 @@ class TritonBundler:
|
||||
- TritonBundler.begin_compile is called when we start compiling in Inductor
|
||||
- TritonBundler.put is called each time a Triton Kernel is compiled
|
||||
- TritonBundler.collect is called when a cache entry is being generated
|
||||
- TritonBundler.end_compile is called to indicate bundling is completed,
|
||||
collect will execute this function as well.
|
||||
- TritonBundler.read_and_emit is called when a cache entry is read
|
||||
"""
|
||||
|
||||
@ -92,7 +94,9 @@ class TritonBundler:
|
||||
if not config.is_fbcode():
|
||||
return False
|
||||
|
||||
return justknobs_check("pytorch/remote_cache:bundle_triton_into_fx_graph_cache")
|
||||
return justknobs_check(
|
||||
"pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def begin_compile(cls) -> None:
|
||||
@ -102,9 +106,19 @@ class TritonBundler:
|
||||
"""
|
||||
if not TritonBundler.is_enabled():
|
||||
return
|
||||
log.debug("TritonBundler.begin_compile is called")
|
||||
assert cls._entries is None
|
||||
cls._entries = []
|
||||
|
||||
@classmethod
|
||||
def end_compile(cls) -> None:
|
||||
"""
|
||||
Finalizes the TritonBundler. If collect is not yet called, it
|
||||
discards the current bundle.
|
||||
"""
|
||||
log.debug("TritonBundler.end_compile is called")
|
||||
cls._entries = None
|
||||
|
||||
@classmethod
|
||||
def put(cls, kernel_hash: str, device: int) -> None:
|
||||
"""
|
||||
@ -127,7 +141,7 @@ class TritonBundler:
|
||||
This function also finalizes the current bundle.
|
||||
"""
|
||||
if not TritonBundler.is_enabled():
|
||||
cls._entries = None
|
||||
cls.end_compile()
|
||||
return [], None
|
||||
|
||||
with dynamo_timed(
|
||||
@ -171,7 +185,7 @@ class TritonBundler:
|
||||
artifacts,
|
||||
)
|
||||
)
|
||||
cls._entries = None
|
||||
cls.end_compile()
|
||||
return result, TritonBundlerMetadata(kernel_names)
|
||||
return [], None
|
||||
|
||||
|
Reference in New Issue
Block a user