mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add support for Custom Kernels (#100661)
- This change introduces these APIs to enable developing custom kernels on the MPS Stream: `torch::mps::get_command_buffer()` `torch::mps::get_dispatch_queue()` `torch::mps::commit()` - Add ObjC test case Pull Request resolved: https://github.com/pytorch/pytorch/pull/100661 Approved by: https://github.com/kulinseth, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
d9d98b4d54
commit
f39cda83d1
@ -17,48 +17,55 @@ class Context;
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API MPSHooksInterface {
|
||||
// this fails the implementation if MPSHooks functions are called, but
|
||||
// MPS backend is not present.
|
||||
#define FAIL_MPSHOOKS_FUNC(func) \
|
||||
TORCH_CHECK(false, "Cannot execute ", func ,"() without MPS backend.");
|
||||
|
||||
virtual ~MPSHooksInterface() = default;
|
||||
|
||||
// Initialize the MPS library state
|
||||
virtual void initMPS() const {
|
||||
AT_ERROR("Cannot initialize MPS without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual bool hasMPS() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool isOnMacOS13orNewer(unsigned minor = 0) const {
|
||||
AT_ERROR("MPS backend is not available.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual const Generator& getDefaultMPSGenerator() const {
|
||||
AT_ERROR("Cannot get default MPS generator without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual Allocator* getMPSDeviceAllocator() const {
|
||||
AT_ERROR("MPSDeviceAllocator requires MPS.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void deviceSynchronize() const {
|
||||
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual void commitStream() const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual void* getCommandBuffer() const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual void* getDispatchQueue() const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void emptyCache() const {
|
||||
AT_ERROR("Cannot execute emptyCache() without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual size_t getCurrentAllocatedMemory() const {
|
||||
AT_ERROR("Cannot execute getCurrentAllocatedMemory() without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual size_t getDriverAllocatedMemory() const {
|
||||
AT_ERROR("Cannot execute getDriverAllocatedMemory() without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void setMemoryFraction(double /*ratio*/) const {
|
||||
AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
#undef FAIL_MPSHOOKS_FUNC
|
||||
};
|
||||
|
||||
struct TORCH_API MPSHooksArgs {};
|
||||
|
||||
@ -80,7 +80,6 @@ class TORCH_API MPSDevice {
|
||||
|
||||
TORCH_API bool is_available();
|
||||
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
|
||||
TORCH_API void device_synchronize();
|
||||
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
||||
|
||||
} // namespace mps
|
||||
|
||||
@ -141,9 +141,5 @@ bool is_macos_13_or_newer(MacOSVersion version) {
|
||||
return MPSDevice::getInstance()->isMacOS13Plus(version);
|
||||
}
|
||||
|
||||
void device_synchronize() {
|
||||
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
} // namespace at
|
||||
|
||||
@ -12,11 +12,22 @@ namespace at { namespace mps {
|
||||
struct MPSHooks : public at::MPSHooksInterface {
|
||||
MPSHooks(at::MPSHooksArgs) {}
|
||||
void initMPS() const override;
|
||||
|
||||
// MPSDevice interface
|
||||
bool hasMPS() const override;
|
||||
bool isOnMacOS13orNewer(unsigned minor) const override;
|
||||
Allocator* getMPSDeviceAllocator() const override;
|
||||
|
||||
// MPSGeneratorImpl interface
|
||||
const Generator& getDefaultMPSGenerator() const override;
|
||||
|
||||
// MPSStream interface
|
||||
void deviceSynchronize() const override;
|
||||
void commitStream() const override;
|
||||
void* getCommandBuffer() const override;
|
||||
void* getDispatchQueue() const override;
|
||||
|
||||
// MPSAllocator interface
|
||||
Allocator* getMPSDeviceAllocator() const override;
|
||||
void emptyCache() const override;
|
||||
size_t getCurrentAllocatedMemory() const override;
|
||||
size_t getDriverAllocatedMemory() const override;
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/mps/MPSHooks.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSHooks.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
@ -26,7 +27,7 @@ bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const {
|
||||
case 2:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
default:
|
||||
TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+");
|
||||
TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.2+");
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
}
|
||||
}
|
||||
@ -40,7 +41,19 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
|
||||
}
|
||||
|
||||
void MPSHooks::deviceSynchronize() const {
|
||||
at::mps::device_synchronize();
|
||||
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
}
|
||||
|
||||
void MPSHooks::commitStream() const {
|
||||
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT);
|
||||
}
|
||||
|
||||
void* MPSHooks::getCommandBuffer() const {
|
||||
return at::mps::getDefaultMPSStream()->commandBuffer();
|
||||
}
|
||||
|
||||
void* MPSHooks::getDispatchQueue() const {
|
||||
return at::mps::getDefaultMPSStream()->queue();
|
||||
}
|
||||
|
||||
void MPSHooks::emptyCache() const {
|
||||
@ -96,19 +96,10 @@ void MPSStream::commitAndWait() {
|
||||
}
|
||||
|
||||
if (_commandBuffer) {
|
||||
if (_enableCommitAndContinue) {
|
||||
// no need to release the command buffer with CommitAndContinue
|
||||
// This improves the performance by eliminating the overhead of recreating
|
||||
// command buffers, and avoiding distruption to commitAndContinue's internal cache
|
||||
id<MTLCommandBuffer> rootCommandBuffer = _commandBuffer.rootCommandBuffer;
|
||||
[_commandBuffer commitAndContinue];
|
||||
[rootCommandBuffer waitUntilCompleted];
|
||||
} else {
|
||||
[_commandBuffer commit];
|
||||
[_commandBuffer waitUntilCompleted];
|
||||
[_commandBuffer release];
|
||||
_commandBuffer = nil;
|
||||
}
|
||||
[_commandBuffer commit];
|
||||
[_commandBuffer waitUntilCompleted];
|
||||
[_commandBuffer release];
|
||||
_commandBuffer = nil;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -108,8 +108,11 @@ list(APPEND ATen_VEC_TEST_SRCS
|
||||
|
||||
list(APPEND ATen_MPS_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_print.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_allocator.cpp
|
||||
)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_allocator.cpp)
|
||||
if(APPLE AND USE_MPS)
|
||||
list(APPEND ATen_MPS_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mps_test_objc_interface.mm)
|
||||
endif()
|
||||
|
||||
# Caffe2 specific tests
|
||||
if(BUILD_CAFFE2)
|
||||
|
||||
86
aten/src/ATen/test/mps_test_objc_interface.mm
Normal file
86
aten/src/ATen/test/mps_test_objc_interface.mm
Normal file
@ -0,0 +1,86 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/torch.h>
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
// this sample custom kernel is taken from:
|
||||
// https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu
|
||||
static const char* CUSTOM_KERNEL = R"MPS_ADD_ARRAYS(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void add_arrays(device const float* inA,
|
||||
device const float* inB,
|
||||
device float* result,
|
||||
uint index [[thread_position_in_grid]])
|
||||
{
|
||||
result[index] = inA[index] + inB[index];
|
||||
}
|
||||
)MPS_ADD_ARRAYS";
|
||||
|
||||
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
}
|
||||
|
||||
TEST(MPSObjCInterfaceTest, MPSCustomKernel) {
|
||||
const unsigned int tensor_length = 100000UL;
|
||||
|
||||
// fail if mps isn't available
|
||||
ASSERT_TRUE(torch::mps::is_available());
|
||||
|
||||
torch::Tensor cpu_input1 = torch::randn({tensor_length}, at::device(at::kCPU));
|
||||
torch::Tensor cpu_input2 = torch::randn({tensor_length}, at::device(at::kCPU));
|
||||
torch::Tensor cpu_output = cpu_input1 + cpu_input2;
|
||||
|
||||
torch::Tensor mps_input1 = cpu_input1.detach().to(at::kMPS);
|
||||
torch::Tensor mps_input2 = cpu_input2.detach().to(at::kMPS);
|
||||
torch::Tensor mps_output = torch::empty({tensor_length}, at::device(at::kMPS));
|
||||
|
||||
@autoreleasepool {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
NSError *error = nil;
|
||||
size_t numThreads = mps_output.numel();
|
||||
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL]
|
||||
options: nil
|
||||
error: &error];
|
||||
TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
|
||||
|
||||
id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"];
|
||||
TORCH_CHECK(customFunction, "Failed to create function state object for the kernel");
|
||||
|
||||
id<MTLComputePipelineState> kernelPSO = [device newComputePipelineStateWithFunction: customFunction error: &error];
|
||||
TORCH_CHECK(kernelPSO, error.localizedDescription.UTF8String);
|
||||
|
||||
// Get a reference of the MPSStream MTLCommandBuffer.
|
||||
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
|
||||
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
|
||||
|
||||
// Get a reference of the MPSStream dispatch_queue. This is used for CPU side synchronization while encoding.
|
||||
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
|
||||
dispatch_sync(serialQueue, ^(){
|
||||
// Start a compute pass.
|
||||
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
||||
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
||||
|
||||
// Encode the pipeline state object and its parameters.
|
||||
[computeEncoder setComputePipelineState: kernelPSO];
|
||||
[computeEncoder setBuffer: getMTLBufferStorage(mps_input1) offset:0 atIndex:0];
|
||||
[computeEncoder setBuffer: getMTLBufferStorage(mps_input2) offset:0 atIndex:1];
|
||||
[computeEncoder setBuffer: getMTLBufferStorage(mps_output) offset:0 atIndex:2];
|
||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
||||
|
||||
// Calculate a thread group size.
|
||||
NSUInteger threadsPerGroupSize = std::min(kernelPSO.maxTotalThreadsPerThreadgroup, numThreads);
|
||||
MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1);
|
||||
|
||||
// Encode the compute command.
|
||||
[computeEncoder dispatchThreads: gridSize threadsPerThreadgroup: threadGroupSize];
|
||||
[computeEncoder endEncoding];
|
||||
|
||||
torch::mps::commit();
|
||||
});
|
||||
}
|
||||
// synchronize the MPS stream before reading back from MPS buffer
|
||||
torch::mps::synchronize();
|
||||
|
||||
ASSERT_TRUE(at::allclose(cpu_output, mps_output.to(at::kCPU)));
|
||||
}
|
||||
@ -1686,7 +1686,9 @@ if(BUILD_TEST)
|
||||
foreach(test_src ${Caffe2_MPS_TEST_SRCS})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
find_library(metal NAMES Metal)
|
||||
find_library(foundation NAMES Foundation)
|
||||
target_link_libraries(${test_name} torch_library gtest_main ${metal} ${foundation})
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
|
||||
|
||||
@ -1440,7 +1440,7 @@ class _TensorBase(metaclass=_TensorMeta):
|
||||
def _multiprocessing_init() -> None: ...
|
||||
|
||||
# Defined in torch/csrc/mps/Module.cpp
|
||||
def _mps_synchronize() -> None: ...
|
||||
def _mps_deviceSynchronize() -> None: ...
|
||||
def _mps_get_default_generator() -> Generator: ...
|
||||
def _mps_emptyCache() -> None: ...
|
||||
def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
||||
|
||||
@ -5,6 +5,15 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
|
||||
#else
|
||||
typedef void* MTLCommandBuffer_t;
|
||||
typedef void* dispatch_queue_t;
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace mps {
|
||||
|
||||
@ -15,7 +24,19 @@ bool TORCH_API is_available();
|
||||
void TORCH_API manual_seed(uint64_t seed);
|
||||
|
||||
/// Waits for all streams on a MPS device to complete.
|
||||
/// See this link for more info:
|
||||
/// https://developer.apple.com/documentation/metal/mtlcommandbuffer/1443039-waituntilcompleted
|
||||
void TORCH_API synchronize();
|
||||
|
||||
/// Submits the currently active command buffer to run on the MPS device
|
||||
void TORCH_API commit();
|
||||
|
||||
/// Get the current command buffer to encode the Metal commands
|
||||
MTLCommandBuffer_t TORCH_API get_command_buffer();
|
||||
|
||||
/// Get the dispatch_queue_t to synchronize encoding the custom kernels
|
||||
/// with the PyTorch MPS backend
|
||||
dispatch_queue_t TORCH_API get_dispatch_queue();
|
||||
|
||||
} // namespace mps
|
||||
} // namespace torch
|
||||
|
||||
@ -25,9 +25,22 @@ void manual_seed(uint64_t seed) {
|
||||
}
|
||||
|
||||
void synchronize() {
|
||||
TORCH_CHECK(is_available(), "No MPS devices are available");
|
||||
at::detail::getMPSHooks().deviceSynchronize();
|
||||
}
|
||||
|
||||
void commit() {
|
||||
at::detail::getMPSHooks().commitStream();
|
||||
}
|
||||
|
||||
MTLCommandBuffer_t get_command_buffer() {
|
||||
return static_cast<MTLCommandBuffer_t>(
|
||||
at::detail::getMPSHooks().getCommandBuffer());
|
||||
}
|
||||
|
||||
dispatch_queue_t get_dispatch_queue() {
|
||||
return static_cast<dispatch_queue_t>(
|
||||
at::detail::getMPSHooks().getDispatchQueue());
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
} // namespace torch
|
||||
|
||||
@ -72,7 +72,9 @@ static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
|
||||
static PyObject* MPSModule_deviceSynchronize(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::detail::getMPSHooks().deviceSynchronize();
|
||||
Py_RETURN_NONE;
|
||||
@ -120,7 +122,10 @@ static PyObject* MPSModule_driverAllocatedMemory(
|
||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
||||
// cppcoreguidelines-avoid-c-arrays)
|
||||
static struct PyMethodDef _MPSModule_methods[] = {
|
||||
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
|
||||
{"_mps_deviceSynchronize",
|
||||
MPSModule_deviceSynchronize,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
|
||||
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
||||
{"_mps_is_on_macos_13_or_newer",
|
||||
|
||||
@ -16,7 +16,7 @@ def _get_default_mps_generator() -> torch._C.Generator:
|
||||
|
||||
def synchronize() -> None:
|
||||
r"""Waits for all kernels in all streams on a MPS device to complete."""
|
||||
return torch._C._mps_synchronize()
|
||||
return torch._C._mps_deviceSynchronize()
|
||||
|
||||
def get_rng_state() -> Tensor:
|
||||
r"""Returns the random number generator state as a ByteTensor."""
|
||||
|
||||
Reference in New Issue
Block a user