Revert "[Inductor][CuTeDSL] Move load_template up two directories (#165347) (#165576)"

This reverts commit febb60323018948b2b9d2cff35b3cc4e0d0c55c8.

Reverted https://github.com/pytorch/pytorch/pull/165576 on behalf of https://github.com/seemethere due to This was actually reverted internally, current PR is linked to a stale diff so diff train tools think that this is landed via co-dev when it was actually reverted ([comment](https://github.com/pytorch/pytorch/pull/165576#issuecomment-3417510146))
This commit is contained in:
PyTorch MergeBot
2025-10-17 23:33:17 +00:00
parent 1b397420f2
commit 69c33898fa
5 changed files with 19 additions and 27 deletions

View File

@ -3,7 +3,6 @@
import math
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any, Optional, Union
@ -37,7 +36,6 @@ from ...lowering import (
to_dtype,
)
from ...select_algorithm import realize_inputs
from ...utils import load_template
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
@ -339,7 +337,13 @@ def next_power_of_two(n):
return 2 ** math.ceil(math.log2(n))
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
_TEMPLATE_DIR = Path(__file__).parent / "templates"
def load_template(name: str) -> str:
"""Load a template file and return its content."""
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
return f.read()
# Template strings have been moved to templates/common.py.jinja

View File

@ -29,7 +29,7 @@ from .common import (
freeze_irnodes,
get_fwd_subgraph_outputs,
infer_dense_strides,
load_flex_template,
load_template,
maybe_realize,
set_head_dim_values,
SubgraphResults,
@ -79,9 +79,9 @@ def get_float32_precision():
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=load_flex_template("flex_attention")
+ load_flex_template("utilities")
+ load_flex_template("common"),
source=load_template("flex_attention")
+ load_template("utilities")
+ load_template("common"),
)
@ -464,7 +464,7 @@ def flex_attention_backward_grid(
flex_attention_backward_template = TritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
source=load_template("flex_backwards") + load_template("utilities"),
)

View File

@ -22,7 +22,7 @@ from .common import (
create_num_blocks_fake_generator,
freeze_irnodes,
get_fwd_subgraph_outputs,
load_flex_template,
load_template,
maybe_realize,
set_head_dim_values,
)
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=load_flex_template("flex_decode")
+ load_flex_template("utilities")
+ load_flex_template("common"),
source=load_template("flex_decode")
+ load_template("utilities")
+ load_template("common"),
)

View File

@ -12,7 +12,7 @@ from torch.fx import GraphModule
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
from ...lowering import empty_strided
from .common import infer_dense_strides, load_flex_template, SubgraphResults
from .common import infer_dense_strides, load_template, SubgraphResults
aten = torch.ops.aten
@ -36,8 +36,7 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
flash_attention_cutedsl_template = CuteDSLTemplate(
name="flash_attention_cutedsl",
source=load_flex_template("flash_attention"),
name="flash_attention_cutedsl", source=load_template("flash_attention")
)

View File

@ -67,10 +67,6 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten, tree_map_only
if TYPE_CHECKING:
from pathlib import Path
OPTIMUS_EXCLUDE_POST_GRAD = [
"activation_quantization_aten_pass",
"inductor_autotune_lookup_table",
@ -3890,10 +3886,3 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
return dep_name.startswith(
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
)
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
def load_template(name: str, template_dir: Path) -> str:
"""Load a template file and return its content."""
with open(template_dir / f"{name}.py.jinja") as f:
return f.read()