Compare commits

...

1 Commits

Author SHA1 Message Date
f9da58f230 wip on reset_to_zero support 2024-11-18 14:14:21 -08:00
5 changed files with 111 additions and 55 deletions

View File

@ -1407,46 +1407,6 @@ def forward(self, x_1, output_1):
self.assertEqual(compiled_out, eager_out)
@requires_gpu
def test_triton_kernel_reset_to_zero(self):
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
],
key=["n_elements"],
reset_to_zero=["out_ptr"],
)
@triton.jit
def add_kernel_autotuned_reset(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def f(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned_reset[grid](x, y, output, n_elements)
return output
x = torch.randn(4, device=GPU_TYPE)
msg = "Only configs, keys, and restore_value are supported for triton.autotune"
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
f(x, x)
@requires_gpu
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@ -2010,6 +1970,67 @@ def forward(self, arg0_1, arg1_1):
# make sure x was restored after autotuning
torch.testing.assert_close(x, prev + 1)
# we have to enable autotune at compile time for this test
@requires_gpu
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("autotune_at_compile_time", [True])
def test_triton_kernel_reset_to_zero(self, backend, autotune_at_compile_time):
if autotune_at_compile_time and backend != "inductor":
raise unittest.SkipTest("compile-time autotuning only exists in inductor")
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE": 64, "COND": 1234}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE": 32, "COND": 1234}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE": 16, "COND": 1234}, num_stages=3, num_warps=8
),
],
key=[],
reset_to_zero=["counter"],
)
@triton.jit
def increment_kernel(
in_ptr0,
counter, # reset this to zero every time
n_elements,
BLOCK_SIZE: "tl.constexpr",
COND: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
in_ptr_vals = tl.load(in_ptr0 + offsets, mask=mask)
count = tl.load(counter + offsets, mask=mask)
# count should always be zero
tl.store(in_ptr0 + offsets, in_ptr_vals + count, mask=mask)
@torch.compile(fullgraph=True, backend=backend)
def f(x, y):
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
increment_kernel[grid](x, y, n_elements=n_elements)
return x
x = torch.rand(4, device=GPU_TYPE)
y = torch.clone(x)
rand = torch.rand(4, device=GPU_TYPE)
# during autotuning, x should not change in value
with torch._inductor.config.patch(
{"triton.autotune_at_compile_time": autotune_at_compile_time}
):
# we will add rand a single time to x
f(x, rand)
self.assertEqual(y + rand, x)
@requires_gpu
@parametrize("dtype", (torch.float16, torch.float32, torch.float64))
def test_triton_kernel_float64_constant(self, dtype):

View File

@ -1077,14 +1077,6 @@ class TritonHOPifier:
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
# Set via reset_to_zero argument
# https://github.com/triton-lang/triton/pull/5083
# changes kernel.reset_idx to kernel.reset_to_zero
or (hasattr(kernel, "reset_idx") and len(kernel.reset_idx) != 0)
or (
hasattr(kernel, "reset_to_zero")
and len(kernel.reset_to_zero) != 0
)
or (
"use_cuda_graph" in defaults
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph

View File

@ -1503,6 +1503,7 @@ class PythonWrapperCodegen(CodeGen):
configs,
kwargs,
restore_value_args,
reset_to_zero_args,
):
from torch.utils._triton import patch_triton_dtype_repr
@ -1590,6 +1591,9 @@ class PythonWrapperCodegen(CodeGen):
if restore_value_args:
triton_meta["restore_value"] = tuple(restore_value_args)
if reset_to_zero_args:
triton_meta["reset_to_zero"] = tuple(reset_to_zero_args)
# Distinguish between different functions using function id
cache_key: List[Any] = [id(kernel.fn)]
if len(configs) > 0:

View File

@ -5483,6 +5483,7 @@ class UserDefinedTritonKernel(ExternKernel):
kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = []
restore_value_args = []
reset_to_zero_args = []
if isinstance(kernel, Autotuner):
# https://github.com/triton-lang/triton/pull/5083
# changes kernel.restore_idx to kernel.restore_value
@ -5492,16 +5493,28 @@ class UserDefinedTritonKernel(ExternKernel):
else:
assert hasattr(kernel, "restore_value")
restore_value_args.extend(kernel.restore_value)
if hasattr(kernel, "reset_idx"):
for i in kernel.reset_idx:
reset_to_zero_args.append(kernel.fn.arg_names[i])
elif hasattr(kernel, "reset_to_zero"):
reset_to_zero_args.extend(kernel.reset_to_zero)
configs = kernel.configs
kernel = kernel.fn
return kernel, configs, restore_value_args
return kernel, configs, restore_value_args, reset_to_zero_args
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
kernel, configs, restore_value_args = self.get_kernel_and_metadata()
(
kernel,
configs,
restore_value_args,
reset_to_zero_args,
) = self.get_kernel_and_metadata()
# Definition of kernel
new_name, triton_meta = wrapper.define_user_defined_triton_kernel(
kernel, configs, self.kwargs, restore_value_args
kernel, configs, self.kwargs, restore_value_args, reset_to_zero_args
)
raw_args = [
self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
@ -5608,7 +5621,7 @@ class UserDefinedTritonKernel(ExternKernel):
self.kernel_idx = kernel_idx
self.grid = grid
kernel, configs, _ = self.get_kernel_and_metadata()
kernel, configs, _, _ = self.get_kernel_and_metadata()
# If we are autotuning, not all arguments will be passed
self.ordered_kwargs_for_cpp_kernel = [

View File

@ -202,6 +202,7 @@ class CachingAutotuner(KernelInterface):
configs,
save_cache_hook,
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
reset_to_zero_arg_names: List[str],
optimize_mem,
heuristic_type,
size_hints=None,
@ -226,6 +227,7 @@ class CachingAutotuner(KernelInterface):
self.inductor_meta = {} if inductor_meta is None else inductor_meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.reset_to_zero_arg_names = reset_to_zero_arg_names
self.optimize_mem = optimize_mem
self.configs = configs
self.heuristic_type = heuristic_type
@ -736,7 +738,7 @@ class CachingAutotuner(KernelInterface):
def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
# we don't skip configs wiht spilled registers when auto-tuning custom
# we don't skip configs with spilled registers when auto-tuning custom
# (user-written) Triton kernels, as (i) we don't have any knowledge or
# control over the kernel code; (ii) there is empirical evidence that
# for some (complicated) custom Triton kernels, a register-spilling
@ -760,6 +762,9 @@ class CachingAutotuner(KernelInterface):
cloned_args, cloned_kwargs = self.maybe_clone_args(
cpu_copies, *args, **kwargs
)
# reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
@ -823,6 +828,17 @@ class CachingAutotuner(KernelInterface):
arg, cpu_arg = pair
arg.copy_(cpu_arg, non_blocking=True)
def reset_to_zero_args(self, *args, **kwargs):
for i, arg in enumerate(args):
if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
assert isinstance(arg, torch.Tensor)
arg.zero_()
for name, arg in kwargs.items():
if name in self.reset_to_zero_arg_names:
assert isinstance(arg, torch.Tensor)
arg.zero_()
def maybe_clone_args(
self, exclude: Container[str], *args, **kwargs
) -> Tuple[List[Any], Dict[str, Any]]:
@ -875,6 +891,7 @@ class CachingAutotuner(KernelInterface):
k.shared,
)
self.reset_to_zero_args(*args, **kwargs)
return timings
def autotune_to_one_config(self, *args, **kwargs):
@ -904,7 +921,7 @@ class CachingAutotuner(KernelInterface):
else launcher.bin.metadata["name"]
),
"grid_x": grid_x,
"grid_y": grid_y,
"grid _y": grid_y,
"grid_z": grid_z,
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
"y_block": launcher.config.kwargs.get("YBLOCK", None),
@ -1234,12 +1251,19 @@ def cached_autotune(
if disabled:
log.debug("autotune caching is disabled by config.force_disable_caches")
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
mutated_arg_names = inductor_meta.pop("mutated_arg_names", [])
optimize_mem = inductor_meta.pop("optimize_mem", True)
# Anything that will be reset or set to zero is mutated
if "restore_value" in triton_meta:
mutated_arg_names += triton_meta.pop("restore_value")
reset_to_zero_arg_names: List[str] = []
if "reset_to_zero" in triton_meta:
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
mutated_arg_names.extend(reset_to_zero_arg_names)
def decorator(fn):
# Remove XBLOCK from config if it's not a function argument.
# This way, coordinate descent tuning will not try to tune it.
@ -1265,6 +1289,7 @@ def cached_autotune(
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
@ -1279,6 +1304,7 @@ def cached_autotune(
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,