mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit febb60323018948b2b9d2cff35b3cc4e0d0c55c8. Reverted https://github.com/pytorch/pytorch/pull/165576 on behalf of https://github.com/seemethere due to This was actually reverted internally, current PR is linked to a stale diff so diff train tools think that this is landed via co-dev when it was actually reverted ([comment](https://github.com/pytorch/pytorch/pull/165576#issuecomment-3417510146))
This commit is contained in:
@ -3,7 +3,6 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -37,7 +36,6 @@ from ...lowering import (
|
||||
to_dtype,
|
||||
)
|
||||
from ...select_algorithm import realize_inputs
|
||||
from ...utils import load_template
|
||||
|
||||
|
||||
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
|
||||
@ -339,7 +337,13 @@ def next_power_of_two(n):
|
||||
return 2 ** math.ceil(math.log2(n))
|
||||
|
||||
|
||||
_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR)
|
||||
_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
def load_template(name: str) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
# Template strings have been moved to templates/common.py.jinja
|
||||
|
@ -29,7 +29,7 @@ from .common import (
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
infer_dense_strides,
|
||||
load_flex_template,
|
||||
load_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
SubgraphResults,
|
||||
@ -79,9 +79,9 @@ def get_float32_precision():
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=load_flex_template("flex_attention")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
source=load_template("flex_attention")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
)
|
||||
|
||||
|
||||
@ -464,7 +464,7 @@ def flex_attention_backward_grid(
|
||||
flex_attention_backward_template = TritonTemplate(
|
||||
name="flex_attention_backward",
|
||||
grid=flex_attention_backward_grid,
|
||||
source=load_flex_template("flex_backwards") + load_flex_template("utilities"),
|
||||
source=load_template("flex_backwards") + load_template("utilities"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ from .common import (
|
||||
create_num_blocks_fake_generator,
|
||||
freeze_irnodes,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_flex_template,
|
||||
load_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
)
|
||||
@ -97,9 +97,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
|
||||
flex_decoding_template = TritonTemplate(
|
||||
name="flex_decoding",
|
||||
grid=flex_decoding_grid,
|
||||
source=load_flex_template("flex_decode")
|
||||
+ load_flex_template("utilities")
|
||||
+ load_flex_template("common"),
|
||||
source=load_template("flex_decode")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from torch.fx import GraphModule
|
||||
|
||||
from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox
|
||||
from ...lowering import empty_strided
|
||||
from .common import infer_dense_strides, load_flex_template, SubgraphResults
|
||||
from .common import infer_dense_strides, load_template, SubgraphResults
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -36,8 +36,7 @@ from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
|
||||
|
||||
flash_attention_cutedsl_template = CuteDSLTemplate(
|
||||
name="flash_attention_cutedsl",
|
||||
source=load_flex_template("flash_attention"),
|
||||
name="flash_attention_cutedsl", source=load_template("flash_attention")
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,10 +67,6 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||
"activation_quantization_aten_pass",
|
||||
"inductor_autotune_lookup_table",
|
||||
@ -3890,10 +3886,3 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
|
||||
return dep_name.startswith(
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
|
||||
)
|
||||
|
||||
|
||||
# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
|
||||
def load_template(name: str, template_dir: Path) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(template_dir / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
Reference in New Issue
Block a user