Add suppressions for _inductor/codegen (#165659)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-16 21:37:33 +00:00
committed by PyTorch MergeBot
parent cbc08c8993
commit 5641de7b6b
28 changed files with 157 additions and 11 deletions

View File

@ -23,7 +23,7 @@ project-excludes = [
# ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime",
"torch/_inductor/codegen",
"torch/_inductor/codegen/triton.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"torch/linalg/__init__.py",

View File

@ -950,6 +950,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
or _all_in_parens(string)
):
# don't put extra parens for strings that are already wrapped in parens
# pyrefly: ignore # bad-return
return string
return f"({string})"
@ -1736,7 +1737,9 @@ class KernelArgs:
)
)
for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items()
# pyrefly: ignore # bad-argument-type
self.input_buffers.items(),
self.output_buffers.items(),
):
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
continue
@ -2047,6 +2050,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> None:
super().__init__()
if increase_kernel_count:
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count += 1
self.args = args or KernelArgs()
self.loads = IndentedBuffer()
@ -2113,6 +2117,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.compute = compute
self.stores = stores
self.cse = cse
# pyrefly: ignore # unbound-name
if disallow_stores:
assert not sb, "unexpected store inside swap_buffers"
@ -2384,6 +2389,7 @@ class KernelTemplate:
class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error: TemplateSyntaxError) -> None:
super().__init__(
# pyrefly: ignore # bad-argument-type
original_error.message,
original_error.lineno,
original_error.name,
@ -2395,6 +2401,7 @@ class KernelTemplate:
error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"):
# pyrefly: ignore # missing-attribute
lines = self.original_error.source.split("\n")
error_info += "Context:\n"
start = max(0, self.lineno - 2)

View File

@ -504,6 +504,7 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
return cls(
node1.scheduler,
# pyrefly: ignore # bad-argument-type
(
list(node1.get_outer_nodes())
if type(node1) is OuterLoopFusedSchedulerNode
@ -1716,6 +1717,7 @@ class CppVecOverrides(CppOverrides):
body_vec_var.dtype = dtype
other_vec_var.dtype = dtype
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
# pyrefly: ignore # bad-assignment
V.kernel.overrides
) # type: ignore[has-type]
code.writeline(
@ -1759,6 +1761,7 @@ class CppVecOverrides(CppOverrides):
csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment]
None, index, dtype, V.kernel.compute
)
# pyrefly: ignore # missing-attribute
csevar.update_on_args("index_expr", (expr, dtype), {})
return csevar
@ -2036,6 +2039,7 @@ class CppKernel(Kernel):
# mask's dtype should be bool
mask.dtype = torch.bool
# pyrefly: ignore # bad-assignment
self._load_mask = mask
try:
yield mask
@ -2363,6 +2367,7 @@ class CppKernel(Kernel):
sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
for n in range(len(self.ranges))
]
# pyrefly: ignore # bad-assignment
self.reduction_depth = len(lengths)
return (
self.itervars[: self.reduction_depth],
@ -2610,7 +2615,9 @@ class CppKernel(Kernel):
and end == self.ranges[var_id]
):
end = 1
# pyrefly: ignore # bad-argument-type
conditions.append(f"{var} >= {cexpr_index(start)}")
# pyrefly: ignore # bad-argument-type
conditions.append(f"{var} < {cexpr_index(end)}")
return True
@ -4085,6 +4092,7 @@ class CppKernelProxy(CppKernel):
and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP
):
# No need to promote to float if all users are ops that accepts lowp fp input
# pyrefly: ignore # bad-argument-type
if all(is_lowp_fp_sink(user, dt) for user in _node.users):
continue
ops = _node.args[0]
@ -4095,12 +4103,14 @@ class CppKernelProxy(CppKernel):
_node.replace_all_uses_with(
to_type_node, lambda n: n is not to_type_node
)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
elif (
_node.target == "store"
and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP
):
ops, name, _, value_var, _ = _node.args
# pyrefly: ignore # bad-argument-type
if is_lowp_fp_source_no_promote(value_var, dt):
continue
dtype = V.graph.get_dtype(name)
@ -4109,6 +4119,7 @@ class CppKernelProxy(CppKernel):
"to_dtype", args=(ops, value_var, dtype)
)
_node.replace_input_with(value_var, to_type_node)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
elif _node.target == "reduction":
(
@ -4178,6 +4189,7 @@ class CppKernelProxy(CppKernel):
"to_dtype", args=(ops, value_var, src_dtype)
)
_node.replace_input_with(value_var, to_type_node)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
# to_dtype_bitcast act as a lowp fp source:
@ -4196,6 +4208,7 @@ class CppKernelProxy(CppKernel):
_node.replace_all_uses_with(
to_type_node, lambda n: n is not to_type_node
)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1
def eliminate_to_dtype(sub_graph: torch.fx.Graph):
@ -4289,6 +4302,7 @@ class CppKernelProxy(CppKernel):
with kernel_group.new_kernel(cls, *args) as kernel:
# Ugly hack to maintain the metrics kernel count since
# we only count in CppKernelProxy, not those contained in it
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count -= 1
run(kernel)
@ -4360,6 +4374,7 @@ class CppKernelProxy(CppKernel):
)
if len(tiling_indices) == 1:
# pyrefly: ignore # bad-assignment
metrics.generated_cpp_vec_kernel_count += 1
loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0])
vec_kernel = codegen_kernel(
@ -4386,6 +4401,7 @@ class CppKernelProxy(CppKernel):
and tiling_factors[0] == tiling_factors[1]
)
# pyrefly: ignore # bad-assignment
metrics.generated_cpp_vec_kernel_count += 2
outer_loop = self.loop_nest.tile(
tiling_indices[0], factor=tiling_factors[0]
@ -5134,10 +5150,12 @@ class CppScheduling(BaseScheduling):
contiguous_index_expr = 0
stride = 1
for var, range in reversed(
# pyrefly: ignore # missing-attribute
scheduler_node._body.var_ranges.items()
):
contiguous_index_expr += stride * var
stride *= range
# pyrefly: ignore # missing-attribute
write_index_expr = scheduler_node._body.get_write_expr(
scheduler_buffer.get_name()
)
@ -5206,6 +5224,7 @@ class CppScheduling(BaseScheduling):
)
local_buffers.append(local_buffer_used)
local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index]
# pyrefly: ignore # index-error
local_to_global_buffers[local_buffer_used.name].append(
global_buffer,
)
@ -5450,6 +5469,7 @@ class CppScheduling(BaseScheduling):
wrapper = V.graph.wrapper_code
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
# pyrefly: ignore # bad-argument-type
kernel_name,
)
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
@ -5771,6 +5791,7 @@ class LoopNest:
loop = self.loops[par_depth.start_depth]
loop.parallel = par_depth.parallel_depth
if loop.is_reduction:
# pyrefly: ignore # bad-assignment
metrics.parallel_reduction_count += 1
for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
self.loops[i].collapsed = True

View File

@ -396,12 +396,15 @@ def transpose_w(W: _T, trans_w: bool) -> _T:
if isinstance(W, ir.IRNode):
if trans_w:
if not isinstance(W, ir.TensorBox):
# pyrefly: ignore # bad-assignment
W = ir.TensorBox(W)
W = L.permute(W, [1, 0])
else:
if trans_w:
assert isinstance(W, torch.Tensor)
# pyrefly: ignore # bad-assignment
W = W.transpose(0, 1)
# pyrefly: ignore # bad-return
return W
@ -412,12 +415,15 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
if B is not None:
if isinstance(B, ir.IRNode):
if not isinstance(B, ir.TensorBox):
# pyrefly: ignore # bad-assignment
B = ir.TensorBox(B)
assert hasattr(X, "get_size")
# pyrefly: ignore # missing-attribute
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
else:
assert isinstance(B, torch.Tensor)
assert isinstance(X, torch.Tensor)
# pyrefly: ignore # bad-assignment
B = B.expand(X.shape[0], B.shape[-1])
return B
@ -1043,6 +1049,7 @@ class CppGemmTemplate(CppTemplate):
return cls.prep_weight(
new_inputs,
new_layout,
# pyrefly: ignore # bad-argument-type
micro_gemm,
pre_block_weights,
use_int8_fast_compensation_path,
@ -1066,6 +1073,7 @@ class CppGemmTemplate(CppTemplate):
new_input_nodes, _ = cls.prep_weight(
new_input_nodes,
new_layout,
# pyrefly: ignore # bad-argument-type
micro_gemm,
pre_block_weights,
use_int8_fast_compensation_path,
@ -1470,7 +1478,9 @@ class CppGemmTemplate(CppTemplate):
assert isinstance(template_buffer, ir.IRNode)
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
gemm_output_buffer = ir.Buffer(
name=gemm_output_name, layout=template_buffer.layout
# pyrefly: ignore # missing-attribute
name=gemm_output_name,
layout=template_buffer.layout,
)
current_input_buffer = gemm_output_buffer
for i, creator in enumerate(epilogue_creators):
@ -1481,6 +1491,7 @@ class CppGemmTemplate(CppTemplate):
epilogues.append(
ir.ComputedBuffer(
name=buffer_name,
# pyrefly: ignore # missing-attribute
layout=template_buffer.layout,
data=creator(current_input_buffer),
)
@ -1490,7 +1501,9 @@ class CppGemmTemplate(CppTemplate):
reindexers.append(None)
if i < len(epilogue_creators) - 1:
current_input_buffer = ir.Buffer(
name=buffer_name, layout=template_buffer.layout
# pyrefly: ignore # missing-attribute
name=buffer_name,
layout=template_buffer.layout,
)
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
@ -1521,6 +1534,7 @@ class CppGemmTemplate(CppTemplate):
self.n,
self.k,
input_dtype=X.get_dtype(),
# pyrefly: ignore # missing-attribute
input2_dtype=W.get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,

View File

@ -183,12 +183,14 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
)
self.act_mapping = act_mapping
self.gemm_grouped_num = gemm_grouped_num
# pyrefly: ignore # bad-override
self.output_node: list[ir.Buffer] = [
ir.Buffer(name="buf_out" + str(idx), layout=layout)
for idx in range(gemm_grouped_num)
]
@classmethod
# pyrefly: ignore # bad-override
def add_choices(
cls,
choices: list[ChoiceCaller],
@ -231,6 +233,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
if isinstance(inputs[idx], torch.Tensor):
W = inputs[idx]
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
# pyrefly: ignore # unsupported-operation
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
return new_inputs, layout_or_out
@ -246,8 +249,10 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
new_input = new_inputs[wgt_idx]
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
for bias_idx in range(bias_start_idx, len(new_inputs)):
# pyrefly: ignore # bad-argument-type
new_bias = expand_bias(new_inputs[bias_idx], X)
assert new_bias is not None
# pyrefly: ignore # unsupported-operation
new_inputs[bias_idx] = new_bias
return new_inputs, layout_or_out
@ -308,6 +313,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
W_tensor = []
for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants
# pyrefly: ignore # bad-argument-type
W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
W_tensor # type: ignore[assignment]
@ -324,6 +330,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
template_buffer.inputs[idx] = (
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
)
# pyrefly: ignore # bad-return
return output
template = DataProcessorTemplateWrapper(
@ -362,6 +369,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
cur_idx = bias_start_idx
for inp_idx in range(self.gemm_grouped_num):
inp = None
# pyrefly: ignore # index-error
if self.has_bias[inp_idx]:
inp = self.input_nodes[cur_idx]
cur_idx += 1
@ -390,6 +398,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
self.n,
self.k,
input_dtype=X_list[0].get_dtype(),
# pyrefly: ignore # missing-attribute
input2_dtype=W_list[0].get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,
@ -427,6 +436,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for x_idx in range(wgt_start_idx):
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
for w_idx in range(self.gemm_grouped_num):
# pyrefly: ignore # unsupported-operation
kernel_args["W" + str(w_idx)] = W_list[w_idx]
for inp_idx in range(self.gemm_grouped_num):
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]

View File

@ -85,6 +85,7 @@ class CppTemplate(KernelTemplate):
bmreq = CppBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
# pyrefly: ignore # bad-argument-type
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=extra_args,
source_code=code,
@ -112,6 +113,7 @@ class CppTemplate(KernelTemplate):
kernel_hash_name,
self.name,
self.input_nodes,
# pyrefly: ignore # index-error
self.output_node[0].get_layout()
if isinstance(self.output_node, Iterable)
else self.output_node.get_layout(),

View File

@ -411,6 +411,7 @@ class CppTemplateKernel(CppKernel):
)
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
return self.store_pointwise_nodes(
# pyrefly: ignore # bad-argument-type
dst,
epilogue_nodes, # type: ignore[arg-type]
offsets,
@ -422,6 +423,7 @@ class CppTemplateKernel(CppKernel):
copy = L.copy(dst, src).data.data
with LocalBufferContext(self.args) as scope:
scope.add_local_buffer(src)
# pyrefly: ignore # bad-argument-type
return self.store_pointwise_nodes(dst, [copy])
else:
assert dst.layout == src.layout, f"{dst=}, {src=}"

View File

@ -311,6 +311,7 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
return res
def store_reduction(self, name, index, value):
# pyrefly: ignore # bad-argument-count
return self._inner.store_reduction(*self.localize(name, index), value)

View File

@ -307,6 +307,7 @@ class DeferredTritonCallWrapper:
f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
]
)
# pyrefly: ignore # bad-argument-type
total_args.append(f"tmp_{arg_name}")
def process_args_for_input_shape(arg, arg_type, arg_signature=None):
@ -331,6 +332,7 @@ class DeferredTritonCallWrapper:
f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
]
)
# pyrefly: ignore # bad-argument-type
total_args.append(f"tmp_{arg_name}")
elif (
isinstance(arg_type, type(SymbolicCallArg))
@ -348,6 +350,7 @@ class DeferredTritonCallWrapper:
for arg, arg_type, arg_signature in zip_longest(
call_args, arg_types, arg_signatures
):
# pyrefly: ignore # bad-argument-type
ordered_argsname.append(f'"{arg}"')
process_args_for_input_shape(arg, arg_type, arg_signature)
@ -819,7 +822,9 @@ class CppWrapperGpu(CppWrapperCpu):
if triton:
call_args, arg_types = self.prepare_triton_wrapper_args(
call_args, arg_types
# pyrefly: ignore # bad-argument-type
call_args,
arg_types,
)
wrapper_name = f"call_{kernel_name}"
if wrapper_name not in self._triton_call_wrappers:
@ -843,10 +848,12 @@ class CppWrapperGpu(CppWrapperCpu):
self.writeline(f"{wrapper_name}({', '.join(call_args)});")
else:
casted = []
# pyrefly: ignore # no-matching-overload
for arg_type, arg in zip(arg_types, call_args):
new_arg = arg
if arg_type.endswith("*") and arg != "nullptr":
new_arg = f"{arg}.data_ptr()"
# pyrefly: ignore # bad-argument-type
casted.append(f"({arg_type}){cexpr(new_arg)}")
call_args_str = ", ".join(casted)
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")

View File

@ -190,6 +190,7 @@ class CUDACPPScheduling(BaseScheduling):
assert all(n.node is not None for n in nodes), (
"All epilogue nodes should have an IRNode"
)
# pyrefly: ignore # redundant-cast
return cast(
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
)

View File

@ -72,6 +72,7 @@ class CUDATemplate(KernelTemplate):
@classmethod
@functools.lru_cache(None)
# pyrefly: ignore # bad-override
def _template_from_string(cls, source: str) -> Any:
return KernelTemplate._template_from_string(source)

View File

@ -163,6 +163,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
) -> None:
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
# pyrefly: ignore # missing-attribute
self.visit(self.ast)
cc = int(cuda_env.get_cuda_arch())

View File

@ -470,6 +470,7 @@ class CUDACompileSourceCapturingContext:
self.sources.append(source_code)
return _compile_method_orig(source_code, dst_file_ext)
# pyrefly: ignore # bad-assignment
self._compile_patch = mock.patch(
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
)

View File

@ -286,6 +286,7 @@ class CuteDSLTemplateKernel(Kernel):
# Generate unpacking assignments: in_ptr4 = buffers[0], etc.
unpacking_lines = []
for i, buffer_name in enumerate(tensor_buffers):
# pyrefly: ignore # bad-argument-type
unpacking_lines.append(f"{buffer_name} = buffers[{i}]")
return "\n ".join(unpacking_lines)
@ -493,6 +494,7 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
"""Convert index variable to symbolic form."""
return sympy_index_symbol(str(index_var))
# pyrefly: ignore # bad-override
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> str:

View File

@ -274,7 +274,9 @@ class CuteDSLOpOverrides(OpOverrides):
else "mlir_math.absi"
)
return CuteDSLOpOverrides._apply_unary_op(
x, f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)"
# pyrefly: ignore # bad-argument-type
x,
f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)",
)
@staticmethod

View File

@ -43,6 +43,7 @@ class CuteDSLTemplate(KernelTemplate):
@staticmethod
@functools.lru_cache(None)
# pyrefly: ignore # bad-override
def _template_from_string(source: str) -> Any:
return KernelTemplate._template_from_string(source)

View File

@ -636,6 +636,7 @@ class DimensionInfo:
return "hl.Var()"
if replacements:
replacements = {**replacements}
# pyrefly: ignore # missing-attribute
for sym in expr.free_symbols:
if symbol_is_type(sym, SymT.TMP):
assert isinstance(sym, sympy.Symbol)
@ -709,8 +710,10 @@ class HalideKernel(SIMDKernel):
def dtype_to_str(self, dtype: torch.dtype) -> str:
return halide_type(dtype)
# pyrefly: ignore # bad-override
def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
self.body.writeline(f"{name} = hl.Func({name!r})")
# pyrefly: ignore # bad-argument-type
return HalideCSEVariable(name, bounds, dtype, shape)
def finalize_indexing(self, indices: Sequence[sympy.Expr]):
@ -728,6 +731,7 @@ class HalideKernel(SIMDKernel):
self.index_replacements or self.halide_vars or self.reduction_renames
)
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
# pyrefly: ignore # bad-assignment
indices = dict.fromkeys(map(super().prepare_indexing, indices))
all_used_symbols = OrderedSet[Any]()
sym_to_node = {
@ -826,6 +830,7 @@ class HalideKernel(SIMDKernel):
handled_count = len(nodes)
had_fallback = True
sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
# pyrefly: ignore # missing-argument
if tree.is_reduction:
self.reduction_renames[sym] = sympy_index_symbol(
f"hr{len(self.halide_vars)}"
@ -1222,8 +1227,10 @@ class HalideKernel(SIMDKernel):
parts = []
stride = 1
for i, sym in enumerate(self.reduction_renames):
# pyrefly: ignore # bad-argument-type
parts.append(f"{index}[{i}]")
if stride != 1:
# pyrefly: ignore # unsupported-operation
parts[-1] += f"*{stride}"
stride *= self.halide_vars[sym]
self.body.writeline(f"{result_var} = {' + '.join(parts)}")
@ -1576,6 +1583,7 @@ class HalideKernel(SIMDKernel):
hint = self._autoscheduler_workarounds(
V.graph.sizevars.size_hint(dim.size, fallback=1), dims
)
# pyrefly: ignore # bad-argument-type
range_hints.append(f"hl.Range(0, {hint})")
if "out" not in arg.name:
code.writeline(f"{arg.name}.dim({i}).set_min(0)")

View File

@ -516,6 +516,7 @@ class MetalKernel(SIMDKernel):
var = self.args.output(name)
index = self.prepare_indexing(index)
dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
# pyrefly: ignore # missing-argument
reduction_dim = next(t for t in self.range_trees if t.is_reduction)
# Only one thread in the reduction group needs to store the results
line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});"
@ -582,6 +583,7 @@ class MetalKernel(SIMDKernel):
reduction_idx = ""
acc_buf_size = 1
for rd in self.range_trees:
# pyrefly: ignore # missing-argument
if not rd.is_reduction:
continue
if reduction_idx:
@ -678,7 +680,10 @@ class MetalKernel(SIMDKernel):
)
idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment]
idx_var = next(
t for t in self.range_tree_nodes.values() if t.is_reduction
# pyrefly: ignore # missing-argument
t
for t in self.range_tree_nodes.values()
if t.is_reduction
)
cmp_op = ">" if reduction_type == "argmax" else "<"
nan_suffix = (
@ -745,6 +750,7 @@ class MetalKernel(SIMDKernel):
index_expr = self.rename_indexing(entry.expr)
index_str = self.sexpr(index_expr) # type: ignore[misc]
# pyrefly: ignore # missing-argument
if not entry.is_reduction or (
isinstance(entry.root.numel, sympy.Integer)
and entry.root.numel <= self.max_threadgroup_size
@ -856,7 +862,10 @@ class MetalKernel(SIMDKernel):
if self.inside_reduction:
total_reduction_size = math.prod(
t.numel for t in self.range_trees if t.is_reduction
# pyrefly: ignore # missing-argument
t.numel
for t in self.range_trees
if t.is_reduction
)
# If using dynamic shapes, set the threadgroup size to be the
# max possible size
@ -958,6 +967,7 @@ class MetalKernel(SIMDKernel):
else:
expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner
# pyrefly: ignore # missing-argument
if not tree.is_reduction or self.inside_reduction:
args.append(str(expr))
arg_types.append(int)
@ -977,6 +987,7 @@ class MetalKernel(SIMDKernel):
threads = [
expr_printer(
sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc]
# pyrefly: ignore # missing-argument
if v.is_reduction
else v.numel
)
@ -992,6 +1003,7 @@ class MetalKernel(SIMDKernel):
if self.inside_reduction:
threads = [
expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc]
# pyrefly: ignore # missing-argument
if v.is_reduction
else "1"
for v in self.active_range_trees()

View File

@ -306,6 +306,7 @@ class MultiKernelCall:
# manually force a subkernel to ease perf testing
picked_by_config = config.triton.multi_kernel - 2
assert picked_by_config < len(self._kernels)
# pyrefly: ignore # bad-assignment
self.picked_kernel = picked_by_config
elif not self.disable_cache:
self.load_cache()
@ -329,7 +330,9 @@ class MultiKernelCall:
path = self.cache_file_path()
if path.exists():
with path.open() as fd:
# pyrefly: ignore # bad-assignment
self.picked_kernel = int(fd.read())
# pyrefly: ignore # unsupported-operation
assert self.picked_kernel >= 0 and self.picked_kernel < len(
self._kernels
)
@ -599,5 +602,6 @@ class SizeHintMultiKernelCall(MultiKernelCall):
self._dist_heuristic(shape_key, key) if key is not None else 2**62
for key in self._kernel_hints
]
# pyrefly: ignore # bad-assignment
self.picked_kernel = dists.index(min(dists))
self._cache_shape_choice(shape_key, self.picked_kernel)

View File

@ -513,9 +513,11 @@ class CKGroupedConvFwdTemplate(CKTemplate):
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
else: # tile shape
arg = f"/* {field_name} */ S<{tuple_elements}>"
# pyrefly: ignore # bad-argument-type
template_params.append(arg)
else:
if field_value is not None:
# pyrefly: ignore # bad-argument-type
template_params.append(f"/* {field_name} */ {field_value}")
return self._template_from_string(template_definition).render(
operation_name=op.name(),

View File

@ -590,9 +590,11 @@ class CKGemmTemplate(CKTemplate):
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
else: # tile shape
arg = f"/* {field_name} */ S<{tuple_elements}>"
# pyrefly: ignore # bad-argument-type
template_params.append(arg)
else:
if field_value is not None:
# pyrefly: ignore # bad-argument-type
template_params.append(f"/* {field_name} */ {field_value}")
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
return self._template_from_string(template_definition).render(

View File

@ -187,6 +187,7 @@ class IterationRangesRoot(IterationRanges):
# True if the dimension is implemented as a single program looping over
# the full dimension (currently only used for non-persistent reduction)
# pyrefly: ignore # missing-argument
assert not is_loop or (self.is_reduction and grid_dim is None)
self.is_loop = is_loop
# Index of corresponding dimension on triton tensors
@ -374,6 +375,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
sexpr: Callable[[sympy.Expr], str] = pexpr
kexpr: Callable[[sympy.Expr], str]
allow_block_ptr: bool = False
# pyrefly: ignore # bad-override
kernel_name: str
def __init__(
@ -570,6 +572,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
if tree.tensor_dim is None:
continue
# pyrefly: ignore # missing-argument
if not tree.is_reduction or self.inside_reduction:
sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
return sizes
@ -962,7 +965,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
def active_range_trees(self) -> list[IterationRangesRoot]:
return [
t for t in self.range_trees if not t.is_reduction or self.inside_reduction
# pyrefly: ignore # missing-argument
t
for t in self.range_trees
if not t.is_reduction or self.inside_reduction
]
def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
@ -1110,6 +1116,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
numel = buf_size
dtype = V.graph.get_dtype(arg)
dtype_size = get_dtype_size(dtype)
# pyrefly: ignore # bad-argument-type
nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
return sum(nbytes)
@ -1130,6 +1137,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
argdefs, call_args, _signature, _ = self.args.python_argdefs()
uniform_stride_order = None
# pyrefly: ignore # bad-assignment
for arg_name in call_args:
buf = V.graph.try_get_buffer(arg_name)
if not buf:
@ -1753,11 +1761,13 @@ class SIMDScheduling(BaseScheduling):
for input_name in kernel.named_input_nodes.keys():
subgraph_name = f"<LOAD_INPUT_{input_name}>"
# pyrefly: ignore # missing-attribute
partial_code.finalize_hook(subgraph_name, strict=False)
num_store_subgraphs = kernel.get_store_output_count()
for i in range(num_store_subgraphs):
subgraph_name = kernel._get_store_output_subgraph_name(i)
# pyrefly: ignore # missing-attribute
partial_code.finalize_hook(subgraph_name)
if isinstance(partial_code, str):
@ -1879,6 +1889,7 @@ class SIMDScheduling(BaseScheduling):
only_gen_src_code=True,
)
assert isinstance(src_code, str)
# pyrefly: ignore # bad-argument-type
src_codes.append(src_code)
else:
if size_hint is None:
@ -2708,6 +2719,7 @@ class SIMDScheduling(BaseScheduling):
perf_hint_log.info("possibly bad tiling: %s", ranked_tilings)
# Optionally, prefer tiling into as many dimensions as possible.
# pyrefly: ignore # unbound-name
if config.triton.prefer_nd_tiling:
ranked_tilings = (
cls.get_nd_tilings(node_schedule, numel, reduction_numel)
@ -2757,6 +2769,7 @@ class SIMDScheduling(BaseScheduling):
hint_override=hint_override,
)
# pyrefly: ignore # missing-attribute
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
return src_code

View File

@ -80,6 +80,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_graph_lowering.graph_input_names.append(sym_inp.name)
sym_inputs = [
# pyrefly: ignore # no-matching-overload
int(V.graph.sizevars.shape_env.size_hint(sym_var))
for sym_var in self.sym_inputs
]

View File

@ -379,6 +379,7 @@ class ComboKernel(Kernel):
def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
sub_kernel = triton_kernel
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count -= 1
sub_kernel.args = self.args
sub_kernel.iter_vars_count = self.iter_vars_count
@ -434,10 +435,12 @@ class ComboKernel(Kernel):
assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args
uniquify_block_sizes.append(f"{tree.prefix}numel")
# pyrefly: ignore # missing-argument
if not tree.is_reduction:
if isinstance(simplified_tree_numel, (Integer, int)):
grid.append(int(simplified_tree_numel))
else:
# pyrefly: ignore # bad-argument-type
grid.append(f"{tree.prefix}numel_{num}")
if tree.is_reduction and sub_kernel.persistent_reduction:
@ -475,8 +478,10 @@ class ComboKernel(Kernel):
if sub_kernel.no_x_dim:
min_x_blocks = x_numels
x_numels = (
# pyrefly: ignore # unsupported-operation
-min_x_blocks
if isinstance(x_numels, int)
# pyrefly: ignore # redundant-cast
else "-" + cast(str, x_numels)
)
else:
@ -606,6 +611,7 @@ class ComboKernel(Kernel):
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
"constants": {},
}
# pyrefly: ignore # unsupported-operation
triton_meta["configs"] = [config_of(signature)]
mutated_args = self.get_mutated_args_sub_kernels()
dispatch = self.dispatch_class
@ -684,6 +690,7 @@ class ComboKernel(Kernel):
for sub_kernel in self.sub_kernels:
# TODO: we assume all sub_kernels have the same block size
for tree in sub_kernel.range_trees:
# pyrefly: ignore # missing-argument
if tree.is_reduction and (
not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
):
@ -722,6 +729,7 @@ class ComboKernel(Kernel):
expr = V.graph.wrapper_code.generate_numel_expr(
name, tree, suffix=str(num)
)
# pyrefly: ignore # missing-argument
if not tree.is_reduction or sub_kernel.inside_reduction:
call_args.append(expr)
arg_types.append(type(expr))
@ -733,6 +741,7 @@ class ComboKernel(Kernel):
numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args:
continue
# pyrefly: ignore # missing-argument
if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append(
str(
@ -1012,6 +1021,7 @@ class ComboKernel(Kernel):
for num, sub_kernel in enumerate(self.sub_kernels):
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
for i, tree in enumerate(sub_kernel.range_trees):
# pyrefly: ignore # missing-argument
if not tree.is_reduction:
numel_name = f"{tree.prefix}numel_{num}"
if numel_name in self.dynamic_shape_args:

View File

@ -256,4 +256,5 @@ def config_of(
equal_to_1 = equal_1_arg_indices(args, indices=indices)
# pyrefly: ignore # bad-argument-type
return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)

View File

@ -1115,6 +1115,7 @@ class PythonWrapperCodegen(CodeGen):
return PythonWrapperCodegen()
def set_launcher_fn_name(self) -> None:
# pyrefly: ignore # bad-assignment
self.launcher_fn_name = "call"
def write_constant(self, name: str, hashed: str) -> None:
@ -1251,14 +1252,17 @@ class PythonWrapperCodegen(CodeGen):
self.write_get_raw_stream_header()
def add_meta_once(self, meta: TritonMetaParams) -> str:
# pyrefly: ignore # bad-assignment
meta = repr(meta)
if meta not in self._metas:
var = f"meta{len(self._metas)}"
# pyrefly: ignore # unsupported-operation
self._metas[meta] = var
self.header.writeline(f"{var} = {meta}")
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(f"{var} = {meta}")
self._meta_vars.add(var)
# pyrefly: ignore # index-error
return self._metas[meta]
@cache_on_self
@ -1694,6 +1698,7 @@ class PythonWrapperCodegen(CodeGen):
with self.set_writeline(self.wrapper_call.writeline):
for line in self.lines:
if isinstance(line, WrapperLine):
# pyrefly: ignore # missing-attribute
line.codegen(self.wrapper_call)
else:
self.wrapper_call.writeline(line)
@ -2774,13 +2779,18 @@ class PythonWrapperCodegen(CodeGen):
self,
kernel_name=kernel_name,
call_args=call_args,
# pyrefly: ignore # bad-argument-type
raw_keys=raw_keys,
# pyrefly: ignore # bad-argument-type
raw_args=raw_args,
# pyrefly: ignore # bad-argument-type
arg_types=arg_types,
triton=triton,
# pyrefly: ignore # bad-argument-type
triton_meta=triton_meta,
device=device,
graph_name=V.graph.name,
# pyrefly: ignore # bad-argument-type
original_fxnode_name=original_fxnode_name,
)
)
@ -2901,6 +2911,7 @@ class PythonWrapperCodegen(CodeGen):
reused_args = {}
for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
# pyrefly: ignore # no-matching-overload
zip(call_args, arg_types, raw_keys, raw_args)
):
key = None
@ -3688,6 +3699,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of
# the subgraph.
# pyrefly: ignore # bad-assignment
self.launcher_fn_name = self.subgraph_name
def write_header(self) -> None:

View File

@ -186,6 +186,7 @@ class WrapperFxCodegen(PythonWrapperCodegen):
"""
Get the input nodes corresponding to FX graph placeholders.
"""
# pyrefly: ignore # missing-argument
if V.aot_compilation and not self.is_subgraph:
# AOT graphs must match the signature of the input module.
return {
@ -210,6 +211,7 @@ class WrapperFxCodegen(PythonWrapperCodegen):
graph_inputs=self.get_fx_graph_inputs(),
graph_outputs=self.get_graph_outputs(),
subgms=self.subgms,
# pyrefly: ignore # missing-argument
is_subgraph=self.is_subgraph,
).generate()
@ -992,13 +994,17 @@ class FxConverter:
call_kwargs = {
key: val
for key, val in zip(signature, call_args)
# pyrefly: ignore # missing-attribute
if key not in constants and key not in cfg.kwargs
}
# Add constants stored as Triton metadata, in signature order.
call_kwargs |= constants
new_call_args = [
call_kwargs[key] for key in signature if key not in cfg.kwargs
# pyrefly: ignore # missing-attribute
call_kwargs[key]
for key in signature
if key not in cfg.kwargs
]
# Add Inductor's extra launcher args to the end.
@ -1014,9 +1020,11 @@ class FxConverter:
call_args = add_constants_to_call_args(call_args, kernel_config)
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
call_kwargs = dict(zip(signature, call_args))
# pyrefly: ignore # missing-attribute
assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), (
f"kwargs overlap config: {call_kwargs}"
)
# pyrefly: ignore # missing-attribute
call_kwargs.update(kernel_config.kwargs)
# Replace sympy.floor with FloorDiv, to make the expression traceable.

View File

@ -356,7 +356,7 @@ def bucket_all_reduce(
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default,
)