mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add suppressions for _inductor/codegen (#165659)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbc08c8993
commit
5641de7b6b
@ -23,7 +23,7 @@ project-excludes = [
|
||||
# ==== below will be enabled directory by directory ====
|
||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||
"torch/_inductor/runtime",
|
||||
"torch/_inductor/codegen",
|
||||
"torch/_inductor/codegen/triton.py",
|
||||
# formatting issues, will turn on after adjusting where suppressions can be
|
||||
# in import statements
|
||||
"torch/linalg/__init__.py",
|
||||
|
@ -950,6 +950,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||
or _all_in_parens(string)
|
||||
):
|
||||
# don't put extra parens for strings that are already wrapped in parens
|
||||
# pyrefly: ignore # bad-return
|
||||
return string
|
||||
return f"({string})"
|
||||
|
||||
@ -1736,7 +1737,9 @@ class KernelArgs:
|
||||
)
|
||||
)
|
||||
for outer, inner in chain(
|
||||
self.input_buffers.items(), self.output_buffers.items()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.input_buffers.items(),
|
||||
self.output_buffers.items(),
|
||||
):
|
||||
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
||||
continue
|
||||
@ -2047,6 +2050,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if increase_kernel_count:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_kernel_count += 1
|
||||
self.args = args or KernelArgs()
|
||||
self.loads = IndentedBuffer()
|
||||
@ -2113,6 +2117,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.compute = compute
|
||||
self.stores = stores
|
||||
self.cse = cse
|
||||
# pyrefly: ignore # unbound-name
|
||||
if disallow_stores:
|
||||
assert not sb, "unexpected store inside swap_buffers"
|
||||
|
||||
@ -2384,6 +2389,7 @@ class KernelTemplate:
|
||||
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
||||
def __init__(self, original_error: TemplateSyntaxError) -> None:
|
||||
super().__init__(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
original_error.message,
|
||||
original_error.lineno,
|
||||
original_error.name,
|
||||
@ -2395,6 +2401,7 @@ class KernelTemplate:
|
||||
error_info = f"Error in template at line {self.lineno}\n"
|
||||
error_info += f"Error message: {self.message}\n"
|
||||
if hasattr(self.original_error, "source"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
lines = self.original_error.source.split("\n")
|
||||
error_info += "Context:\n"
|
||||
start = max(0, self.lineno - 2)
|
||||
|
@ -504,6 +504,7 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
|
||||
if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
|
||||
return cls(
|
||||
node1.scheduler,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
list(node1.get_outer_nodes())
|
||||
if type(node1) is OuterLoopFusedSchedulerNode
|
||||
@ -1716,6 +1717,7 @@ class CppVecOverrides(CppOverrides):
|
||||
body_vec_var.dtype = dtype
|
||||
other_vec_var.dtype = dtype
|
||||
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
V.kernel.overrides
|
||||
) # type: ignore[has-type]
|
||||
code.writeline(
|
||||
@ -1759,6 +1761,7 @@ class CppVecOverrides(CppOverrides):
|
||||
csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment]
|
||||
None, index, dtype, V.kernel.compute
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
csevar.update_on_args("index_expr", (expr, dtype), {})
|
||||
return csevar
|
||||
|
||||
@ -2036,6 +2039,7 @@ class CppKernel(Kernel):
|
||||
# mask's dtype should be bool
|
||||
mask.dtype = torch.bool
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._load_mask = mask
|
||||
try:
|
||||
yield mask
|
||||
@ -2363,6 +2367,7 @@ class CppKernel(Kernel):
|
||||
sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
|
||||
for n in range(len(self.ranges))
|
||||
]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.reduction_depth = len(lengths)
|
||||
return (
|
||||
self.itervars[: self.reduction_depth],
|
||||
@ -2610,7 +2615,9 @@ class CppKernel(Kernel):
|
||||
and end == self.ranges[var_id]
|
||||
):
|
||||
end = 1
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
conditions.append(f"{var} >= {cexpr_index(start)}")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
conditions.append(f"{var} < {cexpr_index(end)}")
|
||||
return True
|
||||
|
||||
@ -4085,6 +4092,7 @@ class CppKernelProxy(CppKernel):
|
||||
and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP
|
||||
):
|
||||
# No need to promote to float if all users are ops that accepts lowp fp input
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if all(is_lowp_fp_sink(user, dt) for user in _node.users):
|
||||
continue
|
||||
ops = _node.args[0]
|
||||
@ -4095,12 +4103,14 @@ class CppKernelProxy(CppKernel):
|
||||
_node.replace_all_uses_with(
|
||||
to_type_node, lambda n: n is not to_type_node
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
elif (
|
||||
_node.target == "store"
|
||||
and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP
|
||||
):
|
||||
ops, name, _, value_var, _ = _node.args
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if is_lowp_fp_source_no_promote(value_var, dt):
|
||||
continue
|
||||
dtype = V.graph.get_dtype(name)
|
||||
@ -4109,6 +4119,7 @@ class CppKernelProxy(CppKernel):
|
||||
"to_dtype", args=(ops, value_var, dtype)
|
||||
)
|
||||
_node.replace_input_with(value_var, to_type_node)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
elif _node.target == "reduction":
|
||||
(
|
||||
@ -4178,6 +4189,7 @@ class CppKernelProxy(CppKernel):
|
||||
"to_dtype", args=(ops, value_var, src_dtype)
|
||||
)
|
||||
_node.replace_input_with(value_var, to_type_node)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
|
||||
# to_dtype_bitcast act as a lowp fp source:
|
||||
@ -4196,6 +4208,7 @@ class CppKernelProxy(CppKernel):
|
||||
_node.replace_all_uses_with(
|
||||
to_type_node, lambda n: n is not to_type_node
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
|
||||
def eliminate_to_dtype(sub_graph: torch.fx.Graph):
|
||||
@ -4289,6 +4302,7 @@ class CppKernelProxy(CppKernel):
|
||||
with kernel_group.new_kernel(cls, *args) as kernel:
|
||||
# Ugly hack to maintain the metrics kernel count since
|
||||
# we only count in CppKernelProxy, not those contained in it
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_kernel_count -= 1
|
||||
|
||||
run(kernel)
|
||||
@ -4360,6 +4374,7 @@ class CppKernelProxy(CppKernel):
|
||||
)
|
||||
|
||||
if len(tiling_indices) == 1:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_cpp_vec_kernel_count += 1
|
||||
loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0])
|
||||
vec_kernel = codegen_kernel(
|
||||
@ -4386,6 +4401,7 @@ class CppKernelProxy(CppKernel):
|
||||
and tiling_factors[0] == tiling_factors[1]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_cpp_vec_kernel_count += 2
|
||||
outer_loop = self.loop_nest.tile(
|
||||
tiling_indices[0], factor=tiling_factors[0]
|
||||
@ -5134,10 +5150,12 @@ class CppScheduling(BaseScheduling):
|
||||
contiguous_index_expr = 0
|
||||
stride = 1
|
||||
for var, range in reversed(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
scheduler_node._body.var_ranges.items()
|
||||
):
|
||||
contiguous_index_expr += stride * var
|
||||
stride *= range
|
||||
# pyrefly: ignore # missing-attribute
|
||||
write_index_expr = scheduler_node._body.get_write_expr(
|
||||
scheduler_buffer.get_name()
|
||||
)
|
||||
@ -5206,6 +5224,7 @@ class CppScheduling(BaseScheduling):
|
||||
)
|
||||
local_buffers.append(local_buffer_used)
|
||||
local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index]
|
||||
# pyrefly: ignore # index-error
|
||||
local_to_global_buffers[local_buffer_used.name].append(
|
||||
global_buffer,
|
||||
)
|
||||
@ -5450,6 +5469,7 @@ class CppScheduling(BaseScheduling):
|
||||
wrapper = V.graph.wrapper_code
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kernel_name,
|
||||
)
|
||||
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
@ -5771,6 +5791,7 @@ class LoopNest:
|
||||
loop = self.loops[par_depth.start_depth]
|
||||
loop.parallel = par_depth.parallel_depth
|
||||
if loop.is_reduction:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.parallel_reduction_count += 1
|
||||
for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
|
||||
self.loops[i].collapsed = True
|
||||
|
@ -396,12 +396,15 @@ def transpose_w(W: _T, trans_w: bool) -> _T:
|
||||
if isinstance(W, ir.IRNode):
|
||||
if trans_w:
|
||||
if not isinstance(W, ir.TensorBox):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
W = ir.TensorBox(W)
|
||||
W = L.permute(W, [1, 0])
|
||||
else:
|
||||
if trans_w:
|
||||
assert isinstance(W, torch.Tensor)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
W = W.transpose(0, 1)
|
||||
# pyrefly: ignore # bad-return
|
||||
return W
|
||||
|
||||
|
||||
@ -412,12 +415,15 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
|
||||
if B is not None:
|
||||
if isinstance(B, ir.IRNode):
|
||||
if not isinstance(B, ir.TensorBox):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
B = ir.TensorBox(B)
|
||||
assert hasattr(X, "get_size")
|
||||
# pyrefly: ignore # missing-attribute
|
||||
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
|
||||
else:
|
||||
assert isinstance(B, torch.Tensor)
|
||||
assert isinstance(X, torch.Tensor)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
B = B.expand(X.shape[0], B.shape[-1])
|
||||
return B
|
||||
|
||||
@ -1043,6 +1049,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
return cls.prep_weight(
|
||||
new_inputs,
|
||||
new_layout,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
micro_gemm,
|
||||
pre_block_weights,
|
||||
use_int8_fast_compensation_path,
|
||||
@ -1066,6 +1073,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
new_input_nodes, _ = cls.prep_weight(
|
||||
new_input_nodes,
|
||||
new_layout,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
micro_gemm,
|
||||
pre_block_weights,
|
||||
use_int8_fast_compensation_path,
|
||||
@ -1470,7 +1478,9 @@ class CppGemmTemplate(CppTemplate):
|
||||
assert isinstance(template_buffer, ir.IRNode)
|
||||
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
||||
gemm_output_buffer = ir.Buffer(
|
||||
name=gemm_output_name, layout=template_buffer.layout
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name=gemm_output_name,
|
||||
layout=template_buffer.layout,
|
||||
)
|
||||
current_input_buffer = gemm_output_buffer
|
||||
for i, creator in enumerate(epilogue_creators):
|
||||
@ -1481,6 +1491,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
epilogues.append(
|
||||
ir.ComputedBuffer(
|
||||
name=buffer_name,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
layout=template_buffer.layout,
|
||||
data=creator(current_input_buffer),
|
||||
)
|
||||
@ -1490,7 +1501,9 @@ class CppGemmTemplate(CppTemplate):
|
||||
reindexers.append(None)
|
||||
if i < len(epilogue_creators) - 1:
|
||||
current_input_buffer = ir.Buffer(
|
||||
name=buffer_name, layout=template_buffer.layout
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name=buffer_name,
|
||||
layout=template_buffer.layout,
|
||||
)
|
||||
|
||||
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
|
||||
@ -1521,6 +1534,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
self.n,
|
||||
self.k,
|
||||
input_dtype=X.get_dtype(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
input2_dtype=W.get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
|
@ -183,12 +183,14 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
)
|
||||
self.act_mapping = act_mapping
|
||||
self.gemm_grouped_num = gemm_grouped_num
|
||||
# pyrefly: ignore # bad-override
|
||||
self.output_node: list[ir.Buffer] = [
|
||||
ir.Buffer(name="buf_out" + str(idx), layout=layout)
|
||||
for idx in range(gemm_grouped_num)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def add_choices(
|
||||
cls,
|
||||
choices: list[ChoiceCaller],
|
||||
@ -231,6 +233,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
if isinstance(inputs[idx], torch.Tensor):
|
||||
W = inputs[idx]
|
||||
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
@ -246,8 +249,10 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
new_input = new_inputs[wgt_idx]
|
||||
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
|
||||
for bias_idx in range(bias_start_idx, len(new_inputs)):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
new_bias = expand_bias(new_inputs[bias_idx], X)
|
||||
assert new_bias is not None
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
new_inputs[bias_idx] = new_bias
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
@ -308,6 +313,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
W_tensor = []
|
||||
for W_node in W_nodes:
|
||||
assert W_node.get_name() in V.graph.constants
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
W_tensor.append(V.graph.constants[W_node.get_name()])
|
||||
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
|
||||
W_tensor # type: ignore[assignment]
|
||||
@ -324,6 +330,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
template_buffer.inputs[idx] = (
|
||||
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||
)
|
||||
# pyrefly: ignore # bad-return
|
||||
return output
|
||||
|
||||
template = DataProcessorTemplateWrapper(
|
||||
@ -362,6 +369,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
cur_idx = bias_start_idx
|
||||
for inp_idx in range(self.gemm_grouped_num):
|
||||
inp = None
|
||||
# pyrefly: ignore # index-error
|
||||
if self.has_bias[inp_idx]:
|
||||
inp = self.input_nodes[cur_idx]
|
||||
cur_idx += 1
|
||||
@ -390,6 +398,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
self.n,
|
||||
self.k,
|
||||
input_dtype=X_list[0].get_dtype(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
input2_dtype=W_list[0].get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
@ -427,6 +436,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
for x_idx in range(wgt_start_idx):
|
||||
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
|
||||
for w_idx in range(self.gemm_grouped_num):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
kernel_args["W" + str(w_idx)] = W_list[w_idx]
|
||||
for inp_idx in range(self.gemm_grouped_num):
|
||||
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
|
||||
|
@ -85,6 +85,7 @@ class CppTemplate(KernelTemplate):
|
||||
bmreq = CppBenchmarkRequest(
|
||||
kernel_name=kernel_name,
|
||||
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
||||
extra_args=extra_args,
|
||||
source_code=code,
|
||||
@ -112,6 +113,7 @@ class CppTemplate(KernelTemplate):
|
||||
kernel_hash_name,
|
||||
self.name,
|
||||
self.input_nodes,
|
||||
# pyrefly: ignore # index-error
|
||||
self.output_node[0].get_layout()
|
||||
if isinstance(self.output_node, Iterable)
|
||||
else self.output_node.get_layout(),
|
||||
|
@ -411,6 +411,7 @@ class CppTemplateKernel(CppKernel):
|
||||
)
|
||||
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
||||
return self.store_pointwise_nodes(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dst,
|
||||
epilogue_nodes, # type: ignore[arg-type]
|
||||
offsets,
|
||||
@ -422,6 +423,7 @@ class CppTemplateKernel(CppKernel):
|
||||
copy = L.copy(dst, src).data.data
|
||||
with LocalBufferContext(self.args) as scope:
|
||||
scope.add_local_buffer(src)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return self.store_pointwise_nodes(dst, [copy])
|
||||
else:
|
||||
assert dst.layout == src.layout, f"{dst=}, {src=}"
|
||||
|
@ -311,6 +311,7 @@ class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
return res
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
return self._inner.store_reduction(*self.localize(name, index), value)
|
||||
|
||||
|
||||
|
@ -307,6 +307,7 @@ class DeferredTritonCallWrapper:
|
||||
f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
|
||||
]
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
total_args.append(f"tmp_{arg_name}")
|
||||
|
||||
def process_args_for_input_shape(arg, arg_type, arg_signature=None):
|
||||
@ -331,6 +332,7 @@ class DeferredTritonCallWrapper:
|
||||
f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
|
||||
]
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
total_args.append(f"tmp_{arg_name}")
|
||||
elif (
|
||||
isinstance(arg_type, type(SymbolicCallArg))
|
||||
@ -348,6 +350,7 @@ class DeferredTritonCallWrapper:
|
||||
for arg, arg_type, arg_signature in zip_longest(
|
||||
call_args, arg_types, arg_signatures
|
||||
):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
ordered_argsname.append(f'"{arg}"')
|
||||
process_args_for_input_shape(arg, arg_type, arg_signature)
|
||||
|
||||
@ -819,7 +822,9 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
|
||||
if triton:
|
||||
call_args, arg_types = self.prepare_triton_wrapper_args(
|
||||
call_args, arg_types
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
call_args,
|
||||
arg_types,
|
||||
)
|
||||
wrapper_name = f"call_{kernel_name}"
|
||||
if wrapper_name not in self._triton_call_wrappers:
|
||||
@ -843,10 +848,12 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
self.writeline(f"{wrapper_name}({', '.join(call_args)});")
|
||||
else:
|
||||
casted = []
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
for arg_type, arg in zip(arg_types, call_args):
|
||||
new_arg = arg
|
||||
if arg_type.endswith("*") and arg != "nullptr":
|
||||
new_arg = f"{arg}.data_ptr()"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
casted.append(f"({arg_type}){cexpr(new_arg)}")
|
||||
call_args_str = ", ".join(casted)
|
||||
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")
|
||||
|
@ -190,6 +190,7 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
assert all(n.node is not None for n in nodes), (
|
||||
"All epilogue nodes should have an IRNode"
|
||||
)
|
||||
# pyrefly: ignore # redundant-cast
|
||||
return cast(
|
||||
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
|
||||
)
|
||||
|
@ -72,6 +72,7 @@ class CUDATemplate(KernelTemplate):
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
# pyrefly: ignore # bad-override
|
||||
def _template_from_string(cls, source: str) -> Any:
|
||||
return KernelTemplate._template_from_string(source)
|
||||
|
||||
|
@ -163,6 +163,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
) -> None:
|
||||
self.example_inputs = example_inputs
|
||||
self.ast = ast.parse(self.source)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.visit(self.ast)
|
||||
|
||||
cc = int(cuda_env.get_cuda_arch())
|
||||
|
@ -470,6 +470,7 @@ class CUDACompileSourceCapturingContext:
|
||||
self.sources.append(source_code)
|
||||
return _compile_method_orig(source_code, dst_file_ext)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._compile_patch = mock.patch(
|
||||
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
|
||||
)
|
||||
|
@ -286,6 +286,7 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
# Generate unpacking assignments: in_ptr4 = buffers[0], etc.
|
||||
unpacking_lines = []
|
||||
for i, buffer_name in enumerate(tensor_buffers):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
unpacking_lines.append(f"{buffer_name} = buffers[{i}]")
|
||||
|
||||
return "\n ".join(unpacking_lines)
|
||||
@ -493,6 +494,7 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
||||
"""Convert index variable to symbolic form."""
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
||||
) -> str:
|
||||
|
@ -274,7 +274,9 @@ class CuteDSLOpOverrides(OpOverrides):
|
||||
else "mlir_math.absi"
|
||||
)
|
||||
return CuteDSLOpOverrides._apply_unary_op(
|
||||
x, f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
x,
|
||||
f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -43,6 +43,7 @@ class CuteDSLTemplate(KernelTemplate):
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
# pyrefly: ignore # bad-override
|
||||
def _template_from_string(source: str) -> Any:
|
||||
return KernelTemplate._template_from_string(source)
|
||||
|
||||
|
@ -636,6 +636,7 @@ class DimensionInfo:
|
||||
return "hl.Var()"
|
||||
if replacements:
|
||||
replacements = {**replacements}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for sym in expr.free_symbols:
|
||||
if symbol_is_type(sym, SymT.TMP):
|
||||
assert isinstance(sym, sympy.Symbol)
|
||||
@ -709,8 +710,10 @@ class HalideKernel(SIMDKernel):
|
||||
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
||||
return halide_type(dtype)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
|
||||
self.body.writeline(f"{name} = hl.Func({name!r})")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return HalideCSEVariable(name, bounds, dtype, shape)
|
||||
|
||||
def finalize_indexing(self, indices: Sequence[sympy.Expr]):
|
||||
@ -728,6 +731,7 @@ class HalideKernel(SIMDKernel):
|
||||
self.index_replacements or self.halide_vars or self.reduction_renames
|
||||
)
|
||||
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
indices = dict.fromkeys(map(super().prepare_indexing, indices))
|
||||
all_used_symbols = OrderedSet[Any]()
|
||||
sym_to_node = {
|
||||
@ -826,6 +830,7 @@ class HalideKernel(SIMDKernel):
|
||||
handled_count = len(nodes)
|
||||
had_fallback = True
|
||||
sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
|
||||
# pyrefly: ignore # missing-argument
|
||||
if tree.is_reduction:
|
||||
self.reduction_renames[sym] = sympy_index_symbol(
|
||||
f"hr{len(self.halide_vars)}"
|
||||
@ -1222,8 +1227,10 @@ class HalideKernel(SIMDKernel):
|
||||
parts = []
|
||||
stride = 1
|
||||
for i, sym in enumerate(self.reduction_renames):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
parts.append(f"{index}[{i}]")
|
||||
if stride != 1:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
parts[-1] += f"*{stride}"
|
||||
stride *= self.halide_vars[sym]
|
||||
self.body.writeline(f"{result_var} = {' + '.join(parts)}")
|
||||
@ -1576,6 +1583,7 @@ class HalideKernel(SIMDKernel):
|
||||
hint = self._autoscheduler_workarounds(
|
||||
V.graph.sizevars.size_hint(dim.size, fallback=1), dims
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
range_hints.append(f"hl.Range(0, {hint})")
|
||||
if "out" not in arg.name:
|
||||
code.writeline(f"{arg.name}.dim({i}).set_min(0)")
|
||||
|
@ -516,6 +516,7 @@ class MetalKernel(SIMDKernel):
|
||||
var = self.args.output(name)
|
||||
index = self.prepare_indexing(index)
|
||||
dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
|
||||
# pyrefly: ignore # missing-argument
|
||||
reduction_dim = next(t for t in self.range_trees if t.is_reduction)
|
||||
# Only one thread in the reduction group needs to store the results
|
||||
line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});"
|
||||
@ -582,6 +583,7 @@ class MetalKernel(SIMDKernel):
|
||||
reduction_idx = ""
|
||||
acc_buf_size = 1
|
||||
for rd in self.range_trees:
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not rd.is_reduction:
|
||||
continue
|
||||
if reduction_idx:
|
||||
@ -678,7 +680,10 @@ class MetalKernel(SIMDKernel):
|
||||
)
|
||||
idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment]
|
||||
idx_var = next(
|
||||
t for t in self.range_tree_nodes.values() if t.is_reduction
|
||||
# pyrefly: ignore # missing-argument
|
||||
t
|
||||
for t in self.range_tree_nodes.values()
|
||||
if t.is_reduction
|
||||
)
|
||||
cmp_op = ">" if reduction_type == "argmax" else "<"
|
||||
nan_suffix = (
|
||||
@ -745,6 +750,7 @@ class MetalKernel(SIMDKernel):
|
||||
index_expr = self.rename_indexing(entry.expr)
|
||||
index_str = self.sexpr(index_expr) # type: ignore[misc]
|
||||
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not entry.is_reduction or (
|
||||
isinstance(entry.root.numel, sympy.Integer)
|
||||
and entry.root.numel <= self.max_threadgroup_size
|
||||
@ -856,7 +862,10 @@ class MetalKernel(SIMDKernel):
|
||||
|
||||
if self.inside_reduction:
|
||||
total_reduction_size = math.prod(
|
||||
t.numel for t in self.range_trees if t.is_reduction
|
||||
# pyrefly: ignore # missing-argument
|
||||
t.numel
|
||||
for t in self.range_trees
|
||||
if t.is_reduction
|
||||
)
|
||||
# If using dynamic shapes, set the threadgroup size to be the
|
||||
# max possible size
|
||||
@ -958,6 +967,7 @@ class MetalKernel(SIMDKernel):
|
||||
else:
|
||||
expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner
|
||||
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not tree.is_reduction or self.inside_reduction:
|
||||
args.append(str(expr))
|
||||
arg_types.append(int)
|
||||
@ -977,6 +987,7 @@ class MetalKernel(SIMDKernel):
|
||||
threads = [
|
||||
expr_printer(
|
||||
sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc]
|
||||
# pyrefly: ignore # missing-argument
|
||||
if v.is_reduction
|
||||
else v.numel
|
||||
)
|
||||
@ -992,6 +1003,7 @@ class MetalKernel(SIMDKernel):
|
||||
if self.inside_reduction:
|
||||
threads = [
|
||||
expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc]
|
||||
# pyrefly: ignore # missing-argument
|
||||
if v.is_reduction
|
||||
else "1"
|
||||
for v in self.active_range_trees()
|
||||
|
@ -306,6 +306,7 @@ class MultiKernelCall:
|
||||
# manually force a subkernel to ease perf testing
|
||||
picked_by_config = config.triton.multi_kernel - 2
|
||||
assert picked_by_config < len(self._kernels)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.picked_kernel = picked_by_config
|
||||
elif not self.disable_cache:
|
||||
self.load_cache()
|
||||
@ -329,7 +330,9 @@ class MultiKernelCall:
|
||||
path = self.cache_file_path()
|
||||
if path.exists():
|
||||
with path.open() as fd:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.picked_kernel = int(fd.read())
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
assert self.picked_kernel >= 0 and self.picked_kernel < len(
|
||||
self._kernels
|
||||
)
|
||||
@ -599,5 +602,6 @@ class SizeHintMultiKernelCall(MultiKernelCall):
|
||||
self._dist_heuristic(shape_key, key) if key is not None else 2**62
|
||||
for key in self._kernel_hints
|
||||
]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.picked_kernel = dists.index(min(dists))
|
||||
self._cache_shape_choice(shape_key, self.picked_kernel)
|
||||
|
@ -513,9 +513,11 @@ class CKGroupedConvFwdTemplate(CKTemplate):
|
||||
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
|
||||
else: # tile shape
|
||||
arg = f"/* {field_name} */ S<{tuple_elements}>"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
template_params.append(arg)
|
||||
else:
|
||||
if field_value is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
template_params.append(f"/* {field_name} */ {field_value}")
|
||||
return self._template_from_string(template_definition).render(
|
||||
operation_name=op.name(),
|
||||
|
@ -590,9 +590,11 @@ class CKGemmTemplate(CKTemplate):
|
||||
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
|
||||
else: # tile shape
|
||||
arg = f"/* {field_name} */ S<{tuple_elements}>"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
template_params.append(arg)
|
||||
else:
|
||||
if field_value is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
template_params.append(f"/* {field_name} */ {field_value}")
|
||||
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
|
||||
return self._template_from_string(template_definition).render(
|
||||
|
@ -187,6 +187,7 @@ class IterationRangesRoot(IterationRanges):
|
||||
|
||||
# True if the dimension is implemented as a single program looping over
|
||||
# the full dimension (currently only used for non-persistent reduction)
|
||||
# pyrefly: ignore # missing-argument
|
||||
assert not is_loop or (self.is_reduction and grid_dim is None)
|
||||
self.is_loop = is_loop
|
||||
# Index of corresponding dimension on triton tensors
|
||||
@ -374,6 +375,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
sexpr: Callable[[sympy.Expr], str] = pexpr
|
||||
kexpr: Callable[[sympy.Expr], str]
|
||||
allow_block_ptr: bool = False
|
||||
# pyrefly: ignore # bad-override
|
||||
kernel_name: str
|
||||
|
||||
def __init__(
|
||||
@ -570,6 +572,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
if tree.tensor_dim is None:
|
||||
continue
|
||||
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not tree.is_reduction or self.inside_reduction:
|
||||
sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
|
||||
return sizes
|
||||
@ -962,7 +965,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
|
||||
def active_range_trees(self) -> list[IterationRangesRoot]:
|
||||
return [
|
||||
t for t in self.range_trees if not t.is_reduction or self.inside_reduction
|
||||
# pyrefly: ignore # missing-argument
|
||||
t
|
||||
for t in self.range_trees
|
||||
if not t.is_reduction or self.inside_reduction
|
||||
]
|
||||
|
||||
def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
|
||||
@ -1110,6 +1116,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
numel = buf_size
|
||||
dtype = V.graph.get_dtype(arg)
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
|
||||
return sum(nbytes)
|
||||
|
||||
@ -1130,6 +1137,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
|
||||
argdefs, call_args, _signature, _ = self.args.python_argdefs()
|
||||
uniform_stride_order = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg_name in call_args:
|
||||
buf = V.graph.try_get_buffer(arg_name)
|
||||
if not buf:
|
||||
@ -1753,11 +1761,13 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
for input_name in kernel.named_input_nodes.keys():
|
||||
subgraph_name = f"<LOAD_INPUT_{input_name}>"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
partial_code.finalize_hook(subgraph_name, strict=False)
|
||||
|
||||
num_store_subgraphs = kernel.get_store_output_count()
|
||||
for i in range(num_store_subgraphs):
|
||||
subgraph_name = kernel._get_store_output_subgraph_name(i)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
partial_code.finalize_hook(subgraph_name)
|
||||
|
||||
if isinstance(partial_code, str):
|
||||
@ -1879,6 +1889,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
only_gen_src_code=True,
|
||||
)
|
||||
assert isinstance(src_code, str)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
src_codes.append(src_code)
|
||||
else:
|
||||
if size_hint is None:
|
||||
@ -2708,6 +2719,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
perf_hint_log.info("possibly bad tiling: %s", ranked_tilings)
|
||||
|
||||
# Optionally, prefer tiling into as many dimensions as possible.
|
||||
# pyrefly: ignore # unbound-name
|
||||
if config.triton.prefer_nd_tiling:
|
||||
ranked_tilings = (
|
||||
cls.get_nd_tilings(node_schedule, numel, reduction_numel)
|
||||
@ -2757,6 +2769,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
|
||||
return src_code
|
||||
|
||||
|
@ -80,6 +80,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
bm_graph_lowering.graph_input_names.append(sym_inp.name)
|
||||
|
||||
sym_inputs = [
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
int(V.graph.sizevars.shape_env.size_hint(sym_var))
|
||||
for sym_var in self.sym_inputs
|
||||
]
|
||||
|
@ -379,6 +379,7 @@ class ComboKernel(Kernel):
|
||||
|
||||
def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
|
||||
sub_kernel = triton_kernel
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_kernel_count -= 1
|
||||
sub_kernel.args = self.args
|
||||
sub_kernel.iter_vars_count = self.iter_vars_count
|
||||
@ -434,10 +435,12 @@ class ComboKernel(Kernel):
|
||||
assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args
|
||||
uniquify_block_sizes.append(f"{tree.prefix}numel")
|
||||
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not tree.is_reduction:
|
||||
if isinstance(simplified_tree_numel, (Integer, int)):
|
||||
grid.append(int(simplified_tree_numel))
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grid.append(f"{tree.prefix}numel_{num}")
|
||||
|
||||
if tree.is_reduction and sub_kernel.persistent_reduction:
|
||||
@ -475,8 +478,10 @@ class ComboKernel(Kernel):
|
||||
if sub_kernel.no_x_dim:
|
||||
min_x_blocks = x_numels
|
||||
x_numels = (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
-min_x_blocks
|
||||
if isinstance(x_numels, int)
|
||||
# pyrefly: ignore # redundant-cast
|
||||
else "-" + cast(str, x_numels)
|
||||
)
|
||||
else:
|
||||
@ -606,6 +611,7 @@ class ComboKernel(Kernel):
|
||||
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
|
||||
"constants": {},
|
||||
}
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
mutated_args = self.get_mutated_args_sub_kernels()
|
||||
dispatch = self.dispatch_class
|
||||
@ -684,6 +690,7 @@ class ComboKernel(Kernel):
|
||||
for sub_kernel in self.sub_kernels:
|
||||
# TODO: we assume all sub_kernels have the same block size
|
||||
for tree in sub_kernel.range_trees:
|
||||
# pyrefly: ignore # missing-argument
|
||||
if tree.is_reduction and (
|
||||
not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
|
||||
):
|
||||
@ -722,6 +729,7 @@ class ComboKernel(Kernel):
|
||||
expr = V.graph.wrapper_code.generate_numel_expr(
|
||||
name, tree, suffix=str(num)
|
||||
)
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
call_args.append(expr)
|
||||
arg_types.append(type(expr))
|
||||
@ -733,6 +741,7 @@ class ComboKernel(Kernel):
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
extra_args.append(
|
||||
str(
|
||||
@ -1012,6 +1021,7 @@ class ComboKernel(Kernel):
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not tree.is_reduction:
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name in self.dynamic_shape_args:
|
||||
|
@ -256,4 +256,5 @@ def config_of(
|
||||
|
||||
equal_to_1 = equal_1_arg_indices(args, indices=indices)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
|
||||
|
@ -1115,6 +1115,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
return PythonWrapperCodegen()
|
||||
|
||||
def set_launcher_fn_name(self) -> None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.launcher_fn_name = "call"
|
||||
|
||||
def write_constant(self, name: str, hashed: str) -> None:
|
||||
@ -1251,14 +1252,17 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.write_get_raw_stream_header()
|
||||
|
||||
def add_meta_once(self, meta: TritonMetaParams) -> str:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
meta = repr(meta)
|
||||
if meta not in self._metas:
|
||||
var = f"meta{len(self._metas)}"
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self._metas[meta] = var
|
||||
self.header.writeline(f"{var} = {meta}")
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.writeline(f"{var} = {meta}")
|
||||
self._meta_vars.add(var)
|
||||
# pyrefly: ignore # index-error
|
||||
return self._metas[meta]
|
||||
|
||||
@cache_on_self
|
||||
@ -1694,6 +1698,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
with self.set_writeline(self.wrapper_call.writeline):
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
line.codegen(self.wrapper_call)
|
||||
else:
|
||||
self.wrapper_call.writeline(line)
|
||||
@ -2774,13 +2779,18 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self,
|
||||
kernel_name=kernel_name,
|
||||
call_args=call_args,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
raw_keys=raw_keys,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
raw_args=raw_args,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
arg_types=arg_types,
|
||||
triton=triton,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
triton_meta=triton_meta,
|
||||
device=device,
|
||||
graph_name=V.graph.name,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
original_fxnode_name=original_fxnode_name,
|
||||
)
|
||||
)
|
||||
@ -2901,6 +2911,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
|
||||
reused_args = {}
|
||||
for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
zip(call_args, arg_types, raw_keys, raw_args)
|
||||
):
|
||||
key = None
|
||||
@ -3688,6 +3699,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
def set_launcher_fn_name(self) -> None:
|
||||
# This sets up the name of the function containing the launcher code of
|
||||
# the subgraph.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.launcher_fn_name = self.subgraph_name
|
||||
|
||||
def write_header(self) -> None:
|
||||
|
@ -186,6 +186,7 @@ class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
"""
|
||||
Get the input nodes corresponding to FX graph placeholders.
|
||||
"""
|
||||
# pyrefly: ignore # missing-argument
|
||||
if V.aot_compilation and not self.is_subgraph:
|
||||
# AOT graphs must match the signature of the input module.
|
||||
return {
|
||||
@ -210,6 +211,7 @@ class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
graph_inputs=self.get_fx_graph_inputs(),
|
||||
graph_outputs=self.get_graph_outputs(),
|
||||
subgms=self.subgms,
|
||||
# pyrefly: ignore # missing-argument
|
||||
is_subgraph=self.is_subgraph,
|
||||
).generate()
|
||||
|
||||
@ -992,13 +994,17 @@ class FxConverter:
|
||||
call_kwargs = {
|
||||
key: val
|
||||
for key, val in zip(signature, call_args)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if key not in constants and key not in cfg.kwargs
|
||||
}
|
||||
|
||||
# Add constants stored as Triton metadata, in signature order.
|
||||
call_kwargs |= constants
|
||||
new_call_args = [
|
||||
call_kwargs[key] for key in signature if key not in cfg.kwargs
|
||||
# pyrefly: ignore # missing-attribute
|
||||
call_kwargs[key]
|
||||
for key in signature
|
||||
if key not in cfg.kwargs
|
||||
]
|
||||
|
||||
# Add Inductor's extra launcher args to the end.
|
||||
@ -1014,9 +1020,11 @@ class FxConverter:
|
||||
call_args = add_constants_to_call_args(call_args, kernel_config)
|
||||
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
|
||||
call_kwargs = dict(zip(signature, call_args))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), (
|
||||
f"kwargs overlap config: {call_kwargs}"
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
call_kwargs.update(kernel_config.kwargs)
|
||||
|
||||
# Replace sympy.floor with FloorDiv, to make the expression traceable.
|
||||
|
@ -356,7 +356,7 @@ def bucket_all_reduce(
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user