mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti][mps] Improve tabbing in cpp generation (#158351)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158351 Approved by: https://github.com/desertfire, https://github.com/malfet ghstack dependencies: #158349, #158350
This commit is contained in:
committed by
PyTorch MergeBot
parent
84058d1179
commit
cc372ad557
@ -81,11 +81,11 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
|
||||
if isinstance(arg_type, torch.dtype):
|
||||
new_args.append(
|
||||
f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n"
|
||||
f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});"
|
||||
)
|
||||
elif arg_type in (int, sympy.core.symbol.Symbol):
|
||||
new_args.append(
|
||||
f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n"
|
||||
f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -110,28 +110,26 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
"cpp",
|
||||
)
|
||||
with debug_printer_manager:
|
||||
self.writeline(self.wrap_mps_kernel_call(kernel_name, new_args))
|
||||
|
||||
def wrap_mps_kernel_call(self, name: str, call_args: list[str]) -> str:
|
||||
lib_name = name[: -len("_func")]
|
||||
calling_args = " ".join(call_args)
|
||||
|
||||
kernel_call_str = ""
|
||||
self.write_mps_kernel_call(kernel_name, new_args)
|
||||
|
||||
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
|
||||
# Only add handle definition if the kernel is not already used
|
||||
lib_name = name[: -len("_func")]
|
||||
if name not in self._used_kernel_names:
|
||||
self._used_kernel_names.add(name)
|
||||
kernel_call_str += f"""
|
||||
auto {name} = {lib_name}.getKernelFunction("generated_kernel");
|
||||
auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());
|
||||
"""
|
||||
kernel_call_str += f"""
|
||||
{name}->runCommandBlock([&] {{
|
||||
{name}->startEncoding();
|
||||
{calling_args}
|
||||
}});
|
||||
"""
|
||||
return kernel_call_str
|
||||
|
||||
self.writeline(
|
||||
f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");'
|
||||
)
|
||||
self.writeline(
|
||||
f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());"
|
||||
)
|
||||
|
||||
self.writeline(f"{name}->runCommandBlock([&] {{")
|
||||
self.writeline(f" {name}->startEncoding();")
|
||||
for call_arg in call_args:
|
||||
self.writeline(f" {call_arg}")
|
||||
self.writeline("});")
|
||||
|
||||
@staticmethod
|
||||
def get_device_include_path(device: str) -> str:
|
||||
|
Reference in New Issue
Block a user