Add suppressions for _inductor/codegen (#165659)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-16 21:37:33 +00:00
committed by PyTorch MergeBot
parent cbc08c8993
commit 5641de7b6b
28 changed files with 157 additions and 11 deletions

View File

@ -23,7 +23,7 @@ project-excludes = [
# ==== below will be enabled directory by directory ==== # ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ==== # ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime", "torch/_inductor/runtime",
"torch/_inductor/codegen", "torch/_inductor/codegen/triton.py",
# formatting issues, will turn on after adjusting where suppressions can be # formatting issues, will turn on after adjusting where suppressions can be
# in import statements # in import statements
"torch/linalg/__init__.py", "torch/linalg/__init__.py",

View File

@ -950,6 +950,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
or _all_in_parens(string) or _all_in_parens(string)
): ):
# don't put extra parens for strings that are already wrapped in parens # don't put extra parens for strings that are already wrapped in parens
# pyrefly: ignore # bad-return
return string return string
return f"({string})" return f"({string})"
@ -1736,7 +1737,9 @@ class KernelArgs:
) )
) )
for outer, inner in chain( for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items() # pyrefly: ignore # bad-argument-type
self.input_buffers.items(),
self.output_buffers.items(),
): ):
if outer in self.inplace_buffers or isinstance(inner, RemovedArg): if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
continue continue
@ -2047,6 +2050,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> None: ) -> None:
super().__init__() super().__init__()
if increase_kernel_count: if increase_kernel_count:
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count += 1 metrics.generated_kernel_count += 1
self.args = args or KernelArgs() self.args = args or KernelArgs()
self.loads = IndentedBuffer() self.loads = IndentedBuffer()
@ -2113,6 +2117,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.compute = compute self.compute = compute
self.stores = stores self.stores = stores
self.cse = cse self.cse = cse
# pyrefly: ignore # unbound-name
if disallow_stores: if disallow_stores:
assert not sb, "unexpected store inside swap_buffers" assert not sb, "unexpected store inside swap_buffers"
@ -2384,6 +2389,7 @@ class KernelTemplate:
class DetailedTemplateSyntaxError(TemplateSyntaxError): class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error: TemplateSyntaxError) -> None: def __init__(self, original_error: TemplateSyntaxError) -> None:
super().__init__( super().__init__(
# pyrefly: ignore # bad-argument-type
original_error.message, original_error.message,
original_error.lineno, original_error.lineno,
original_error.name, original_error.name,
@ -2395,6 +2401,7 @@ class KernelTemplate:
error_info = f"Error in template at line {self.lineno}\n" error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n" error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"): if hasattr(self.original_error, "source"):
# pyrefly: ignore # missing-attribute
lines = self.original_error.source.split("\n") lines = self.original_error.source.split("\n")
error_info += "Context:\n" error_info += "Context:\n"
start = max(0, self.lineno - 2) start = max(0, self.lineno - 2)

View File

@ -504,6 +504,7 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
return cls( return cls(
node1.scheduler, node1.scheduler,
# pyrefly: ignore # bad-argument-type
( (
list(node1.get_outer_nodes()) list(node1.get_outer_nodes())
if type(node1) is OuterLoopFusedSchedulerNode if type(node1) is OuterLoopFusedSchedulerNode
@ -1716,6 +1717,7 @@ class CppVecOverrides(CppOverrides):
body_vec_var.dtype = dtype body_vec_var.dtype = dtype
other_vec_var.dtype = dtype other_vec_var.dtype = dtype
overrides: type[Union[CppOverrides, CppVecOverrides]] = ( overrides: type[Union[CppOverrides, CppVecOverrides]] = (
# pyrefly: ignore # bad-assignment
V.kernel.overrides V.kernel.overrides
) # type: ignore[has-type] ) # type: ignore[has-type]
code.writeline( code.writeline(
@ -1759,6 +1761,7 @@ class CppVecOverrides(CppOverrides):
csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment]
None, index, dtype, V.kernel.compute None, index, dtype, V.kernel.compute
) )
# pyrefly: ignore # missing-attribute
csevar.update_on_args("index_expr", (expr, dtype), {}) csevar.update_on_args("index_expr", (expr, dtype), {})
return csevar return csevar
@ -2036,6 +2039,7 @@ class CppKernel(Kernel):
# mask's dtype should be bool # mask's dtype should be bool
mask.dtype = torch.bool mask.dtype = torch.bool
# pyrefly: ignore # bad-assignment
self._load_mask = mask self._load_mask = mask
try: try:
yield mask yield mask
@ -2363,6 +2367,7 @@ class CppKernel(Kernel):
sympy_index_symbol_with_prefix(SymT.XBLOCK, n) sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
for n in range(len(self.ranges)) for n in range(len(self.ranges))
] ]
# pyrefly: ignore # bad-assignment
self.reduction_depth = len(lengths) self.reduction_depth = len(lengths)
return ( return (
self.itervars[: self.reduction_depth], self.itervars[: self.reduction_depth],
@ -2610,7 +2615,9 @@ class CppKernel(Kernel):
and end == self.ranges[var_id] and end == self.ranges[var_id]
): ):
end = 1 end = 1
# pyrefly: ignore # bad-argument-type
conditions.append(f"{var} >= {cexpr_index(start)}") conditions.append(f"{var} >= {cexpr_index(start)}")
# pyrefly: ignore # bad-argument-type
conditions.append(f"{var} < {cexpr_index(end)}") conditions.append(f"{var} < {cexpr_index(end)}")
return True return True
@ -4085,6 +4092,7 @@ class CppKernelProxy(CppKernel):
and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP
): ):
# No need to promote to float if all users are ops that accepts lowp fp input # No need to promote to float if all users are ops that accepts lowp fp input
# pyrefly: ignore # bad-argument-type
if all(is_lowp_fp_sink(user, dt) for user in _node.users): if all(is_lowp_fp_sink(user, dt) for user in _node.users):
continue continue
ops = _node.args[0] ops = _node.args[0]
@ -4095,12 +4103,14 @@ class CppKernelProxy(CppKernel):
_node.replace_all_uses_with( _node.replace_all_uses_with(
to_type_node, lambda n: n is not to_type_node to_type_node, lambda n: n is not to_type_node
) )
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1 metrics.cpp_to_dtype_count += 1
elif ( elif (
_node.target == "store" _node.target == "store"
and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP
): ):
ops, name, _, value_var, _ = _node.args ops, name, _, value_var, _ = _node.args
# pyrefly: ignore # bad-argument-type
if is_lowp_fp_source_no_promote(value_var, dt): if is_lowp_fp_source_no_promote(value_var, dt):
continue continue
dtype = V.graph.get_dtype(name) dtype = V.graph.get_dtype(name)
@ -4109,6 +4119,7 @@ class CppKernelProxy(CppKernel):
"to_dtype", args=(ops, value_var, dtype) "to_dtype", args=(ops, value_var, dtype)
) )
_node.replace_input_with(value_var, to_type_node) _node.replace_input_with(value_var, to_type_node)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1 metrics.cpp_to_dtype_count += 1
elif _node.target == "reduction": elif _node.target == "reduction":
( (
@ -4178,6 +4189,7 @@ class CppKernelProxy(CppKernel):
"to_dtype", args=(ops, value_var, src_dtype) "to_dtype", args=(ops, value_var, src_dtype)
) )
_node.replace_input_with(value_var, to_type_node) _node.replace_input_with(value_var, to_type_node)
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1 metrics.cpp_to_dtype_count += 1
# to_dtype_bitcast act as a lowp fp source: # to_dtype_bitcast act as a lowp fp source:
@ -4196,6 +4208,7 @@ class CppKernelProxy(CppKernel):
_node.replace_all_uses_with( _node.replace_all_uses_with(
to_type_node, lambda n: n is not to_type_node to_type_node, lambda n: n is not to_type_node
) )
# pyrefly: ignore # bad-assignment
metrics.cpp_to_dtype_count += 1 metrics.cpp_to_dtype_count += 1
def eliminate_to_dtype(sub_graph: torch.fx.Graph): def eliminate_to_dtype(sub_graph: torch.fx.Graph):
@ -4289,6 +4302,7 @@ class CppKernelProxy(CppKernel):
with kernel_group.new_kernel(cls, *args) as kernel: with kernel_group.new_kernel(cls, *args) as kernel:
# Ugly hack to maintain the metrics kernel count since # Ugly hack to maintain the metrics kernel count since
# we only count in CppKernelProxy, not those contained in it # we only count in CppKernelProxy, not those contained in it
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count -= 1 metrics.generated_kernel_count -= 1
run(kernel) run(kernel)
@ -4360,6 +4374,7 @@ class CppKernelProxy(CppKernel):
) )
if len(tiling_indices) == 1: if len(tiling_indices) == 1:
# pyrefly: ignore # bad-assignment
metrics.generated_cpp_vec_kernel_count += 1 metrics.generated_cpp_vec_kernel_count += 1
loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0]) loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0])
vec_kernel = codegen_kernel( vec_kernel = codegen_kernel(
@ -4386,6 +4401,7 @@ class CppKernelProxy(CppKernel):
and tiling_factors[0] == tiling_factors[1] and tiling_factors[0] == tiling_factors[1]
) )
# pyrefly: ignore # bad-assignment
metrics.generated_cpp_vec_kernel_count += 2 metrics.generated_cpp_vec_kernel_count += 2
outer_loop = self.loop_nest.tile( outer_loop = self.loop_nest.tile(
tiling_indices[0], factor=tiling_factors[0] tiling_indices[0], factor=tiling_factors[0]
@ -5134,10 +5150,12 @@ class CppScheduling(BaseScheduling):
contiguous_index_expr = 0 contiguous_index_expr = 0
stride = 1 stride = 1
for var, range in reversed( for var, range in reversed(
# pyrefly: ignore # missing-attribute
scheduler_node._body.var_ranges.items() scheduler_node._body.var_ranges.items()
): ):
contiguous_index_expr += stride * var contiguous_index_expr += stride * var
stride *= range stride *= range
# pyrefly: ignore # missing-attribute
write_index_expr = scheduler_node._body.get_write_expr( write_index_expr = scheduler_node._body.get_write_expr(
scheduler_buffer.get_name() scheduler_buffer.get_name()
) )
@ -5206,6 +5224,7 @@ class CppScheduling(BaseScheduling):
) )
local_buffers.append(local_buffer_used) local_buffers.append(local_buffer_used)
local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index] local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index]
# pyrefly: ignore # index-error
local_to_global_buffers[local_buffer_used.name].append( local_to_global_buffers[local_buffer_used.name].append(
global_buffer, global_buffer,
) )
@ -5450,6 +5469,7 @@ class CppScheduling(BaseScheduling):
wrapper = V.graph.wrapper_code wrapper = V.graph.wrapper_code
debug_handle = set_kernel_post_grad_provenance_tracing( debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type] node_schedule, # type: ignore[arg-type]
# pyrefly: ignore # bad-argument-type
kernel_name, kernel_name,
) )
wrapper.write_provenance_debug_handle(kernel_name, debug_handle) wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
@ -5771,6 +5791,7 @@ class LoopNest:
loop = self.loops[par_depth.start_depth] loop = self.loops[par_depth.start_depth]
loop.parallel = par_depth.parallel_depth loop.parallel = par_depth.parallel_depth
if loop.is_reduction: if loop.is_reduction:
# pyrefly: ignore # bad-assignment
metrics.parallel_reduction_count += 1 metrics.parallel_reduction_count += 1
for i in range(par_depth.start_depth + 1, par_depth.parallel_depth): for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
self.loops[i].collapsed = True self.loops[i].collapsed = True

View File

@ -396,12 +396,15 @@ def transpose_w(W: _T, trans_w: bool) -> _T:
if isinstance(W, ir.IRNode): if isinstance(W, ir.IRNode):
if trans_w: if trans_w:
if not isinstance(W, ir.TensorBox): if not isinstance(W, ir.TensorBox):
# pyrefly: ignore # bad-assignment
W = ir.TensorBox(W) W = ir.TensorBox(W)
W = L.permute(W, [1, 0]) W = L.permute(W, [1, 0])
else: else:
if trans_w: if trans_w:
assert isinstance(W, torch.Tensor) assert isinstance(W, torch.Tensor)
# pyrefly: ignore # bad-assignment
W = W.transpose(0, 1) W = W.transpose(0, 1)
# pyrefly: ignore # bad-return
return W return W
@ -412,12 +415,15 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
if B is not None: if B is not None:
if isinstance(B, ir.IRNode): if isinstance(B, ir.IRNode):
if not isinstance(B, ir.TensorBox): if not isinstance(B, ir.TensorBox):
# pyrefly: ignore # bad-assignment
B = ir.TensorBox(B) B = ir.TensorBox(B)
assert hasattr(X, "get_size") assert hasattr(X, "get_size")
# pyrefly: ignore # missing-attribute
B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
else: else:
assert isinstance(B, torch.Tensor) assert isinstance(B, torch.Tensor)
assert isinstance(X, torch.Tensor) assert isinstance(X, torch.Tensor)
# pyrefly: ignore # bad-assignment
B = B.expand(X.shape[0], B.shape[-1]) B = B.expand(X.shape[0], B.shape[-1])
return B return B
@ -1043,6 +1049,7 @@ class CppGemmTemplate(CppTemplate):
return cls.prep_weight( return cls.prep_weight(
new_inputs, new_inputs,
new_layout, new_layout,
# pyrefly: ignore # bad-argument-type
micro_gemm, micro_gemm,
pre_block_weights, pre_block_weights,
use_int8_fast_compensation_path, use_int8_fast_compensation_path,
@ -1066,6 +1073,7 @@ class CppGemmTemplate(CppTemplate):
new_input_nodes, _ = cls.prep_weight( new_input_nodes, _ = cls.prep_weight(
new_input_nodes, new_input_nodes,
new_layout, new_layout,
# pyrefly: ignore # bad-argument-type
micro_gemm, micro_gemm,
pre_block_weights, pre_block_weights,
use_int8_fast_compensation_path, use_int8_fast_compensation_path,
@ -1470,7 +1478,9 @@ class CppGemmTemplate(CppTemplate):
assert isinstance(template_buffer, ir.IRNode) assert isinstance(template_buffer, ir.IRNode)
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
gemm_output_buffer = ir.Buffer( gemm_output_buffer = ir.Buffer(
name=gemm_output_name, layout=template_buffer.layout # pyrefly: ignore # missing-attribute
name=gemm_output_name,
layout=template_buffer.layout,
) )
current_input_buffer = gemm_output_buffer current_input_buffer = gemm_output_buffer
for i, creator in enumerate(epilogue_creators): for i, creator in enumerate(epilogue_creators):
@ -1481,6 +1491,7 @@ class CppGemmTemplate(CppTemplate):
epilogues.append( epilogues.append(
ir.ComputedBuffer( ir.ComputedBuffer(
name=buffer_name, name=buffer_name,
# pyrefly: ignore # missing-attribute
layout=template_buffer.layout, layout=template_buffer.layout,
data=creator(current_input_buffer), data=creator(current_input_buffer),
) )
@ -1490,7 +1501,9 @@ class CppGemmTemplate(CppTemplate):
reindexers.append(None) reindexers.append(None)
if i < len(epilogue_creators) - 1: if i < len(epilogue_creators) - 1:
current_input_buffer = ir.Buffer( current_input_buffer = ir.Buffer(
name=buffer_name, layout=template_buffer.layout # pyrefly: ignore # missing-attribute
name=buffer_name,
layout=template_buffer.layout,
) )
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView)) assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
@ -1521,6 +1534,7 @@ class CppGemmTemplate(CppTemplate):
self.n, self.n,
self.k, self.k,
input_dtype=X.get_dtype(), input_dtype=X.get_dtype(),
# pyrefly: ignore # missing-attribute
input2_dtype=W.get_dtype(), input2_dtype=W.get_dtype(),
output_dtype=output_dtype, output_dtype=output_dtype,
compute_dtype=compute_dtype, compute_dtype=compute_dtype,

View File

@ -183,12 +183,14 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
) )
self.act_mapping = act_mapping self.act_mapping = act_mapping
self.gemm_grouped_num = gemm_grouped_num self.gemm_grouped_num = gemm_grouped_num
# pyrefly: ignore # bad-override
self.output_node: list[ir.Buffer] = [ self.output_node: list[ir.Buffer] = [
ir.Buffer(name="buf_out" + str(idx), layout=layout) ir.Buffer(name="buf_out" + str(idx), layout=layout)
for idx in range(gemm_grouped_num) for idx in range(gemm_grouped_num)
] ]
@classmethod @classmethod
# pyrefly: ignore # bad-override
def add_choices( def add_choices(
cls, cls,
choices: list[ChoiceCaller], choices: list[ChoiceCaller],
@ -231,6 +233,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
if isinstance(inputs[idx], torch.Tensor): if isinstance(inputs[idx], torch.Tensor):
W = inputs[idx] W = inputs[idx]
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor" assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
# pyrefly: ignore # unsupported-operation
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
return new_inputs, layout_or_out return new_inputs, layout_or_out
@ -246,8 +249,10 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
new_input = new_inputs[wgt_idx] new_input = new_inputs[wgt_idx]
new_inputs[wgt_idx] = transpose_w(new_input, trans_w) new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
for bias_idx in range(bias_start_idx, len(new_inputs)): for bias_idx in range(bias_start_idx, len(new_inputs)):
# pyrefly: ignore # bad-argument-type
new_bias = expand_bias(new_inputs[bias_idx], X) new_bias = expand_bias(new_inputs[bias_idx], X)
assert new_bias is not None assert new_bias is not None
# pyrefly: ignore # unsupported-operation
new_inputs[bias_idx] = new_bias new_inputs[bias_idx] = new_bias
return new_inputs, layout_or_out return new_inputs, layout_or_out
@ -308,6 +313,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
W_tensor = [] W_tensor = []
for W_node in W_nodes: for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants assert W_node.get_name() in V.graph.constants
# pyrefly: ignore # bad-argument-type
W_tensor.append(V.graph.constants[W_node.get_name()]) W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = ( new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
W_tensor # type: ignore[assignment] W_tensor # type: ignore[assignment]
@ -324,6 +330,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
template_buffer.inputs[idx] = ( template_buffer.inputs[idx] = (
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
) )
# pyrefly: ignore # bad-return
return output return output
template = DataProcessorTemplateWrapper( template = DataProcessorTemplateWrapper(
@ -362,6 +369,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
cur_idx = bias_start_idx cur_idx = bias_start_idx
for inp_idx in range(self.gemm_grouped_num): for inp_idx in range(self.gemm_grouped_num):
inp = None inp = None
# pyrefly: ignore # index-error
if self.has_bias[inp_idx]: if self.has_bias[inp_idx]:
inp = self.input_nodes[cur_idx] inp = self.input_nodes[cur_idx]
cur_idx += 1 cur_idx += 1
@ -390,6 +398,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
self.n, self.n,
self.k, self.k,
input_dtype=X_list[0].get_dtype(), input_dtype=X_list[0].get_dtype(),
# pyrefly: ignore # missing-attribute
input2_dtype=W_list[0].get_dtype(), input2_dtype=W_list[0].get_dtype(),
output_dtype=output_dtype, output_dtype=output_dtype,
compute_dtype=compute_dtype, compute_dtype=compute_dtype,
@ -427,6 +436,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for x_idx in range(wgt_start_idx): for x_idx in range(wgt_start_idx):
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx] kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
for w_idx in range(self.gemm_grouped_num): for w_idx in range(self.gemm_grouped_num):
# pyrefly: ignore # unsupported-operation
kernel_args["W" + str(w_idx)] = W_list[w_idx] kernel_args["W" + str(w_idx)] = W_list[w_idx]
for inp_idx in range(self.gemm_grouped_num): for inp_idx in range(self.gemm_grouped_num):
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx] kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]

View File

@ -85,6 +85,7 @@ class CppTemplate(KernelTemplate):
bmreq = CppBenchmarkRequest( bmreq = CppBenchmarkRequest(
kernel_name=kernel_name, kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
# pyrefly: ignore # bad-argument-type
output_tensor_meta=TensorMeta.from_irnodes(self.output_node), output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=extra_args, extra_args=extra_args,
source_code=code, source_code=code,
@ -112,6 +113,7 @@ class CppTemplate(KernelTemplate):
kernel_hash_name, kernel_hash_name,
self.name, self.name,
self.input_nodes, self.input_nodes,
# pyrefly: ignore # index-error
self.output_node[0].get_layout() self.output_node[0].get_layout()
if isinstance(self.output_node, Iterable) if isinstance(self.output_node, Iterable)
else self.output_node.get_layout(), else self.output_node.get_layout(),

View File

@ -411,6 +411,7 @@ class CppTemplateKernel(CppKernel):
) )
epilogue_nodes = scope.localize_nodes(epilogue_nodes) epilogue_nodes = scope.localize_nodes(epilogue_nodes)
return self.store_pointwise_nodes( return self.store_pointwise_nodes(
# pyrefly: ignore # bad-argument-type
dst, dst,
epilogue_nodes, # type: ignore[arg-type] epilogue_nodes, # type: ignore[arg-type]
offsets, offsets,
@ -422,6 +423,7 @@ class CppTemplateKernel(CppKernel):
copy = L.copy(dst, src).data.data copy = L.copy(dst, src).data.data
with LocalBufferContext(self.args) as scope: with LocalBufferContext(self.args) as scope:
scope.add_local_buffer(src) scope.add_local_buffer(src)
# pyrefly: ignore # bad-argument-type
return self.store_pointwise_nodes(dst, [copy]) return self.store_pointwise_nodes(dst, [copy])
else: else:
assert dst.layout == src.layout, f"{dst=}, {src=}" assert dst.layout == src.layout, f"{dst=}, {src=}"

View File

@ -311,6 +311,7 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
return res return res
def store_reduction(self, name, index, value): def store_reduction(self, name, index, value):
# pyrefly: ignore # bad-argument-count
return self._inner.store_reduction(*self.localize(name, index), value) return self._inner.store_reduction(*self.localize(name, index), value)

View File

@ -307,6 +307,7 @@ class DeferredTritonCallWrapper:
f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
] ]
) )
# pyrefly: ignore # bad-argument-type
total_args.append(f"tmp_{arg_name}") total_args.append(f"tmp_{arg_name}")
def process_args_for_input_shape(arg, arg_type, arg_signature=None): def process_args_for_input_shape(arg, arg_type, arg_signature=None):
@ -331,6 +332,7 @@ class DeferredTritonCallWrapper:
f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
] ]
) )
# pyrefly: ignore # bad-argument-type
total_args.append(f"tmp_{arg_name}") total_args.append(f"tmp_{arg_name}")
elif ( elif (
isinstance(arg_type, type(SymbolicCallArg)) isinstance(arg_type, type(SymbolicCallArg))
@ -348,6 +350,7 @@ class DeferredTritonCallWrapper:
for arg, arg_type, arg_signature in zip_longest( for arg, arg_type, arg_signature in zip_longest(
call_args, arg_types, arg_signatures call_args, arg_types, arg_signatures
): ):
# pyrefly: ignore # bad-argument-type
ordered_argsname.append(f'"{arg}"') ordered_argsname.append(f'"{arg}"')
process_args_for_input_shape(arg, arg_type, arg_signature) process_args_for_input_shape(arg, arg_type, arg_signature)
@ -819,7 +822,9 @@ class CppWrapperGpu(CppWrapperCpu):
if triton: if triton:
call_args, arg_types = self.prepare_triton_wrapper_args( call_args, arg_types = self.prepare_triton_wrapper_args(
call_args, arg_types # pyrefly: ignore # bad-argument-type
call_args,
arg_types,
) )
wrapper_name = f"call_{kernel_name}" wrapper_name = f"call_{kernel_name}"
if wrapper_name not in self._triton_call_wrappers: if wrapper_name not in self._triton_call_wrappers:
@ -843,10 +848,12 @@ class CppWrapperGpu(CppWrapperCpu):
self.writeline(f"{wrapper_name}({', '.join(call_args)});") self.writeline(f"{wrapper_name}({', '.join(call_args)});")
else: else:
casted = [] casted = []
# pyrefly: ignore # no-matching-overload
for arg_type, arg in zip(arg_types, call_args): for arg_type, arg in zip(arg_types, call_args):
new_arg = arg new_arg = arg
if arg_type.endswith("*") and arg != "nullptr": if arg_type.endswith("*") and arg != "nullptr":
new_arg = f"{arg}.data_ptr()" new_arg = f"{arg}.data_ptr()"
# pyrefly: ignore # bad-argument-type
casted.append(f"({arg_type}){cexpr(new_arg)}") casted.append(f"({arg_type}){cexpr(new_arg)}")
call_args_str = ", ".join(casted) call_args_str = ", ".join(casted)
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")

View File

@ -190,6 +190,7 @@ class CUDACPPScheduling(BaseScheduling):
assert all(n.node is not None for n in nodes), ( assert all(n.node is not None for n in nodes), (
"All epilogue nodes should have an IRNode" "All epilogue nodes should have an IRNode"
) )
# pyrefly: ignore # redundant-cast
return cast( return cast(
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node] list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
) )

View File

@ -72,6 +72,7 @@ class CUDATemplate(KernelTemplate):
@classmethod @classmethod
@functools.lru_cache(None) @functools.lru_cache(None)
# pyrefly: ignore # bad-override
def _template_from_string(cls, source: str) -> Any: def _template_from_string(cls, source: str) -> Any:
return KernelTemplate._template_from_string(source) return KernelTemplate._template_from_string(source)

View File

@ -163,6 +163,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
) -> None: ) -> None:
self.example_inputs = example_inputs self.example_inputs = example_inputs
self.ast = ast.parse(self.source) self.ast = ast.parse(self.source)
# pyrefly: ignore # missing-attribute
self.visit(self.ast) self.visit(self.ast)
cc = int(cuda_env.get_cuda_arch()) cc = int(cuda_env.get_cuda_arch())

View File

@ -470,6 +470,7 @@ class CUDACompileSourceCapturingContext:
self.sources.append(source_code) self.sources.append(source_code)
return _compile_method_orig(source_code, dst_file_ext) return _compile_method_orig(source_code, dst_file_ext)
# pyrefly: ignore # bad-assignment
self._compile_patch = mock.patch( self._compile_patch = mock.patch(
"torch._inductor.codecache.CUDACodeCache.compile", my_compile "torch._inductor.codecache.CUDACodeCache.compile", my_compile
) )

View File

@ -286,6 +286,7 @@ class CuteDSLTemplateKernel(Kernel):
# Generate unpacking assignments: in_ptr4 = buffers[0], etc. # Generate unpacking assignments: in_ptr4 = buffers[0], etc.
unpacking_lines = [] unpacking_lines = []
for i, buffer_name in enumerate(tensor_buffers): for i, buffer_name in enumerate(tensor_buffers):
# pyrefly: ignore # bad-argument-type
unpacking_lines.append(f"{buffer_name} = buffers[{i}]") unpacking_lines.append(f"{buffer_name} = buffers[{i}]")
return "\n ".join(unpacking_lines) return "\n ".join(unpacking_lines)
@ -493,6 +494,7 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
"""Convert index variable to symbolic form.""" """Convert index variable to symbolic form."""
return sympy_index_symbol(str(index_var)) return sympy_index_symbol(str(index_var))
# pyrefly: ignore # bad-override
def store( def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> str: ) -> str:

View File

@ -274,7 +274,9 @@ class CuteDSLOpOverrides(OpOverrides):
else "mlir_math.absi" else "mlir_math.absi"
) )
return CuteDSLOpOverrides._apply_unary_op( return CuteDSLOpOverrides._apply_unary_op(
x, f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)" # pyrefly: ignore # bad-argument-type
x,
f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)",
) )
@staticmethod @staticmethod

View File

@ -43,6 +43,7 @@ class CuteDSLTemplate(KernelTemplate):
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
# pyrefly: ignore # bad-override
def _template_from_string(source: str) -> Any: def _template_from_string(source: str) -> Any:
return KernelTemplate._template_from_string(source) return KernelTemplate._template_from_string(source)

View File

@ -636,6 +636,7 @@ class DimensionInfo:
return "hl.Var()" return "hl.Var()"
if replacements: if replacements:
replacements = {**replacements} replacements = {**replacements}
# pyrefly: ignore # missing-attribute
for sym in expr.free_symbols: for sym in expr.free_symbols:
if symbol_is_type(sym, SymT.TMP): if symbol_is_type(sym, SymT.TMP):
assert isinstance(sym, sympy.Symbol) assert isinstance(sym, sympy.Symbol)
@ -709,8 +710,10 @@ class HalideKernel(SIMDKernel):
def dtype_to_str(self, dtype: torch.dtype) -> str: def dtype_to_str(self, dtype: torch.dtype) -> str:
return halide_type(dtype) return halide_type(dtype)
# pyrefly: ignore # bad-override
def create_cse_var(self, name, bounds=None, dtype=None, shape=None): def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
self.body.writeline(f"{name} = hl.Func({name!r})") self.body.writeline(f"{name} = hl.Func({name!r})")
# pyrefly: ignore # bad-argument-type
return HalideCSEVariable(name, bounds, dtype, shape) return HalideCSEVariable(name, bounds, dtype, shape)
def finalize_indexing(self, indices: Sequence[sympy.Expr]): def finalize_indexing(self, indices: Sequence[sympy.Expr]):
@ -728,6 +731,7 @@ class HalideKernel(SIMDKernel):
self.index_replacements or self.halide_vars or self.reduction_renames self.index_replacements or self.halide_vars or self.reduction_renames
) )
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
# pyrefly: ignore # bad-assignment
indices = dict.fromkeys(map(super().prepare_indexing, indices)) indices = dict.fromkeys(map(super().prepare_indexing, indices))
all_used_symbols = OrderedSet[Any]() all_used_symbols = OrderedSet[Any]()
sym_to_node = { sym_to_node = {
@ -826,6 +830,7 @@ class HalideKernel(SIMDKernel):
handled_count = len(nodes) handled_count = len(nodes)
had_fallback = True had_fallback = True
sym = sympy_index_symbol(f"h{len(self.halide_vars)}") sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
# pyrefly: ignore # missing-argument
if tree.is_reduction: if tree.is_reduction:
self.reduction_renames[sym] = sympy_index_symbol( self.reduction_renames[sym] = sympy_index_symbol(
f"hr{len(self.halide_vars)}" f"hr{len(self.halide_vars)}"
@ -1222,8 +1227,10 @@ class HalideKernel(SIMDKernel):
parts = [] parts = []
stride = 1 stride = 1
for i, sym in enumerate(self.reduction_renames): for i, sym in enumerate(self.reduction_renames):
# pyrefly: ignore # bad-argument-type
parts.append(f"{index}[{i}]") parts.append(f"{index}[{i}]")
if stride != 1: if stride != 1:
# pyrefly: ignore # unsupported-operation
parts[-1] += f"*{stride}" parts[-1] += f"*{stride}"
stride *= self.halide_vars[sym] stride *= self.halide_vars[sym]
self.body.writeline(f"{result_var} = {' + '.join(parts)}") self.body.writeline(f"{result_var} = {' + '.join(parts)}")
@ -1576,6 +1583,7 @@ class HalideKernel(SIMDKernel):
hint = self._autoscheduler_workarounds( hint = self._autoscheduler_workarounds(
V.graph.sizevars.size_hint(dim.size, fallback=1), dims V.graph.sizevars.size_hint(dim.size, fallback=1), dims
) )
# pyrefly: ignore # bad-argument-type
range_hints.append(f"hl.Range(0, {hint})") range_hints.append(f"hl.Range(0, {hint})")
if "out" not in arg.name: if "out" not in arg.name:
code.writeline(f"{arg.name}.dim({i}).set_min(0)") code.writeline(f"{arg.name}.dim({i}).set_min(0)")

View File

@ -516,6 +516,7 @@ class MetalKernel(SIMDKernel):
var = self.args.output(name) var = self.args.output(name)
index = self.prepare_indexing(index) index = self.prepare_indexing(index)
dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
# pyrefly: ignore # missing-argument
reduction_dim = next(t for t in self.range_trees if t.is_reduction) reduction_dim = next(t for t in self.range_trees if t.is_reduction)
# Only one thread in the reduction group needs to store the results # Only one thread in the reduction group needs to store the results
line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});"
@ -582,6 +583,7 @@ class MetalKernel(SIMDKernel):
reduction_idx = "" reduction_idx = ""
acc_buf_size = 1 acc_buf_size = 1
for rd in self.range_trees: for rd in self.range_trees:
# pyrefly: ignore # missing-argument
if not rd.is_reduction: if not rd.is_reduction:
continue continue
if reduction_idx: if reduction_idx:
@ -678,7 +680,10 @@ class MetalKernel(SIMDKernel):
) )
idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment] idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment]
idx_var = next( idx_var = next(
t for t in self.range_tree_nodes.values() if t.is_reduction # pyrefly: ignore # missing-argument
t
for t in self.range_tree_nodes.values()
if t.is_reduction
) )
cmp_op = ">" if reduction_type == "argmax" else "<" cmp_op = ">" if reduction_type == "argmax" else "<"
nan_suffix = ( nan_suffix = (
@ -745,6 +750,7 @@ class MetalKernel(SIMDKernel):
index_expr = self.rename_indexing(entry.expr) index_expr = self.rename_indexing(entry.expr)
index_str = self.sexpr(index_expr) # type: ignore[misc] index_str = self.sexpr(index_expr) # type: ignore[misc]
# pyrefly: ignore # missing-argument
if not entry.is_reduction or ( if not entry.is_reduction or (
isinstance(entry.root.numel, sympy.Integer) isinstance(entry.root.numel, sympy.Integer)
and entry.root.numel <= self.max_threadgroup_size and entry.root.numel <= self.max_threadgroup_size
@ -856,7 +862,10 @@ class MetalKernel(SIMDKernel):
if self.inside_reduction: if self.inside_reduction:
total_reduction_size = math.prod( total_reduction_size = math.prod(
t.numel for t in self.range_trees if t.is_reduction # pyrefly: ignore # missing-argument
t.numel
for t in self.range_trees
if t.is_reduction
) )
# If using dynamic shapes, set the threadgroup size to be the # If using dynamic shapes, set the threadgroup size to be the
# max possible size # max possible size
@ -958,6 +967,7 @@ class MetalKernel(SIMDKernel):
else: else:
expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner
# pyrefly: ignore # missing-argument
if not tree.is_reduction or self.inside_reduction: if not tree.is_reduction or self.inside_reduction:
args.append(str(expr)) args.append(str(expr))
arg_types.append(int) arg_types.append(int)
@ -977,6 +987,7 @@ class MetalKernel(SIMDKernel):
threads = [ threads = [
expr_printer( expr_printer(
sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc] sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc]
# pyrefly: ignore # missing-argument
if v.is_reduction if v.is_reduction
else v.numel else v.numel
) )
@ -992,6 +1003,7 @@ class MetalKernel(SIMDKernel):
if self.inside_reduction: if self.inside_reduction:
threads = [ threads = [
expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc]
# pyrefly: ignore # missing-argument
if v.is_reduction if v.is_reduction
else "1" else "1"
for v in self.active_range_trees() for v in self.active_range_trees()

View File

@ -306,6 +306,7 @@ class MultiKernelCall:
# manually force a subkernel to ease perf testing # manually force a subkernel to ease perf testing
picked_by_config = config.triton.multi_kernel - 2 picked_by_config = config.triton.multi_kernel - 2
assert picked_by_config < len(self._kernels) assert picked_by_config < len(self._kernels)
# pyrefly: ignore # bad-assignment
self.picked_kernel = picked_by_config self.picked_kernel = picked_by_config
elif not self.disable_cache: elif not self.disable_cache:
self.load_cache() self.load_cache()
@ -329,7 +330,9 @@ class MultiKernelCall:
path = self.cache_file_path() path = self.cache_file_path()
if path.exists(): if path.exists():
with path.open() as fd: with path.open() as fd:
# pyrefly: ignore # bad-assignment
self.picked_kernel = int(fd.read()) self.picked_kernel = int(fd.read())
# pyrefly: ignore # unsupported-operation
assert self.picked_kernel >= 0 and self.picked_kernel < len( assert self.picked_kernel >= 0 and self.picked_kernel < len(
self._kernels self._kernels
) )
@ -599,5 +602,6 @@ class SizeHintMultiKernelCall(MultiKernelCall):
self._dist_heuristic(shape_key, key) if key is not None else 2**62 self._dist_heuristic(shape_key, key) if key is not None else 2**62
for key in self._kernel_hints for key in self._kernel_hints
] ]
# pyrefly: ignore # bad-assignment
self.picked_kernel = dists.index(min(dists)) self.picked_kernel = dists.index(min(dists))
self._cache_shape_choice(shape_key, self.picked_kernel) self._cache_shape_choice(shape_key, self.picked_kernel)

View File

@ -513,9 +513,11 @@ class CKGroupedConvFwdTemplate(CKTemplate):
arg = f"/* {field_name} */ Tuple<{tuple_elements}>" arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
else: # tile shape else: # tile shape
arg = f"/* {field_name} */ S<{tuple_elements}>" arg = f"/* {field_name} */ S<{tuple_elements}>"
# pyrefly: ignore # bad-argument-type
template_params.append(arg) template_params.append(arg)
else: else:
if field_value is not None: if field_value is not None:
# pyrefly: ignore # bad-argument-type
template_params.append(f"/* {field_name} */ {field_value}") template_params.append(f"/* {field_name} */ {field_value}")
return self._template_from_string(template_definition).render( return self._template_from_string(template_definition).render(
operation_name=op.name(), operation_name=op.name(),

View File

@ -590,9 +590,11 @@ class CKGemmTemplate(CKTemplate):
arg = f"/* {field_name} */ Tuple<{tuple_elements}>" arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
else: # tile shape else: # tile shape
arg = f"/* {field_name} */ S<{tuple_elements}>" arg = f"/* {field_name} */ S<{tuple_elements}>"
# pyrefly: ignore # bad-argument-type
template_params.append(arg) template_params.append(arg)
else: else:
if field_value is not None: if field_value is not None:
# pyrefly: ignore # bad-argument-type
template_params.append(f"/* {field_name} */ {field_value}") template_params.append(f"/* {field_name} */ {field_value}")
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
return self._template_from_string(template_definition).render( return self._template_from_string(template_definition).render(

View File

@ -187,6 +187,7 @@ class IterationRangesRoot(IterationRanges):
# True if the dimension is implemented as a single program looping over # True if the dimension is implemented as a single program looping over
# the full dimension (currently only used for non-persistent reduction) # the full dimension (currently only used for non-persistent reduction)
# pyrefly: ignore # missing-argument
assert not is_loop or (self.is_reduction and grid_dim is None) assert not is_loop or (self.is_reduction and grid_dim is None)
self.is_loop = is_loop self.is_loop = is_loop
# Index of corresponding dimension on triton tensors # Index of corresponding dimension on triton tensors
@ -374,6 +375,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
sexpr: Callable[[sympy.Expr], str] = pexpr sexpr: Callable[[sympy.Expr], str] = pexpr
kexpr: Callable[[sympy.Expr], str] kexpr: Callable[[sympy.Expr], str]
allow_block_ptr: bool = False allow_block_ptr: bool = False
# pyrefly: ignore # bad-override
kernel_name: str kernel_name: str
def __init__( def __init__(
@ -570,6 +572,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
if tree.tensor_dim is None: if tree.tensor_dim is None:
continue continue
# pyrefly: ignore # missing-argument
if not tree.is_reduction or self.inside_reduction: if not tree.is_reduction or self.inside_reduction:
sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
return sizes return sizes
@ -962,7 +965,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
def active_range_trees(self) -> list[IterationRangesRoot]: def active_range_trees(self) -> list[IterationRangesRoot]:
return [ return [
t for t in self.range_trees if not t.is_reduction or self.inside_reduction # pyrefly: ignore # missing-argument
t
for t in self.range_trees
if not t.is_reduction or self.inside_reduction
] ]
def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
@ -1110,6 +1116,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
numel = buf_size numel = buf_size
dtype = V.graph.get_dtype(arg) dtype = V.graph.get_dtype(arg)
dtype_size = get_dtype_size(dtype) dtype_size = get_dtype_size(dtype)
# pyrefly: ignore # bad-argument-type
nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
return sum(nbytes) return sum(nbytes)
@ -1130,6 +1137,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
argdefs, call_args, _signature, _ = self.args.python_argdefs() argdefs, call_args, _signature, _ = self.args.python_argdefs()
uniform_stride_order = None uniform_stride_order = None
# pyrefly: ignore # bad-assignment
for arg_name in call_args: for arg_name in call_args:
buf = V.graph.try_get_buffer(arg_name) buf = V.graph.try_get_buffer(arg_name)
if not buf: if not buf:
@ -1753,11 +1761,13 @@ class SIMDScheduling(BaseScheduling):
for input_name in kernel.named_input_nodes.keys(): for input_name in kernel.named_input_nodes.keys():
subgraph_name = f"<LOAD_INPUT_{input_name}>" subgraph_name = f"<LOAD_INPUT_{input_name}>"
# pyrefly: ignore # missing-attribute
partial_code.finalize_hook(subgraph_name, strict=False) partial_code.finalize_hook(subgraph_name, strict=False)
num_store_subgraphs = kernel.get_store_output_count() num_store_subgraphs = kernel.get_store_output_count()
for i in range(num_store_subgraphs): for i in range(num_store_subgraphs):
subgraph_name = kernel._get_store_output_subgraph_name(i) subgraph_name = kernel._get_store_output_subgraph_name(i)
# pyrefly: ignore # missing-attribute
partial_code.finalize_hook(subgraph_name) partial_code.finalize_hook(subgraph_name)
if isinstance(partial_code, str): if isinstance(partial_code, str):
@ -1879,6 +1889,7 @@ class SIMDScheduling(BaseScheduling):
only_gen_src_code=True, only_gen_src_code=True,
) )
assert isinstance(src_code, str) assert isinstance(src_code, str)
# pyrefly: ignore # bad-argument-type
src_codes.append(src_code) src_codes.append(src_code)
else: else:
if size_hint is None: if size_hint is None:
@ -2708,6 +2719,7 @@ class SIMDScheduling(BaseScheduling):
perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) perf_hint_log.info("possibly bad tiling: %s", ranked_tilings)
# Optionally, prefer tiling into as many dimensions as possible. # Optionally, prefer tiling into as many dimensions as possible.
# pyrefly: ignore # unbound-name
if config.triton.prefer_nd_tiling: if config.triton.prefer_nd_tiling:
ranked_tilings = ( ranked_tilings = (
cls.get_nd_tilings(node_schedule, numel, reduction_numel) cls.get_nd_tilings(node_schedule, numel, reduction_numel)
@ -2757,6 +2769,7 @@ class SIMDScheduling(BaseScheduling):
hint_override=hint_override, hint_override=hint_override,
) )
# pyrefly: ignore # missing-attribute
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
return src_code return src_code

View File

@ -80,6 +80,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_graph_lowering.graph_input_names.append(sym_inp.name) bm_graph_lowering.graph_input_names.append(sym_inp.name)
sym_inputs = [ sym_inputs = [
# pyrefly: ignore # no-matching-overload
int(V.graph.sizevars.shape_env.size_hint(sym_var)) int(V.graph.sizevars.shape_env.size_hint(sym_var))
for sym_var in self.sym_inputs for sym_var in self.sym_inputs
] ]

View File

@ -379,6 +379,7 @@ class ComboKernel(Kernel):
def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
sub_kernel = triton_kernel sub_kernel = triton_kernel
# pyrefly: ignore # bad-assignment
metrics.generated_kernel_count -= 1 metrics.generated_kernel_count -= 1
sub_kernel.args = self.args sub_kernel.args = self.args
sub_kernel.iter_vars_count = self.iter_vars_count sub_kernel.iter_vars_count = self.iter_vars_count
@ -434,10 +435,12 @@ class ComboKernel(Kernel):
assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args
uniquify_block_sizes.append(f"{tree.prefix}numel") uniquify_block_sizes.append(f"{tree.prefix}numel")
# pyrefly: ignore # missing-argument
if not tree.is_reduction: if not tree.is_reduction:
if isinstance(simplified_tree_numel, (Integer, int)): if isinstance(simplified_tree_numel, (Integer, int)):
grid.append(int(simplified_tree_numel)) grid.append(int(simplified_tree_numel))
else: else:
# pyrefly: ignore # bad-argument-type
grid.append(f"{tree.prefix}numel_{num}") grid.append(f"{tree.prefix}numel_{num}")
if tree.is_reduction and sub_kernel.persistent_reduction: if tree.is_reduction and sub_kernel.persistent_reduction:
@ -475,8 +478,10 @@ class ComboKernel(Kernel):
if sub_kernel.no_x_dim: if sub_kernel.no_x_dim:
min_x_blocks = x_numels min_x_blocks = x_numels
x_numels = ( x_numels = (
# pyrefly: ignore # unsupported-operation
-min_x_blocks -min_x_blocks
if isinstance(x_numels, int) if isinstance(x_numels, int)
# pyrefly: ignore # redundant-cast
else "-" + cast(str, x_numels) else "-" + cast(str, x_numels)
) )
else: else:
@ -606,6 +611,7 @@ class ComboKernel(Kernel):
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()), "device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
"constants": {}, "constants": {},
} }
# pyrefly: ignore # unsupported-operation
triton_meta["configs"] = [config_of(signature)] triton_meta["configs"] = [config_of(signature)]
mutated_args = self.get_mutated_args_sub_kernels() mutated_args = self.get_mutated_args_sub_kernels()
dispatch = self.dispatch_class dispatch = self.dispatch_class
@ -684,6 +690,7 @@ class ComboKernel(Kernel):
for sub_kernel in self.sub_kernels: for sub_kernel in self.sub_kernels:
# TODO: we assume all sub_kernels have the same block size # TODO: we assume all sub_kernels have the same block size
for tree in sub_kernel.range_trees: for tree in sub_kernel.range_trees:
# pyrefly: ignore # missing-argument
if tree.is_reduction and ( if tree.is_reduction and (
not sub_kernel.inside_reduction or sub_kernel.persistent_reduction not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
): ):
@ -722,6 +729,7 @@ class ComboKernel(Kernel):
expr = V.graph.wrapper_code.generate_numel_expr( expr = V.graph.wrapper_code.generate_numel_expr(
name, tree, suffix=str(num) name, tree, suffix=str(num)
) )
# pyrefly: ignore # missing-argument
if not tree.is_reduction or sub_kernel.inside_reduction: if not tree.is_reduction or sub_kernel.inside_reduction:
call_args.append(expr) call_args.append(expr)
arg_types.append(type(expr)) arg_types.append(type(expr))
@ -733,6 +741,7 @@ class ComboKernel(Kernel):
numel_name = f"{tree.prefix}numel_{num}" numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args: if numel_name not in self.dynamic_shape_args:
continue continue
# pyrefly: ignore # missing-argument
if not tree.is_reduction or sub_kernel.inside_reduction: if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append( extra_args.append(
str( str(
@ -1012,6 +1021,7 @@ class ComboKernel(Kernel):
for num, sub_kernel in enumerate(self.sub_kernels): for num, sub_kernel in enumerate(self.sub_kernels):
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
for i, tree in enumerate(sub_kernel.range_trees): for i, tree in enumerate(sub_kernel.range_trees):
# pyrefly: ignore # missing-argument
if not tree.is_reduction: if not tree.is_reduction:
numel_name = f"{tree.prefix}numel_{num}" numel_name = f"{tree.prefix}numel_{num}"
if numel_name in self.dynamic_shape_args: if numel_name in self.dynamic_shape_args:

View File

@ -256,4 +256,5 @@ def config_of(
equal_to_1 = equal_1_arg_indices(args, indices=indices) equal_to_1 = equal_1_arg_indices(args, indices=indices)
# pyrefly: ignore # bad-argument-type
return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)

View File

@ -1115,6 +1115,7 @@ class PythonWrapperCodegen(CodeGen):
return PythonWrapperCodegen() return PythonWrapperCodegen()
def set_launcher_fn_name(self) -> None: def set_launcher_fn_name(self) -> None:
# pyrefly: ignore # bad-assignment
self.launcher_fn_name = "call" self.launcher_fn_name = "call"
def write_constant(self, name: str, hashed: str) -> None: def write_constant(self, name: str, hashed: str) -> None:
@ -1251,14 +1252,17 @@ class PythonWrapperCodegen(CodeGen):
self.write_get_raw_stream_header() self.write_get_raw_stream_header()
def add_meta_once(self, meta: TritonMetaParams) -> str: def add_meta_once(self, meta: TritonMetaParams) -> str:
# pyrefly: ignore # bad-assignment
meta = repr(meta) meta = repr(meta)
if meta not in self._metas: if meta not in self._metas:
var = f"meta{len(self._metas)}" var = f"meta{len(self._metas)}"
# pyrefly: ignore # unsupported-operation
self._metas[meta] = var self._metas[meta] = var
self.header.writeline(f"{var} = {meta}") self.header.writeline(f"{var} = {meta}")
if config.triton.autotune_at_compile_time: if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(f"{var} = {meta}") self.kernel_autotune_calls.writeline(f"{var} = {meta}")
self._meta_vars.add(var) self._meta_vars.add(var)
# pyrefly: ignore # index-error
return self._metas[meta] return self._metas[meta]
@cache_on_self @cache_on_self
@ -1694,6 +1698,7 @@ class PythonWrapperCodegen(CodeGen):
with self.set_writeline(self.wrapper_call.writeline): with self.set_writeline(self.wrapper_call.writeline):
for line in self.lines: for line in self.lines:
if isinstance(line, WrapperLine): if isinstance(line, WrapperLine):
# pyrefly: ignore # missing-attribute
line.codegen(self.wrapper_call) line.codegen(self.wrapper_call)
else: else:
self.wrapper_call.writeline(line) self.wrapper_call.writeline(line)
@ -2774,13 +2779,18 @@ class PythonWrapperCodegen(CodeGen):
self, self,
kernel_name=kernel_name, kernel_name=kernel_name,
call_args=call_args, call_args=call_args,
# pyrefly: ignore # bad-argument-type
raw_keys=raw_keys, raw_keys=raw_keys,
# pyrefly: ignore # bad-argument-type
raw_args=raw_args, raw_args=raw_args,
# pyrefly: ignore # bad-argument-type
arg_types=arg_types, arg_types=arg_types,
triton=triton, triton=triton,
# pyrefly: ignore # bad-argument-type
triton_meta=triton_meta, triton_meta=triton_meta,
device=device, device=device,
graph_name=V.graph.name, graph_name=V.graph.name,
# pyrefly: ignore # bad-argument-type
original_fxnode_name=original_fxnode_name, original_fxnode_name=original_fxnode_name,
) )
) )
@ -2901,6 +2911,7 @@ class PythonWrapperCodegen(CodeGen):
reused_args = {} reused_args = {}
for i, (arg, arg_type, raw_key, raw_arg) in enumerate( for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
# pyrefly: ignore # no-matching-overload
zip(call_args, arg_types, raw_keys, raw_args) zip(call_args, arg_types, raw_keys, raw_args)
): ):
key = None key = None
@ -3688,6 +3699,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
def set_launcher_fn_name(self) -> None: def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of # This sets up the name of the function containing the launcher code of
# the subgraph. # the subgraph.
# pyrefly: ignore # bad-assignment
self.launcher_fn_name = self.subgraph_name self.launcher_fn_name = self.subgraph_name
def write_header(self) -> None: def write_header(self) -> None:

View File

@ -186,6 +186,7 @@ class WrapperFxCodegen(PythonWrapperCodegen):
""" """
Get the input nodes corresponding to FX graph placeholders. Get the input nodes corresponding to FX graph placeholders.
""" """
# pyrefly: ignore # missing-argument
if V.aot_compilation and not self.is_subgraph: if V.aot_compilation and not self.is_subgraph:
# AOT graphs must match the signature of the input module. # AOT graphs must match the signature of the input module.
return { return {
@ -210,6 +211,7 @@ class WrapperFxCodegen(PythonWrapperCodegen):
graph_inputs=self.get_fx_graph_inputs(), graph_inputs=self.get_fx_graph_inputs(),
graph_outputs=self.get_graph_outputs(), graph_outputs=self.get_graph_outputs(),
subgms=self.subgms, subgms=self.subgms,
# pyrefly: ignore # missing-argument
is_subgraph=self.is_subgraph, is_subgraph=self.is_subgraph,
).generate() ).generate()
@ -992,13 +994,17 @@ class FxConverter:
call_kwargs = { call_kwargs = {
key: val key: val
for key, val in zip(signature, call_args) for key, val in zip(signature, call_args)
# pyrefly: ignore # missing-attribute
if key not in constants and key not in cfg.kwargs if key not in constants and key not in cfg.kwargs
} }
# Add constants stored as Triton metadata, in signature order. # Add constants stored as Triton metadata, in signature order.
call_kwargs |= constants call_kwargs |= constants
new_call_args = [ new_call_args = [
call_kwargs[key] for key in signature if key not in cfg.kwargs # pyrefly: ignore # missing-attribute
call_kwargs[key]
for key in signature
if key not in cfg.kwargs
] ]
# Add Inductor's extra launcher args to the end. # Add Inductor's extra launcher args to the end.
@ -1014,9 +1020,11 @@ class FxConverter:
call_args = add_constants_to_call_args(call_args, kernel_config) call_args = add_constants_to_call_args(call_args, kernel_config)
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
call_kwargs = dict(zip(signature, call_args)) call_kwargs = dict(zip(signature, call_args))
# pyrefly: ignore # missing-attribute
assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), ( assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), (
f"kwargs overlap config: {call_kwargs}" f"kwargs overlap config: {call_kwargs}"
) )
# pyrefly: ignore # missing-attribute
call_kwargs.update(kernel_config.kwargs) call_kwargs.update(kernel_config.kwargs)
# Replace sympy.floor with FloorDiv, to make the expression traceable. # Replace sympy.floor with FloorDiv, to make the expression traceable.

View File

@ -356,7 +356,7 @@ def bucket_all_reduce(
mode: str | None = None, mode: str | None = None,
) -> None: ) -> None:
if bucket_cap_mb_by_bucket_idx is None: if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import ( from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default, bucket_cap_mb_by_bucket_idx_default,
) )