[MPS] Add support for Custom Kernels (#100661)

- This change introduces these APIs to enable developing custom kernels on the MPS Stream:
`torch::mps::get_command_buffer()`
`torch::mps::get_dispatch_queue()`
`torch::mps::commit()`
- Add ObjC test case
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100661
Approved by: https://github.com/kulinseth, https://github.com/malfet
This commit is contained in:
Ramin Azarmehr
2023-05-08 20:05:46 +00:00
committed by PyTorch MergeBot
parent d9d98b4d54
commit f39cda83d1
14 changed files with 196 additions and 49 deletions

View File

@ -17,48 +17,55 @@ class Context;
namespace at {
struct TORCH_API MPSHooksInterface {
// this fails the implementation if MPSHooks functions are called, but
// MPS backend is not present.
#define FAIL_MPSHOOKS_FUNC(func) \
TORCH_CHECK(false, "Cannot execute ", func ,"() without MPS backend.");
virtual ~MPSHooksInterface() = default;
// Initialize the MPS library state
virtual void initMPS() const {
AT_ERROR("Cannot initialize MPS without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual bool hasMPS() const {
return false;
}
virtual bool isOnMacOS13orNewer(unsigned minor = 0) const {
AT_ERROR("MPS backend is not available.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual const Generator& getDefaultMPSGenerator() const {
AT_ERROR("Cannot get default MPS generator without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual Allocator* getMPSDeviceAllocator() const {
AT_ERROR("MPSDeviceAllocator requires MPS.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void deviceSynchronize() const {
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void commitStream() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void* getCommandBuffer() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void* getDispatchQueue() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void emptyCache() const {
AT_ERROR("Cannot execute emptyCache() without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual size_t getCurrentAllocatedMemory() const {
AT_ERROR("Cannot execute getCurrentAllocatedMemory() without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual size_t getDriverAllocatedMemory() const {
AT_ERROR("Cannot execute getDriverAllocatedMemory() without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void setMemoryFraction(double /*ratio*/) const {
AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
FAIL_MPSHOOKS_FUNC(__func__);
}
#undef FAIL_MPSHOOKS_FUNC
};
struct TORCH_API MPSHooksArgs {};

View File

@ -80,7 +80,6 @@ class TORCH_API MPSDevice {
TORCH_API bool is_available();
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
TORCH_API void device_synchronize();
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
} // namespace mps

View File

@ -141,9 +141,5 @@ bool is_macos_13_or_newer(MacOSVersion version) {
return MPSDevice::getInstance()->isMacOS13Plus(version);
}
void device_synchronize() {
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
}
} // namespace mps
} // namespace at

View File

@ -12,11 +12,22 @@ namespace at { namespace mps {
struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
// MPSDevice interface
bool hasMPS() const override;
bool isOnMacOS13orNewer(unsigned minor) const override;
Allocator* getMPSDeviceAllocator() const override;
// MPSGeneratorImpl interface
const Generator& getDefaultMPSGenerator() const override;
// MPSStream interface
void deviceSynchronize() const override;
void commitStream() const override;
void* getCommandBuffer() const override;
void* getDispatchQueue() const override;
// MPSAllocator interface
Allocator* getMPSDeviceAllocator() const override;
void emptyCache() const override;
size_t getCurrentAllocatedMemory() const override;
size_t getDriverAllocatedMemory() const override;

View File

@ -1,9 +1,10 @@
// Copyright © 2022 Apple Inc.
#include <ATen/mps/MPSHooks.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSGeneratorImpl.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSHooks.h>
#include <ATen/mps/MPSStream.h>
namespace at {
namespace mps {
@ -26,7 +27,7 @@ bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const {
case 2:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
default:
TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+");
TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.2+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
}
}
@ -40,7 +41,19 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
}
void MPSHooks::deviceSynchronize() const {
at::mps::device_synchronize();
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
}
void MPSHooks::commitStream() const {
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT);
}
void* MPSHooks::getCommandBuffer() const {
return at::mps::getDefaultMPSStream()->commandBuffer();
}
void* MPSHooks::getDispatchQueue() const {
return at::mps::getDefaultMPSStream()->queue();
}
void MPSHooks::emptyCache() const {

View File

@ -96,19 +96,10 @@ void MPSStream::commitAndWait() {
}
if (_commandBuffer) {
if (_enableCommitAndContinue) {
// no need to release the command buffer with CommitAndContinue
// This improves the performance by eliminating the overhead of recreating
// command buffers, and avoiding distruption to commitAndContinue's internal cache
id<MTLCommandBuffer> rootCommandBuffer = _commandBuffer.rootCommandBuffer;
[_commandBuffer commitAndContinue];
[rootCommandBuffer waitUntilCompleted];
} else {
[_commandBuffer commit];
[_commandBuffer waitUntilCompleted];
[_commandBuffer release];
_commandBuffer = nil;
}
[_commandBuffer commit];
[_commandBuffer waitUntilCompleted];
[_commandBuffer release];
_commandBuffer = nil;
}
}

View File

@ -108,8 +108,11 @@ list(APPEND ATen_VEC_TEST_SRCS
list(APPEND ATen_MPS_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_print.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_allocator.cpp
)
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_allocator.cpp)
if(APPLE AND USE_MPS)
list(APPEND ATen_MPS_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_objc_interface.mm)
endif()
# Caffe2 specific tests
if(BUILD_CAFFE2)

View File

@ -0,0 +1,86 @@
#include <gtest/gtest.h>
#include <torch/torch.h>
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
// this sample custom kernel is taken from:
// https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu
static const char* CUSTOM_KERNEL = R"MPS_ADD_ARRAYS(
#include <metal_stdlib>
using namespace metal;
kernel void add_arrays(device const float* inA,
device const float* inB,
device float* result,
uint index [[thread_position_in_grid]])
{
result[index] = inA[index] + inB[index];
}
)MPS_ADD_ARRAYS";
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
TEST(MPSObjCInterfaceTest, MPSCustomKernel) {
const unsigned int tensor_length = 100000UL;
// fail if mps isn't available
ASSERT_TRUE(torch::mps::is_available());
torch::Tensor cpu_input1 = torch::randn({tensor_length}, at::device(at::kCPU));
torch::Tensor cpu_input2 = torch::randn({tensor_length}, at::device(at::kCPU));
torch::Tensor cpu_output = cpu_input1 + cpu_input2;
torch::Tensor mps_input1 = cpu_input1.detach().to(at::kMPS);
torch::Tensor mps_input2 = cpu_input2.detach().to(at::kMPS);
torch::Tensor mps_output = torch::empty({tensor_length}, at::device(at::kMPS));
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
NSError *error = nil;
size_t numThreads = mps_output.numel();
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL]
options: nil
error: &error];
TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"];
TORCH_CHECK(customFunction, "Failed to create function state object for the kernel");
id<MTLComputePipelineState> kernelPSO = [device newComputePipelineStateWithFunction: customFunction error: &error];
TORCH_CHECK(kernelPSO, error.localizedDescription.UTF8String);
// Get a reference of the MPSStream MTLCommandBuffer.
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
// Get a reference of the MPSStream dispatch_queue. This is used for CPU side synchronization while encoding.
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
dispatch_sync(serialQueue, ^(){
// Start a compute pass.
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
// Encode the pipeline state object and its parameters.
[computeEncoder setComputePipelineState: kernelPSO];
[computeEncoder setBuffer: getMTLBufferStorage(mps_input1) offset:0 atIndex:0];
[computeEncoder setBuffer: getMTLBufferStorage(mps_input2) offset:0 atIndex:1];
[computeEncoder setBuffer: getMTLBufferStorage(mps_output) offset:0 atIndex:2];
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
// Calculate a thread group size.
NSUInteger threadsPerGroupSize = std::min(kernelPSO.maxTotalThreadsPerThreadgroup, numThreads);
MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1);
// Encode the compute command.
[computeEncoder dispatchThreads: gridSize threadsPerThreadgroup: threadGroupSize];
[computeEncoder endEncoding];
torch::mps::commit();
});
}
// synchronize the MPS stream before reading back from MPS buffer
torch::mps::synchronize();
ASSERT_TRUE(at::allclose(cpu_output, mps_output.to(at::kCPU)));
}

View File

@ -1686,7 +1686,9 @@ if(BUILD_TEST)
foreach(test_src ${Caffe2_MPS_TEST_SRCS})
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} "${test_src}")
target_link_libraries(${test_name} torch_library gtest_main)
find_library(metal NAMES Metal)
find_library(foundation NAMES Foundation)
target_link_libraries(${test_name} torch_library gtest_main ${metal} ${foundation})
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})

View File

@ -1440,7 +1440,7 @@ class _TensorBase(metaclass=_TensorMeta):
def _multiprocessing_init() -> None: ...
# Defined in torch/csrc/mps/Module.cpp
def _mps_synchronize() -> None: ...
def _mps_deviceSynchronize() -> None: ...
def _mps_get_default_generator() -> Generator: ...
def _mps_emptyCache() -> None: ...
def _mps_setMemoryFraction(fraction: _float) -> None: ...

View File

@ -5,6 +5,15 @@
#include <cstddef>
#include <cstdint>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
#else
typedef void* MTLCommandBuffer_t;
typedef void* dispatch_queue_t;
#endif
namespace torch {
namespace mps {
@ -15,7 +24,19 @@ bool TORCH_API is_available();
void TORCH_API manual_seed(uint64_t seed);
/// Waits for all streams on a MPS device to complete.
/// See this link for more info:
/// https://developer.apple.com/documentation/metal/mtlcommandbuffer/1443039-waituntilcompleted
void TORCH_API synchronize();
/// Submits the currently active command buffer to run on the MPS device
void TORCH_API commit();
/// Get the current command buffer to encode the Metal commands
MTLCommandBuffer_t TORCH_API get_command_buffer();
/// Get the dispatch_queue_t to synchronize encoding the custom kernels
/// with the PyTorch MPS backend
dispatch_queue_t TORCH_API get_dispatch_queue();
} // namespace mps
} // namespace torch

View File

@ -25,9 +25,22 @@ void manual_seed(uint64_t seed) {
}
void synchronize() {
TORCH_CHECK(is_available(), "No MPS devices are available");
at::detail::getMPSHooks().deviceSynchronize();
}
void commit() {
at::detail::getMPSHooks().commitStream();
}
MTLCommandBuffer_t get_command_buffer() {
return static_cast<MTLCommandBuffer_t>(
at::detail::getMPSHooks().getCommandBuffer());
}
dispatch_queue_t get_dispatch_queue() {
return static_cast<dispatch_queue_t>(
at::detail::getMPSHooks().getDispatchQueue());
}
} // namespace mps
} // namespace torch

View File

@ -72,7 +72,9 @@ static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
static PyObject* MPSModule_deviceSynchronize(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().deviceSynchronize();
Py_RETURN_NONE;
@ -120,7 +122,10 @@ static PyObject* MPSModule_driverAllocatedMemory(
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
{"_mps_deviceSynchronize",
MPSModule_deviceSynchronize,
METH_NOARGS,
nullptr},
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_mps_is_on_macos_13_or_newer",

View File

@ -16,7 +16,7 @@ def _get_default_mps_generator() -> torch._C.Generator:
def synchronize() -> None:
r"""Waits for all kernels in all streams on a MPS device to complete."""
return torch._C._mps_synchronize()
return torch._C._mps_deviceSynchronize()
def get_rng_state() -> Tensor:
r"""Returns the random number generator state as a ByteTensor."""