mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti][mps] Fix deduplication of kernels (#156843)
Previously I was not correctly deduplicating kernels generated by mps, so it would generate multiple of the same kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156843 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
977abe786d
commit
17dab018e3
@ -6,7 +6,7 @@ import sys
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing import FileCheck, make_tensor
|
||||
from torch.testing._internal.common_dtype import get_all_dtypes
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -274,6 +274,38 @@ class MPSBasicTestsAOTI(TestCase):
|
||||
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
|
||||
self.check_model(m, inp, dynamic_shapes)
|
||||
|
||||
def test_reuse_kernel(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
a = torch.sin(x)
|
||||
b = torch.mm(a, y)
|
||||
c = torch.sin(b)
|
||||
d = torch.mm(b, c)
|
||||
return d
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(87, 87, device="mps"),
|
||||
torch.randn(87, 87, device="mps"),
|
||||
)
|
||||
model = Model()
|
||||
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
package_path = torch._export.aot_compile(ep.module(), example_inputs)
|
||||
|
||||
target_str = 'mps_lib_0.getKernelFunction("generated_kernel")'
|
||||
target_count = 1
|
||||
|
||||
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:
|
||||
src_code = cpp.read()
|
||||
FileCheck().check_count(
|
||||
target_str,
|
||||
target_count,
|
||||
exactly=True,
|
||||
).run(src_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..ir import GraphPartitionSignature
|
||||
from ..virtualized import V
|
||||
@ -11,6 +12,10 @@ from .wrapper import PythonWrapperCodegen
|
||||
|
||||
|
||||
class CppWrapperMps(CppWrapperGpu):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._used_kernel_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
is_subgraph: bool,
|
||||
@ -81,14 +86,23 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
def wrap_kernel_call(self, name: str, call_args: list[str]) -> str:
|
||||
lib_name = name[: -len("_func")]
|
||||
calling_args = " ".join(call_args)
|
||||
return f"""
|
||||
|
||||
kernel_call_str = ""
|
||||
|
||||
# Only add handle definition if the kernel is not already used
|
||||
if name not in self._used_kernel_names:
|
||||
self._used_kernel_names.add(name)
|
||||
kernel_call_str += f"""
|
||||
auto {name} = {lib_name}.getKernelFunction("generated_kernel");
|
||||
auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());
|
||||
"""
|
||||
kernel_call_str += f"""
|
||||
{name}->runCommandBlock([&] {{
|
||||
{name}->startEncoding();
|
||||
{calling_args}
|
||||
}});
|
||||
"""
|
||||
return kernel_call_str
|
||||
|
||||
@staticmethod
|
||||
def get_device_include_path(device: str) -> str:
|
||||
|
@ -972,15 +972,18 @@ class MetalScheduling(SIMDScheduling):
|
||||
mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}"
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
src_code = (
|
||||
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}"
|
||||
+ src_code
|
||||
)
|
||||
kernel_name = f"{mps_lib_name}_func"
|
||||
else:
|
||||
kernel_name = f"{mps_lib_name}.generated_kernel"
|
||||
|
||||
wrapper.src_to_kernel[src_code] = kernel_name
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
src_code = (
|
||||
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}"
|
||||
+ src_code
|
||||
)
|
||||
|
||||
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
metadata_comment = f"{origins}\n{detailed_origins}"
|
||||
wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False)
|
||||
|
Reference in New Issue
Block a user