Files
pytorch/test/inductor/test_codecache.py
Yuanyuan Chen fdab48a7c1 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 07:36:18 +00:00

3225 lines
117 KiB
Python

# Owner(s): ["module: inductor"]
import functools
import logging
import os
import pickle
import shutil
import subprocess
import sys
import tempfile
import textwrap
import unittest
from contextlib import contextmanager
from typing import Optional, Union
from typing_extensions import override
from unittest import mock
import torch
from torch._dynamo import reset
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.utils import counters
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor import config, metrics
from torch._inductor.codecache import (
BypassFxGraphCache,
cuda_compile_command,
CUDACodeCache,
FxGraphCachePickler,
FxGraphHashDetails,
PyCodeCache,
TensorMetadata,
TensorMetadataAndValues,
)
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,
CustomGraphPass,
CustomPartitionerFn,
get_hash_for_files,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import clear_caches, fresh_cache
from torch._library import capture_triton
from torch.compiler._cache import (
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
)
from torch.testing._internal.common_cuda import (
SM80OrLater,
TEST_MULTIGPU,
with_tf32_off,
)
from torch.testing._internal.common_device_type import largeTensorTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_FBCODE,
IS_SANDCASTLE,
parametrize,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
HAS_MULTIGPU,
HAS_TRITON,
patch_inductor_backend,
requires_gpu,
requires_triton,
)
from torch.testing._internal.triton_utils import requires_cuda_and_triton
try:
from . import custom_inductor_config
except ImportError:
import custom_inductor_config
if HAS_TRITON:
import triton # @manual
from torch.testing._internal.triton_utils import add_kernel, sub_kernel
torch._dynamo.config.fake_tensor_cache_enabled = True
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
class LogCaptureHandler(logging.Handler):
def __init__(self, level):
super().__init__(level)
self.records = []
def emit(self, record):
self.records.append(record)
@contextmanager
def capture_logs(log_name, log_level):
try:
logger = logging.getLogger(log_name)
old_level = logger.level
handler = logging.Handler()
logger.setLevel(log_level)
log_records = []
def emit(record):
log_records.append(record)
handler.emit = emit
logger.addHandler(handler)
yield log_records
finally:
logger.removeHandler(handler)
logger.setLevel(old_level)
class MyModelConv2d(torch.nn.Module):
def __init__(self, dim=512):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
def forward(self, x):
x = self.conv1(x)
torch._dynamo.graph_break()
x = self.conv2(x)
return x
class TestPyCodeCache(TestCase):
def test_linemaps_empty(self):
src = """import torch"""
(key, path) = PyCodeCache.write(src, "")
# Load with an empty linemap
PyCodeCache.load_by_key_path(key, path, linemap=[])
stack_frames = PyCodeCache.stack_frames_for_code(path, 0)
self.assertEqual(stack_frames, None)
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_editable_cached_wrapper(self):
with tempfile.TemporaryDirectory() as tmpdir:
env = os.environ.copy()
env["TORCHINDUCTOR_CACHE_DIR"] = tmpdir
step1 = textwrap.dedent(
"""
import glob
import os
import torch
import warnings
from torch._inductor import config
warnings.filterwarnings("ignore")
config.fx_graph_cache = True
config.fx_graph_remote_cache = False
torch._dynamo.reset()
@torch.compile(backend="inductor")
def f(x):
return x * 2
f(torch.ones(2))
cache_dir = os.environ["TORCHINDUCTOR_CACHE_DIR"]
pyfiles = glob.glob(os.path.join(cache_dir, "**", "*.py"), recursive=True)
print(pyfiles[0])
"""
)
wrapper_path = (
subprocess.check_output([sys.executable, "-c", step1], env=env)
.decode()
.strip()
)
step2 = textwrap.dedent(
"""
import torch
import warnings
from torch._dynamo.utils import counters
from torch._inductor import config
warnings.filterwarnings("ignore")
config.fx_graph_cache = True
config.fx_graph_remote_cache = False
torch._dynamo.reset()
@torch.compile(backend="inductor")
def f(x):
return x * 2
f(torch.ones(2))
print(counters["inductor"]["fxgraph_cache_hit"])
"""
)
hit = (
subprocess.check_output([sys.executable, "-c", step2], env=env)
.decode()
.strip()
)
self.assertEqual(hit, "1")
with open(wrapper_path) as f:
src = f.read()
with open(wrapper_path, "w") as f:
f.write(
src.replace(
"def call(self, args):",
"def call(self, args):\n print('debug')",
)
)
step3 = textwrap.dedent(
"""
import torch
import warnings
from torch._inductor import config
warnings.filterwarnings("ignore")
config.fx_graph_cache = True
config.fx_graph_remote_cache = False
torch._dynamo.reset()
@torch.compile(backend="inductor")
def f(x):
return x * 2
f(torch.ones(2))
"""
)
out = subprocess.check_output(
[sys.executable, "-c", step3], env=env
).decode()
self.assertIn("debug", out)
@instantiate_parametrized_tests
class TestFxGraphCache(TestCase):
device_type = GPU_TYPE
def setUp(self):
super().setUp()
counters.clear()
DynamoCache.clear()
PrecompileContext.clear()
AOTAutogradCache.clear()
PatchCaches.setUp()
CacheArtifactManager.clear()
torch._dynamo.reset()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
def reset(self):
AOTAutogradCache.clear()
DynamoCache.clear()
PrecompileContext.clear()
PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset()
clear_caches()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"compile_threads": 1})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("dynamic", (False, True))
@parametrize("bundle_triton", (False, True))
@parametrize("use_static_cuda_launcher", (False, True))
@parametrize("grad", (False, True))
def test_cache_load_function(
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher, grad
):
"""
Verify that we can populate and load functions from the cache.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
raise unittest.SkipTest(
"Static cuda launcher requires cuda and triton bundling"
)
if use_static_cuda_launcher and TEST_WITH_ROCM:
raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM")
grad_multiplier = 2 if grad else 1
def fn(x, y):
yy = y @ y
return x * 2 + yy.view(25)
a_orig = torch.rand(25, dtype=dtype, device=device)
b_orig = torch.rand(5, 5, dtype=dtype, device=device)
with config.patch(
bundle_triton_into_fx_graph_cache=bundle_triton,
use_static_cuda_launcher=use_static_cuda_launcher,
):
compiled_fn = torch.compile(fn, dynamic=dynamic)
a1 = a_orig.clone().requires_grad_(grad)
b1 = b_orig.clone().requires_grad_(grad)
a2 = a_orig.clone().requires_grad_(grad)
b2 = b_orig.clone().requires_grad_(grad)
# A first call should miss in the cache.
eager_result = fn(a1, b1)
compiled_result = compiled_fn(a2, b2)
self.assertEqual(eager_result, compiled_result)
if grad:
eager_result.sum().backward()
compiled_result.sum().backward()
self.assertEqual(a1.grad, a2.grad)
self.assertEqual(b1.grad, b2.grad)
self.assertEqual(
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1
)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
# we expect:
# .ttir
# .ttgir
# .llir
# .ptx (cuda) or .spv (xpu)
# .json
# __grp__.*.json
# optionally, we can also get
# .cubin (CUDA only)
# .source (new versions of triton only, triton-lang/triton#6992)
# to avoid depending on the device and triton version, just assert that
# we have at least 6 kernels.
save_and_read_min_artifact_count = 6
if bundle_triton and device != "cpu":
self.assertGreaterEqual(
counters["inductor"]["triton_bundler_save_kernel"],
grad_multiplier * save_and_read_min_artifact_count,
)
self.assertEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0
)
if use_static_cuda_launcher:
self.assertEqual(
counters["inductor"]["triton_bundler_save_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
self.assertEqual(
counters["inductor"]["triton_bundler_load_static_autotuner"], 0
)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
a1 = a_orig.clone().requires_grad_(grad)
b1 = b_orig.clone().requires_grad_(grad)
a2 = a_orig.clone().requires_grad_(grad)
b2 = b_orig.clone().requires_grad_(grad)
eager_result = fn(a1, b1)
compiled_result = compiled_fn(a2, b2)
self.assertEqual(eager_result, compiled_result)
if grad:
eager_result.sum().backward()
compiled_result.sum().backward()
self.assertEqual(a1.grad, a2.grad)
self.assertEqual(b1.grad, b2.grad)
self.assertEqual(
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 1
)
self.assertEqual(
counters["inductor"]["fxgraph_cache_hit"], grad_multiplier * 1
)
self.assertEqual(
counters["inductor"]["fxgraph_lookup_write_file"], grad_multiplier * 1
)
if bundle_triton and device != "cpu":
self.assertGreaterEqual(
counters["inductor"]["triton_bundler_save_kernel"],
grad_multiplier * save_and_read_min_artifact_count,
)
self.assertGreaterEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
grad_multiplier * save_and_read_min_artifact_count,
)
if use_static_cuda_launcher:
self.assertEqual(
counters["inductor"]["triton_bundler_save_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
self.assertEqual(
counters["inductor"]["triton_bundler_load_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
self.reset()
a1 = a_orig.clone().requires_grad_(grad)
b1 = b_orig.clone().requires_grad_(grad)
a2 = a_orig.clone().requires_grad_(grad)
b2 = b_orig.clone().requires_grad_(grad)
eager_result = fn(a1, b1)
if grad:
eager_result.sum().backward()
with torch.compiler.config.patch({"cache_key_tag": "test"}):
compiled_result = compiled_fn(a2, b2)
if grad:
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
if grad:
self.assertEqual(a1.grad, a2.grad)
self.assertEqual(b1.grad, b2.grad)
self.assertEqual(
counters["inductor"]["fxgraph_cache_miss"], grad_multiplier * 2
)
self.assertEqual(
counters["inductor"]["fxgraph_cache_hit"], grad_multiplier * 1
)
self.assertEqual(
counters["inductor"]["fxgraph_lookup_write_file"], grad_multiplier * 1
)
if bundle_triton and device != "cpu":
self.assertGreaterEqual(
counters["inductor"]["triton_bundler_save_kernel"],
grad_multiplier * save_and_read_min_artifact_count * 2,
)
self.assertGreaterEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
grad_multiplier * save_and_read_min_artifact_count,
)
if use_static_cuda_launcher:
self.assertEqual(
counters["inductor"]["triton_bundler_save_static_autotuner"],
grad_multiplier * 2 if device == "cuda" else 0,
)
self.assertEqual(
counters["inductor"]["triton_bundler_load_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
@requires_triton()
@config.patch({"fx_graph_remote_cache": True})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("dynamic", (False, True))
@parametrize("bundle_triton", (False, True))
@parametrize("use_static_cuda_launcher", (False, True))
@config.patch(
{"compile_threads": 1}
) # Can't check globalStats if there are workers
def test_remote_cache_load_function(
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher
):
from unittest.mock import patch
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
raise unittest.SkipTest(
"Static cuda launcher requires cuda and triton bundling"
)
if use_static_cuda_launcher and TEST_WITH_ROCM:
raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM")
def fn(x, y):
return (x * 2, y @ y)
a = torch.rand(25, dtype=dtype, device=device)
b = torch.rand(5, 5, dtype=dtype, device=device)
with (
config.patch(
{
"fx_graph_remote_cache": True,
"bundle_triton_into_fx_graph_cache": bundle_triton,
"use_static_cuda_launcher": use_static_cuda_launcher,
}
),
patch.dict(os.environ),
PatchCaches(),
):
os.environ.pop("TRITON_CACHE_MANAGER", None)
for _ in range(4):
with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
self.assertEqual(fn(a, b), compiled_fn(a, b))
reset()
self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1))
with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(global_stats.fx_graph, Stats(2, 3, 2))
# Check that the cache entries seem reasonable
for k in global_stats.fx_graph.cache.keys():
self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+")
@requires_triton()
@config.patch(
{
"fx_graph_cache": True,
"fx_graph_remote_cache": False,
"autotune_local_cache": True,
}
)
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("dynamic", (False, True))
@torch._functorch.config.patch({"enable_autograd_cache": False})
def test_cache_hot_load(self, device, dtype, dynamic):
"""
Verify that we can populate and hot load functions from the cache.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
def fn(x, y):
return x.sin() @ y
a = torch.rand(100, 100, dtype=dtype, device=device)
b = torch.rand(100, 100, dtype=dtype, device=device)
# Record artifacts
with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call should miss in the cache.
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
autotune_expect = 1 if device == GPU_TYPE else 0
self.assertEqual(len(cache_info.inductor_artifacts), 1)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 0)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# We did not load anything so dont hit yet
with fresh_cache():
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with fresh_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 1)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 0)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@requires_triton()
@config.patch(
{
"fx_graph_cache": True,
"fx_graph_remote_cache": False,
"autotune_local_cache": True,
}
)
@torch._dynamo.config.patch(
{
"caching_precompile": True,
}
)
@parametrize("dynamic", (False, True))
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic):
"""
Verify that we can populate and hot load functions from the cache.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
def fn(x, y):
return x.sin() @ y
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
# Record artifacts
with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call should miss in the cache.
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 1)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
autotune_expect = 2 if device == GPU_TYPE else 0
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.assertEqual(len(cache_info.precompile_artifacts), 1)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# We did not load anything so dont hit yet
with fresh_cache():
eager_result = fn(a, b)
# With caching precompile, we have to re torch.compile the function
# to trigger cache lookup
compiled_fn = torch.compile(fn, dynamic=dynamic)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with fresh_cache(), torch.compiler.set_stance("fail_on_recompile"):
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.assertEqual(len(cache_info.precompile_artifacts), 1)
# With caching precompile, we have to re torch.compile the function
# to trigger cache lookup
compiled_fn = torch.compile(fn, dynamic=dynamic)
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 1)
@config.patch(
{
"fx_graph_cache": True,
"fx_graph_remote_cache": False,
}
)
def test_cache_hot_load_repeat(self):
def fn(x, y):
return x @ y.sin()
compiled_fn = torch.compile(fn, dynamic=False)
a = torch.randn(4, 4)
b = torch.randn(4, 4)
a2 = torch.randn(4, 8)
b2 = torch.randn(8, 4)
with fresh_cache():
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
artifacts = torch.compiler.save_cache_artifacts()
self.assertFalse(torch.compiler._cache.CacheArtifactManager.need_serialize())
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
self.reset()
with fresh_cache():
torch.compiler.load_cache_artifacts(artifact_bytes)
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertFalse(torch.compiler._cache.CacheArtifactManager.need_serialize())
self.reset()
with fresh_cache():
eager_result = fn(a2, b2)
compiled_result = compiled_fn(a2, b2)
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertTrue(torch.compiler._cache.CacheArtifactManager.need_serialize())
@torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
@torch._functorch.config.patch({"enable_autograd_cache": False})
@config.patch({"fx_graph_cache": True, "fx_graph_remote_cache": False})
def test_cache_hot_load_pgo(self):
"""
Verify that we can populate and hot load functions from the cache with pgo.
"""
backend = torch._dynamo.testing.CompileCounterWithBackend("inductor")
@torch.compile(backend=backend, fullgraph=True)
def f(x):
return x * 2
# Record artifacts
with torch.compiler.config.patch(job_id=self.id()), fresh_cache():
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(backend.frame_count, 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), 0)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 0)
self.assertEqual(len(cache_info.pgo_artifacts), 2)
self.reset()
backend.clear()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with torch.compiler.config.patch({"job_id": self.id()}), fresh_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), 0)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 0)
self.assertEqual(len(cache_info.pgo_artifacts), 2)
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(backend.frame_count, 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
@torch._functorch.config.patch({"enable_autograd_cache": False})
@config.patch({"fx_graph_cache": True, "fx_graph_remote_cache": False})
def test_cache_hot_load_pgo_swap_file_names(self):
"""
Verify that we can populate and hot load functions from the cache with pgo
with file name swapping
"""
backend = torch._dynamo.testing.CompileCounterWithBackend("inductor")
@torch.compile(backend=backend, fullgraph=True)
def f(x):
return x * 2
# Record artifacts
with mock.patch(
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
):
with fresh_cache():
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(backend.frame_count, 2)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
self.assertEqual(len(cache_info.pgo_artifacts), 2)
self.reset()
backend.clear()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with (
mock.patch(
"torch._utils_internal.get_mast_job_name_version",
return_value=("bar", 10),
),
fresh_cache(),
):
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.pgo_artifacts), 2)
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(backend.frame_count, 1)
def test_cache_hot_load_empty(self):
self.assertIsNone(torch.compiler.save_cache_artifacts())
def test_cache_hot_load_generic(self):
class CacheStub:
def __init__(self):
self.cache = {}
def lookup(self, key):
content = self.cache.get(key)
if content is None:
return None
CacheArtifactManager.record_artifact(
ArbitraryCacheArtifact.type(), key, content
)
return content
def save(self, key, content):
self.cache[key] = content
CacheArtifactManager.record_artifact(
ArbitraryCacheArtifact.type(), key, content
)
def clear(self):
self.cache.clear()
cache_stub = CacheStub()
@CacheArtifactFactory.register
class ArbitraryCacheArtifact(CacheArtifact):
@override
def populate_cache(self) -> None:
cache_stub.cache[self.key] = self.content.decode()
@override
@staticmethod
def type() -> str:
return "test"
@override
@staticmethod
def encode(content: str) -> bytes:
return content.encode()
test_cache = {"1": "foo", "2": "bar", "foo": "bar"}
for k, v in test_cache.items():
cache_stub.save(k, v)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
self.assertEqual(len(cache_info.test_artifacts), 3)
cache_stub.clear()
CacheArtifactManager.clear()
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.test_artifacts), 3)
self.assertEqual(cache_stub.cache, test_cache)
CacheArtifactManager.clear()
cache_stub.lookup("foo")
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
_, cache_info = artifacts
self.assertEqual(len(cache_info.test_artifacts), 1)
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.float64))
@parametrize("dynamic", (False, True))
def test_cache_load_model(self, device, dtype, dynamic):
"""
Verify that we can populate and load models from the cache.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
def fn(mod, x):
mod.zero_grad()
mod(x).sum().backward()
return [p.grad for p in mod.parameters()]
compiled_fn = torch.compile(fn, dynamic=dynamic)
mod = MyModelConv2d().to(device=device, dtype=dtype)
inp = torch.randn(2, 3, 16, 32, device=device, dtype=dtype)
# The first call should see all cache misses.
counters.clear()
grads1 = compiled_fn(mod, inp)
self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# The second should see all hits. (First reset so in-memory guards
# don't prevent compilation).
counters.clear()
self.reset()
grads2 = compiled_fn(mod, inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
# And the results should be the same.
self.assertEqual(grads1, grads2)
@largeTensorTest("64GB", device=GPU_TYPE, inductor=True)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("device", (GPU_TYPE,))
@parametrize("dtype", (torch.float16, torch.bfloat16))
def test_cache_load_with_guards_int32_bounds(self, device, dtype):
"""
Test caching the same graph, but under conditions that introduce guards
for tensor sizes < int32.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires CUDA SM80 or later")
def fn(x, y):
return (x + x, y + y)
compiled_fn = torch.compile(fn, dynamic=True)
# Iterate over different shapes, varying whether the total
# size is below or above int32. For each combination, we expect
# different guards around whether the symbolic sizes do or do
# not exceed int32.
shapes = (
((5, 6), (7, 8)),
((5, 6), (47000, 47001)),
((47000, 47001), (5, 6)),
)
for a_shape, b_shape in shapes:
a = torch.rand(a_shape, device=device, dtype=dtype)
b = torch.rand(b_shape, device=device, dtype=dtype)
# AVOID a dynamo reset here. We expect guards to have been
# added that will be violated with the new shape. We should
# see a recompilation (along with a cache miss).
counters.clear()
res1 = compiled_fn(a, b)
self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# A second call should hit. (Reset here to force compilation).
counters.clear()
self.reset()
res2 = compiled_fn(a, b)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(res1, res2)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_load_with_guards_static_bounds(self, device, dtype):
"""
Test caching the same graph, but under conditions that introduce guards
for static bounds.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
# See lowering; for all of the pooling operators, we always guard and
# make the height/width static.
def fn(x):
return torch.nn.functional.adaptive_avg_pool2d(x, [5, 7])
compiled_fn = torch.compile(fn, dynamic=True)
# Iterate over different input shapes. Each new shape should cause
# a cache miss.
shapes = ((1, 64, 8, 9), (1, 64, 9, 10), (1, 64, 10, 11))
for shape in shapes:
x = torch.rand(shape, device=device, dtype=dtype)
# AVOID a dynamo reset here. For each cache hit, we expect guards
# to have been added that will be violated with each new shape.
# We should see a recompilation (along with a cache miss).
counters.clear()
res1 = compiled_fn(x)
self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# A second call should hit.
counters.clear()
self.reset()
res2 = compiled_fn(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(res1, res2)
@config.patch("fx_graph_cache", True)
@torch._functorch.config.patch({"enable_autograd_cache": False})
@config.patch("fx_graph_remote_cache", False)
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@requires_cuda_and_triton
def test_no_arguments_tensor_device_guards(self):
"""
Usually, when there are example inputs, the device index of the inputs
is sufficient to make sure we don't cache hit with the results from different
cuda devices.
When the input has no arguments, we still need to have the cuda
device index in the cache key.
"""
@torch.compile
def f():
y = torch.randn(3, device="cuda")
return (y,)
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
result = f()
self.assertEqual(result[0].device, torch.device("cuda:0"))
self.reset()
# Should not cache hit with device guard
with torch.cuda._DeviceGuard(1):
torch.cuda.set_device(1)
result = f()
self.assertEqual(result[0].device, torch.device("cuda:1"))
@config.patch("fx_graph_cache", True)
@torch._functorch.config.patch({"enable_autograd_cache": False})
@config.patch("fx_graph_remote_cache", False)
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@requires_cuda_and_triton
def test_tensor_device_guards_cpu_tensor(self):
"""
CPU tensor arguments should still cache hit
"""
@torch.compile
def f(x):
return x.sin()
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
result = f(torch.randn(3, device="cpu"))
self.assertEqual(result.device, torch.device("cpu"))
self.reset()
# Should not cache hit with device guard
with torch.cuda._DeviceGuard(1):
torch.cuda.set_device(1)
result = f(torch.randn(3, device="cpu"))
self.assertEqual(result.device, torch.device("cpu"))
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("device", (GPU_TYPE, "cpu"))
def test_constant_handling(self, device):
"""
Test that different constants are recognized correctly.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
def fn1(x):
return x + torch.tensor(list(range(12)), device=device)
def fn2(x):
return x + torch.tensor(list(range(1, 13)), device=device)
a = torch.rand(12, device=device)
compiled_fn1 = torch.compile(fn1)
compiled_fn2 = torch.compile(fn2)
# A call to fn1 should miss in the cache.
self.assertEqual(fn1(a), compiled_fn1(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# A call to fn2 should also miss (the constant is different)
self.assertEqual(fn2(a), compiled_fn2(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("variant", ("v1", "v2"))
def test_auto_functionalized_caching(self, variant):
if variant == "v1":
patch = torch._inductor.config.patch(enable_auto_functionalized_v2=False)
else:
assert variant == "v2"
patch = torch._inductor.config.patch(enable_auto_functionalized_v2=True)
@torch.library.custom_op("mylib::sin_inplace", mutates_args=["x"])
def sin_inplace(x: torch.Tensor) -> None:
x.sin_()
@torch.library.custom_op("mylib::cos_inplace", mutates_args=["x"])
def cos_inplace(x: torch.Tensor) -> None:
x.cos_()
@torch.compile(fullgraph=True)
def fn(x, op):
y = torch.empty_like(x)
op(y)
return y
x = torch.randn(3)
with patch:
# A first call should miss in the cache.
fn(x, sin_inplace)
self.reset()
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
fn(x, sin_inplace)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
# A third call with different operator should have a cache miss
self.reset()
fn(x, cos_inplace)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@requires_cuda_and_triton
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@with_tf32_off
def test_flex_attention_caching(self):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
block_mask = create_block_mask(
lambda b, h, q, kv: q >= kv, None, None, 512, 512
)
def score_mod(score, b, h, q, kv):
return score + (q - kv)
def fn(q, k, v):
return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)
def score_mod2(score, b, h, q, kv):
return score
def fn2(q, k, v):
return flex_attention(q, k, v, score_mod=score_mod2, block_mask=block_mask)
a, b, c = (torch.randn(1, 4, 512, 64).cuda() for _ in range(3))
compiled_fn = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
atol, rtol = 1e-4, 1e-4
# A first call should miss in the cache.
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
# A third call with different score_mod should have a cache miss
self.reset()
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@requires_gpu()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("bundle_triton", (False, True))
def test_higher_order_op_bypass(self, bundle_triton):
"""
Verify that we bypass the cache when we have a higher order ops
and that bundler start/end works with a cache bypass.
"""
def fn(x):
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0], true_fn, false_fn, (x,))
with config.patch(
bundle_triton_into_fx_graph_cache=bundle_triton,
):
compiled_fn = torch.compile(fn, dynamic=True, fullgraph=True)
x = torch.randn(4, 4, device=GPU_TYPE)
compiled_fn(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
@requires_gpu()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("bundle_triton", (False, True))
def test_triton_higher_order_op(self, bundle_triton):
"""
Verify that we can cache user defined triton kernel higher order op
"""
def fn(x, y):
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
return x
def fn2(x, y):
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
sub_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
return x
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn2 = torch.compile(fn2, fullgraph=True)
x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
compiled_fn2(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
@requires_gpu()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("bundle_triton", (False, True))
def test_triton_higher_order_op_different_configs(self, bundle_triton):
"""
Verify that user defined triton kernel with
different configs are cached separately.
"""
add_kernel1 = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
],
key=[],
)(add_kernel)
add_kernel2 = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
)(add_kernel)
def fn(x, y):
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel1[grid](x, y, x, n_elements)
return x
def fn2(x, y):
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel2[grid](x, y, x, n_elements)
return x
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn2 = torch.compile(fn2, fullgraph=True)
x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
compiled_fn2(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
@requires_gpu()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"compile_threads": 1})
@parametrize("bundle_triton", (False, True))
@parametrize("use_static_cuda_launcher", (False, True))
def test_triton_op(self, bundle_triton, use_static_cuda_launcher):
if use_static_cuda_launcher and TEST_WITH_ROCM:
raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM")
libname = "my_cool_namespace"
opname = "my_triton_operator"
@torch._library.triton_op(f"{libname}::{opname}", mutates_args={})
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output
def f(x, y):
return add(x, y)
compile_threads = 1 if use_static_cuda_launcher else config.compile_threads
with config.patch(
bundle_triton_into_fx_graph_cache=bundle_triton,
use_static_cuda_launcher=use_static_cuda_launcher,
compile_threads=compile_threads,
):
compiled_fn = torch.compile(f, fullgraph=True)
x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_generated_kernel_count(self):
"""
Test that we bump the generated_kernel_count metric on a cache hit.
"""
torch._logging.set_logs(inductor_metrics=True)
def fn(x, y):
return (x * y + y,)
a = torch.rand(5, 5)
b = torch.rand(5, 5)
compiled_fn = torch.compile(fn)
metrics.reset()
self.assertEqual(metrics.generated_kernel_count, 0)
# Verify the "miss" case.
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(metrics.generated_kernel_count, 1)
# Verify the "hit" case
self.reset()
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(metrics.generated_kernel_count, 2)
torch._logging.set_logs()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_inductor_counters(self):
"""
Test that we bump the inductor counters on a cache hit.
"""
def fn(a, b):
return torch.mm(a, b)
a = torch.rand(8, 32, device="cpu")
b = torch.rand(32, 8, device="cpu")
compiled_fn = torch.compile(fn)
# Verify the "miss" case.
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# Verify the "hit" case.
self.reset()
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_cache_clear(self):
"""
Test clearing the cache.
"""
def fn(x, y):
return (x * y,)
a = torch.rand(5, 5)
b = torch.rand(5, 5)
compiled_fn = torch.compile(fn)
# A first call should miss in the cache.
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# A second call should hit.
counters.clear()
self.reset()
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
# Clear the cache; now we should miss.
counters.clear()
self.reset()
torch._inductor.codecache.FxGraphCache.clear()
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_cache_with_nt(self):
def gen_nt(r):
values = torch.randn(r, 16)
offsets = torch.tensor([0, 2, 3, 6, 13, r])
return torch.nested.nested_tensor_from_jagged(values, offsets)
def fn(nt):
if nt.values().size(0) % 16 == 0:
return nt.sin()
return nt.cos()
inp1 = gen_nt(19)
inp2 = gen_nt(20)
counters.clear()
torch.compile(fn)(inp1)
torch.compile(fn)(inp2)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.reset()
counters.clear()
torch.compile(fn)(inp1)
torch.compile(fn)(inp2)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_cache_with_symint_non_arg_guard(self):
def fn(x, ref_id):
self_id = 22
if self_id == ref_id:
x = torch.mul(x, 1.0)
else:
x = torch.mul(x, 0)
return x
x = torch.ones(2)
counters.clear()
torch.compile(fn, fullgraph=True, dynamic=True)(x, 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.reset()
counters.clear()
torch.compile(fn, fullgraph=True, dynamic=True)(x, 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_cache_guard(self):
def f(x, val):
if val > 5:
return x.sin()
else:
return x.cos()
x = torch.ones(2)
a = torch.compile(f, dynamic=True)(x, 6)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.reset()
counters.clear()
b = torch.compile(f, dynamic=True)(x, 4)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertNotEqual(a, b)
@config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False})
@requires_cuda_and_triton
@unittest.expectedFailure # TODO: pass in optimize_mem at runtime
def test_async_compile_cache(self):
class SimpleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x * 2
@staticmethod
def backward(ctx, grad_output):
return grad_output * 2
x = torch.rand([10], requires_grad=True, device="cuda")
counters.clear()
sf = SimpleFunction
out = torch.compile(sf.apply)(x)
out.sum().backward()
self.assertEqual(counters["inductor"]["async_compile_cache_miss"], 1)
self.assertEqual(counters["inductor"]["async_compile_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
def test_cache_guard_overspec(self):
b = torch.tensor([0, 2, 4, 6, 8])
@torch.compile
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.isin(x, b)
model = MyModel()
counters.clear()
for i in range(1, 5):
model(torch.arange(i))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.reset()
counters.clear()
for i in range(1, 5):
model(torch.arange(i))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"freezing": True})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("inlinable", (True, False))
def test_freezing(self, device, inlinable):
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
# For machines with mkldnn_fp16 support, weight_pack in mkldnn_fusion.py causes
# the creation of a mkldnn format tensor which the current implementation does
# not support.
if (
device == "cpu"
and torch.backends.mkldnn.is_available()
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
):
raise unittest.SkipTest("mkldnn tensors unsupported")
# The shape of the frozen constant determines if it will be inlined.
shape = (4,) if inlinable else (8, 8)
class MM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(shape))
def forward(self, x):
return x @ self.param
dtype = torch.float16
# Populate a cache entry.
mod1 = MM().to(device=device, dtype=dtype)
with torch.no_grad():
x = torch.rand(shape).to(device=device, dtype=dtype)
out0 = mod1(x)
out1 = torch.compile(mod1)(x)
self.assertEqual(out0, out1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
self.reset()
# Same nn.Module, but with different parameters. In the case that the param can
# be inlined, we should consider the actual tensor value and we expect a cache
# miss (because the values are different here). If the param cannot be inlined,
# then we consider only the tensor metadata and we expect a cache hit.
mod2 = MM().to(device=device, dtype=dtype)
self.assertNotEqual(mod1.param, mod2.param)
with torch.no_grad():
x = torch.rand(shape).to(device=device, dtype=dtype)
out0 = mod2(x)
out1 = torch.compile(mod2)(x)
self.assertEqual(out0, out1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(
counters["inductor"]["fxgraph_cache_miss"], 1 if inlinable else 0
)
self.assertEqual(
counters["inductor"]["fxgraph_cache_hit"], 0 if inlinable else 1
)
@instantiate_parametrized_tests
class TestStandaloneCompile(TestCase):
def setUp(self):
super().setUp()
counters.clear()
PatchCaches.setUp()
CacheArtifactManager.clear()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
def reset(self):
AOTAutogradCache.clear()
PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset()
clear_caches()
def capture(self, fn, dynamic=None):
def inner(*args):
gm = None
actual_args = None
kwargs = None
def backend(gm_, args_, **kwargs_):
nonlocal gm
nonlocal actual_args
nonlocal kwargs
gm = gm_
actual_args = args_
kwargs = kwargs_
return gm
_ = torch.compile(fn, fullgraph=True, backend=backend, dynamic=dynamic)(
*args
)
return gm, actual_args, kwargs
return inner
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("format", ("binary", "unpacked"))
@parametrize("dynamic", (False, True))
@parametrize("graph_partition", (False, True))
def test_basic(
self, device: str, format: str, dynamic: bool, graph_partition: bool
) -> None:
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
mod = torch.nn.Linear(1, 3, device=device)
x = torch.randn(4, 1, device=device)
if dynamic:
torch._dynamo.mark_dynamic(x, 0)
def f(x):
with torch.no_grad():
return mod(x), x.sin()
eager_out = f(x)
with (
tempfile.TemporaryDirectory() as temp_dir,
config.patch(graph_partition=graph_partition),
):
path = (
temp_dir
if format == "unpacked"
else os.path.join(temp_dir, "compiled_artifact.bin")
)
with fresh_cache():
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path, format=format)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_cache():
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
if dynamic:
concrete_args = [
4 if isinstance(a, torch.SymInt) else a for a in args
]
else:
concrete_args = args
compiled_out = loaded(*concrete_args)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic", (False, True))
def test_call_in_backend(self, dynamic: bool) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
if dynamic:
torch._dynamo.mark_dynamic(x, 0)
def f(x):
with torch.no_grad():
return mod(x)
eager_out = f(x)
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(gm, args)
with fresh_cache():
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
self.assertEqual(eager_out, compiled_out)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_save_in_new_path(self) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
def f(x):
with torch.no_grad():
return mod(x)
eager_out = f(x)
with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "new_dir")
with fresh_cache():
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path, format="unpacked")
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_cache():
loaded = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked"
)
compiled_out = loaded(*args)[0]
self.assertEqual(eager_out, compiled_out)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("device", (GPU_TYPE, "cpu"))
def test_modify_unpacked_file(self, device: str) -> None:
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
x = torch.ones(4, device=device)
def f(x):
with torch.no_grad():
return 2 * x, x.sin()
eager_out = f(x)
with tempfile.TemporaryDirectory() as temp_dir:
with fresh_cache():
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_out = compiled_artifact(*args)
self.assertEqual(eager_out, compiled_out)
compiled_artifact.save(path=temp_dir, format="unpacked")
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_cache():
# Now modify the output file and expect to see the changes
for subdir in os.listdir(temp_dir):
if subdir in ["aotautograd", "fxgraph"]:
continue
subdir_path = os.path.join(temp_dir, subdir)
for file in os.listdir(subdir_path):
file_path = os.path.join(subdir_path, file)
assert os.path.isfile(file_path)
with open(file_path) as f:
file_contents = f.read()
if device == GPU_TYPE:
file_contents = file_contents.replace(
"tmp1 = 2.0", "tmp1 = 8.0"
)
else:
assert device == "cpu"
file_contents = file_contents.replace(
"auto tmp1 = static_cast<float>(2.0);",
"auto tmp1 = static_cast<float>(8.0);",
)
with open(file_path, "w") as f:
f.write(file_contents)
loaded = torch._inductor.CompiledArtifact.load(
path=temp_dir, format="unpacked"
)
compiled_out = loaded(*args)
self.assertEqual(4 * eager_out[0], compiled_out[0])
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@unittest.skipIf(IS_FBCODE, "torch import error")
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_different_process(self):
x = torch.ones(4, 1)
def f(x):
return x.sin() * 2
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
with tempfile.TemporaryDirectory() as temp_dir:
path = normalize_path_separator(
os.path.join(temp_dir, "compiled_artifact.bin")
)
with fresh_cache():
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path)
script = f"""
import torch
from torch._inductor.utils import fresh_cache
arg = torch.ones(4, 1)
with fresh_cache():
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
compiled_result = loaded(arg)[0]
eager_result = arg.sin() * 2
if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
raise RuntimeError("tensors do not match")
"""
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
cwd=os.path.dirname(os.path.realpath(__file__)),
)
except subprocess.CalledProcessError as e:
self.fail(
msg=(
"Subprocess exception while attempting to run test: "
+ e.output.decode("utf-8")
)
)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_dynamic_shapes_from_graph(self):
def f(x):
return x.shape[0] * x
x = torch.ones(3)
torch._dynamo.mark_dynamic(x, 0)
with fresh_cache():
# captured graph is lambda s0, x: x * s0
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes="from_graph"
)
x = torch.ones(4)
(result,) = compiled_artifact(4, x)
self.assertEqual(result, x * 4)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
def test_split_module(self):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
x = x + (a0**2) + (a1 / 2)
x = x + (b0**2) + (b1 / 2)
x = x + (c0**2) + (c1 / 2)
return x
seen = 0
splits = [4, 8]
def split(n):
nonlocal seen
if seen < splits[0]:
seen += 1
return 0
elif seen < splits[1]:
seen += 1
return 1
else:
seen += 1
return 2
def t():
return torch.randn([])
x = t()
a0 = t()
a1 = t()
b0 = t()
b1 = t()
c0 = t()
c1 = t()
example_inputs = (x, a0, a1, b0, b1, c0, c1)
gm, inps, _ = self.capture(Mod())(*example_inputs)
split = torch.fx.passes.split_module.split_module(gm, gm, split)
# Each of the split graphs only has one output.
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
y = ca0(a0, x, a1)
y = ca1(b0, y, b1)
y = ca2(c0, y, c1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
# TODO: split_module causes ca1 and ca2 to have different type annotations
# for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
expected = Mod()(*example_inputs)
self.assertEqual(y, expected)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("config_patches", [True, False])
def test_dynamic_shapes_from_example_inputs(self, config_patches):
def f(x):
return x.shape[0] * x
x = torch.ones(3)
torch._dynamo.mark_dynamic(x, 0)
with fresh_cache():
# captured graph is lambda s0, x: x * s0
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
if config_patches:
config_patches = {"fx_graph_cache": True}
else:
config_patches = None
# specialized on example inputs
compiled_artifact = torch._inductor.standalone_compile(
gm,
(5, torch.ones(4)),
dynamic_shapes="from_example_inputs",
options={"config_patches": config_patches},
)
x = torch.ones(4)
(result,) = compiled_artifact(3, x)
# int 5 was baked in!
self.assertEqual(result, x * 5)
# size 4 was baked in
with self.assertRaisesRegex(AssertionError, "expected size 5==4"):
x = torch.randn(5)
(result,) = compiled_artifact(4, x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
def test_static_shapes(self, dynamic_shapes):
def f(x):
return x.shape[0] * x
static_x = torch.randn(3)
with fresh_cache():
# static_gm is lambda x: x * 3
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
static_gm, [static_x], dynamic_shapes=dynamic_shapes
)
x = torch.randn(3)
(result,) = compiled_artifact(x)
self.assertEqual(result, x * 3)
with self.assertRaisesRegex(AssertionError, "expected size 4==3"):
x = torch.randn(4)
(result,) = compiled_artifact(x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
def test_backend(self, dynamic_shapes):
def f(x):
return x.shape[0] * x
x = torch.randn(3)
torch._dynamo.mark_dynamic(x, 0)
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
)
y = torch.randn(4)
(result,) = compiled_artifact(4, y)
self.assertEqual(result, y * 4)
return compiled_artifact
torch._dynamo.reset()
_ = torch.compile(f, backend=backend)(x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_backend_dynamic_shapes_from_example_inputs(self):
def f(x):
return x.shape[0] * x
x = torch.ones(4)
torch._dynamo.mark_dynamic(x, 0)
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
)
y = torch.ones(4)
(result,) = compiled_artifact(4, y)
# 5 was baked in
self.assertEqual(result, y * 5)
# shape of y was baked in
with self.assertRaisesRegex(AssertionError, "expected size 5==4"):
y = torch.ones(5)
(result,) = compiled_artifact(4, y)
return compiled_artifact
torch._dynamo.reset()
_ = torch.compile(f, backend=backend)(x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize(
"dynamic_shapes", ["from_tracing_context", "from_graph", "from_example_inputs"]
)
def test_backend_static_shapes(self, dynamic_shapes):
# on static_x, all of these options should produce a static graph,
# but it's a bit hard to tell, so these are just smoke tests.
static_x = torch.randn(3)
def f(x):
return x.shape[0] * x
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
)
result = torch.compile(f, backend=backend)(static_x)
self.assertEqual(result, static_x * 3)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_custom_pass_handling(self):
"""
Test that properly-registered custom hooks allow caching.
"""
class TestCustomGraphPass(CustomGraphPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
return None
def uuid(self) -> Optional[Union[bytes, str]]:
return "uuid"
def fn(a, b):
return torch.mm(a, b)
a = torch.rand(8, 32, device="cpu")
b = torch.rand(32, 8, device="cpu")
compiled_fn = torch.compile(fn)
# The cache should be bypassed if a custom hook doesn't use CustomGraphPass.
with config.patch({"post_grad_custom_pre_pass": lambda x: x}):
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# With proper usage, we expect normal caching.
custom_pass = TestCustomGraphPass()
with config.patch(
{
"post_grad_custom_pre_pass": custom_pass,
"post_grad_custom_post_pass": custom_pass,
"joint_custom_pre_pass": custom_pass,
"joint_custom_post_pass": custom_pass,
}
):
self.reset()
counters.clear()
# Cache miss
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.reset()
counters.clear()
# Cache hit
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
class TestCustomPartitionerFn(CustomPartitionerFn):
def __init__(self):
self._uuid = None
def __call__(
self, gm, joint_inputs, **kwargs
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
return gm, gm # Dummy implementation
def uuid(self) -> Optional[Union[bytes, str]]:
return self._uuid
class TestFxGraphCacheHashing(TestCase):
def test_parameter_constants(self):
"""
Test the hashing of parameter constants.
"""
small = torch.nn.Parameter(torch.rand(8))
large = torch.nn.Parameter(torch.rand(32))
self.assertTrue(GraphLowering.can_inline_constant(small))
self.assertFalse(GraphLowering.can_inline_constant(large))
# By default, we hash the metadata and values independent of the size.
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
# For frozen parameters, we only hash the values of small tensors.
gm._has_frozen_params = True
gm._frozen_param0 = small
gm._frozen_param1 = large
small._is_frozen_param = True
large._is_frozen_param = True
pickler = FxGraphCachePickler(gm)
data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadata)
def test_hash_fake_tensors(self):
"""
Test hashing (pickling) FakeTensors with various characteristics.
"""
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
with torch._subclasses.FakeTensorMode():
# Verify that FakeTensors get pickled into a TensorMetadata:
data = pickler.dumps(torch.randn(1))
self.assertIsInstance(pickle.loads(data), TensorMetadata)
# Different shapes:
self.assertEqual(
pickler.dumps(torch.randn(3)),
pickler.dumps(torch.randn(3)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3)),
pickler.dumps(torch.randn(4)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3)),
pickler.dumps(torch.randn(3, 3)),
)
self.assertEqual(
pickler.dumps(torch.randn(3, 3)),
pickler.dumps(torch.randn(3, 3)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, 3)),
pickler.dumps(torch.randn(3, 4)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, 3)),
pickler.dumps(torch.randn(4, 3)),
)
# Different strides:
self.assertEqual(
pickler.dumps(torch.randn(3, 3)),
pickler.dumps(torch.randn(3, 3).transpose(0, 1).transpose(0, 1)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, 3)),
pickler.dumps(torch.randn(3, 3).transpose(0, 1)),
)
# Different storage offsets:
self.assertEqual(
pickler.dumps(torch.randn(3)[1:]),
pickler.dumps(torch.randn(3)[1:]),
)
self.assertEqual(
pickler.dumps(torch.randn(3)[1:]),
pickler.dumps(torch.randn(2)),
)
# Different dtypes:
self.assertEqual(
pickler.dumps(torch.randn(3, dtype=torch.float32)),
pickler.dumps(torch.randn(3, dtype=torch.float32)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, dtype=torch.float32)),
pickler.dumps(torch.randn(3, dtype=torch.float64)),
)
# Different 'requires_grad':
self.assertEqual(
pickler.dumps(torch.randn(3, requires_grad=True)),
pickler.dumps(torch.randn(3, requires_grad=True)),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, requires_grad=True)),
pickler.dumps(torch.randn(3, requires_grad=False)),
)
# Different memory formats:
self.assertNotEqual(
pickler.dumps(torch.randn(1, 2, 3, 4)),
pickler.dumps(
torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last)
),
)
# Different devices:
self.assertEqual(
pickler.dumps(torch.randn(3, device="meta")),
pickler.dumps(torch.randn(3, device="meta")),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, device="meta")),
pickler.dumps(torch.randn(3, device="cpu")),
)
if HAS_MULTIGPU:
self.assertEqual(
pickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")),
pickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")),
)
self.assertNotEqual(
pickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:0")),
pickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")),
)
def test_hash_kwargs(self):
"""
Test the special handling of the kwargs when hashing, i.e.,
ordering of the kwargs dict and any set arguments.
"""
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
# Dict order of the kwargs should not affect hashes.
details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, [])
details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}, [])
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
# Different kwarg values should affect hashes.
details1 = FxGraphHashDetails(None, [], {"a": 0}, [])
details2 = FxGraphHashDetails(None, [], {"a": 1}, [])
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
# Set order should not affect hashes. Sets are unordered, but
# sorting and creating a new set seems to change the order.
set1 = {"a", "b", "c", "d", "e", "f", "g"}
set2 = set(sorted(set1)) # noqa: C414
details1 = FxGraphHashDetails(None, [], {"a": set1}, [])
details2 = FxGraphHashDetails(None, [], {"a": set2}, [])
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
# But different set contents should affect hashes.
details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}, [])
details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}, [])
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
def test_hash_config_changes(self):
"""
Test that different config settings affect hashes.
"""
with config.patch({"max_autotune": False}):
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
with config.patch({"max_autotune": True}):
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_hash_private_config_changes(self):
"""
Test that private config settings affect hashes.
"""
with config.patch({"_micro_pipeline_tp": False}):
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
with config.patch({"_micro_pipeline_tp": True}):
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_non_serializable_custom_passes_causes_cache_miss(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 4))
def forward(self, x):
return x @ self.param
mod1 = Mod()
mod_compiled = torch.compile(mod1)
with torch.no_grad():
x = torch.rand(4, 4)
# miss
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# hit
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
# hit
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
# miss (private config changed)
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
with (
capture_logs("torch._inductor.codecache", logging.INFO) as logs,
config.patch({"_fuse_ddp_communication_passes": [lambda *args: None]}),
):
# bypass (custom pass is not serializable)
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
# assert that our bypass is explicit
self.assertTrue(
any(
x.getMessage()
== "Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
for x in logs
)
)
def test_hash_custom_passes(self):
"""
Test CustomGraphPass usage.
"""
class TestCustomGraphPass(CustomGraphPass):
def __init__(self):
self._uuid = None
def __call__(self, graph: torch.fx.graph.Graph) -> None:
return None
def uuid(self) -> Optional[Union[bytes, str]]:
return self._uuid
custom_pass = TestCustomGraphPass()
with config.patch({"post_grad_custom_pre_pass": custom_pass}):
custom_pass._uuid = "1"
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
custom_pass._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_hash_custom_backend_pass(self):
"""
Test CustomGraphModulePass usage.
"""
class TestCustomGraphModulePass(CustomGraphModulePass):
def __init__(self):
self._uuid = None
def __call__(self, gm: torch.fx.GraphModule) -> None:
return None
def uuid(self) -> Optional[Union[bytes, str]]:
return self._uuid
custom_pass = TestCustomGraphModulePass()
with patch_inductor_backend("cpu", custom_pass=custom_pass):
custom_pass._uuid = "1"
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
custom_pass._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_hash_custom_backend_config(self):
"""
Test cache correctness when a custom inductor codegen config
is installed
"""
with patch_inductor_backend(
"cpu", custom_backend_config=custom_inductor_config
):
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
self.assertEqual(pickler.dumps(details1), pickler.dumps(details2))
custom_inductor_config.enable_optimisation = True
details3 = FxGraphHashDetails(None, [], {}, [])
self.assertNotEqual(pickler.dumps(details2), pickler.dumps(details3))
torch._dynamo.reset()
counters.clear()
custom_inductor_config.enable_optimisation = False
x = torch.zeros(32)
y = torch.zeros(32)
compiled_fn = torch.compile(torch.add)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
torch._dynamo.reset()
counters.clear()
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
# Changing the custom config should trigger a recompilation
custom_inductor_config.enable_optimisation = True
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
def test_hash_custom_partitioner_fn(self):
"""
Test that the custom partitioner function's UUID is properly used in the FX graph cache hashing.
"""
custom_partitioner_fn = TestCustomPartitionerFn()
with config.patch({"custom_partitioner_fn": custom_partitioner_fn}):
custom_partitioner_fn._uuid = "1"
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
custom_partitioner_fn._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])
self.assertEqual(details1._custom_partitioner_fn, "1")
self.assertEqual(details2._custom_partitioner_fn, "1")
self.assertEqual(details3._custom_partitioner_fn, "2")
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_bypass_unsupported(self):
"""
Test _reduce_unsupported
"""
gm = torch.fx.GraphModule({}, torch.fx.Graph())
with self.assertRaises(BypassFxGraphCache):
FxGraphCachePickler(gm).dumps(
torch.fx.experimental._backward_state.BackwardState()
)
def test_stable_strings(self):
"""
Test that objects containing identical strings pickle the same
even if they are not the same id.
"""
s1 = "string"
s2 = "strin" # codespell:ignore
s2 += "g"
self.assertNotEqual(id(s1), id(s2))
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps([s1, s1]),
pickler.dumps([s1, s2]),
)
def test_get_hash_for_files(self):
"""
Test the get_hash_for_files helper.
"""
# delete=True does not work on Windows.
# See https://docs.python.org/3.12/library/tempfile.html#tempfile.NamedTemporaryFile
with tempfile.NamedTemporaryFile(delete=False) as temp:
try:
temp.write(b"contents")
temp.flush()
hash1 = get_hash_for_files((temp.name,))
get_hash_for_files.cache_clear()
hash2 = get_hash_for_files((temp.name,))
temp.write(b" ")
temp.flush()
get_hash_for_files.cache_clear()
hash3 = get_hash_for_files((temp.name,))
self.assertEqual(hash1, hash2)
self.assertNotEqual(hash1, hash3)
finally:
temp.close()
os.unlink(temp.name)
class TestCudaCompileCommand(TestCase):
@requires_cuda_and_triton
def test_cuda_compile_command(self):
cmd_no_extra_args: str = cuda_compile_command(
["abc.cu", "def.cu"], "output", "so"
)
assert "nvcc " in cmd_no_extra_args, cmd_no_extra_args
assert "abc.cu" in cmd_no_extra_args, cmd_no_extra_args
assert "def.cu" in cmd_no_extra_args, cmd_no_extra_args
assert "output" in cmd_no_extra_args, cmd_no_extra_args
cmd_extra_args: str = cuda_compile_command(
["abc.cu", "def.cu"], "output", "so", ["-Wwhatever", "-nothing"]
)
assert "nvcc " in cmd_extra_args, cmd_extra_args
assert " -Wwhatever" in cmd_extra_args, cmd_extra_args
assert " -nothing" in cmd_extra_args, cmd_extra_args
assert "abc.cu" in cmd_extra_args, cmd_extra_args
assert "def.cu" in cmd_extra_args, cmd_extra_args
assert "output " in cmd_extra_args, cmd_extra_args
with mock.patch("subprocess.check_output") as check_output_mock:
CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"])
check_output_mock.assert_called()
cmd_parts: list[str] = check_output_mock.call_args[0][0]
assert cmd_parts[0].endswith("nvcc"), cmd_parts
assert "-Wsomething" in cmd_parts, cmd_parts
assert "-DNDEBUG" in cmd_parts, cmd_parts
@instantiate_parametrized_tests
class TestAutotuneCache(TestCase):
device_type = GPU_TYPE
def setUp(self):
super().setUp()
counters.clear()
PatchCaches.setUp()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
def reset(self):
PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset()
clear_caches()
@requires_cuda_and_triton
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@unittest.skipIf(
TEST_WITH_ROCM, "Requires static cuda launcher, which does not support ROCM"
)
@config.patch({"use_static_cuda_launcher": True})
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"autotune_local_cache": False})
@config.patch({"autotune_remote_cache": True})
@config.patch({"bundled_autotune_remote_cache": False})
@config.patch({"max_autotune": True})
@config.patch(
{"compile_threads": 1}
) # Worker processes do not register PatchCaches() properly
def test_autotune_cache_warm_start(self):
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
return x + y, a + b
def f(x, y, a, b):
return Model()(x, y, a, b)
x = torch.randn(100, 100).cuda()
y = torch.randn(100, 100).cuda()
a = torch.randn(1000, 100).cuda()
b = torch.randn(1000, 100).cuda()
f_compiled = torch.compile(f, fullgraph=True)
with PatchCaches():
a1 = f_compiled(x, y, a, b)
self.assertEqual(global_stats.autotune_remote, Stats(2, 0, 2))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# Don't reset FxGraphCache, see that it loads again
torch._dynamo.reset()
a2 = f_compiled(x, y, a, b)
self.assertEqual(a1, a2)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(global_stats.autotune_remote, Stats(2, 2, 2))
# Check that the cache entries seem reasonable
for k in global_stats.autotune_remote.cache.keys():
self.assertRegex(k, r"[0-9a-z]{52}")
for k in global_stats.triton.cache.keys():
self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+")
@requires_cuda_and_triton
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@config.patch({"fx_graph_cache": False})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"autotune_local_cache": False})
@config.patch({"autotune_remote_cache": True})
@config.patch({"bundled_autotune_remote_cache": False})
@config.patch({"max_autotune": True})
@config.patch(
{"compile_threads": 1}
) # Worker processes do not register PatchCaches() properly
def test_autotune_cache(self):
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
return x + y, a + b
def f(x, y, a, b):
return Model()(x, y, a, b)
x = torch.randn(100, 100).cuda()
y = torch.randn(100, 100).cuda()
a = torch.randn(1000, 100).cuda()
b = torch.randn(1000, 100).cuda()
f_compiled = torch.compile(f, fullgraph=True)
with PatchCaches():
f_compiled(x, y, a, b)
self.assertEqual(global_stats.autotune_remote, Stats(2, 0, 2))
self.reset()
f_compiled(x, y, a, b)
self.assertEqual(global_stats.autotune_remote, Stats(2, 2, 2))
# Check that the cache entries seem reasonable
for k in global_stats.autotune_remote.cache.keys():
self.assertRegex(k, r"[0-9a-z]{52}")
for k in global_stats.triton.cache.keys():
self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+")
@requires_cuda_and_triton
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@config.patch({"fx_graph_cache": False})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"autotune_local_cache": True})
@config.patch({"autotune_remote_cache": False})
@config.patch({"bundled_autotune_remote_cache": True})
@config.patch({"compile_threads": 1})
@config.patch({"max_autotune": True})
def test_bundled_autotune_remote_cache(self):
class Model(torch.nn.Module):
def forward(self, a, b, c, d, e, f):
return a + b, c + d, e + f
def f(a, b, c, d, e, f):
return Model()(a, b, c, d, e, f)
f_compiled = torch.compile(f, fullgraph=True)
a = torch.randn(101, 100).cuda()
b = torch.randn(101, 100).cuda()
c = torch.randn(102, 100).cuda()
d = torch.randn(102, 100).cuda()
e = torch.randn(103, 100).cuda()
f = torch.randn(103, 100).cuda()
with PatchCaches():
f_compiled(a, b, c, d, e, f)
self.assertEqual(global_stats.autotune_local, Stats(3, 0, 3))
self.assertEqual(global_stats.bundled_autotune, Stats(1, 0, 1))
self.reset()
f_compiled(a, b, c, d, e, f)
self.assertEqual(global_stats.autotune_local, Stats(6, 3, 3))
self.assertEqual(global_stats.bundled_autotune, Stats(1, 1, 1))
with torch.compiler.config.patch({"cache_key_tag": "test"}):
global_stats.reset()
self.reset()
f_compiled(a, b, c, d, e, f)
self.assertEqual(global_stats.autotune_local, Stats(3, 0, 3))
self.assertEqual(global_stats.bundled_autotune, Stats(1, 0, 1))
self.reset()
f_compiled(a, b, c, d, e, f)
self.assertEqual(global_stats.autotune_local, Stats(6, 3, 3))
self.assertEqual(global_stats.bundled_autotune, Stats(1, 1, 1))
# Check that the cache entries seem reasonable
for k in global_stats.autotune_local.cache.keys():
self.assertRegex(k, r"tmp[^/]*/([^/]{2})/[^/]{64}\.best_config")
for k in global_stats.bundled_autotune.cache.keys():
self.assertRegex(k, r"pt2:bundled-autotune-v1::[0-9a-z]{64}:c[0-9]+")
for k in global_stats.triton.cache.keys():
self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+")
@requires_triton()
@requires_cuda_and_triton
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@config.patch({"fx_graph_cache": False})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"bundled_autotune_remote_cache": False})
@config.patch({"max_autotune": True})
@config.patch(
{"compile_threads": 1}
) # Worker processes do not register PatchCaches() properly
@parametrize("remote_cache", (True, False))
def test_modified_autotune_cache(self, remote_cache):
"""
If a developer changes the way the autotune cache is handled,
there's a chance it'll break the cache. This happened with
#150122. This test ensures that if torch code changes, then
old cache entries will be invalidated.
"""
def mock_torch_key(value: str) -> bytes:
return value.encode("utf-8")
def get_autotune_stats():
if remote_cache:
return global_stats.autotune_remote
return global_stats.autotune_local
def fn(x, y):
return (x + y).relu()
x = torch.randn(100, 100).to(GPU_TYPE)
y = torch.randn(100, 100).to(GPU_TYPE)
with config.patch(
{
"autotune_local_cache": not remote_cache,
"autotune_remote_cache": remote_cache,
}
):
with PatchCaches():
with mock.patch(
"torch._inductor.codecache.torch_key",
functools.partial(mock_torch_key, "torchkey1"),
):
f_compiled = torch.compile(fn, fullgraph=True)
res1 = f_compiled(x, y)
self.assertEqual(get_autotune_stats(), Stats(1, 0, 1))
torch._dynamo.reset()
PyCodeCache.cache_clear()
with mock.patch(
"torch._inductor.codecache.torch_key",
functools.partial(mock_torch_key, "torchkey2"),
):
f_compiled = torch.compile(fn, fullgraph=True)
res2 = f_compiled(x, y)
self.assertEqual(get_autotune_stats(), Stats(2, 0, 2))
self.assertEqual(res1, res2)
class TestRemoteAOTAutogradCache(TestCase):
@requires_cuda_and_triton
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@config.patch({"fx_graph_cache": False})
@config.patch({"fx_graph_remote_cache": True})
@torch._functorch.config.patch({"enable_autograd_cache": False})
@torch._functorch.config.patch({"enable_remote_autograd_cache": True})
def test_autograd_remote_cache(self):
def f(a, b):
return a + b
f_compiled = torch.compile(f)
a = torch.randn(101, 100, device="cuda", requires_grad=False)
b = torch.randn(101, 100, device="cuda", requires_grad=False)
with PatchCaches():
f_compiled(a, b)
self.assertEqual(global_stats.aot_autograd, Stats(1, 0, 1))
self.assertEqual(global_stats.fx_graph, Stats(1, 0, 1))
torch._dynamo.reset()
f_compiled(a, b)
self.assertEqual(global_stats.aot_autograd, Stats(1, 1, 1))
self.assertEqual(global_stats.fx_graph, Stats(1, 1, 1))
torch._dynamo.reset()
with torch.compiler.config.patch({"cache_key_tag": "test"}):
f_compiled(a, b)
self.assertEqual(global_stats.aot_autograd, Stats(2, 1, 2))
self.assertEqual(global_stats.fx_graph, Stats(2, 1, 2))
# Check that the cache entries seem reasonable
for k in global_stats.aot_autograd.cache.keys():
self.assertRegex(k, r"pt2:autograd-experimental::[0-9a-z]{52}:c[0-9]+")
for k in global_stats.fx_graph.cache.keys():
self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+")
@requires_cuda_and_triton
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@config.patch({"fx_graph_cache": False})
@config.patch({"fx_graph_remote_cache": True})
@torch._functorch.config.patch({"enable_autograd_cache": False})
@torch._functorch.config.patch({"enable_remote_autograd_cache": True})
def test_autograd_remote_lazy_backward(self):
"""
Lazily compile the backward, and lazily save to cache
"""
def fn(a, b):
return a.cos() + b
with PatchCaches():
a = torch.randn(25, requires_grad=True)
b = torch.randn(25, requires_grad=True)
a2 = a.detach().clone().requires_grad_(True)
b2 = b.detach().clone().requires_grad_(True)
compiled_fn = torch.compile(fn, backend="inductor")
self.assertEqual(fn(a, b), compiled_fn(a2, b2))
self.assertEqual(global_stats.aot_autograd, Stats(0, 0, 1))
# Clear dynamo and run again. Should be a cache miss still, because backward hasn't run
torch._dynamo.reset()
self.assertEqual(fn(a, b), compiled_fn(a2, b2))
self.assertEqual(global_stats.aot_autograd, Stats(0, 0, 2))
# Now let's run the backward
fn(a, b).sum().backward()
compiled_fn(a2, b2).sum().backward()
self.assertEqual(a.grad, a2.grad)
self.assertEqual(b.grad, b2.grad)
self.assertEqual(global_stats.aot_autograd, Stats(1, 0, 2))
# Clear dynamo and rerun everything, now there should be a cache hit
torch._dynamo.reset()
a = torch.randn(25, requires_grad=True)
b = torch.randn(25, requires_grad=True)
a2 = a.detach().clone().requires_grad_(True)
b2 = b.detach().clone().requires_grad_(True)
self.assertEqual(fn(a, b), compiled_fn(a2, b2))
self.assertEqual(global_stats.aot_autograd, Stats(1, 1, 2))
fn(a, b).sum().backward()
compiled_fn(a2, b2).sum().backward()
self.assertEqual(a.grad, a2.grad)
self.assertEqual(b.grad, b2.grad)
class TestUtils(TestCase):
@config.patch({"fx_graph_remote_cache": False})
def test_fresh_cache(self):
def fn(x, y):
return x + y
a = torch.rand(10)
b = torch.rand(10)
with fresh_cache():
self.assertEqual(len(PyCodeCache.modules), 0)
res1 = torch.compile(fn)(a, b)
cache_dir1 = cache_dir()
torch._dynamo.reset()
with fresh_cache():
self.assertEqual(len(PyCodeCache.modules), 0)
res2 = torch.compile(fn)(a, b)
cache_dir2 = cache_dir()
self.assertEqual(res1, res2)
self.assertNotEqual(cache_dir1, cache_dir2)
# This combination of settings exposed a bug where we cleared the
# PyCodeCache disk artifacts while they were still needed:
@requires_cuda_and_triton
@config.patch(
{
"coordinate_descent_tuning": True,
"force_disable_caches": True,
}
)
def test_force_disable_coordinate_descent(self):
def fn():
inp = torch.randn(32, 50, 768, device="cuda")
weight = torch.randn(768, 768, device="cuda")
layer = torch.nn.LayerNorm(768, device="cuda")
return layer(inp @ weight)
torch.compile(fn)()
if __name__ == "__main__":
run_tests()