diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index 143141ec4f68..b953927f52be 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -81,11 +81,11 @@ class CppWrapperMps(CppWrapperGpu): for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): if isinstance(arg_type, torch.dtype): new_args.append( - f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});" ) elif arg_type in (int, sympy.core.symbol.Symbol): new_args.append( - f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});" ) else: raise NotImplementedError( @@ -110,28 +110,26 @@ class CppWrapperMps(CppWrapperGpu): "cpp", ) with debug_printer_manager: - self.writeline(self.wrap_mps_kernel_call(kernel_name, new_args)) - - def wrap_mps_kernel_call(self, name: str, call_args: list[str]) -> str: - lib_name = name[: -len("_func")] - calling_args = " ".join(call_args) - - kernel_call_str = "" + self.write_mps_kernel_call(kernel_name, new_args) + def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: # Only add handle definition if the kernel is not already used + lib_name = name[: -len("_func")] if name not in self._used_kernel_names: self._used_kernel_names.add(name) - kernel_call_str += f""" - auto {name} = {lib_name}.getKernelFunction("generated_kernel"); - auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); - """ - kernel_call_str += f""" - {name}->runCommandBlock([&] {{ - {name}->startEncoding(); - {calling_args} - }}); - """ - return kernel_call_str + + self.writeline( + f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");' + ) + self.writeline( + f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());" + ) + + self.writeline(f"{name}->runCommandBlock([&] {{") + self.writeline(f" {name}->startEncoding();") + for call_arg in call_args: + self.writeline(f" {call_arg}") + self.writeline("});") @staticmethod def get_device_include_path(device: str) -> str: