[aoti][mps] Improve tabbing in cpp generation (#158351)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158351
Approved by: https://github.com/desertfire, https://github.com/malfet
ghstack dependencies: #158349, #158350
This commit is contained in:
angelayi
2025-07-22 13:38:34 -07:00
committed by PyTorch MergeBot
parent 84058d1179
commit cc372ad557

View File

@ -81,11 +81,11 @@ class CppWrapperMps(CppWrapperGpu):
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
if isinstance(arg_type, torch.dtype):
new_args.append(
f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n"
f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});"
)
elif arg_type in (int, sympy.core.symbol.Symbol):
new_args.append(
f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n"
f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});"
)
else:
raise NotImplementedError(
@ -110,28 +110,26 @@ class CppWrapperMps(CppWrapperGpu):
"cpp",
)
with debug_printer_manager:
self.writeline(self.wrap_mps_kernel_call(kernel_name, new_args))
def wrap_mps_kernel_call(self, name: str, call_args: list[str]) -> str:
lib_name = name[: -len("_func")]
calling_args = " ".join(call_args)
kernel_call_str = ""
self.write_mps_kernel_call(kernel_name, new_args)
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
# Only add handle definition if the kernel is not already used
lib_name = name[: -len("_func")]
if name not in self._used_kernel_names:
self._used_kernel_names.add(name)
kernel_call_str += f"""
auto {name} = {lib_name}.getKernelFunction("generated_kernel");
auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());
"""
kernel_call_str += f"""
{name}->runCommandBlock([&] {{
{name}->startEncoding();
{calling_args}
}});
"""
return kernel_call_str
self.writeline(
f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");'
)
self.writeline(
f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());"
)
self.writeline(f"{name}->runCommandBlock([&] {{")
self.writeline(f" {name}->startEncoding();")
for call_arg in call_args:
self.writeline(f" {call_arg}")
self.writeline("});")
@staticmethod
def get_device_include_path(device: str) -> str: