mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CuTeDSL] Move load_template up two directories
Summary: This is a reland of D84527470 Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future. Test Plan: INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 Differential Revision: D85013024
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fa0db212e7
commit
ce7be304e9
@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -36,6 +37,7 @@ from ...lowering import (
|
||||
to_dtype,
|
||||
)
|
||||
from ...select_algorithm import realize_inputs
|
||||
from ...utils import load_template
|
||||
|
||||
|
||||
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
|
||||
@ -337,13 +339,8 @@ def next_power_of_two(n):
|
||||
return 2 ** math.ceil(math.log2(n))
|
||||
|
||||
|
||||
_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
def load_template(name: str) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
|
||||
|
||||
|
||||
# Template strings have been moved to templates/common.py.jinja
|
||||
|
||||
@ -29,7 +29,7 @@ from .common import (
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
infer_dense_strides,
|
||||
load_template,
|
||||
load_flex_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
SubgraphResults,
|
||||
@ -79,9 +79,9 @@ def get_float32_precision():
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=load_template("flex_attention")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
source=load_flex_template("flex_attention")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
)
|
||||
|
||||
|
||||
@ -464,7 +464,7 @@ def flex_attention_backward_grid(
|
||||
flex_attention_backward_template = TritonTemplate(
|
||||
name="flex_attention_backward",
|
||||
grid=flex_attention_backward_grid,
|
||||
source=load_template("flex_backwards") + load_template("utilities"),
|
||||
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ from .common import (
|
||||
create_num_blocks_fake_generator,
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_template,
|
||||
load_flex_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
)
|
||||
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
|
||||
flex_decoding_template = TritonTemplate(
|
||||
name="flex_decoding",
|
||||
grid=flex_decoding_grid,
|
||||
source=load_template("flex_decode")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
source=load_flex_template("flex_decode")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
|
||||
from ...lowering import empty_strided
|
||||
from .common import infer_dense_strides, load_template, SubgraphResults
|
||||
from .common import infer_dense_strides, load_flex_template, SubgraphResults
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -36,7 +36,7 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
|
||||
|
||||
flash_attention_cutedsl_template = CuteDSLTemplate(
|
||||
name="flash_attention_cutedsl", source=load_template("flash_attention")
|
||||
name="flash_attention_cutedsl", source=load_flex_template("flash_attention")
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -67,6 +67,9 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||
"activation_quantization_aten_pass",
|
||||
"inductor_autotune_lookup_table",
|
||||
@ -3885,3 +3888,10 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
|
||||
return dep_name.startswith(
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
|
||||
)
|
||||
|
||||
|
||||
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
|
||||
def load_template(name: str, template_dir: Path) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(template_dir / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
||||
Reference in New Issue
Block a user