Annotate graph.py (#131400)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131400
Approved by: https://github.com/shunting314
This commit is contained in:
eellison
2024-07-22 18:35:39 -07:00
committed by PyTorch MergeBot
parent 102d8e5a63
commit 16a2a1aad3
13 changed files with 184 additions and 122 deletions

View File

@ -1357,7 +1357,7 @@ class KernelArgs:
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
arg_types.append(type(outer))
arg_types.append(type(outer)) # type: ignore[arg-type]
precompile_args.append(SizeArg(inner, outer))
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)

View File

@ -449,7 +449,7 @@ class CppPackedGemmTemplate(CppTemplate):
W.get_name() + "_BMatrixCompens",
)
else:
BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0)
BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
new_inputs.append(BCompensate)
return new_inputs, layout_or_out

View File

@ -860,7 +860,7 @@ class SIMDKernel(Kernel):
argdefs, call_args, signature, _ = self.args.python_argdefs()
uniform_stride_order = None
for arg_name in call_args:
buf = V.graph.get_buffer(arg_name)
buf = V.graph.try_get_buffer(arg_name)
if buf and len(buf.layout.size) == 4:
# ignore the tensor if only 1 dimension is non-zero
if len([x for x in buf.layout.size if x == 1]) == 3:
@ -877,13 +877,13 @@ class SIMDKernel(Kernel):
stride_order_list = [
ir.get_stride_order(V.graph.get_buffer(name).layout.stride)
if V.graph.get_buffer(name)
if V.graph.try_get_buffer(name)
else None
for name in call_args
]
size_list = [
V.graph.get_buffer(name).layout.size
if V.graph.get_buffer(name)
if V.graph.try_get_buffer(name)
else None
for name in call_args
]

View File

@ -2361,7 +2361,7 @@ class TritonKernel(SIMDKernel):
var_names = []
for arg_name, arg_sig in zip(call_args, signature):
var_name = f"arg_{next(name_cnt)}"
buf = V.graph.get_buffer(arg_name)
buf = V.graph.try_get_buffer(arg_name)
if buf:
result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long

View File

@ -80,7 +80,7 @@ def is_unaligned_buffer(arg: TensorArg):
if V.graph.scheduler:
layout = V.graph.scheduler.get_buffer_layout(buf_name)
else:
buffer = V.graph.get_buffer(buf_name)
buffer = V.graph.try_get_buffer(buf_name)
# output arg
if not buffer:
assert buf_name == V.kernel.output_node.name

View File

@ -155,7 +155,7 @@ TritonGrid = Union[
def user_defined_kernel_grid_fn_code(
name: str,
configs: List[triton.Config],
configs: List[triton.Config], # type: ignore[name-defined]
grids: List[TritonGrid],
wrapper: Optional[WrapperCodeGen] = None,
) -> Tuple[str, str]:
@ -1342,8 +1342,8 @@ class WrapperCodeGen(CodeGen):
compile_wrapper.splice(kernel.src, strip=True)
# Also include any possible kernel being called indirectly
from triton import JITFunction
from triton.language import constexpr
from triton import JITFunction # type: ignore[name-defined, attr-defined]
from triton.language import constexpr # type: ignore[name-defined]
# global constexpr vars handled above
symbols_included = {original_name}
@ -1522,7 +1522,7 @@ class WrapperCodeGen(CodeGen):
def generate_example_arg_value(self, arg, arg_type=None, raw_arg=None, index=None):
if isinstance(arg_type, torch_dtype):
if V.graph.get_buffer(arg) is not None:
if V.graph.try_get_buffer(arg) is not None:
buf_name = arg
buf = V.graph.get_buffer(arg)
else:

View File

@ -555,7 +555,7 @@ def extract_input_node_reduction_ranges(
if read.name in seen:
continue
seen.add(read.name)
buffer = V.graph.get_buffer(read.name)
buffer = V.graph.try_get_buffer(read.name)
if buffer is None:
continue
op = buffer.get_defining_op()

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
import functools
import itertools
import logging
@ -9,13 +8,17 @@ import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from types import ModuleType
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
@ -23,17 +26,21 @@ from typing import (
)
import sympy
from sympy import Expr
import torch
import torch._logging
import torch.fx
from torch import device, Tensor
from torch._decomp import get_decompositions
from torch._dynamo.utils import defake, dynamo_timed
from torch._logging import LazyString, trace_structured
from torch._prims_common import make_channels_last_strides_for
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import (
free_unbacked_symbols,
has_free_symbols,
@ -42,6 +49,8 @@ from torch.fx.experimental.symbolic_shapes import (
ShapeEnv,
SymTypes,
)
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.utils._mode_utils import no_dispatch
from torch.utils._sympy.numbers import int_oo
@ -82,6 +91,8 @@ from .lowering import (
needs_realized_inputs,
unsupported_output_tensor,
)
from .scheduler import BaseSchedulerNode
from .sizevars import SizeVarAllocator
from .utils import (
convert_shape_to_inductor,
@ -112,11 +123,11 @@ if config.is_fbcode():
from torch._inductor.fb.utils import log_module_code
else:
def log_module_code(*args, **kwargs):
def log_module_code(*args: Any, **kwargs: Any) -> None:
pass
def supported_dtype_of_cpp_wrapper(dtype, cuda):
def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool:
supported_dtype = {
torch.float32,
torch.float64,
@ -141,7 +152,7 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
return dtype in supported_dtype
def may_get_constant_buffer_dtype(constant_buffer):
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
assert isinstance(
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
@ -159,12 +170,14 @@ def may_get_constant_buffer_dtype(constant_buffer):
return None
def is_magic_method(op):
def is_magic_method(op: Any) -> bool:
magic_ops = {method_to_operator(m) for m in magic_methods}
return op in magic_ops
def getattr_recursive(obj, target):
def getattr_recursive(
obj: GraphModule, target: str
) -> Union[Tensor, torch._C.ScriptObject, GraphModule]:
target_atoms = target.split(".")
attr_itr = obj
for i, atom in enumerate(target_atoms):
@ -176,7 +189,7 @@ def getattr_recursive(obj, target):
return attr_itr
def mark_nodes_dislike_padding(g):
def mark_nodes_dislike_padding(g: Graph) -> None:
"""
Nodes like convolution/convolution_backward want its input to be dense.
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
@ -207,7 +220,9 @@ def mark_nodes_dislike_padding(g):
aten.scatter_reduce,
}
def _get_overload_packet(node):
def _get_overload_packet(
node: torch.fx.Node,
) -> Optional[torch._ops.OpOverloadPacket]:
return (
node.target._overloadpacket
if node.op == "call_function"
@ -237,7 +252,9 @@ def mark_nodes_dislike_padding(g):
class GraphLowering(torch.fx.Interpreter):
graph_outputs: List[ir.IRNode]
def symbolic_sizes_strides(self, ex: torch.Tensor):
def symbolic_sizes_strides(
self, ex: torch.Tensor
) -> Tuple[Union[List[int], List[Expr]], Union[List[int], List[Expr]]]:
"""
Support dynamic shapes and dynamic strides by assigning variables
to each dimension. We duck-shape tensors, so if two tensors
@ -272,7 +289,9 @@ class GraphLowering(torch.fx.Interpreter):
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
return size, stride
def static_sizes_strides(self, ex: torch.Tensor):
def static_sizes_strides(
self, ex: torch.Tensor
) -> Tuple[List[sympy.Expr], List[sympy.Expr]]:
"""
Primarily used to weights
"""
@ -284,19 +303,21 @@ class GraphLowering(torch.fx.Interpreter):
self,
gm: torch.fx.GraphModule,
example_inputs: Optional[List[torch.Tensor]] = None,
shape_env=None,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
user_visible_outputs=None,
layout_opt=None,
extern_node_serializer=None,
is_inference=False,
is_const_graph=False,
const_output_index=None,
const_code=None,
const_module=None,
name=None,
shape_env: Optional[ShapeEnv] = None,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
user_visible_outputs: Optional[Dict[str, None]] = None,
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[
Callable[[List[ir.ExternKernelNode]], Any]
] = None,
is_inference: bool = False,
is_const_graph: bool = False,
const_output_index: Optional[Dict[str, int]] = None,
const_code: Optional[str] = None,
const_module: Optional["GraphLowering"] = None,
name: Optional[str] = None,
):
super().__init__(gm)
self.example_inputs = example_inputs
@ -370,7 +391,7 @@ class GraphLowering(torch.fx.Interpreter):
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
self.name_to_op: Dict[str, ir.Operation] = {}
self.creation_time = time.time()
self.name = name
self.name = name # type: ignore[assignment]
self.cpp_wrapper = cpp_wrapper
# record multi_kernel choice for cpp_wrapper so the second pass knows
@ -407,7 +428,7 @@ class GraphLowering(torch.fx.Interpreter):
self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
"dynamo_flat_name_to_original_fqn", {}
)
self.allocated_constant_name = (
self.allocated_constant_name: Dict[str, str] = (
const_module.allocated_constant_name if const_module is not None else {}
)
init_backend_registration()
@ -417,12 +438,14 @@ class GraphLowering(torch.fx.Interpreter):
self.aligned_inputs: Set[str] = set()
self.no_fuse_buffer_names: Set[str] = set()
def has_feature(self, device, feature):
def has_feature(
self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature
) -> bool:
assert isinstance(feature, BackendFeature), feature
return feature in self.get_backend_features(get_device_type(device))
@staticmethod
def decide_layout_opt(gm, *, is_inference) -> bool:
def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
"""
Decide if we should enable layout optimization for this graph based on
heuristics.
@ -470,19 +493,21 @@ class GraphLowering(torch.fx.Interpreter):
)
return False
def is_grouped(n):
return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
def is_grouped(n: Any) -> bool:
meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator]
assert isinstance(meta_val, torch.Tensor)
return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator]
def is_in_out_channel(n):
def is_in_out_channel(n: torch.fx.Node) -> bool:
return (
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
and n.args[1].meta["val"].size(2) > 1
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator]
and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator]
)
def is_small_channel(n):
def is_small_channel(n: torch.fx.Node) -> bool:
return (
n.args[1].meta["val"].size(0) <= 64
and n.args[1].meta["val"].size(1) <= 64
n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator]
and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator]
)
# only grouped convolutions benchmarked as slower in conv samples for inference only
@ -614,7 +639,7 @@ class GraphLowering(torch.fx.Interpreter):
name=self.qualify_name(subgraph_name),
)
def find_nodes_prefer_channels_last(self):
def find_nodes_prefer_channels_last(self) -> Set[Node]:
"""
The rule to decide if an node prefer channels last is simple.
1. if it's input/output of a convolution
@ -664,12 +689,12 @@ class GraphLowering(torch.fx.Interpreter):
return output_set
def warn_fallback(self, name):
def warn_fallback(self, name: str) -> None:
if name not in self._warned_fallback:
self._warned_fallback.add(name)
perf_hint_log.info("Using FallbackKernel: %s", name)
def add_device_info(self, device: torch.device):
def add_device_info(self, device: torch.device) -> None:
self.device_types.add(device.type)
if device.index is not None:
self.device_idxs.add(device.index)
@ -677,10 +702,12 @@ class GraphLowering(torch.fx.Interpreter):
self.device_node_mapping[device] = V.graph.current_node
@property
def fake_mode(self):
def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode:
return V.fake_mode
def get_buffer(self, buffer_name: str):
def try_get_buffer(
self, buffer_name: str
) -> Optional[Union[ir.TensorBox, ir.Buffer]]:
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name]
if buffer_name in self.graph_inputs:
@ -693,9 +720,16 @@ class GraphLowering(torch.fx.Interpreter):
data.device, data.dtype, *V.graph.static_sizes_strides(data)
),
)
return None
def get_dtype(self, buffer_name: str):
def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]:
buf = self.try_get_buffer(buffer_name)
if buf is not None:
return buf
raise RuntimeError(f"Failed to find buffer matching name {buffer_name}")
def get_dtype(self, buffer_name: str) -> torch.dtype:
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
if buffer_name in self.name_to_buffer:
@ -707,7 +741,7 @@ class GraphLowering(torch.fx.Interpreter):
return self.get_dtype(m.group(1))
raise KeyError(f"could not find {buffer_name}")
def get_numel(self, buffer_name: str):
def get_numel(self, buffer_name: str) -> Union[int, Expr]:
from .ir import MultiOutputLayout
if buffer_name in self.constants:
@ -722,10 +756,10 @@ class GraphLowering(torch.fx.Interpreter):
raise KeyError(f"could not find {buffer_name}")
@dynamo_timed
def run(self, *args):
def run(self, *args: Any) -> Any:
return super().run(*args)
def register_operation(self, op: ir.Operation):
def register_operation(self, op: ir.Operation) -> str:
assert op.operation_name is None, f"Operation registered twice: {op}"
assert isinstance(op, ir.Operation)
name = self.qualify_name(f"op{len(self.operations)}")
@ -734,7 +768,7 @@ class GraphLowering(torch.fx.Interpreter):
op.operation_name = name
return name
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False):
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
name = self.qualify_name(f"buf{len(self.buffers)}")
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer
@ -754,8 +788,10 @@ class GraphLowering(torch.fx.Interpreter):
self.lists[name] = operation_names
return name
def register_users_of(self, node_output):
def register(value):
def register_users_of(
self, node_output: Union[Iterable[ir.IRNode], ir.IRNode]
) -> None:
def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None:
if isinstance(value, (list, tuple)):
for x in value:
register(x)
@ -765,7 +801,7 @@ class GraphLowering(torch.fx.Interpreter):
register(node_output)
def mark_buffer_mutated(self, name: str):
def mark_buffer_mutated(self, name: str) -> None:
"""
When a buffer is mutated we need to make sure all the reads to
the old version are realized before the mutation happens.
@ -779,7 +815,7 @@ class GraphLowering(torch.fx.Interpreter):
for user in self.name_to_users[name]:
user.realize()
def get_original_value_of_constant(self, name: str):
def get_original_value_of_constant(self, name: str) -> torch.Tensor:
"""
In AOTI, module buffers may have been mutated during the tracing and compilation.
Thus we need to read from previously stored original buffers, to make sure the
@ -795,7 +831,9 @@ class GraphLowering(torch.fx.Interpreter):
else self.constants[name]
)
def allocate_non_dup_const_name(self, name, data):
def allocate_non_dup_const_name(
self, name: Optional[str], data: Union[Tensor]
) -> str:
orig_name = name
if not config.aot_inductor.use_runtime_constant_folding:
for constant_name, value in self.constants.items():
@ -813,6 +851,7 @@ class GraphLowering(torch.fx.Interpreter):
if name is None:
name = f"constant{len(self.constants)}"
assert name is not None
if name[0].isdigit():
name = f"constant_{name}"
name = self.qualify_name(name)
@ -830,10 +869,12 @@ class GraphLowering(torch.fx.Interpreter):
f"{tuple(data.size())!r} {tuple(data.stride())!r} "
f"{hash(data):x}"
)
self.allocated_constant_name[name] = orig_name
self.allocated_constant_name[name] = orig_name # type: ignore[assignment]
return name
def add_tensor_constant(self, data, name=None):
def add_tensor_constant(
self, data: Tensor, name: Optional[str] = None
) -> TensorBox:
new_name = self.allocate_non_dup_const_name(name, data)
return TensorBox.create(
ir.ConstantBuffer(
@ -842,7 +883,7 @@ class GraphLowering(torch.fx.Interpreter):
)
)
def constant_name(self, name: str, device_override: Optional[torch.device]):
def constant_name(self, name: str, device_override: Optional[torch.device]) -> str:
"""
We AOT copy constants to the devices they are needed on.
If device_override doesn't match the constant's device, then
@ -858,7 +899,9 @@ class GraphLowering(torch.fx.Interpreter):
self.constants[name].to(device_override),
)
def placeholder(self, target: str, args, kwargs):
def placeholder(
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
) -> Union[Expr, TensorBox, None]:
example = super().placeholder(target, args, kwargs)
self.graph_input_names.append(target)
if isinstance(example, SymTypes):
@ -882,7 +925,7 @@ class GraphLowering(torch.fx.Interpreter):
# the first N inputs are weights
sizes, strides = self.static_sizes_strides(example)
else:
sizes, strides = self.symbolic_sizes_strides(example)
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
# TODO(jansel): handle input aliasing
target = self.qualify_name(target)
tensor = TensorBox.create(
@ -912,7 +955,7 @@ class GraphLowering(torch.fx.Interpreter):
self.aligned_inputs.add(target)
return tensor
def call_function(self, target, args, kwargs):
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
@ -923,7 +966,9 @@ class GraphLowering(torch.fx.Interpreter):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)
def get_custom_op_layout_constraints(target, args, kwargs):
def get_custom_op_layout_constraints(
target: torch._ops.OpOverload, args: Any, kwargs: Dict[str, Any]
) -> Tuple[Optional[Callable], Tuple[Any], Dict[str, Any]]: # type: ignore[type-arg]
# Custom operations that require preserving stride order
# which run through implicit fallback must constrain their
# arguments' fx strides
@ -971,8 +1016,8 @@ class GraphLowering(torch.fx.Interpreter):
raise MissingOperatorWithoutDecomp(target, args, kwargs)
try:
log.debug(" via %s", lowerings[target])
out = lowerings[target](*args, **kwargs)
log.debug(" via %s", lowerings[target]) # type: ignore[index]
out = lowerings[target](*args, **kwargs) # type: ignore[index]
return out
except Exception as e:
raise LoweringException(e, target, args, kwargs).with_traceback(
@ -986,9 +1031,11 @@ class GraphLowering(torch.fx.Interpreter):
"""
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(self, target, args, kwargs):
def get_attr(
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant
value = getattr_recursive(self.module, target)
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
if isinstance(value, torch.fx.GraphModule):
return ir.Subgraph(name=target, graph_module=value)
@ -998,6 +1045,7 @@ class GraphLowering(torch.fx.Interpreter):
self.constant_reprs[target] = ""
return TorchBindObject(target, value)
assert isinstance(value, torch.Tensor)
if (
config.aot_inductor.use_runtime_constant_folding
or config.always_keep_tensor_constants
@ -1017,13 +1065,15 @@ class GraphLowering(torch.fx.Interpreter):
return self.add_tensor_constant(value, target)
def call_module(self, target, args, kwargs):
def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
raise AssertionError
def call_method(self, target, args, kwargs):
def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
raise AssertionError
def output(self, target, args, kwargs):
def output(
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
) -> None:
result = super().output(target, args, kwargs)
if not isinstance(result, (tuple, list)):
# nested subgraphs can have singleton outputs
@ -1099,12 +1149,12 @@ class GraphLowering(torch.fx.Interpreter):
self.graph_id if self.graph_id is not None else -1,
)
def finalize(self):
def finalize(self) -> None:
for buf in self.buffers:
buf.decide_layout()
@contextmanager
def set_current_node(self, node: torch.fx.Node):
def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def]
old = self.current_node
try:
self.current_node = node
@ -1114,9 +1164,9 @@ class GraphLowering(torch.fx.Interpreter):
def try_match_insignificant_strides(
self,
tensor,
tensor: Union[ir.TensorBox, ir.BaseView],
meta_strides_inp: Tuple[Union[int, torch.SymInt], ...],
) -> ir.TensorBox:
) -> Union[ir.TensorBox, ir.BaseView]:
"""
Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
dimensions - size 0 or 1 - will be updated.
@ -1135,9 +1185,13 @@ class GraphLowering(torch.fx.Interpreter):
self.sizevars.statically_known_equals(s1, s2)
for s1, s2 in zip(meta_strides, tensor.get_stride())
):
return tensor
return tensor # type: ignore[arg-type]
def significant_strides_equal(shape, meta_strides, tensor_strides):
def significant_strides_equal(
shape: Sequence[Union[Expr, int]],
meta_strides: Sequence[Union[Expr, int]],
tensor_strides: Sequence[Union[Expr, int]],
) -> bool:
for dim, s1, s2 in zip(shape, meta_strides, tensor_strides):
if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type]
continue
@ -1167,8 +1221,8 @@ class GraphLowering(torch.fx.Interpreter):
)
return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout))
def run_node(self, n: torch.fx.Node):
def debug(msg):
def run_node(self, n: torch.fx.Node) -> object:
def debug(msg: str) -> None:
log.debug("lowering %s %s", LazyString(n.format_node), msg)
buffer_watermark = len(self.buffers)
@ -1193,7 +1247,7 @@ class GraphLowering(torch.fx.Interpreter):
elif n.op == "call_function" and n.target in layout_constraints:
debug("layout_constraints")
args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index]
result = self.call_function(n.target, args, kwargs)
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
elif is_magic_method(n.target):
# TODO: this is sus, it probably should be handled in the
# lowerings themselves similarly to sym_size/sym-stride
@ -1392,7 +1446,7 @@ class GraphLowering(torch.fx.Interpreter):
for op in self.operations[operation_watermark:]:
new_unbacked_defs |= op.get_unbacked_symbol_defs()
def format_new_defs():
def format_new_defs() -> str:
r = []
for buf in self.buffers[buffer_watermark:]:
r.append(
@ -1427,7 +1481,7 @@ class GraphLowering(torch.fx.Interpreter):
# This is all doable, it just hasn't been done yet.
shape_env = V.graph.sizevars.shape_env
def make_assert(expr, msg):
def make_assert(expr: Expr, msg: str) -> None:
assert_op = ir.AssertScalar(expr, msg)
self.register_buffer(assert_op, set_name=True)
self.register_operation(assert_op)
@ -1438,7 +1492,7 @@ class GraphLowering(torch.fx.Interpreter):
vr = shape_env.var_to_range[i0]
if not shape_env._default_unspecified_value_range().issubset(vr):
def is_convertible(s):
def is_convertible(s: Expr) -> bool:
if s in (int_oo, -int_oo):
return False
try:
@ -1492,7 +1546,7 @@ class GraphLowering(torch.fx.Interpreter):
return result
def validate_can_generate_cpp_wrapper(self):
def validate_can_generate_cpp_wrapper(self) -> None:
if config.disable_cpp_codegen:
raise CppWrapperCodeGenError("C++ codegen is disabled")
@ -1511,7 +1565,7 @@ class GraphLowering(torch.fx.Interpreter):
if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
def init_wrapper_code(self):
def init_wrapper_code(self) -> None:
self.cuda = "cuda" in self.device_types
if self.cpp_wrapper:
self.validate_can_generate_cpp_wrapper()
@ -1541,7 +1595,7 @@ class GraphLowering(torch.fx.Interpreter):
self.const_module.wrapper_code.src_to_kernel
)
def codegen_with_cpp_wrapper(self):
def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]:
"""
For CPU, the cpp wrapper codegen is done in one pass.
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
@ -1559,7 +1613,9 @@ class GraphLowering(torch.fx.Interpreter):
if not config.triton.autotune_at_compile_time:
def materialize(x):
def materialize(
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
) -> Union[int, float, torch.Tensor]:
if isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
@ -1617,7 +1673,10 @@ class GraphLowering(torch.fx.Interpreter):
# f, the inputs x will be mutated twice in the process:
# once here, and again when running the compiled model;
# this will also lead to a numerically incorrect output
real_inputs[idx] = clone_preserve_strides(real_inputs[idx])
mutated_inp = real_inputs[idx]
assert isinstance(mutated_inp, torch.Tensor)
real_inputs[idx] = clone_preserve_strides(mutated_inp)
del mutated_inp
with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)
@ -1636,7 +1695,7 @@ class GraphLowering(torch.fx.Interpreter):
# cpu
return self.codegen()
def codegen(self):
def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]:
from .scheduler import Scheduler
self.init_wrapper_code()
@ -1650,7 +1709,7 @@ class GraphLowering(torch.fx.Interpreter):
self.wrapper_code.pop_codegened_graph()
return result
def codegen_subgraph(self, parent_graph):
def codegen_subgraph(self, parent_graph: "GraphLowering") -> None:
"""
This is a more compact version of the `codegen()` above
where we codegen this graph as a subgraph of some parent
@ -1669,7 +1728,11 @@ class GraphLowering(torch.fx.Interpreter):
self.scheduler = Scheduler(self.operations)
self.scheduler.codegen()
def count_bytes(self):
def count_bytes(
self,
) -> Tuple[
int, List[Tuple[BaseSchedulerNode, int]], List[Tuple[BaseSchedulerNode, float]]
]:
total_bytes = 0
node_counts = []
node_runtimes = []
@ -1678,15 +1741,16 @@ class GraphLowering(torch.fx.Interpreter):
total_bytes += num_bytes
node_counts.append((node, num_bytes // 4))
node_runtimes.append((node, node.get_estimated_runtime()))
return total_bytes, node_counts, node_runtimes
@staticmethod
def save_output_code(code: str):
def save_output_code(code: str) -> None:
# No-op to be patched for unit tests
pass
@dynamo_timed(phase_name="code_gen", fwd_only=False)
def compile_to_module(self):
def compile_to_module(self) -> ModuleType:
from .codecache import PyCodeCache
code, linemap = (
@ -1696,7 +1760,7 @@ class GraphLowering(torch.fx.Interpreter):
GraphLowering.save_output_code(code)
output_code_log.debug("Output code: \n%s", code)
try:
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc]
key, path = PyCodeCache.write(code)
except Exception:
trace_structured(
@ -1715,12 +1779,12 @@ class GraphLowering(torch.fx.Interpreter):
mod = PyCodeCache.load_by_key_path(
key,
path,
linemap=linemap,
linemap=linemap, # type: ignore[arg-type]
attrs={**self.constants, **self.torchbind_constants},
)
self.cache_key = key
self.cache_path = path
self.cache_linemap = linemap
self.cache_linemap = linemap # type: ignore[assignment]
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
@ -1735,7 +1799,7 @@ class GraphLowering(torch.fx.Interpreter):
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def compile_to_fn(self):
def compile_to_fn(self) -> Any:
if self.aot_mode:
from .codecache import AotCodeCompiler
@ -1764,7 +1828,7 @@ class GraphLowering(torch.fx.Interpreter):
else:
return self.compile_to_module().call
def get_output_names(self):
def get_output_names(self) -> List[str]:
return [
node.get_name()
for node in self.graph_outputs
@ -1772,7 +1836,7 @@ class GraphLowering(torch.fx.Interpreter):
and not isinstance(node, ir.ShapeAsConstantBuffer)
]
def is_unspec_arg(self, name: str):
def is_unspec_arg(self, name: str) -> bool:
# dynamo wraps unspec variable as 0d CPU tensor,
# need to convert to scalar during codegen (triton only)
return (

View File

@ -16,6 +16,7 @@ from typing import (
Any,
Callable,
ClassVar,
ContextManager,
Dict,
Iterable,
List,
@ -4429,9 +4430,9 @@ class ExternKernel(InputsKernel):
# NOTE: Don't use extract_read_writes here as it fails when
# make_loader() inlines the computation
x_unwrap_view = x.unwrap_view()
x_unwrap_view_fx_node = V.graph.get_buffer(
x_unwrap_view.get_name()
).get_origin_node()
buf = V.graph.get_buffer(x_unwrap_view.get_name())
assert buf is not None
x_unwrap_view_fx_node = buf.get_origin_node()
# Prefer channels last format according to how the format is set from eager.
if (
x_unwrap_view_fx_node is not None
@ -5474,7 +5475,7 @@ class FallbackKernel(ExternKernelAlloc):
self.op_overload = kernel
self.unflatten_args = unflatten_args
self.kwargs = {} if kwargs is None else kwargs
V.graph.warn_fallback(self.python_kernel_name)
V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type]
# args that are aliased
self.alias_names: List[str] = []
@ -5848,8 +5849,8 @@ class FallbackKernel(ExternKernelAlloc):
@classmethod
def create(cls, kernel, *args, **kwargs):
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
context = (
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
context: ContextManager[None] = (
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment]
)
with context:
(

View File

@ -1654,7 +1654,7 @@ def _warn_complex_not_supported():
# There are some types (CPU) which we accept as input but not as
# output.
def unsupported_input_tensor(t: torch._subclasses.FakeTensor, parent=None):
def unsupported_input_tensor(t: torch.Tensor, parent=None):
"Do not support reading or writing to this tensor"
if t.is_complex():
# Complex views are supported with IR ComplexView
@ -1668,7 +1668,7 @@ def unsupported_input_tensor(t: torch._subclasses.FakeTensor, parent=None):
return False
def unsupported_output_tensor(t: torch._subclasses.FakeTensor, parent=None):
def unsupported_output_tensor(t: torch.Tensor, parent=None):
"Do not support writing tensor but can read from it"
if unsupported_input_tensor(t, parent):
return True

View File

@ -8,7 +8,7 @@ import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, List, Set, Tuple, TYPE_CHECKING
from torch._inductor import config
from torch._inductor.utils import get_benchmark_name
@ -16,12 +16,7 @@ from torch._inductor.utils import get_benchmark_name
# Prevent circular import
if TYPE_CHECKING:
from torch._inductor.scheduler import (
BaseSchedulerNode,
ExternKernelSchedulerNode,
NopKernelSchedulerNode,
SchedulerNode,
)
from torch._inductor.scheduler import BaseSchedulerNode
# counter for tracking how many kernels have been generated
generated_kernel_count = 0
@ -29,7 +24,7 @@ generated_cpp_vec_kernel_count = 0
num_bytes_accessed = 0
nodes_num_elem: List[
Tuple[
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
BaseSchedulerNode,
int,
]
] = []

View File

@ -751,7 +751,7 @@ class TritonTemplate(KernelTemplate):
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
output_tensor_meta=TensorMeta.from_irnodes(layout),
)

View File

@ -300,7 +300,9 @@ class SizeVarAllocator:
return False
def statically_known_equals(self, left: Expr, right: Union[Expr, int]) -> bool:
def statically_known_equals(
self, left: Union[Expr, int], right: Union[Expr, int]
) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left and right are equal.
"""