mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Finish typing common.py (#146225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146225 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
dca5cc0255
commit
3a67c0e48d
@ -1,4 +1,3 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -32,7 +31,6 @@ import sympy
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
|
|
||||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
@ -42,6 +40,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
|||||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||||
|
|
||||||
from .. import config, metrics
|
from .. import config, metrics
|
||||||
|
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
boolean_ops,
|
boolean_ops,
|
||||||
DeferredLineBase,
|
DeferredLineBase,
|
||||||
@ -57,7 +56,9 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..ir import FixedLayout, IRNode
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
||||||
from ..loop_body import LoopBody
|
from ..loop_body import LoopBody
|
||||||
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
||||||
from .wrapper import PythonWrapperCodegen
|
from .wrapper import PythonWrapperCodegen
|
||||||
@ -121,7 +122,7 @@ class WorkspaceArg:
|
|||||||
dtype: torch.dtype = torch.uint8
|
dtype: torch.dtype = torch.uint8
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unique_name(prefix="workspace_") -> str:
|
def unique_name(prefix: str = "workspace_") -> str:
|
||||||
return f"{prefix}{next(V.graph.workspace_id)}"
|
return f"{prefix}{next(V.graph.workspace_id)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -456,7 +457,9 @@ def init_backend_registration() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def index_prevent_reordering(
|
def index_prevent_reordering(
|
||||||
index: list[sympy.Expr], index_vars, sizes
|
index: Sequence[sympy.Expr],
|
||||||
|
index_vars: Sequence[sympy.Expr],
|
||||||
|
sizes: Sequence[sympy.Expr],
|
||||||
) -> list[sympy.Expr]:
|
) -> list[sympy.Expr]:
|
||||||
from ..ir import FlexibleLayout
|
from ..ir import FlexibleLayout
|
||||||
|
|
||||||
@ -631,11 +634,11 @@ class DataTypePropagation:
|
|||||||
return self.propagate_graph(self.graphs["root"])
|
return self.propagate_graph(self.graphs["root"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def propagate_loopbody(cls, body) -> Optional[torch.dtype]:
|
def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
|
||||||
return cls(body).propagate()
|
return cls(body).propagate()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def propagate_scheduler_node(cls, node) -> Optional[torch.dtype]:
|
def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
|
||||||
from ..loop_body import LoopBody
|
from ..loop_body import LoopBody
|
||||||
from ..scheduler import SchedulerNode
|
from ..scheduler import SchedulerNode
|
||||||
|
|
||||||
@ -673,7 +676,7 @@ class OpDecompositions:
|
|||||||
return ops.mul(x, x)
|
return ops.mul(x, x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def erfc(x: OpVarT):
|
def erfc(x: OpVarT) -> OpVarT:
|
||||||
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
|
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -759,7 +762,7 @@ def _all_in_parens(string: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class OpOverrides(OpDecompositions):
|
class OpOverrides(OpDecompositions):
|
||||||
def __init__(self, parent):
|
def __init__(self, parent: OpsHandler[OpVarT]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._parent = parent
|
self._parent = parent
|
||||||
|
|
||||||
@ -1109,12 +1112,12 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]:
|
|||||||
class DeferredLine(DeferredLineBase):
|
class DeferredLine(DeferredLineBase):
|
||||||
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
||||||
|
|
||||||
def __init__(self, name, line):
|
def __init__(self, name: str, line: str):
|
||||||
super().__init__(line)
|
super().__init__(line)
|
||||||
self.name = name
|
self.name = name
|
||||||
assert not isinstance(line, DeferredLineBase)
|
assert not isinstance(line, DeferredLineBase)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self) -> Optional[str]:
|
||||||
if all(
|
if all(
|
||||||
self.name not in x
|
self.name not in x
|
||||||
for x in (
|
for x in (
|
||||||
@ -1127,14 +1130,14 @@ class DeferredLine(DeferredLineBase):
|
|||||||
return self.line
|
return self.line
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _new_line(self, line):
|
def _new_line(self, line: str) -> DeferredLine:
|
||||||
return DeferredLine(self.name, line)
|
return DeferredLine(self.name, line)
|
||||||
|
|
||||||
|
|
||||||
class BracesBuffer(IndentedBuffer):
|
class BracesBuffer(IndentedBuffer):
|
||||||
def indent(self, offset=1) -> contextlib.AbstractContextManager[None]:
|
def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def ctx():
|
def ctx() -> Iterator[None]:
|
||||||
for _ in range(offset):
|
for _ in range(offset):
|
||||||
self.writeline("{")
|
self.writeline("{")
|
||||||
self._indent += 1
|
self._indent += 1
|
||||||
@ -1163,7 +1166,7 @@ class ArgName:
|
|||||||
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
|
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
|
||||||
is_constexpr: bool = False
|
is_constexpr: bool = False
|
||||||
|
|
||||||
def full_name(self):
|
def full_name(self) -> str:
|
||||||
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
|
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
|
||||||
|
|
||||||
|
|
||||||
@ -1509,8 +1512,8 @@ class CSEVariable:
|
|||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self.name)
|
return hash(self.name)
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
return type(other) == type(self) and other.name == self.name
|
return isinstance(other, CSEVariable) and other.name == self.name
|
||||||
|
|
||||||
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
|
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
@ -1557,7 +1560,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
|||||||
self.invalidated_stores: OrderedSet[str] = OrderedSet()
|
self.invalidated_stores: OrderedSet[str] = OrderedSet()
|
||||||
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
|
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
|
||||||
|
|
||||||
def invalidate(self, keep_vars: OrderedSet[CSEVariable]):
|
def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
|
||||||
for name, tmp in [*self.store_cache.items()]:
|
for name, tmp in [*self.store_cache.items()]:
|
||||||
if tmp not in keep_vars:
|
if tmp not in keep_vars:
|
||||||
del self.store_cache[name]
|
del self.store_cache[name]
|
||||||
@ -1578,6 +1581,14 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
|||||||
reduction_cache=self.reduction_cache,
|
reduction_cache=self.reduction_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def scoped_copy(self) -> typing.Self:
|
||||||
|
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
|
||||||
|
new_cse = self.clone()
|
||||||
|
new_cse._cache = ScopedDict(self._cache)
|
||||||
|
new_cse.reduction_cache = ScopedDict(self.reduction_cache)
|
||||||
|
new_cse.store_cache = ScopedDict(self.store_cache)
|
||||||
|
return new_cse
|
||||||
|
|
||||||
def augment_key(self, cache_key: str) -> AugmentedKeyT:
|
def augment_key(self, cache_key: str) -> AugmentedKeyT:
|
||||||
"Override this method to augment cache key with backend specifics"
|
"Override this method to augment cache key with backend specifics"
|
||||||
return cast(AugmentedKeyT, cache_key)
|
return cast(AugmentedKeyT, cache_key)
|
||||||
@ -1730,7 +1741,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
self.kernel_name: Optional[str] = None
|
self.kernel_name: Optional[str] = None
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def set_current_node(self, node):
|
def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
|
||||||
prior = self.current_node
|
prior = self.current_node
|
||||||
self.current_node = node
|
self.current_node = node
|
||||||
self.node_to_bounds = node._body.bounds().get_bounds()
|
self.node_to_bounds = node._body.bounds().get_bounds()
|
||||||
@ -1740,16 +1751,16 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
self.current_node = prior
|
self.current_node = prior
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def swap_buffers(self, lb, cb=None, sb=None):
|
def swap_buffers(
|
||||||
def scope_cse(cse: CSE[CSEVariableType, Any]):
|
self,
|
||||||
new_cse = cse.clone()
|
lb: IndentedBuffer,
|
||||||
new_cse._cache = ScopedDict(cse._cache)
|
cb: Optional[IndentedBuffer] = None,
|
||||||
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
|
sb: Optional[IndentedBuffer] = None,
|
||||||
new_cse.store_cache = ScopedDict(cse.store_cache)
|
) -> Iterator[None]:
|
||||||
return new_cse
|
|
||||||
|
|
||||||
if cb is None:
|
if cb is None:
|
||||||
cb = lb
|
cb = lb
|
||||||
|
if disallow_stores := sb is None:
|
||||||
|
sb = IndentedBuffer()
|
||||||
loads = self.loads
|
loads = self.loads
|
||||||
compute = self.compute
|
compute = self.compute
|
||||||
stores = self.stores
|
stores = self.stores
|
||||||
@ -1757,7 +1768,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
self.loads = lb
|
self.loads = lb
|
||||||
self.compute = cb
|
self.compute = cb
|
||||||
self.stores = sb
|
self.stores = sb
|
||||||
self.cse = scope_cse(cse)
|
self.cse = cse.scoped_copy()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -1765,11 +1776,13 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
self.compute = compute
|
self.compute = compute
|
||||||
self.stores = stores
|
self.stores = stores
|
||||||
self.cse = cse
|
self.cse = cse
|
||||||
|
if disallow_stores:
|
||||||
|
assert not sb, "unexpected store inside swap_buffers"
|
||||||
|
|
||||||
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def indirect_load(self, name: str, index: sympy.Expr):
|
def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
||||||
"""A load the depends on an index we have read"""
|
"""A load the depends on an index we have read"""
|
||||||
prior = self.loads
|
prior = self.loads
|
||||||
try:
|
try:
|
||||||
@ -1779,7 +1792,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
finally:
|
finally:
|
||||||
self.loads = prior
|
self.loads = prior
|
||||||
|
|
||||||
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def store(
|
def store(
|
||||||
@ -1815,7 +1828,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
) -> tuple[CSEVariable, ...]:
|
) -> tuple[CSEVariable, ...]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def var_ranges(self):
|
def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def bucketize(
|
def bucketize(
|
||||||
@ -1869,21 +1882,21 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
|
|
||||||
def check_bounds(
|
def check_bounds(
|
||||||
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||||
):
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def index_to_str(self, index: sympy.Expr) -> str:
|
def index_to_str(self, index: sympy.Expr) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> typing.Self:
|
||||||
# TODO: hoist this to top level
|
# TODO: hoist this to top level
|
||||||
class CSEProxy:
|
class CSEProxy:
|
||||||
self.name = "CSEProxy"
|
name = "CSEProxy"
|
||||||
vr_analysis = ValueRangeAnalysis()
|
vr_analysis = ValueRangeAnalysis()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
||||||
def inner(*args, **kwargs):
|
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
|
||||||
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
||||||
|
|
||||||
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||||
@ -1891,7 +1904,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
|
|
||||||
output_idx = 0
|
output_idx = 0
|
||||||
|
|
||||||
def do_cse(v):
|
def do_cse(v: str) -> CSEVariable:
|
||||||
# cpp backend doesnt set current device - TODO: fix
|
# cpp backend doesnt set current device - TODO: fix
|
||||||
if V.graph.current_device is not None:
|
if V.graph.current_device is not None:
|
||||||
device_str = V.graph.get_current_device_or_throw().type
|
device_str = V.graph.get_current_device_or_throw().type
|
||||||
@ -1950,7 +1963,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bound_variable(name, *args, **kwargs):
|
def _bound_variable(
|
||||||
|
name: str, *args: Any, **kwargs: Any
|
||||||
|
) -> ValueRanges[Any]:
|
||||||
"""
|
"""
|
||||||
If the variable comes from an FX node, we forward the bound we have already computed
|
If the variable comes from an FX node, we forward the bound we have already computed
|
||||||
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
||||||
@ -1979,7 +1994,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
# If there is no FX bound but we know how to compute one we do so
|
# If there is no FX bound but we know how to compute one we do so
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
||||||
def arg_to_bound(x):
|
def arg_to_bound(x: Any) -> Any:
|
||||||
if isinstance(x, CSEVariable):
|
if isinstance(x, CSEVariable):
|
||||||
return x.bounds
|
return x.bounds
|
||||||
elif isinstance(x, sympy.Expr):
|
elif isinstance(x, sympy.Expr):
|
||||||
@ -1996,8 +2011,8 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
var: CSEVariable,
|
var: CSEVariable,
|
||||||
size: Union[sympy.Expr, int],
|
size: Union[sympy.Expr, int],
|
||||||
check: bool = True,
|
check: bool = True,
|
||||||
wrap_neg=True,
|
wrap_neg: bool = True,
|
||||||
):
|
) -> sympy.Symbol:
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size = sympy.Integer(size)
|
size = sympy.Integer(size)
|
||||||
assert isinstance(size, sympy.Expr), size
|
assert isinstance(size, sympy.Expr), size
|
||||||
@ -2045,7 +2060,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def check_bounds(
|
def check_bounds(
|
||||||
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||||
):
|
) -> None:
|
||||||
return self.check_bounds(expr, size, lower, upper)
|
return self.check_bounds(expr, size, lower, upper)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -2067,7 +2082,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_store_cache(name: str, value: CSEVariable):
|
def _update_store_cache(name: str, value: CSEVariable) -> None:
|
||||||
value = cast(CSEVariableType, value)
|
value = cast(CSEVariableType, value)
|
||||||
self.cse.store_cache[name] = value
|
self.cse.store_cache[name] = value
|
||||||
if self.current_node and name in V.graph.name_to_buffer:
|
if self.current_node and name in V.graph.name_to_buffer:
|
||||||
@ -2087,7 +2102,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
return None # type: ignore[return-value]
|
return None # type: ignore[return-value]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
def store_reduction(
|
||||||
|
name: str, index: sympy.Expr, value: CSEVariable
|
||||||
|
) -> None:
|
||||||
self.store_buffer_names.add(name)
|
self.store_buffer_names.add(name)
|
||||||
CSEProxy._update_store_cache(name, value)
|
CSEProxy._update_store_cache(name, value)
|
||||||
|
|
||||||
@ -2215,7 +2232,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
self.remove_kernel_local_buffers()
|
self.remove_kernel_local_buffers()
|
||||||
super().__exit__(exc_type, exc_val, exc_tb)
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
@ -2271,7 +2288,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
self.args.inplace_buffers[name] = REMOVED
|
self.args.inplace_buffers[name] = REMOVED
|
||||||
self.removed_buffers.add(name)
|
self.removed_buffers.add(name)
|
||||||
|
|
||||||
def rename_indexing(self, index) -> sympy.Expr:
|
def rename_indexing(
|
||||||
|
self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
|
||||||
|
) -> sympy.Expr:
|
||||||
# adds the necessary kernel args for index expressions
|
# adds the necessary kernel args for index expressions
|
||||||
# and renames variables in index expressions to kernel arg names
|
# and renames variables in index expressions to kernel arg names
|
||||||
if isinstance(index, (list, tuple)):
|
if isinstance(index, (list, tuple)):
|
||||||
@ -2292,7 +2311,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||||||
}
|
}
|
||||||
return sympy_subs(index, replacements)
|
return sympy_subs(index, replacements)
|
||||||
|
|
||||||
def create_cse_var(self, *args, **kwargs):
|
def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
|
||||||
return CSEVariable(*args, **kwargs)
|
return CSEVariable(*args, **kwargs)
|
||||||
|
|
||||||
def arg_name(self, node: IRNode) -> Optional[str]:
|
def arg_name(self, node: IRNode) -> Optional[str]:
|
||||||
@ -2313,7 +2332,7 @@ class OptimizationContext:
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def jinja2_env():
|
def jinja2_env() -> Any:
|
||||||
try:
|
try:
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
@ -2332,7 +2351,9 @@ class KernelTemplate:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
|
def indent_except_first(
|
||||||
|
source: str, num_indents: int, indents_spacing: int = 4
|
||||||
|
) -> str:
|
||||||
lines = source.splitlines(True)
|
lines = source.splitlines(True)
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
lines[1:] = [
|
lines[1:] = [
|
||||||
@ -2341,64 +2362,74 @@ class KernelTemplate:
|
|||||||
return "".join(lines)
|
return "".join(lines)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _template_from_string(source):
|
def _template_from_string(source: str) -> Any:
|
||||||
env = jinja2_env()
|
env = jinja2_env()
|
||||||
if env is None:
|
if env is None:
|
||||||
return None
|
return None
|
||||||
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
||||||
from jinja2 import TemplateSyntaxError
|
from jinja2 import TemplateSyntaxError
|
||||||
|
|
||||||
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
|
||||||
def __init__(self, original_error):
|
|
||||||
super().__init__(
|
|
||||||
original_error.message,
|
|
||||||
original_error.lineno,
|
|
||||||
original_error.name,
|
|
||||||
original_error.filename,
|
|
||||||
)
|
|
||||||
self.original_error = original_error
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
error_info = f"Error in template at line {self.lineno}\n"
|
|
||||||
error_info += f"Error message: {self.message}\n"
|
|
||||||
if hasattr(self.original_error, "source"):
|
|
||||||
lines = self.original_error.source.split("\n")
|
|
||||||
error_info += "Context:\n"
|
|
||||||
start = max(0, self.lineno - 2)
|
|
||||||
end = min(len(lines), self.lineno + 2)
|
|
||||||
for i in range(start, end):
|
|
||||||
if i == self.lineno - 1:
|
|
||||||
error_info += f"{i + 1}: --> {lines[i]}\n"
|
|
||||||
if hasattr(self.original_error, "column"):
|
|
||||||
error_info += (
|
|
||||||
" "
|
|
||||||
+ " " * (self.original_error.column - 1)
|
|
||||||
+ "^\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
error_info += f"{i + 1}: {lines[i]}\n"
|
|
||||||
return error_info
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return env.from_string(source)
|
return env.from_string(source)
|
||||||
except TemplateSyntaxError as e:
|
except TemplateSyntaxError as e:
|
||||||
|
|
||||||
|
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
||||||
|
def __init__(self, original_error: TemplateSyntaxError) -> None:
|
||||||
|
super().__init__(
|
||||||
|
original_error.message,
|
||||||
|
original_error.lineno,
|
||||||
|
original_error.name,
|
||||||
|
original_error.filename,
|
||||||
|
)
|
||||||
|
self.original_error = original_error
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
error_info = f"Error in template at line {self.lineno}\n"
|
||||||
|
error_info += f"Error message: {self.message}\n"
|
||||||
|
if hasattr(self.original_error, "source"):
|
||||||
|
lines = self.original_error.source.split("\n")
|
||||||
|
error_info += "Context:\n"
|
||||||
|
start = max(0, self.lineno - 2)
|
||||||
|
end = min(len(lines), self.lineno + 2)
|
||||||
|
for i in range(start, end):
|
||||||
|
if i == self.lineno - 1:
|
||||||
|
error_info += f"{i + 1}: --> {lines[i]}\n"
|
||||||
|
if hasattr(self.original_error, "column"):
|
||||||
|
error_info += (
|
||||||
|
" "
|
||||||
|
+ " " * (self.original_error.column - 1)
|
||||||
|
+ "^\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_info += f"{i + 1}: {lines[i]}\n"
|
||||||
|
return error_info
|
||||||
|
|
||||||
raise DetailedTemplateSyntaxError(e) from e
|
raise DetailedTemplateSyntaxError(e) from e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fake_get_dtype(fake_out):
|
def _fake_get_dtype(
|
||||||
|
fake_outs: Union[list[Buffer], Buffer]
|
||||||
|
) -> Callable[[str], torch.dtype]:
|
||||||
_get_dtype_real = V.graph.get_dtype
|
_get_dtype_real = V.graph.get_dtype
|
||||||
|
if isinstance(fake_outs, (list, tuple)):
|
||||||
|
lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
|
||||||
|
else:
|
||||||
|
lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
|
||||||
|
|
||||||
def get_dtype(name):
|
def get_dtype(name: str) -> torch.dtype:
|
||||||
if name == fake_out.get_name():
|
result = lookup.get(name)
|
||||||
return fake_out.get_dtype()
|
if result is not None:
|
||||||
|
return result
|
||||||
return _get_dtype_real(name)
|
return _get_dtype_real(name)
|
||||||
|
|
||||||
return get_dtype
|
return get_dtype
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def maybe_append_choice(self, choices, **kwargs):
|
def maybe_append_choice(
|
||||||
|
self, choices: list[Any], **kwargs: Any
|
||||||
|
) -> Optional[NotImplementedError]:
|
||||||
"""
|
"""
|
||||||
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
||||||
Returns None if success, otherwise returns the error.
|
Returns None if success, otherwise returns the error.
|
||||||
@ -2413,7 +2444,7 @@ class KernelTemplate:
|
|||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
def generate(self, **kwargs) -> torch._inductor.ir.ChoiceCaller:
|
def generate(self, **kwargs: Any) -> ChoiceCaller:
|
||||||
"""
|
"""
|
||||||
Generates a ChoiceCaller instance from the given arguments.
|
Generates a ChoiceCaller instance from the given arguments.
|
||||||
"""
|
"""
|
||||||
|
@ -188,18 +188,6 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
|||||||
for idx in range(gemm_grouped_num)
|
for idx in range(gemm_grouped_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fake_get_dtype(fake_outs: list[ir.Buffer]) -> Callable[[str], torch.dtype]:
|
|
||||||
_get_dtype_real = V.graph.get_dtype
|
|
||||||
|
|
||||||
def get_dtype(name: str) -> torch.dtype:
|
|
||||||
for fake_out in fake_outs:
|
|
||||||
if name == fake_out.get_name():
|
|
||||||
return fake_out.get_dtype()
|
|
||||||
return _get_dtype_real(name)
|
|
||||||
|
|
||||||
return get_dtype
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_choices(
|
def add_choices(
|
||||||
cls,
|
cls,
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
@ -29,7 +31,7 @@ ReductionType = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _arg_str(a) -> str:
|
def _arg_str(a: object) -> str:
|
||||||
if isinstance(a, sympy.Expr):
|
if isinstance(a, sympy.Expr):
|
||||||
return sympy_str(a)
|
return sympy_str(a)
|
||||||
return str(a)
|
return str(a)
|
||||||
@ -44,7 +46,7 @@ class OpsHandler(Protocol[T]):
|
|||||||
"""
|
"""
|
||||||
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
|
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
|
||||||
as well as the contract for op handlers. The type T signifies the domain
|
as well as the contract for op handlers. The type T signifies the domain
|
||||||
of the abstract analysis AKA what all of the functions return / take as arguments
|
of the abstract analysis AKA what all the functions return / take as arguments
|
||||||
anywhere compute occurs.
|
anywhere compute occurs.
|
||||||
|
|
||||||
While these operators are typically dtype polymorphic (e.g., you can use mul
|
While these operators are typically dtype polymorphic (e.g., you can use mul
|
||||||
@ -247,7 +249,7 @@ class OpsHandler(Protocol[T]):
|
|||||||
# TODO: in practice, this seems to actually return None, but not returning
|
# TODO: in practice, this seems to actually return None, but not returning
|
||||||
# a T makes common __getattr__ idioms not type correctly. Figure out if
|
# a T makes common __getattr__ idioms not type correctly. Figure out if
|
||||||
# this should be returning something.
|
# this should be returning something.
|
||||||
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
|
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> None:
|
||||||
"""
|
"""
|
||||||
Store the fully accumulated result of 'reduction' to the memory
|
Store the fully accumulated result of 'reduction' to the memory
|
||||||
location 'name' offset by 'expr'.
|
location 'name' offset by 'expr'.
|
||||||
@ -1046,7 +1048,7 @@ class ExtractConstantsHandler(NoopHandler):
|
|||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant":
|
def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant:
|
||||||
from torch._inductor import ir
|
from torch._inductor import ir
|
||||||
|
|
||||||
return ir.Constant(value=value, dtype=dtype, device=self.device)
|
return ir.Constant(value=value, dtype=dtype, device=self.device)
|
||||||
|
Reference in New Issue
Block a user