[BE][MPS] Apply clang-format to mps headers (#140906)

It was a mistake to amiss them in the past

All changes in this PR except ones to .lintrunner.toml are generated by running
`lintrunner -a --take CLANGFORMAT --all-files`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140906
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2024-11-17 10:13:01 -08:00
committed by PyTorch MergeBot
parent 5a7e147ef3
commit 99014a297c
24 changed files with 896 additions and 787 deletions

View File

@ -56,10 +56,12 @@ code = 'CLANGFORMAT'
include_patterns = [
'aten/src/ATen/*.h',
'aten/src/ATen/mps/**/*.mm',
'aten/src/ATen/mps/**/*.h',
'aten/src/ATen/xpu/**/*.h',
'aten/src/ATen/xpu/**/*.cpp',
'aten/src/ATen/native/mps/**/*.metal',
'aten/src/ATen/native/mps/**/*.mm',
'aten/src/ATen/native/mps/**/*.h',
'aten/src/ATen/native/vulkan/**/*.h',
'aten/src/ATen/native/vulkan/**/*.cpp',
'aten/src/ATen/native/cuda/MultiTensorApply.cuh',

View File

@ -12,8 +12,7 @@ C10_EXPORT TensorBase empty_mps(
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
C10_EXPORT TensorBase empty_mps(
IntArrayRef size, const TensorOptions &options);
C10_EXPORT TensorBase empty_mps(IntArrayRef size, const TensorOptions& options);
C10_EXPORT TensorBase empty_strided_mps(
IntArrayRef size,
@ -24,6 +23,6 @@ C10_EXPORT TensorBase empty_strided_mps(
C10_EXPORT TensorBase empty_strided_mps(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions &options);
const TensorOptions& options);
} // namespace at::detail

View File

@ -2,7 +2,7 @@
namespace at::mps {
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
struct __attribute__ ((packed)) packed_uint5{{
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
}};
@ -120,7 +120,7 @@ kernel void scatter_kernel_1(uint linear_index [[thread_position_in
}}
)METAL_SCATTER";
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
struct __attribute__ ((packed)) packed_uint5{{
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
}};

View File

@ -6,12 +6,12 @@
#include <ATen/mps/MPSEvent.h>
#include <ATen/mps/MPSStream.h>
#include <c10/util/flat_hash_map.h>
#include <mach/vm_page_size.h>
#include <cstdio>
#include <mutex>
#include <set>
#include <unordered_set>
#include <mach/vm_page_size.h>
#include <c10/util/flat_hash_map.h>
// this implementation is based on CUDACachingAllocator.
// It utilizes Metal Heaps to improve the performance with buffer allocation.
@ -19,32 +19,34 @@
// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
namespace at::mps::HeapAllocator {
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
static const size_t kXLargeHeapD =
MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
static const size_t kXLargeHeapU =
MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
// buffer pools could be customized with a combination of usage flags
enum UsageFlags : uint32_t {
PRIVATE = 0,
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
MANAGED = (1 << 2), // managed storage mode
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
};
// debug verbosity flags
enum DebugVerbosity : uint32_t {
SILENT = 0,
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
SILENT = 0,
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
ALLOCATIONS = (1 << 1), // print buffer allocations
RECYCLES = (1 << 2), // print buffer recycling
RELEASES = (1 << 3), // print buffer releases
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
RECYCLES = (1 << 2), // print buffer recycling
RELEASES = (1 << 3), // print buffer releases
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
};
struct HeapBlock;
@ -67,10 +69,8 @@ struct BufferBlock {
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
MPSEventPtr event;
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
HeapBlock* Heap = nullptr) :
buffer(Buffer), size(Size), requested_size(RequestedSize),
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr, HeapBlock* Heap = nullptr)
: buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {}
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
@ -79,15 +79,19 @@ struct BufferBlock {
assert(((Alignment - 1) & Alignment) == 0);
return ((Size + Alignment - 1) & ~(Alignment - 1));
}
uint32_t retainCount() const { return [buffer retainCount]; }
uint32_t retainCount() const {
return [buffer retainCount];
}
};
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
struct BufferPool;
struct AllocParams {
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
size_t size() const { return search_key.size; }
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool)
: search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {}
size_t size() const {
return search_key.size;
}
BufferBlock search_key;
BufferPool* pool;
@ -102,7 +106,9 @@ struct AllocParams {
struct HeapBlock {
id<MTLHeap> heap;
struct { size_t total, available; } size;
struct {
size_t total, available;
} size;
BufferPool* pool;
unsigned int n_buffers = 0;
id_t heap_id;
@ -111,9 +117,12 @@ struct HeapBlock {
// counter to assign unique ids to heap blocks
static uint64_t heap_counter;
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool* Pool = nullptr)
: heap(Heap),
size({.total = Size, .available = Size}),
pool(Pool),
heap_id(Heap ? ++heap_counter : 0),
is_split(true) {}
static MTLResourceOptions getOptions(uint32_t usage) {
// TODO: check the caching performance of write-combined mode
@ -126,16 +135,17 @@ struct HeapBlock {
else
options |= MTLResourceStorageModePrivate;
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
options |=
(usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
return options;
}
static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
HeapBlock *heapBlock = nullptr;
HeapBlock* heapBlock = nullptr;
bool is_split = true;
const size_t size = params.size();
MTLHeapDescriptor *d = [MTLHeapDescriptor new];
MTLHeapDescriptor* d = [MTLHeapDescriptor new];
if (d) {
const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
if (size <= kMaxSmallAlloc) {
@ -152,10 +162,11 @@ struct HeapBlock {
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
// this automatically handles Metal buffer access synchronizations at the
// cost of slightly lower performance.
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
d.hazardTrackingMode =
(usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
d.resourceOptions = getOptions(usage);
d.type = MTLHeapTypeAutomatic;
id<MTLHeap> heap = [device newHeapWithDescriptor: d];
id<MTLHeap> heap = [device newHeapWithDescriptor:d];
if (heap) {
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
const size_t heap_size = heapAvailableSize(heap);
@ -169,8 +180,8 @@ struct HeapBlock {
return heapBlock;
}
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
(uintptr_t)a->heap < (uintptr_t)b->heap;
return (a->size.available != b->size.available) ? a->size.available < b->size.available
: (uintptr_t)a->heap < (uintptr_t)b->heap;
}
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
return [heap maxAvailableSizeWithAlignment:Alignment];
@ -205,8 +216,12 @@ struct HeapBlock {
size.available = 0;
return retainCount;
}
uint32_t retainCount() const { return [heap retainCount]; }
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
uint32_t retainCount() const {
return [heap retainCount];
}
void updateAvailableSize() {
size.available = heapAvailableSize(heap);
}
};
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
@ -219,9 +234,8 @@ struct BufferPool {
SCALAR,
};
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
device(Device), usage(Usage),
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
BufferPool(const id<MTLDevice> Device, uint32_t Usage)
: device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {}
const id<MTLDevice> device;
// usage flags to customize the pool for various purposes (see UsageFlags enum)
@ -248,12 +262,12 @@ struct BufferPool {
};
class MPSHeapAllocatorImpl {
public:
explicit MPSHeapAllocatorImpl() :
m_device(at::mps::MPSDevice::getInstance()->device()),
m_max_buffer_size([m_device maxBufferLength]),
m_stream(getDefaultMPSStream()),
m_event_pool(getMPSEventPool()) {
public:
explicit MPSHeapAllocatorImpl()
: m_device(at::mps::MPSDevice::getInstance()->device()),
m_max_buffer_size([m_device maxBufferLength]),
m_stream(getDefaultMPSStream()),
m_event_pool(getMPSEventPool()) {
init_allocator();
}
~MPSHeapAllocatorImpl() {
@ -298,34 +312,50 @@ public:
// (see m_high_watermark_ratio for description)
void setHighWatermarkRatio(double ratio);
// (see m_low_watermark_limit for description)
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
size_t getLowWatermarkLimit() const {
return m_low_watermark_limit;
}
// (see m_max_total_allowed_size for description)
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
size_t getHighWatermarkLimit() const {
return m_max_total_allowed_size;
}
// (see m_total_allocated_memory for description)
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
size_t getTotalAllocatedMemory() const {
return m_total_allocated_memory;
}
// (see m_current_allocated_memory for description)
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
size_t getCurrentAllocatedMemory() const {
return m_current_allocated_memory;
}
// total GPU memory allocated in the process by Metal driver; including
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
size_t getDriverAllocatedMemory() const {
return current_allocated_size();
}
// recommended Max memory for Metal
size_t getRecommendedMaxMemory() const { return max_device_size(); }
size_t getRecommendedMaxMemory() const {
return max_device_size();
}
// (see enum DebugVerbosity for description)
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
uint32_t getDebugVerbosity() const {
return m_debug_verbosity;
}
// returns the device that we allocate from
inline id<MTLDevice> Device() const { return m_device; }
inline id<MTLDevice> Device() const {
return m_device;
}
// TODO: make a common function to do size unit conversions in PyTorch.
inline std::string format_size(uint64_t size) const;
private:
private:
// (see m_high_watermark_ratio for description)
constexpr static double default_high_watermark_ratio = 1.7;
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
constexpr static double default_high_watermark_upper_bound = 2.0;
// (see m_low_watermark_ratio for description)
// on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
constexpr static double default_low_watermark_ratio_unified = 1.4;
constexpr static double default_low_watermark_ratio_unified = 1.4;
constexpr static double default_low_watermark_ratio_discrete = 1.0;
const id<MTLDevice> m_device;
@ -387,14 +417,19 @@ private:
size_t get_allocation_size(size_t size, uint32_t usage) const;
// maximum size of device memory available for allocation in current process
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
size_t max_device_size() const {
return [m_device recommendedMaxWorkingSetSize];
}
// there are implicit allocations from MPS backend, so we need to query the 'device' for
// total allocated size instead of manually tracking in MPSAllocator
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
size_t current_allocated_size() const {
return [m_device currentAllocatedSize];
}
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(
buffer_block ? buffer_block->buffer : nullptr, event);
}
return true;
}

View File

@ -2,9 +2,9 @@
#pragma once
#include <ATen/core/ATen_fwd.h>
#include <c10/core/Allocator.h>
#include <c10/util/Registry.h>
#include <ATen/core/ATen_fwd.h>
#define MB(x) (x * 1048576UL)
@ -13,17 +13,19 @@ namespace at::mps {
// this is a public interface to access MPSAllocator.
// Do not declare methods that would depend on MPS or Metal frameworks.
class IMPSAllocator : public c10::Allocator {
public:
public:
// see the comments in MPSAllocator.h for the description of these methods.
virtual void emptyCache() const = 0;
virtual void freeInactiveBuffers() const = 0;
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
virtual id_t getBufferId(const void* ptr) const = 0;
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
const = 0;
virtual bool isSharedBuffer(const void* ptr) const = 0;
virtual bool isSharedStorageSupported() const = 0;
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
const = 0;
virtual std::string formatSize(size_t size) const = 0;
virtual void setLowWatermarkRatio(double ratio) const = 0;
virtual void setHighWatermarkRatio(double ratio) const = 0;
@ -34,7 +36,8 @@ public:
virtual size_t getCurrentAllocatedMemory() const = 0;
virtual size_t getDriverAllocatedMemory() const = 0;
virtual size_t getRecommendedMaxMemory() const = 0;
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(
const void* ptr) const = 0;
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
};
@ -43,16 +46,17 @@ class IMpsAllocatorCallback {
public:
enum class EventType {
ALLOCATED, // buffer got allocated to be used immediately
RECYCLED, // buffer pulled from free list to be reused
FREED, // buffer put to free list for future recycling
RELEASED, // buffer memory released
RECYCLED, // buffer pulled from free list to be reused
FREED, // buffer put to free list for future recycling
RELEASED, // buffer memory released
ALLOCATION_FAILED // buffer allocation failed
};
virtual ~IMpsAllocatorCallback() = default;
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
};
// MPS allocator will execute every registered callback when a block of memory is freed.
// MPS allocator will execute every registered callback when a block of memory
// is freed.
TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)

View File

@ -5,7 +5,6 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>

View File

@ -11,7 +11,7 @@ namespace at::mps {
// NOTE: don't create instances of this class directly.
// Use MPSEventPool to acquire instances of MPSEvent.
class MPSEvent {
public:
public:
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
~MPSEvent();
@ -26,16 +26,21 @@ public:
// blocks the CPU thread until all the GPU work that were scheduled
// prior to recording this event are completed.
bool synchronize();
// resets this event with new parameters in case it gets reused from the event pool
// resets this event with new parameters in case it gets reused from the event
// pool
void reset(MPSStream* stream, bool enable_timing);
// returns the unique ID of the event instance
id_t getID() const { return m_id; }
id_t getID() const {
return m_id;
}
// returns the completion timestamp of the event
uint64_t getCompletionTime() const { return m_completion_time; }
uint64_t getCompletionTime() const {
return m_completion_time;
}
// if already recorded, waits for cpu_sync_cv to be signaled
void waitForCpuSync();
private:
private:
id_t m_id;
// enables measuring the completion time of the notifyListener of this event
bool m_enable_timing;
@ -63,7 +68,7 @@ private:
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
class MPSEventPool {
public:
public:
explicit MPSEventPool(MPSStream* default_stream);
~MPSEventPool();
@ -80,7 +85,7 @@ public:
// returns elapsed time between two recorded events in milliseconds
double elapsedTime(id_t start_event_id, id_t end_event_id);
private:
private:
MPSStream* m_default_stream = nullptr;
std::recursive_mutex m_mutex;
std::stack<std::unique_ptr<MPSEvent>> m_pool{};

View File

@ -17,7 +17,8 @@ struct rng_data_pod {
};
TORCH_API const Generator& getDefaultMPSGenerator();
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
TORCH_API Generator
createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
} // namespace mps::detail
@ -37,12 +38,20 @@ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void update_philox_counters();
void set_engine(at::Philox4_32 engine) { engine_ = engine; }
at::Philox4_32 engine() { return engine_; }
uint32_t* state_data() { return data_.state.data(); }
static DeviceType device_type() { return DeviceType::MPS; }
void set_engine(at::Philox4_32 engine) {
engine_ = engine;
}
at::Philox4_32 engine() {
return engine_;
}
uint32_t* state_data() {
return data_.state.data();
}
static DeviceType device_type() {
return DeviceType::MPS;
}
private:
private:
mps::detail::rng_data_pod data_;
at::Philox4_32 engine_;

View File

@ -1,12 +1,12 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/Context.h>
#include <ATen/mps/MPSEvent.h>
#include <ATen/mps/MPSStream.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/Context.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSEvent.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
@ -18,11 +18,10 @@
#include <c10/core/MemoryFormat.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorImpl.h>
#include <sys/_types/_size_t.h>
#include <memory>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/intrusive_ptr.h>
#include <sys/_types/_size_t.h>
#include <memory>
namespace at::mps {
@ -30,7 +29,8 @@ typedef MPSEvent* mpsEvent_t;
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
// https://github.com/pytorch/pytorch/issues/77170
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
struct TORCH_API MPSGuardImpl final
: public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
// constructor
@ -83,7 +83,7 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
}
DeviceIndex deviceCount() const noexcept override {
if (at::hasMPS()) {
//TODO: extend it for multi-device case
// TODO: extend it for multi-device case
return 1;
} else {
return 0;
@ -91,28 +91,22 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
}
// Event-related functions
void createEvent(
mpsEvent_t* event,
const EventFlag flag) const;
void createEvent(mpsEvent_t* event, const EventFlag flag) const;
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override;
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override;
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override;
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override;
void block(
void* event,
const Stream& stream) const override;
void block(void* event, const Stream& stream) const override;
bool queryEvent(void* event) const override;
void synchronizeDevice(const DeviceIndex device_index) const override;
};
/// A variant of OptionalDeviceGuard that is specialized for MPS.
@ -175,7 +169,6 @@ struct OptionalMPSGuard {
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
};
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl)
} // namespace at::mps

View File

@ -2,8 +2,8 @@
#pragma once
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/Generator.h>
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/mps/MPSEvent.h>
#include <optional>
@ -38,7 +38,8 @@ struct MPSHooks : public at::MPSHooksInterface {
Allocator* getPinnedMemoryAllocator() const override;
// MPSProfiler interface
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
const override;
void profilerStopTrace() const override;
// MPSEvent interface
@ -48,7 +49,8 @@ struct MPSHooks : public at::MPSHooksInterface {
void waitForEvent(uint32_t event_id) const override;
void synchronizeEvent(uint32_t event_id) const override;
bool queryEvent(uint32_t event_id) const override;
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
const override;
// Compatibility with Accelerator API
bool hasPrimaryContext(DeviceIndex device_index) const override {

View File

@ -3,11 +3,11 @@
#pragma once
#include <ATen/Tensor.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSStream.h>
#include <os/signpost.h>
#include <os/log.h>
#include <os/signpost.h>
#include <atomic>
#include <ctime>
@ -29,8 +29,8 @@ struct BaseInfo {
CPU_FALLBACK,
};
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
type(infoType), profileId(Id), handle(Handle) { }
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
: type(infoType), profileId(Id), handle(Handle) {}
virtual ~BaseInfo() = default;
// type of profiling info
@ -41,30 +41,36 @@ struct BaseInfo {
// since it's possible to use event and interval-based signposts at the
// same time, we need separate IDs for each.
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
// GPUStartTime")
std::atomic<double> totalGpuTime{0.0};
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
// accumulated Scheduling time in ms (obtained from CompletionHandler's
// "KernelEndTime - KernelStartTime")
std::atomic<double> totalSchedulingTime{0.0};
// indicates if the operation or copy execution has completed
std::atomic_bool completed{false};
// handle used to identify the profile info's instance (usually the pointer)
const uintptr_t handle;
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
virtual const std::string toString(
double gpuTime = 0,
double schedulingTime = 0) const;
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
static std::string buildTensorString(
const Tensor& tensor,
bool includeBufferId = false) {
if (tensor.defined()) {
std::stringstream tensorStr;
auto deviceType = tensor.device().type();
tensorStr << c10::DeviceTypeName(deviceType);
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
<< ":" << buffer.retainCount << ")";
id<MTLBuffer> buffer =
__builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":"
<< buffer.retainCount << ")";
}
tensorStr << ":"
<< tensor.scalar_type() << tensor.sizes();
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";
@ -76,21 +82,28 @@ struct BaseInfo {
};
struct OperationInfo : BaseInfo {
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
OperationInfo(
const void* Handle,
bool IsGraph,
uint64_t Id,
const std::string& StrKey)
: BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
strKey(StrKey) {}
uint64_t runCount = 0;
std::string strKey;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
const override;
// builds a string for a kernel
static std::string buildKernelString(const std::string& kernelName,
const TensorList& tensors,
bool includeBufferId = false) {
static std::string buildKernelString(
const std::string& kernelName,
const TensorList& tensors,
bool includeBufferId = false) {
std::stringstream kernelStr;
kernelStr << kernelName;
for (const Tensor& tensor: tensors) {
for (const Tensor& tensor : tensors) {
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
}
return kernelStr.str();
@ -98,23 +111,24 @@ struct OperationInfo : BaseInfo {
};
struct CpuFbInfo : BaseInfo {
CpuFbInfo(uint64_t Id, const std::string& OpName) :
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
CpuFbInfo(uint64_t Id, const std::string& OpName)
: BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
uint64_t runCount = 0;
// the current and total overhead of copies in bytes required to convert the Op's
// input tensors from MPS to CPU and then output from CPU back to MPS
// the current and total overhead of copies in bytes required to convert the
// Op's input tensors from MPS to CPU and then output from CPU back to MPS
size_t currentCopyOverhead = 0;
size_t totalCopyOverhead = 0;
std::string opName;
std::string strKey;
uint64_t startTime = 0;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
const override;
void updateCopyOverhead(const TensorList& tensors) {
currentCopyOverhead = 0;
for (const Tensor& tensor: tensors) {
for (const Tensor& tensor : tensors) {
if (tensor.defined()) {
currentCopyOverhead += tensor.nbytes();
}
@ -130,9 +144,17 @@ struct CopyInfo : BaseInfo {
CPU_TO_MPS,
};
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
CopyInfo(
const void* Handle,
size_t Length,
uint64_t Id,
bool IsNonBlocking,
bool UsesBlitter)
: BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
kind(Kind::MPS_TO_MPS),
length(Length),
isNonBlocking(IsNonBlocking),
usesBlitter(UsesBlitter) {}
Kind kind;
size_t length;
@ -143,11 +165,17 @@ struct CopyInfo : BaseInfo {
// for copies that don't use blitters, we measure CPU time
uint64_t startTime = 0;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
const override;
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
static std::string buildTensorString(
const void* buffer,
const OptionalTensorRef tensor,
bool includeBufferId = false);
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
static bool isStorageOnMPS(
const void* buffer,
const OptionalTensorRef tensor) {
if (tensor.has_value()) {
return tensor->device().type() == at::kMPS;
}
@ -156,8 +184,11 @@ struct CopyInfo : BaseInfo {
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
}
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
static Kind getCopyKind(
const void* srcBuffer,
const void* dstBuffer,
const OptionalTensorRef srcTensor,
const OptionalTensorRef dstTensor) {
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
@ -171,8 +202,9 @@ struct CopyInfo : BaseInfo {
};
struct CopyStat : CopyInfo {
explicit CopyStat(std::string CopyKindStr) :
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
explicit CopyStat(std::string CopyKindStr)
: CopyInfo(nullptr, 0, 0, false, false),
kindStr(std::move(CopyKindStr)) {}
// total number of copies
size_t totalCount = 0;
// number of Scalar copies (i.e., less than sizeof(int64))
@ -188,29 +220,29 @@ struct CopyStat : CopyInfo {
};
class MPSProfiler {
public:
public:
// lower 16 bits used for profiler options
enum ProfileOptions : uint32_t {
OPTIONS_NONE = 0,
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
// (used for convenience to not compute bit flags by OR-ing manually)
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
// etc.) (used for convenience to not compute bit flags by OR-ing manually)
// trace all signpost types using events
ALL_SIGNPOST_EVENTS = (1 << 0),
ALL_SIGNPOST_EVENTS = (1 << 0),
// trace all signpost types using intervals
ALL_SIGNPOST_INTERVALS = (1 << 1),
// always wait for command buffer to finish executing after each commit
WAIT_UNTIL_COMPLETED = (1 << 2),
WAIT_UNTIL_COMPLETED = (1 << 2),
// for interval-based signposts, include the scheduling portion of
// Graph/Kernel/Copy executions as well.
// if flag is disable, only "GPU run time" is included in interval,
// and not schedule time.
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
// use these if you need to trace signposts types individually (rarely required)
// trace signpost using intervals
// use these if you need to trace signposts types individually (rarely
// required) trace signpost using intervals
USE_INTERVALS = (1 << 4),
// trace signpost by emitting events
USE_EVENTS = (1 << 5),
USE_EVENTS = (1 << 5),
// used for sanity check (Change this when new option added)
OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
};
@ -222,9 +254,9 @@ public:
// trace signposts for PyTorch operation executions
RUN_OPERATION = (1 << 16),
// trace signposts for blitter copies
BLIT_COPY = (1 << 17),
BLIT_COPY = (1 << 17),
// trace signposts for ops that fall back on CPU
CPU_FALLBACK = (1 << 18),
CPU_FALLBACK = (1 << 18),
// used for sanity check (Change this when new type added)
SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
};
@ -235,39 +267,44 @@ public:
// Info logging options during execution
// -------------------------------------
// prints operation info (id/key/run_count) during execution
OPERATION_INFO = (1 << 0),
OPERATION_INFO = (1 << 0),
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
COPY_INFO = (1 << 1),
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
CPU_FALLBACK_INFO = (1 << 2),
COPY_INFO = (1 << 1),
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during
// execution
CPU_FALLBACK_INFO = (1 << 2),
// Profiling Statistics logging options when process terminates
// ------------------------------------------------------------
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
// this is convenient to not combine following stats bit flags manually
ALL_STATS = (1 << 3),
// prints operation stats (GPU times, run count, etc.) before process terminates
OPERATION_STATS = (1 << 4),
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
COPY_STATS = (1 << 5),
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
// process terminates this is convenient to not combine following stats bit
// flags manually
ALL_STATS = (1 << 3),
// prints operation stats (GPU times, run count, etc.) before process
// terminates
OPERATION_STATS = (1 << 4),
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process
// terminates
COPY_STATS = (1 << 5),
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
// for tensors, etc.) before process terminates
CPU_FALLBACK_STATS = (1 << 6),
CPU_FALLBACK_STATS = (1 << 6),
// Metadata format options when logging the info
// ---------------------------------------------
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
INCLUDE_GPU_TIME = (1 << 7),
// if enabled, includes GPU run time in metadata (i.e.,
// GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
// ms])
INCLUDE_GPU_TIME = (1 << 7),
// if enabled, includes GPU scheduling time in metadata separately
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
INCLUDE_KERNEL_TIME = (1 << 8),
// if enabled, includes the unique buffer ID in metadata for the storage
// of a tensor that was allocated on MPSAllocator. This is useful (along with
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
// with various operations.
INCLUDE_BUFFER_ID = (1 << 9),
// of a tensor that was allocated on MPSAllocator. This is useful (along
// with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
// involved with various operations.
INCLUDE_BUFFER_ID = (1 << 9),
// used for sanity check (Change this when new option added)
LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
@ -276,15 +313,28 @@ public:
explicit MPSProfiler();
~MPSProfiler();
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
const OptionalTensorRef srcTensor,
const OptionalTensorRef dstTensor,
size_t length, bool isNonBlocking, bool usesBlitter = true);
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal
// Kernels the beginProfile*() functions return a profileId which is unique
// per graph/kernel/copy
uint64_t beginProfileKernel(
const void* handle,
const std::string& strKey,
bool isGraph);
uint64_t beginProfileKernel(
const void* handle,
const std::string& kernelName,
const TensorList& tensors);
uint64_t beginProfileCopy(
const void* srcBuffer,
const void* dstBuffer,
const OptionalTensorRef srcTensor,
const OptionalTensorRef dstTensor,
size_t length,
bool isNonBlocking,
bool usesBlitter = true);
uint64_t beginProfileCPUFallback(
const std::string& opName,
const TensorList& tensors);
void beginProfileGPUInterval(const void* handle);
void endProfileCopy(uint64_t profileId, SyncType syncType);
@ -309,22 +359,25 @@ public:
// logging are enabled for the SignpostTypes
bool isOperationProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
(m_log_options &
(LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
}
bool isCopyProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
}
bool isCPUFallbackProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
(m_log_options &
(LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
}
bool isSignpostTracingEnabled() const {
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
}
private:
// indicates what type of signpost types are enabled and traced by MPS profiler.
// indicates what type of signpost types are enabled and traced by MPS
// profiler.
uint32_t m_signpost_types = 0;
uint32_t m_profile_options = 0;
uint32_t m_log_options = 0;
@ -332,14 +385,15 @@ public:
uint64_t m_graph_counter = 0;
uint64_t m_cpu_fb_counter = 0;
uint64_t m_copy_counter = 0;
// technically, it's possible to trace both events and intervals at the same time
// so we use separate os_log categories for them
// technically, it's possible to trace both events and intervals at the same
// time so we use separate os_log categories for them
os_log_t m_os_log_events;
os_log_t m_os_log_intervals;
// stats logging could run either from destructor or signal handler
// so this is used to check if logging has already started.
std::atomic_bool hasLoggedStats{false};
// indicates there are pending completionHandler callbacks that haven't been called yet.
// indicates there are pending completionHandler callbacks that haven't been
// called yet.
std::atomic_bool hasPendingCompletionHandlers{false};
// used to capture sigint signal to log profiling stats
static struct sigaction currentSigint, previousSigint;
@ -347,40 +401,62 @@ public:
// We use the following lists for two reasons:
// 1- for interval-based signposts the "begin" point won't be in same function
// as the "end" point where we need to be able to retrieve signpost's info
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
// 2- if Operations info need to be logged when process ends using
// LogOptions::OPERATION_INFO.
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
// the string key for this map is the op name that we fall back to execute on CPU
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
// the pointer key for this map is either "MPSGraph*" or
// "id<MTLComputePipelineState>" for Metal Kernels this list is retained and
// could be logged along with aggregate profiling numbers when the process
// ends.
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>>
m_op_info_list{};
// the string key for this map is the op name that we fall back to execute on
// CPU this list is retained and could be logged along with aggregate
// profiling numbers when the process ends.
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>>
m_cpu_fb_info_list{};
// this list contains the info for copies, and its key is the unique profileId
// which is generated from m_copy_counter
// The copyInfo list is not retained.
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
// a short list that contains copy stats
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>>
m_copy_stat_list{};
mutable MTLCaptureManager *captureManager = nil;
mutable MTLCaptureManager* captureManager = nil;
unsigned captureCount = 0;
void initialize();
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
os_signpost_id_t interval_signpost_id,
double gpuTime, double schedulingTime);
void endProfileExecution(
BaseInfo& info,
os_signpost_id_t event_signpost_id,
os_signpost_id_t interval_signpost_id,
double gpuTime,
double schedulingTime);
void addProfilerScheduledHandler(BaseInfo& info);
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
const std::string& msg) const;
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
const std::string& msg) const;
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
void emitSignpostEvent(
SignpostTypes signpost_type,
os_signpost_id_t signpost_id,
const std::string& msg) const;
void beginSignpostInterval(
SignpostTypes signpost_type,
os_signpost_id_t signpost_id,
const std::string& msg) const;
void endSignpostInterval(
SignpostTypes signpost_type,
os_signpost_id_t signpost_id) const;
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
// returns true if logging the profiling info "during the execution" is enabled
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
void updateCopyStats(
const CopyInfo& copyInfo,
double gpuTime,
double schedulingTime);
// returns true if logging the profiling info "during the execution" is
// enabled
bool isProfileInfoLoggingEnabled(
BaseInfo::Type infoType,
bool isExecutionEnded);
// logs all the profiling stats that are enabled
void logProfilingStats();
// logs kernel profiling stats when the process ends.
@ -390,7 +466,9 @@ public:
// logs copy profiling stats when the process ends.
void logCopyProfilingStats(std::FILE* f) const;
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
os_signpost_id_t generateSignpostId(
os_signpost_type_t signpostType,
const void* ptr = nullptr);
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
static void handleIntSignal(int signal);
};

View File

@ -5,10 +5,10 @@
#include <cstdint>
#include <utility>
#include <c10/core/DeviceGuard.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
#include <ATen/mps/MPSDevice.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
@ -32,7 +32,6 @@ typedef void* MTLDevice_t;
#define nil NULL;
#endif
namespace at::mps {
//-----------------------------------------------------------------
@ -40,16 +39,15 @@ namespace at::mps {
//-----------------------------------------------------------------
enum class SyncType {
NONE, // no commit to command buffer
COMMIT, // commit and flush the command buffer
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
COMMIT_ADAPTIVE, // commit adaptively based on available memory
NONE, // no commit to command buffer
COMMIT, // commit and flush the command buffer
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
COMMIT_AND_CONTINUE, // commit and continue with a new underlying command buffer
COMMIT_ADAPTIVE, // commit adaptively based on available memory
};
class TORCH_API MPSStream
{
public:
class TORCH_API MPSStream {
public:
enum Unchecked { UNCHECKED };
/// Construct a MPSStream from a Stream. This construction is checked,
@ -57,41 +55,64 @@ public:
explicit MPSStream(Stream stream);
~MPSStream();
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
dispatch_queue_t queue() const { return _serialQueue; }
MTLCommandQueue_t commandQueue() const {
return _commandQueue;
};
dispatch_queue_t queue() const {
return _serialQueue;
}
MPSCommandBuffer* commandBuffer();
MTLComputeCommandEncoder_t commandEncoder();
void endKernelCoalescing();
void synchronize(SyncType syncType);
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset,
uint64_t profileId, SyncType syncType = SyncType::NONE);
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset,
bool non_blocking, uint64_t profileId);
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
void copy(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
uint64_t profileId,
SyncType syncType = SyncType::NONE);
void copy_and_sync(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
bool non_blocking,
uint64_t profileId);
void executeMPSGraph(MPSGraph* mpsGraph,
NSDictionary* feeds,
NSDictionary* results,
SyncType syncType = SyncType::NONE);
void addCompletedHandler(MTLCommandBufferHandler block);
/// Get the MPS device index that this stream is associated with.
c10::DeviceIndex device_index() const { return _stream.device_index(); }
c10::DeviceIndex device_index() const {
return _stream.device_index();
}
MTLCommandQueue_t stream() const { return _commandQueue; };
MTLCommandQueue_t stream() const {
return _commandQueue;
};
MTLDevice_t device() const { return [_commandQueue device];}
MTLDevice_t device() const {
return [_commandQueue device];
}
/// Explicit conversion to Stream.
Stream unwrap() const { return _stream; }
Stream unwrap() const {
return _stream;
}
private:
private:
Stream _stream;
MTLCommandQueue_t _commandQueue = nil;
MPSCommandBuffer* _commandBuffer = nil;
MPSCommandBuffer* _prevCommandBuffer = nil;
MTLComputeCommandEncoder_t _commandEncoder = nil;
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
MPSGraphExecutionDescriptor* _executionDescriptor = nil;
MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
dispatch_queue_t _serialQueue = nullptr;
// CommitAndContinue is enabled by default
bool _enableCommitAndContinue = true;
@ -117,8 +138,7 @@ TORCH_API MPSStream* getDefaultMPSStream();
// MPSStreamImpl
//-----------------------------------------------------------------
class TORCH_API MPSStreamImpl
{
class TORCH_API MPSStreamImpl {
public:
/**
* Gets single instance of the MPSStream.

View File

@ -2,44 +2,40 @@
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
#if !defined(__MAC_15_0) && \
(!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
#if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel
-(MPSNDArray * __nullable) reshapeWithCommandBuffer: (__nullable id <MTLCommandBuffer>) cmdBuf
sourceArray: (MPSNDArray * __nonnull) sourceArray
shape: (MPSShape * __nonnull) shape
destinationArray: (MPSNDArray * __nullable) destinationArray;
- (MPSNDArray* __nullable)reshapeWithCommandBuffer:(__nullable id<MTLCommandBuffer>)cmdBuf
sourceArray:(MPSNDArray* __nonnull)sourceArray
shape:(MPSShape* __nonnull)shape
destinationArray:(MPSNDArray* __nullable)destinationArray;
@end
@interface MPSNDArrayDescriptor()
@property (readwrite, nonatomic) BOOL preferPackedRows;
@interface MPSNDArrayDescriptor ()
@property(readwrite, nonatomic) BOOL preferPackedRows;
@end
@interface MPSNDArray()
-(nonnull instancetype) initWithBuffer:(id<MTLBuffer> _Nonnull) buffer
offset:(NSUInteger) offset
descriptor:(MPSNDArrayDescriptor * _Nonnull) descriptor;
-(MPSNDArray * __nullable) arrayViewWithShape:(MPSShape * _Nullable) shape
strides:(MPSShape * _Nonnull) strides;
@interface MPSNDArray ()
- (nonnull instancetype)initWithBuffer:(id<MTLBuffer> _Nonnull)buffer
offset:(NSUInteger)offset
descriptor:(MPSNDArrayDescriptor* _Nonnull)descriptor;
- (MPSNDArray* __nullable)arrayViewWithShape:(MPSShape* _Nullable)shape strides:(MPSShape* _Nonnull)strides;
@end
typedef NS_ENUM(NSInteger, MTLMathMode)
{
MTLMathModeSafe = 0,
MTLMathModeRelaxed = 1,
MTLMathModeFast = 2,
typedef NS_ENUM(NSInteger, MTLMathMode) {
MTLMathModeSafe = 0,
MTLMathModeRelaxed = 1,
MTLMathModeFast = 2,
};
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions)
{
MTLMathFloatingPointFunctionsFast = 0,
MTLMathFloatingPointFunctionsPrecise = 1,
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) {
MTLMathFloatingPointFunctionsFast = 0,
MTLMathFloatingPointFunctionsPrecise = 1,
};
@interface MTLCompileOptions()
@property (readwrite, nonatomic) MTLMathMode mathMode;
@property (readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions;
@interface MTLCompileOptions ()
@property(readwrite, nonatomic) MTLMathMode mathMode;
@property(readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions;
@end
#endif

View File

@ -2,52 +2,47 @@
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
#if !defined(__MAC_14_0) && \
(!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
{
MPSGraphFFTScalingModeNone = 0L,
MPSGraphFFTScalingModeSize = 1L,
MPSGraphFFTScalingModeUnitary = 2L,
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) {
MPSGraphFFTScalingModeNone = 0L,
MPSGraphFFTScalingModeSize = 1L,
MPSGraphFFTScalingModeUnitary = 2L,
};
@interface FakeMPSGraphFFTDescriptor : NSObject<NSCopying>
@property (readwrite, nonatomic) BOOL inverse;
@property (readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode;
@property (readwrite, nonatomic) BOOL roundToOddHermitean;
+(nullable instancetype) descriptor;
@property(readwrite, nonatomic) BOOL inverse;
@property(readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode;
@property(readwrite, nonatomic) BOOL roundToOddHermitean;
+ (nullable instancetype)descriptor;
@end
@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor;
@interface MPSGraph (SonomaOps)
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor
axes:(NSArray<NSNumber*>* _Nonnull)axes
descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
name:(NSString* _Nullable)name;
-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor
axes:(NSArray<NSNumber *> * _Nonnull) axes
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)realToHermiteanFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor
axes:(NSArray<NSNumber*>* _Nonnull)axes
descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
name:(NSString* _Nullable)name;
-(MPSGraphTensor * _Nonnull) realToHermiteanFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor
axes:(NSArray<NSNumber *> * _Nonnull) axes
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
name:(NSString * _Nullable) name;
-(MPSGraphTensor * _Nonnull) HermiteanToRealFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor
axes:(NSArray<NSNumber *> * _Nonnull) axes
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)HermiteanToRealFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor
axes:(NSArray<NSNumber*>* _Nonnull)axes
descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
name:(NSString* _Nullable)name;
@end
// define BFloat16 enums for MacOS13
#define MPSDataTypeBFloat16 ((MPSDataType) (MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16))
#define MPSDataTypeBFloat16 ((MPSDataType)(MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16))
// define Metal version
#define MTLLanguageVersion3_1 ((MTLLanguageVersion) ((3 << 16) + 1))
#define MTLLanguageVersion3_1 ((MTLLanguageVersion)((3 << 16) + 1))
#endif

View File

@ -2,30 +2,29 @@
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
// TODO: Remove me when moved to MacOS 13
#if !defined(__MAC_13_2) && \
(!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject<NSCopying>
@property (readwrite, nonatomic) NSUInteger strideInX;
@property (readwrite, nonatomic) NSUInteger strideInY;
@property (readwrite, nonatomic) NSUInteger strideInZ;
@property (readwrite, nonatomic) NSUInteger dilationRateInX;
@property (readwrite, nonatomic) NSUInteger dilationRateInY;
@property (readwrite, nonatomic) NSUInteger dilationRateInZ;
@property(readwrite, nonatomic) NSUInteger strideInX;
@property(readwrite, nonatomic) NSUInteger strideInY;
@property(readwrite, nonatomic) NSUInteger strideInZ;
@property(readwrite, nonatomic) NSUInteger dilationRateInX;
@property(readwrite, nonatomic) NSUInteger dilationRateInY;
@property(readwrite, nonatomic) NSUInteger dilationRateInZ;
@property (readwrite, nonatomic) NSUInteger paddingLeft;
@property (readwrite, nonatomic) NSUInteger paddingRight;
@property (readwrite, nonatomic) NSUInteger paddingTop;
@property (readwrite, nonatomic) NSUInteger paddingBottom;
@property (readwrite, nonatomic) NSUInteger paddingFront;
@property (readwrite, nonatomic) NSUInteger paddingBack;
@property(readwrite, nonatomic) NSUInteger paddingLeft;
@property(readwrite, nonatomic) NSUInteger paddingRight;
@property(readwrite, nonatomic) NSUInteger paddingTop;
@property(readwrite, nonatomic) NSUInteger paddingBottom;
@property(readwrite, nonatomic) NSUInteger paddingFront;
@property(readwrite, nonatomic) NSUInteger paddingBack;
@property (readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle;
@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout;
@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout;
@property(readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle;
@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout;
@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout;
@property (readwrite, nonatomic) NSUInteger groups;
@property(readwrite, nonatomic) NSUInteger groups;
@end
@ -35,163 +34,163 @@
@interface MPSGraph (VenturaOps)
#if !defined(__MAC_13_0) && \
(!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
{
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
MPSGraphResizeNearestRoundingModeCeil = 2L,
MPSGraphResizeNearestRoundingModeFloor = 3L,
MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) {
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
MPSGraphResizeNearestRoundingModeCeil = 2L,
MPSGraphResizeNearestRoundingModeFloor = 3L,
MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
};
// Define complex enums for MacOS 12
#define MPSDataTypeComplexBit 0x01000000
#define MPSDataTypeComplexFloat32 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
#define MPSDataTypeComplexFloat16 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
#define MPSDataTypeComplexFloat32 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
#define MPSDataTypeComplexFloat16 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
#endif
- (MPSGraphTensor * _Nonnull) convolution3DWithSourceTensor:(MPSGraphTensor * _Nonnull) source
weightsTensor:(MPSGraphTensor * _Nonnull) weights
descriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) descriptor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)convolution3DWithSourceTensor:(MPSGraphTensor* _Nonnull)source
weightsTensor:(MPSGraphTensor* _Nonnull)weights
descriptor:(MPSGraphConvolution3DOpDescriptor* _Nonnull)descriptor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
weightsTensor:(MPSGraphTensor * _Nonnull) weights
outputShape:(MPSShape * _Nonnull) outputShape
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)
convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
weightsTensor:(MPSGraphTensor* _Nonnull)weights
outputShape:(MPSShape* _Nonnull)outputShape
forwardConvolutionDescriptor:
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
sourceTensor:(MPSGraphTensor * _Nonnull) source
outputShape:(MPSShape * _Nonnull) outputShape
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)
convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
sourceTensor:(MPSGraphTensor* _Nonnull)source
outputShape:(MPSShape* _Nonnull)outputShape
forwardConvolutionDescriptor:
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor
axis:(NSInteger)axis
name:(NSString * _Nullable)name;
- (MPSGraphTensor* _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor* _Nonnull)tensor
axis:(NSInteger)axis
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor
axis:(NSInteger)axis
name:(NSString * _Nullable)name;
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axis:(NSInteger)axis
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
axis:(NSInteger) axis
descending:(BOOL) descending
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axis:(NSInteger)axis
descending:(BOOL)descending
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
descending:(BOOL) descending
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
descending:(BOOL)descending
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor
axis:(NSInteger)axis
name:(NSString * _Nullable)name;
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axis:(NSInteger)axis
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
axis:(NSInteger) axis
descending:(BOOL) descending
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axis:(NSInteger)axis
descending:(BOOL)descending
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
descending:(BOOL) descending
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
descending:(BOOL)descending
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor
name:(NSString * _Nullable)name;
- (MPSGraphTensor* _Nonnull)inverseOfTensor:(MPSGraphTensor* _Nonnull)inputTensor name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
sizeTensor:(MPSGraphTensor * _Nonnull) size
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
centerResult:(BOOL) centerResult
alignCorners:(BOOL) alignCorners
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
sizeTensor:(MPSGraphTensor* _Nonnull)size
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
centerResult:(BOOL)centerResult
alignCorners:(BOOL)alignCorners
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
sizeTensor:(MPSGraphTensor * _Nonnull) size
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
sizeTensor:(MPSGraphTensor* _Nonnull)size
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
sizeTensor:(MPSGraphTensor * _Nonnull) size
centerResult:(BOOL) centerResult
alignCorners:(BOOL) alignCorners
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
sizeTensor:(MPSGraphTensor* _Nonnull)size
centerResult:(BOOL)centerResult
alignCorners:(BOOL)alignCorners
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
sizeTensor:(MPSGraphTensor * _Nonnull) size
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
sizeTensor:(MPSGraphTensor* _Nonnull)size
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
input:(MPSGraphTensor * _Nonnull) input
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
centerResult:(BOOL) centerResult
alignCorners:(BOOL) alignCorners
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
input:(MPSGraphTensor* _Nonnull)input
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
centerResult:(BOOL)centerResult
alignCorners:(BOOL)alignCorners
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
input:(MPSGraphTensor * _Nonnull) input
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
input:(MPSGraphTensor* _Nonnull)input
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
input:(MPSGraphTensor * _Nonnull) input
centerResult:(BOOL) centerResult
alignCorners:(BOOL) alignCorners
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
input:(MPSGraphTensor* _Nonnull)input
centerResult:(BOOL)centerResult
alignCorners:(BOOL)alignCorners
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
input:(MPSGraphTensor * _Nonnull) input
scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
layout:(MPSGraphTensorNamedDataLayout) layout
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient
input:(MPSGraphTensor* _Nonnull)input
scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset
layout:(MPSGraphTensorNamedDataLayout)layout
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source
coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates
layout:(MPSGraphTensorNamedDataLayout) layout
normalizeCoordinates:(BOOL) normalizeCoordinates
relativeCoordinates:(BOOL) relativeCoordinates
alignCorners:(BOOL) alignCorners
paddingMode:(MPSGraphPaddingMode) paddingMode
samplingMode:(MPSGraphResizeMode) samplingMode
constantValue:(double) constantValue
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source
coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates
layout:(MPSGraphTensorNamedDataLayout)layout
normalizeCoordinates:(BOOL)normalizeCoordinates
relativeCoordinates:(BOOL)relativeCoordinates
alignCorners:(BOOL)alignCorners
paddingMode:(MPSGraphPaddingMode)paddingMode
samplingMode:(MPSGraphResizeMode)samplingMode
constantValue:(double)constantValue
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source
coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates
layout:(MPSGraphTensorNamedDataLayout) layout
normalizeCoordinates:(BOOL) normalizeCoordinates
relativeCoordinates:(BOOL) relativeCoordinates
alignCorners:(BOOL) alignCorners
paddingMode:(MPSGraphPaddingMode) paddingMode
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
constantValue:(double) constantValue
name:(NSString * _Nullable) name;
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source
coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates
layout:(MPSGraphTensorNamedDataLayout)layout
normalizeCoordinates:(BOOL)normalizeCoordinates
relativeCoordinates:(BOOL)relativeCoordinates
alignCorners:(BOOL)alignCorners
paddingMode:(MPSGraphPaddingMode)paddingMode
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
constantValue:(double)constantValue
name:(NSString* _Nullable)name;
- (MPSGraphTensor* _Nonnull)truncateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
@end

View File

@ -26,7 +26,7 @@
// Fwd declarations
namespace at {
struct TensorIteratorBase;
struct TensorIteratorBase;
}
using namespace at::mps;
@ -35,7 +35,9 @@ namespace at::native::mps {
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
struct MPSScalar {
id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
id<MTLBuffer> getMTLBuffer() const {
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
}
size_t size = 0;
ScalarType type = ScalarType::Undefined;
@ -48,13 +50,10 @@ struct MPSScalar {
c10::complex<float> cf;
c10::complex<at::Half> ch;
at::BFloat16 bf16;
} value {};
} value{};
};
void runMPSGraph(MPSStream* mpsStream,
MPSGraph* mpsGraph,
NSDictionary* feeds,
NSDictionary* results);
void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
MPSDataType getMPSDataType(ScalarType scalar_type);
static inline MPSDataType getMPSDataType(const TensorBase& t) {
@ -64,7 +63,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type);
static inline MPSDataType getMPSScalarType(const TensorBase& t) {
return getMPSScalarType(t.scalar_type());
}
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) {
return getMPSTypeString(t.scalar_type(), short_name);
@ -81,10 +80,18 @@ std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
bool canSliceViewTensor(const TensorBase& src, MPSShape *mpsShape);
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
bool canSliceViewTensor(const TensorBase& src, MPSShape* mpsShape);
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src,
MPSShape* mpsShape,
const MPSDataType mpsDataType);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
@ -102,8 +109,12 @@ class Placeholder {
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
Placeholder(MPSGraphTensor* mpsGraphTensor,
const Tensor& self,
MPSShape* mpsShape = nullptr,
bool gatherTensorData = true,
MPSDataType dataType = MPSDataTypeInvalid,
bool useMPSStridedAPI = true);
MPSGraphTensor* getMPSGraphTensor() {
return _placeholder;
}
@ -123,21 +134,21 @@ class Placeholder {
void resize_tensor(Tensor* output);
Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor);
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor);
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor);
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
MPSGraph* make_mps_graph();
void printTensorNDArray(const TensorBase& t);
MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape *shape, MPSDataType mpsType);
MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType);
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const TensorBase& tensor);
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor);
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType);
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar);
string get_mem_format_string(c10::MemoryFormat memory_format);
@ -145,75 +156,73 @@ using MPSCacheKey = uint64_t;
// derive this class to cache a graph and its inputs/outputs
// can be used to store any NSObject
struct MPSCachedGraph
{
MPSCachedGraph(NSObject *object) : _object([object retain]) {}
struct MPSCachedGraph {
MPSCachedGraph(NSObject* object) : _object([object retain]) {}
virtual ~MPSCachedGraph() {
[_object release];
_object = nullptr;
[_object release];
_object = nullptr;
}
template<typename T>
template <typename T>
inline T* as() {
return static_cast<T*>(this);
}
MPSGraph *graph() const { return (MPSGraph *)_object; }
NSObject *object() const { return _object; }
private:
NSObject *_object = nullptr;
MPSGraph* graph() const {
return (MPSGraph*)_object;
}
NSObject* object() const {
return _object;
}
private:
NSObject* _object = nullptr;
};
struct MPSUnaryCachedGraph : public MPSCachedGraph
{
MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct MPSUnaryCachedGraph : public MPSCachedGraph {
MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
struct MPSUnaryGradCachedGraph : public MPSCachedGraph
{
MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
MPSGraphTensor *gradInputTensor_ = nil;
struct MPSUnaryGradCachedGraph : public MPSCachedGraph {
MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil; // some backward input is actually the forward's output
MPSGraphTensor* gradInputTensor_ = nil;
};
struct MPSBinaryCachedGraph : public MPSCachedGraph
{
MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct MPSBinaryCachedGraph : public MPSCachedGraph {
MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
struct MPSBinaryGradCachedGraph : public MPSCachedGraph
{
MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *gradInputTensor_ = nil;
struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
MPSGraphTensor* gradInputTensor_ = nil;
};
// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
// Cache holding various keys mapped to graphs
struct MPSGraphCache
{
typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
struct MPSGraphCache {
typedef MPSCachedGraph* (^CreateCachedGraphBlock)();
struct CacheEntry {
CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
CacheEntry(const std::string& key, MPSCachedGraph* cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
MPSCachedGraph* cachedGraph_ = nullptr;
std::string key_;
};
public:
static MPSGraphCache* getInstance() {
if(_instance_cache == nullptr) {
if (_instance_cache == nullptr) {
_instance_cache = new MPSGraphCache();
}
return _instance_cache;
@ -232,7 +241,6 @@ struct MPSGraphCache
void operator=(const MPSGraphCache&) = delete;
MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
__block MPSCachedGraph* cachedGraph = nil;
MPSCacheKey hash = std::hash<std::string>{}(key);
@ -253,19 +261,17 @@ struct MPSGraphCache
return cachedGraph;
}
template<typename T>
template <typename T>
inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
return static_cast<T*>(CreateCachedGraph(key, createCacheBlock));
}
MPSCachedGraph* LookUp(const std::string& key) const {
__block MPSCachedGraph* cachedGraph = nullptr;
MPSCacheKey hash = std::hash<std::string>{}(key);
dispatch_sync(serialQueue_, ^() {
if (cache_.count(hash) != 0) {
auto& entry = cache_.at(hash);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
@ -276,9 +282,9 @@ struct MPSGraphCache
return cachedGraph;
}
template<typename T>
template <typename T>
inline T* LookUpAs(const std::string& key) const {
return static_cast<T *>(LookUp(key));
return static_cast<T*>(LookUp(key));
}
private:
@ -292,14 +298,13 @@ struct MPSGraphCache
static MPSGraphCache* _instance_cache;
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
dispatch_queue_t serialQueue_ = nullptr;
};
// Common template for creating graph with a specified cache if missing
template<typename T>
template <typename T>
inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
auto cache_ = MPSGraphCache::getInstance();
if (auto rc = cache_->LookUpAs<T>(key)) {
if (auto rc = cache_->LookUpAs<T>(key)) {
return rc;
}
return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
@ -317,10 +322,12 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(M
// Common math operations
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
TORCH_WARN_ONCE( \
"MPS: no support for int64 for ", \
op_name, \
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
}
/**
@ -335,17 +342,19 @@ inline bool is_dense_in_storage(const TensorBase& t) {
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
}
class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src), nparams(0), compile_options(nullptr) {}
MetalShaderLibrary(const std::string& src, unsigned nparams_)
: shaderSource(src), nparams(nparams_), compile_options(nullptr) {}
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_)
: shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
}
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).first;
}
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
@ -355,12 +364,15 @@ public:
return getLibraryPipelineState(getLibrary(params), fname).second;
}
static MetalShaderLibrary& getBundledLibrary();
protected:
protected:
virtual id<MTLLibrary> getLibrary();
virtual id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
id<MTLLibrary> library = nil;
private:
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
private:
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib,
const std::string& fname);
id<MTLLibrary> compileLibrary(const std::string& src);
std::string shaderSource;
@ -370,24 +382,21 @@ private:
std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
};
template<typename encoder_t,
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> || std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
template <typename encoder_t,
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) {
[encoder setBuffer:getMTLBufferStorage(t)
offset:t.storage_offset() * t.element_size()
atIndex:idx];
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
}
template<typename T,
typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
template <typename T, typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const T val, unsigned idx) {
[encoder setBytes:&val length:sizeof(T) atIndex: idx];
[encoder setBytes:&val length:sizeof(T) atIndex:idx];
}
template<typename Container,
typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
template <typename Container, typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Container& values, unsigned idx) {
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex: idx];
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx];
}
static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
@ -400,38 +409,40 @@ static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
[encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
}
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder,
const TensorIteratorBase& iter,
bool use_64bit_index = false);
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()};
}
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
return @{
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
};
return @{
p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(),
p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(),
};
}
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
return @{
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
};
return @{
p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(),
p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(),
p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(),
};
}
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
return @{
p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
};
return @{
p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(),
p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(),
p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(),
p4.getMPSGraphTensor() : p4.getMPSGraphTensorData(),
};
}
inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
}
inline bool supportsComplex() {
@ -464,7 +475,7 @@ inline void checkSupportsBFloat16() {
inline bool needsGather(const TensorBase& t) {
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ;
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());
}
} // namespace at::native::mps

View File

@ -1,13 +1,14 @@
// Copyright © 2022 Apple Inc.
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) AT_DISPATCH_CASE( \
at::ScalarType::Half, \
__VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))

View File

@ -1,5 +1,8 @@
#pragma once
namespace at::native::mps {
void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output);
void complex_mul_out(
const Tensor& input,
const Tensor& other,
const Tensor& output);
}

View File

@ -17,8 +17,7 @@ void _fused_adam_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adam_amsgrad_mps_impl_(
TensorList params,
@ -34,7 +33,6 @@ void _fused_adam_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<at::Tensor>& grad_scale,
const std::optional<at::Tensor>& found_inf
);
const std::optional<at::Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -16,8 +16,7 @@ void _fused_adam_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adam_mps_impl_(
TensorList params,
@ -32,6 +31,5 @@ void _fused_adam_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -17,8 +17,7 @@ void _fused_adamw_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adamw_amsgrad_mps_impl_(
TensorList params,
@ -34,6 +33,5 @@ void _fused_adamw_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -16,8 +16,7 @@ void _fused_adamw_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adamw_mps_impl_(
TensorList params,
@ -32,7 +31,6 @@ void _fused_adamw_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -436,10 +436,12 @@ REGISTER_FUSED_SGD_MOMENTUM_OP(half);
)METAL";
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(
const std::string& fname) {
static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
return std::make_pair(
lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
}
} //namespace mps
} // namespace mps
} // namespace at::native

View File

@ -16,18 +16,15 @@ struct MetadataArguments { // the size of this struct must be less than 4 bytes
};
struct FusedAdamEncodingFunctor {
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize
) const {
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize) const {
float lr_lv = lr;
float beta1_lv = beta1;
float beta2_lv = beta2;
@ -35,12 +32,8 @@ struct FusedAdamEncodingFunctor {
float eps_lv = eps;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, lr_lv, 2);
mtl_setBytes(computeEncoder, beta1_lv, 3);
mtl_setBytes(computeEncoder, beta2_lv, 4);
@ -49,29 +42,23 @@ struct FusedAdamEncodingFunctor {
mtl_setBytes(computeEncoder, maximize_lv, 7);
}
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize
) const {
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize) const {
float beta1_lv = beta1;
float beta2_lv = beta2;
float weight_decay_lv = weight_decay;
float eps_lv = eps;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBuffer(computeEncoder, lr, 2);
mtl_setBytes(computeEncoder, beta1_lv, 3);
mtl_setBytes(computeEncoder, beta2_lv, 4);
@ -86,144 +73,117 @@ struct FusedSgdEncodingFunctor {};
template <>
struct FusedSgdEncodingFunctor<true> {
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step
) const {
float weight_decay_lv = weight_decay;
float momentum_lv = momentum;
float lr_lv = lr;
float dampening_lv = dampening;
uint8_t nesterov_lv = nesterov;
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step) const {
float weight_decay_lv = weight_decay;
float momentum_lv = momentum;
float lr_lv = lr;
float dampening_lv = dampening;
uint8_t nesterov_lv = nesterov;
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, momentum_lv, 3);
mtl_setBytes(computeEncoder, lr_lv, 4);
mtl_setBytes(computeEncoder, dampening_lv, 5);
mtl_setBytes(computeEncoder, nesterov_lv, 6);
mtl_setBytes(computeEncoder, maximize_lv, 7);
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, momentum_lv, 3);
mtl_setBytes(computeEncoder, lr_lv, 4);
mtl_setBytes(computeEncoder, dampening_lv, 5);
mtl_setBytes(computeEncoder, nesterov_lv, 6);
mtl_setBytes(computeEncoder, maximize_lv, 7);
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
}
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double momentum,
const at::Tensor& lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step
) const {
float weight_decay_lv = weight_decay;
float momentum_lv = momentum;
float dampening_lv = dampening;
uint8_t nesterov_lv = nesterov;
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double momentum,
const at::Tensor& lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step) const {
float weight_decay_lv = weight_decay;
float momentum_lv = momentum;
float dampening_lv = dampening;
uint8_t nesterov_lv = nesterov;
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, momentum_lv, 3);
mtl_setBuffer(computeEncoder, lr, 4);
mtl_setBytes(computeEncoder, dampening_lv, 5);
mtl_setBytes(computeEncoder, nesterov_lv, 6);
mtl_setBytes(computeEncoder, maximize_lv, 7);
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, momentum_lv, 3);
mtl_setBuffer(computeEncoder, lr, 4);
mtl_setBytes(computeEncoder, dampening_lv, 5);
mtl_setBytes(computeEncoder, nesterov_lv, 6);
mtl_setBytes(computeEncoder, maximize_lv, 7);
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
}
};
template <>
struct FusedSgdEncodingFunctor<false> {
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double lr,
const bool maximize
) const {
float weight_decay_lv = weight_decay;
float lr_lv = lr;
uint8_t maximize_lv = maximize;
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double lr,
const bool maximize) const {
float weight_decay_lv = weight_decay;
float lr_lv = lr;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, lr_lv, 3);
mtl_setBytes(computeEncoder, maximize_lv, 4);
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, lr_lv, 3);
mtl_setBytes(computeEncoder, maximize_lv, 4);
}
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const at::Tensor& lr,
const bool maximize
) const {
float weight_decay_lv = weight_decay;
uint8_t maximize_lv = maximize;
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const at::Tensor& lr,
const bool maximize) const {
float weight_decay_lv = weight_decay;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBuffer(computeEncoder, lr, 3);
mtl_setBytes(computeEncoder, maximize_lv, 4);
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBuffer(computeEncoder, lr, 3);
mtl_setBytes(computeEncoder, maximize_lv, 4);
}
};
template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes>
static void multi_tensor_apply_for_fused_optimizer(
const std::string& kernel_name,
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
encoder_func_t encode,
ArgTypes... args
) {
static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name,
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
encoder_func_t encode,
ArgTypes... args) {
const auto num_tensors = tensor_lists[0].size();
if (num_tensors == 0) {
return;
}
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth");
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
for (const auto& d : c10::irange(depth)) {
TORCH_CHECK(
tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
"Only float and half are supported");
}
id<MTLDevice> device = MPSDevice::getInstance()->device();
@ -251,7 +211,8 @@ static void multi_tensor_apply_for_fused_optimizer(
// BufferIndex is the index in the kernel function
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
int64_t tensor_loc = 0;
@ -265,10 +226,11 @@ static void multi_tensor_apply_for_fused_optimizer(
}
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
}
if (state_steps.size() > 0){
if (state_steps.size() > 0) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}
@ -281,47 +243,50 @@ static void multi_tensor_apply_for_fused_optimizer(
TORCH_CHECK(chunks > -1);
for (const auto& chunk : c10::irange(chunks)) {
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
threadgroup_loc++;
threadgroup_loc++;
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
// Reach the maximum threadgroups per dispatch
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
// Reach the maximum threadgroups per dispatch
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
if (tensor_full || blocks_full){
encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
if (tensor_full || blocks_full) {
encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
// Reset
threadgroup_loc = 0;
if (chunk == chunks - 1) {
// last chunk
tensor_loc = 0;
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
} else {
// reuse the current tensor since the current one isn't done.
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
// Reset
threadgroup_loc = 0;
if (chunk == chunks - 1) {
// last chunk
tensor_loc = 0;
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
} else {
// reuse the current tensor since the current one isn't done.
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
}
if (state_steps.size() > 0){
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}
tensor_loc = 1;
}
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
}
if (state_steps.size() > 0) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}
tensor_loc = 1;
}
}
}
}
@ -334,7 +299,6 @@ static void multi_tensor_apply_for_fused_optimizer(
}
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
}
});
}