mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CuTeDSL] Move load_template up two directories (#165347)
Summary: Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future. Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:flex_flash -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8` Differential Revision: D84527470 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165347 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffe3cb226a
commit
815d641599
@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -36,6 +37,7 @@ from ...lowering import (
|
||||
to_dtype,
|
||||
)
|
||||
from ...select_algorithm import realize_inputs
|
||||
from ...utils import load_template
|
||||
|
||||
|
||||
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
|
||||
@ -337,13 +339,7 @@ def next_power_of_two(n):
|
||||
return 2 ** math.ceil(math.log2(n))
|
||||
|
||||
|
||||
_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
def load_template(name: str) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
||||
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
|
||||
|
||||
# Template strings have been moved to templates/common.py.jinja
|
||||
|
@ -29,7 +29,7 @@ from .common import (
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
infer_dense_strides,
|
||||
load_template,
|
||||
load_flex_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
SubgraphResults,
|
||||
@ -79,9 +79,9 @@ def get_float32_precision():
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=load_template("flex_attention")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
source=load_flex_template("flex_attention")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
)
|
||||
|
||||
|
||||
@ -464,7 +464,7 @@ def flex_attention_backward_grid(
|
||||
flex_attention_backward_template = TritonTemplate(
|
||||
name="flex_attention_backward",
|
||||
grid=flex_attention_backward_grid,
|
||||
source=load_template("flex_backwards") + load_template("utilities"),
|
||||
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ from .common import (
|
||||
create_num_blocks_fake_generator,
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_template,
|
||||
load_flex_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
)
|
||||
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
|
||||
flex_decoding_template = TritonTemplate(
|
||||
name="flex_decoding",
|
||||
grid=flex_decoding_grid,
|
||||
source=load_template("flex_decode")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
source=load_flex_template("flex_decode")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
|
||||
from ...lowering import empty_strided
|
||||
from .common import infer_dense_strides, load_template, SubgraphResults
|
||||
from .common import infer_dense_strides, load_flex_template, SubgraphResults
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -36,7 +36,8 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
|
||||
|
||||
flash_attention_cutedsl_template = CuteDSLTemplate(
|
||||
name="flash_attention_cutedsl", source=load_template("flash_attention")
|
||||
name="flash_attention_cutedsl",
|
||||
source=load_flex_template("flash_attention"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,6 +67,10 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||
"activation_quantization_aten_pass",
|
||||
"inductor_autotune_lookup_table",
|
||||
@ -3886,3 +3890,10 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
|
||||
return dep_name.startswith(
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
|
||||
)
|
||||
|
||||
|
||||
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
|
||||
def load_template(name: str, template_dir: Path) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(template_dir / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
Reference in New Issue
Block a user