Compare commits

...

4 Commits

Author SHA1 Message Date
94f9f4a41c Synchronize mps backend in the timer 2025-05-02 10:51:59 -07:00
996379b04f Do the strides correctly 2025-05-02 10:40:15 -07:00
3912d5b44d Delete copy constructor and assignment from MPSCachedKernel. Include
MPSSequoiaOps header to Linear. Add check for macos version to
use correct code path to avoid breaking previous OS.
2025-04-29 14:27:11 -07:00
ad993f34ca Adding a direct MPS kernel path to linear op and MPS kernel caching mechanism for improved perf. 2025-04-25 13:37:04 -07:00
4 changed files with 235 additions and 55 deletions

View File

@ -100,6 +100,7 @@ MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
const TensorBase& input,
bool includesInt64 = false);
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
// The MPSShape could vary based on memory format
@ -160,6 +161,26 @@ string get_mem_format_string(c10::MemoryFormat memory_format);
using MPSCacheKey = uint64_t;
struct MPSCachedKernel {
MPSCachedKernel(NSObject* object) : _object([object retain]) {}
virtual ~MPSCachedKernel() {
[_object release];
_object = nullptr;
}
// Delete copy constructor and assignment
MPSCachedKernel(const MPSCachedKernel&) = delete;
void operator=(const MPSCachedKernel&) = delete;
template <typename T>
inline T* kernel() const {
return (T*)_object;
}
private:
NSObject* _object = nullptr;
};
// derive this class to cache a graph and its inputs/outputs
// can be used to store any NSObject
struct MPSCachedGraph {
@ -214,6 +235,97 @@ struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
MPSGraphTensor* gradInputTensor_ = nil;
};
struct MPSKernelCache {
typedef MPSCachedKernel* (^CreateCachedKernelBlock)();
struct CacheEntry {
CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {}
MPSCachedKernel* cachedKernel_ = nullptr;
std::string key_;
};
public:
static MPSKernelCache* getInstance() {
if (_instance_cache == nullptr) {
_instance_cache = new MPSKernelCache();
}
return _instance_cache;
}
~MPSKernelCache() {
dispatch_release(serialQueue_);
for (const auto& i : cache_) {
delete i.second.cachedKernel_;
}
}
// Disallow the copy constructor and operator= functions
MPSKernelCache(const MPSKernelCache&) = delete;
void operator=(const MPSKernelCache&) = delete;
MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) {
__block MPSCachedKernel* cachedKernel = nil;
MPSCacheKey hash = std::hash<std::string>{}(key);
dispatch_sync_with_rethrow(serialQueue_, ^() {
if (cache_.count(hash) != 0) {
auto& entry = cache_.at(hash);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n");
cachedKernel = entry.cachedKernel_;
} else {
cachedKernel = createCacheBlock();
CacheEntry entry(key, cachedKernel);
cache_.emplace(hash, entry);
}
});
return cachedKernel;
}
template <typename T>
inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) {
return static_cast<T*>(CreateCachedKernel(key, createCacheBlock));
}
MPSCachedKernel* LookUp(const std::string& key) const {
__block MPSCachedKernel* cachedKernel = nil;
MPSCacheKey hash = std::hash<std::string>{}(key);
dispatch_sync_with_rethrow(serialQueue_, ^() {
if (cache_.count(hash) != 0) {
auto& entry = cache_.at(hash);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n");
cachedKernel = entry.cachedKernel_;
}
});
return cachedKernel;
}
template <typename T>
inline T* LookUpAs(const std::string& key) const {
return static_cast<T*>(LookUp(key));
}
private:
MPSKernelCache() {
serialQueue_ = dispatch_queue_create("kernel cache queue", DISPATCH_QUEUE_SERIAL);
}
static MPSKernelCache* _instance_cache;
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
dispatch_queue_t serialQueue_ = nullptr;
};
// Common template for creating cached kernel if missing
template <typename T>
inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function<MPSKernel*()> instantiate) {
auto cache_ = MPSKernelCache::getInstance();
if (auto rc = cache_->LookUpAs<T>(key)) {
return rc;
}
return cache_->CreateCachedKernelAs<T>(key, ^mps::MPSCachedKernel*() {
auto k_ = new mps::MPSCachedKernel(instantiate());
return k_;
});
}
// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
// Cache holding various keys mapped to graphs

View File

@ -468,7 +468,7 @@ MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* stride
offset:t.storage_offset() * t.element_size()
descriptor:srcTensorDesc] autorelease];
if (strides != nil) {
srcNDArray = [srcNDArray arrayViewWithShape:sizes strides:strides];
srcNDArray = getStridedMPSNDArray(t, srcNDArray);
}
return srcNDArray;
}
@ -477,7 +477,7 @@ MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes, const I
return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides));
}
static MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) {
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) {
auto strides = src.strides();
auto sizes = src.sizes();
auto nStrides = strides.size();
@ -779,6 +779,8 @@ string get_mem_format_string(c10::MemoryFormat memory_format) {
MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;
MPSKernelCache* MPSKernelCache::_instance_cache = nullptr;
void MPSGraphCache::profileCachedGraph(const CacheEntry& cacheEntry) const {
auto& profiler = getMPSProfiler();
if (profiler.isOperationProfilingEnabled()) {

View File

@ -1,6 +1,8 @@
// Copyright © 2022 Apple Inc.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/ExpandUtils.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/ops/linear_backward_native.h>
#include <ATen/ops/linear_native.h>
@ -9,6 +11,60 @@ namespace at::native {
using namespace mps;
static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const Tensor& bias, Tensor& output) {
bool is_bias_defined = bias.defined();
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const string key = "mps_linear" + getTensorsStringKey({input, weight, bias}, true, true);
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
mpsStream->endKernelCoalescing();
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
MPSDataType mpsDataType = getMPSDataType(weight.scalar_type());
auto inputNDArray = getMPSNDArray(input, input.sizes(), input.strides());
auto outNDArray = getMPSNDArray(output, output.sizes(), output.strides());
id<MTLBuffer> weightBuf = getMTLBufferStorage(weight);
MPSNDArrayDescriptor* weightDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType
shape:getMPSShape(weight.sizes())];
weightDesc.preferPackedRows = YES;
[weightDesc transposeDimension:0 withDimension:1];
MPSNDArray* weightNDArray = [[MPSNDArray alloc] initWithBuffer:weightBuf
offset:weight.storage_offset() * weight.element_size()
descriptor:weightDesc];
if (is_bias_defined) {
auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides());
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(
key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3]; });
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
[kernel encodeToCommandEncoder:computeEncoder
commandBuffer:commandBuffer
sourceArrays:@[ inputNDArray, weightNDArray, biasNDArray ]
destinationArray:outNDArray];
getMPSProfiler().endProfileKernel(kernel);
} else {
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(
key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; });
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
[kernel encodeToCommandEncoder:computeEncoder
commandBuffer:commandBuffer
sourceArrays:@[ inputNDArray, weightNDArray ]
destinationArray:outNDArray];
getMPSProfiler().endProfileKernel(kernel);
}
}
});
}
Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::optional<Tensor>& bias_opt) {
// wT = transpose(weight);
// y=x*wT+b
@ -19,6 +75,8 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps");
TORCH_CHECK(supportedFloatingOrComplexType(weight_arg), "MPS device does not support linear for non-float weights");
TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps");
TORCH_CHECK((input.scalar_type() != kComplexFloat && input.scalar_type() != kComplexHalf),
"mps linear does not support complex types");
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
const bool is_bias_defined = bias.defined();
@ -54,66 +112,70 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
return output;
}
MPSStream* stream = getCurrentMPSStream();
bool is_macos_15_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if (is_macos_15_or_newer) {
_mps_linear_nograph(input, weight, bias, output);
} else {
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* biasTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* biasTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
@autoreleasepool {
std::string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
@autoreleasepool {
string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-1
withDimension:-2
name:nil];
// matrixMultiplicationWithPrimary crashes for 5D tensors, see https://github.com/pytorch/pytorch/issues/114942
bool doReshape = input.dim() > 4;
if (!doReshape && is_bias_defined) {
// workaround to improve the performance with 3D+ inputs
doReshape =
input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32 && bias.dim() <= 1;
}
auto inputFlattened = doReshape ? [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil] : inputTensor;
auto outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputFlattened
secondaryTensor:weightTransposeTensor
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-1
withDimension:-2
name:nil];
// matrixMultiplicationWithPrimary crashes for 5D tensors, see https://github.com/pytorch/pytorch/issues/114942
bool doReshape = input.dim() > 4;
if (!doReshape && is_bias_defined) {
// workaround to improve the performance with 3D+ inputs
doReshape = input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32 &&
bias.dim() <= 1;
}
auto inputFlattened = doReshape ? [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil] : inputTensor;
auto outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputFlattened
secondaryTensor:weightTransposeTensor
name:nil];
if (is_bias_defined) {
newCachedGraph->biasTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, bias);
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor
secondaryTensor:newCachedGraph->biasTensor_
name:nil];
}
if (doReshape) {
outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output_size) name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->weightTensor_ = weightTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
Placeholder biasPlaceholder = Placeholder();
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
if (is_bias_defined) {
newCachedGraph->biasTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, bias);
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor
secondaryTensor:newCachedGraph->biasTensor_
name:nil];
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
}
if (doReshape) {
outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output_size) name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->weightTensor_ = weightTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
Placeholder biasPlaceholder = Placeholder();
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
if (is_bias_defined) {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
// Shave off '1' present at the end of the shape

View File

@ -21,6 +21,10 @@ elif torch.xpu.is_available():
def timer() -> float:
torch.xpu.synchronize()
return timeit.default_timer()
elif torch.mps.is_available():
def timer() -> float:
torch.mps.synchronize()
return timeit.default_timer()
elif torch._C._get_privateuse1_backend_name() != "privateuseone":
privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \
if torch._C._get_privateuse1_backend_name() != "cpu" else None