Compare commits

...

2 Commits

3 changed files with 217 additions and 4 deletions

View File

@ -147,6 +147,7 @@ MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStre
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
MPSGraph* make_mps_graph();
MPSKernel* make_mps_kernel();
void printTensorNDArray(const TensorBase& t);
MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType);
@ -160,8 +161,22 @@ string get_mem_format_string(c10::MemoryFormat memory_format);
using MPSCacheKey = uint64_t;
// derive this class to cache a graph and its inputs/outputs
// can be used to store any NSObject
struct MPSCachedKernel {
MPSCachedKernel(NSObject* object) : _object([object retain]) {}
virtual ~MPSCachedKernel() {
[_object release];
_object = nullptr;
}
template <typename T>
inline T* kernel() const {
return (T*)_object;
}
private:
NSObject* _object = nullptr;
};
struct MPSCachedGraph {
MPSCachedGraph(NSObject* object) : _object([object retain]) {}
virtual ~MPSCachedGraph() {
@ -214,6 +229,101 @@ struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
MPSGraphTensor* gradInputTensor_ = nil;
};
struct MPSKernelCache {
typedef MPSCachedKernel* (^CreateCachedKernelBlock)();
struct CacheEntry {
CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {}
// CacheEntry(const std::string& key, MPSKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {}
MPSCachedKernel* cachedKernel_ = nullptr;
std::string key_;
};
public:
static MPSKernelCache* getInstance() {
if (_instance_cache == nullptr) {
_instance_cache = new MPSKernelCache();
}
return _instance_cache;
}
~MPSKernelCache() {
dispatch_release(serialQueue_);
for (const auto& i : cache_) {
delete i.second.cachedKernel_;
}
}
// Disallow the copy constructor and operator= functions
MPSKernelCache(const MPSKernelCache&) = delete;
void operator=(const MPSKernelCache&) = delete;
MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) {
__block MPSCachedKernel* cachedKernel = nil;
MPSCacheKey hash = std::hash<std::string>{}(key);
dispatch_sync_with_rethrow(serialQueue_, ^() {
if(cache_.count(hash) != 0) {
auto& entry = cache_.at(hash);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
cachedKernel = entry.cachedKernel_;
} else {
cachedKernel = createCacheBlock();
CacheEntry entry(key, cachedKernel);
cache_.emplace(hash, entry);
profileCachedKernel(entry);
}
});
return cachedKernel;
}
template <typename T>
inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) {
return static_cast<T*>(CreateCachedKernel(key, createCacheBlock));
}
MPSCachedKernel* LookUp(const std::string& key) const {
__block MPSCachedKernel* cachedKernel = nil;
MPSCacheKey hash = std::hash<std::string>{}(key);
dispatch_sync_with_rethrow(serialQueue_, ^() {
if(cache_.count(hash) != 0) {
auto& entry = cache_.at(hash);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
cachedKernel = entry.cachedKernel_;
}
});
return cachedKernel;
}
template<typename T>
inline T* LookUpAs(const std::string& key) const {
return static_cast<T*>(LookUp(key));
}
private:
MPSKernelCache() {
serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
}
void profileCachedKernel(const CacheEntry& cacheEntry) const;
static MPSKernelCache* _instance_cache;
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
dispatch_queue_t serialQueue_ = nullptr;
};
// Common template for creating graph with a specified cache if missing
template <typename T>
inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function<MPSKernel*()> instantiate) {
auto cache_ = MPSKernelCache::getInstance();
if (auto rc = cache_->LookUpAs<T>(key)) {
return rc;
}
return cache_->CreateCachedKernelAs<T>(key, ^mps::MPSCachedKernel*() {
auto k_ = new mps::MPSCachedKernel(instantiate());
return k_;
});
}
// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
// Cache holding various keys mapped to graphs

View File

@ -315,7 +315,7 @@ std::string getArrayRefString(const IntArrayRef s) {
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
std::string str;
// The key format per tensor would look like ":Float32[1,1,1,10]:"
// The key format per tensor would look like ":Float32[1,1,1,10,]:"
for (const Tensor& tensor : tensors) {
str += ":";
if (tensor.defined()) {
@ -328,7 +328,7 @@ std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, boo
str += "-1";
} else {
str +=
std::string([[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","].UTF8String);
getArrayRefString(tensor.sizes());
}
}
str += "]";
@ -778,6 +778,19 @@ string get_mem_format_string(c10::MemoryFormat memory_format) {
MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;
void MPSKernelCache::profileCachedKernel(const CacheEntry& cacheEntry) const {
auto& profiler = getMPSProfiler();
if (profiler.isOperationProfilingEnabled()) {
std::string graphKey = cacheEntry.key_;
// for interval-based signpost tracing, we begin the interval here to be able
// to measure the time it takes to compile the graphs (if graph newly created),
// and also the time potentially spent on gather/scatter of graph's input tensors
// profiler.beginProfileKernel(cacheEntry.cachedKernel_->kernel(), graphKey, true);
}
}
MPSKernelCache* MPSKernelCache::_instance_cache = nullptr;
void MPSGraphCache::profileCachedGraph(const CacheEntry& cacheEntry) const {
auto& profiler = getMPSProfiler();
if (profiler.isOperationProfilingEnabled()) {

View File

@ -4,12 +4,101 @@
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/ops/linear_backward_native.h>
#include <ATen/ops/linear_native.h>
#include <ATen/mps/MPSProfiler.h>
namespace at::native {
using namespace mps;
static Tensor _mps_linear_new(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
using namespace mps;
if (input.scalar_type() == kComplexFloat || input.scalar_type() == kComplexHalf) {
TORCH_CHECK(false, "mps linear does not support complex types");
}
TORCH_CHECK(supportedFloatingOrComplexType(input), "MPS device does not support linear for non-float inputs");
TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps");
TORCH_CHECK(supportedFloatingOrComplexType(weight), "MPS device does not support linear for non-float weights");
TORCH_CHECK(weight.is_mps(), "Tensor for argument weight is on ", weight.device(), " but expected on mps");
Tensor bias;
bool has_bias = bias_opt.has_value();
if(has_bias){
bias = bias_opt.value();
TORCH_CHECK(bias.is_mps(), "Tensor for argument bias is on ", bias.device(), " but expected on mps");
TORCH_CHECK(supportedFloatingOrComplexType(bias), "MPS device does not support linear for non-float bias");
}
if (input.numel() == 0 || weight.numel() == 0) {
auto input_size = input.sizes();
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight.size(0));
Tensor output = at::empty(output_size, input.scalar_type(), std::nullopt, kMPS, std::nullopt, input.suggest_memory_format());
return output;
}
auto input_size = input.sizes();
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight.size(0));
Tensor output = at::empty(output_size, input.scalar_type(), std::nullopt, kMPS, std::nullopt, input.suggest_memory_format());
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const string key = "mps_linear" + getTensorsStringKey({input, weight, bias}, true);
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
mpsStream->endKernelCoalescing();
MPSDataType mpsDataType = getMPSDataType(weight.scalar_type());
auto inputNDArray = getMPSNDArray(input, input.sizes(), input.strides());
auto outNDArray = getMPSNDArray(output, output.sizes(), output.strides());
id<MTLBuffer> weightBuf = getMTLBufferStorage(weight);
MPSNDArrayDescriptor* weightTensorDesc =
[MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:getMPSShape(weight.sizes())];
weightTensorDesc.preferPackedRows = YES;
[weightTensorDesc transposeDimension:0 withDimension:1];
MPSNDArray* weightNDArray = [[MPSNDArray alloc] initWithBuffer:weightBuf
offset:weight.storage_offset() * weight.element_size()
descriptor:weightTensorDesc];
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
if(has_bias){
auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides());
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(key, [&]() {
return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3];
});
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
[kernel encodeToCommandEncoder:computeEncoder
commandBuffer:commandBuffer
sourceArrays:@[ inputNDArray, weightNDArray, biasNDArray]
destinationArray:outNDArray];
getMPSProfiler().endProfileKernel(kernel);
} else {
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(key, [&]() {
return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2];
});
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
[kernel encodeToCommandEncoder:computeEncoder
commandBuffer:commandBuffer
sourceArrays:@[ inputNDArray, weightNDArray]
destinationArray:outNDArray];
getMPSProfiler().endProfileKernel(kernel);
}
}
});
return output;
}
Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::optional<Tensor>& bias_opt) {
return _mps_linear_new(input, weight_arg, bias_opt);
// wT = transpose(weight);
// y=x*wT+b
@ -67,6 +156,7 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
@autoreleasepool {
string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);