[Inductor][CuTeDSL] Move load_template up two directories (#165347)

Summary: Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:flex_flash -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`

Differential Revision: D84527470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165347
Approved by: https://github.com/drisspg
This commit is contained in:
Nikhil Patel
2025-10-15 16:34:58 +00:00
committed by PyTorch MergeBot
parent ffe3cb226a
commit 815d641599
5 changed files with 27 additions and 19 deletions

View File

@ -3,6 +3,7 @@
import math
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any, Optional, Union
@ -36,6 +37,7 @@ from ...lowering import (
to_dtype,
)
from ...select_algorithm import realize_inputs
from ...utils import load_template
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
@ -337,13 +339,7 @@ def next_power_of_two(n):
return 2 ** math.ceil(math.log2(n))
_TEMPLATE_DIR = Path(__file__).parent / "templates"
def load_template(name: str) -> str:
"""Load a template file and return its content."""
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
return f.read()
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
# Template strings have been moved to templates/common.py.jinja

View File

@ -29,7 +29,7 @@ from .common import (
freeze_irnodes,
get_fwd_subgraph_outputs,
infer_dense_strides,
load_template,
load_flex_template,
maybe_realize,
set_head_dim_values,
SubgraphResults,
@ -79,9 +79,9 @@ def get_float32_precision():
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=load_template("flex_attention")
+ load_template("utilities")
+ load_template("common"),
source=load_flex_template("flex_attention")
+ load_flex_template("utilities")
+ load_flex_template("common"),
)
@ -464,7 +464,7 @@ def flex_attention_backward_grid(
flex_attention_backward_template = TritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=load_template("flex_backwards") + load_template("utilities"),
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
)

View File

@ -22,7 +22,7 @@ from .common import (
create_num_blocks_fake_generator,
freeze_irnodes,
get_fwd_subgraph_outputs,
load_template,
load_flex_template,
maybe_realize,
set_head_dim_values,
)
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=load_template("flex_decode")
+ load_template("utilities")
+ load_template("common"),
source=load_flex_template("flex_decode")
+ load_flex_template("utilities")
+ load_flex_template("common"),
)

View File

@ -12,7 +12,7 @@ from torch.fx import GraphModule
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
from ...lowering import empty_strided
from .common import infer_dense_strides, load_template, SubgraphResults
from .common import infer_dense_strides, load_flex_template, SubgraphResults
aten = torch.ops.aten
@ -36,7 +36,8 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
flash_attention_cutedsl_template = CuteDSLTemplate(
name="flash_attention_cutedsl", source=load_template("flash_attention")
name="flash_attention_cutedsl",
source=load_flex_template("flash_attention"),
)

View File

@ -67,6 +67,10 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten, tree_map_only
if TYPE_CHECKING:
from pathlib import Path
OPTIMUS_EXCLUDE_POST_GRAD = [
"activation_quantization_aten_pass",
"inductor_autotune_lookup_table",
@ -3886,3 +3890,10 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
return dep_name.startswith(
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
)
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
def load_template(name: str, template_dir: Path) -> str:
"""Load a template file and return its content."""
with open(template_dir / f"{name}.py.jinja") as f:
return f.read()