mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Prerequisite for MPS C++ extension (#102483)
in order to add mps kernels to torchvision codebase, we need to expose mps headers and allow objc++ files used in extensions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102483 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
0c9117a61f
commit
3c0072e7c0
@ -544,7 +544,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
|
||||
|
||||
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS})
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h})
|
||||
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${mps_h} ${native_mps_h} ${miopen_h})
|
||||
# Metal
|
||||
if(USE_PYTORCH_METAL_EXPORT)
|
||||
# Add files needed from exporting metal models(optimized_for_mobile)
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <os/signpost.h>
|
||||
#include <os/log.h>
|
||||
@ -51,15 +50,7 @@ struct BaseInfo {
|
||||
// handle used to identify the profile info's instance (usually the pointer)
|
||||
const uintptr_t handle;
|
||||
|
||||
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const {
|
||||
// the gpuTime will be non-zero mainly for event-based signposts.
|
||||
// The interval-based signposts will have "duration" as well as accumulated
|
||||
// total GPU time, up to the point of execution.
|
||||
return fmt::format("{}{}",
|
||||
gpuTime > 0.0 ? fmt::format(", gpu={:.3f} ms", gpuTime) : "",
|
||||
schedulingTime > 0.0 ? fmt::format(", cpu={:.3f} ms", schedulingTime) : "");
|
||||
}
|
||||
|
||||
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
|
||||
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
|
||||
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
|
||||
if (tensor.defined()) {
|
||||
@ -91,11 +82,7 @@ struct OperationInfo : BaseInfo {
|
||||
uint64_t runCount = 0;
|
||||
std::string strKey;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override {
|
||||
return fmt::format("aten::{} (id={}{}, run={}{})",
|
||||
strKey, type == Type::GRAPH ? "G" : "K", profileId, runCount,
|
||||
BaseInfo::toString(gpuTime, schedulingTime));
|
||||
}
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
|
||||
// builds a string for a kernel
|
||||
static std::string buildKernelString(const std::string& kernelName,
|
||||
@ -123,12 +110,7 @@ struct CpuFbInfo : BaseInfo {
|
||||
std::string strKey;
|
||||
uint64_t startTime = 0;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override {
|
||||
return fmt::format("CPU Fallback::{} (id={}, run={}, CopyOverhead={}{})",
|
||||
strKey, profileId, runCount,
|
||||
getIMPSAllocator()->formatSize(currentCopyOverhead),
|
||||
BaseInfo::toString(0.0, schedulingTime));
|
||||
}
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
|
||||
void updateCopyOverhead(const TensorList& tensors) {
|
||||
currentCopyOverhead = 0;
|
||||
@ -161,27 +143,9 @@ struct CopyInfo : BaseInfo {
|
||||
// for copies that don't use blitters, we measure CPU time
|
||||
uint64_t startTime = 0;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override {
|
||||
return fmt::format("{}Copy{}: {} --> {} (len={}{})",
|
||||
// Copies could be using Blit Encoder, or using regular
|
||||
// memcpy() on Unified memory
|
||||
usesBlitter ? "Blit" : "Mem",
|
||||
// CopySync indicates COMMIT_AND_WAIT was used to synchronize
|
||||
// the GPU stream with CPU after the blocking copy
|
||||
isNonBlocking ? "" : "Sync", srcStrKey, dstStrKey,
|
||||
getIMPSAllocator()->formatSize(length),
|
||||
BaseInfo::toString(gpuTime, schedulingTime));
|
||||
}
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
|
||||
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false) {
|
||||
if (tensor.has_value()) {
|
||||
return BaseInfo::buildTensorString(*tensor, includeBufferId);
|
||||
}
|
||||
// if tensor is not defined (e.g., copy_blit_mps()), then use buffer
|
||||
// pointer to build the string.
|
||||
const bool isBufferOnMPS = isStorageOnMPS(buffer, tensor);
|
||||
return fmt::format("{}:{:p}", isBufferOnMPS ? "MPS" : "CPU", buffer);
|
||||
}
|
||||
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
|
||||
|
||||
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
|
||||
if (tensor.has_value()) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
// these need to be literal strings when passed to os_signpost*()
|
||||
// function macros; so no LUTs could be used
|
||||
@ -20,6 +21,57 @@
|
||||
namespace at::mps {
|
||||
namespace Profiler {
|
||||
|
||||
const std::string BaseInfo::toString(double gpuTime, double schedulingTime) const {
|
||||
// the gpuTime will be non-zero mainly for event-based signposts.
|
||||
// The interval-based signposts will have "duration" as well as accumulated
|
||||
// total GPU time, up to the point of execution.
|
||||
return fmt::format("{}{}",
|
||||
gpuTime > 0.0 ? fmt::format(", gpu={:.3f} ms", gpuTime) : "",
|
||||
schedulingTime > 0.0 ? fmt::format(", cpu={:.3f} ms", schedulingTime) : "");
|
||||
}
|
||||
|
||||
const std::string OperationInfo::toString(double gpuTime, double schedulingTime) const {
|
||||
return fmt::format("aten::{} (id={}{}, run={}{})",
|
||||
strKey,
|
||||
type == Type::GRAPH ? "G" : "K",
|
||||
profileId,
|
||||
runCount,
|
||||
BaseInfo::toString(gpuTime, schedulingTime));
|
||||
}
|
||||
|
||||
const std::string CpuFbInfo::toString(double gpuTime, double schedulingTime) const {
|
||||
return fmt::format("CPU Fallback::{} (id={}, run={}, CopyOverhead={}{})",
|
||||
strKey,
|
||||
profileId,
|
||||
runCount,
|
||||
getIMPSAllocator()->formatSize(currentCopyOverhead),
|
||||
BaseInfo::toString(0.0, schedulingTime));
|
||||
}
|
||||
|
||||
const std::string CopyInfo::toString(double gpuTime, double schedulingTime) const {
|
||||
return fmt::format("{}Copy{}: {} --> {} (len={}{})",
|
||||
// Copies could be using Blit Encoder, or using regular
|
||||
// memcpy() on Unified memory
|
||||
usesBlitter ? "Blit" : "Mem",
|
||||
// CopySync indicates COMMIT_AND_WAIT was used to synchronize
|
||||
// the GPU stream with CPU after the blocking copy
|
||||
isNonBlocking ? "" : "Sync",
|
||||
srcStrKey,
|
||||
dstStrKey,
|
||||
getIMPSAllocator()->formatSize(length),
|
||||
BaseInfo::toString(gpuTime, schedulingTime));
|
||||
}
|
||||
|
||||
std::string CopyInfo::buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId) {
|
||||
if (tensor.has_value()) {
|
||||
return BaseInfo::buildTensorString(*tensor, includeBufferId);
|
||||
}
|
||||
// if tensor is not defined (e.g., copy_blit_mps()), then use buffer
|
||||
// pointer to build the string.
|
||||
const bool isBufferOnMPS = isStorageOnMPS(buffer, tensor);
|
||||
return fmt::format("{}:{:p}", isBufferOnMPS ? "MPS" : "CPU", buffer);
|
||||
}
|
||||
|
||||
MPSProfiler::MPSProfiler() : m_os_log_events(nullptr), m_os_log_intervals(nullptr) {
|
||||
// see enum LogOptions for the description.
|
||||
static const char* log_options_str = getenv(kEVLogProfileInfoStr);
|
||||
|
2
setup.py
2
setup.py
@ -1090,6 +1090,7 @@ def main():
|
||||
'include/ATen/hip/detail/*.cuh',
|
||||
'include/ATen/hip/detail/*.h',
|
||||
'include/ATen/hip/impl/*.h',
|
||||
'include/ATen/mps/*.h',
|
||||
'include/ATen/miopen/*.h',
|
||||
'include/ATen/detail/*.h',
|
||||
'include/ATen/native/*.h',
|
||||
@ -1098,6 +1099,7 @@ def main():
|
||||
'include/ATen/native/cuda/*.cuh',
|
||||
'include/ATen/native/hip/*.h',
|
||||
'include/ATen/native/hip/*.cuh',
|
||||
'include/ATen/native/mps/*.h',
|
||||
'include/ATen/native/quantized/*.h',
|
||||
'include/ATen/native/quantized/cpu/*.h',
|
||||
'include/ATen/quantized/*.h',
|
||||
|
76
test/cpp_extensions/mps_extension.mm
Normal file
76
test/cpp_extensions/mps_extension.mm
Normal file
@ -0,0 +1,76 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
// this sample custom kernel is taken from:
|
||||
// https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu
|
||||
static const char* CUSTOM_KERNEL = R"MPS_ADD_ARRAYS(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void add_arrays(device const float* inA,
|
||||
device const float* inB,
|
||||
device float* result,
|
||||
uint index [[thread_position_in_grid]])
|
||||
{
|
||||
result[index] = inA[index] + inB[index];
|
||||
}
|
||||
)MPS_ADD_ARRAYS";
|
||||
|
||||
at::Tensor get_cpu_add_output(at::Tensor & cpu_input1, at::Tensor & cpu_input2) {
|
||||
return cpu_input1 + cpu_input2;
|
||||
}
|
||||
|
||||
at::Tensor get_mps_add_output(at::Tensor & mps_input1, at::Tensor & mps_input2) {
|
||||
|
||||
// smoke tests
|
||||
TORCH_CHECK(mps_input1.is_mps());
|
||||
TORCH_CHECK(mps_input2.is_mps());
|
||||
TORCH_CHECK(mps_input1.sizes() == mps_input2.sizes());
|
||||
|
||||
using namespace at::native::mps;
|
||||
at::Tensor mps_output = at::empty_like(mps_input1);
|
||||
|
||||
@autoreleasepool {
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
NSError *error = nil;
|
||||
size_t numThreads = mps_output.numel();
|
||||
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL]
|
||||
options: nil
|
||||
error: &error];
|
||||
TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
|
||||
|
||||
id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"];
|
||||
TORCH_CHECK(customFunction, "Failed to create function state object for the kernel");
|
||||
|
||||
id<MTLComputePipelineState> kernelPSO = [device newComputePipelineStateWithFunction: customFunction error: &error];
|
||||
TORCH_CHECK(kernelPSO, error.localizedDescription.UTF8String);
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
// Start a compute pass.
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
||||
|
||||
// Encode the pipeline state object and its parameters.
|
||||
[computeEncoder setComputePipelineState: kernelPSO];
|
||||
[computeEncoder setBuffer: getMTLBufferStorage(mps_input1) offset:0 atIndex:0];
|
||||
[computeEncoder setBuffer: getMTLBufferStorage(mps_input2) offset:0 atIndex:1];
|
||||
[computeEncoder setBuffer: getMTLBufferStorage(mps_output) offset:0 atIndex:2];
|
||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
||||
|
||||
// Calculate a thread group size.
|
||||
NSUInteger threadsPerGroupSize = std::min(kernelPSO.maxTotalThreadsPerThreadgroup, numThreads);
|
||||
MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1);
|
||||
|
||||
// Encode the compute command.
|
||||
[computeEncoder dispatchThreads: gridSize threadsPerThreadgroup: threadGroupSize];
|
||||
|
||||
});
|
||||
}
|
||||
return mps_output;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("get_cpu_add_output", &get_cpu_add_output);
|
||||
m.def("get_mps_add_output", &get_mps_add_output);
|
||||
}
|
@ -49,6 +49,14 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
extension = CppExtension(
|
||||
'torch_test_cpp_extension.mps',
|
||||
['mps_extension.mm'],
|
||||
extra_compile_args=CXX_FLAGS,
|
||||
)
|
||||
ext_modules.append(extension)
|
||||
|
||||
# todo(mkozuki): Figure out the root cause
|
||||
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
# malfet: One should not assume that PyTorch re-exports CUDA dependencies
|
||||
|
@ -85,6 +85,19 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not torch.backends.mps.is_available(), "MPS not found")
|
||||
def test_mps_extension(self):
|
||||
import torch_test_cpp_extension.mps as mps_extension
|
||||
|
||||
tensor_length = 100000
|
||||
x = torch.zeros(tensor_length, device="cpu", dtype=torch.float32)
|
||||
y = torch.zeros(tensor_length, device="cpu", dtype=torch.float32)
|
||||
|
||||
cpu_output = mps_extension.get_cpu_add_output(x, y)
|
||||
mps_output = mps_extension.get_mps_add_output(x.to("mps"), y.to("mps"))
|
||||
|
||||
self.assertEqual(cpu_output, mps_output.to("cpu"))
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
|
@ -27,6 +27,7 @@ if TEST_CUDA and torch.version.cuda is not None: # the skip CUDNN test for ROCm
|
||||
TEST_CUDNN = (
|
||||
TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()
|
||||
)
|
||||
TEST_MPS = torch.backends.mps.is_available()
|
||||
IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
@ -116,6 +117,26 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
||||
def test_mps_extension(self):
|
||||
module = torch.utils.cpp_extension.load(
|
||||
name="torch_test_mps_extension",
|
||||
sources=[
|
||||
"cpp_extensions/mps_extension.mm",
|
||||
],
|
||||
verbose=True,
|
||||
keep_intermediates=False,
|
||||
)
|
||||
|
||||
tensor_length = 100000
|
||||
x = torch.zeros(tensor_length, device="cpu", dtype=torch.float32)
|
||||
y = torch.zeros(tensor_length, device="cpu", dtype=torch.float32)
|
||||
|
||||
cpu_output = module.get_cpu_add_output(x, y)
|
||||
mps_output = module.get_mps_add_output(x.to("mps"), y.to("mps"))
|
||||
|
||||
self.assertEqual(cpu_output, mps_output.to("cpu"))
|
||||
|
||||
def _run_jit_cuda_archflags(self, flags, expected):
|
||||
# Compile an extension with given `flags`
|
||||
def _check_cuobjdump_output(expected_values, is_ptx=False):
|
||||
|
@ -524,8 +524,10 @@ class BuildExtension(build_ext):
|
||||
if 'nvcc_dlink' in extension.extra_compile_args:
|
||||
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
|
||||
# Register .cu, .cuh and .hip as valid source extensions.
|
||||
# Register .cu, .cuh, .hip, and .mm as valid source extensions.
|
||||
self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
|
||||
if torch.backends.mps.is_built():
|
||||
self.compiler.src_extensions += ['.mm']
|
||||
# Save the original _compile method for later.
|
||||
if self.compiler.compiler_type == 'msvc':
|
||||
self.compiler._cpp_extensions += ['.cu', '.cuh']
|
||||
|
Reference in New Issue
Block a user