[Inductor][CuTeDSL] Move load_template up two directories

Summary:
This is a reland of D84527470

Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8

Differential Revision: D85013024
This commit is contained in:
Nikhil Patel
2025-10-19 11:39:41 -07:00
committed by Facebook GitHub Bot
parent fa0db212e7
commit ce7be304e9
5 changed files with 25 additions and 18 deletions

View File

@ -3,6 +3,7 @@
import math
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any, Optional, Union
@ -36,6 +37,7 @@ from ...lowering import (
to_dtype,
)
from ...select_algorithm import realize_inputs
from ...utils import load_template
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
@ -337,13 +339,8 @@ def next_power_of_two(n):
return 2 ** math.ceil(math.log2(n))
_TEMPLATE_DIR = Path(__file__).parent / "templates"
def load_template(name: str) -> str:
"""Load a template file and return its content."""
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
return f.read()
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
# Template strings have been moved to templates/common.py.jinja

View File

@ -29,7 +29,7 @@ from .common import (
freeze_irnodes,
get_fwd_subgraph_outputs,
infer_dense_strides,
load_template,
load_flex_template,
maybe_realize,
set_head_dim_values,
SubgraphResults,
@ -79,9 +79,9 @@ def get_float32_precision():
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=load_template("flex_attention")
+ load_template("utilities")
+ load_template("common"),
source=load_flex_template("flex_attention")
+ load_flex_template("utilities")
+ load_flex_template("common"),
)
@ -464,7 +464,7 @@ def flex_attention_backward_grid(
flex_attention_backward_template = TritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=load_template("flex_backwards") + load_template("utilities"),
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
)

View File

@ -22,7 +22,7 @@ from .common import (
create_num_blocks_fake_generator,
freeze_irnodes,
get_fwd_subgraph_outputs,
load_template,
load_flex_template,
maybe_realize,
set_head_dim_values,
)
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=load_template("flex_decode")
+ load_template("utilities")
+ load_template("common"),
source=load_flex_template("flex_decode")
+ load_flex_template("utilities")
+ load_flex_template("common"),
)

View File

@ -12,7 +12,7 @@ from torch.fx import GraphModule
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
from ...lowering import empty_strided
from .common import infer_dense_strides, load_template, SubgraphResults
from .common import infer_dense_strides, load_flex_template, SubgraphResults
aten = torch.ops.aten
@ -36,7 +36,7 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
flash_attention_cutedsl_template = CuteDSLTemplate(
name="flash_attention_cutedsl", source=load_template("flash_attention")
name="flash_attention_cutedsl", source=load_flex_template("flash_attention")
)

View File

@ -67,6 +67,9 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten, tree_map_only
if TYPE_CHECKING:
from pathlib import Path
OPTIMUS_EXCLUDE_POST_GRAD = [
"activation_quantization_aten_pass",
"inductor_autotune_lookup_table",
@ -3885,3 +3888,10 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
return dep_name.startswith(
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
)
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
def load_template(name: str, template_dir: Path) -> str:
"""Load a template file and return its content."""
with open(template_dir / f"{name}.py.jinja") as f:
return f.read()