mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][MPS] Apply clang-format to mps headers (#140906)
It was a mistake to amiss them in the past All changes in this PR except ones to .lintrunner.toml are generated by running `lintrunner -a --take CLANGFORMAT --all-files` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140906 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a7e147ef3
commit
99014a297c
@ -56,10 +56,12 @@ code = 'CLANGFORMAT'
|
||||
include_patterns = [
|
||||
'aten/src/ATen/*.h',
|
||||
'aten/src/ATen/mps/**/*.mm',
|
||||
'aten/src/ATen/mps/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.cpp',
|
||||
'aten/src/ATen/native/mps/**/*.metal',
|
||||
'aten/src/ATen/native/mps/**/*.mm',
|
||||
'aten/src/ATen/native/mps/**/*.h',
|
||||
'aten/src/ATen/native/vulkan/**/*.h',
|
||||
'aten/src/ATen/native/vulkan/**/*.cpp',
|
||||
'aten/src/ATen/native/cuda/MultiTensorApply.cuh',
|
||||
|
@ -12,8 +12,7 @@ C10_EXPORT TensorBase empty_mps(
|
||||
std::optional<Device> device_opt,
|
||||
std::optional<bool> pin_memory_opt,
|
||||
std::optional<c10::MemoryFormat> memory_format_opt);
|
||||
C10_EXPORT TensorBase empty_mps(
|
||||
IntArrayRef size, const TensorOptions &options);
|
||||
C10_EXPORT TensorBase empty_mps(IntArrayRef size, const TensorOptions& options);
|
||||
|
||||
C10_EXPORT TensorBase empty_strided_mps(
|
||||
IntArrayRef size,
|
||||
@ -24,6 +23,6 @@ C10_EXPORT TensorBase empty_strided_mps(
|
||||
C10_EXPORT TensorBase empty_strided_mps(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
const TensorOptions &options);
|
||||
const TensorOptions& options);
|
||||
|
||||
} // namespace at::detail
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
|
||||
static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
|
||||
struct __attribute__ ((packed)) packed_uint5{{
|
||||
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
||||
}};
|
||||
@ -120,7 +120,7 @@ kernel void scatter_kernel_1(uint linear_index [[thread_position_in
|
||||
}}
|
||||
)METAL_SCATTER";
|
||||
|
||||
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
|
||||
static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
|
||||
struct __attribute__ ((packed)) packed_uint5{{
|
||||
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
||||
}};
|
||||
|
@ -6,12 +6,12 @@
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <cstdio>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
// this implementation is based on CUDACachingAllocator.
|
||||
// It utilizes Metal Heaps to improve the performance with buffer allocation.
|
||||
@ -19,32 +19,34 @@
|
||||
// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
|
||||
namespace at::mps::HeapAllocator {
|
||||
|
||||
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
|
||||
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
|
||||
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
|
||||
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
|
||||
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
|
||||
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
|
||||
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
|
||||
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
|
||||
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
|
||||
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
|
||||
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
|
||||
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
|
||||
static const size_t kXLargeHeapD =
|
||||
MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
|
||||
static const size_t kXLargeHeapU =
|
||||
MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
|
||||
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
|
||||
|
||||
// buffer pools could be customized with a combination of usage flags
|
||||
enum UsageFlags : uint32_t {
|
||||
PRIVATE = 0,
|
||||
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
|
||||
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
|
||||
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
|
||||
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
|
||||
MANAGED = (1 << 2), // managed storage mode
|
||||
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
|
||||
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
|
||||
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
|
||||
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
|
||||
};
|
||||
// debug verbosity flags
|
||||
enum DebugVerbosity : uint32_t {
|
||||
SILENT = 0,
|
||||
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
|
||||
SILENT = 0,
|
||||
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
|
||||
ALLOCATIONS = (1 << 1), // print buffer allocations
|
||||
RECYCLES = (1 << 2), // print buffer recycling
|
||||
RELEASES = (1 << 3), // print buffer releases
|
||||
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
|
||||
RECYCLES = (1 << 2), // print buffer recycling
|
||||
RELEASES = (1 << 3), // print buffer releases
|
||||
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
|
||||
};
|
||||
|
||||
struct HeapBlock;
|
||||
@ -67,10 +69,8 @@ struct BufferBlock {
|
||||
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
|
||||
MPSEventPtr event;
|
||||
|
||||
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
|
||||
HeapBlock* Heap = nullptr) :
|
||||
buffer(Buffer), size(Size), requested_size(RequestedSize),
|
||||
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
|
||||
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr, HeapBlock* Heap = nullptr)
|
||||
: buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {}
|
||||
|
||||
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
|
||||
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
|
||||
@ -79,15 +79,19 @@ struct BufferBlock {
|
||||
assert(((Alignment - 1) & Alignment) == 0);
|
||||
return ((Size + Alignment - 1) & ~(Alignment - 1));
|
||||
}
|
||||
uint32_t retainCount() const { return [buffer retainCount]; }
|
||||
uint32_t retainCount() const {
|
||||
return [buffer retainCount];
|
||||
}
|
||||
};
|
||||
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
|
||||
|
||||
struct BufferPool;
|
||||
struct AllocParams {
|
||||
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
|
||||
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
|
||||
size_t size() const { return search_key.size; }
|
||||
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool)
|
||||
: search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {}
|
||||
size_t size() const {
|
||||
return search_key.size;
|
||||
}
|
||||
|
||||
BufferBlock search_key;
|
||||
BufferPool* pool;
|
||||
@ -102,7 +106,9 @@ struct AllocParams {
|
||||
|
||||
struct HeapBlock {
|
||||
id<MTLHeap> heap;
|
||||
struct { size_t total, available; } size;
|
||||
struct {
|
||||
size_t total, available;
|
||||
} size;
|
||||
BufferPool* pool;
|
||||
unsigned int n_buffers = 0;
|
||||
id_t heap_id;
|
||||
@ -111,9 +117,12 @@ struct HeapBlock {
|
||||
// counter to assign unique ids to heap blocks
|
||||
static uint64_t heap_counter;
|
||||
|
||||
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
|
||||
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
|
||||
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
|
||||
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool* Pool = nullptr)
|
||||
: heap(Heap),
|
||||
size({.total = Size, .available = Size}),
|
||||
pool(Pool),
|
||||
heap_id(Heap ? ++heap_counter : 0),
|
||||
is_split(true) {}
|
||||
|
||||
static MTLResourceOptions getOptions(uint32_t usage) {
|
||||
// TODO: check the caching performance of write-combined mode
|
||||
@ -126,16 +135,17 @@ struct HeapBlock {
|
||||
else
|
||||
options |= MTLResourceStorageModePrivate;
|
||||
|
||||
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
||||
options |=
|
||||
(usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
|
||||
HeapBlock *heapBlock = nullptr;
|
||||
HeapBlock* heapBlock = nullptr;
|
||||
bool is_split = true;
|
||||
const size_t size = params.size();
|
||||
MTLHeapDescriptor *d = [MTLHeapDescriptor new];
|
||||
MTLHeapDescriptor* d = [MTLHeapDescriptor new];
|
||||
if (d) {
|
||||
const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
|
||||
if (size <= kMaxSmallAlloc) {
|
||||
@ -152,10 +162,11 @@ struct HeapBlock {
|
||||
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
||||
// this automatically handles Metal buffer access synchronizations at the
|
||||
// cost of slightly lower performance.
|
||||
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
||||
d.hazardTrackingMode =
|
||||
(usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
||||
d.resourceOptions = getOptions(usage);
|
||||
d.type = MTLHeapTypeAutomatic;
|
||||
id<MTLHeap> heap = [device newHeapWithDescriptor: d];
|
||||
id<MTLHeap> heap = [device newHeapWithDescriptor:d];
|
||||
if (heap) {
|
||||
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||
const size_t heap_size = heapAvailableSize(heap);
|
||||
@ -169,8 +180,8 @@ struct HeapBlock {
|
||||
return heapBlock;
|
||||
}
|
||||
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
|
||||
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
|
||||
(uintptr_t)a->heap < (uintptr_t)b->heap;
|
||||
return (a->size.available != b->size.available) ? a->size.available < b->size.available
|
||||
: (uintptr_t)a->heap < (uintptr_t)b->heap;
|
||||
}
|
||||
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
|
||||
return [heap maxAvailableSizeWithAlignment:Alignment];
|
||||
@ -205,8 +216,12 @@ struct HeapBlock {
|
||||
size.available = 0;
|
||||
return retainCount;
|
||||
}
|
||||
uint32_t retainCount() const { return [heap retainCount]; }
|
||||
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
|
||||
uint32_t retainCount() const {
|
||||
return [heap retainCount];
|
||||
}
|
||||
void updateAvailableSize() {
|
||||
size.available = heapAvailableSize(heap);
|
||||
}
|
||||
};
|
||||
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
|
||||
|
||||
@ -219,9 +234,8 @@ struct BufferPool {
|
||||
SCALAR,
|
||||
};
|
||||
|
||||
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
|
||||
device(Device), usage(Usage),
|
||||
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
|
||||
BufferPool(const id<MTLDevice> Device, uint32_t Usage)
|
||||
: device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {}
|
||||
|
||||
const id<MTLDevice> device;
|
||||
// usage flags to customize the pool for various purposes (see UsageFlags enum)
|
||||
@ -248,12 +262,12 @@ struct BufferPool {
|
||||
};
|
||||
|
||||
class MPSHeapAllocatorImpl {
|
||||
public:
|
||||
explicit MPSHeapAllocatorImpl() :
|
||||
m_device(at::mps::MPSDevice::getInstance()->device()),
|
||||
m_max_buffer_size([m_device maxBufferLength]),
|
||||
m_stream(getDefaultMPSStream()),
|
||||
m_event_pool(getMPSEventPool()) {
|
||||
public:
|
||||
explicit MPSHeapAllocatorImpl()
|
||||
: m_device(at::mps::MPSDevice::getInstance()->device()),
|
||||
m_max_buffer_size([m_device maxBufferLength]),
|
||||
m_stream(getDefaultMPSStream()),
|
||||
m_event_pool(getMPSEventPool()) {
|
||||
init_allocator();
|
||||
}
|
||||
~MPSHeapAllocatorImpl() {
|
||||
@ -298,34 +312,50 @@ public:
|
||||
// (see m_high_watermark_ratio for description)
|
||||
void setHighWatermarkRatio(double ratio);
|
||||
// (see m_low_watermark_limit for description)
|
||||
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
|
||||
size_t getLowWatermarkLimit() const {
|
||||
return m_low_watermark_limit;
|
||||
}
|
||||
// (see m_max_total_allowed_size for description)
|
||||
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
|
||||
size_t getHighWatermarkLimit() const {
|
||||
return m_max_total_allowed_size;
|
||||
}
|
||||
// (see m_total_allocated_memory for description)
|
||||
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
|
||||
size_t getTotalAllocatedMemory() const {
|
||||
return m_total_allocated_memory;
|
||||
}
|
||||
// (see m_current_allocated_memory for description)
|
||||
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
|
||||
size_t getCurrentAllocatedMemory() const {
|
||||
return m_current_allocated_memory;
|
||||
}
|
||||
// total GPU memory allocated in the process by Metal driver; including
|
||||
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
|
||||
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
|
||||
size_t getDriverAllocatedMemory() const {
|
||||
return current_allocated_size();
|
||||
}
|
||||
// recommended Max memory for Metal
|
||||
size_t getRecommendedMaxMemory() const { return max_device_size(); }
|
||||
size_t getRecommendedMaxMemory() const {
|
||||
return max_device_size();
|
||||
}
|
||||
// (see enum DebugVerbosity for description)
|
||||
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
|
||||
uint32_t getDebugVerbosity() const {
|
||||
return m_debug_verbosity;
|
||||
}
|
||||
// returns the device that we allocate from
|
||||
inline id<MTLDevice> Device() const { return m_device; }
|
||||
inline id<MTLDevice> Device() const {
|
||||
return m_device;
|
||||
}
|
||||
|
||||
// TODO: make a common function to do size unit conversions in PyTorch.
|
||||
inline std::string format_size(uint64_t size) const;
|
||||
|
||||
private:
|
||||
private:
|
||||
// (see m_high_watermark_ratio for description)
|
||||
constexpr static double default_high_watermark_ratio = 1.7;
|
||||
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
|
||||
constexpr static double default_high_watermark_upper_bound = 2.0;
|
||||
// (see m_low_watermark_ratio for description)
|
||||
// on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
|
||||
constexpr static double default_low_watermark_ratio_unified = 1.4;
|
||||
constexpr static double default_low_watermark_ratio_unified = 1.4;
|
||||
constexpr static double default_low_watermark_ratio_discrete = 1.0;
|
||||
|
||||
const id<MTLDevice> m_device;
|
||||
@ -387,14 +417,19 @@ private:
|
||||
size_t get_allocation_size(size_t size, uint32_t usage) const;
|
||||
// maximum size of device memory available for allocation in current process
|
||||
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
|
||||
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
|
||||
size_t max_device_size() const {
|
||||
return [m_device recommendedMaxWorkingSetSize];
|
||||
}
|
||||
// there are implicit allocations from MPS backend, so we need to query the 'device' for
|
||||
// total allocated size instead of manually tracking in MPSAllocator
|
||||
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
|
||||
size_t current_allocated_size() const {
|
||||
return [m_device currentAllocatedSize];
|
||||
}
|
||||
|
||||
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
|
||||
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
|
||||
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
|
||||
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(
|
||||
buffer_block ? buffer_block->buffer : nullptr, event);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/util/Registry.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
|
||||
#define MB(x) (x * 1048576UL)
|
||||
|
||||
@ -13,17 +13,19 @@ namespace at::mps {
|
||||
// this is a public interface to access MPSAllocator.
|
||||
// Do not declare methods that would depend on MPS or Metal frameworks.
|
||||
class IMPSAllocator : public c10::Allocator {
|
||||
public:
|
||||
public:
|
||||
// see the comments in MPSAllocator.h for the description of these methods.
|
||||
virtual void emptyCache() const = 0;
|
||||
virtual void freeInactiveBuffers() const = 0;
|
||||
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
|
||||
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
|
||||
virtual id_t getBufferId(const void* ptr) const = 0;
|
||||
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
|
||||
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
|
||||
const = 0;
|
||||
virtual bool isSharedBuffer(const void* ptr) const = 0;
|
||||
virtual bool isSharedStorageSupported() const = 0;
|
||||
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
|
||||
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
|
||||
const = 0;
|
||||
virtual std::string formatSize(size_t size) const = 0;
|
||||
virtual void setLowWatermarkRatio(double ratio) const = 0;
|
||||
virtual void setHighWatermarkRatio(double ratio) const = 0;
|
||||
@ -34,7 +36,8 @@ public:
|
||||
virtual size_t getCurrentAllocatedMemory() const = 0;
|
||||
virtual size_t getDriverAllocatedMemory() const = 0;
|
||||
virtual size_t getRecommendedMaxMemory() const = 0;
|
||||
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
|
||||
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(
|
||||
const void* ptr) const = 0;
|
||||
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
||||
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
||||
};
|
||||
@ -43,16 +46,17 @@ class IMpsAllocatorCallback {
|
||||
public:
|
||||
enum class EventType {
|
||||
ALLOCATED, // buffer got allocated to be used immediately
|
||||
RECYCLED, // buffer pulled from free list to be reused
|
||||
FREED, // buffer put to free list for future recycling
|
||||
RELEASED, // buffer memory released
|
||||
RECYCLED, // buffer pulled from free list to be reused
|
||||
FREED, // buffer put to free list for future recycling
|
||||
RELEASED, // buffer memory released
|
||||
ALLOCATION_FAILED // buffer allocation failed
|
||||
};
|
||||
virtual ~IMpsAllocatorCallback() = default;
|
||||
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
|
||||
};
|
||||
|
||||
// MPS allocator will execute every registered callback when a block of memory is freed.
|
||||
// MPS allocator will execute every registered callback when a block of memory
|
||||
// is freed.
|
||||
TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
|
||||
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
|
||||
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
|
@ -11,7 +11,7 @@ namespace at::mps {
|
||||
// NOTE: don't create instances of this class directly.
|
||||
// Use MPSEventPool to acquire instances of MPSEvent.
|
||||
class MPSEvent {
|
||||
public:
|
||||
public:
|
||||
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
|
||||
~MPSEvent();
|
||||
|
||||
@ -26,16 +26,21 @@ public:
|
||||
// blocks the CPU thread until all the GPU work that were scheduled
|
||||
// prior to recording this event are completed.
|
||||
bool synchronize();
|
||||
// resets this event with new parameters in case it gets reused from the event pool
|
||||
// resets this event with new parameters in case it gets reused from the event
|
||||
// pool
|
||||
void reset(MPSStream* stream, bool enable_timing);
|
||||
// returns the unique ID of the event instance
|
||||
id_t getID() const { return m_id; }
|
||||
id_t getID() const {
|
||||
return m_id;
|
||||
}
|
||||
// returns the completion timestamp of the event
|
||||
uint64_t getCompletionTime() const { return m_completion_time; }
|
||||
uint64_t getCompletionTime() const {
|
||||
return m_completion_time;
|
||||
}
|
||||
// if already recorded, waits for cpu_sync_cv to be signaled
|
||||
void waitForCpuSync();
|
||||
|
||||
private:
|
||||
private:
|
||||
id_t m_id;
|
||||
// enables measuring the completion time of the notifyListener of this event
|
||||
bool m_enable_timing;
|
||||
@ -63,7 +68,7 @@ private:
|
||||
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
|
||||
|
||||
class MPSEventPool {
|
||||
public:
|
||||
public:
|
||||
explicit MPSEventPool(MPSStream* default_stream);
|
||||
~MPSEventPool();
|
||||
|
||||
@ -80,7 +85,7 @@ public:
|
||||
// returns elapsed time between two recorded events in milliseconds
|
||||
double elapsedTime(id_t start_event_id, id_t end_event_id);
|
||||
|
||||
private:
|
||||
private:
|
||||
MPSStream* m_default_stream = nullptr;
|
||||
std::recursive_mutex m_mutex;
|
||||
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
|
||||
|
@ -17,7 +17,8 @@ struct rng_data_pod {
|
||||
};
|
||||
|
||||
TORCH_API const Generator& getDefaultMPSGenerator();
|
||||
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
|
||||
TORCH_API Generator
|
||||
createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
|
||||
|
||||
} // namespace mps::detail
|
||||
|
||||
@ -37,12 +38,20 @@ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
|
||||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
||||
void update_philox_counters();
|
||||
|
||||
void set_engine(at::Philox4_32 engine) { engine_ = engine; }
|
||||
at::Philox4_32 engine() { return engine_; }
|
||||
uint32_t* state_data() { return data_.state.data(); }
|
||||
static DeviceType device_type() { return DeviceType::MPS; }
|
||||
void set_engine(at::Philox4_32 engine) {
|
||||
engine_ = engine;
|
||||
}
|
||||
at::Philox4_32 engine() {
|
||||
return engine_;
|
||||
}
|
||||
uint32_t* state_data() {
|
||||
return data_.state.data();
|
||||
}
|
||||
static DeviceType device_type() {
|
||||
return DeviceType::MPS;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
mps::detail::rng_data_pod data_;
|
||||
at::Philox4_32 engine_;
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
@ -18,11 +18,10 @@
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <sys/_types/_size_t.h>
|
||||
#include <memory>
|
||||
#include <c10/core/UndefinedTensorImpl.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
#include <sys/_types/_size_t.h>
|
||||
#include <memory>
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
@ -30,7 +29,8 @@ typedef MPSEvent* mpsEvent_t;
|
||||
|
||||
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
|
||||
// https://github.com/pytorch/pytorch/issues/77170
|
||||
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
struct TORCH_API MPSGuardImpl final
|
||||
: public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
|
||||
|
||||
// constructor
|
||||
@ -83,7 +83,7 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
|
||||
}
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
if (at::hasMPS()) {
|
||||
//TODO: extend it for multi-device case
|
||||
// TODO: extend it for multi-device case
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
@ -91,28 +91,22 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
|
||||
}
|
||||
|
||||
// Event-related functions
|
||||
void createEvent(
|
||||
mpsEvent_t* event,
|
||||
const EventFlag flag) const;
|
||||
void createEvent(mpsEvent_t* event, const EventFlag flag) const;
|
||||
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override;
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override;
|
||||
|
||||
void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override;
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override;
|
||||
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override;
|
||||
void block(void* event, const Stream& stream) const override;
|
||||
|
||||
bool queryEvent(void* event) const override;
|
||||
|
||||
void synchronizeDevice(const DeviceIndex device_index) const override;
|
||||
|
||||
};
|
||||
|
||||
/// A variant of OptionalDeviceGuard that is specialized for MPS.
|
||||
@ -175,7 +169,6 @@ struct OptionalMPSGuard {
|
||||
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
||||
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl)
|
||||
|
||||
} // namespace at::mps
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/Generator.h>
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
#include <optional>
|
||||
|
||||
@ -38,7 +38,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
|
||||
// MPSProfiler interface
|
||||
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
|
||||
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
|
||||
const override;
|
||||
void profilerStopTrace() const override;
|
||||
|
||||
// MPSEvent interface
|
||||
@ -48,7 +49,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||
void waitForEvent(uint32_t event_id) const override;
|
||||
void synchronizeEvent(uint32_t event_id) const override;
|
||||
bool queryEvent(uint32_t event_id) const override;
|
||||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
|
||||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
|
||||
const override;
|
||||
|
||||
// Compatibility with Accelerator API
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
|
@ -3,11 +3,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
#include <os/signpost.h>
|
||||
#include <os/log.h>
|
||||
#include <os/signpost.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
@ -29,8 +29,8 @@ struct BaseInfo {
|
||||
CPU_FALLBACK,
|
||||
};
|
||||
|
||||
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
|
||||
type(infoType), profileId(Id), handle(Handle) { }
|
||||
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
|
||||
: type(infoType), profileId(Id), handle(Handle) {}
|
||||
virtual ~BaseInfo() = default;
|
||||
|
||||
// type of profiling info
|
||||
@ -41,30 +41,36 @@ struct BaseInfo {
|
||||
// since it's possible to use event and interval-based signposts at the
|
||||
// same time, we need separate IDs for each.
|
||||
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
|
||||
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
|
||||
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
|
||||
// GPUStartTime")
|
||||
std::atomic<double> totalGpuTime{0.0};
|
||||
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
|
||||
// accumulated Scheduling time in ms (obtained from CompletionHandler's
|
||||
// "KernelEndTime - KernelStartTime")
|
||||
std::atomic<double> totalSchedulingTime{0.0};
|
||||
// indicates if the operation or copy execution has completed
|
||||
std::atomic_bool completed{false};
|
||||
// handle used to identify the profile info's instance (usually the pointer)
|
||||
const uintptr_t handle;
|
||||
|
||||
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
|
||||
virtual const std::string toString(
|
||||
double gpuTime = 0,
|
||||
double schedulingTime = 0) const;
|
||||
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
|
||||
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
|
||||
static std::string buildTensorString(
|
||||
const Tensor& tensor,
|
||||
bool includeBufferId = false) {
|
||||
if (tensor.defined()) {
|
||||
std::stringstream tensorStr;
|
||||
auto deviceType = tensor.device().type();
|
||||
tensorStr << c10::DeviceTypeName(deviceType);
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
|
||||
<< ":" << buffer.retainCount << ")";
|
||||
id<MTLBuffer> buffer =
|
||||
__builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":"
|
||||
<< buffer.retainCount << ")";
|
||||
}
|
||||
tensorStr << ":"
|
||||
<< tensor.scalar_type() << tensor.sizes();
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
@ -76,21 +82,28 @@ struct BaseInfo {
|
||||
};
|
||||
|
||||
struct OperationInfo : BaseInfo {
|
||||
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
|
||||
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
|
||||
OperationInfo(
|
||||
const void* Handle,
|
||||
bool IsGraph,
|
||||
uint64_t Id,
|
||||
const std::string& StrKey)
|
||||
: BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
|
||||
strKey(StrKey) {}
|
||||
|
||||
uint64_t runCount = 0;
|
||||
std::string strKey;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
|
||||
const override;
|
||||
|
||||
// builds a string for a kernel
|
||||
static std::string buildKernelString(const std::string& kernelName,
|
||||
const TensorList& tensors,
|
||||
bool includeBufferId = false) {
|
||||
static std::string buildKernelString(
|
||||
const std::string& kernelName,
|
||||
const TensorList& tensors,
|
||||
bool includeBufferId = false) {
|
||||
std::stringstream kernelStr;
|
||||
kernelStr << kernelName;
|
||||
for (const Tensor& tensor: tensors) {
|
||||
for (const Tensor& tensor : tensors) {
|
||||
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
}
|
||||
return kernelStr.str();
|
||||
@ -98,23 +111,24 @@ struct OperationInfo : BaseInfo {
|
||||
};
|
||||
|
||||
struct CpuFbInfo : BaseInfo {
|
||||
CpuFbInfo(uint64_t Id, const std::string& OpName) :
|
||||
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
|
||||
CpuFbInfo(uint64_t Id, const std::string& OpName)
|
||||
: BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
|
||||
|
||||
uint64_t runCount = 0;
|
||||
// the current and total overhead of copies in bytes required to convert the Op's
|
||||
// input tensors from MPS to CPU and then output from CPU back to MPS
|
||||
// the current and total overhead of copies in bytes required to convert the
|
||||
// Op's input tensors from MPS to CPU and then output from CPU back to MPS
|
||||
size_t currentCopyOverhead = 0;
|
||||
size_t totalCopyOverhead = 0;
|
||||
std::string opName;
|
||||
std::string strKey;
|
||||
uint64_t startTime = 0;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
|
||||
const override;
|
||||
|
||||
void updateCopyOverhead(const TensorList& tensors) {
|
||||
currentCopyOverhead = 0;
|
||||
for (const Tensor& tensor: tensors) {
|
||||
for (const Tensor& tensor : tensors) {
|
||||
if (tensor.defined()) {
|
||||
currentCopyOverhead += tensor.nbytes();
|
||||
}
|
||||
@ -130,9 +144,17 @@ struct CopyInfo : BaseInfo {
|
||||
CPU_TO_MPS,
|
||||
};
|
||||
|
||||
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
|
||||
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
|
||||
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
|
||||
CopyInfo(
|
||||
const void* Handle,
|
||||
size_t Length,
|
||||
uint64_t Id,
|
||||
bool IsNonBlocking,
|
||||
bool UsesBlitter)
|
||||
: BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
|
||||
kind(Kind::MPS_TO_MPS),
|
||||
length(Length),
|
||||
isNonBlocking(IsNonBlocking),
|
||||
usesBlitter(UsesBlitter) {}
|
||||
|
||||
Kind kind;
|
||||
size_t length;
|
||||
@ -143,11 +165,17 @@ struct CopyInfo : BaseInfo {
|
||||
// for copies that don't use blitters, we measure CPU time
|
||||
uint64_t startTime = 0;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
|
||||
const override;
|
||||
|
||||
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
|
||||
static std::string buildTensorString(
|
||||
const void* buffer,
|
||||
const OptionalTensorRef tensor,
|
||||
bool includeBufferId = false);
|
||||
|
||||
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
|
||||
static bool isStorageOnMPS(
|
||||
const void* buffer,
|
||||
const OptionalTensorRef tensor) {
|
||||
if (tensor.has_value()) {
|
||||
return tensor->device().type() == at::kMPS;
|
||||
}
|
||||
@ -156,8 +184,11 @@ struct CopyInfo : BaseInfo {
|
||||
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
|
||||
}
|
||||
|
||||
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
|
||||
static Kind getCopyKind(
|
||||
const void* srcBuffer,
|
||||
const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor,
|
||||
const OptionalTensorRef dstTensor) {
|
||||
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
|
||||
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
|
||||
@ -171,8 +202,9 @@ struct CopyInfo : BaseInfo {
|
||||
};
|
||||
|
||||
struct CopyStat : CopyInfo {
|
||||
explicit CopyStat(std::string CopyKindStr) :
|
||||
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
|
||||
explicit CopyStat(std::string CopyKindStr)
|
||||
: CopyInfo(nullptr, 0, 0, false, false),
|
||||
kindStr(std::move(CopyKindStr)) {}
|
||||
// total number of copies
|
||||
size_t totalCount = 0;
|
||||
// number of Scalar copies (i.e., less than sizeof(int64))
|
||||
@ -188,29 +220,29 @@ struct CopyStat : CopyInfo {
|
||||
};
|
||||
|
||||
class MPSProfiler {
|
||||
public:
|
||||
public:
|
||||
// lower 16 bits used for profiler options
|
||||
enum ProfileOptions : uint32_t {
|
||||
OPTIONS_NONE = 0,
|
||||
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
|
||||
// (used for convenience to not compute bit flags by OR-ing manually)
|
||||
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
|
||||
// etc.) (used for convenience to not compute bit flags by OR-ing manually)
|
||||
// trace all signpost types using events
|
||||
ALL_SIGNPOST_EVENTS = (1 << 0),
|
||||
ALL_SIGNPOST_EVENTS = (1 << 0),
|
||||
// trace all signpost types using intervals
|
||||
ALL_SIGNPOST_INTERVALS = (1 << 1),
|
||||
// always wait for command buffer to finish executing after each commit
|
||||
WAIT_UNTIL_COMPLETED = (1 << 2),
|
||||
WAIT_UNTIL_COMPLETED = (1 << 2),
|
||||
// for interval-based signposts, include the scheduling portion of
|
||||
// Graph/Kernel/Copy executions as well.
|
||||
// if flag is disable, only "GPU run time" is included in interval,
|
||||
// and not schedule time.
|
||||
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
|
||||
|
||||
// use these if you need to trace signposts types individually (rarely required)
|
||||
// trace signpost using intervals
|
||||
// use these if you need to trace signposts types individually (rarely
|
||||
// required) trace signpost using intervals
|
||||
USE_INTERVALS = (1 << 4),
|
||||
// trace signpost by emitting events
|
||||
USE_EVENTS = (1 << 5),
|
||||
USE_EVENTS = (1 << 5),
|
||||
// used for sanity check (Change this when new option added)
|
||||
OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
|
||||
};
|
||||
@ -222,9 +254,9 @@ public:
|
||||
// trace signposts for PyTorch operation executions
|
||||
RUN_OPERATION = (1 << 16),
|
||||
// trace signposts for blitter copies
|
||||
BLIT_COPY = (1 << 17),
|
||||
BLIT_COPY = (1 << 17),
|
||||
// trace signposts for ops that fall back on CPU
|
||||
CPU_FALLBACK = (1 << 18),
|
||||
CPU_FALLBACK = (1 << 18),
|
||||
// used for sanity check (Change this when new type added)
|
||||
SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
|
||||
};
|
||||
@ -235,39 +267,44 @@ public:
|
||||
// Info logging options during execution
|
||||
// -------------------------------------
|
||||
// prints operation info (id/key/run_count) during execution
|
||||
OPERATION_INFO = (1 << 0),
|
||||
OPERATION_INFO = (1 << 0),
|
||||
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
|
||||
COPY_INFO = (1 << 1),
|
||||
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
|
||||
CPU_FALLBACK_INFO = (1 << 2),
|
||||
COPY_INFO = (1 << 1),
|
||||
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during
|
||||
// execution
|
||||
CPU_FALLBACK_INFO = (1 << 2),
|
||||
|
||||
// Profiling Statistics logging options when process terminates
|
||||
// ------------------------------------------------------------
|
||||
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
|
||||
// this is convenient to not combine following stats bit flags manually
|
||||
ALL_STATS = (1 << 3),
|
||||
// prints operation stats (GPU times, run count, etc.) before process terminates
|
||||
OPERATION_STATS = (1 << 4),
|
||||
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
|
||||
COPY_STATS = (1 << 5),
|
||||
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
|
||||
// process terminates this is convenient to not combine following stats bit
|
||||
// flags manually
|
||||
ALL_STATS = (1 << 3),
|
||||
// prints operation stats (GPU times, run count, etc.) before process
|
||||
// terminates
|
||||
OPERATION_STATS = (1 << 4),
|
||||
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process
|
||||
// terminates
|
||||
COPY_STATS = (1 << 5),
|
||||
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
|
||||
// for tensors, etc.) before process terminates
|
||||
CPU_FALLBACK_STATS = (1 << 6),
|
||||
CPU_FALLBACK_STATS = (1 << 6),
|
||||
|
||||
// Metadata format options when logging the info
|
||||
// ---------------------------------------------
|
||||
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
|
||||
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
|
||||
INCLUDE_GPU_TIME = (1 << 7),
|
||||
// if enabled, includes GPU run time in metadata (i.e.,
|
||||
// GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
|
||||
// ms])
|
||||
INCLUDE_GPU_TIME = (1 << 7),
|
||||
// if enabled, includes GPU scheduling time in metadata separately
|
||||
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
|
||||
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
|
||||
INCLUDE_KERNEL_TIME = (1 << 8),
|
||||
// if enabled, includes the unique buffer ID in metadata for the storage
|
||||
// of a tensor that was allocated on MPSAllocator. This is useful (along with
|
||||
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
|
||||
// with various operations.
|
||||
INCLUDE_BUFFER_ID = (1 << 9),
|
||||
// of a tensor that was allocated on MPSAllocator. This is useful (along
|
||||
// with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
|
||||
// involved with various operations.
|
||||
INCLUDE_BUFFER_ID = (1 << 9),
|
||||
|
||||
// used for sanity check (Change this when new option added)
|
||||
LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
|
||||
@ -276,15 +313,28 @@ public:
|
||||
explicit MPSProfiler();
|
||||
~MPSProfiler();
|
||||
|
||||
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
||||
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
|
||||
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
|
||||
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
|
||||
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor,
|
||||
const OptionalTensorRef dstTensor,
|
||||
size_t length, bool isNonBlocking, bool usesBlitter = true);
|
||||
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
|
||||
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal
|
||||
// Kernels the beginProfile*() functions return a profileId which is unique
|
||||
// per graph/kernel/copy
|
||||
uint64_t beginProfileKernel(
|
||||
const void* handle,
|
||||
const std::string& strKey,
|
||||
bool isGraph);
|
||||
uint64_t beginProfileKernel(
|
||||
const void* handle,
|
||||
const std::string& kernelName,
|
||||
const TensorList& tensors);
|
||||
uint64_t beginProfileCopy(
|
||||
const void* srcBuffer,
|
||||
const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor,
|
||||
const OptionalTensorRef dstTensor,
|
||||
size_t length,
|
||||
bool isNonBlocking,
|
||||
bool usesBlitter = true);
|
||||
uint64_t beginProfileCPUFallback(
|
||||
const std::string& opName,
|
||||
const TensorList& tensors);
|
||||
void beginProfileGPUInterval(const void* handle);
|
||||
|
||||
void endProfileCopy(uint64_t profileId, SyncType syncType);
|
||||
@ -309,22 +359,25 @@ public:
|
||||
// logging are enabled for the SignpostTypes
|
||||
bool isOperationProfilingEnabled() const {
|
||||
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
|
||||
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
|
||||
(m_log_options &
|
||||
(LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
|
||||
}
|
||||
bool isCopyProfilingEnabled() const {
|
||||
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
|
||||
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
|
||||
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
|
||||
}
|
||||
bool isCPUFallbackProfilingEnabled() const {
|
||||
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
|
||||
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
|
||||
(m_log_options &
|
||||
(LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
|
||||
}
|
||||
bool isSignpostTracingEnabled() const {
|
||||
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
|
||||
}
|
||||
|
||||
private:
|
||||
// indicates what type of signpost types are enabled and traced by MPS profiler.
|
||||
// indicates what type of signpost types are enabled and traced by MPS
|
||||
// profiler.
|
||||
uint32_t m_signpost_types = 0;
|
||||
uint32_t m_profile_options = 0;
|
||||
uint32_t m_log_options = 0;
|
||||
@ -332,14 +385,15 @@ public:
|
||||
uint64_t m_graph_counter = 0;
|
||||
uint64_t m_cpu_fb_counter = 0;
|
||||
uint64_t m_copy_counter = 0;
|
||||
// technically, it's possible to trace both events and intervals at the same time
|
||||
// so we use separate os_log categories for them
|
||||
// technically, it's possible to trace both events and intervals at the same
|
||||
// time so we use separate os_log categories for them
|
||||
os_log_t m_os_log_events;
|
||||
os_log_t m_os_log_intervals;
|
||||
// stats logging could run either from destructor or signal handler
|
||||
// so this is used to check if logging has already started.
|
||||
std::atomic_bool hasLoggedStats{false};
|
||||
// indicates there are pending completionHandler callbacks that haven't been called yet.
|
||||
// indicates there are pending completionHandler callbacks that haven't been
|
||||
// called yet.
|
||||
std::atomic_bool hasPendingCompletionHandlers{false};
|
||||
// used to capture sigint signal to log profiling stats
|
||||
static struct sigaction currentSigint, previousSigint;
|
||||
@ -347,40 +401,62 @@ public:
|
||||
// We use the following lists for two reasons:
|
||||
// 1- for interval-based signposts the "begin" point won't be in same function
|
||||
// as the "end" point where we need to be able to retrieve signpost's info
|
||||
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
|
||||
// 2- if Operations info need to be logged when process ends using
|
||||
// LogOptions::OPERATION_INFO.
|
||||
|
||||
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
||||
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
||||
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
|
||||
// the string key for this map is the op name that we fall back to execute on CPU
|
||||
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
||||
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
|
||||
// the pointer key for this map is either "MPSGraph*" or
|
||||
// "id<MTLComputePipelineState>" for Metal Kernels this list is retained and
|
||||
// could be logged along with aggregate profiling numbers when the process
|
||||
// ends.
|
||||
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>>
|
||||
m_op_info_list{};
|
||||
// the string key for this map is the op name that we fall back to execute on
|
||||
// CPU this list is retained and could be logged along with aggregate
|
||||
// profiling numbers when the process ends.
|
||||
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>>
|
||||
m_cpu_fb_info_list{};
|
||||
// this list contains the info for copies, and its key is the unique profileId
|
||||
// which is generated from m_copy_counter
|
||||
// The copyInfo list is not retained.
|
||||
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
|
||||
// a short list that contains copy stats
|
||||
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
|
||||
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>>
|
||||
m_copy_stat_list{};
|
||||
|
||||
mutable MTLCaptureManager *captureManager = nil;
|
||||
mutable MTLCaptureManager* captureManager = nil;
|
||||
unsigned captureCount = 0;
|
||||
|
||||
void initialize();
|
||||
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
|
||||
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
|
||||
os_signpost_id_t interval_signpost_id,
|
||||
double gpuTime, double schedulingTime);
|
||||
void endProfileExecution(
|
||||
BaseInfo& info,
|
||||
os_signpost_id_t event_signpost_id,
|
||||
os_signpost_id_t interval_signpost_id,
|
||||
double gpuTime,
|
||||
double schedulingTime);
|
||||
void addProfilerScheduledHandler(BaseInfo& info);
|
||||
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
|
||||
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
||||
const std::string& msg) const;
|
||||
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
||||
const std::string& msg) const;
|
||||
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
|
||||
void emitSignpostEvent(
|
||||
SignpostTypes signpost_type,
|
||||
os_signpost_id_t signpost_id,
|
||||
const std::string& msg) const;
|
||||
void beginSignpostInterval(
|
||||
SignpostTypes signpost_type,
|
||||
os_signpost_id_t signpost_id,
|
||||
const std::string& msg) const;
|
||||
void endSignpostInterval(
|
||||
SignpostTypes signpost_type,
|
||||
os_signpost_id_t signpost_id) const;
|
||||
|
||||
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
|
||||
// returns true if logging the profiling info "during the execution" is enabled
|
||||
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
|
||||
void updateCopyStats(
|
||||
const CopyInfo& copyInfo,
|
||||
double gpuTime,
|
||||
double schedulingTime);
|
||||
// returns true if logging the profiling info "during the execution" is
|
||||
// enabled
|
||||
bool isProfileInfoLoggingEnabled(
|
||||
BaseInfo::Type infoType,
|
||||
bool isExecutionEnded);
|
||||
// logs all the profiling stats that are enabled
|
||||
void logProfilingStats();
|
||||
// logs kernel profiling stats when the process ends.
|
||||
@ -390,7 +466,9 @@ public:
|
||||
// logs copy profiling stats when the process ends.
|
||||
void logCopyProfilingStats(std::FILE* f) const;
|
||||
|
||||
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
|
||||
os_signpost_id_t generateSignpostId(
|
||||
os_signpost_type_t signpostType,
|
||||
const void* ptr = nullptr);
|
||||
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
|
||||
static void handleIntSignal(int signal);
|
||||
};
|
||||
|
@ -5,10 +5,10 @@
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
@ -32,7 +32,6 @@ typedef void* MTLDevice_t;
|
||||
#define nil NULL;
|
||||
#endif
|
||||
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
@ -40,16 +39,15 @@ namespace at::mps {
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
enum class SyncType {
|
||||
NONE, // no commit to command buffer
|
||||
COMMIT, // commit and flush the command buffer
|
||||
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
|
||||
COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
|
||||
COMMIT_ADAPTIVE, // commit adaptively based on available memory
|
||||
NONE, // no commit to command buffer
|
||||
COMMIT, // commit and flush the command buffer
|
||||
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
|
||||
COMMIT_AND_CONTINUE, // commit and continue with a new underlying command buffer
|
||||
COMMIT_ADAPTIVE, // commit adaptively based on available memory
|
||||
};
|
||||
|
||||
class TORCH_API MPSStream
|
||||
{
|
||||
public:
|
||||
class TORCH_API MPSStream {
|
||||
public:
|
||||
enum Unchecked { UNCHECKED };
|
||||
|
||||
/// Construct a MPSStream from a Stream. This construction is checked,
|
||||
@ -57,41 +55,64 @@ public:
|
||||
explicit MPSStream(Stream stream);
|
||||
|
||||
~MPSStream();
|
||||
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
|
||||
dispatch_queue_t queue() const { return _serialQueue; }
|
||||
MTLCommandQueue_t commandQueue() const {
|
||||
return _commandQueue;
|
||||
};
|
||||
dispatch_queue_t queue() const {
|
||||
return _serialQueue;
|
||||
}
|
||||
|
||||
MPSCommandBuffer* commandBuffer();
|
||||
MTLComputeCommandEncoder_t commandEncoder();
|
||||
void endKernelCoalescing();
|
||||
void synchronize(SyncType syncType);
|
||||
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
||||
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
size_t length, size_t srcOffset, size_t dstOffset,
|
||||
uint64_t profileId, SyncType syncType = SyncType::NONE);
|
||||
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
size_t length, size_t srcOffset, size_t dstOffset,
|
||||
bool non_blocking, uint64_t profileId);
|
||||
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
|
||||
void copy(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
uint64_t profileId,
|
||||
SyncType syncType = SyncType::NONE);
|
||||
void copy_and_sync(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
bool non_blocking,
|
||||
uint64_t profileId);
|
||||
void executeMPSGraph(MPSGraph* mpsGraph,
|
||||
NSDictionary* feeds,
|
||||
NSDictionary* results,
|
||||
SyncType syncType = SyncType::NONE);
|
||||
void addCompletedHandler(MTLCommandBufferHandler block);
|
||||
|
||||
/// Get the MPS device index that this stream is associated with.
|
||||
c10::DeviceIndex device_index() const { return _stream.device_index(); }
|
||||
c10::DeviceIndex device_index() const {
|
||||
return _stream.device_index();
|
||||
}
|
||||
|
||||
MTLCommandQueue_t stream() const { return _commandQueue; };
|
||||
MTLCommandQueue_t stream() const {
|
||||
return _commandQueue;
|
||||
};
|
||||
|
||||
MTLDevice_t device() const { return [_commandQueue device];}
|
||||
MTLDevice_t device() const {
|
||||
return [_commandQueue device];
|
||||
}
|
||||
|
||||
/// Explicit conversion to Stream.
|
||||
Stream unwrap() const { return _stream; }
|
||||
Stream unwrap() const {
|
||||
return _stream;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
Stream _stream;
|
||||
MTLCommandQueue_t _commandQueue = nil;
|
||||
MPSCommandBuffer* _commandBuffer = nil;
|
||||
MPSCommandBuffer* _prevCommandBuffer = nil;
|
||||
MTLComputeCommandEncoder_t _commandEncoder = nil;
|
||||
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
|
||||
MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
|
||||
MPSGraphExecutionDescriptor* _executionDescriptor = nil;
|
||||
MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
|
||||
dispatch_queue_t _serialQueue = nullptr;
|
||||
// CommitAndContinue is enabled by default
|
||||
bool _enableCommitAndContinue = true;
|
||||
@ -117,8 +138,7 @@ TORCH_API MPSStream* getDefaultMPSStream();
|
||||
// MPSStreamImpl
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
class TORCH_API MPSStreamImpl
|
||||
{
|
||||
class TORCH_API MPSStreamImpl {
|
||||
public:
|
||||
/**
|
||||
* Gets single instance of the MPSStream.
|
||||
|
@ -2,44 +2,40 @@
|
||||
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
|
||||
#if !defined(__MAC_15_0) && \
|
||||
(!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
|
||||
#if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
|
||||
|
||||
@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel
|
||||
-(MPSNDArray * __nullable) reshapeWithCommandBuffer: (__nullable id <MTLCommandBuffer>) cmdBuf
|
||||
sourceArray: (MPSNDArray * __nonnull) sourceArray
|
||||
shape: (MPSShape * __nonnull) shape
|
||||
destinationArray: (MPSNDArray * __nullable) destinationArray;
|
||||
- (MPSNDArray* __nullable)reshapeWithCommandBuffer:(__nullable id<MTLCommandBuffer>)cmdBuf
|
||||
sourceArray:(MPSNDArray* __nonnull)sourceArray
|
||||
shape:(MPSShape* __nonnull)shape
|
||||
destinationArray:(MPSNDArray* __nullable)destinationArray;
|
||||
@end
|
||||
|
||||
@interface MPSNDArrayDescriptor()
|
||||
@property (readwrite, nonatomic) BOOL preferPackedRows;
|
||||
@interface MPSNDArrayDescriptor ()
|
||||
@property(readwrite, nonatomic) BOOL preferPackedRows;
|
||||
@end
|
||||
|
||||
@interface MPSNDArray()
|
||||
-(nonnull instancetype) initWithBuffer:(id<MTLBuffer> _Nonnull) buffer
|
||||
offset:(NSUInteger) offset
|
||||
descriptor:(MPSNDArrayDescriptor * _Nonnull) descriptor;
|
||||
-(MPSNDArray * __nullable) arrayViewWithShape:(MPSShape * _Nullable) shape
|
||||
strides:(MPSShape * _Nonnull) strides;
|
||||
@interface MPSNDArray ()
|
||||
- (nonnull instancetype)initWithBuffer:(id<MTLBuffer> _Nonnull)buffer
|
||||
offset:(NSUInteger)offset
|
||||
descriptor:(MPSNDArrayDescriptor* _Nonnull)descriptor;
|
||||
- (MPSNDArray* __nullable)arrayViewWithShape:(MPSShape* _Nullable)shape strides:(MPSShape* _Nonnull)strides;
|
||||
@end
|
||||
|
||||
typedef NS_ENUM(NSInteger, MTLMathMode)
|
||||
{
|
||||
MTLMathModeSafe = 0,
|
||||
MTLMathModeRelaxed = 1,
|
||||
MTLMathModeFast = 2,
|
||||
typedef NS_ENUM(NSInteger, MTLMathMode) {
|
||||
MTLMathModeSafe = 0,
|
||||
MTLMathModeRelaxed = 1,
|
||||
MTLMathModeFast = 2,
|
||||
};
|
||||
|
||||
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions)
|
||||
{
|
||||
MTLMathFloatingPointFunctionsFast = 0,
|
||||
MTLMathFloatingPointFunctionsPrecise = 1,
|
||||
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) {
|
||||
MTLMathFloatingPointFunctionsFast = 0,
|
||||
MTLMathFloatingPointFunctionsPrecise = 1,
|
||||
};
|
||||
|
||||
@interface MTLCompileOptions()
|
||||
@property (readwrite, nonatomic) MTLMathMode mathMode;
|
||||
@property (readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions;
|
||||
@interface MTLCompileOptions ()
|
||||
@property(readwrite, nonatomic) MTLMathMode mathMode;
|
||||
@property(readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions;
|
||||
@end
|
||||
|
||||
#endif
|
||||
|
@ -2,52 +2,47 @@
|
||||
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
|
||||
#if !defined(__MAC_14_0) && \
|
||||
(!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
|
||||
#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
|
||||
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
|
||||
{
|
||||
MPSGraphFFTScalingModeNone = 0L,
|
||||
MPSGraphFFTScalingModeSize = 1L,
|
||||
MPSGraphFFTScalingModeUnitary = 2L,
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) {
|
||||
MPSGraphFFTScalingModeNone = 0L,
|
||||
MPSGraphFFTScalingModeSize = 1L,
|
||||
MPSGraphFFTScalingModeUnitary = 2L,
|
||||
};
|
||||
|
||||
@interface FakeMPSGraphFFTDescriptor : NSObject<NSCopying>
|
||||
@property (readwrite, nonatomic) BOOL inverse;
|
||||
@property (readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode;
|
||||
@property (readwrite, nonatomic) BOOL roundToOddHermitean;
|
||||
+(nullable instancetype) descriptor;
|
||||
@property(readwrite, nonatomic) BOOL inverse;
|
||||
@property(readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode;
|
||||
@property(readwrite, nonatomic) BOOL roundToOddHermitean;
|
||||
+ (nullable instancetype)descriptor;
|
||||
@end
|
||||
|
||||
@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor;
|
||||
|
||||
@interface MPSGraph (SonomaOps)
|
||||
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
|
||||
|
||||
-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axes:(NSArray<NSNumber*>* _Nonnull)axes
|
||||
descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axes:(NSArray<NSNumber *> * _Nonnull) axes
|
||||
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)realToHermiteanFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axes:(NSArray<NSNumber*>* _Nonnull)axes
|
||||
descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
-(MPSGraphTensor * _Nonnull) realToHermiteanFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axes:(NSArray<NSNumber *> * _Nonnull) axes
|
||||
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
|
||||
name:(NSString * _Nullable) name;
|
||||
|
||||
-(MPSGraphTensor * _Nonnull) HermiteanToRealFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axes:(NSArray<NSNumber *> * _Nonnull) axes
|
||||
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)HermiteanToRealFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axes:(NSArray<NSNumber*>* _Nonnull)axes
|
||||
descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
@end
|
||||
|
||||
// define BFloat16 enums for MacOS13
|
||||
#define MPSDataTypeBFloat16 ((MPSDataType) (MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16))
|
||||
#define MPSDataTypeBFloat16 ((MPSDataType)(MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16))
|
||||
|
||||
// define Metal version
|
||||
#define MTLLanguageVersion3_1 ((MTLLanguageVersion) ((3 << 16) + 1))
|
||||
#define MTLLanguageVersion3_1 ((MTLLanguageVersion)((3 << 16) + 1))
|
||||
#endif
|
||||
|
@ -2,30 +2,29 @@
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
|
||||
// TODO: Remove me when moved to MacOS 13
|
||||
#if !defined(__MAC_13_2) && \
|
||||
(!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
|
||||
#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
|
||||
|
||||
@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject<NSCopying>
|
||||
|
||||
@property (readwrite, nonatomic) NSUInteger strideInX;
|
||||
@property (readwrite, nonatomic) NSUInteger strideInY;
|
||||
@property (readwrite, nonatomic) NSUInteger strideInZ;
|
||||
@property (readwrite, nonatomic) NSUInteger dilationRateInX;
|
||||
@property (readwrite, nonatomic) NSUInteger dilationRateInY;
|
||||
@property (readwrite, nonatomic) NSUInteger dilationRateInZ;
|
||||
@property(readwrite, nonatomic) NSUInteger strideInX;
|
||||
@property(readwrite, nonatomic) NSUInteger strideInY;
|
||||
@property(readwrite, nonatomic) NSUInteger strideInZ;
|
||||
@property(readwrite, nonatomic) NSUInteger dilationRateInX;
|
||||
@property(readwrite, nonatomic) NSUInteger dilationRateInY;
|
||||
@property(readwrite, nonatomic) NSUInteger dilationRateInZ;
|
||||
|
||||
@property (readwrite, nonatomic) NSUInteger paddingLeft;
|
||||
@property (readwrite, nonatomic) NSUInteger paddingRight;
|
||||
@property (readwrite, nonatomic) NSUInteger paddingTop;
|
||||
@property (readwrite, nonatomic) NSUInteger paddingBottom;
|
||||
@property (readwrite, nonatomic) NSUInteger paddingFront;
|
||||
@property (readwrite, nonatomic) NSUInteger paddingBack;
|
||||
@property(readwrite, nonatomic) NSUInteger paddingLeft;
|
||||
@property(readwrite, nonatomic) NSUInteger paddingRight;
|
||||
@property(readwrite, nonatomic) NSUInteger paddingTop;
|
||||
@property(readwrite, nonatomic) NSUInteger paddingBottom;
|
||||
@property(readwrite, nonatomic) NSUInteger paddingFront;
|
||||
@property(readwrite, nonatomic) NSUInteger paddingBack;
|
||||
|
||||
@property (readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle;
|
||||
@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout;
|
||||
@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout;
|
||||
@property(readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle;
|
||||
@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout;
|
||||
@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout;
|
||||
|
||||
@property (readwrite, nonatomic) NSUInteger groups;
|
||||
@property(readwrite, nonatomic) NSUInteger groups;
|
||||
|
||||
@end
|
||||
|
||||
@ -35,163 +34,163 @@
|
||||
|
||||
@interface MPSGraph (VenturaOps)
|
||||
|
||||
#if !defined(__MAC_13_0) && \
|
||||
(!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
|
||||
#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
|
||||
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
|
||||
{
|
||||
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
|
||||
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
|
||||
MPSGraphResizeNearestRoundingModeCeil = 2L,
|
||||
MPSGraphResizeNearestRoundingModeFloor = 3L,
|
||||
MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
|
||||
MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) {
|
||||
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
|
||||
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
|
||||
MPSGraphResizeNearestRoundingModeCeil = 2L,
|
||||
MPSGraphResizeNearestRoundingModeFloor = 3L,
|
||||
MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
|
||||
MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
|
||||
};
|
||||
|
||||
// Define complex enums for MacOS 12
|
||||
#define MPSDataTypeComplexBit 0x01000000
|
||||
#define MPSDataTypeComplexFloat32 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
|
||||
#define MPSDataTypeComplexFloat16 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
|
||||
#define MPSDataTypeComplexFloat32 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
|
||||
#define MPSDataTypeComplexFloat16 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
|
||||
#endif
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) convolution3DWithSourceTensor:(MPSGraphTensor * _Nonnull) source
|
||||
weightsTensor:(MPSGraphTensor * _Nonnull) weights
|
||||
descriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) descriptor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)convolution3DWithSourceTensor:(MPSGraphTensor* _Nonnull)source
|
||||
weightsTensor:(MPSGraphTensor* _Nonnull)weights
|
||||
descriptor:(MPSGraphConvolution3DOpDescriptor* _Nonnull)descriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
|
||||
weightsTensor:(MPSGraphTensor * _Nonnull) weights
|
||||
outputShape:(MPSShape * _Nonnull) outputShape
|
||||
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)
|
||||
convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
|
||||
weightsTensor:(MPSGraphTensor* _Nonnull)weights
|
||||
outputShape:(MPSShape* _Nonnull)outputShape
|
||||
forwardConvolutionDescriptor:
|
||||
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
|
||||
sourceTensor:(MPSGraphTensor * _Nonnull) source
|
||||
outputShape:(MPSShape * _Nonnull) outputShape
|
||||
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)
|
||||
convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
|
||||
sourceTensor:(MPSGraphTensor* _Nonnull)source
|
||||
outputShape:(MPSShape* _Nonnull)outputShape
|
||||
forwardConvolutionDescriptor:
|
||||
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
name:(NSString * _Nullable)name;
|
||||
- (MPSGraphTensor* _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
name:(NSString * _Nullable)name;
|
||||
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axis:(NSInteger) axis
|
||||
descending:(BOOL) descending
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
descending:(BOOL)descending
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
|
||||
descending:(BOOL) descending
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
|
||||
descending:(BOOL)descending
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
name:(NSString * _Nullable)name;
|
||||
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axis:(NSInteger) axis
|
||||
descending:(BOOL) descending
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axis:(NSInteger)axis
|
||||
descending:(BOOL)descending
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
|
||||
descending:(BOOL) descending
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
|
||||
descending:(BOOL)descending
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor
|
||||
name:(NSString * _Nullable)name;
|
||||
- (MPSGraphTensor* _Nonnull)inverseOfTensor:(MPSGraphTensor* _Nonnull)inputTensor name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
|
||||
sizeTensor:(MPSGraphTensor * _Nonnull) size
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
|
||||
centerResult:(BOOL) centerResult
|
||||
alignCorners:(BOOL) alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
|
||||
sizeTensor:(MPSGraphTensor* _Nonnull)size
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
|
||||
centerResult:(BOOL)centerResult
|
||||
alignCorners:(BOOL)alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
|
||||
sizeTensor:(MPSGraphTensor * _Nonnull) size
|
||||
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
|
||||
sizeTensor:(MPSGraphTensor* _Nonnull)size
|
||||
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
|
||||
sizeTensor:(MPSGraphTensor * _Nonnull) size
|
||||
centerResult:(BOOL) centerResult
|
||||
alignCorners:(BOOL) alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
|
||||
sizeTensor:(MPSGraphTensor* _Nonnull)size
|
||||
centerResult:(BOOL)centerResult
|
||||
alignCorners:(BOOL)alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
|
||||
sizeTensor:(MPSGraphTensor * _Nonnull) size
|
||||
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
|
||||
sizeTensor:(MPSGraphTensor* _Nonnull)size
|
||||
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
|
||||
input:(MPSGraphTensor * _Nonnull) input
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
|
||||
centerResult:(BOOL) centerResult
|
||||
alignCorners:(BOOL) alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
|
||||
input:(MPSGraphTensor* _Nonnull)input
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
|
||||
centerResult:(BOOL)centerResult
|
||||
alignCorners:(BOOL)alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
|
||||
input:(MPSGraphTensor * _Nonnull) input
|
||||
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
|
||||
input:(MPSGraphTensor* _Nonnull)input
|
||||
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
|
||||
input:(MPSGraphTensor * _Nonnull) input
|
||||
centerResult:(BOOL) centerResult
|
||||
alignCorners:(BOOL) alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
|
||||
input:(MPSGraphTensor* _Nonnull)input
|
||||
centerResult:(BOOL)centerResult
|
||||
alignCorners:(BOOL)alignCorners
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
|
||||
input:(MPSGraphTensor * _Nonnull) input
|
||||
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
|
||||
input:(MPSGraphTensor* _Nonnull)input
|
||||
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source
|
||||
coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
normalizeCoordinates:(BOOL) normalizeCoordinates
|
||||
relativeCoordinates:(BOOL) relativeCoordinates
|
||||
alignCorners:(BOOL) alignCorners
|
||||
paddingMode:(MPSGraphPaddingMode) paddingMode
|
||||
samplingMode:(MPSGraphResizeMode) samplingMode
|
||||
constantValue:(double) constantValue
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source
|
||||
coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
normalizeCoordinates:(BOOL)normalizeCoordinates
|
||||
relativeCoordinates:(BOOL)relativeCoordinates
|
||||
alignCorners:(BOOL)alignCorners
|
||||
paddingMode:(MPSGraphPaddingMode)paddingMode
|
||||
samplingMode:(MPSGraphResizeMode)samplingMode
|
||||
constantValue:(double)constantValue
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source
|
||||
coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates
|
||||
layout:(MPSGraphTensorNamedDataLayout) layout
|
||||
normalizeCoordinates:(BOOL) normalizeCoordinates
|
||||
relativeCoordinates:(BOOL) relativeCoordinates
|
||||
alignCorners:(BOOL) alignCorners
|
||||
paddingMode:(MPSGraphPaddingMode) paddingMode
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
|
||||
constantValue:(double) constantValue
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source
|
||||
coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates
|
||||
layout:(MPSGraphTensorNamedDataLayout)layout
|
||||
normalizeCoordinates:(BOOL)normalizeCoordinates
|
||||
relativeCoordinates:(BOOL)relativeCoordinates
|
||||
alignCorners:(BOOL)alignCorners
|
||||
paddingMode:(MPSGraphPaddingMode)paddingMode
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
|
||||
constantValue:(double)constantValue
|
||||
name:(NSString* _Nullable)name;
|
||||
- (MPSGraphTensor* _Nonnull)truncateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
|
||||
|
||||
@end
|
||||
|
@ -26,7 +26,7 @@
|
||||
|
||||
// Fwd declarations
|
||||
namespace at {
|
||||
struct TensorIteratorBase;
|
||||
struct TensorIteratorBase;
|
||||
}
|
||||
using namespace at::mps;
|
||||
|
||||
@ -35,7 +35,9 @@ namespace at::native::mps {
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
|
||||
struct MPSScalar {
|
||||
id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
|
||||
id<MTLBuffer> getMTLBuffer() const {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
||||
}
|
||||
|
||||
size_t size = 0;
|
||||
ScalarType type = ScalarType::Undefined;
|
||||
@ -48,13 +50,10 @@ struct MPSScalar {
|
||||
c10::complex<float> cf;
|
||||
c10::complex<at::Half> ch;
|
||||
at::BFloat16 bf16;
|
||||
} value {};
|
||||
} value{};
|
||||
};
|
||||
|
||||
void runMPSGraph(MPSStream* mpsStream,
|
||||
MPSGraph* mpsGraph,
|
||||
NSDictionary* feeds,
|
||||
NSDictionary* results);
|
||||
void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
|
||||
|
||||
MPSDataType getMPSDataType(ScalarType scalar_type);
|
||||
static inline MPSDataType getMPSDataType(const TensorBase& t) {
|
||||
@ -64,7 +63,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type);
|
||||
static inline MPSDataType getMPSScalarType(const TensorBase& t) {
|
||||
return getMPSScalarType(t.scalar_type());
|
||||
}
|
||||
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
|
||||
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
|
||||
std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
|
||||
static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) {
|
||||
return getMPSTypeString(t.scalar_type(), short_name);
|
||||
@ -81,10 +80,18 @@ std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
|
||||
bool canSliceViewTensor(const TensorBase& src, MPSShape *mpsShape);
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
|
||||
bool canSliceViewTensor(const TensorBase& src, MPSShape* mpsShape);
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src,
|
||||
MPSShape* mpsShape,
|
||||
const MPSDataType mpsDataType);
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64 = false);
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64 = false);
|
||||
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
|
||||
@ -102,8 +109,12 @@ class Placeholder {
|
||||
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
|
||||
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
const Tensor& self,
|
||||
MPSShape* mpsShape = nullptr,
|
||||
bool gatherTensorData = true,
|
||||
MPSDataType dataType = MPSDataTypeInvalid,
|
||||
bool useMPSStridedAPI = true);
|
||||
MPSGraphTensor* getMPSGraphTensor() {
|
||||
return _placeholder;
|
||||
}
|
||||
@ -123,21 +134,21 @@ class Placeholder {
|
||||
void resize_tensor(Tensor* output);
|
||||
Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
|
||||
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
|
||||
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
|
||||
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor);
|
||||
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor);
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
|
||||
MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor);
|
||||
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
|
||||
|
||||
MPSGraph* make_mps_graph();
|
||||
void printTensorNDArray(const TensorBase& t);
|
||||
MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape *shape, MPSDataType mpsType);
|
||||
MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType);
|
||||
|
||||
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const TensorBase& tensor);
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
|
||||
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType);
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor);
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType);
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar);
|
||||
|
||||
string get_mem_format_string(c10::MemoryFormat memory_format);
|
||||
|
||||
@ -145,75 +156,73 @@ using MPSCacheKey = uint64_t;
|
||||
|
||||
// derive this class to cache a graph and its inputs/outputs
|
||||
// can be used to store any NSObject
|
||||
struct MPSCachedGraph
|
||||
{
|
||||
MPSCachedGraph(NSObject *object) : _object([object retain]) {}
|
||||
struct MPSCachedGraph {
|
||||
MPSCachedGraph(NSObject* object) : _object([object retain]) {}
|
||||
virtual ~MPSCachedGraph() {
|
||||
[_object release];
|
||||
_object = nullptr;
|
||||
[_object release];
|
||||
_object = nullptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline T* as() {
|
||||
return static_cast<T*>(this);
|
||||
}
|
||||
|
||||
MPSGraph *graph() const { return (MPSGraph *)_object; }
|
||||
NSObject *object() const { return _object; }
|
||||
private:
|
||||
NSObject *_object = nullptr;
|
||||
MPSGraph* graph() const {
|
||||
return (MPSGraph*)_object;
|
||||
}
|
||||
NSObject* object() const {
|
||||
return _object;
|
||||
}
|
||||
|
||||
private:
|
||||
NSObject* _object = nullptr;
|
||||
};
|
||||
|
||||
struct MPSUnaryCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct MPSUnaryCachedGraph : public MPSCachedGraph {
|
||||
MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
struct MPSUnaryGradCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *gradOutputTensor_ = nil;
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
|
||||
MPSGraphTensor *gradInputTensor_ = nil;
|
||||
struct MPSUnaryGradCachedGraph : public MPSCachedGraph {
|
||||
MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil; // some backward input is actually the forward's output
|
||||
MPSGraphTensor* gradInputTensor_ = nil;
|
||||
};
|
||||
|
||||
struct MPSBinaryCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *otherTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct MPSBinaryCachedGraph : public MPSCachedGraph {
|
||||
MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
struct MPSBinaryGradCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *gradOutputTensor_ = nil;
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *otherTensor_ = nil;
|
||||
MPSGraphTensor *gradInputTensor_ = nil;
|
||||
struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
|
||||
MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* gradInputTensor_ = nil;
|
||||
};
|
||||
|
||||
// TODO: Improve the overall design of MPSGraphCache.
|
||||
// https://github.com/pytorch/pytorch/issues/77176
|
||||
// Cache holding various keys mapped to graphs
|
||||
struct MPSGraphCache
|
||||
{
|
||||
typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
|
||||
struct MPSGraphCache {
|
||||
typedef MPSCachedGraph* (^CreateCachedGraphBlock)();
|
||||
|
||||
struct CacheEntry {
|
||||
CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
|
||||
CacheEntry(const std::string& key, MPSCachedGraph* cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
|
||||
MPSCachedGraph* cachedGraph_ = nullptr;
|
||||
std::string key_;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
static MPSGraphCache* getInstance() {
|
||||
if(_instance_cache == nullptr) {
|
||||
if (_instance_cache == nullptr) {
|
||||
_instance_cache = new MPSGraphCache();
|
||||
}
|
||||
return _instance_cache;
|
||||
@ -232,7 +241,6 @@ struct MPSGraphCache
|
||||
void operator=(const MPSGraphCache&) = delete;
|
||||
|
||||
MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
|
||||
|
||||
__block MPSCachedGraph* cachedGraph = nil;
|
||||
|
||||
MPSCacheKey hash = std::hash<std::string>{}(key);
|
||||
@ -253,19 +261,17 @@ struct MPSGraphCache
|
||||
return cachedGraph;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
|
||||
return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
|
||||
return static_cast<T*>(CreateCachedGraph(key, createCacheBlock));
|
||||
}
|
||||
|
||||
MPSCachedGraph* LookUp(const std::string& key) const {
|
||||
|
||||
__block MPSCachedGraph* cachedGraph = nullptr;
|
||||
|
||||
MPSCacheKey hash = std::hash<std::string>{}(key);
|
||||
|
||||
dispatch_sync(serialQueue_, ^() {
|
||||
|
||||
if (cache_.count(hash) != 0) {
|
||||
auto& entry = cache_.at(hash);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
|
||||
@ -276,9 +282,9 @@ struct MPSGraphCache
|
||||
return cachedGraph;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline T* LookUpAs(const std::string& key) const {
|
||||
return static_cast<T *>(LookUp(key));
|
||||
return static_cast<T*>(LookUp(key));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -292,14 +298,13 @@ struct MPSGraphCache
|
||||
static MPSGraphCache* _instance_cache;
|
||||
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
|
||||
dispatch_queue_t serialQueue_ = nullptr;
|
||||
|
||||
};
|
||||
|
||||
// Common template for creating graph with a specified cache if missing
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
|
||||
auto cache_ = MPSGraphCache::getInstance();
|
||||
if (auto rc = cache_->LookUpAs<T>(key)) {
|
||||
if (auto rc = cache_->LookUpAs<T>(key)) {
|
||||
return rc;
|
||||
}
|
||||
return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
|
||||
@ -317,10 +322,12 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(M
|
||||
// Common math operations
|
||||
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
|
||||
|
||||
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
|
||||
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
|
||||
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
|
||||
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
|
||||
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
|
||||
TORCH_WARN_ONCE( \
|
||||
"MPS: no support for int64 for ", \
|
||||
op_name, \
|
||||
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
|
||||
}
|
||||
|
||||
/**
|
||||
@ -335,17 +342,19 @@ inline bool is_dense_in_storage(const TensorBase& t) {
|
||||
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
|
||||
}
|
||||
|
||||
|
||||
class MetalShaderLibrary {
|
||||
public:
|
||||
MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
|
||||
public:
|
||||
MetalShaderLibrary(const std::string& src) : shaderSource(src), nparams(0), compile_options(nullptr) {}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_)
|
||||
: shaderSource(src), nparams(nparams_), compile_options(nullptr) {}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_)
|
||||
: shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
|
||||
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
|
||||
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
|
||||
return getLibraryPipelineState(getLibrary(), fname).first;
|
||||
}
|
||||
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
|
||||
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname,
|
||||
const std::initializer_list<std::string>& params) {
|
||||
return getLibraryPipelineState(getLibrary(params), fname).first;
|
||||
}
|
||||
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
|
||||
@ -355,12 +364,15 @@ public:
|
||||
return getLibraryPipelineState(getLibrary(params), fname).second;
|
||||
}
|
||||
static MetalShaderLibrary& getBundledLibrary();
|
||||
protected:
|
||||
|
||||
protected:
|
||||
virtual id<MTLLibrary> getLibrary();
|
||||
virtual id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
|
||||
id<MTLLibrary> library = nil;
|
||||
private:
|
||||
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
|
||||
|
||||
private:
|
||||
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib,
|
||||
const std::string& fname);
|
||||
|
||||
id<MTLLibrary> compileLibrary(const std::string& src);
|
||||
std::string shaderSource;
|
||||
@ -370,24 +382,21 @@ private:
|
||||
std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
|
||||
};
|
||||
|
||||
template<typename encoder_t,
|
||||
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> || std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
|
||||
template <typename encoder_t,
|
||||
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
|
||||
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
|
||||
static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) {
|
||||
[encoder setBuffer:getMTLBufferStorage(t)
|
||||
offset:t.storage_offset() * t.element_size()
|
||||
atIndex:idx];
|
||||
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
|
||||
}
|
||||
|
||||
template<typename T,
|
||||
typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
|
||||
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const T val, unsigned idx) {
|
||||
[encoder setBytes:&val length:sizeof(T) atIndex: idx];
|
||||
[encoder setBytes:&val length:sizeof(T) atIndex:idx];
|
||||
}
|
||||
|
||||
template<typename Container,
|
||||
typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
|
||||
template <typename Container, typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
|
||||
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Container& values, unsigned idx) {
|
||||
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex: idx];
|
||||
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx];
|
||||
}
|
||||
|
||||
static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
|
||||
@ -400,38 +409,40 @@ static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
|
||||
[encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
|
||||
}
|
||||
|
||||
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
|
||||
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder,
|
||||
const TensorIteratorBase& iter,
|
||||
bool use_64bit_index = false);
|
||||
|
||||
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
|
||||
return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
|
||||
return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()};
|
||||
}
|
||||
|
||||
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
|
||||
return @{
|
||||
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
|
||||
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
|
||||
};
|
||||
return @{
|
||||
p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(),
|
||||
p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(),
|
||||
};
|
||||
}
|
||||
|
||||
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
|
||||
return @{
|
||||
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
|
||||
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
|
||||
p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
|
||||
};
|
||||
return @{
|
||||
p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(),
|
||||
p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(),
|
||||
p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(),
|
||||
};
|
||||
}
|
||||
|
||||
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
|
||||
return @{
|
||||
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
|
||||
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
|
||||
p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
|
||||
p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
|
||||
};
|
||||
return @{
|
||||
p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(),
|
||||
p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(),
|
||||
p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(),
|
||||
p4.getMPSGraphTensor() : p4.getMPSGraphTensorData(),
|
||||
};
|
||||
}
|
||||
|
||||
inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
|
||||
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
|
||||
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
|
||||
}
|
||||
|
||||
inline bool supportsComplex() {
|
||||
@ -464,7 +475,7 @@ inline void checkSupportsBFloat16() {
|
||||
|
||||
inline bool needsGather(const TensorBase& t) {
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ;
|
||||
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());
|
||||
}
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
@ -1,13 +1,14 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))
|
||||
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) AT_DISPATCH_CASE( \
|
||||
at::ScalarType::Half, \
|
||||
__VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))
|
||||
|
@ -1,5 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
namespace at::native::mps {
|
||||
void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output);
|
||||
void complex_mul_out(
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output);
|
||||
}
|
||||
|
@ -17,8 +17,7 @@ void _fused_adam_amsgrad_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adam_amsgrad_mps_impl_(
|
||||
TensorList params,
|
||||
@ -34,7 +33,6 @@ void _fused_adam_amsgrad_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<at::Tensor>& grad_scale,
|
||||
const std::optional<at::Tensor>& found_inf
|
||||
);
|
||||
const std::optional<at::Tensor>& found_inf);
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
@ -16,8 +16,7 @@ void _fused_adam_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adam_mps_impl_(
|
||||
TensorList params,
|
||||
@ -32,6 +31,5 @@ void _fused_adam_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
} // namespace at::native::mps
|
||||
|
@ -17,8 +17,7 @@ void _fused_adamw_amsgrad_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adamw_amsgrad_mps_impl_(
|
||||
TensorList params,
|
||||
@ -34,6 +33,5 @@ void _fused_adamw_amsgrad_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
} // namespace at::native::mps
|
||||
|
@ -16,8 +16,7 @@ void _fused_adamw_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adamw_mps_impl_(
|
||||
TensorList params,
|
||||
@ -32,7 +31,6 @@ void _fused_adamw_mps_impl_(
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
@ -436,10 +436,12 @@ REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
||||
|
||||
)METAL";
|
||||
|
||||
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
|
||||
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(
|
||||
const std::string& fname) {
|
||||
static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
|
||||
return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
|
||||
return std::make_pair(
|
||||
lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
|
||||
}
|
||||
|
||||
} //namespace mps
|
||||
} // namespace mps
|
||||
} // namespace at::native
|
@ -16,18 +16,15 @@ struct MetadataArguments { // the size of this struct must be less than 4 bytes
|
||||
};
|
||||
|
||||
struct FusedAdamEncodingFunctor {
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize
|
||||
) const {
|
||||
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize) const {
|
||||
float lr_lv = lr;
|
||||
float beta1_lv = beta1;
|
||||
float beta2_lv = beta2;
|
||||
@ -35,12 +32,8 @@ struct FusedAdamEncodingFunctor {
|
||||
float eps_lv = eps;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, lr_lv, 2);
|
||||
mtl_setBytes(computeEncoder, beta1_lv, 3);
|
||||
mtl_setBytes(computeEncoder, beta2_lv, 4);
|
||||
@ -49,29 +42,23 @@ struct FusedAdamEncodingFunctor {
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 7);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const at::Tensor& lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize
|
||||
) const {
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const at::Tensor& lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize) const {
|
||||
float beta1_lv = beta1;
|
||||
float beta2_lv = beta2;
|
||||
float weight_decay_lv = weight_decay;
|
||||
float eps_lv = eps;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBuffer(computeEncoder, lr, 2);
|
||||
mtl_setBytes(computeEncoder, beta1_lv, 3);
|
||||
mtl_setBytes(computeEncoder, beta2_lv, 4);
|
||||
@ -86,144 +73,117 @@ struct FusedSgdEncodingFunctor {};
|
||||
|
||||
template <>
|
||||
struct FusedSgdEncodingFunctor<true> {
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double momentum,
|
||||
const double lr,
|
||||
const double dampening,
|
||||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step
|
||||
) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float momentum_lv = momentum;
|
||||
float lr_lv = lr;
|
||||
float dampening_lv = dampening;
|
||||
uint8_t nesterov_lv = nesterov;
|
||||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double momentum,
|
||||
const double lr,
|
||||
const double dampening,
|
||||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float momentum_lv = momentum;
|
||||
float lr_lv = lr;
|
||||
float dampening_lv = dampening;
|
||||
uint8_t nesterov_lv = nesterov;
|
||||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, momentum_lv, 3);
|
||||
mtl_setBytes(computeEncoder, lr_lv, 4);
|
||||
mtl_setBytes(computeEncoder, dampening_lv, 5);
|
||||
mtl_setBytes(computeEncoder, nesterov_lv, 6);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 7);
|
||||
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, momentum_lv, 3);
|
||||
mtl_setBytes(computeEncoder, lr_lv, 4);
|
||||
mtl_setBytes(computeEncoder, dampening_lv, 5);
|
||||
mtl_setBytes(computeEncoder, nesterov_lv, 6);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 7);
|
||||
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double momentum,
|
||||
const at::Tensor& lr,
|
||||
const double dampening,
|
||||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step
|
||||
) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float momentum_lv = momentum;
|
||||
float dampening_lv = dampening;
|
||||
uint8_t nesterov_lv = nesterov;
|
||||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double momentum,
|
||||
const at::Tensor& lr,
|
||||
const double dampening,
|
||||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float momentum_lv = momentum;
|
||||
float dampening_lv = dampening;
|
||||
uint8_t nesterov_lv = nesterov;
|
||||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, momentum_lv, 3);
|
||||
mtl_setBuffer(computeEncoder, lr, 4);
|
||||
mtl_setBytes(computeEncoder, dampening_lv, 5);
|
||||
mtl_setBytes(computeEncoder, nesterov_lv, 6);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 7);
|
||||
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, momentum_lv, 3);
|
||||
mtl_setBuffer(computeEncoder, lr, 4);
|
||||
mtl_setBytes(computeEncoder, dampening_lv, 5);
|
||||
mtl_setBytes(computeEncoder, nesterov_lv, 6);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 7);
|
||||
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FusedSgdEncodingFunctor<false> {
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double lr,
|
||||
const bool maximize
|
||||
) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float lr_lv = lr;
|
||||
uint8_t maximize_lv = maximize;
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double lr,
|
||||
const bool maximize) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float lr_lv = lr;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, lr_lv, 3);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 4);
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, lr_lv, 3);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 4);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const at::Tensor& lr,
|
||||
const bool maximize
|
||||
) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
uint8_t maximize_lv = maximize;
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const at::Tensor& lr,
|
||||
const bool maximize) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBuffer(computeEncoder, lr, 3);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 4);
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBuffer(computeEncoder, lr, 3);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 4);
|
||||
}
|
||||
};
|
||||
|
||||
template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes>
|
||||
static void multi_tensor_apply_for_fused_optimizer(
|
||||
const std::string& kernel_name,
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
at::TensorList state_steps,
|
||||
encoder_func_t encode,
|
||||
ArgTypes... args
|
||||
) {
|
||||
static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name,
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
at::TensorList state_steps,
|
||||
encoder_func_t encode,
|
||||
ArgTypes... args) {
|
||||
const auto num_tensors = tensor_lists[0].size();
|
||||
|
||||
if (num_tensors == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth");
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
TORCH_CHECK(
|
||||
tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
|
||||
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
|
||||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
|
||||
"Only float and half are supported");
|
||||
}
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
@ -251,7 +211,8 @@ static void multi_tensor_apply_for_fused_optimizer(
|
||||
|
||||
// BufferIndex is the index in the kernel function
|
||||
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
|
||||
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
|
||||
options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
|
||||
int64_t tensor_loc = 0;
|
||||
@ -265,10 +226,11 @@ static void multi_tensor_apply_for_fused_optimizer(
|
||||
}
|
||||
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
|
||||
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
|
||||
}
|
||||
if (state_steps.size() > 0){
|
||||
if (state_steps.size() > 0) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
}
|
||||
@ -281,47 +243,50 @@ static void multi_tensor_apply_for_fused_optimizer(
|
||||
TORCH_CHECK(chunks > -1);
|
||||
|
||||
for (const auto& chunk : c10::irange(chunks)) {
|
||||
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
|
||||
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
|
||||
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
|
||||
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
|
||||
|
||||
threadgroup_loc++;
|
||||
threadgroup_loc++;
|
||||
|
||||
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
|
||||
// Reach the maximum threadgroups per dispatch
|
||||
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
|
||||
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
|
||||
// Reach the maximum threadgroups per dispatch
|
||||
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
|
||||
|
||||
if (tensor_full || blocks_full){
|
||||
encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
|
||||
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
|
||||
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
|
||||
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
|
||||
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
if (tensor_full || blocks_full) {
|
||||
encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
|
||||
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
|
||||
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
|
||||
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
|
||||
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
// Reset
|
||||
threadgroup_loc = 0;
|
||||
if (chunk == chunks - 1) {
|
||||
// last chunk
|
||||
tensor_loc = 0;
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
} else {
|
||||
// reuse the current tensor since the current one isn't done.
|
||||
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
|
||||
// Reset
|
||||
threadgroup_loc = 0;
|
||||
if (chunk == chunks - 1) {
|
||||
// last chunk
|
||||
tensor_loc = 0;
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
|
||||
options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
} else {
|
||||
// reuse the current tensor since the current one isn't done.
|
||||
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
|
||||
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
|
||||
options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
|
||||
}
|
||||
if (state_steps.size() > 0){
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
}
|
||||
tensor_loc = 1;
|
||||
}
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
|
||||
}
|
||||
if (state_steps.size() > 0) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
}
|
||||
tensor_loc = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -334,7 +299,6 @@ static void multi_tensor_apply_for_fused_optimizer(
|
||||
}
|
||||
|
||||
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user