[BE][MPS] Add MPS to clang format (#96562)

I'm getting tired of asking to add space after if and all that jazz, so let's linter do that.
Add section for Objective-C language, where column with is extended to 120 characters and `AlignAfterOpenBracket` is set to `Align`

All `.mm` changes in this PR are made by running linter as follows:
```
lintrunner --take CLANGFORMAT --all-files --apply-patches
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96562
Approved by: https://github.com/seemethere, https://github.com/janeyx99, https://github.com/ZainRizvi, https://github.com/izaitsevfb, https://github.com/PaliC, https://github.com/albanD
This commit is contained in:
Nikita Shulga
2023-03-10 23:17:54 +00:00
committed by PyTorch MergeBot
parent a7689e73f6
commit 4242e698a3
48 changed files with 8289 additions and 9129 deletions

View File

@ -60,9 +60,6 @@ MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
@ -85,4 +82,11 @@ SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
---
Language: ObjC
ColumnLimit: 120
AlignAfterOpenBracket: Align
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
...

View File

@ -49,6 +49,8 @@ init_command = [
code = 'CLANGFORMAT'
include_patterns = [
'aten/src/ATen/*.h',
'aten/src/ATen/mps/**/*.mm',
'aten/src/ATen/native/mps/**/*.mm',
'aten/src/ATen/native/vulkan/**/*.h',
'aten/src/ATen/native/vulkan/**/*.cpp',
'c10/**/*.h',

View File

@ -1,10 +1,10 @@
// Copyright © 2022 Apple Inc.
#include <ATen/CPUFunctions.h>
#include <ATen/EmptyTensor.h>
#include <ATen/mps/MPSAllocator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Storage.h>
#include <ATen/CPUFunctions.h>
#include <ATen/EmptyTensor.h>
#include <iostream>
namespace at {
@ -19,25 +19,26 @@ uint64_t HeapBlock::heap_counter = 0;
void MPSHeapAllocatorImpl::init_allocator() {
// debug verbosity flags (see DebugVerbosity enum)
static const char *verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR");
static const char* verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR");
m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT;
static const char *high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) :
default_high_watermark_ratio;
static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
const double high_watermark_ratio =
high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio;
setHighWatermarkRatio(high_watermark_ratio);
const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified :
default_low_watermark_ratio_discrete;
static const char *low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
const double default_low_watermark_ratio =
m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete;
static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
const double low_watermark_ratio =
low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
setLowWatermarkRatio(low_watermark_ratio);
}
void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio);
m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
static_cast<size_t>(ratio * (double)max_device_size());
m_max_total_allowed_size =
(ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
if (m_debug_verbosity & DebugVerbosity::PROFILING) {
std::cerr << "\nHigh watermark memory allocation limit: "
<< (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n";
@ -47,11 +48,12 @@ void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) {
// used for comparison with lower_watermark_ratio
const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
const double high_watermark_limit =
m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio);
// we use this to detect if there's memory pressure
m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
static_cast<size_t>(ratio * (double)max_device_size());
m_low_watermark_limit =
(ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
if (m_debug_verbosity & DebugVerbosity::PROFILING) {
std::cerr << "Low watermark memory allocation limit: "
<< (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n";
@ -61,7 +63,7 @@ void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) {
HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) {
BufferPool& pool = *params.pool;
HeapBlock *heap_block = nullptr;
HeapBlock* heap_block = nullptr;
HeapBlock search_key(params.size());
auto it = pool.heaps.lower_bound(&search_key);
@ -69,10 +71,8 @@ HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) {
heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage);
if (heap_block) {
if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) {
std::cerr << "\nAllocated "
<< ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ")
<< " heap #" << heap_block->heap_id
<< " of size " << format_size(heap_block->size.total)
std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ") << " heap #"
<< heap_block->heap_id << " of size " << format_size(heap_block->size.total)
<< " (#heaps: " << (pool.heaps.size() + 1)
<< ", current allocated: " << format_size(current_allocated_size()) << ")\n";
}
@ -91,7 +91,7 @@ bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) {
current_allocated_size() + params.size() > m_max_total_allowed_size) {
return false;
}
HeapBlock *heap = get_free_heap(params);
HeapBlock* heap = get_free_heap(params);
if (!heap) {
return false; // this will cause releasing pool buffers to free up memory
}
@ -110,16 +110,13 @@ bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) {
if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) &&
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
std::cerr << "Allocated "
<< ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "")
<< " buffer #" << params.buffer_block->buf_id
<< " of size " << format_size(params.size())
<< " at " << params.buffer_block->buffer
<< " from heap #" << heap->heap_id
std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
<< params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at "
<< params.buffer_block->buffer << " from heap #" << heap->heap_id
<< " (requested: " << format_size(params.requested_size)
<< ", heap: " << format_size(heap->size.available)
<< ", total: " << format_size(m_total_allocated_memory) << ")\n";
<< ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory)
<< ")\n";
}
return true;
}
@ -158,8 +155,8 @@ bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
// this will skip unnecessary garbage collection as we'll reuse the newly released space
params.has_memory_pressure = false;
} else if (params.has_memory_pressure) {
// the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap container)
// in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
// the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap
// container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
release_buffer(buffer_block, true);
} else {
// only if there's no memory pressure, we'll reuse the oversized buffer
@ -177,15 +174,12 @@ bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
if ((m_debug_verbosity & DebugVerbosity::RECYCLES) &&
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
std::cerr << "Reusing "
<< ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "")
<< " buffer #" << params.buffer_block->buf_id
<< " of size " << format_size(params.buffer_block->size)
<< " at " << params.buffer_block->buffer
<< " (requested: " << format_size(params.requested_size)
<< ", use#: " << params.buffer_block->use_count + 1
<< ", retain#: " << params.buffer_block->retainCount() << ")\n";
std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
<< params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at "
<< params.buffer_block->buffer << " (requested: " << format_size(params.requested_size)
<< ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount()
<< ")\n";
}
return true;
}
@ -214,7 +208,8 @@ BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usag
alloc_buffer(params) ||
// Callbacks might release more memory (eg. by forcing a GC in the host language) thus
// we can retry getting a free buffer in the pool, before trying to alloc again.
(trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && get_free_buffer(params)) ||
(trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) &&
get_free_buffer(params)) ||
// Free enough available cached blocks to satisfy alloc and retry alloc.
(release_available_cached_buffers(params) && alloc_buffer(params)) ||
// Free all cached buffers and retry alloc.
@ -229,16 +224,30 @@ BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usag
// chunk of requested size couldn't be found.
if (!block_found || !buffer_block) {
if (m_high_watermark_ratio > 0.0) {
TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory),
", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory),
", max allowed: ", format_size(m_max_total_allowed_size), "). Tried to allocate ", format_size(alloc_size),
" on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
TORCH_CHECK(
false,
"MPS backend out of memory (MPS allocated: ",
format_size(m_total_allocated_memory),
", other allocations: ",
format_size(current_allocated_size() - m_total_allocated_memory),
", max allowed: ",
format_size(m_max_total_allowed_size),
"). Tried to allocate ",
format_size(alloc_size),
" on ",
((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
" pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).");
} else {
TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory),
", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory),
"). Tried to allocate ", format_size(alloc_size),
" on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), " pool.");
TORCH_CHECK(false,
"MPS backend out of memory (MPS allocated: ",
format_size(m_total_allocated_memory),
", other allocations: ",
format_size(current_allocated_size() - m_total_allocated_memory),
"). Tried to allocate ",
format_size(alloc_size),
" on ",
((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
" pool.");
}
}
buffer_block->in_use = true;
@ -270,7 +279,7 @@ BufferBlock* MPSHeapAllocatorImpl::get_allocated_buffer_block(void* ptr) {
}
bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) {
HeapBlock *heap_block = buffer_block->heap;
HeapBlock* heap_block = buffer_block->heap;
BufferPool& pool = *heap_block->pool;
m_total_allocated_memory -= buffer_block->size;
pool.allocated_size -= buffer_block->size;
@ -284,12 +293,9 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
if ((m_debug_verbosity & DebugVerbosity::RELEASES) &&
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
std::cerr << "Released buffer #" << buffer_block->buf_id
<< " of size " << format_size(buffer_block->size)
<< " from heap #" << heap_block->heap_id
<< " (heap size: " << format_size(heap_block->size.available)
<< ", use#: " << buffer_block->use_count
<< ", retain#: " << retainCount
std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size)
<< " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available)
<< ", use#: " << buffer_block->use_count << ", retain#: " << retainCount
<< ", gc#: " << buffer_block->gc_count << ")\n";
}
delete buffer_block;
@ -298,10 +304,9 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
pool.heaps_pending_update.erase(heap_block);
retainCount = heap_block->releaseMTLHeap();
if (m_debug_verbosity & DebugVerbosity::RELEASES) {
std::cerr << "Released heap #" << heap_block->heap_id
<< " of size " << format_size(heap_block->size.total)
<< " (current allocated: " << format_size(current_allocated_size())
<< ", retain#: " << retainCount << ")\n";
std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total)
<< " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount
<< ")\n";
}
delete heap_block;
return true;
@ -312,7 +317,7 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
if (retainCount > 1) {
pool.heaps_pending_update.insert(heap_block);
m_mutex.unlock();
m_stream->addCompletedHandler(^(id <MTLCommandBuffer>) {
m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
// check if the heap block still exists
if (pool.heaps_pending_update.find(heap_block) != pool.heaps_pending_update.end()) {
@ -333,13 +338,11 @@ void MPSHeapAllocatorImpl::release_buffers(BufferPool& pool) {
return;
}
if ((m_debug_verbosity & DebugVerbosity::RELEASES)) {
std::cerr << "Releasing " << pool.buffers.size()
<< " buffers from "
<< ((pool.usage & UsageFlags::SMALL ) ? "small " : "large ")
std::cerr << "Releasing " << pool.buffers.size() << " buffers from "
<< ((pool.usage & UsageFlags::SMALL) ? "small " : "large ")
<< ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
<< ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "")
<< " pool (total size: " << format_size(pool.allocated_size)
<< ", #buffers: " << pool.n_buffers << ")\n";
<< " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n";
}
auto it = pool.buffers.begin();
while (it != pool.buffers.end()) {
@ -381,10 +384,8 @@ bool MPSHeapAllocatorImpl::release_available_cached_buffers(AllocParams& params)
bool MPSHeapAllocatorImpl::release_cached_buffers() {
if (m_debug_verbosity >= DebugVerbosity::PROFILING) {
std::cerr << "Attempting to release cached buffers (MPS allocated: "
<< format_size(m_total_allocated_memory)
<< ", other allocations: "
<< format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory)
<< ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
}
// before releasing the buffers make sure the command buffer has finished.
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
@ -445,11 +446,10 @@ void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) {
}
}
if (m_debug_verbosity & DebugVerbosity::RELEASES) {
std::cerr << "Garbage collected " << freed_count
<< " buffers from large "
std::cerr << "Garbage collected " << freed_count << " buffers from large "
<< ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
<< " pool (total reclaimed: " << format_size(gc_reclaimed)
<< ", #buffers: " << pool.buffers.size() << ")\n";
<< " pool (total reclaimed: " << format_size(gc_reclaimed) << ", #buffers: " << pool.buffers.size()
<< ")\n";
}
}
@ -464,7 +464,7 @@ id<MTLBuffer> MPSHeapAllocatorImpl::malloc(size_t size, uint32_t usage) {
bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
// it's OK for the buffer_block to not exist yet
return buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED);
}
@ -487,9 +487,9 @@ id<MTLBuffer> MPSHeapAllocatorImpl::allocScalarBufferWithValue(void* value, size
ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
if (buffer_block) {
return (ssize_t) buffer_block->requested_size;
return (ssize_t)buffer_block->requested_size;
}
// -1 indicates the passed buffer pointer wasn't found
return -1;
@ -498,7 +498,7 @@ ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr) {
void MPSHeapAllocatorImpl::setBufferShape(void* ptr, const IntArrayRef& shape) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr);
// note that the IntArrayRef doesn't own the underlying data, and the backing
// memory for shape data must persist as long as the buffer is in use.
@ -509,7 +509,7 @@ void MPSHeapAllocatorImpl::setBufferShape(void* ptr, const IntArrayRef& shape) {
IntArrayRef MPSHeapAllocatorImpl::getBufferShape(void* ptr) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
if (buffer_block && buffer_block->shape.size() > 0) {
return IntArrayRef{buffer_block->shape};
}
@ -517,7 +517,7 @@ IntArrayRef MPSHeapAllocatorImpl::getBufferShape(void* ptr) {
}
void MPSHeapAllocatorImpl::free(void* ptr) {
BufferBlock *buffer_block = nullptr;
BufferBlock* buffer_block = nullptr;
{
std::lock_guard<std::recursive_mutex> lock(m_mutex);
@ -531,7 +531,7 @@ void MPSHeapAllocatorImpl::free(void* ptr) {
}
// we sync the scalar pool manually with completion handler at the time buffer is
// freed when the MPSScalar instance goes our of scope
m_stream->addCompletedHandler(^(id <MTLCommandBuffer>) {
m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
free_buffer(buffer_block);
});
@ -555,10 +555,15 @@ inline std::string MPSHeapAllocatorImpl::format_size(uint64_t size) const {
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024UL) { os << size << " bytes"; }
else if (size <= 1048576UL) { os << ((float) size / 1024.0) << " KB"; }
else if (size <= 1073741824UL) { os << ((float) size / 1048576.0) << " MB"; }
else { os << ((float) size / 1073741824.0) << " GB"; }
if (size <= 1024UL) {
os << size << " bytes";
} else if (size <= 1048576UL) {
os << ((float)size / 1024.0) << " KB";
} else if (size <= 1073741824UL) {
os << ((float)size / 1048576.0) << " MB";
} else {
os << ((float)size / 1073741824.0) << " GB";
}
return os.str();
}
@ -574,16 +579,13 @@ HeapAllocator::MPSHeapAllocatorImpl& _getAllocImpl() {
// MPS allocator struct to be registered with Pytorch
struct TORCH_API MPSAllocator final : public IMPSAllocator {
public:
explicit MPSAllocator(uint32_t Usage) :
m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage)
{
public:
explicit MPSAllocator(uint32_t Usage)
: m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) {
if (_getAllocImpl().getDebugVerbosity()) {
if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) {
std::cerr << "Initializing "
<< ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
<< " heap allocator on "
<< (m_has_unified_memory ? "unified" : "discrete")
std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
<< " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete")
<< " device memory of size "
<< _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n";
}
@ -593,34 +595,64 @@ public:
~MPSAllocator() override {
_getAllocImpl().emptyCache();
}
DeleterFnPtr raw_deleter() const override { return &Delete; }
DeleterFnPtr raw_deleter() const override {
return &Delete;
}
DataPtr allocate(const size_t nbytes) const override {
__block id<MTLBuffer> buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr;
return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
}
// implementation of IMPSAllocator interface
DataPtr allocScalarBufferWithValue(void *value, size_t size) const override {
DataPtr allocScalarBufferWithValue(void* value, size_t size) const override {
id<MTLBuffer> buf = _getAllocImpl().allocScalarBufferWithValue(value, size);
return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
}
bool isSharedBuffer(void* ptr) const override {
return _getAllocImpl().isSharedBuffer(ptr);
}
bool isSharedStorageSupported() const override {
return m_has_unified_memory;
}
void emptyCache() const override {
_getAllocImpl().emptyCache();
}
ssize_t getUnalignedBufferSize(void* ptr) const override {
return _getAllocImpl().getUnalignedBufferSize(ptr);
}
IntArrayRef getBufferShape(void* ptr) const override {
return _getAllocImpl().getBufferShape(ptr);
}
void setBufferShape(void* ptr, const IntArrayRef& shape) const override {
_getAllocImpl().setBufferShape(ptr, shape);
}
size_t getTotalAllocatedMemory() const override {
return _getAllocImpl().getTotalAllocatedMemory();
}
size_t getCurrentAllocatedMemory() const override {
return _getAllocImpl().getCurrentAllocatedMemory();
}
size_t getDriverAllocatedMemory() const override {
return _getAllocImpl().getDriverAllocatedMemory();
}
ssize_t getLowWatermarkValue() const override {
return _getAllocImpl().getLowWatermarkValue();
}
size_t getLowWatermarkLimit() const override {
return _getAllocImpl().getLowWatermarkLimit();
}
size_t getHighWatermarkLimit() const override {
return _getAllocImpl().getHighWatermarkLimit();
}
void setLowWatermarkRatio(double ratio) const override {
_getAllocImpl().setLowWatermarkRatio(ratio);
}
void setHighWatermarkRatio(double ratio) const override {
_getAllocImpl().setHighWatermarkRatio(ratio);
}
bool isSharedBuffer(void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); }
bool isSharedStorageSupported() const override { return m_has_unified_memory; }
void emptyCache() const override { _getAllocImpl().emptyCache(); }
ssize_t getUnalignedBufferSize(void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); }
IntArrayRef getBufferShape(void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); }
void setBufferShape(void* ptr, const IntArrayRef& shape) const override { _getAllocImpl().setBufferShape(ptr, shape); }
size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); }
size_t getCurrentAllocatedMemory() const override { return _getAllocImpl().getCurrentAllocatedMemory(); }
size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); }
ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); }
size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); }
size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); }
void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); }
void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); }
private:
private:
bool m_has_unified_memory;
uint32_t m_usage;
@ -662,15 +694,13 @@ namespace native {
// Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we
// will be able to use SharedStorageMode for MTLBuffer allocations. This will
// avoid extra copies on DataLoading operations.
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device)
{
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
return at::mps::_getSharedAllocator().isSharedBuffer(self.storage().data());
}
// torch.pin_memory() implementation
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device)
{
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
auto* shared_allocator = at::mps::getIMPSAllocator(true);
TORCH_CHECK(shared_allocator, "unable to pin memory on a non-unified memory device");

View File

@ -2,10 +2,10 @@
#include <c10/util/CallOnce.h>
#include <ATen/mps/IndexKernels.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/IndexKernels.h>
namespace at {
namespace mps {
@ -23,9 +23,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
}
MPSDevice* MPSDevice::getInstance() {
c10::call_once(mpsdev_init, [] {
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
});
c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
return mps_device.get();
}
@ -33,25 +31,31 @@ id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLF
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions *options = [MTLCompileOptions new];
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled: YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
options: options
error: &error];
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled:YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}
id<MTLFunction> indexFunction = nil;
if (constantValues) {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]
constantValues: constantValues
error: &error] autorelease];
indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]
constantValues:constantValues
error:&error] autorelease];
} else {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
indexFunction =
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
}
TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]);
TORCH_CHECK(indexFunction,
"Failed to create specialized function state object: ",
kernel,
", error: ",
[[error description] UTF8String]);
return indexFunction;
}
@ -63,49 +67,52 @@ MPSDevice::~MPSDevice() {
_mtl_indexing_library = nil;
}
MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 12.3+
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");
if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor:
recurrentWeight:
inputWeight:
bias:
initState:
initCell:
descriptor:
name:)] == NO) {
if ([mpsCD instancesRespondToSelector:@selector
(LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) {
return;
}
NSArray* devices = [MTLCopyAllDevices() autorelease];
for (unsigned long i = 0 ; i < [devices count] ; i++) {
for (unsigned long i = 0; i < [devices count]; i++) {
id<MTLDevice> device = devices[i];
if(![device isLowPower]) { // exclude Intel GPUs
if (![device isLowPower]) { // exclude Intel GPUs
_mtl_device = [device retain];
break;
}
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
}
bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
id mpsCD = NSClassFromString(@"MPSGraph");
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES;
static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector(
sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:
axis:name:)] == YES;
static bool _macos_13_1_plus =
[mpsCD instancesRespondToSelector:@selector
(sampleGridWithSourceTensor:
coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode
:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus =
[mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_3_plus = [_mtl_device respondsToSelector:@selector(maximumConcurrentCompilationTaskCount)];
switch (version) {
case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus;
default: return false;
case MacOSVersion::MACOS_VER_13_0_PLUS:
return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS:
return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS:
return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS:
return _macos_13_3_plus;
default:
return false;
}
}

View File

@ -4,40 +4,46 @@
namespace at {
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
{
TORCH_WARN_ONCE("The operator '", op.schema().operator_name(), "' is not currently supported ",
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_WARN_ONCE("The operator '",
op.schema().operator_name(),
"' is not currently supported ",
"on the MPS backend and will fall back to run on the CPU.",
" This may have performance implications.");
native::cpu_fallback(op, stack);
}
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
{
TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not currently implemented ",
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack){TORCH_CHECK_NOT_IMPLEMENTED(
false,
"The operator '",
op.schema().operator_name(),
"' is not currently implemented ",
"for the MPS device. If you want this op to be added in priority during the prototype ",
"phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ",
"As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
"to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
"on MPS.")
}
"on MPS.")}
// This dispatch should never be called for tensor on MPS but is frequently called
// If one of them are on CPU
Tensor slow_conv2d_forward_mps(
const Tensor &self,
const Tensor &weight,
Tensor slow_conv2d_forward_mps(const Tensor& self,
const Tensor& weight,
IntArrayRef kernel_size,
const c10::optional<Tensor> &bias,
const c10::optional<Tensor>& bias,
IntArrayRef stride,
IntArrayRef padding) {
TORCH_CHECK(self.device() == weight.device(), __func__, ": input(device='", self.device(), "') and weight(device=", weight.device(), "') must be on the same device");
TORCH_CHECK(self.device() == weight.device(),
__func__,
": input(device='",
self.device(),
"') and weight(device=",
weight.device(),
"') must be on the same device");
TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
}
TORCH_LIBRARY_IMPL(_, MPS, m) {
static const char *enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
static const char* enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
if (!enable_mps_fallback || std::stoi(enable_mps_fallback) == 0) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&mps_error_fallback>());
} else {

View File

@ -24,7 +24,8 @@ Generator createMPSGenerator(uint64_t seed_val) {
MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
data_({.seed = seed_in}), engine_(seed_in, 0, 0) { }
data_({.seed = seed_in}),
engine_(seed_in, 0, 0) {}
void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
data_.seed = seed;
@ -60,7 +61,8 @@ c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
static const size_t seed_size = sizeof(uint64_t);
static const size_t total_size = states_size + seed_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto state_tensor = at::detail::empty_cpu(
{(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed();
memcpy(rng_state, this->data_.state.data(), states_size);

View File

@ -1,31 +1,24 @@
// Copyright © 2022 Apple Inc.
#include <ATen/mps/MPSGuardImpl.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSGuardImpl.h>
namespace at {
namespace mps {
void MPSGuardImpl::createEvent(
mpsEvent_t* event,
const EventFlag flag) const {
}
void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {}
void MPSGuardImpl::destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept {
if (!event) return;
void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept {
if (!event)
return;
auto mps_event = static_cast<mpsEvent_t>(event);
mps_event->~MPSEvent();
}
}
void MPSGuardImpl::record(
void** event,
void MPSGuardImpl::record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const {
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
@ -36,22 +29,19 @@ namespace mps {
auto mps_event = static_cast<mpsEvent_t>(*event);
MPSStream mps_stream{stream};
mps_event->recordEvent(true);
}
void MPSGuardImpl::block(
void* event,
const Stream& stream) const {
}
void MPSGuardImpl::block(void* event, const Stream& stream) const {
auto mps_event = static_cast<mpsEvent_t>(event);
MPSStream mps_stream{stream};
mps_event->waitForEvent(true);
}
}
bool MPSGuardImpl::queryEvent(void* event) const {
bool MPSGuardImpl::queryEvent(void* event) const {
auto mps_event = static_cast<mpsEvent_t>(event);
return mps_event->queryEvent();
}
}
}
}

View File

@ -1,7 +1,7 @@
// Copyright © 2022 Apple Inc.
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSStream.h>
namespace at {
namespace mps {
@ -17,9 +17,9 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
TORCH_CHECK(_stream.device_type() == DeviceType::MPS);
_serialQueue = dispatch_queue_create("metal gpu stream", nullptr);
_executionDescriptor = [MPSGraphExecutionDescriptor new];
_executionDescriptor.completionHandler = ^(NSDictionary<MPSGraphTensor *,
MPSGraphTensorData *> * resultsDictionary,
NSError * _Nullable error) { };
_executionDescriptor.completionHandler =
^(NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* resultsDictionary, NSError* _Nullable error) {
};
}
MPSStream::~MPSStream() {
@ -41,7 +41,7 @@ MPSCommandBuffer* MPSStream::commandBuffer() {
void MPSStream::synchronize(SyncType syncType) {
if (!_commandBuffer)
return;
switch(syncType) {
switch (syncType) {
case SyncType::NONE:
// typically in GPU to GPU copies we won't commit explicitly
break;
@ -115,25 +115,27 @@ void MPSStream::addCompletedHandler(MTLCommandBufferHandler block) {
});
}
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType)
{
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) {
TORCH_INTERNAL_ASSERT(length >= offset);
if (length == 0) return;
if (length == 0)
return;
dispatch_sync(_serialQueue, ^() {
@autoreleasepool {
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
[blitEncoder fillBuffer:buffer
range:NSMakeRange(offset, length)
value:value];
[blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value];
[blitEncoder endEncoding];
synchronize(syncType);
}
});
}
void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) {
void MPSStream::copy(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
SyncType syncType) {
dispatch_sync(_serialQueue, ^() {
@autoreleasepool {
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
@ -149,10 +151,14 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
});
}
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, size_t length,
size_t srcOffset, size_t dstOffset, bool non_blocking) {
copy(srcBuffer, dstBuffer, length, srcOffset, dstOffset,
!non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
bool non_blocking) {
copy(
srcBuffer, dstBuffer, length, srcOffset, dstOffset, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
}
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) {
@ -184,8 +190,7 @@ MPSStream* MPSStreamImpl::_stream = nullptr;
MPSStream* MPSStreamImpl::getInstance() {
if (_stream == nullptr) {
_stream =
new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
_stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
}
return _stream;
}
@ -204,8 +209,8 @@ MPSStream* getDefaultMPSStream() {
// MPSEvent
//-----------------------------------------------------------------
MPSEvent::MPSEvent(bool deferInitialization) :
is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
MPSEvent::MPSEvent(bool deferInitialization)
: is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
if (!deferInitialization) {
initialize();
}
@ -256,8 +261,7 @@ void MPSEvent::waitForEvent(bool syncEvent) {
});
}
void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block)
{
void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) {
if (!is_initialized)
initialize();
dispatch_sync(_stream->queue(), ^() {

View File

@ -1,7 +1,7 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native::mps {
@ -28,27 +28,31 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
case ScalarType::Bool:
return MPSDataTypeBool;
case ScalarType::Double:
TORCH_CHECK_TYPE(false, "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
TORCH_CHECK_TYPE(false,
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
"Please use float32 instead.")
default:
TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
TORCH_CHECK_TYPE(
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
}
}
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast to these
// types.
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const Tensor& input,
bool includesInt64) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
bool condition =
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
if (includesInt64) {
condition = condition && (dataType != MPSDataTypeInt64);
}
if (condition) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
return [mpsGraph castTensor:inputTensor
toType:dataType
name:@"castInputTensor"];
return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
}
return inputTensor;
}
@ -56,16 +60,18 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor,
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast from these
// types.
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const Tensor& input,
bool includesInt64) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
bool condition =
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
if (includesInt64) {
condition = condition && (dataType != MPSDataTypeInt64);
}
if (condition) {
inputTensor = [mpsGraph castTensor:inputTensor
toType:dataType
name:@"castInputTensor"];
inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
}
return inputTensor;
}
@ -92,7 +98,8 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
case ScalarType::Bool:
return MPSDataTypeBool;
default:
TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
TORCH_CHECK_TYPE(
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
}
}
@ -148,7 +155,7 @@ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
int64_t ndim = t.dim();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
for (const auto i : c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:i];
}
return axes;
@ -159,7 +166,7 @@ NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim)
IntArrayRef dimValues = dim.value();
int ndim = dimValues.size();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
for (const auto i : c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:dimValues[i]];
}
@ -171,7 +178,7 @@ NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim)
std::string getMPSShapeString(MPSShape* shape) {
std::string str;
for(NSNumber *elem in shape) {
for (NSNumber* elem in shape) {
str += std::to_string(elem.unsignedLongValue) + ",";
}
return str;
@ -186,7 +193,7 @@ std::string getArrayRefString(const IntArrayRef s) {
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) {
std::string str;
// The key format per tensor would look like ":Float32[1,1,1,10]:"
for (const Tensor& tensor: tensors) {
for (const Tensor& tensor : tensors) {
str += ":";
if (tensor.defined()) {
str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
@ -216,7 +223,7 @@ MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) {
const NSUInteger C = sizes[1];
const NSUInteger H = sizes[2];
const NSUInteger W = sizes[3];
return @[@(N), @(H), @(W), @(C)];
return @[ @(N), @(H), @(W), @(C) ];
}
const int sz = sizes.size();
const int sz_ = (sz > 0) ? sz : 1;
@ -232,27 +239,27 @@ MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) {
}
void printTensorNDArray(const Tensor& t) {
if (!t.is_mps()) return;
if(t.numel() == 0) return;
if (!t.is_mps())
return;
if (t.numel() == 0)
return;
// Get shape and data type
auto selfShape = getMPSShape(t);
auto selfDType = getMPSDataType(t.scalar_type());
// Initialize data
id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf
shape:selfShape
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape
dataType:selfDType] autorelease];
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wobjc-method-access")
#if C10_CLANG_HAS_WARNING("-Wobjc-method-access")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access")
#endif
#endif
[tdata printNDArray];
C10_CLANG_DIAGNOSTIC_POP()
}
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType)
{
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) {
id<MTLBuffer> buffer = getMTLBufferStorage(tensor);
MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer
shape:shape
@ -261,9 +268,12 @@ MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType
return [tmpGraphTensorData mpsndarray];
}
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape,
bool gatherTensorData, MPSDataType dataType) : _tensor(src)
{
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
const Tensor& src,
MPSShape* mpsShape,
bool gatherTensorData,
MPSDataType dataType)
: _tensor(src) {
TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!");
// extract the pointer to MTLBuffer from the Tensor's storage
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
@ -285,8 +295,9 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
// if buffer size is zero in here, it's not a user error. It could be a missing check for
// tensor.numel() == 0 in our internal implementations of ops.
TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!");
const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType :
_tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type());
const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType
: _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type())
: getMPSDataType(_tensor.scalar_type());
if (src.is_contiguous() && src.storage_offset() && sliceViewTensor) {
_value = getMPSGraphTensorDataForView(src, mpsShape, mpsDataType);
@ -295,34 +306,25 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
mpsShape = getMPSShape(_tensor);
}
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
shape:mpsShape
dataType:mpsDataType] autorelease];
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:mpsShape dataType:mpsDataType] autorelease];
}
TORCH_INTERNAL_ASSERT(_value);
_placeholder = mpsGraphTensor;
}
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph,
MPSStream* mpsStream,
const Tensor& tensor) {
MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) {
auto mpsShape = getMPSShape(tensor);
auto dataType = getMPSDataType(tensor.scalar_type());
MPSGraphTensorData *result = nil;
MPSGraphTensorData* result = nil;
if (tensor.numel() > 0) {
id<MTLBuffer> buf = getMTLBufferStorage(tensor);
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf
shape:mpsShape
dataType:dataType]
autorelease];
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] autorelease];
} else {
// create empty NDArray
MPSNDArrayDescriptor *desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
shape:mpsShape];
MPSNDArray *emptyArray = [[[MPSNDArray alloc]
initWithDevice:mpsStream->device() descriptor:desc] autorelease];
MPSNDArrayDescriptor* desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsShape];
MPSNDArray* emptyArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:desc] autorelease];
result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:emptyArray] autorelease];
}
assert(result);
@ -332,30 +334,40 @@ MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph,
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) {
switch (type) {
case ScalarType::Double:
case ScalarType::Float: return {.value.f = scalar.to<float>() , .size = sizeof(float) , .type = type};
case ScalarType::Half: return {.value.h = scalar.to<at::Half>(), .size = sizeof(short) , .type = type};
case ScalarType::Long: return {.value.i = scalar.to<int64_t>() , .size = sizeof(int64_t), .type = type};
case ScalarType::Int: return {.value.i = scalar.to<int32_t>() , .size = sizeof(int32_t), .type = type};
case ScalarType::Short: return {.value.i = scalar.to<int16_t>() , .size = sizeof(int16_t), .type = type};
case ScalarType::Char: return {.value.i = scalar.to<int8_t>() , .size = sizeof(int8_t) , .type = type};
case ScalarType::Byte: return {.value.i = scalar.to<uint8_t>() , .size = sizeof(uint8_t), .type = type};
case ScalarType::Bool: return {.value.b = scalar.to<bool>() , .size = sizeof(bool) , .type = type};
case ScalarType::Float:
return {.value.f = scalar.to<float>(), .size = sizeof(float), .type = type};
case ScalarType::Half:
return {.value.h = scalar.to<at::Half>(), .size = sizeof(short), .type = type};
case ScalarType::Long:
return {.value.i = scalar.to<int64_t>(), .size = sizeof(int64_t), .type = type};
case ScalarType::Int:
return {.value.i = scalar.to<int32_t>(), .size = sizeof(int32_t), .type = type};
case ScalarType::Short:
return {.value.i = scalar.to<int16_t>(), .size = sizeof(int16_t), .type = type};
case ScalarType::Char:
return {.value.i = scalar.to<int8_t>(), .size = sizeof(int8_t), .type = type};
case ScalarType::Byte:
return {.value.i = scalar.to<uint8_t>(), .size = sizeof(uint8_t), .type = type};
case ScalarType::Bool:
return {.value.b = scalar.to<bool>(), .size = sizeof(bool), .type = type};
default:
TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend.");
}
}
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar) {
MPSGraphTensorData *result = nullptr;
MPSGraphTensorData* result = nullptr;
// Scalar pools are only supported on devices with unified memory
if (mpsStream->device().hasUnifiedMemory) {
scalar.buffer = getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size);
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer: scalar.getMTLBuffer()
shape: @[@1]
dataType: getMPSScalarType(scalar.type)] autorelease];
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:scalar.getMTLBuffer()
shape:@[ @1 ]
dataType:getMPSScalarType(scalar.type)] autorelease];
} else {
MPSNDArrayDescriptor *tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type) shape:@[@1]];
MPSNDArray *tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:tensorDesc] autorelease];
MPSNDArrayDescriptor* tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type)
shape:@[ @1 ]];
MPSNDArray* tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device()
descriptor:tensorDesc] autorelease];
[tensorNDArray writeBytes:&scalar.value strideBytes:nil];
result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:tensorNDArray] autorelease];
}
@ -371,58 +383,50 @@ MPSGraph* make_mps_graph() {
return mpsGraph;
}
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) {
return [mpsGraph placeholderWithShape:nil
dataType:dataType
name:nil];
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil];
}
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
return [mpsGraph placeholderWithShape:mpsShape
dataType:dataType
name:nil];
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil];
}
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor) {
return [mpsGraph placeholderWithShape:getMPSShape(tensor)
dataType:getMPSScalarType(tensor.scalar_type())
name:nil];
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) {
return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil];
}
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) {
return [mpsGraph placeholderWithShape:@[@1]
dataType:dataType
name:nil];
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
return [mpsGraph placeholderWithShape:@[ @1 ] dataType:dataType name:nil];
}
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar) {
return [mpsGraph placeholderWithShape:@[@1]
dataType:getMPSScalarType(scalar.type())
name:nil];
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar) {
return [mpsGraph placeholderWithShape:@[ @1 ] dataType:getMPSScalarType(scalar.type()) name:nil];
}
// this is meant to suppress the availability warning on castTensor
// we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
if ([tensor dataType] == toType) {
return tensor;
}
return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];
}
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"];
}
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor) {
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) {
TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!");
return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil]
dimension:2 withDimension:1 name: nil];
dimension:2
withDimension:1
name:nil];
}
string get_mem_format_string(c10::MemoryFormat memory_format) {
string mem_format_key;
switch(memory_format) {
switch (memory_format) {
case at::MemoryFormat::Contiguous:
mem_format_key = "Contiguous";
break;
@ -439,11 +443,12 @@ string get_mem_format_string(c10::MemoryFormat memory_format) {
MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;
class MPSGraphCacheCallback : public IMpsAllocatorCallback {
public:
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { }
public:
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) {}
void executeMPSAllocatorCallback(void* ptr, EventType event) override { }
private:
void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
private:
MPSGraphCache* graph_cache;
};

File diff suppressed because it is too large Load Diff

View File

@ -1,52 +1,54 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/Pool.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
void set_kernel_params
(int64_t isizeH, int64_t isizeW,
int64_t osizeH, int64_t osizeW,
int64_t &strideH, int64_t &strideW,
int64_t &kernel_sizeH, int64_t &kernel_sizeW,
void set_kernel_params(int64_t isizeH,
int64_t isizeW,
int64_t osizeH,
int64_t osizeW,
int64_t& strideH,
int64_t& strideW,
int64_t& kernel_sizeH,
int64_t& kernel_sizeW,
bool check_avg_pooling = false) {
TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW),
"Adaptive pool MPS: Input height and width must both be greater than, "
"or equal to, or lesser than output height and width")
if(isizeH >= osizeH) {
if (isizeH >= osizeH) {
if (check_avg_pooling) {
TORCH_CHECK((isizeH % osizeH == 0 && isizeW % osizeW == 0),
"Adaptive pool MPS: input sizes must be divisible by output sizes.");
}
strideH = (int64_t) (isizeH / osizeH);
strideW = (int64_t) (isizeW / osizeW);
kernel_sizeH = isizeH - (osizeH-1) * strideH;
kernel_sizeW = isizeW - (osizeW-1) * strideW;
strideH = (int64_t)(isizeH / osizeH);
strideW = (int64_t)(isizeW / osizeW);
kernel_sizeH = isizeH - (osizeH - 1) * strideH;
kernel_sizeW = isizeW - (osizeW - 1) * strideW;
} else {
if (check_avg_pooling) {
TORCH_CHECK((osizeH % isizeH == 0 && osizeW % isizeW == 0),
"Adaptive pool MPS: output sizes must be divisible by input sizes.");
}
strideH = (int64_t) (osizeH / isizeH);
strideW = (int64_t) (osizeW / isizeW);
kernel_sizeH = osizeH - (isizeH-1) * strideH;
kernel_sizeW = osizeW - (isizeW-1) * strideW;
strideH = (int64_t)(osizeH / isizeH);
strideW = (int64_t)(osizeW / isizeW);
kernel_sizeH = osizeH - (isizeH - 1) * strideH;
kernel_sizeW = osizeW - (isizeW - 1) * strideW;
}
}
// Adaptive average pooling
Tensor& adaptive_avg_pool2d_out_mps
(const Tensor& input,
IntArrayRef output_size,
Tensor& output) {
Tensor& adaptive_avg_pool2d_out_mps(const Tensor& input, IntArrayRef output_size, Tensor& output) {
for (int64_t i = 1; i < input.ndimension(); i++) {
TORCH_CHECK(input.size(i) > 0,
"adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
"but input has sizes ", input.sizes(), " with dimension ", i, " being empty");
"but input has sizes ",
input.sizes(),
" with dimension ",
i,
" being empty");
}
int64_t isizeH = input.size(-2);
@ -57,12 +59,9 @@ Tensor& adaptive_avg_pool2d_out_mps
int64_t strideH = 0, strideW = 0;
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW, true);
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
if(isizeH >= osizeH) {
if (isizeH >= osizeH) {
output = at::avg_pool2d(input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
@ -73,7 +72,7 @@ Tensor& adaptive_avg_pool2d_out_mps
} else {
Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto input_sizes = input.sizes();
std::vector<int64_t> phony_shape{input_sizes.begin(), input_sizes.end() -2};
std::vector<int64_t> phony_shape{input_sizes.begin(), input_sizes.end() - 2};
phony_shape.push_back(output_size[0]);
phony_shape.push_back(output_size[1]);
phony_grad.resize_(IntArrayRef(phony_shape));
@ -86,16 +85,13 @@ Tensor& adaptive_avg_pool2d_out_mps
true,
c10::nullopt);
// Multiply output by kernel size
output = at::mul(output, kernel_sizeH*kernel_sizeW);
output = at::mul(output, kernel_sizeH * kernel_sizeW);
}
return output;
}
Tensor adaptive_avg_pool2d_mps
(at::Tensor const& input,
IntArrayRef output_size) {
Tensor adaptive_avg_pool2d_mps(at::Tensor const& input, IntArrayRef output_size) {
IntArrayRef output_shape;
auto osizeH = output_size[0];
@ -103,7 +99,7 @@ Tensor adaptive_avg_pool2d_mps
std::vector<long long> out_dims = {};
if(input.ndimension() == 4) {
if (input.ndimension() == 4) {
auto sizeB = input.size(0);
auto sizeD = input.size(1);
@ -112,8 +108,7 @@ Tensor adaptive_avg_pool2d_mps
out_dims.push_back(osizeH);
out_dims.push_back(osizeW);
output_shape = IntArrayRef(out_dims);
}
else {
} else {
auto sizeD = input.size(0);
out_dims.push_back(sizeD);
out_dims.push_back(osizeH);
@ -122,21 +117,12 @@ Tensor adaptive_avg_pool2d_mps
}
const auto memory_format = input.suggest_memory_format();
Tensor output = at::native::empty_mps(
output_shape,
input.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
memory_format);
Tensor output =
at::native::empty_mps(output_shape, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format);
return adaptive_avg_pool2d_out_mps(input, output_size, output);
}
Tensor adaptive_avg_pool2d_backward_mps
(const Tensor& gradOutput,
const Tensor& input) {
Tensor adaptive_avg_pool2d_backward_mps(const Tensor& gradOutput, const Tensor& input) {
int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
int64_t osizeH = gradOutput.size(-2);
@ -145,14 +131,11 @@ Tensor adaptive_avg_pool2d_backward_mps
int64_t strideH = 0, strideW = 0;
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW, true);
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (gradInput.numel() != 0) {
if(isizeH >= osizeH) {
if (isizeH >= osizeH) {
gradInput = at::avg_pool2d_backward(gradOutput,
input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
@ -169,7 +152,7 @@ Tensor adaptive_avg_pool2d_backward_mps
false,
true,
c10::nullopt);
gradInput = at::mul(gradInput, kernel_sizeH*kernel_sizeW);
gradInput = at::mul(gradInput, kernel_sizeH * kernel_sizeW);
}
}
@ -178,15 +161,15 @@ Tensor adaptive_avg_pool2d_backward_mps
// Adaptive max pooling
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
const Tensor& output,
const Tensor& indices) {
(const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
for (int64_t i = 1; i < input.ndimension(); i++) {
TORCH_CHECK(input.size(i) > 0,
"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
"but input has sizes ",
input.sizes(),
" with dimension ",
i,
" being "
"empty");
}
@ -198,13 +181,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
int64_t strideH = 0, strideW = 0;
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW);
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
at::max_pool2d_with_indices_out(const_cast<Tensor&>(output),
const_cast<Tensor&>(indices), input,
const_cast<Tensor&>(indices),
input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
IntArrayRef({0, 0}),
@ -213,11 +194,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
}
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
(const Tensor& gradOutput,
const Tensor& input,
const Tensor& indices,
const Tensor& gradInput) {
(const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) {
int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
int64_t osizeH = gradOutput.size(-2);
@ -226,13 +203,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
int64_t strideH = 0, strideW = 0;
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW);
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
at::max_pool2d_with_indices_backward_out(const_cast<Tensor&>(gradInput),
gradOutput, input,
gradOutput,
input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
IntArrayRef({0, 0}),

View File

@ -1,5 +1,5 @@
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
@ -124,10 +124,10 @@ static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
return binaryLibrary;
}
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_BINARY encoding:NSASCIIStringEncoding]
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BINARY encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
@ -167,7 +167,7 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
dispatch_sync(mpsStream->queue(), ^(){
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
NSError* error = nil;
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
@ -177,23 +177,25 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
std::vector<uint32_t> iterShapeData(iterShape.size());
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);
for (const auto i: c10::irange(iterShape.size())) {
for (const auto i : c10::irange(iterShape.size())) {
TORCH_CHECK(i <= UINT32_MAX);
iterShapeData[i] = (uint32_t)(iterShape[i]);
}
for (const auto i: c10::irange(nDim)) {
for (const auto offset: c10::irange(nOffsets)) {
for (const auto i : c10::irange(nDim)) {
for (const auto offset : c10::irange(nOffsets)) {
strides[i][offset] = iter.strides(offset)[i];
}
}
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
error: &error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
@ -206,8 +208,7 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
kernelOffsetsTGSize = numThreads;
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> binaryPSO = binaryPipelineState(device, kernel);
@ -223,8 +224,7 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: threadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
[computeEncoder endEncoding];
mpsStream->commit(true);

View File

@ -4,34 +4,40 @@
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>
#include <c10/util/Optional.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <c10/util/Optional.h>
#include <torch/library.h>
namespace at::native {
namespace mps {
struct BinaryOpCachedGraph : public MPSCachedGraph
{
BinaryOpCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct BinaryOpCachedGraph : public MPSCachedGraph {
BinaryOpCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *primaryTensor = nil, *secondaryTensor = nil;
MPSGraphTensor *alphaTensor = nil, *outputTensor = nil;
};
typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
#define BinaryOpFn(graph, primary, secondary) MPSGraphTensor* (mps::BinaryOpCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
#define BinaryOpFn(graph, primary, secondary) \
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
// alpha is always 1.0 except when this function is called from add_sub_template()
void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha,
const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock)
{
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte ),
void binaryOpTensor(const Tensor& self,
const Tensor& other,
const Scalar& alpha,
const Tensor& output_,
std::string op_name,
BinaryOpBlock binaryBlock) {
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
"MPS support binary op with uint8 natively starting from macOS 13.0");
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
(self.scalar_type() == ScalarType::Long ||
(other.scalar_type() == ScalarType::Long && (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.2");
(other.scalar_type() == ScalarType::Long &&
(self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
"MPS: ",
op_name,
" op with int64 input is supported natively starting from macOS 13.2");
MPSStream* mpsStream = getCurrentMPSStream();
const bool is_self_scalar = self.dim() == 0;
@ -79,16 +85,18 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = op_name + getTensorsStringKey({self, other, output_});
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () {
BinaryOpCachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
BinaryOpCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new BinaryOpCachedGraph(mpsGraph);
newCachedGraph->primaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
newCachedGraph->secondaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
newCachedGraph->primaryTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
newCachedGraph->secondaryTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor;
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;
@ -101,8 +109,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
common_dtype = outputDataType;
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
} else if (outputDataType == ScalarType::Bool &&
(inputDataType == ScalarType::Byte ||
otherDataType == ScalarType::Byte)) {
(inputDataType == ScalarType::Byte || otherDataType == ScalarType::Byte)) {
common_dtype = ScalarType::Byte;
}
}
@ -113,8 +120,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
// Output tensor should have been promoted but it remains an int32 tensor
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to
// int32 tensor Output tensor should have been promoted but it remains an int32 tensor
if (outputDataType != common_dtype ||
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
@ -122,10 +129,10 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
}
return newCachedGraph;
});
cachedGraph = static_cast<BinaryOpCachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<BinaryOpCachedGraph*>(tmpCachedGraph);
}
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
Placeholder selfPlaceholder;
Placeholder otherPlaceholder;
MPSScalar self_scalar;
@ -136,16 +143,22 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
self_scalar = getMPSScalar(self.item(), inputDataType);
feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self_scalar);
} else {
selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, /*mpsShape*/nil,
/*gatherTensorData=*/true, getMPSScalarType(inputDataType));
selfPlaceholder = Placeholder(cachedGraph->primaryTensor,
self,
/*mpsShape*/ nil,
/*gatherTensorData=*/true,
getMPSScalarType(inputDataType));
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
}
if (is_other_scalar && !other.is_mps()) {
other_scalar = getMPSScalar(other.item(), otherDataType);
feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other_scalar);
} else {
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, /*mpsShape*/nil,
/*gatherTensorData=*/true, getMPSScalarType(otherDataType));
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor,
other,
/*mpsShape*/ nil,
/*gatherTensorData=*/true,
getMPSScalarType(otherDataType));
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
}
@ -156,9 +169,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, needsCopyToOutput ? output : output_);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
if (needsCopyToOutput) {
@ -167,27 +179,28 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
}
}
void binaryOpScalar(const Tensor& self, const Scalar& other, const Scalar& alpha,
const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock)
{
void binaryOpScalar(const Tensor& self,
const Scalar& other,
const Scalar& alpha,
const Tensor& output,
std::string op_name,
BinaryOpBlock binaryBlock) {
binaryOpTensor(self, wrapped_scalar_tensor(other), alpha, output, op_name, binaryBlock);
}
void div_mode_template(const Tensor& self, const Tensor& other,
void div_mode_template(const Tensor& self,
const Tensor& other,
c10::optional<c10::string_view> rounding_mode,
const Tensor& output, const string op_name)
{
if(rounding_mode.has_value() && *rounding_mode == "trunc"){
TORCH_CHECK(self.scalar_type() != ScalarType::Half,
"MPS: does not support trunc_divide op with float16 input");
const Tensor& output,
const string op_name) {
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input");
}
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
if(!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor
toType:MPSDataTypeFloat32
name:@"primaryCastTensor"];
if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"];
secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor
toType:MPSDataTypeFloat32
name:@"secondaryCastTensor"];
@ -207,9 +220,7 @@ void div_mode_template(const Tensor& self, const Tensor& other,
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor
secondaryTensor:secondaryCastTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
secondaryTensor:mulTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
}
return truncTensor;
} else if (*rounding_mode == "floor") {
@ -218,20 +229,26 @@ void div_mode_template(const Tensor& self, const Tensor& other,
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
secondaryTensor:secondaryCastTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
secondaryTensor:mulTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
}
return floorTensor;
}
assert(0 && "Invalid rounding mode\n");
return nullptr;
};
binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block);
binaryOpTensor(self,
other,
Scalar(1.0),
output,
op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""),
div_mode_op_block);
}
void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name)
{
void add_sub_template(const Tensor& self,
const Tensor& other,
const Scalar& alpha,
const Tensor& output,
std::string op_name) {
if (alpha.toDouble() == 0.0) {
if (!self.is_alias_of(output)) { // if inplace, no-op
const_cast<Tensor&>(output) = self.clone();
@ -251,60 +268,79 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp
// if alpha is 1.0, then we don't bother adding another multiply to graph
if (alpha_has_value) {
cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[@1]);
cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[ @1 ]);
secondaryTensor = [mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor
secondaryTensor:cachedGraph->alphaTensor
name:nil];
}
if (op_name == "add")
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor
secondaryTensor:secondaryTensor
name:nil];
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
else
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
secondaryTensor:secondaryTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
};
// add alpha's type to the key only if multiply was added to graph
binaryOpTensor(self, other, alpha, output, op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""), add_sub_op_block);
binaryOpTensor(self,
other,
alpha,
output,
op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""),
add_sub_op_block);
}
} // namespace mps
#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \
Tensor& func_out (const Tensor& self, const other_type& other, Tensor& output) { \
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
Tensor& func_out(const Tensor& self, const other_type& other, Tensor& output) { \
mps::binaryOp##other_type( \
self, \
other, \
Scalar(1.0), \
output, \
#func_stub, \
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
MPSGraph* mpsGraph = cachedGraph->graph(); \
return [mpsGraph func_stub##WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \
return [mpsGraph func_stub## \
WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \
secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \
name:nil]; }); \
name:nil]; \
}); \
return output; \
}
}
#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && \
std::string(#func_stub) == "atan2"), \
"MPS does not support ", #func_stub, " op with int64 input") \
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && std::string(#func_stub) == "atan2"), \
"MPS does not support ", \
#func_stub, \
" op with int64 input") \
mps::binaryOp##other_type(self, \
other, \
Scalar(1.0), \
output, \
#func_stub, \
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
MPSGraph* mpsGraph = cachedGraph->graph(); \
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
secondaryTensor:secondaryCastTensor \
name:nil]; }); \
}
name:nil]; \
}); \
}
// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor()
#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \
mps::binaryOp##other_type(self, \
other, \
Scalar(1.0), \
output, \
#func_stub, \
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
MPSGraph* mpsGraph = cachedGraph->graph(); \
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
secondaryTensor:secondaryCastTensor \
name:nil]; }); \
}
name:nil]; \
}); \
}
// Boolean Binary Ops
CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(eq_scalar_out_mps, equal, Scalar);
@ -332,24 +368,24 @@ CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor);
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor);
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor);
TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
TORCH_IMPL_FUNC(div_out_mode_mps)
(const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out");
}
TORCH_IMPL_FUNC(div_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_IMPL_FUNC(div_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::div_mode_template(self, other, c10::nullopt, output, "div_out");
}
TORCH_IMPL_FUNC(add_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
mps::add_sub_template(self, other, alpha, output, "add");
}
TORCH_IMPL_FUNC(sub_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
mps::add_sub_template(self, other, alpha, output, "sub");
}
TORCH_IMPL_FUNC(pow_Scalar_out_mps) (const Scalar& base, const Tensor& exp, const Tensor& out) {
TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const Tensor& out) {
if (base.equal(1.0)) {
out.fill_(1);
} else {
@ -386,21 +422,18 @@ Tensor& floor_divide_mps_(Tensor& self, const Tensor& other) {
return floor_divide_out_mps(self, other, self);
}
TORCH_IMPL_FUNC(remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::div_mode_template(self, other, "floor", output, "remainder_out_mps");
}
TORCH_IMPL_FUNC(fmod_mps_out) (const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_IMPL_FUNC(fmod_mps_out)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out");
}
TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0
shape:@[@1]
dataType:primaryCastTensor.dataType];
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
secondaryTensor:twoTensor
name:nil]
@ -413,11 +446,11 @@ TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const T
mps::binaryOpTensor(self, other, Scalar(1.0), output, "hypot_out_mps", hypot_op_block);
}
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
MPSGraphTensor* sumTensor =
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
@ -425,11 +458,11 @@ TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, con
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block);
}
TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
MPSGraphTensor* sumTensor =
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
@ -437,16 +470,12 @@ TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, co
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block);
}
TORCH_IMPL_FUNC(xlogy_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
shape:@[@1]
dataType:primaryCastTensor.dataType];
MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor
name:nil];
MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor
name:nil];
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor name:nil];
MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor name:nil];
MPSGraphTensor* xlogyTensor = [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor
secondaryTensor:logyTensor
name:nil];

View File

@ -1,6 +1,6 @@
#include <ATen/ExpandUtils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/Resize.h>
#include <ATen/ExpandUtils.h>
#include <ATen/ops/logical_not_native.h>
#include <fmt/format.h>
#include <torch/library.h>
@ -86,7 +86,6 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
}}
)METAL";
const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for bitwise ops
// As bitwise ops sign-agnostic map signed/unsigned char and boolean to the same type
@ -112,7 +111,6 @@ const std::string& getMetalType(const c10::Scalar& s) {
return getMetalType(s.type());
}
static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
@ -123,10 +121,11 @@ static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
if (it != libMap.end()) {
return it->second;
}
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
@ -134,7 +133,6 @@ static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
return rc;
}
static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
@ -146,38 +144,37 @@ static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
if (it != cplMap.end()) {
return it->second;
}
NSError *error = nil;
NSError* error = nil;
auto library = compileBitwiseOpsLibrary(device, t1, t2, t3);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
TORCH_CHECK(
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;
}
void dispatch1DJob(id<MTLComputeCommandEncoder> commandEncoder, id<MTLComputePipelineState> cplState, uint32_t length)
{
void dispatch1DJob(id<MTLComputeCommandEncoder> commandEncoder, id<MTLComputePipelineState> cplState, uint32_t length) {
uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
auto size = MTLSizeMake(length, 1, 1);
auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
[commandEncoder dispatchThreads:size
threadsPerThreadgroup:threadGroupSize];
[commandEncoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
}
void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& other, at::Tensor& output, const std::string& kernel_name) {
void handle_tensor_tensor_binary_op(const at::Tensor& self,
const at::Tensor& other,
at::Tensor& output,
const std::string& kernel_name) {
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
getMetalType(output),
getMetalType(self),
getMetalType(other),
kernel_name);
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
uint32_t length = output.numel();
if (length == 0) {
return;
}
dispatch_sync(stream->queue(), ^(){
dispatch_sync(stream->queue(), ^() {
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
@ -188,29 +185,29 @@ void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& ot
[commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]];
[commandEncoder setComputePipelineState:cplState];
[commandEncoder setBytes:&length length:sizeof(length) atIndex:0];
[commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1];
[commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2];
[commandEncoder setBuffer:otherBuf offset:other.storage_offset()*other.itemsize() atIndex:3];
[commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1];
[commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2];
[commandEncoder setBuffer:otherBuf offset:other.storage_offset() * other.itemsize() atIndex:3];
dispatch1DJob(commandEncoder, cplState, length);
[commandEncoder endEncoding];
stream->commit(true);
});
}
void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& other, at::Tensor& output, const std::string& kernel_name) {
void handle_tensor_scalar_binary_op(const at::Tensor& self,
const at::Scalar& other,
at::Tensor& output,
const std::string& kernel_name) {
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
getMetalType(output),
getMetalType(self),
getMetalType(other),
kernel_name);
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
uint64_t sval = other.to<int64_t>();
uint32_t length = output.numel();
if (length == 0) {
return;
}
dispatch_sync(stream->queue(), ^(){
dispatch_sync(stream->queue(), ^() {
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
@ -220,8 +217,8 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
[commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]];
[commandEncoder setComputePipelineState:cplState];
[commandEncoder setBytes:&length length:sizeof(length) atIndex:0];
[commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1];
[commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2];
[commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1];
[commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2];
[commandEncoder setBytes:&sval length:sizeof(sval) atIndex:3];
dispatch1DJob(commandEncoder, cplState, length);
[commandEncoder endEncoding];
@ -229,7 +226,10 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
});
}
at::Tensor& _bitwise_op_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output_, const std::string& op_name) {
at::Tensor& _bitwise_op_out_mps(const at::Tensor& self,
const at::Tensor& other,
at::Tensor& output_,
const std::string& op_name) {
using namespace at::mps;
const bool is_self_scalar = self.dim() == 0;
const bool is_other_scalar = other.dim() == 0;
@ -269,19 +269,19 @@ at::Tensor& _bitwise_op_out_mps (const at::Tensor& self, const at::Tensor& other
return output_;
}
at::Tensor& bitwise_and_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
at::Tensor& bitwise_and_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
return _bitwise_op_out_mps(self, other, output, "and");
}
at::Tensor& bitwise_or_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
at::Tensor& bitwise_or_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
return _bitwise_op_out_mps(self, other, output, "or");
}
at::Tensor& bitwise_xor_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
at::Tensor& bitwise_xor_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
return _bitwise_op_out_mps(self, other, output, "xor");
}
at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
at::Tensor& bitwise_not_out_mps(const at::Tensor& self, at::Tensor& output_) {
// Handle boolean tensor using logical not
if (self.scalar_type() == c10::ScalarType::Bool) {
return at::native::logical_not_out_mps(self, output_);
@ -310,12 +310,9 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
}
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
getMetalType(output),
getMetalType(self),
getMetalType(self),
"bitwise_not");
dispatch_sync(stream->queue(), ^(){
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(self), "bitwise_not");
dispatch_sync(stream->queue(), ^() {
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
@ -325,8 +322,8 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
[commandEncoder pushDebugGroup:@"Dispatch bitwise_not kernel"];
[commandEncoder setComputePipelineState:cplState];
[commandEncoder setBytes:&length length:sizeof(length) atIndex:0];
[commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1];
[commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2];
[commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1];
[commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2];
dispatch1DJob(commandEncoder, cplState, length);
[commandEncoder endEncoding];
stream->commit(true);
@ -337,8 +334,6 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
return output_;
}
TORCH_LIBRARY_IMPL(aten, MPS, m) {
m.impl("bitwise_and.Tensor_out", bitwise_and_out_mps);
m.impl("bitwise_or.Tensor_out", bitwise_or_out_mps);

View File

@ -12,23 +12,16 @@
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#endif
namespace at::native {
Tensor dot_mps(
const Tensor &self,
const Tensor &other)
{
Tensor dot_mps(const Tensor& self, const Tensor& other) {
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS: dot op doesn't support int64 input")
using namespace mps;
auto output = at::native::empty_mps({}, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* selfTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
@ -40,45 +33,38 @@ Tensor dot_mps(
@autoreleasepool {
string key = "dot_mps" + getTensorsStringKey({self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor *castSelf = nil;
MPSGraphTensor *castOther = nil;
MPSGraphTensor* castSelf = nil;
MPSGraphTensor* castOther = nil;
if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte
|| self.scalar_type() == ScalarType::Char) {
castSelf = [mpsGraph castTensor:selfTensor
toType:MPSDataTypeInt32
name:@"castSelfTensor"];
castOther = [mpsGraph castTensor:otherTensor
toType:MPSDataTypeInt32
name:@"castOtherTensor"];
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
self.scalar_type() == ScalarType::Char) {
castSelf = [mpsGraph castTensor:selfTensor toType:MPSDataTypeInt32 name:@"castSelfTensor"];
castOther = [mpsGraph castTensor:otherTensor toType:MPSDataTypeInt32 name:@"castOtherTensor"];
} else {
castSelf = selfTensor;
castOther = otherTensor;
}
MPSGraphTensor *dot = [mpsGraph multiplicationWithPrimaryTensor: castSelf
secondaryTensor: castOther
name: @"multiplication"];
MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf
secondaryTensor:castOther
name:@"multiplication"];
MPSGraphTensor *dotProductTensor = [mpsGraph reductionSumWithTensor: dot
axes: nil
name: @"dotProduct"];
MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor:dot axes:nil name:@"dotProduct"];
if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte
|| self.scalar_type() == ScalarType::Char)
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
self.scalar_type() == ScalarType::Char)
dotProductTensor = [mpsGraph castTensor:dotProductTensor
toType:getMPSDataType(self)
name:@"castDotProductTensor"];
@ -89,7 +75,7 @@ Tensor dot_mps(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
@ -101,9 +87,8 @@ Tensor dot_mps(
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -111,14 +96,12 @@ Tensor dot_mps(
return output;
}
Tensor& addmv_out_mps_impl(
const Tensor &self,
const Tensor &mat,
const Tensor &vec,
Tensor& addmv_out_mps_impl(const Tensor& self,
const Tensor& mat,
const Tensor& vec,
const Scalar& beta_,
const Scalar& alpha_,
Tensor& result)
{
Tensor& result) {
using namespace mps;
TORCH_CHECK(mat.is_mps());
@ -129,38 +112,35 @@ Tensor& addmv_out_mps_impl(
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
auto betaval = beta_.toComplexDouble();
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *selfTensor_ = nil;
MPSGraphTensor *matMulVecTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* selfTensor_ = nil;
MPSGraphTensor* matMulVecTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
MPSStream *stream = at::mps::getCurrentMPSStream();
MPSStream* stream = at::mps::getCurrentMPSStream();
Tensor matMulVec = mm(mat, vec.unsqueeze(1)).squeeze(1);
@autoreleasepool {
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec})
+ ":" + to_string(beta_.toDouble())
+ ":" + to_string(alpha_.toDouble());
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
":" + to_string(alpha_.toDouble());
CachedGraph* cachedGraph = nil;
if(!cachedGraph) {
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
// Intermediates for beta and alpha
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha_.toDouble()
dataType: getMPSScalarType(mat.scalar_type())];
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha_.toDouble()
dataType:getMPSScalarType(mat.scalar_type())];
// Intermediates for multiplying by beta and alpha
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:matMulVecTensor
@ -168,18 +148,17 @@ Tensor& addmv_out_mps_impl(
name:@"MM/alpha*(mat@vec)"];
newCachedGraph->outputTensor_ = productTimesAlphaTensor;
if (betaval != 0.0)
{
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta_.toDouble()
dataType: getMPSScalarType(self.scalar_type())];
if (betaval != 0.0) {
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta_.toDouble()
dataType:getMPSScalarType(self.scalar_type())];
MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: selfTensor
secondaryTensor: betaTensor
name: @"MM/beta*input"];
MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
secondaryTensor:betaTensor
name:@"MM/beta*input"];
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor: productTimesAlphaTensor
secondaryTensor: selfTimesBetaTensor
name: @"MM/beta*input + alpha*(mat@vec)"];
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
secondaryTensor:selfTimesBetaTensor
name:@"MM/beta*input + alpha*(mat@vec)"];
newCachedGraph->outputTensor_ = outputTensor;
}
@ -189,23 +168,21 @@ Tensor& addmv_out_mps_impl(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder matMulVecPlaceholder = Placeholder(cachedGraph->matMulVecTensor_, matMulVec);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
feeds[matMulVecPlaceholder.getMPSGraphTensor()] = matMulVecPlaceholder.getMPSGraphTensorData();
if (betaval != 0.0)
{
if (betaval != 0.0) {
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -213,7 +190,13 @@ Tensor& addmv_out_mps_impl(
return result;
}
TORCH_IMPL_FUNC(addmv_out_mps)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
TORCH_IMPL_FUNC(addmv_out_mps)
(const Tensor& self,
const Tensor& mat,
const Tensor& vec,
const Scalar& beta_,
const Scalar& alpha_,
const Tensor& result) {
addmv_out_mps_impl(self, mat, vec, beta_, alpha_, const_cast<Tensor&>(result));
}

View File

@ -18,26 +18,27 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
auto dataType =
!isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
// workaround by filing it as int8 tensor and than casting to bool
// See https://github.com/pytorch/pytorch/issues/82427
@ -47,17 +48,12 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
shape:getMPSShape(self)
dataType:dataType];
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
if (isBool) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"constWithBool-workaround"];
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
}
if (isUInt8) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeUInt8
name:@"constWithUInt8-workaround"];
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
}
newCachedGraph->outputTensor_ = outputTensor;
@ -66,13 +62,11 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
});
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
needsCopyToOutput ? output : self,
nullptr, !needsCopyToOutput);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results);
@ -109,7 +103,10 @@ Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
}
Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) {
TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
TORCH_CHECK(value.dim() == 0,
"fill_ only supports 0-dimension value tensor but got tensor with ",
value.dim(),
" dimensions.");
Scalar scalar_value = value.item();
if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
return self;

View File

@ -2,39 +2,51 @@
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>
namespace at::native {
void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* descriptor_,
NSUInteger strideInX, NSUInteger strideInY,
NSUInteger dilationRateInX, NSUInteger dilationRateInY,
NSUInteger paddingHorizontal, NSUInteger paddingVertical,
c10::MemoryFormat memory_format, NSUInteger groups) {
descriptor_.strides = @[@1, [[NSNumber alloc] initWithInteger: strideInY],
[[NSNumber alloc] initWithInteger: strideInX]];
descriptor_.dilationRates = @[@1, [[NSNumber alloc] initWithInteger: dilationRateInY],
[[NSNumber alloc] initWithInteger: dilationRateInX]];
NSUInteger strideInX,
NSUInteger strideInY,
NSUInteger dilationRateInX,
NSUInteger dilationRateInY,
NSUInteger paddingHorizontal,
NSUInteger paddingVertical,
c10::MemoryFormat memory_format,
NSUInteger groups) {
descriptor_.strides =
@[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ];
descriptor_.dilationRates =
@[ @1, [[NSNumber alloc] initWithInteger:dilationRateInY], [[NSNumber alloc] initWithInteger:dilationRateInX] ];
descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit;
descriptor_.paddingValues = @[@0, @0, [[NSNumber alloc] initWithInteger: paddingVertical], [[NSNumber alloc]
initWithInteger: paddingVertical], [[NSNumber alloc]
initWithInteger: paddingHorizontal], [[NSNumber alloc]
initWithInteger: paddingHorizontal]];
descriptor_.paddingValues = @[
@0,
@0,
[[NSNumber alloc] initWithInteger:paddingVertical],
[[NSNumber alloc] initWithInteger:paddingVertical],
[[NSNumber alloc] initWithInteger:paddingHorizontal],
[[NSNumber alloc] initWithInteger:paddingHorizontal]
];
descriptor_.channelDimensionIndex = -3LL;
}
// Create convolution descriptor
void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
NSUInteger strideInX, NSUInteger strideInY,
NSUInteger dilationRateInX, NSUInteger dilationRateInY,
NSUInteger paddingHorizontal, NSUInteger paddingVertical,
c10::MemoryFormat memory_format, NSUInteger groups) {
NSUInteger strideInX,
NSUInteger strideInY,
NSUInteger dilationRateInX,
NSUInteger dilationRateInY,
NSUInteger paddingHorizontal,
NSUInteger paddingVertical,
c10::MemoryFormat memory_format,
NSUInteger groups) {
descriptor_.strideInX = strideInX;
descriptor_.strideInY = strideInY;
descriptor_.dilationRateInX = dilationRateInX;
@ -48,16 +60,15 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
descriptor_.paddingTop = paddingVertical;
descriptor_.paddingBottom = paddingVertical;
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ?
MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW
: MPSGraphTensorNamedDataLayoutNHWC;
// PyTorch always uses OIHW memory layout for weights
descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
descriptor_.groups = groups;
}
Tensor _mps_convolution_impl(
const Tensor& input_t,
Tensor _mps_convolution_impl(const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
IntArrayRef padding,
@ -70,25 +81,22 @@ Tensor _mps_convolution_impl(
namespace native_mps = at::native::mps;
CheckedFrom c = "mps_convolution";
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
bool bias_defined;
if(bias_opt == c10::nullopt)
if (bias_opt == c10::nullopt)
bias_defined = false;
else
bias_defined = bias_opt->defined();
auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
auto output_t = at::empty(
input_shape.has_value() ?
input_shape.value() :
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
auto output_t =
at::empty(input_shape.has_value() ? input_shape.value()
: conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation),
input->scalar_type(),
c10::nullopt,
kMPS,
@ -98,14 +106,13 @@ Tensor _mps_convolution_impl(
if (output_t.numel() == 0) {
return output_t;
}
TensorArg output{ output_t, "result", 0 };
TensorArg output{output_t, "result", 0};
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
// Derive from MPSCachedGraph
struct CachedGraph : public native_mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public native_mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* biasTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
@ -117,13 +124,12 @@ Tensor _mps_convolution_impl(
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
IntArrayRef bias_shape;
if(bias_defined)
if (bias_defined)
bias_shape = bias_opt.value().sizes();
string mem_format_key;
switch(memory_format) {
switch (memory_format) {
case at::MemoryFormat::Contiguous:
mem_format_key = "Contiguous";
break;
@ -135,76 +141,87 @@ Tensor _mps_convolution_impl(
}
string bias_shape_key;
if(bias_defined) {
if (bias_defined) {
bias_shape_key = to_string(bias_shape[0]);
} else {
bias_shape_key = "nobias";
}
string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
+ to_string(groups) + ":" + mem_format_key
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
+ to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) +
":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" +
to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ =[[MPSGraphConvolution2DOpDescriptor new] autorelease];
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
MPSShape* weightShape = mps::getMPSShape(weight_t);
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) &&
inputShape.count >= 4 && weightShape.count >= 4 && !is_channels_last);
if(isDepthwiseConv) {
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
dilation[1], dilation[0],
padding[1], padding[0],
memory_format, groups);
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 &&
weightShape.count >= 4 && !is_channels_last);
if (isDepthwiseConv) {
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
memory_format,
groups);
} else {
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
dilation[1], dilation[0],
padding[1], padding[0],
memory_format, groups);
fill_conv_desc(conv2dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
memory_format,
groups);
}
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(
mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor* biasTensor = nil;
if(bias_defined) {
biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()));
if (bias_defined) {
biasTensor =
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()));
}
MPSGraphTensor* outputTensor;
if(isDepthwiseConv) {
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil];
outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor: inputTensor
weightsTensor: weightTransposeTensor
descriptor: depthWiseConv3dDescriptor_
name: nil];
if (isDepthwiseConv) {
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-3
withDimension:-4
name:nil];
outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:inputTensor
weightsTensor:weightTransposeTensor
descriptor:depthWiseConv3dDescriptor_
name:nil];
} else {
outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor
weightsTensor: weightTensor
descriptor: conv2dDescriptor_
name: nil];
outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
weightsTensor:weightTensor
descriptor:conv2dDescriptor_
name:nil];
}
if (is_channels_last) {
outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor);
}
if(bias_defined) {
outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor
secondaryTensor: biasTensor
name: nil];
if (bias_defined) {
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->weightTensor_ = weightTensor;
@ -213,27 +230,28 @@ Tensor _mps_convolution_impl(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t);
auto biasPlaceholder = native_mps::Placeholder();
// Reshape the bias to be broadcastable with output of conv2d
if(bias_defined)
biasPlaceholder = native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
if (bias_defined)
biasPlaceholder =
native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, *output);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 3] autorelease];
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
[[[NSMutableDictionary alloc] initWithCapacity:3] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
if(bias_defined) {
if (bias_defined) {
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -241,8 +259,7 @@ Tensor _mps_convolution_impl(
return *output;
}
Tensor _mps_convolution(
const Tensor& input_t,
Tensor _mps_convolution(const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
IntArrayRef padding,
@ -252,29 +269,32 @@ Tensor _mps_convolution(
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
}
Tensor mps_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
Tensor mps_convolution_backward_input(IntArrayRef input_size,
const Tensor& grad_output_t,
const Tensor& weight_t,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
CheckedFrom c = "mps_convolution_backward_input";
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
weight{ weight_t, "weight", 2 };
TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2};
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});
auto memory_format = grad_output_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);
auto grad_input_t = at::empty(input_size, grad_output_t.options(), c10::nullopt);
// Avoid "grad_input" when this is being used as transposed convolution
TensorArg grad_input{ grad_input_t, "result", 0 };
TensorArg grad_input{grad_input_t, "result", 0};
convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
// Derive from MPSCachedGraph
struct CachedGraph : public native_mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public native_mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* gradInputTensor_ = nil;
@ -284,11 +304,10 @@ Tensor mps_convolution_backward_input(
// Add backward with input
@autoreleasepool {
MPSStream* stream = getCurrentMPSStream();
string mem_format_key;
switch(memory_format) {
switch (memory_format) {
case at::MemoryFormat::Contiguous:
mem_format_key = "Contiguous";
break;
@ -302,54 +321,67 @@ Tensor mps_convolution_backward_input(
MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
MPSShape* mps_input_shape = getMPSShape(input_size);
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
+ to_string(groups) + ":" + mem_format_key
+ getTensorsStringKey({grad_output_t, weight_t}) + ":"
+ string([ns_shape_key UTF8String]);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
MPSShape* weightOutputShape = mps::getMPSShape(weight_t);
// Depthwise conv is input feature channels = groups. So I in OIHW has to be 1.
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) &&
gradOutputShape.count >= 4 && weightOutputShape.count >= 4 && !is_channels_last);
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 &&
weightOutputShape.count >= 4 && !is_channels_last);
if(isDepthwiseConv) {
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
dilation[1], dilation[0],
padding[1], padding[0],
at::MemoryFormat::Contiguous, groups);
if (isDepthwiseConv) {
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
at::MemoryFormat::Contiguous,
groups);
} else {
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
dilation[1], dilation[0],
padding[1], padding[0],
at::MemoryFormat::Contiguous, groups);
fill_conv_desc(conv2dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
at::MemoryFormat::Contiguous,
groups);
}
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
MPSGraphTensor* gradInputTensor;
if(isDepthwiseConv) {
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil];
gradInputTensor = [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
if (isDepthwiseConv) {
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-3
withDimension:-4
name:nil];
gradInputTensor =
[mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
weightsTensor:weightTransposeTensor
outputShape:mps_input_shape
descriptor:depthWiseConv3dDescriptor_
@ -368,30 +400,34 @@ Tensor mps_convolution_backward_input(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
weightsPlaceholder.getMPSGraphTensor() : weightsPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return *grad_input;
}
Tensor mps_convolution_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
@ -403,27 +439,21 @@ Tensor mps_convolution_backward_weights(
// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
TensorArg grad_output{ grad_output_t, "grad_output", 1 };
TensorArg input{ input_t, "input", 2};
TensorArg grad_output{grad_output_t, "grad_output", 1};
TensorArg input{input_t, "input", 2};
checkAllSameType(c, {grad_output, input});
checkAllSameGPU(c, {grad_output, input});
auto grad_weight_t = at::empty(
weight_size,
grad_output_t.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
TensorArg grad_weight{ grad_weight_t, "result", 0 };
auto grad_weight_t =
at::empty(weight_size, grad_output_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
TensorArg grad_weight{grad_weight_t, "result", 0};
convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
// Derive from MPSCachedGraph
struct CachedGraph : public native_mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public native_mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* gradWeightTensor_ = nil;
@ -432,11 +462,10 @@ Tensor mps_convolution_backward_weights(
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
@autoreleasepool {
MPSStream* stream = getCurrentMPSStream();
string mem_format_key;
switch(memory_format) {
switch (memory_format) {
case at::MemoryFormat::Contiguous:
mem_format_key = "Contiguous";
break;
@ -448,60 +477,75 @@ Tensor mps_convolution_backward_weights(
}
MPSShape* mps_weight_shape = getMPSShape(weight_size);
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
+ to_string(groups) + ":" + mem_format_key
+ getTensorsStringKey({grad_output_t, input_t}) + ":"
+ string([ns_shape_key UTF8String]);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, input_t}) + ":" + string([ns_shape_key UTF8String]);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
MPSShape* inputShape = mps::getMPSShape(input_t);
bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 && mps_weight_shape.count >= 4 && !is_channels_last);
bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 &&
mps_weight_shape.count >= 4 && !is_channels_last);
if(isDepthwiseConv) {
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
dilation[1], dilation[0],
padding[1], padding[0],
at::MemoryFormat::Contiguous, groups);
if (isDepthwiseConv) {
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
at::MemoryFormat::Contiguous,
groups);
} else {
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
dilation[1], dilation[0],
padding[1], padding[0],
at::MemoryFormat::Contiguous, groups);
fill_conv_desc(conv2dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
at::MemoryFormat::Contiguous,
groups);
}
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
MPSGraphTensor* gradWeightTensor;
if(isDepthwiseConv) {
if (isDepthwiseConv) {
NSNumber* outputFeatChannelDim = mps_weight_shape[0];
MPSShape* weightShapeTranspose = @[@1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3]];
MPSGraphTensor* gradWeightTensorTranspose = [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ];
MPSGraphTensor* gradWeightTensorTranspose =
[mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
sourceTensor:inputTensor
outputShape:weightShapeTranspose
descriptor:depthWiseConv3dDescriptor_
name:nil];
gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose dimension:-3 withDimension:-4 name:nil];
gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose
dimension:-3
withDimension:-4
name:nil];
} else {
gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
gradWeightTensor =
[mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
sourceTensor:inputTensor
outputShape:mps_weight_shape
forwardConvolutionDescriptor:conv2dDescriptor_
@ -513,21 +557,20 @@ Tensor mps_convolution_backward_weights(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -535,10 +578,14 @@ Tensor mps_convolution_backward_weights(
return grad_weight_t;
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,3> output_mask) {
std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_convolution_backward(const at::Tensor& input,
const at::Tensor& grad_output,
const at::Tensor& weight,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::array<bool, 3> output_mask) {
Tensor grad_input, grad_weight, grad_bias;
if (input.numel() == 0) {
if (output_mask[0]) {
@ -549,73 +596,85 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
}
} else {
if (output_mask[0]) {
grad_input = mps_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
grad_input = mps_convolution_backward_input(
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
}
if (output_mask[1]) {
grad_weight = mps_convolution_backward_weights(weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
grad_weight = mps_convolution_backward_weights(
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
}
}
return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
Tensor mps_convolution_transpose_forward(
const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
{
auto input_size = conv_input_size(grad_output.sizes(), weight.sizes(),
padding, output_padding, stride, dilation, groups);
return mps_convolution_backward_input(input_size, grad_output, weight,
padding, stride, dilation, groups, false);
Tensor mps_convolution_transpose_forward(const Tensor& grad_output,
const Tensor& weight,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
auto input_size =
conv_input_size(grad_output.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
return mps_convolution_backward_input(input_size, grad_output, weight, padding, stride, dilation, groups, false);
}
Tensor _mps_convolution_transpose(
const Tensor& input_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
Tensor _mps_convolution_transpose(const Tensor& input_t,
const Tensor& weight_t,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS");
auto output_t = mps_convolution_transpose_forward(
input_t, weight_t, padding, output_padding, stride, dilation, groups);
auto output_t =
mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups);
return output_t;
}
Tensor mps_convolution_transpose_backward_input(
const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, IntArrayRef input_shape)
{
return _mps_convolution_impl(
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
Tensor mps_convolution_transpose_backward_input(const Tensor& grad_output_t,
const Tensor& weight_t,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
IntArrayRef input_shape) {
return _mps_convolution_impl(grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
}
Tensor mps_convolution_transpose_backward_weight(
IntArrayRef weight_size,
Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
{
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
return mps_convolution_backward_weights(
weight_size, input_t, grad_output_t,
padding, stride, dilation, groups, false);
weight_size, input_t, grad_output_t, padding, stride, dilation, groups, false);
}
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,2> output_mask) {
std::tuple<Tensor, Tensor> mps_convolution_transpose_backward(const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::array<bool, 2> output_mask) {
Tensor grad_input, grad_weight;
if (output_mask[0]) {
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
grad_input =
mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
}
if (output_mask[1]) {
grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups);
grad_weight = mps_convolution_transpose_backward_weight(
weight.sizes(), grad_output, input, padding, stride, dilation, groups);
}
return std::tuple<Tensor,Tensor>{grad_input, grad_weight};
return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
}
} // namespace at::native

View File

@ -6,10 +6,7 @@
namespace at::native {
namespace mps {
void* pageAlignedBlockPtr(
const void* ptr,
NSUInteger size,
NSUInteger* alignedBlockSize) {
void* pageAlignedBlockPtr(const void* ptr, NSUInteger size, NSUInteger* alignedBlockSize) {
uintptr_t address = (uintptr_t)ptr;
uintptr_t alignedAddress = address & ~(PAGE_SIZE - 1);
uintptr_t alignedEnd = ((address + size) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
@ -30,7 +27,7 @@ size_t compute_strided_size(const at::Tensor& t) {
if (t.numel() == 0) {
return 0;
}
for(const auto i: c10::irange(t.dim())) {
for (const auto i : c10::irange(t.dim())) {
assert(t.size(i) > 0);
rc += (t.size(i) - 1) * t.stride(i);
}
@ -43,13 +40,15 @@ bool is_strided_contiguous(const at::Tensor& t) {
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
id<MTLBuffer> destBuffer, id<MTLBuffer> sourceBuffer, bool non_blocking = true) {
void copy_cast_mps(at::Tensor& dst,
const at::Tensor& src,
id<MTLBuffer> destBuffer,
id<MTLBuffer> sourceBuffer,
bool non_blocking = true) {
using namespace mps;
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
@ -64,11 +63,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
@autoreleasepool {
string key = "copy_cast_mps" + getTensorsStringKey({src, dst});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
@ -85,23 +84,24 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc]
initWithMTLBuffer:sourceBuffer shape:srcShape dataType:srcDType]
autorelease];
MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc]
initWithMTLBuffer:destBuffer shape:dstShape dataType:dstDType]
autorelease];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_: srcData};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor_: dstData};
stream->executeMPSGraph(cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer
shape:srcShape
dataType:srcDType] autorelease];
MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:destBuffer
shape:dstShape
dataType:dstDType] autorelease];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : srcData};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor_ : dstData};
stream->executeMPSGraph(
cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
}
}
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
auto sameMemFormat =
src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* stream = getCurrentMPSStream();
@ -181,15 +181,14 @@ static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool
}
// Copies tensor from cpu to mps backed by identical strided-contiguous data
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
{
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking) {
MPSStream* stream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
auto src_byte_offset = src.storage_offset() * src.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
const size_t size_to_copy = src.nbytes();
const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;
const void* host_src = static_cast<char*>(src.storage().data()) + src_byte_offset;
TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));
@ -210,8 +209,7 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
}
}
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
// Typecast to dst_ if needed and expand, which is a no-op
Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);
@ -233,7 +231,7 @@ static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool n
dst = at::empty_like(src, at::device(at::kMPS));
}
copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy);
return needs_copy? dst_.copy_(dst) : dst_;
return needs_copy ? dst_.copy_(dst) : dst_;
}
void copy_blit_mps(void* dst, const void* src, size_t size) {
@ -241,8 +239,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
stream->copy_and_sync((id<MTLBuffer>)(src), (id<MTLBuffer>)(dst), size, 0, 0, true);
}
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
auto src_byte_offset = src_.storage_offset() * src_.itemsize();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
@ -250,7 +247,8 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
// gather into dst. This reduces the overhead of doing an additional blit for most cases
bool returnGatherOutput = dst_.is_contiguous();
Tensor src;
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
auto sameMemFormat =
src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
const bool sameDataType = src_.dtype() == dst_.dtype();
if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) ||
@ -301,8 +299,7 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
return dst_;
}
at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
{
at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking) {
TORCH_CHECK(dst.defined(), "dst is undefined");
TORCH_CHECK(src.defined(), "src is undefined");
@ -328,20 +325,16 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}
TORCH_INTERNAL_ASSERT(
src.device().type() == DeviceType::MPS,
"mps_copy_ is implemented only for *->MPS; MPS->*");
TORCH_INTERNAL_ASSERT(src.device().type() == DeviceType::MPS, "mps_copy_ is implemented only for *->MPS; MPS->*");
return dst;
}
} // namespace mps
Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst)
{
Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst) {
return mps::mps_copy_(const_cast<Tensor&>(dst), self, false);
}
Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking)
{
Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
return mps::mps_copy_(const_cast<Tensor&>(dst), self, non_blocking);
}

View File

@ -1,7 +1,7 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/Cross.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
@ -82,10 +82,10 @@ static id<MTLLibrary> compileCrossOpLibrary(id<MTLDevice> device) {
return crossLibrary;
}
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
crossLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_CROSS encoding:NSASCIIStringEncoding]
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
crossLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_CROSS encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(crossLibrary, "Failed to create metal cross library, error: ", [[error description] UTF8String]);
@ -133,7 +133,7 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
dispatch_sync(mpsStream->queue(), ^(){
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
NSError* error = nil;
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
@ -143,23 +143,25 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
std::vector<uint32_t> iterShapeData(iterShape.size());
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);
for (const auto i: c10::irange(iterShape.size())) {
for (const auto i : c10::irange(iterShape.size())) {
TORCH_CHECK(i <= UINT32_MAX);
iterShapeData[i] = (uint32_t)(iterShape[i]);
}
for (const auto i: c10::irange(nDim)) {
for (const auto offset: c10::irange(nOffsets)) {
for (const auto i : c10::irange(nDim)) {
for (const auto offset : c10::irange(nOffsets)) {
strides[i][offset] = iter.strides(offset)[i];
}
}
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
error: &error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
@ -172,8 +174,7 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
kernelOffsetsTGSize = numThreads;
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
id<MTLComputePipelineState> crossPSO = crossPipelineState(device, out.scalar_type());
[computeEncoder setComputePipelineState:crossPSO];
@ -191,8 +192,7 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: threadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
[computeEncoder endEncoding];
mpsStream->commit(true);

View File

@ -1,39 +1,40 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/Distributions.h>
#include <ATen/native/DistributionTemplates.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/mps/MPSGeneratorImpl.h>
#include <ATen/native/DistributionTemplates.h>
#include <ATen/native/Distributions.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
struct RandomCachedGraph : public MPSCachedGraph
{
RandomCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { }
struct RandomCachedGraph : public MPSCachedGraph {
RandomCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
// Only relevant for multinomial
MPSGraphTensor *probTensor = nil;
MPSGraphTensor *resultTensor = nil;
MPSGraphTensor *stateTensor = nil;
MPSGraphTensor* probTensor = nil;
MPSGraphTensor* resultTensor = nil;
MPSGraphTensor* stateTensor = nil;
// used for Normal distributions only
MPSGraphTensor *meanTensor = nil, *stdTensor = nil;
};
typedef MPSGraphTensor* (^RandomOpBlock)(RandomCachedGraph*, MPSGraphTensor*);
#define RandomOpFn(graph, randomTensor) MPSGraphTensor* (mps::RandomCachedGraph* graph, MPSGraphTensor* randomTensor)
#define RandomOpFn(graph, randomTensor) MPSGraphTensor*(mps::RandomCachedGraph * graph, MPSGraphTensor * randomTensor)
// for Uniform distributions with scalar from (val1) and to (val2) intervals
// for Normal distributions with scalar mean (val1) and std (val2) values
template<typename scalar_t>
Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
template <typename scalar_t>
Tensor& random_mps_impl(Tensor& self,
scalar_t val1,
scalar_t val2,
const c10::optional<Tensor>& mean_opt,
const c10::optional<Tensor>& std_opt,
MPSGraphRandomDistribution distribution,
c10::optional<Generator> gen,
std::string op_name, RandomOpBlock randomBlock)
{
std::string op_name,
RandomOpBlock randomBlock) {
if (self.numel() == 0) {
return self;
}
@ -46,13 +47,14 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
RandomCachedGraph *newCachedGraph = nil;
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^MPSCachedGraph*() {
RandomCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new RandomCachedGraph(mpsGraph);
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(at::mps::detail::PHILOX_STATE_N)]);
newCachedGraph->stateTensor =
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
// FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend.
const MPSDataType inputDataType = [&] {
@ -64,8 +66,8 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
}();
const MPSDataType outputDataType = (std::is_same<scalar_t, bool>::value) ? MPSDataTypeBool : inputDataType;
MPSGraphRandomOpDescriptor *desc = [MPSGraphRandomOpDescriptor descriptorWithDistribution: distribution
dataType: inputDataType];
MPSGraphRandomOpDescriptor* desc = [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution
dataType:inputDataType];
if (distribution == MPSGraphRandomDistributionUniform) {
if (inputDataType == MPSDataTypeInt32) {
desc.minInteger = static_cast<NSInteger>(val1);
@ -81,10 +83,10 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
// we don't use the output state tensor from the MPSGraph API as it requires reading back from GPU to CPU.
// Instead, we keep the Philox state in the MPSGenerator and use the PyTorch's philox_engine to maintain
// the counters, and feed them to the graph manually
NSArray<MPSGraphTensor*> *resultTensors = [mpsGraph randomTensorWithShape: getMPSShape(self)
descriptor: desc
stateTensor: newCachedGraph->stateTensor
name: nil];
NSArray<MPSGraphTensor*>* resultTensors = [mpsGraph randomTensorWithShape:getMPSShape(self)
descriptor:desc
stateTensor:newCachedGraph->stateTensor
name:nil];
newCachedGraph->resultTensor = randomBlock ? randomBlock(newCachedGraph, resultTensors[0]) : resultTensors[0];
// results will be cast if self's scalar type isn't directly supported by MPS backend.
if (getMPSDataType(self) != outputDataType)
@ -94,19 +96,20 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
});
}
// feed the updated state values to the graph
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]];
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
MPSNDArrayDescriptor* stateDesc =
[MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]];
MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease];
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(mps_gen->mutex_);
// update the Philox state values on each run
mps_gen->update_philox_counters();
[stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil];
[stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil];
}
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease];
Placeholder meanPlaceholder, stdPlaceholder;
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
feeds[cachedGraph->stateTensor] = stateTensorData;
if (cachedGraph->stdTensor) {
@ -121,7 +124,7 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = @{
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
};
@ -131,12 +134,13 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
return self;
}
Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s,
Tensor& normal_mps_impl(Tensor& self,
double mean_s,
double std_s,
const c10::optional<Tensor>& mean_opt,
const c10::optional<Tensor>& std_opt,
c10::optional<Generator> gen,
std::string op_name)
{
std::string op_name) {
const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt));
const Tensor& mean_t = *(at::borrow_from_optional_tensor(mean_opt));
@ -153,39 +157,45 @@ Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s,
if (std_t.defined()) {
cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, std_t);
resultTensor = [mpsGraph multiplicationWithPrimaryTensor: randomTensor
secondaryTensor: cachedGraph->stdTensor
name: nil];
resultTensor = [mpsGraph multiplicationWithPrimaryTensor:randomTensor
secondaryTensor:cachedGraph->stdTensor
name:nil];
}
if (mean_t.defined()) {
cachedGraph->meanTensor = mpsGraphRankedPlaceHolder(mpsGraph, mean_t);
return [mpsGraph additionWithPrimaryTensor: resultTensor
secondaryTensor: cachedGraph->meanTensor
name: nil];
return [mpsGraph additionWithPrimaryTensor:resultTensor secondaryTensor:cachedGraph->meanTensor name:nil];
}
return resultTensor;
};
return random_mps_impl<double>(self, mean_s, std_s, mean_opt, std_opt,
MPSGraphRandomDistributionNormal, gen,
op_name + getTensorsStringKey({mean_t, std_t}), random_op_block);
return random_mps_impl<double>(self,
mean_s,
std_s,
mean_opt,
std_opt,
MPSGraphRandomDistributionNormal,
gen,
op_name + getTensorsStringKey({mean_t, std_t}),
random_op_block);
}
Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional<Generator> gen, std::string op_name)
{
Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional<Generator> gen, std::string op_name) {
TORCH_CHECK(prob_t.is_same_size(self), op_name, ": probability and self tensor should be of the same shape")
RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, prob_t);
return [mpsGraph lessThanWithPrimaryTensor: randomTensor
secondaryTensor: cachedGraph->stdTensor
name: nil];
return [mpsGraph lessThanWithPrimaryTensor:randomTensor secondaryTensor:cachedGraph->stdTensor name:nil];
};
// Bernoulli generates binary output so we use bool type
return mps::random_mps_impl<bool>(self, 0.0, 1.0, c10::nullopt, prob_t,
MPSGraphRandomDistributionUniform, gen,
op_name + getTensorsStringKey({prob_t}), random_op_block);
return mps::random_mps_impl<bool>(self,
0.0,
1.0,
c10::nullopt,
prob_t,
MPSGraphRandomDistributionUniform,
gen,
op_name + getTensorsStringKey({prob_t}),
random_op_block);
}
} // namespace mps
@ -196,15 +206,19 @@ Tensor& uniform_mps_(Tensor& self, double from, double to, c10::optional<Generat
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
"uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
">::max(), but found to=", to, " and from=", from,
"uniform_ expects to-from <= std::numeric_limits<",
toString(self.scalar_type()),
">::max(), but found to=",
to,
" and from=",
from,
" which result in to-from to exceed the limit");
from = std::min(std::max(from, min), max);
to = std::max(std::min(to, max), min);
});
return mps::random_mps_impl<double>(self, from, to, c10::nullopt, c10::nullopt,
MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
return mps::random_mps_impl<double>(
self, from, to, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
}
Tensor& normal_mps_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
@ -271,21 +285,34 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
to = *to_opt;
TORCH_CHECK(from < to, "random_mps_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
if (isFloatingType(input_dtype)) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] {
from = templates::update_from<scalar_t>(from);
to = templates::update_to<scalar_t>(to);
TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
TORCH_CHECK(
from < to,
"random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=",
from,
" >= to=",
to);
});
templates::check_from_to_in_range(from, to - 1, self.dtype());
}
} else if (from != std::numeric_limits<int64_t>::lowest()) {
// [from, std::numeric_limits<int64_t>::max()]
if (isFloatingType(input_dtype)) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] {
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
to = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
to = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max()
: static_cast<int64_t>(scalar_t_max);
from = templates::update_from<scalar_t>(from);
TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=", from, " > to=", to);
TORCH_CHECK(
from < to,
"random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=",
from,
" > to=",
to);
});
} else if (isIntegralType(input_dtype, /*includeBool=*/true)) {
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, input_dtype, "random_from_to_range_calc", [&] {
@ -295,13 +322,11 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
to = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
}
});
}
else {
} else {
TORCH_CHECK(false, "random_mps_ handles only integral, floating-point and boolean types");
}
templates::check_from_to_in_range(from, to, self.dtype());
}
else {
} else {
// [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
// range = 2^64
@ -309,8 +334,8 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
TORCH_CHECK(false, "random_mps_ currently does not handle the lowest() -> max() range");
}
return mps::random_mps_impl<int64_t>(self, from, to - 1, c10::nullopt, c10::nullopt,
MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
return mps::random_mps_impl<int64_t>(
self, from, to - 1, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
}
Tensor& random_mps_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
@ -323,22 +348,23 @@ Tensor& exponential_mps_(Tensor& self, double lambda, c10::optional<Generator> g
mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar: 1.0f
dataType: randomTensor.dataType];
MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar: -lambda
dataType: randomTensor.dataType];
MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor: unitTensor
secondaryTensor: randomTensor
name: nil];
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor: subtractTensor
name: nil];
return [mpsGraph divisionWithPrimaryTensor: logTensor
secondaryTensor: minusLambdaTensor
name: nil];
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f dataType:randomTensor.dataType];
MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar:-lambda dataType:randomTensor.dataType];
MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor
secondaryTensor:randomTensor
name:nil];
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
};
return mps::random_mps_impl<double>(self, 0.0, 1.0, c10::nullopt, c10::nullopt,
MPSGraphRandomDistributionUniform, gen,
"exponential_mps_:" + std::to_string(lambda), random_op_block);
return mps::random_mps_impl<double>(self,
0.0,
1.0,
c10::nullopt,
c10::nullopt,
MPSGraphRandomDistributionUniform,
gen,
"exponential_mps_:" + std::to_string(lambda),
random_op_block);
}
Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor& result) {
@ -354,9 +380,12 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
}
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(!generator.has_value() ||
(generator.has_value() && result.device() == generator->device()),
"Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()),
"Expected a '",
result.device(),
"' generator device but found '",
generator->device(),
"'");
check_supported_max_int_with_precision(n, result);
result.resize_({n});
@ -366,36 +395,34 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor
axis:0
name:nil];
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor axis:0 name:nil];
if (result.scalar_type() != kInt) {
argsortTensor = [mpsGraph castTensor:argsortTensor
toType:mps::getMPSDataType(result)
name:@"castOutput"];
argsortTensor = [mpsGraph castTensor:argsortTensor toType:mps::getMPSDataType(result) name:@"castOutput"];
}
return argsortTensor;
};
return mps::random_mps_impl<int64_t>(result, 0.0, 1.0, c10::nullopt, c10::nullopt,
MPSGraphRandomDistributionUniform, generator,
"ranperm_out_mps:" + mps::getTensorsStringKey({result}), random_op_block);
return mps::random_mps_impl<int64_t>(result,
0.0,
1.0,
c10::nullopt,
c10::nullopt,
MPSGraphRandomDistributionUniform,
generator,
"ranperm_out_mps:" + mps::getTensorsStringKey({result}),
random_op_block);
}
Tensor& multinomial_with_replacement_mps_kernel(
const Tensor& self,
Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
const int64_t n_sample,
c10::optional<Generator> generator,
Tensor& result) {
using namespace mps;
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(generator, at::mps::detail::getDefaultMPSGenerator());
int inputSize = self.dim();
int numDist =
inputSize == 1 ? 1 : self.size(0);
int numCategories =
inputSize == 1 ? self.size(0) : self.size(1);
int numDist = inputSize == 1 ? 1 : self.size(0);
int numCategories = inputSize == 1 ? self.size(0) : self.size(1);
// Restructure data for 2d
auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self;
@ -408,24 +435,22 @@ Tensor& multinomial_with_replacement_mps_kernel(
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
RandomCachedGraph *newCachedGraph = nil;
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^MPSCachedGraph*() {
RandomCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSShape* prob_shape = getMPSShape(self_v);
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new RandomCachedGraph(mpsGraph);
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]);
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]);
auto prob_dtype = getMPSDataType(self_v);
// This is probability weights
newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v), prob_shape);
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor
axis:-1
name:nil];
MPSGraphTensor* sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor axis:-1 name:nil];
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
MPSGraphTensor* normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
secondaryTensor:sumProbs
name:nil];
@ -433,73 +458,67 @@ Tensor& multinomial_with_replacement_mps_kernel(
auto ns_numDist = [NSNumber numberWithInt:numDist];
auto ns_n_sample = [NSNumber numberWithInt:n_sample];
MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
shape:@[ns_numCategories, ns_numCategories]
MPSGraphTensor* ones = [mpsGraph constantWithScalar:1.0f
shape:@[ ns_numCategories, ns_numCategories ]
dataType:prob_dtype];
auto zeroTensor = [mpsGraph constantWithScalar: 0.0f
dataType: MPSDataTypeInt32];
auto minusOneTensor = [mpsGraph constantWithScalar: -1.0f
dataType: MPSDataTypeInt32];
auto zeroTensor = [mpsGraph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
auto minusOneTensor = [mpsGraph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
MPSGraphTensor* upperTriangle = [mpsGraph bandPartWithTensor:ones
numLowerTensor:zeroTensor
numUpperTensor:minusOneTensor
name:nil];
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
MPSGraphTensor* upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
secondaryTensor:upperTriangle
name:nil];
MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
MPSGraphTensor* lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
secondaryTensor:normalizedProbs
name:nil];
upperProbRange = [mpsGraph reshapeTensor:upperProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
withShape:@[ ns_numDist, @1, ns_numCategories ]
name:nil];
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
withShape:@[ ns_numDist, @1, ns_numCategories ]
name:nil];
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
MPSGraphRandomOpDescriptor* descriptor =
[MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
dataType:prob_dtype];
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
NSArray<MPSGraphTensor*>* generatorTensors = [mpsGraph randomTensorWithShape:@[ ns_numDist, ns_n_sample, @1 ]
descriptor:descriptor
stateTensor:newCachedGraph->stateTensor
name:nil];
MPSGraphTensor *randomTensor = generatorTensors[0];
MPSGraphTensor* randomTensor = generatorTensors[0];
auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
auto broadcastShape = @[ ns_numDist, ns_n_sample, ns_numCategories ];
int broadcastShapeVals[3] = {numDist, static_cast<int>(n_sample), numCategories};
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
MPSGraphTensor* broadcastShapeTensor = [mpsGraph
constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
shape:@[ [NSNumber numberWithUnsignedInteger:broadcastShape.count] ]
dataType:MPSDataTypeUInt32];
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
toShape:broadcastShape
name:nil];
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
MPSGraphTensor* samplesTensor = [mpsGraph broadcastTensor:randomTensor toShape:broadcastShape name:nil];
MPSGraphTensor* sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
secondaryTensor:lowerProbRange
name:nil];
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
MPSGraphTensor* sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
secondaryTensor:upperProbRange
name:nil];
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
MPSGraphTensor* sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
secondaryTensor:sampleBelow
name:nil];
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
toType:MPSDataTypeInt32
name:@"sampleMask"];
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
MPSGraphTensor* sampleMask = [mpsGraph castTensor:sampleWithin toType:MPSDataTypeInt32 name:@"sampleMask"];
MPSGraphTensor* categoriesTensor = [mpsGraph coordinateAlongAxis:-1
withShapeTensor:broadcastShapeTensor
name:nil];
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
MPSGraphTensor* binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
secondaryTensor:sampleMask
name:nil];
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
axis:-1
name:nil];
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ns_numDist ,ns_n_sample]
MPSGraphTensor* reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor axis:-1 name:nil];
MPSGraphTensor* reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ ns_numDist, ns_n_sample ]
name:nil];
newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor
toType:getMPSDataType(result)
@ -509,32 +528,31 @@ Tensor& multinomial_with_replacement_mps_kernel(
});
}
// update the Philox state values on each run of the same graph
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]];
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
MPSNDArrayDescriptor* stateDesc =
[MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]];
MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease];
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(mps_gen->mutex_);
// update the Philox state values on each run
mps_gen->update_philox_counters();
[stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil];
[stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil];
}
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease];
auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v);
auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
cachedGraph->stateTensor : stateTensorData,
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return result;
}
/* The largest consecutive integer representable in float32 (2^24) */
@ -545,27 +563,20 @@ Tensor& multinomial_out_mps(const Tensor& self,
bool with_replacement,
c10::optional<Generator> gen,
Tensor& result) {
TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device");
TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ", result.scalar_type());
TORCH_CHECK(
result.scalar_type() == ScalarType::Long, "multinomial expects Long tensor out, got: ", result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
int64_t n_categories = self.size(-1);
TORCH_CHECK(with_replacement || (n_sample <= n_categories),
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");
TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24");
if (self.dim() == 1) {
result.resize_({n_sample});
@ -583,19 +594,15 @@ Tensor& multinomial_out_mps(const Tensor& self,
if (!with_replacement || n_sample == 1) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
TORCH_CHECK(is_valid.to<bool>(), "probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
if (self.dim() == 1){
if (self.dim() == 1) {
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");
TORCH_CHECK(!zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
@ -625,11 +632,7 @@ Tensor& multinomial_out_mps(const Tensor& self,
return result;
}
Tensor multinomial_mps(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen) {
Tensor multinomial_mps(const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional<Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));
multinomial_out_mps(self, n_sample, with_replacement, gen, result);
return result;

View File

@ -3,9 +3,8 @@
#include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>
#include <c10/util/Optional.h>
#include <torch/library.h>
// Steps to add op for MPS backend:
// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key
@ -29,7 +28,6 @@
// g) Then call runMPSGraph() with input params and return the result.
//
namespace at::native {
Tensor& eye_out_mps(int64_t n, Tensor& result) {
@ -38,7 +36,6 @@ Tensor& eye_out_mps(int64_t n, Tensor& result) {
}
Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
// This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);
@ -47,7 +44,7 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
result.zero_();
// Handle empty outputs
if(result.numel() == 0)
if (result.numel() == 0)
return result;
// Get MPS stream
@ -55,25 +52,24 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph time and time again for the same operation
// The keys of this structure are based on the inputs and outputs needed for the operation
// Here, we don't have any input tensors, just an output tensor
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph
// time and time again for the same operation The keys of this structure are based on the inputs and outputs needed
// for the operation Here, we don't have any input tensors, just an output tensor
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
// etc match the earlier created MPSGraph
string key = "eye_out_mps:" + getTensorsStringKey({result});
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
// Initialize graph
@ -84,11 +80,9 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
dataType:getMPSDataType(result)];
// Here we can call the MPSGraph API needed to execute the operation.
// The API details can be found here: https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor
numLower:0
numUpper:0
name:nil];
// The API details can be found here:
// https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil];
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
@ -102,9 +96,8 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
// In this case, there are no inputs, so the feeds are nil
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
// Run the graph
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
@ -113,5 +106,4 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
return result;
}
} // namespace at::native

View File

@ -1,12 +1,15 @@
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/GridSamplerUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at {
namespace native {
void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
void grid_sampler_2d_mps_impl(Tensor& output,
const Tensor& input,
const Tensor& grid,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners) {
// Grid Sampler support has been added in macOS 13.1
#if defined(__MAC_13_2)
@ -18,35 +21,43 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
MPSGraphPaddingMode paddingMode;
auto memory_format = input.suggest_memory_format();
MPSGraphTensorNamedDataLayout inputTensorLayout =
(memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
MPSGraphTensorNamedDataLayout inputTensorLayout = (memory_format == at::MemoryFormat::Contiguous)
? MPSGraphTensorNamedDataLayoutNCHW
: MPSGraphTensorNamedDataLayoutNHWC;
switch (static_cast<GridSamplerPadding>(padding_mode)) {
case GridSamplerPadding::Zeros:
paddingMode = MPSGraphPaddingModeZero; break;
paddingMode = MPSGraphPaddingModeZero;
break;
case GridSamplerPadding::Border:
TORCH_CHECK(false, "MPS: Unsupported Border padding mode"); break;
TORCH_CHECK(false, "MPS: Unsupported Border padding mode");
break;
case GridSamplerPadding::Reflection:
paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric; break;
paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric;
break;
default:
TORCH_CHECK(false, "MPS: Unrecognised Padding Mode: ", padding_mode);
}
switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
case GridSamplerInterpolation::Bilinear:
samplingMode = MPSGraphResizeBilinear; break;
samplingMode = MPSGraphResizeBilinear;
break;
case GridSamplerInterpolation::Nearest:
samplingMode = MPSGraphResizeNearest; break;
samplingMode = MPSGraphResizeNearest;
break;
case GridSamplerInterpolation::Bicubic:
TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation"); break;
TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation");
break;
default:
TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode); break;
TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode);
break;
}
MPSStream *stream = getCurrentMPSStream();
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* gridTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
@ -55,17 +66,13 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "grid_sampler_2d_mps" +
getTensorsStringKey({input, grid}) +
":" + std::to_string(interpolation_mode) +
":" + std::to_string(padding_mode) +
":" + std::to_string(align_corners);
string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) +
":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
@ -75,27 +82,27 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
MPSGraphTensor* outputTensor = nil;
if (static_cast<GridSamplerInterpolation>(interpolation_mode) == GridSamplerInterpolation::Nearest) {
outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor
coordinateTensor: gridTensor
layout: inputTensorLayout
normalizeCoordinates: TRUE
relativeCoordinates: FALSE
alignCorners: align_corners
paddingMode: paddingMode
nearestRoundingMode: MPSGraphResizeNearestRoundingModeRoundToEven
constantValue: 0.0f
name: nil];
outputTensor = [mpsGraph sampleGridWithSourceTensor:inputTensor
coordinateTensor:gridTensor
layout:inputTensorLayout
normalizeCoordinates:TRUE
relativeCoordinates:FALSE
alignCorners:align_corners
paddingMode:paddingMode
nearestRoundingMode:MPSGraphResizeNearestRoundingModeRoundToEven
constantValue:0.0f
name:nil];
} else {
outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor
coordinateTensor: gridTensor
layout: inputTensorLayout
normalizeCoordinates: TRUE
relativeCoordinates: FALSE
alignCorners: align_corners
paddingMode: paddingMode
samplingMode: samplingMode
constantValue: 0.0f
name: nil];
outputTensor = [mpsGraph sampleGridWithSourceTensor:inputTensor
coordinateTensor:gridTensor
layout:inputTensorLayout
normalizeCoordinates:TRUE
relativeCoordinates:FALSE
alignCorners:align_corners
paddingMode:paddingMode
samplingMode:samplingMode
constantValue:0.0f
name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
@ -104,29 +111,29 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder gridPlaceholder = Placeholder(cachedGraph->gridTensor_, grid);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
gridPlaceholder.getMPSGraphTensor() : gridPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
#endif // defined(__MAC_13_2)
}
Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
Tensor grid_sampler_2d_mps(const Tensor& input,
const Tensor& grid,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners) {
#if defined(__MAC_13_2)
bool xcode_sdk_13_2_or_higher = true;
@ -138,17 +145,16 @@ Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid,
TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.1. ",
"Falling back on CPU. This may have performance implications.");
return at::grid_sampler_2d(
input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners).clone().to("mps");
return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
.clone()
.to("mps");
}
auto in_size = input.sizes();
auto grid_size = grid.sizes();
auto output = at::empty(
{in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
grid_sampler_2d_mps_impl(
output, input, grid, interpolation_mode, padding_mode, align_corners);
grid_sampler_2d_mps_impl(output, input, grid, interpolation_mode, padding_mode, align_corners);
return output;
}

View File

@ -3,25 +3,25 @@
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/ceil_div.h>
#include <ATen/NativeFunctions.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/operations/Indexing.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/Resize.h>
#include <ATen/AccumulateType.h>
#include <torch/library.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/IndexingUtils.h>
#include <c10/util/irange.h>
#include <c10/core/QScheme.h>
#include <c10/util/SmallVector.h>
#include <ATen/native/IndexKernel.h>
#include <c10/util/irange.h>
#include <torch/library.h>
#ifdef __OBJC__
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
@ -29,8 +29,7 @@
namespace at::native {
static
bool dispatchIndexKernel(TensorIteratorBase& iter,
static bool dispatchIndexKernel(TensorIteratorBase& iter,
IntArrayRef index_size,
IntArrayRef index_stride,
bool index_select,
@ -48,7 +47,7 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
dispatch_sync(mpsStream->queue(), ^(){
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
NSError* error = nil;
constexpr uint32_t nOffsets = 3;
@ -59,13 +58,13 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
std::vector<uint32_t> iterShapeData(iterShape.size());
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);
for (const auto i: c10::irange(iterShape.size())) {
for (const auto i : c10::irange(iterShape.size())) {
TORCH_CHECK(i <= UINT32_MAX);
iterShapeData[i] = (uint32_t)(iterShape[i]);
}
for (const auto i: c10::irange(nDim)) {
for (const auto offset: c10::irange(nOffsets)) {
for (const auto i : c10::irange(nDim)) {
for (const auto offset : c10::irange(nOffsets)) {
strides[i][offset] = iter.strides(offset)[i];
}
}
@ -73,12 +72,14 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
error: &error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
@ -92,34 +93,34 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
kernelOffsetsTGSize = numThreads;
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues new] autorelease];
[constantValues setConstantValue: &num_indices type:MTLDataTypeUInt atIndex:0];
[constantValues setConstantValue:&num_indices type:MTLDataTypeUInt atIndex:0];
std::string indexFunction = getIndexFunctionName(inputTensor.scalar_type(), index_select, accumulate);
id<MTLFunction> indexKernelFunction = MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues);
id<MTLFunction> indexKernelFunction =
MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues);
id<MTLArgumentEncoder> argumentEncoder = [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease];
NSUInteger argumentBufferLength = argumentEncoder.encodedLength;
id<MTLBuffer> indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
[argumentEncoder setArgumentBuffer:indexAB offset:0];
for (uint32_t idx = 0; idx < num_indices; idx++) {
const Tensor& indexTensor = iter.tensor(idx+2);
[argumentEncoder setBuffer: getMTLBufferStorage(indexTensor)
offset: indexTensor.storage_offset() * indexTensor.element_size()
atIndex: idx];
const Tensor& indexTensor = iter.tensor(idx + 2);
[argumentEncoder setBuffer:getMTLBufferStorage(indexTensor)
offset:indexTensor.storage_offset() * indexTensor.element_size()
atIndex:idx];
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
}
// FIXME: PSO needs to be cached
id<MTLComputePipelineState> indexSelectPSO = [[device newComputePipelineStateWithFunction: indexKernelFunction
error: &error] autorelease];
id<MTLComputePipelineState> indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction
error:&error] autorelease];
TORCH_CHECK(indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
for (uint32_t idx = 0; idx < num_indices; idx++) {
const Tensor& indexTensor = iter.tensor(idx+2);
const Tensor& indexTensor = iter.tensor(idx + 2);
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
}
@ -129,15 +130,16 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
[computeEncoder setBytes:index_stride.data() length:sizeof(index_stride[0]) * index_stride.size() atIndex:2];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
[computeEncoder setBuffer:inputBuffer offset:inputTensor.storage_offset() * inputTensor.element_size() atIndex:4];
[computeEncoder setBuffer:outputBuffer offset:outputTensor.storage_offset() * outputTensor.element_size() atIndex:5];
[computeEncoder setBuffer:outputBuffer
offset:outputTensor.storage_offset() * outputTensor.element_size()
atIndex:5];
NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > numThreads)
tgSize = numThreads;
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: threadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
[computeEncoder endEncoding];
mpsStream->synchronize(SyncType::COMMIT_AND_CONTINUE);
@ -147,7 +149,11 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
return true;
}
static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const std::string& op, bool accumulate) {
static void validateInputData(const TensorIteratorBase& iter,
IntArrayRef index_size,
IntArrayRef index_stride,
const std::string& op,
bool accumulate) {
using namespace mps;
int64_t num_indices = index_size.size();
@ -159,13 +165,11 @@ static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_
if (accumulate) {
// No atomic support for the rest of dtypes
TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float ||
inputTensor.scalar_type() == ScalarType::Int ||
TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int ||
inputTensor.scalar_type() == ScalarType::Bool);
} else {
TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true) ||
inputTensor.scalar_type() == ScalarType::Float ||
inputTensor.scalar_type() == ScalarType::Half,
inputTensor.scalar_type() == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Half,
getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out"));
}
}
@ -186,41 +190,37 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray
}
}
static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) {
NoNamesGuard guard;
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
"masked_select: expected BoolTensor for mask");
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "masked_select: expected BoolTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");
auto mask_temp = (mask.dim() == 0)
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
: c10::MaybeOwned<Tensor>::borrowed(mask);
auto self_temp = (self.dim() == 0)
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
: c10::MaybeOwned<Tensor>::borrowed(self);
auto mask_temp =
(mask.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(mask);
auto self_temp =
(self.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(self);
// Cannot reassign to mask_temp and self_temp here! if they are
// owning and expand_outplace returns a borrow, the returned borrow
// would dangle.
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
at::index_out(
result, *std::get<1>(mask_self_expanded),
at::index_out(result,
*std::get<1>(mask_self_expanded),
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
return result;
}
static
Tensor nonzero_fallback(const Tensor& self) {
static Tensor nonzero_fallback(const Tensor& self) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");
return at::nonzero(self.to("cpu")).clone().to("mps");
}
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
if (!is_macos_13_or_newer()) {
Tensor out_fallback = nonzero_fallback(self);
at::native::resize_output(out_, out_fallback.sizes());
@ -237,18 +237,22 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
using namespace mps;
const uint32_t maxDimensions = 16;
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(),
"nonzero is not supported for tensors with more than INT_MAX elements, \
file a support request");
TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ",
out_.device(), " and self on ", self.device());
TORCH_CHECK(
out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
TORCH_CHECK(self.device() == out_.device(),
"expected self and out to be on the same device, but got out on ",
out_.device(),
" and self on ",
self.device());
TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions");
TORCH_CHECK(out_.is_mps());
MPSStream *stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* scatterDataTensor_ = nil;
@ -258,19 +262,14 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
stream->synchronize(SyncType::COMMIT_AND_WAIT);
Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt));
Tensor out = at::native::empty_mps(
{self.numel(), nDim == 0 ? 1 : nDim},
out_.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
{self.numel(), nDim == 0 ? 1 : nDim}, out_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
int64_t _apparentInputShape = 1;
for (auto dim : self.sizes()) {
_apparentInputShape *= dim;
}
MPSShape *apparentOutputShape = @[@(self.numel() * nDim)];
MPSShape *apparentInputShape = @[@(_apparentInputShape)];
MPSShape* apparentOutputShape = @[ @(self.numel() * nDim) ];
MPSShape* apparentInputShape = @[ @(_apparentInputShape) ];
// Pseudocode:
//
@ -284,67 +283,68 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "nonzero_out_mps" + getTensorsStringKey(self);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSDataType inputDataType = getMPSDataType(self);
MPSShape* inputShape = getMPSShape(self);
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
MPSGraphTensor *oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
MPSGraphTensor *minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
MPSGraphTensor* scatterDataTensor =
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
MPSGraphTensor* minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
MPSGraphTensor* inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor
axis:0
name:nil];
MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
MPSGraphTensor* countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor axis:0 name:nil];
MPSGraphTensor* maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
toType:MPSDataTypeInt32
name:@"castToInt32"];
MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor
axis:0
name:nil];
MPSGraphTensor *indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
MPSGraphTensor* indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor axis:0 name:nil];
MPSGraphTensor* indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
secondaryTensor:oneTensor
name:nil];
MPSGraphTensor *maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
MPSGraphTensor* maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
truePredicateTensor:indicesMinusOneTensor
falsePredicateTensor:minusMaxDimTensor
name:nil];
MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil]
withShape:@[@-1]
MPSGraphTensor* coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0
withShape:inputShape
name:nil]
withShape:@[ @-1 ]
name:nil];
if (nDim > 1) {
NSMutableArray<MPSGraphTensor*> *maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
NSMutableArray<MPSGraphTensor*> *coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
NSMutableArray<MPSGraphTensor*>* maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
NSMutableArray<MPSGraphTensor*>* coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim
dataType:MPSDataTypeInt32];
MPSGraphTensor* constantRankTensor = [mpsGraph constantWithScalar:nDim dataType:MPSDataTypeInt32];
maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
secondaryTensor:constantRankTensor
name:nil];
coordinatesTensorArray[0] = coordinatesTensor;
for (int i = 1; i < nDim; i++){
for (int i = 1; i < nDim; i++) {
maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
secondaryTensor:oneTensor
name:nil];
coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil]
withShape:@[@-1]
coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i
withShape:inputShape
name:nil]
withShape:@[ @-1 ]
name:nil];
}
maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil];
coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil];
}
MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
updatesTensor:coordinatesTensor
indicesTensor:maskedIndicesTensor
axis:0
@ -358,7 +358,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
@ -386,7 +386,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
return out_;
}
Tensor nonzero_mps(const Tensor& self){
Tensor nonzero_mps(const Tensor& self) {
if (!is_macos_13_or_newer()) {
return nonzero_fallback(self);
}
@ -395,13 +395,13 @@ Tensor nonzero_mps(const Tensor& self){
return nonzero_out_mps(self, out);
}
Tensor masked_select_mps(const Tensor & self, const Tensor & mask) {
Tensor masked_select_mps(const Tensor& self, const Tensor& mask) {
namedinference::compute_broadcast_outnames(self, mask);
Tensor result = at::empty({0}, self.options());
return masked_select_out_mps_impl(result, self, mask);
}
Tensor & masked_select_out_mps(const Tensor & self, const Tensor & mask, Tensor & result) {
Tensor& masked_select_out_mps(const Tensor& self, const Tensor& mask, Tensor& result) {
namedinference::compute_broadcast_outnames(self, mask);
return masked_select_out_mps_impl(result, self, mask);
}
@ -409,27 +409,22 @@ Tensor & masked_select_out_mps(const Tensor & self, const Tensor & mask, Tensor
Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
using namespace mps;
Tensor result = at::native::empty_mps(
self.sizes(),
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
Tensor result =
at::native::empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
auto total_dims = self.dim();
// It wraps the dims and checks that there are no repeated dims
auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
NSMutableArray<NSNumber*> * ns_dims = [[NSMutableArray<NSNumber*> new] autorelease];
NSMutableArray<NSNumber*>* ns_dims = [[NSMutableArray<NSNumber*> new] autorelease];
for (const auto i : c10::irange(total_dims)) {
if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
if (flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
[ns_dims addObject:[NSNumber numberWithInt:i]];
}
}
// Nothing to do, we return fast
if (dims.size() == 0 || self.numel() <=1) {
if (dims.size() == 0 || self.numel() <= 1) {
result.copy_(self);
return result;
}
@ -451,22 +446,20 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
}
@autoreleasepool {
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
// etc match the earlier created MPSGraph
string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]);
auto cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self));
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor
axes:ns_dims
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor axes:ns_dims name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
@ -475,36 +468,31 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
}
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
Placeholder inputPlaceholder = Placeholder(
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, result, /*mpsShape*/nil, /*gatherTensorData=*/false, outputDataType);
Placeholder inputPlaceholder =
Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, result, /*mpsShape*/ nil, /*gatherTensorData=*/false, outputDataType);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
// Run the graph
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return result;
}
TORCH_IMPL_FUNC(index_add_mps_out)(
const Tensor& self,
TORCH_IMPL_FUNC(index_add_mps_out)
(const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& source,
const Scalar& alpha,
const Tensor& result) {
using namespace mps;
MPSStream* stream = getCurrentMPSStream();
dim = maybe_wrap_dim(dim, self.dim());
@ -515,9 +503,8 @@ TORCH_IMPL_FUNC(index_add_mps_out)(
TORCH_CHECK(source.scalar_type() != ScalarType::Long, "index_add(): Expected non int64 dtype for source.");
auto casted_type = isFloatingType(source.scalar_type()) ? ScalarType::Float : ScalarType::Int;
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* indexTensor_ = nil;
MPSGraphTensor* sourceTensor_ = nil;
@ -528,13 +515,12 @@ TORCH_IMPL_FUNC(index_add_mps_out)(
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -585,17 +571,14 @@ TORCH_IMPL_FUNC(index_add_mps_out)(
sourcePlaceholder.getMPSGraphTensor() : sourcePlaceholder.getMPSGraphTensorData(),
cachedGraph->alphaTensor_ : getMPSGraphTensorFromScalar(stream, alpha_scalar),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
Tensor index_select_mps(const Tensor & self,
int64_t dim,
const Tensor & index) {
Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
IntArrayRef input_shape = self.sizes();
auto num_input_dims = input_shape.size();
@ -606,7 +589,7 @@ Tensor index_select_mps(const Tensor & self,
std::vector<int64_t> shape_data(num_input_dims);
// Calculate new shape
for(auto i : c10::irange(num_input_dims)) {
for (auto i : c10::irange(num_input_dims)) {
if (i == dim) {
shape_data[i] = num_indices;
} else {
@ -616,33 +599,24 @@ Tensor index_select_mps(const Tensor & self,
IntArrayRef output_shape = IntArrayRef(shape_data.data(), num_input_dims);
Tensor result = at::native::empty_mps(
output_shape,
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
Tensor result =
at::native::empty_mps(output_shape, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
index_select_out_mps(self, dim, index, result);
return result;
}
Tensor& index_select_out_mps(const Tensor & self,
int64_t dim,
const Tensor & index,
Tensor & output) {
Tensor& index_select_out_mps(const Tensor& self, int64_t dim, const Tensor& index, Tensor& output) {
using namespace mps;
MPSStream* stream = getCurrentMPSStream();
dim = maybe_wrap_dim(dim, self.dim());
// Checks
TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
"index_select(): Expected dtype int32 or int64 for index");
TORCH_CHECK(self.scalar_type() == output.scalar_type(),
"index_select(): self and output must have the same scalar type");
TORCH_CHECK(dim == 0 || dim < self.dim(),
"index_select(): Indexing dim ", dim, " is out of bounds of tensor");
TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor");
// Empty index
if (index.numel() == 0) {
@ -650,15 +624,14 @@ Tensor& index_select_out_mps(const Tensor & self,
}
// Scalar input
if (self.dim() == 0 && self.numel() == 1){
if (self.dim() == 0 && self.numel() == 1) {
output.copy_(self);
return output;
}
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* indexTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
@ -667,23 +640,20 @@ Tensor& index_select_out_mps(const Tensor & self,
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
auto inputType = getMPSDataType(self);
auto outputType = getMPSDataType(output);
if (inputType == MPSDataTypeUInt8 ||
(!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) {
if (inputType == MPSDataTypeUInt8 || (!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) {
inputType = MPSDataTypeInt8;
}
if (outputType == MPSDataTypeUInt8 ||
(!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) {
if (outputType == MPSDataTypeUInt8 || (!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) {
outputType = MPSDataTypeInt8;
}
@autoreleasepool {
string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -706,48 +676,55 @@ Tensor& index_select_out_mps(const Tensor & self,
});
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self,
/*mpsShape=*/nullptr, /*gatherTensorData=*/true, /*dataType=*/inputType);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_,
self,
/*mpsShape=*/nullptr,
/*gatherTensorData=*/true,
/*dataType=*/inputType);
Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output,
/*mpsShape=*/nullptr, /*gatherTensorData=*/false, /*dataType=*/outputType);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
output,
/*mpsShape=*/nullptr,
/*gatherTensorData=*/false,
/*dataType=*/outputType);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return output;
}
Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value) {
Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Scalar& value) {
using namespace mps;
if (self.numel() == 0) {
return self;
}
TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ",
mask.device(), " and self on ", self.device());
TORCH_CHECK(self.device() == mask.device(),
"expected self and mask to be on the same device, but got mask on ",
mask.device(),
" and self on ",
self.device());
TORCH_CHECK(mask.scalar_type() == kByte || mask.scalar_type() == kBool,
"expected mask dtype to be Bool but got ", mask.scalar_type());
"expected mask dtype to be Bool but got ",
mask.scalar_type());
auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *maskTensor_ = nil;
MPSGraphTensor *valueTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* maskTensor_ = nil;
MPSGraphTensor* valueTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -770,10 +747,9 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
@autoreleasepool {
string key = "masked_fill" + getTensorsStringKey({self, *b_mask}) + ":" + getMPSTypeString(value.type());
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -786,9 +762,7 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
MPSDataType valueType = getMPSScalarType(value.type());
MPSGraphTensor* castValueTensor = valueTensor;
if (valueType != inputDataType) {
castValueTensor = [mpsGraph castTensor:valueTensor
toType:inputDataType
name:@"castValueTensor"];
castValueTensor = [mpsGraph castTensor:valueTensor toType:inputDataType name:@"castValueTensor"];
}
MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor
@ -805,12 +779,12 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
});
}
Placeholder selfPlaceholder = Placeholder(
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
Placeholder maskPlaceholder = Placeholder(
cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nil, /*gatherTensorData=*/true, maskDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/false, inputDataType);
Placeholder selfPlaceholder =
Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
Placeholder maskPlaceholder =
Placeholder(cachedGraph->maskTensor_, *b_mask, /*mpsShape*/ nil, /*gatherTensorData=*/true, maskDataType);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/false, inputDataType);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@ -819,9 +793,8 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
cachedGraph->valueTensor_ : getMPSGraphTensorFromScalar(stream, valueScalar)
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -829,18 +802,18 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
return self;
}
Tensor embedding_dense_backward_mps(
const Tensor & grad_, const Tensor & indices, int64_t num_weights,
int64_t padding_idx, bool scale_grad_by_freq)
{
Tensor embedding_dense_backward_mps(const Tensor& grad_,
const Tensor& indices,
int64_t num_weights,
int64_t padding_idx,
bool scale_grad_by_freq) {
// TODO: implement padding_idx & scale_grad_by_freq.
namespace native_mps = at::native::mps;
struct CachedGraph : public native_mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *incomingGradTensor_ = nil;
MPSGraphTensor *indicesTensor_ = nil;
MPSGraphTensor *outgoingGradTensor_ = nil;
struct CachedGraph : public native_mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* incomingGradTensor_ = nil;
MPSGraphTensor* indicesTensor_ = nil;
MPSGraphTensor* outgoingGradTensor_ = nil;
};
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
@ -854,12 +827,7 @@ Tensor embedding_dense_backward_mps(
int64_t D = incoming_gradient_shape[num_incoming_gradient_dims - 1];
c10::SmallVector<int64_t, 2> outgoing_gradient_shape{num_weights, D};
Tensor outgoing_gradient = at::native::empty_mps(
IntArrayRef(outgoing_gradient_shape),
grad_.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
IntArrayRef(outgoing_gradient_shape), grad_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
if (outgoing_gradient.numel() == 0) {
return outgoing_gradient;
@ -868,21 +836,24 @@ Tensor embedding_dense_backward_mps(
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
string key = "edb_mps:" + native_mps::getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) + ":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq);
string key = "edb_mps:" + native_mps::getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) +
":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" +
std::to_string(scale_grad_by_freq);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
// Initialize once if configuration not found in cache
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ native_mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^native_mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* incomingGradTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_));
MPSGraphTensor* incomingGradTensor =
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_));
MPSGraphTensor* indicesTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices));
MPSGraphTensor* indicesTensor =
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices));
MPSGraphTensor* reshapedIndicesTensor = indicesTensor;
@ -890,31 +861,27 @@ Tensor embedding_dense_backward_mps(
MPSDataType dataType = mps::getMPSDataType(grad_);
// issue 105486100, scatterNDWithUpdatesTensor produces wrong result for float16
if (dataType == MPSDataTypeFloat16) {
castGradTensor = [mpsGraph castTensor: incomingGradTensor
toType: MPSDataTypeFloat32
name: @"castGradTensor"];
castGradTensor = [mpsGraph castTensor:incomingGradTensor toType:MPSDataTypeFloat32 name:@"castGradTensor"];
}
if (num_indices_dims != 0) {
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor: indicesTensor
axes: @[@-1]
name: nil];
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor:indicesTensor axes:@[ @-1 ] name:nil];
}
auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: castGradTensor
indicesTensor: reshapedIndicesTensor
shape: native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape))
batchDimensions: 0
mode: MPSGraphScatterModeAdd
name: @"edb"];
auto outgoingGradTensor =
[mpsGraph scatterNDWithUpdatesTensor:castGradTensor
indicesTensor:reshapedIndicesTensor
shape:native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape))
batchDimensions:0
mode:MPSGraphScatterModeAdd
name:@"edb"];
if (dataType == MPSDataTypeFloat16) {
outgoingGradTensor = [mpsGraph castTensor: outgoingGradTensor
toType: MPSDataTypeFloat16
name: @"castGradTensor"];
outgoingGradTensor = [mpsGraph castTensor:outgoingGradTensor
toType:MPSDataTypeFloat16
name:@"castGradTensor"];
}
newCachedGraph->incomingGradTensor_ = incomingGradTensor;
newCachedGraph->indicesTensor_ = indicesTensor;
newCachedGraph->outgoingGradTensor_ = outgoingGradTensor;
}
return newCachedGraph;
});
@ -923,29 +890,30 @@ Tensor embedding_dense_backward_mps(
auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices);
auto outgoingGradPlaceholder = native_mps::Placeholder(cachedGraph->outgoingGradTensor_, outgoing_gradient);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
incomingGradPlaceholder.getMPSGraphTensor() : incomingGradPlaceholder.getMPSGraphTensorData(),
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outgoingGradPlaceholder.getMPSGraphTensor() : outgoingGradPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outgoingGradPlaceholder.getMPSGraphTensor() : outgoingGradPlaceholder.getMPSGraphTensorData()};
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return outgoing_gradient;
}
Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Tensor & value) {
TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
"with ", value.dim(), " dimension(s).");
Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Tensor& value) {
TORCH_CHECK(value.dim() == 0,
"masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
"with ",
value.dim(),
" dimension(s).");
return masked_fill__mps(self, mask, value.item());
}
Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& source) {
Tensor& masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& source) {
at::assert_no_internal_overlap(self);
TORCH_CHECK(
self.scalar_type() == source.scalar_type(),
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
"masked_scatter: expected self and source to have same dtypes but got",
self.scalar_type(),
" and ",
@ -958,25 +926,22 @@ Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& sou
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
"masked_scatter: expected BoolTensor or ByteTensor for mask");
auto mask_temp = (mask.dim() == 0)
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
: c10::MaybeOwned<Tensor>::borrowed(mask);
auto self_temp = (self.dim() == 0)
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
: c10::MaybeOwned<Tensor>::borrowed(self);
auto mask_temp =
(mask.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(mask);
auto self_temp =
(self.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(self);
// Cannot reassign to mask_temp and self_temp here! if they are
// owning and expand_outplace returns a borrow, the returned borrow
// would dangle.
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
auto indices = at::native::expandTensors(
*std::get<1>(mask_self_expanded),
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))})
);
auto indices =
at::native::expandTensors(*std::get<1>(mask_self_expanded),
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
// next broadcast all index tensors together
try {
indices = at::expand_outplace(indices);
} catch (std::exception &e) {
} catch (std::exception& e) {
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together");
}
@ -987,15 +952,10 @@ Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& sou
c10::List<c10::optional<Tensor>> final_indices;
final_indices.reserve(indices.size());
for (const auto index: indices) {
for (const auto index : indices) {
final_indices.push_back(index);
}
return at::index_put_out(
self,
*std::get<1>(mask_self_expanded),
final_indices,
source.resize_(indices[0].numel())
);
return at::index_put_out(self, *std::get<1>(mask_self_expanded), final_indices, source.resize_(indices[0].numel()));
}
REGISTER_DISPATCH(index_stub, &index_kernel_mps);

View File

@ -1,17 +1,16 @@
#include <ATen/ATen.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <torch/library.h>
#include <ATen/native/mps/OperationUtils.h>
#include <c10/util/Optional.h>
#include <torch/library.h>
namespace at::native {
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info)
{
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
TORCH_WARN_ONCE(
"torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
auto cpu_result = result.clone().to("cpu");
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
@ -28,9 +27,8 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
return;
}
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
@ -46,39 +44,33 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
@autoreleasepool {
string key = "inv_out_mps" + getTensorsStringKey({A});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph)
{
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A);
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor
name: nil];
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
if (!isContiguous) {

View File

@ -6,17 +6,14 @@ namespace at::native {
using namespace mps;
Tensor _mps_linear(
const Tensor& input,
const Tensor& weight_arg,
const c10::optional<Tensor>& bias_opt) {
Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::optional<Tensor>& bias_opt) {
// wT = transpose(weight);
// y=x*wT+b
auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg;
TORCH_CHECK(input.scalar_type() == ScalarType::Float ||
input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs");
TORCH_CHECK(input.scalar_type() == ScalarType::Float || input.scalar_type() == ScalarType::Half,
"MPS device does not support linear for non-float inputs");
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
bool is_bias_defined = bias.defined();
@ -24,24 +21,19 @@ Tensor _mps_linear(
auto input_size = input.sizes();
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight.size(0));
Tensor output = at::native::empty_mps(output_size,
input.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
input.suggest_memory_format());
Tensor output = at::native::empty_mps(
output_size, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, input.suggest_memory_format());
TORCH_CHECK(output.is_mps());
if(output.numel() == 0) {
if (output.numel() == 0) {
return output;
}
MPSStream *stream = getCurrentMPSStream();
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* biasTensor_ = nil;
@ -51,14 +43,12 @@ Tensor _mps_linear(
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "mps_linear" + getTensorsStringKey({input, weight, bias}) ;
string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
@ -71,14 +61,11 @@ Tensor _mps_linear(
name:nil];
MPSGraphTensor* outputTensor = nil;
if (!is_bias_defined)
{
if (!is_bias_defined) {
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor
secondaryTensor:weightTransposeTensor
name:nil];
}
else
{
} else {
MPSGraphTensor* inputFlattened = inputTensor;
bool doReshape = false;
// workaround to improve the performance with 3D+ inputs
@ -94,7 +81,8 @@ Tensor _mps_linear(
MPSGraphTensor* biasedTensor = [mpsGraph additionWithPrimaryTensor:xMulWTTensor
secondaryTensor:newCachedGraph->biasTensor_
name:nil];
outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil] : biasedTensor;
outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil]
: biasedTensor;
}
newCachedGraph->inputTensor_ = inputTensor;
@ -110,89 +98,76 @@ Tensor _mps_linear(
Placeholder biasPlaceholder = Placeholder();
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
if (is_bias_defined) {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
// Shave off '1' present at the end of the shape
if(weight_arg.dim() == 1) {
if (weight_arg.dim() == 1) {
// Number of elements in new output shape
auto output_sizes = output.sizes();
std::vector<int64_t> out_shape(output_sizes.begin(), output_sizes.end()-1);
std::vector<int64_t> out_shape(output_sizes.begin(), output_sizes.end() - 1);
return output.view(IntArrayRef(out_shape));
}
return output;
}
Tensor _mps_linear_backward_input(
IntArrayRef input_size,
const Tensor & grad_output,
const Tensor & weight)
{
TORCH_CHECK(grad_output.is_mps(),
"mps_linear_backward: grad_output needs to be mps layout");
TORCH_CHECK(weight.device().is_mps() &&
(weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)),
"mps_linear_backward: unsupported weights data type: ", weight.scalar_type());
Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) {
TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout");
TORCH_CHECK(weight.device().is_mps() && (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)),
"mps_linear_backward: unsupported weights data type: ",
weight.scalar_type());
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double
|| grad_output.scalar_type() == ScalarType::Float
|| grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double || grad_output.scalar_type() == ScalarType::Float ||
grad_output.scalar_type() == ScalarType::Half,
"MPS device does not support linear backward for non-float inputs");
const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous();
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *weightTensor_ = nil;
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
Tensor output = at::native::empty_mps(input_size,
grad_output.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
grad_output.suggest_memory_format());
Tensor output = at::native::empty_mps(
input_size, grad_output.scalar_type(), c10::nullopt, kMPS, c10::nullopt, grad_output.suggest_memory_format());
TORCH_CHECK(output.is_mps());
if (grad_output.numel() == 0) {
return output;
}
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSStream *stream= getCurrentMPSStream();
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph *mpsGraph = make_mps_graph();
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor *outputTensor =
[mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTensor
secondaryTensor: weightTensor
name: nil];
MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensor
secondaryTensor:weightTensor
name:nil];
newCachedGraph->weightTensor_ = weightTensor;
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
@ -211,9 +186,8 @@ Tensor _mps_linear_backward_input(
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
@ -221,27 +195,27 @@ Tensor _mps_linear_backward_input(
}
}
std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
const Tensor& grad_output, const Tensor& input, const Tensor& weight, bool bias_defined)
{
std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
bool bias_defined) {
TORCH_CHECK(grad_output.is_mps() && input.is_mps(),
"_mps_linear_backward: grad_output and input needs to be mps layout");
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float ||
grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || grad_output.scalar_type() == ScalarType::Half,
"MPS device does not support linear backward for non-float inputs");
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *weightTensor_ = nil;
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
MPSGraphTensor *biasTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* biasTensor_ = nil;
};
auto grad_output_reshaped = grad_output.dim() != 2 ?
grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
auto grad_output_reshaped =
grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input;
TORCH_CHECK(grad_output_reshaped.is_mps());
@ -265,48 +239,41 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
if (grad_output.numel() == 0) {
output.zero_();
bias.zero_();
return std::tuple<Tensor, Tensor>{ output, bias };
return std::tuple<Tensor, Tensor>{output, bias};
}
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSStream *stream= getCurrentMPSStream();
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph *mpsGraph = make_mps_graph();
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped);
MPSGraphTensor *gradOutputTransposeTensor =
[mpsGraph transposeTensor: gradOutputTensor
dimension: -1
withDimension: -2
name: nil];
MPSGraphTensor* gradOutputTransposeTensor = [mpsGraph transposeTensor:gradOutputTensor
dimension:-1
withDimension:-2
name:nil];
// grad_weight
MPSGraphTensor *outputTensor =
[mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTransposeTensor
secondaryTensor: inputTensor
name: nil];
MPSGraphTensor *biasTensor = nil;
if (bias_defined)
{
MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTransposeTensor
secondaryTensor:inputTensor
name:nil];
MPSGraphTensor* biasTensor = nil;
if (bias_defined) {
// grad_bias
biasTensor = [mpsGraph reductionSumWithTensor: gradOutputTensor
axis: 0
name: nil];
biasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:0 name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
@ -338,14 +305,14 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return std::tuple<Tensor, Tensor>{ output, bias };
return std::tuple<Tensor, Tensor>{output, bias};
}
}
std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(
const Tensor& input, const Tensor& grad_output,
const Tensor& weight, std::array<bool,3> output_mask) {
std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
std::array<bool, 3> output_mask) {
Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
grad_input = _mps_linear_backward_input(input.sizes(), grad_output, weight);

View File

@ -1,8 +1,8 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
@ -15,19 +15,18 @@ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor,
bool& transpose_tensor,
int64_t& ld_tensor,
bool transpose_result,
int64_t m, int64_t n) {
int64_t m,
int64_t n) {
IntArrayRef tensor_strides = tensor.strides();
Tensor tensor_;
int fast_dim = transpose_result ? 2 : 1;
int leading_dim = transpose_result ? 1 : 2;
if (tensor_strides[fast_dim] == 1 &&
(tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
if (tensor_strides[fast_dim] == 1 && (tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
transpose_tensor = false;
tensor_ = tensor;
ld_tensor = tensor_strides[leading_dim];
} else if ((tensor_strides[leading_dim] == 1) &&
(tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
} else if ((tensor_strides[leading_dim] == 1) && (tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
transpose_tensor = true;
tensor_ = tensor;
ld_tensor = tensor_strides[fast_dim];
@ -50,14 +49,13 @@ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor,
* Helper functions to be used for mm/addmm for detecting the Transpositions
* when doing GEMM operations.
*/
void prepare_matrices_for_broadcasting(
const Tensor * bias,
const Tensor & self,
const Tensor & other,
const Scalar * beta,
bool * transpose_mat1_times_mat2,
bool & transpose_mat1,
bool & transpose_mat2) {
void prepare_matrices_for_broadcasting(const Tensor* bias,
const Tensor& self,
const Tensor& other,
const Scalar* beta,
bool* transpose_mat1_times_mat2,
bool& transpose_mat1,
bool& transpose_mat2) {
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
if (bias && beta->toDouble() != 0.0f) {
TORCH_CHECK(bias->dim() == 2, "tensors must be 2-D");
@ -79,20 +77,14 @@ void prepare_matrices_for_broadcasting(
}
}
enum LinearAlgebraOpType {
ADDBMM_OP_TYPE,
BADDBMM_OP_TYPE
};
enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
Tensor& mm_out_mps_impl(
const Tensor& self,
const Tensor& other,
Tensor& output) {
Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps;
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.scalar_type() == ScalarType::Double
|| self.scalar_type() == ScalarType::Float
|| self.scalar_type() == ScalarType::Half, "MPS device does not support mm for non-float inputs");
TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float ||
self.scalar_type() == ScalarType::Half,
"MPS device does not support mm for non-float inputs");
TensorArg args[]{{output, "out", 0}, {self, "mat1", 1}, {other, "mat2", 2}};
checkAllSameGPU("mm", args);
@ -105,45 +97,39 @@ Tensor& mm_out_mps_impl(
return output;
}
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *selfTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* selfTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSStream* stream = getCurrentMPSStream();
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
@autoreleasepool {
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *selfTensor = nil;
MPSGraphTensor *otherTensor = nil;
MPSGraphTensor *outputTensor = nil;
if(self.numel() == 0 || other.numel() == 0) {
MPSGraphTensor* selfTensor = nil;
MPSGraphTensor* otherTensor = nil;
MPSGraphTensor* outputTensor = nil;
if (self.numel() == 0 || other.numel() == 0) {
outputTensor = [mpsGraph constantWithScalar:0.
shape:getMPSShape(output_sizes)
dataType:getMPSDataType(output)];
}
else {
} else {
selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfTensor
@ -157,11 +143,11 @@ Tensor& mm_out_mps_impl(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder();
Placeholder otherPlaceholder = Placeholder();
if(!(self.numel() == 0 || other.numel() == 0)) {
if (!(self.numel() == 0 || other.numel() == 0)) {
selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
}
@ -169,15 +155,14 @@ Tensor& mm_out_mps_impl(
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
if(!(self.numel() == 0 || other.numel() == 0))
if (!(self.numel() == 0 || other.numel() == 0))
feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -185,26 +170,25 @@ Tensor& mm_out_mps_impl(
return output;
}
Tensor addr_mps(const Tensor& self,
const Tensor& vec1, const Tensor& vec2,
const Scalar& beta, const Scalar& alpha) {
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
Tensor result = at::empty({0}, self.options());
addr_out_mps(self, vec1,vec2,beta,alpha,result);
addr_out_mps(self, vec1, vec2, beta, alpha, result);
return result;
}
Tensor& addr_out_mps(const Tensor& self,
const Tensor& vec1, const Tensor& vec2,
const Scalar& beta, const Scalar& alpha, Tensor &result) {
const Tensor& vec1,
const Tensor& vec2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
using namespace mps;
TORCH_CHECK(result.is_mps());
TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
|| vec1.scalar_type() == ScalarType::Float
|| vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double || vec1.scalar_type() == ScalarType::Float ||
vec1.scalar_type() == ScalarType::Half,
"MPS device does not support addr for non-float input");
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
checkAllSameGPU(__func__, args);
@ -239,37 +223,34 @@ Tensor& addr_out_mps(const Tensor& self,
MPSStream* stream = getCurrentMPSStream();
bool is_beta_non_zero = beta.toDouble() != 0.0;
MPSShape* inputShape = @[@(vec1.numel()), @(1)];
MPSShape* otherShape = @[@(1), @(vec2.numel())];
MPSShape* inputShape = @[ @(vec1.numel()), @(1) ];
MPSShape* otherShape = @[ @(1), @(vec2.numel()) ];
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *vec1Tensor_ = nil;
MPSGraphTensor *vec2Tensor_ = nil;
MPSGraphTensor *selfTensor_ = nil;
MPSGraphTensor *resultTensor_ = nil;
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* vec1Tensor_ = nil;
MPSGraphTensor* vec2Tensor_ = nil;
MPSGraphTensor* selfTensor_ = nil;
MPSGraphTensor* resultTensor_ = nil;
};
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
@autoreleasepool {
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
+ ":" + to_string(beta.toDouble())
+ ":" + to_string(alpha.toDouble());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) +
":" + to_string(alpha.toDouble());
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
// Intermediate as placeholder
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
@ -307,7 +288,7 @@ Tensor& addr_out_mps(const Tensor& self,
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape);
@ -321,9 +302,8 @@ Tensor& addr_out_mps(const Tensor& self,
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()};
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -331,8 +311,7 @@ Tensor& addr_out_mps(const Tensor& self,
return result;
}
Tensor& addmm_out_mps_impl(
const Tensor& bias,
Tensor& addmm_out_mps_impl(const Tensor& bias,
const Tensor& self, // input
const Tensor& other, // weight
const Scalar& beta,
@ -342,9 +321,9 @@ Tensor& addmm_out_mps_impl(
TORCH_CHECK(output.is_mps());
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.scalar_type() == ScalarType::Double
|| self.scalar_type() == ScalarType::Float
|| self.scalar_type() == ScalarType::Half, "MPS device does not support addmm for non-float input");
TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float ||
self.scalar_type() == ScalarType::Half,
"MPS device does not support addmm for non-float input");
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
checkAllSameGPU(__func__, args);
@ -382,58 +361,48 @@ Tensor& addmm_out_mps_impl(
bool transpose_mat2 = false;
bool is_beta_non_zero = beta.toDouble() != 0.0;
prepare_matrices_for_broadcasting(&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2);
prepare_matrices_for_broadcasting(
&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2);
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *selfTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *biasTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* selfTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
MPSGraphTensor* biasTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
@autoreleasepool {
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_})
+ ":" + to_string(transpose_mat1) + ":" + to_string(transpose_mat2)
+ ":" + to_string(beta.toDouble())
+ ":" + to_string(alpha.toDouble());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(transpose_mat1) +
":" + to_string(transpose_mat2) + ":" + to_string(beta.toDouble()) + ":" + to_string(alpha.toDouble());
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor *biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor* biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
MPSGraphTensor* t1 = nil;
MPSGraphTensor* t2 = nil;
if(transpose_mat1)
t1 = [mpsGraph transposeTensor:selfTensor
dimension:-1
withDimension:-2
name:nil];
if (transpose_mat1)
t1 = [mpsGraph transposeTensor:selfTensor dimension:-1 withDimension:-2 name:nil];
else
t1 = selfTensor;
if(transpose_mat2)
t2 = [mpsGraph transposeTensor:otherTensor
dimension:-1
withDimension:-2
name:nil];
if (transpose_mat2)
t2 = [mpsGraph transposeTensor:otherTensor dimension:-1 withDimension:-2 name:nil];
else
t2 = otherTensor;
// TODO: Use alpha and beta here with fill_.Scalar and mul
// Intermediate as placeholder
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
@ -458,10 +427,7 @@ Tensor& addmm_out_mps_impl(
}
if (transpose_mat1_times_mat2)
biasTimesBetaTensor = [mpsGraph transposeTensor: biasTimesBetaTensor
dimension: -1
withDimension: -2
name: nil];
biasTimesBetaTensor = [mpsGraph transposeTensor:biasTimesBetaTensor dimension:-1 withDimension:-2 name:nil];
MPSGraphTensor* outputTensor = productTimesAlphaTensor;
if (is_beta_non_zero) {
@ -477,7 +443,7 @@ Tensor& addmm_out_mps_impl(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
@ -491,9 +457,8 @@ Tensor& addmm_out_mps_impl(
biasPlaceholder.getMPSGraphTensor() : biasPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -501,16 +466,12 @@ Tensor& addmm_out_mps_impl(
return output;
}
Tensor& bmm_out_mps_impl(
const Tensor & batch1,
const Tensor & batch2,
Tensor & result) {
Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) {
using namespace mps;
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
|| batch1.scalar_type() == ScalarType::Float
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs");
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float ||
batch1.scalar_type() == ScalarType::Half,
"MPS device does not support bmm for non-float inputs");
if (batch1.numel() == 0 || batch2.numel() == 0) {
result.zero_();
@ -519,31 +480,29 @@ Tensor& bmm_out_mps_impl(
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *batch1Tensor_ = nil;
MPSGraphTensor *batch2Tensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* batch1Tensor_ = nil;
MPSGraphTensor* batch2Tensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
@autoreleasepool {
string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
MPSGraphTensor* batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
MPSGraphTensor* batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor
secondaryTensor:batch2Tensor
@ -555,7 +514,7 @@ Tensor& bmm_out_mps_impl(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1);
Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2);
@ -566,9 +525,8 @@ Tensor& bmm_out_mps_impl(
batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -576,13 +534,12 @@ Tensor& bmm_out_mps_impl(
return result;
}
Tensor& addbmm_or_baddbmm_out_mps_impl(
const Tensor & input,
const Tensor & batch1,
const Tensor & batch2,
const Scalar & beta,
const Scalar & alpha,
Tensor & result,
Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result,
LinearAlgebraOpType opType) {
using namespace mps;
@ -591,22 +548,29 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
TORCH_CHECK(batch2.is_mps());
TORCH_CHECK(result.is_mps());
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
|| batch1.scalar_type() == ScalarType::Float
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support addbmm or baddbmm for non-float inputs");
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float ||
batch1.scalar_type() == ScalarType::Half,
"MPS device does not support addbmm or baddbmm for non-float inputs");
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
TORCH_CHECK(batch1.size(0) == batch2.size(0),
"batch1 and batch2 must have same number of batches, got ",
batch1.size(0), " and ", batch2.size(0));
batch1.size(0),
" and ",
batch2.size(0));
TORCH_CHECK(batch1.size(2) == batch2.size(1),
"Incompatible matrix sizes for bmm (",
batch1.size(1), "x", batch1.size(2), " and ",
batch2.size(1), "x", batch2.size(2), ")");
batch1.size(1),
"x",
batch1.size(2),
" and ",
batch2.size(1),
"x",
batch2.size(2),
")");
if (opType == ADDBMM_OP_TYPE)
{
if (opType == ADDBMM_OP_TYPE) {
result.resize_as_(input);
const int64_t num_batches = batch1.size(0);
@ -619,42 +583,39 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *batch1Tensor_ = nil;
MPSGraphTensor *batch2Tensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* batch1Tensor_ = nil;
MPSGraphTensor* batch2Tensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
@autoreleasepool {
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
key += getTensorsStringKey({batch1, batch2, input})
+ ":" + to_string(beta.toDouble())
+ ":" + to_string(alpha.toDouble());
key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" +
to_string(alpha.toDouble());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool{
MPSGraph *mpsGraph = mps::make_mps_graph();
@autoreleasepool {
MPSGraph* mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor *inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
MPSGraphTensor* batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
// Intermediates for beta and alpha
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta.toDouble()
dataType: getMPSScalarType(input.scalar_type())];
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha.toDouble()
dataType: getMPSScalarType(batch1.scalar_type())];
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
dataType:getMPSScalarType(input.scalar_type())];
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
dataType:getMPSScalarType(batch1.scalar_type())];
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor
secondaryTensor:batch2Tensor
@ -662,18 +623,19 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
MPSGraphTensor* reductionSumTensor = productTensor;
if (opType == ADDBMM_OP_TYPE) {
reductionSumTensor = [mpsGraph reductionSumWithTensor: productTensor
axis: 0
name: @"reductionSum(batch1@batch2)"];
reductionSumTensor = [mpsGraph reductionSumWithTensor:productTensor
axis:0
name:@"reductionSum(batch1@batch2)"];
}
// Intermediates for multiplying by beta and alpha
MPSGraphTensor* reductionSumTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: reductionSumTensor
secondaryTensor: alphaTensor
name: @"alpha*(batch1@batch2)"];
MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
secondaryTensor: betaTensor
name: @"beta*input"];
MPSGraphTensor* reductionSumTimesAlphaTensor =
[mpsGraph multiplicationWithPrimaryTensor:reductionSumTensor
secondaryTensor:alphaTensor
name:@"alpha*(batch1@batch2)"];
MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
secondaryTensor:betaTensor
name:@"beta*input"];
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:reductionSumTimesAlphaTensor
secondaryTensor:biasTimesBetaTensor
@ -686,7 +648,7 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1);
@ -699,9 +661,8 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -713,40 +674,67 @@ TORCH_IMPL_FUNC(mm_out_mps)(const Tensor& self, const Tensor& mat2, const Tensor
mm_out_mps_impl(self, mat2, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(addmm_out_mps)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
TORCH_IMPL_FUNC(addmm_out_mps)
(const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
addmm_out_mps_impl(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(bmm_out_mps) (const Tensor & batch1, const Tensor & batch2, const Tensor & result) {
TORCH_IMPL_FUNC(bmm_out_mps)(const Tensor& batch1, const Tensor& batch2, const Tensor& result) {
bmm_out_mps_impl(batch1, batch2, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(baddbmm_out_mps) (const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
TORCH_IMPL_FUNC(baddbmm_out_mps)
(const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
addbmm_or_baddbmm_out_mps_impl(self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result), BADDBMM_OP_TYPE);
}
Tensor& addbmm_out_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
Tensor& addbmm_out_mps(const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out");
addbmm_or_baddbmm_out_mps_impl(*b_self, batch1, batch2, beta, alpha, result, ADDBMM_OP_TYPE);
return result;
}
Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
Tensor addbmm_mps(const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
Tensor result = at::empty({0}, self.options());
return addbmm_out_mps(self, batch1, batch2, beta, alpha, result);
}
Tensor &addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
Tensor& addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
return addbmm_out_mps(self, batch1, batch2, beta, alpha, self);
}
Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) {
Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
const Tensor& B,
bool upper,
bool transpose,
bool left,
bool unitriangular,
Tensor& out) {
using namespace mps;
checkInputsSolver(A, B, left, "linalg.solve_triangular");
Tensor A_t, B_t;
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr);
at::native::resize_output(out, B_t.sizes());
if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) {
@ -768,7 +756,7 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
dispatch_sync(mpsStream->queue(), ^(){
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1;
@ -779,7 +767,7 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
uint64_t aElemSize = A_.element_size();
uint64_t bElemSize = B_.element_size();
MPSMatrixSolveTriangular *filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device
MPSMatrixSolveTriangular* filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device
right:!left
upper:upper
transpose:transpose
@ -794,22 +782,24 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
rowBytes:aCols * aElemSize
matrixBytes:aRows * aCols * aElemSize
dataType:getMPSDataType(A_)];
MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows
MPSMatrixDescriptor* rightHandSideMatrixDesc =
[MPSMatrixDescriptor matrixDescriptorWithRows:bRows
columns:bCols
matrices:batchSize
rowBytes:bCols * bElemSize
matrixBytes:bRows * bCols * bElemSize
dataType:getMPSDataType(B_)];
for (const auto i: c10::irange(batchSize)) {
for (const auto i : c10::irange(batchSize)) {
const uint64_t aBatchOffset = i * aRows * aCols;
const uint64_t bBatchOffset = i * bRows * bCols;
MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
offset:(A_t.storage_offset() + aBatchOffset) * aElemSize
descriptor:sourceMatrixDesc] autorelease];
MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer
MPSMatrix* rightHandSideMatrix =
[[[MPSMatrix alloc] initWithBuffer:bBuffer
offset:(B_t.storage_offset() + bBatchOffset) * bElemSize
descriptor:rightHandSideMatrixDesc] autorelease];
MPSMatrix *solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
offset:(out.storage_offset() + bBatchOffset) * bElemSize
descriptor:rightHandSideMatrixDesc] autorelease];
@ -824,7 +814,12 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
return out;
}
Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) {
Tensor& linalg_solve_triangular_mps_out(const Tensor& A,
const Tensor& B,
bool upper,
bool left,
bool unitriangular,
Tensor& out) {
return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
}
@ -834,7 +829,14 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper,
return out;
}
TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
TORCH_IMPL_FUNC(triangular_solve_mps_out)
(const Tensor& self,
const Tensor& A,
bool upper,
bool transpose,
bool unitriangular,
const Tensor& result,
const Tensor& clone_A) {
clone_A.copy_(A);
Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7,15 +7,18 @@ namespace at::native {
namespace mps {
// Pad operations (1D/2D/3D forward and backward)
Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding,
Tensor& pad_out_template(Tensor& output,
const Tensor& input_,
IntArrayRef padding,
const c10::optional<Tensor>& grad_output_opt,
MPSGraphPaddingMode mode, double constantValue, const string op_name)
{
const int padding_size = (int) padding.size();
MPSGraphPaddingMode mode,
double constantValue,
const string op_name) {
const int padding_size = (int)padding.size();
int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
TORCH_CHECK(padding_size == 2 || padding_size == 4 || padding_size == 6,
"invalid padding argument of size ", padding_size);
TORCH_CHECK(
padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size);
const Tensor& grad_output_ = *(at::borrow_from_optional_tensor(grad_output_opt));
const bool is_backward_pass = grad_output_.defined();
@ -23,8 +26,13 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
int64_t nbatch = 1;
int64_t ndims = input_.ndimension();
TORCH_CHECK(ndims >= (int64_t)padding_dim, "Length of pad should be no more than twice the number of "
"dimensions of the input. Pad length is ", padding_size, "while the input has ", ndims, "dimensions.");
TORCH_CHECK(ndims >= (int64_t)padding_dim,
"Length of pad should be no more than twice the number of "
"dimensions of the input. Pad length is ",
padding_size,
"while the input has ",
ndims,
"dimensions.");
// number of input dims with ConstantPad could be less than 2
int dim_w = padding_dim;
@ -36,7 +44,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0;
TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) ||
(ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0),
"3D or 4D (batch mode) tensor expected for input, but got: ", input_);
"3D or 4D (batch mode) tensor expected for input, but got: ",
input_);
}
if (ndims == padding_dim) {
@ -73,8 +82,15 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
if (!is_backward_pass) {
TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1,
"input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
"output H: ", output_h, " W: ", output_w);
"input (H: ",
input_h,
", W: ",
input_w,
") is too small. Calculated "
"output H: ",
output_h,
" W: ",
output_w);
std::vector<int64_t> outputSizes;
if (mode == MPSGraphPaddingModeConstant) {
@ -83,7 +99,7 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
auto ori_padding_dim = padding_size / 2;
auto l_diff = ndims - ori_padding_dim;
for (size_t i = 0; i < (size_t)l_diff; i ++) {
for (size_t i = 0; i < (size_t)l_diff; i++) {
outputSizes.emplace_back(input_sizes[i]);
}
for (const auto i : c10::irange((size_t)ori_padding_dim)) {
@ -95,20 +111,38 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
// these checks aren't relevant for constant pad
TORCH_CHECK(pad_l < input_w && pad_r < input_w,
"Argument #4: Padding size should be less than the corresponding "
"input dimension, but got: padding (", pad_l, ", ", pad_r,
") at dimension ", dim_w, " of input ", ndims);
"input dimension, but got: padding (",
pad_l,
", ",
pad_r,
") at dimension ",
dim_w,
" of input ",
ndims);
if (padding_dim > 1) {
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
"Argument #6: Padding size should be less than the corresponding "
"input dimension, but got: padding (", pad_t, ", ", pad_b,
") at dimension ", dim_h, " of input ", ndims);
"input dimension, but got: padding (",
pad_t,
", ",
pad_b,
") at dimension ",
dim_h,
" of input ",
ndims);
}
if (padding_dim > 2) {
TORCH_CHECK(pad_front < input_d && pad_back < input_d,
"Argument #8: Padding size should be less than the corresponding "
"input dimension, but got: padding (", pad_front, ", ", pad_back,
") at dimension ", dim_d, " of input ", ndims);
"input dimension, but got: padding (",
pad_front,
", ",
pad_back,
") at dimension ",
dim_d,
" of input ",
ndims);
}
outputSizes.insert(outputSizes.begin(), output_w);
if (padding_dim >= 2)
@ -133,10 +167,16 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
input = input_.contiguous();
} else {
TORCH_CHECK(output_w == grad_output_.size(dim_w),
"gradOutput width unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
"gradOutput width unexpected. Expected: ",
output_w,
", Got: ",
grad_output_.size(dim_w));
if (padding_dim > 1) {
TORCH_CHECK(output_h == grad_output_.size(dim_h),
"gradOutput height unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
"gradOutput height unexpected. Expected: ",
output_h,
", Got: ",
grad_output_.size(dim_h));
}
output.resize_as_(input);
if (output.numel() == 0 || grad_output_.numel() == 0)
@ -157,7 +197,7 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
const int64_t rightIdx = pdim * 2 + 1;
const int64_t padIdx = ndims - pdim - 1;
leftPadVec [padIdx] = @(padding[leftIdx]);
leftPadVec[padIdx] = @(padding[leftIdx]);
rightPadVec[padIdx] = @(padding[rightIdx]);
// workaround for negative padding issue in backward pass
if (is_backward_pass) {
@ -180,8 +220,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
}
}
}
MPSShape *leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims];
MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims];
MPSShape* leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims];
MPSShape* rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims];
MPSDataType dataType = getMPSScalarType(input.scalar_type());
// workaround for Bool type assert with Constant padding
@ -190,20 +230,20 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { }
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
MPSGraphTensor *gradOutputTensor = nil;
MPSGraphTensor* gradOutputTensor = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" +
getArrayRefString(padding) + "]:" + std::to_string(constantValue);
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
"]:" + std::to_string(constantValue);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
@ -211,7 +251,7 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
if (!is_backward_pass) {
MPSGraphTensor *padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor
MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor
withPaddingMode:mode
leftPadding:leftPadding
rightPadding:rightPadding
@ -219,7 +259,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
name:nil];
// workaround for the right padding bug in Monterey
if (needsSlice) {
newCachedGraph->outputTensor = [mpsGraph sliceTensor:padTensor
newCachedGraph->outputTensor =
[mpsGraph sliceTensor:padTensor
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
@ -232,7 +273,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
}
} else {
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output));
MPSGraphTensor *padGradTensor = [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor
MPSGraphTensor* padGradTensor =
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor
sourceTensor:newCachedGraph->inputTensor
paddingMode:mode
leftPadding:leftPadding
@ -240,7 +282,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
name:nil];
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
if (needsSlice) {
newCachedGraph->outputTensor = [mpsGraph sliceGradientTensor:padGradTensor
newCachedGraph->outputTensor =
[mpsGraph sliceGradientTensor:padGradTensor
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil]
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
@ -259,17 +302,17 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr, true, dataType);
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() :
Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType);
Placeholder gradOutputPlaceholder = !is_backward_pass
? Placeholder()
: Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
if (is_backward_pass) {
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
return output;
@ -278,123 +321,156 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
// 1D Reflection and Replication Padding
TORCH_IMPL_FUNC(reflection_pad1d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps");
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
mps::pad_out_template(const_cast<Tensor&>(output),
input,
padding,
c10::nullopt,
MPSGraphPaddingModeReflect,
0.0,
"reflection_pad1d_out_mps");
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
{
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
grad_input.resize_as_(input).zero_();
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps");
mps::pad_out_template(const_cast<Tensor&>(grad_input),
input,
padding,
grad_output,
MPSGraphPaddingModeReflect,
0.0,
"reflection_pad1d_backward_out_mps");
}
TORCH_IMPL_FUNC(replication_pad1d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps");
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
mps::pad_out_template(const_cast<Tensor&>(output),
input,
padding,
c10::nullopt,
MPSGraphPaddingModeClampToEdge,
0.0,
"replication_pad1d_out_mps");
}
TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
{
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
grad_input.resize_as_(input).zero_();
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps");
mps::pad_out_template(const_cast<Tensor&>(grad_input),
input,
padding,
grad_output,
MPSGraphPaddingModeClampToEdge,
0.0,
"replication_pad1d_backward_out_mps");
}
// 2D Reflection and Replication Padding
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output)
{
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) {
return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
}
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding)
{
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) {
Tensor output = at::empty({0}, input.options());
return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
}
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
{
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
Tensor& grad_input) {
grad_input.resize_as_(input).zero_();
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
}
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
{
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
}
TORCH_IMPL_FUNC(replication_pad2d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps");
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
mps::pad_out_template(const_cast<Tensor&>(output),
input,
padding,
c10::nullopt,
MPSGraphPaddingModeClampToEdge,
0.0,
"replication_pad2d_out_mps");
}
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
{
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
Tensor& grad_input) {
grad_input.resize_as_(input).zero_();
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
{
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
// 3D Reflection and Replication Padding
TORCH_IMPL_FUNC(reflection_pad3d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps");
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
mps::pad_out_template(const_cast<Tensor&>(output),
input,
padding,
c10::nullopt,
MPSGraphPaddingModeReflect,
0.0,
"reflection_pad3d_out_mps");
}
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
{
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
grad_input.resize_as_(input).zero_();
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps");
mps::pad_out_template(const_cast<Tensor&>(grad_input),
input,
padding,
grad_output,
MPSGraphPaddingModeReflect,
0.0,
"reflection_pad3d_backward_out_mps");
}
TORCH_IMPL_FUNC(replication_pad3d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps");
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
mps::pad_out_template(const_cast<Tensor&>(output),
input,
padding,
c10::nullopt,
MPSGraphPaddingModeClampToEdge,
0.0,
"replication_pad3d_out_mps");
}
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
{
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
Tensor& grad_input) {
grad_input.resize_as_(input).zero_();
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
{
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
// backward pass is exlicitly handled in autograd by negating the "pad" argument
Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value)
{
Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) {
if (pad.size() > 6) {
TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ",
"It uses View Ops default implementation to run. This may have performance implications.");
return at::native::constant_pad_nd(self, pad, value);
}
Tensor output = at::empty({0}, self.options());
return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
return mps::pad_out_template(
output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
}
} // namespace at::native

View File

@ -12,22 +12,20 @@ void addc_mul_div_out_mps(const Tensor& self,
const Scalar& value_opt, // default value = 1.0
const Tensor& output,
const bool is_div,
const string op_name)
{
const string op_name) {
if (value_opt.toDouble() == 0.0) {
output.copy_(self);
return;
}
if(output.numel() == 0) {
if (output.numel() == 0) {
return;
}
MPSStream* mpsStream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
MPSGraphTensor *firstTensor = nil, *secondTensor = nil, *valueTensor = nil;
};
@ -39,10 +37,10 @@ void addc_mul_div_out_mps(const Tensor& self,
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type()));
ScalarType common_dtype =
c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type()));
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
@ -50,26 +48,27 @@ void addc_mul_div_out_mps(const Tensor& self,
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]);
newCachedGraph->valueTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[ @1 ]);
// the tensor to be optionally multiplied by value_scalar
MPSGraphTensor *multiplicandTensor = nil;
MPSGraphTensor* multiplicandTensor = nil;
auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype);
auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype);
if (is_div) {
multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor
secondaryTensor:secondTensor
name:nil];
multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor secondaryTensor:secondTensor name:nil];
} else {
multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor
secondaryTensor:secondTensor
name:nil];
}
// the tensor to be added to input_tensor
MPSGraphTensor *addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor
MPSGraphTensor* addendTensor = [mpsGraph
multiplicationWithPrimaryTensor:multiplicandTensor
secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype)
name:nil];
auto outputTensor = [mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype)
auto outputTensor =
[mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype)
secondaryTensor:addendTensor
name:nil];
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type());
@ -93,9 +92,8 @@ void addc_mul_div_out_mps(const Tensor& self,
cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_scalar),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
}
@ -105,14 +103,12 @@ void addc_mul_div_out_mps(const Tensor& self,
// APIs exposed to at::native scope
TORCH_IMPL_FUNC(addcmul_out_mps)
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output)
{
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) {
mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, false, "addcmul_out_mps");
}
TORCH_IMPL_FUNC(addcdiv_out_mps)
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output)
{
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) {
mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, true, "addcdiv_out_mps");
}

View File

@ -1,14 +1,13 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/Pool.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
struct PoolingCachedGraph : public MPSCachedGraph
{
PoolingCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct PoolingCachedGraph : public MPSCachedGraph {
PoolingCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor = nil;
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* indicesTensor = nil;
@ -17,24 +16,30 @@ struct PoolingCachedGraph : public MPSCachedGraph
};
typedef MPSGraphTensor* (^PoolingOpBlock)(PoolingCachedGraph&, MPSGraphPooling2DOpDescriptor*);
#define PoolingOpFn(graph, desc) MPSGraphTensor* (mps::PoolingCachedGraph& graph, MPSGraphPooling2DOpDescriptor* desc)
#define PoolingOpFn(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph & graph, MPSGraphPooling2DOpDescriptor * desc)
// Pooling ops (1D/2D forward and backward Max and Average pooling)
static void pool2d_template(const Tensor& input, const Tensor& output,
static void pool2d_template(const Tensor& input,
const Tensor& output,
const c10::optional<Tensor>& indices_opt,
const c10::optional<Tensor>& grad_output_opt,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation,
bool ceil_mode, bool count_include_pad,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
bool count_include_pad,
const c10::optional<int64_t> divisor_override,
PoolingOpBlock poolingBlock, const c10::string& op_name)
{
PoolingOpBlock poolingBlock,
const c10::string& op_name) {
if (input.numel() == 0) {
return;
}
if (!is_macos_13_or_newer()) {
TORCH_CHECK(input.scalar_type() != ScalarType::Long,
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0.");
"MPS: ",
op_name,
" op with int64 input is supported natively starting from macOS 13.0.");
}
const int64_t ndims = input.ndimension();
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
@ -48,13 +53,17 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
// be incompatible with the PyTorch's global NCHW layout.
const auto memory_format = has_indices ? MemoryFormat::Contiguous : suggested_memory_format;
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, op_name,
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
op_name,
": kernel_size must either be a single int, or a tuple of two ints")
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, op_name,
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
op_name,
": stride must either be omitted, a single int, or a tuple of two ints")
TORCH_CHECK(padding.size() == 1 || padding.size() == 2, op_name,
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
op_name,
": padding must be either be a single int, or a tuple of two ints");
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, op_name,
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
op_name,
": dilation must be either a single int, or a tuple of two ints");
if (suggested_memory_format == at::MemoryFormat::ChannelsLast) {
@ -80,8 +89,21 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
pool2d_shape_check(input,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format);
auto output_memory_format = output.suggest_memory_format();
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
@ -90,7 +112,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
indices.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
if (output.numel() == 0) {
std::vector<int64_t> outputSizes {nInputPlane, outputHeight, outputWidth};
std::vector<int64_t> outputSizes{nInputPlane, outputHeight, outputWidth};
if (ndims == 4) {
outputSizes.insert(outputSizes.begin(), nbatch);
}
@ -111,10 +133,9 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" +
getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" +
getArrayRefString(padding) + "]:D[" + getArrayRefString(dilation) + "]" +
(ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + getArrayRefString(kernel_size) +
"]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) + "]:D[" +
getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
(has_divisor ? ":divisor" : "") + ":" +
(suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
@ -123,44 +144,46 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
PoolingCachedGraph* cachedGraph = cache_->LookUpAs<PoolingCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<PoolingCachedGraph>(key, ^ MPSCachedGraph * () {
PoolingCachedGraph *newCachedGraph = nil;
cachedGraph = cache_->CreateCachedGraphAs<PoolingCachedGraph>(key, ^MPSCachedGraph*() {
PoolingCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new PoolingCachedGraph(mpsGraph);
MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor
descriptorWithKernelWidth: kW
kernelHeight: kH
strideInX: dW
strideInY: dH
dilationRateInX: dilationW
dilationRateInY: dilationH
paddingLeft: padW
paddingRight: ceil_mode ? padW * dW : padW
paddingTop: padH
paddingBottom: ceil_mode ? padH * dH : padH
paddingStyle: MPSGraphPaddingStyleExplicit
dataLayout: memory_format == MemoryFormat::ChannelsLast ?
MPSGraphTensorNamedDataLayoutNHWC :
MPSGraphTensorNamedDataLayoutNCHW];
MPSGraphPooling2DOpDescriptor* desc =
[MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:kW
kernelHeight:kH
strideInX:dW
strideInY:dH
dilationRateInX:dilationW
dilationRateInY:dilationH
paddingLeft:padW
paddingRight:ceil_mode ? padW * dW : padW
paddingTop:padH
paddingBottom:ceil_mode ? padH * dH : padH
paddingStyle:MPSGraphPaddingStyleExplicit
dataLayout:memory_format == MemoryFormat::ChannelsLast
? MPSGraphTensorNamedDataLayoutNHWC
: MPSGraphTensorNamedDataLayoutNCHW];
desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false;
if (has_indices) {
desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D;
desc.returnIndicesDataType = MPSDataTypeInt32;
}
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
newCachedGraph->inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
if (is_backward_pass) {
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
newCachedGraph->gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
}
if (has_divisor) {
newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[@1]);
newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[ @1 ]);
}
MPSGraphTensor* outputTensor = poolingBlock(*newCachedGraph, desc);
// with desc.dataLayout = NHWC (i.e., ChannelsLast), the results need to be converted back to NCHW
newCachedGraph->outputTensor = memory_format == MemoryFormat::ChannelsLast ?
convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
newCachedGraph->outputTensor =
memory_format == MemoryFormat::ChannelsLast ? convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
}
return newCachedGraph;
});
@ -168,14 +191,16 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
MPSStream* mpsStream = getCurrentMPSStream();
// in case of ChannelsLast we don't perform gather() in placeholder to avoid implicit conversion to NCHW
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast);
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() :
Placeholder(cachedGraph->gradOutputTensor, grad_output,
gradOutputShape, memory_format != MemoryFormat::ChannelsLast);
Placeholder inputPlaceholder =
Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast);
Placeholder gradOutputPlaceholder = !is_backward_pass
? Placeholder()
: Placeholder(
cachedGraph->gradOutputTensor, grad_output, gradOutputShape, memory_format != MemoryFormat::ChannelsLast);
Placeholder indicesPlaceholder = has_indices ? Placeholder(cachedGraph->indicesTensor, indices) : Placeholder();
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary *results = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* results = [[NSMutableDictionary new] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
results[outputPlaceholder.getMPSGraphTensor()] = outputPlaceholder.getMPSGraphTensorData();
@ -192,7 +217,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
}
MPSScalar divisor_scalar;
if (cachedGraph->divisorTensor) {
const float divisor = float(kH * kW) / (float) divisor_override.value();
const float divisor = float(kH * kW) / (float)divisor_override.value();
divisor_scalar = getMPSScalar(divisor, ScalarType::Float);
feeds[cachedGraph->divisorTensor] = getMPSGraphTensorFromScalar(mpsStream, divisor_scalar);
}
@ -205,14 +230,17 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
}
}
static void avg_pool2d_template(const Tensor& input, const Tensor& output,
static void avg_pool2d_template(const Tensor& input,
const Tensor& output,
const c10::optional<Tensor>& grad_output_opt,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation,
bool ceil_mode, bool count_include_pad,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
bool count_include_pad,
const c10::optional<int64_t> divisor_override,
const c10::string& op_name)
{
const c10::string& op_name) {
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
const bool is_backward_pass = grad_output.defined();
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
@ -226,12 +254,21 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
"not supported on MPS backend. ",
"Falling back on CPU. This may have performance implications.");
if (!is_backward_pass) {
const_cast<Tensor&>(output) = at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode,
count_include_pad, divisor_override).clone().to("mps");
const_cast<Tensor&>(output) =
at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
.clone()
.to("mps");
} else {
const_cast<Tensor&>(output) = at::avg_pool2d_backward(grad_output.to("cpu"), input.to("cpu"),
kernel_size, stride, padding, ceil_mode, count_include_pad,
divisor_override).clone().to("mps");
const_cast<Tensor&>(output) = at::avg_pool2d_backward(grad_output.to("cpu"),
input.to("cpu"),
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override)
.clone()
.to("mps");
}
return;
}
@ -239,7 +276,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
const int64_t ndims = input.ndimension();
MPSShape *paddingShape = nil;
MPSShape* paddingShape = nil;
MPSGraphTensor* paddedTensor = cachedGraph.inputTensor;
// workaround for issue #103039644: mismatching MPS vs. CPU results
@ -249,14 +286,14 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
std::vector<NSNumber*> padVec(ndims, @(0));
padVec[ndims - 1] = @(padding.size() == 1 ? padding[0] : padding[1]);
padVec[ndims - 2] = @(ndims > 3 ? padding[0] : 0);
paddingShape = [NSArray arrayWithObjects: padVec.data() count:ndims];
paddedTensor = [mpsGraph padTensor: cachedGraph.inputTensor
withPaddingMode: MPSGraphPaddingModeZero
leftPadding: paddingShape
rightPadding: paddingShape
constantValue: 0.0
name: nil];
paddedTensor = [mpsGraph identityWithTensor: paddedTensor name: nil];
paddingShape = [NSArray arrayWithObjects:padVec.data() count:ndims];
paddedTensor = [mpsGraph padTensor:cachedGraph.inputTensor
withPaddingMode:MPSGraphPaddingModeZero
leftPadding:paddingShape
rightPadding:paddingShape
constantValue:0.0
name:nil];
paddedTensor = [mpsGraph identityWithTensor:paddedTensor name:nil];
} else {
desc.includeZeroPadToAverage = count_include_pad;
}
@ -265,35 +302,33 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
}
if (!is_backward_pass) {
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor: paddedTensor
descriptor: desc
name: nil];
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor:paddedTensor descriptor:desc name:nil];
if (cachedGraph.divisorTensor) {
// workaround: custom divisor isn't supported by MPS backend, so we scale manually
return [mpsGraph multiplicationWithPrimaryTensor: avgPoolTensor
secondaryTensor: cachedGraph.divisorTensor
name: nil];
return [mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor
secondaryTensor:cachedGraph.divisorTensor
name:nil];
} else {
return avgPoolTensor;
}
} else { // backward pass
MPSGraphTensor* scaledGradTensor = cachedGraph.gradOutputTensor;
if (cachedGraph.divisorTensor) {
scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor: cachedGraph.gradOutputTensor
secondaryTensor: cachedGraph.divisorTensor
name: nil];
scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor:cachedGraph.gradOutputTensor
secondaryTensor:cachedGraph.divisorTensor
name:nil];
}
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor: scaledGradTensor
sourceTensor: paddedTensor
descriptor: desc
name: nil];
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor:scaledGradTensor
sourceTensor:paddedTensor
descriptor:desc
name:nil];
if (explicit_padding) {
return [mpsGraph padGradientWithIncomingGradientTensor: avgPoolTensor
sourceTensor: cachedGraph.inputTensor
paddingMode: MPSGraphPaddingModeZero
leftPadding: paddingShape
rightPadding: paddingShape
name: nil];
return [mpsGraph padGradientWithIncomingGradientTensor:avgPoolTensor
sourceTensor:cachedGraph.inputTensor
paddingMode:MPSGraphPaddingModeZero
leftPadding:paddingShape
rightPadding:paddingShape
name:nil];
} else {
return avgPoolTensor;
@ -301,59 +336,85 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
}
};
pool2d_template(input, output, c10::nullopt, grad_output_opt, kernel_size, stride,
padding, {1, 1}, ceil_mode, count_include_pad, divisor_override,
pooling_op_block, op_name);
pool2d_template(input,
output,
c10::nullopt,
grad_output_opt,
kernel_size,
stride,
padding,
{1, 1},
ceil_mode,
count_include_pad,
divisor_override,
pooling_op_block,
op_name);
}
} // namespace mps
Tensor mps_max_pool2d(
const Tensor& input,
Tensor mps_max_pool2d(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor
descriptor: desc
name: nil];
return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
};
mps::pool2d_template(input, output, c10::nullopt, c10::nullopt, kernel_size, stride,
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d");
mps::pool2d_template(input,
output,
c10::nullopt,
c10::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
c10::nullopt,
pooling_op_block,
"max_pool2d");
return output;
}
Tensor mps_max_pool2d_backward(
const Tensor& grad_output,
Tensor mps_max_pool2d_backward(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous);
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor
sourceTensor: cachedGraph.inputTensor
descriptor: desc
name: nil];
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor
sourceTensor:cachedGraph.inputTensor
descriptor:desc
name:nil];
};
mps::pool2d_template(input, grad_input, c10::nullopt, grad_output, kernel_size, stride,
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_backward");
mps::pool2d_template(input,
grad_input,
c10::nullopt,
grad_output,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
c10::nullopt,
pooling_op_block,
"max_pool2d_backward");
return grad_input;
}
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
const Tensor& input,
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
@ -361,27 +422,37 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
bool ceil_mode,
const Tensor& output,
const Tensor& indices) {
auto indices_memory_format = indices.suggest_memory_format();
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
descriptor: desc
name: nil];
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor
descriptor:desc
name:nil];
cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
return poolOutputs[0];
};
mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride,
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices");
mps::pool2d_template(input,
output,
indices,
c10::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
c10::nullopt,
pooling_op_block,
"max_pool2d_indices");
if (indices_memory_format == MemoryFormat::ChannelsLast) {
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
}
}
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
const Tensor& grad_output,
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
@ -390,20 +461,30 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
bool ceil_mode,
const Tensor& indices,
const Tensor& grad_input) {
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor
sourceTensor: cachedGraph.inputTensor
descriptor: desc
name: nil];
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor
sourceTensor:cachedGraph.inputTensor
descriptor:desc
name:nil];
};
mps::pool2d_template(input, grad_input, indices, grad_output, kernel_size, stride,
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices_backward");
mps::pool2d_template(input,
grad_input,
indices,
grad_output,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
c10::nullopt,
pooling_op_block,
"max_pool2d_indices_backward");
}
TORCH_IMPL_FUNC(avg_pool2d_out_mps) (
const Tensor& input,
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
(const Tensor& input,
int64_t kH,
int64_t kW,
int64_t dH,
@ -414,13 +495,21 @@ TORCH_IMPL_FUNC(avg_pool2d_out_mps) (
bool count_include_pad,
c10::optional<int64_t> divisor_override,
const Tensor& output) {
mps::avg_pool2d_template(input, output, c10::nullopt, {kH, kW}, {dH, dW}, {padH, padW},
{1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d");
mps::avg_pool2d_template(input,
output,
c10::nullopt,
{kH, kW},
{dH, dW},
{padH, padW},
{1, 1},
ceil_mode,
count_include_pad,
divisor_override,
"avg_pool2d");
}
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) (
const Tensor& gradOutput,
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)
(const Tensor& gradOutput,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
@ -429,9 +518,17 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) (
bool count_include_pad,
c10::optional<int64_t> divisor_override,
const Tensor& gradInput) {
mps::avg_pool2d_template(input, gradInput, gradOutput, kernel_size, stride, padding,
{1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d_backward");
mps::avg_pool2d_template(input,
gradInput,
gradOutput,
kernel_size,
stride,
padding,
{1, 1},
ceil_mode,
count_include_pad,
divisor_override,
"avg_pool2d_backward");
}
} // namespace at::native

View File

@ -1,9 +1,9 @@
// Copyright © 2022 Apple Inc.
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/AccumulateType.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
@ -15,37 +15,38 @@ namespace at::native {
namespace {
struct RangeCachedGraph : public mps::MPSCachedGraph {
API_AVAILABLE(macosx(12.3))
RangeCachedGraph(MPSGraph *mpsGraph, MPSDataType dataType, int32_t shapeVal, bool needsClamp = false, bool startLessEnd = false): MPSCachedGraph(mpsGraph) {
RangeCachedGraph(MPSGraph* mpsGraph,
MPSDataType dataType,
int32_t shapeVal,
bool needsClamp = false,
bool startLessEnd = false)
: MPSCachedGraph(mpsGraph) {
@autoreleasepool {
auto shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:&shapeVal length:sizeof(int32_t)]
shape: @[@1]
shape:@[ @1 ]
dataType:MPSDataTypeInt32];
auto coordsTensor = [mpsGraph coordinateAlongAxis:0
withShapeTensor:shapeTensor
name:nil];
auto coordsTensor = [mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil];
coordsTensor = [mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"];
startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]);
multiplyTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]);
startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
multiplyTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
auto scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor
secondaryTensor:multiplyTensor
name:nil];
outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords
secondaryTensor:startTensor
name:nil];
outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil];
if (needsClamp) {
endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]);
endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
outputTensor = [mpsGraph clampWithTensor:outputTensor
minValueTensor: startLessEnd? startTensor : endTensor
maxValueTensor: startLessEnd? endTensor : startTensor
name: nil];
minValueTensor:startLessEnd ? startTensor : endTensor
maxValueTensor:startLessEnd ? endTensor : startTensor
name:nil];
}
}
}
MPSGraphTensor *startTensor = nil;
MPSGraphTensor *endTensor = nil;
MPSGraphTensor *multiplyTensor = nil;
MPSGraphTensor *outputTensor = nil;
MPSGraphTensor* startTensor = nil;
MPSGraphTensor* endTensor = nil;
MPSGraphTensor* multiplyTensor = nil;
MPSGraphTensor* outputTensor = nil;
};
} // anonymous namespace
@ -59,17 +60,17 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
double size_d;
if (std::is_same<scalar_t, int64_t>::value) {
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
/ step.to<accscalar_t>());
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) / step.to<accscalar_t>());
} else {
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
/ step.to<double>());
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>());
}
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
"unsupported range: ",
xstart,
" -> ",
xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
@ -79,11 +80,17 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
int64_t numel = result.numel();
if (numel != size) {
if(numel > 0){
TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(),
" is ", numel, " which does not match the computed number of elements ", size,
if (numel > 0) {
TORCH_WARN("The number of elements in the out tensor of shape ",
result.sizes(),
" is ",
numel,
" which does not match the computed number of elements ",
size,
". Note that this may occur as a result of rounding error. "
"The out tensor will be resized to a tensor of shape (", size, ",).");
"The out tensor will be resized to a tensor of shape (",
size,
",).");
}
result.resize_({size});
}
@ -100,28 +107,27 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
auto mpsDataType = getMPSDataType(result);
@autoreleasepool {
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
auto cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() {
auto* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
auto mpsGraph = make_mps_graph();
return new RangeCachedGraph(mpsGraph, mpsDataType, size);
});
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<RangeCachedGraph*>(tmpCachedGraph);
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
MPSScalar startScalar = getMPSScalar(start, result.scalar_type());
feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar);
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
if(!is_contiguous) {
if (!is_contiguous) {
result.copy_(r);
}
});
@ -139,17 +145,17 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
// double size_d = ((xend - xstart) / xstep) + 1;
double size_d;
if (std::is_same<scalar_t, int64_t>::value) {
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
/ step.to<accscalar_t>() + 1;
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) / step.to<accscalar_t>() + 1;
} else {
size_d = static_cast<double>(end.to<double>() - start.to<double>())
/ step.to<double>() + 1;
size_d = static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>() + 1;
}
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
"unsupported range: ",
xstart,
" -> ",
xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
@ -171,28 +177,27 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
auto mpsDataType = getMPSDataType(result);
@autoreleasepool {
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
auto cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() {
auto* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
auto mpsGraph = make_mps_graph();
return new RangeCachedGraph(mpsGraph, mpsDataType, size);
});
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<RangeCachedGraph*>(tmpCachedGraph);
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
MPSScalar startScalar = getMPSScalar(start, result.scalar_type());
feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar);
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
if(!is_contiguous) {
if (!is_contiguous) {
result.copy_(r);
}
});
@ -222,28 +227,30 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
bool start_less_end = (start.to<double>() <= end.to<double>());
@autoreleasepool {
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
RangeCachedGraph* cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
string key =
"linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
RangeCachedGraph* cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
RangeCachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
RangeCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new RangeCachedGraph(mpsGraph, MPSDataTypeFloat32, steps, true, start_less_end);
if(getMPSDataType(result) != MPSDataTypeFloat32) {
newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor toType:getMPSDataType(result) name:@"output"];
if (getMPSDataType(result) != MPSDataTypeFloat32) {
newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor
toType:getMPSDataType(result)
name:@"output"];
}
}
return newCachedGraph;
});
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<RangeCachedGraph*>(tmpCachedGraph);
}
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
auto multiply = (end.to<double>() - start.to<double>()) / ((double)steps - 1.0f);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
@ -255,9 +262,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
MPSScalar multiplyScalar = getMPSScalar(multiply, ScalarType::Float);
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, multiplyScalar);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

File diff suppressed because it is too large Load Diff

View File

@ -8,8 +8,8 @@
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Repeat.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>
#include <fmt/format.h>
#include <torch/library.h>
#ifdef __OBJC__
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
@ -19,8 +19,7 @@ namespace at::native {
Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
auto nDims = self.dim();
TORCH_CHECK(dims.size() == (size_t)nDims,
"number of dims don't match in permute");
TORCH_CHECK(dims.size() == (size_t)nDims, "number of dims don't match in permute");
auto oldSizes = self.sizes();
auto oldStrides = self.strides();
DimVector newSizes(nDims);
@ -28,8 +27,7 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
std::vector<bool> seen(nDims);
for (const auto i : c10::irange(nDims)) {
auto dim = maybe_wrap_dim(dims[i], nDims);
TORCH_CHECK(!seen[dim],
"repeated dim in permute");
TORCH_CHECK(!seen[dim], "repeated dim in permute");
seen[dim] = true;
newSizes[i] = oldSizes[dim];
newStrides[i] = oldStrides[dim];
@ -38,16 +36,14 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
}
Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
using namespace mps;
TORCH_CHECK(repeats.size() >= (size_t)self.dim(),
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
// Add new leading dimensions to the tensor if the
@ -58,7 +54,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
DimVector target_size(repeats.size());
bool zero_tensor = false;
for(const auto idx : c10::irange(repeats.size())) {
for (const auto idx : c10::irange(repeats.size())) {
if (repeats[idx] == 0) {
zero_tensor = true;
}
@ -68,7 +64,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
Tensor expanded_tensor = self.expand(padded_size);
Tensor result = at::empty(target_size, self.options());
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
if(zero_tensor || result.numel() == 0) {
if (zero_tensor || result.numel() == 0) {
return result;
}
@ -86,40 +82,37 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
@autoreleasepool {
string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor
withMultiplier:getMPSShape(repeats)
name:nil];
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor withMultiplier:getMPSShape(repeats) name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(
cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType);
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/ false, outputDataType);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
@ -142,18 +135,18 @@ kernel void repeat_interleave(constant {0} * repeat_ptr [[buf
}}
)METAL_REPEAT";
static
id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::string& t1) {
static id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::string& t1) {
auto key = t1;
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
return it->second;
}
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()]
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
@ -161,8 +154,7 @@ id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::strin
return rc;
}
static
id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::string& t1) {
static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::string& t1) {
static std::string kernel = "repeat_interleave";
auto key = kernel + t1;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
@ -170,19 +162,19 @@ id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::st
if (it != cplMap.end()) {
return it->second;
}
NSError *error = nil;
NSError* error = nil;
auto library = compileRepeatInterleaveLib(device, t1);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(func != nil, "Can't get kernel ", kernel);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
TORCH_CHECK(
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;
}
template <typename index_t>
void computeRepeatIndices(
index_t* repeat_ptr,
void computeRepeatIndices(index_t* repeat_ptr,
int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
@ -208,7 +200,7 @@ void computeRepeatIndices(
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
id<MTLComputePipelineState> pipelineState = getPipelineState(MPSDevice::getInstance()->device(), scalar_type);
[computeEncoder setComputePipelineState: pipelineState];
[computeEncoder setComputePipelineState:pipelineState];
[computeEncoder setBuffer:repeatBuffer offset:0 atIndex:0];
[computeEncoder setBuffer:cumsumBuffer offset:0 atIndex:1];
[computeEncoder setBuffer:resultBuffer offset:0 atIndex:2];
@ -233,12 +225,12 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
TORCH_WARN_ONCE(
"MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
repeat = repeat.to(kInt);
}
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(
repeat, output_size);
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});
return output;
}

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,13 @@
// Copyright © 2022 Apple Inc.
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/Copy.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>
#ifdef __OBJC__
@ -21,8 +21,12 @@ namespace at::native {
Scalar _local_scalar_dense_mps(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_mps", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half,
at::ScalarType::Bool,
at::ScalarType::BFloat16,
self.scalar_type(),
"_local_scalar_dense_mps",
[&] {
Tensor output = at::empty_like(self, kCPU);
Tensor cpu_output = mps::mps_copy_(output, self, false);
@ -33,5 +37,4 @@ Scalar _local_scalar_dense_mps(const Tensor& self) {
return r;
}
} // namespace at::native

View File

@ -5,12 +5,7 @@
namespace at::native {
TORCH_IMPL_FUNC(gather_out_mps)
(const Tensor & self_arg,
int64_t dim,
const Tensor & index,
bool sparse_grad,
const Tensor & output)
{
(const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) {
using namespace mps;
if (self_arg.numel() == 0 || index.numel() == 0) {
@ -20,14 +15,11 @@ TORCH_IMPL_FUNC(gather_out_mps)
dim = at::maybe_wrap_dim(dim, self.dim());
TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet")
TORCH_CHECK(self.scalar_type() == output.scalar_type(),
"gather(): self and output must have the same scalar type");
TORCH_CHECK(dim >= 0 && dim < self.dim(),
"gather(): Indexing dim ", dim, " is out of bounds of tensor");
TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type");
TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor");
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* indexTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
@ -36,7 +28,6 @@ TORCH_IMPL_FUNC(gather_out_mps)
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
MPSShape* input_shape = getMPSShape(self);
MPSShape* index_shape = getMPSShape(index);
uint32_t num_input_dims = [input_shape count];
@ -47,8 +38,9 @@ TORCH_IMPL_FUNC(gather_out_mps)
bool needSlice = false;
for (const auto i : c10::irange(num_input_dims)) {
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
"Index dim must not exceed input dim except at gathering axis")
if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
needSlice = true;
}
auto input_type = getMPSDataType(self);
@ -60,11 +52,11 @@ TORCH_IMPL_FUNC(gather_out_mps)
output_type = MPSDataTypeInt8;
}
string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -76,10 +68,10 @@ TORCH_IMPL_FUNC(gather_out_mps)
MPSGraphTensor* getInput = inputTensor;
// Slice into the input tensor IF NEEDED
if(needSlice) {
NSMutableArray<NSNumber*> *starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*> *ends = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*> *strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
if (needSlice) {
NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*>* ends = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
for (const auto i : c10::irange(num_input_dims)) {
// All strides are 1
@ -89,23 +81,19 @@ TORCH_IMPL_FUNC(gather_out_mps)
ends[i] = (i != dim) ? index_shape[i] : input_shape[i];
}
getInput = [mpsGraph sliceTensor:inputTensor
starts:starts
ends:ends
strides:strides
name:nil];
getInput = [mpsGraph sliceTensor:inputTensor starts:starts ends:ends strides:strides name:nil];
}
MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor
toType:MPSDataTypeInt32
name:(NSString * _Nonnull)nil];
name:(NSString* _Nonnull)nil];
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wobjc-method-access"
MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis: (NSInteger) dim
withUpdatesTensor: getInput
indicesTensor: castIndexTensor
name: nil];
MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis:(NSInteger)dim
withUpdatesTensor:getInput
indicesTensor:castIndexTensor
name:nil];
#pragma clang diagnostic pop
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->indexTensor_ = indexTensor;
@ -113,7 +101,7 @@ TORCH_IMPL_FUNC(gather_out_mps)
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type);
@ -124,23 +112,20 @@ TORCH_IMPL_FUNC(gather_out_mps)
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
}
void scatter_mps_general
(const Tensor& self_arg,
void scatter_mps_general(const Tensor& self_arg,
int64_t dim,
const Tensor& index,
const Tensor& src,
const Tensor& output,
string func_name,
const c10::string_view reduce)
{
const c10::string_view reduce) {
using namespace mps;
if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
@ -151,12 +136,10 @@ void scatter_mps_general
TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(),
"scatter(): self, src and output must have the same scalar type");
TORCH_CHECK(dim >= 0 && dim < self.dim(),
"scatter(): Indexing dim ", dim, " is out of bounds of tensor");
TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor");
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* indexTensor_ = nil;
MPSGraphTensor* srcTensor_ = nil;
@ -166,7 +149,6 @@ void scatter_mps_general
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
MPSShape* input_shape = getMPSShape(self);
MPSShape* index_shape = getMPSShape(index);
MPSShape* src_shape = getMPSShape(src);
@ -174,7 +156,8 @@ void scatter_mps_general
uint32_t num_index_dims = [index_shape count];
uint32_t num_src_dims = [src_shape count];
TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, "Input, index and src must have same rank")
TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims,
"Input, index and src must have same rank")
// Do we need to slice into the src tensor?
bool needSlice = false;
@ -182,11 +165,13 @@ void scatter_mps_general
bool needsCast = false;
for (const auto i : c10::irange(num_input_dims)) {
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
if([index_shape[i] intValue] < [src_shape[i] intValue])
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
"Index dim must not exceed input dim except at gathering axis")
TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue],
"Index dim must not exceed input dim except at gathering axis")
if ([index_shape[i] intValue] < [src_shape[i] intValue])
needSlice = true;
if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
inputNeedSlice = true;
}
TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS")
@ -197,11 +182,12 @@ void scatter_mps_general
needsCast = true;
}
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" +
std::string(reduce);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -229,9 +215,9 @@ void scatter_mps_general
// Slice into the src or input tensors IF NEEDED
if (needSlice || inputNeedSlice) {
NSMutableArray<NSNumber*> *starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*> *strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*> *ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
NSMutableArray<NSNumber*>* ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
for (const auto i : c10::irange(num_input_dims)) {
strides[i] = @1;
@ -240,11 +226,7 @@ void scatter_mps_general
scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i];
}
if (needSlice) {
slicedSrc = [mpsGraph sliceTensor:castSrcTensor
starts:starts
ends:ends_src
strides:strides
name:nil];
slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts ends:ends_src strides:strides name:nil];
}
if (inputNeedSlice) {
slicedInput = [mpsGraph sliceTensor:castInputTensor
@ -256,28 +238,29 @@ void scatter_mps_general
}
MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet;
if(reduce == "sum" || reduce == "add")
if (reduce == "sum" || reduce == "add")
scatter_mode = MPSGraphScatterModeAdd;
else if(reduce == "prod" || reduce == "multiply")
else if (reduce == "prod" || reduce == "multiply")
scatter_mode = MPSGraphScatterModeMul;
else if(reduce == "amax")
else if (reduce == "amax")
scatter_mode = MPSGraphScatterModeMax;
else if(reduce == "amin")
else if (reduce == "amin")
scatter_mode = MPSGraphScatterModeMin;
// Scatter this into the input with set mode
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wobjc-method-access"
MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim
withDataTensor: slicedInput
updatesTensor: slicedSrc
indicesTensor: castIndexTensor
mode: scatter_mode
name: nil];
MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis:(NSInteger)dim
withDataTensor:slicedInput
updatesTensor:slicedSrc
indicesTensor:castIndexTensor
mode:scatter_mode
name:nil];
#pragma clang diagnostic pop
if(inputNeedSlice) {
if (inputNeedSlice) {
// Make an array of scatter indices tensors
NSMutableArray<MPSGraphTensor*>* indicesTensors = [NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
NSMutableArray<MPSGraphTensor*>* indicesTensors =
[NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
// 1. Concatenate the coord tensors
// 2. Flatten the values
@ -289,18 +272,18 @@ void scatter_mps_general
shape_data[i] = {[scatterInputShape[i] intValue]};
}
MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
shape:@[[NSNumber numberWithUnsignedInt:num_input_dims]]
MPSGraphTensor* scatterInputShapeTensor =
[mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
shape:@[ [NSNumber numberWithUnsignedInt:num_input_dims] ]
dataType:MPSDataTypeInt32];
for (const auto i : c10::irange(num_input_dims)) {
MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i
dataType:MPSDataTypeInt32];
MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor: axisTensor
withShapeTensor: scatterInputShapeTensor
name: nil];
MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i dataType:MPSDataTypeInt32];
MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor:axisTensor
withShapeTensor:scatterInputShapeTensor
name:nil];
scatter_currentIndexTensor = [mpsGraph reshapeTensor:scatter_currentIndexTensor
withShape:@[@-1, @1]
withShape:@[ @-1, @1 ]
name:nil];
indicesTensors[i] = scatter_currentIndexTensor;
}
@ -309,9 +292,7 @@ void scatter_mps_general
dimension:(NSInteger)1
name:nil];
MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor
withShape:@[@-1]
name:nil];
MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor withShape:@[ @-1 ] name:nil];
outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor
updatesTensor:flatValuesTensor
@ -325,11 +306,12 @@ void scatter_mps_general
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->srcTensor_ = srcTensor;
newCachedGraph->indexTensor_ = indexTensor;
newCachedGraph->outputTensor_ = needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
newCachedGraph->outputTensor_ =
needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape);
@ -342,41 +324,24 @@ void scatter_mps_general
srcPlaceholder.getMPSGraphTensor() : srcPlaceholder.getMPSGraphTensorData(),
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
}
TORCH_IMPL_FUNC(scatter_src_out_mps)
(const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& src,
const Tensor& output) {
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
scatter_mps_general(self, dim, index, src, output, "scatter_src_out_mps", "set");
}
TORCH_IMPL_FUNC(scatter_value_out_mps)
(const Tensor& self,
int64_t dim,
const Tensor& index,
const Scalar& value,
const Tensor& output) {
Tensor src = at::native::empty_mps(index.sizes(),
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
self.suggest_memory_format());
(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& output) {
Tensor src = at::native::empty_mps(
index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
src.fill_(value);
scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_out_mps", "set");
}
TORCH_IMPL_FUNC(scatter_reduce_out_mps)
@ -386,9 +351,7 @@ TORCH_IMPL_FUNC(scatter_reduce_out_mps)
const Tensor& src,
const c10::string_view reduce,
const Tensor& output) {
scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce);
}
TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
@ -398,25 +361,14 @@ TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
const Scalar& value,
const c10::string_view reduce,
const Tensor& output) {
Tensor src = at::native::empty_mps(index.sizes(),
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
self.suggest_memory_format());
Tensor src = at::native::empty_mps(
index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
src.fill_(value);
scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_reduce_out_mps", reduce);
}
TORCH_IMPL_FUNC(scatter_add_mps_out)
(const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& src,
const Tensor& output) {
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
scatter_mps_general(self, dim, index, src, output, "scatter_add_mps_out", "add");
}

View File

@ -2,10 +2,10 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
@ -27,21 +27,12 @@ std::vector<int64_t> getTopK0Shape(IntArrayRef sizes, const int64_t dim_) {
// topk
TORCH_IMPL_FUNC(topk_out_mps)
(const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted,
const Tensor& values,
const Tensor& indices)
{
(const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted, const Tensor& values, const Tensor& indices) {
using namespace mps;
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
TORCH_CHECK(k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range");
if (!is_macos_13_or_newer() && (k>16)) {
if (!is_macos_13_or_newer() && (k > 16)) {
TORCH_WARN_ONCE("torch.topk support for k>16 by MPS on MacOS 13+, please upgrade");
Tensor cpu_indices = indices.clone().to("cpu");
Tensor cpu_values = values.clone().to("cpu");
@ -58,15 +49,13 @@ TORCH_IMPL_FUNC(topk_out_mps)
}
// Handle empty tensors
if (self.numel() == 0)
{
if (self.numel() == 0) {
values.copy_(self);
indices.copy_(values.toType(at::ScalarType::Long));
return;
}
// Handle k == 0 case. Needed because MPSGraph does not support k == 0.
if (k == 0)
{
if (k == 0) {
const auto out_shape = getTopK0Shape(self.sizes(), dim);
values.resize_(out_shape);
indices.copy_(values.toType(at::ScalarType::Long));
@ -75,7 +64,7 @@ TORCH_IMPL_FUNC(topk_out_mps)
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
};
@ -85,14 +74,12 @@ TORCH_IMPL_FUNC(topk_out_mps)
// Input as placeholders
MPSShape* input_shape = getMPSShape(self);
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = string("topk:") + [ns_shape_key UTF8String] + ":" +
getMPSTypeString(self) +
":k" + to_string(k) + ":dim" + to_string(dim_) +
":largest" + to_string(largest);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) +
":dim" + to_string(dim_) + ":largest" + to_string(largest);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
@ -102,22 +89,19 @@ TORCH_IMPL_FUNC(topk_out_mps)
MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor;
MPSDataType dataType = getMPSDataType(self);
// #issue 104398441 sortWithTensor and argsortWithTensor
if (dataType != MPSDataTypeInt32 &&
dataType != MPSDataTypeFloat32 &&
dataType != MPSDataTypeFloat16) {
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor
toType:dataType
name:@"castInputTensor"];
}
MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
axis:(NSUInteger)dim
descending:largest
name:nil];
sortedTensor = [mpsGraph sliceTensor:sortedTensor
dimension:(NSUInteger)dim
start:((NSUInteger) 0)
length:k
start:((NSUInteger)0)length:k
name:nil];
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
axis:(NSInteger)dim
@ -125,8 +109,7 @@ TORCH_IMPL_FUNC(topk_out_mps)
name:@"argmax_out"];
argSortedTensor = [mpsGraph sliceTensor:argSortedTensor
dimension:dim
start:((NSUInteger) 0)
length:k
start:((NSUInteger)0)length:k
name:nil];
newCachedGraph->valuesTensor = sortedTensor;
newCachedGraph->indicesTensor = argSortedTensor;
@ -134,68 +117,54 @@ TORCH_IMPL_FUNC(topk_out_mps)
} else {
if ((dim_ != -1 && dim_ != self.dim() - 1) && (!largest)) {
// transpose and negate
MPSGraphTensor *transposedInput = [mpsGraph transposeTensor: newCachedGraph->selfTensor
dimension: (NSUInteger)self.dim()-1
withDimension: (NSUInteger)dim_
name: nil];
MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput
name: nil];
MPSGraphTensor * negatedTransposedInput = [mpsGraph negativeWithTensor:identity
name: nil];
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
topKWithSourceTensor:negatedTransposedInput
k:((NSUInteger) k)
MPSGraphTensor* transposedInput = [mpsGraph transposeTensor:newCachedGraph->selfTensor
dimension:(NSUInteger)self.dim() - 1
withDimension:(NSUInteger)dim_
name:nil];
MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil];
MPSGraphTensor* negatedTransposedInput = [mpsGraph negativeWithTensor:identity name:nil];
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedTransposedInput
k:((NSUInteger)k)name:nil];
MPSGraphTensor* valuesNegatedTransposed = outputMPSGraphTensors[0];
MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1];
MPSGraphTensor* valuesNegated = [mpsGraph transposeTensor:valuesNegatedTransposed
dimension:(NSUInteger)self.dim() - 1
withDimension:(NSUInteger)dim_
name:nil];
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil];
newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed
dimension:(NSUInteger)self.dim() - 1
withDimension:(NSUInteger)dim_
name:nil];
MPSGraphTensor *valuesNegatedTransposed = outputMPSGraphTensors[0];
MPSGraphTensor *indicesTransposed = outputMPSGraphTensors[1];
MPSGraphTensor *valuesNegated = [mpsGraph transposeTensor: valuesNegatedTransposed
dimension: (NSUInteger)self.dim()-1
withDimension: (NSUInteger)dim_
name: nil];
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated
name: nil];
newCachedGraph->indicesTensor = [mpsGraph transposeTensor: indicesTransposed
dimension: (NSUInteger)self.dim()-1
withDimension: (NSUInteger)dim_
name: nil];
} else if (dim_ != -1 && dim_ != self.dim() - 1) {
MPSGraphTensor *transposedInput = [mpsGraph transposeTensor: newCachedGraph->selfTensor
dimension: (NSUInteger)self.dim()-1
withDimension: (NSUInteger)dim_
name: nil];
MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput
name: nil];
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
topKWithSourceTensor:identity
k:((NSUInteger) k)
MPSGraphTensor* transposedInput = [mpsGraph transposeTensor:newCachedGraph->selfTensor
dimension:(NSUInteger)self.dim() - 1
withDimension:(NSUInteger)dim_
name:nil];
MPSGraphTensor *valuesTransposed = outputMPSGraphTensors[0];
MPSGraphTensor *indicesTransposed = outputMPSGraphTensors[1];
MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil];
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:identity
k:((NSUInteger)k)name:nil];
MPSGraphTensor* valuesTransposed = outputMPSGraphTensors[0];
MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1];
newCachedGraph->valuesTensor = [mpsGraph transposeTensor:valuesTransposed
dimension: (NSUInteger)self.dim()-1
withDimension: (NSUInteger)dim_
name: nil];
newCachedGraph->indicesTensor = [mpsGraph transposeTensor: indicesTransposed
dimension: (NSUInteger)self.dim()-1
withDimension: (NSUInteger)dim_
name: nil];
dimension:(NSUInteger)self.dim() - 1
withDimension:(NSUInteger)dim_
name:nil];
newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed
dimension:(NSUInteger)self.dim() - 1
withDimension:(NSUInteger)dim_
name:nil];
} else if (!largest) {
// only negate
MPSGraphTensor *negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor
name: nil];
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
topKWithSourceTensor:negatedInput
k:((NSUInteger) k)
name:nil];
MPSGraphTensor *valuesNegated = outputMPSGraphTensors[0];
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated
name: nil];
MPSGraphTensor* negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor name:nil];
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedInput
k:((NSUInteger)k)name:nil];
MPSGraphTensor* valuesNegated = outputMPSGraphTensors[0];
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil];
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
} else {
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
topKWithSourceTensor:newCachedGraph->selfTensor
k:((NSUInteger) k)
name:nil];
NSArray<MPSGraphTensor*>* outputMPSGraphTensors =
[mpsGraph topKWithSourceTensor:newCachedGraph->selfTensor k:((NSUInteger)k)name:nil];
newCachedGraph->valuesTensor = outputMPSGraphTensors[0];
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
}
@ -210,29 +179,21 @@ TORCH_IMPL_FUNC(topk_out_mps)
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
feeds = @{
inputPlaceholder.getMPSGraphTensor() :
inputPlaceholder.getMPSGraphTensorData()
};
feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
valuesPlaceholder.getMPSGraphTensor() :
valuesPlaceholder.getMPSGraphTensorData(),
indicesPlaceholder.getMPSGraphTensor() :
indicesPlaceholder.getMPSGraphTensorData()
valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(),
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
void check_shape_except_dim(const Tensor &first, const Tensor &second,
int dimension, int index)
{
void check_shape_except_dim(const Tensor& first, const Tensor& second, int dimension, int index) {
int first_dims = first.dim();
int second_dims = second.dim();
TORCH_CHECK(first_dims == second_dims,
"Tensors must have same number of dimensions: got ", first_dims,
" and ", second_dims);
TORCH_CHECK(
first_dims == second_dims, "Tensors must have same number of dimensions: got ", first_dims, " and ", second_dims);
for (int dim = 0; dim < first_dims; dim++) {
if (dim == dimension) {
continue;
@ -240,15 +201,20 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
int64_t first_dim_size = at::native::size(first, dim);
int64_t second_dim_size = at::native::size(second, dim);
TORCH_CHECK(first_dim_size == second_dim_size,
"Sizes of tensors must match except in dimension ", dim, ". Got ",
static_cast<long long>(first_dim_size), " and ",
static_cast<long long>(second_dim_size), " (The offending index is ",
index, ")");
"Sizes of tensors must match except in dimension ",
dim,
". Got ",
static_cast<long long>(first_dim_size),
" and ",
static_cast<long long>(second_dim_size),
" (The offending index is ",
index,
")");
}
}
TORCH_IMPL_FUNC(cat_out_mps)
(const ITensorListRef& inputs,
(const ITensorListRef& inputs,
int64_t dimension,
int64_t valid,
bool all_contiguous,
@ -256,7 +222,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
bool all_same_sizes_and_stride,
MemoryFormat memory_format,
const Tensor& out) {
using namespace mps;
if (out.numel() == 0) {
@ -271,13 +236,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
auto lap = at::get_overlap_status(out, t);
TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
"of the output memory locations. Found overlap in input tensor ", idx);
"of the output memory locations. Found overlap in input tensor ",
idx);
idx++;
}
// Check for type promotion
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
"torch.cat(): input types can't be cast to the desired output type ", out.scalar_type());
TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size());
"torch.cat(): input types can't be cast to the desired output type ",
out.scalar_type());
TORCH_CHECK(inputs.size() > 0, "torch.cat(): invalid number of inputs ", inputs.size());
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
@ -288,9 +255,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
// this behavior for backwards compatibility, but only for this specific size
// (i.e. other empty sizes are not skipped).
// FIXME: warn if this is the case
auto should_skip = [](const Tensor& t) {
return t.dim() == 1 && at::native::size(t, 0) == 0;
};
auto should_skip = [](const Tensor& t) { return t.dim() == 1 && at::native::size(t, 0) == 0; };
at::assert_no_internal_overlap(out);
Tensor notSkippedTensor;
@ -317,11 +282,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
for (const Tensor& t : inputs) {
TORCH_CHECK(t.device() == notSkippedTensor.device(),
"torch.cat(): all input tensors must be on the same device. Received ",
t.device(), " and ", notSkippedTensor.device());
t.device(),
" and ",
notSkippedTensor.device());
}
TORCH_CHECK(out.device() == notSkippedTensor.device(),
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
notSkippedTensor.device(), " and out is on ", out.device());
notSkippedTensor.device(),
" and out is on ",
out.device());
// TODO: For better performance by eliminating input tensor gathering and post transpose,
// TODO: it is better to keep the out tensor's memory format.
@ -354,23 +323,23 @@ TORCH_IMPL_FUNC(cat_out_mps)
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" +
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/ true) +
":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph *mpsGraph = make_mps_graph();
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
@ -383,7 +352,8 @@ TORCH_IMPL_FUNC(cat_out_mps)
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
newCachedGraph->inputTensors_[idx] =
mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
if (tensor.scalar_type() != out_dtype) {
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
toType:getMPSDataType(out_dtype)
@ -393,15 +363,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
}
}
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data()
count:len_tensor_array];
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
dimension:dimension // Maybe convert this from int64_t -> int32
name:nil];
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"outputTensor"];
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
}
newCachedGraph->outputTensor_ = outputTensor;
}
@ -418,9 +385,11 @@ TORCH_IMPL_FUNC(cat_out_mps)
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx],
tensor,
getMPSShape(tensor, MemoryFormat::Contiguous),
/*gatherTensorData*/true, scalar_type);
/*gatherTensorData*/ true,
scalar_type);
t_idx++;
}
i++;
@ -430,16 +399,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
outputDataType = MPSDataTypeInt8;
}
Placeholder outputPlaceholder = Placeholder(
cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
for (auto& inputPlaceholder : inputPlaceholders) {
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}

View File

@ -16,30 +16,26 @@
namespace at::native {
void get_shapes(MPSShape* input_shape_readonly,
NSMutableArray<NSNumber*>* &input_shape,
int num_input_dims, c10::MemoryFormat memory_format) {
NSMutableArray<NSNumber*>*& input_shape,
int num_input_dims,
c10::MemoryFormat memory_format) {
// Modify the shape
if(memory_format == at::MemoryFormat::Contiguous) {
for(int i = 0; i < num_input_dims; i++)
if (memory_format == at::MemoryFormat::Contiguous) {
for (int i = 0; i < num_input_dims; i++)
input_shape[i] = input_shape_readonly[i];
}
else { // ChannelsLast
} else { // ChannelsLast
auto num_channels = input_shape_readonly[1];
input_shape[0] = input_shape_readonly[0];
for(int i = 1; i < num_input_dims-1; i++)
input_shape[i] = input_shape_readonly[i+1];
input_shape[num_input_dims-1] = num_channels;
for (int i = 1; i < num_input_dims - 1; i++)
input_shape[i] = input_shape_readonly[i + 1];
input_shape[num_input_dims - 1] = num_channels;
}
}
// Note - Currently only supported for 4D image tensors
TORCH_IMPL_FUNC(softmax_mps_out)
(const Tensor& input_,
const int64_t dim,
const bool half_to_float,
const Tensor& output) {
(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) {
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS");
if (input_.numel() == 0) {
@ -49,25 +45,22 @@ TORCH_IMPL_FUNC(softmax_mps_out)
Tensor input;
if (input_.dim() == 0) {
input = input_.view(1);
}
else
} else
input = input_;
int64_t dim_ = maybe_wrap_dim(dim, input.dim());
TORCH_CHECK(
dim_ >= 0 && dim_ < input.dim(),
"Softmax:dim must be non-negative and less than input dimensions");
TORCH_CHECK(dim_ >= 0 && dim_ < input.dim(), "Softmax:dim must be non-negative and less than input dimensions");
const auto memory_format = input.suggest_memory_format();
// TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should match")
// TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should
// match")
using namespace mps;
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
@ -75,20 +68,20 @@ TORCH_IMPL_FUNC(softmax_mps_out)
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string mem_format_key = get_mem_format_string(memory_format);
MPSShape* input_shape_readonly = mps::getMPSShape(input);
int num_input_dims = [input_shape_readonly count];
// Check - Channels last implies 4d
TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4, "ChannelsLast implies 4d tensor")
TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4,
"ChannelsLast implies 4d tensor")
// Input shape changes based on memory format
NSMutableArray<NSNumber*>* input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
get_shapes(input_shape_readonly, input_shape, num_input_dims, memory_format);
// Change dim
if(memory_format == at::MemoryFormat::ChannelsLast && dim_ > 0) {
switch(dim_) {
if (memory_format == at::MemoryFormat::ChannelsLast && dim_ > 0) {
switch (dim_) {
case 1:
dim_ = 3;
break;
@ -105,13 +98,13 @@ TORCH_IMPL_FUNC(softmax_mps_out)
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":"
+ [ns_shape_key UTF8String] + ":" + std::to_string(dim_);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":" + [ns_shape_key UTF8String] +
":" + std::to_string(dim_);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -120,28 +113,20 @@ TORCH_IMPL_FUNC(softmax_mps_out)
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
// passing selector of softMaxWithTensor on the mpsGraph object
MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor
axis:(NSInteger)dim_
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor axis:(NSInteger)dim_ name:nil];
// Output needs to be contiguous format
if(memory_format == at::MemoryFormat::ChannelsLast) {
if (memory_format == at::MemoryFormat::ChannelsLast) {
auto N = input_shape[0];
auto H = input_shape[1];
auto W = input_shape[2];
auto C = input_shape[3];
outputTensor = [mpsGraph reshapeTensor:outputTensor
withShape:@[N, ([NSNumber numberWithInt:[H intValue]* [W intValue]]), C]
withShape:@[ N, ([NSNumber numberWithInt:[H intValue] * [W intValue]]), C ]
name:nil];
outputTensor = [mpsGraph transposeTensor:outputTensor
dimension:1
withDimension:2
name:nil];
outputTensor = [mpsGraph reshapeTensor:outputTensor
withShape:@[N, C, H, W]
name:nil];
outputTensor = [mpsGraph transposeTensor:outputTensor dimension:1 withDimension:2 name:nil];
outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:@[ N, C, H, W ] name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
@ -149,32 +134,24 @@ TORCH_IMPL_FUNC(softmax_mps_out)
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, input_shape);
// This must be the Contiguous shape
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
TORCH_IMPL_FUNC(softmax_backward_mps_out)
(const Tensor& grad_,
const Tensor& output_,
int64_t dim,
ScalarType input_dtype,
const Tensor& grad_input) {
(const Tensor& grad_, const Tensor& output_, int64_t dim, ScalarType input_dtype, const Tensor& grad_input) {
if (output_.numel() == 0) {
return;
}
@ -182,29 +159,24 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
Tensor grad;
if (grad_.dim() == 0) {
grad = grad_.view(1);
}
else
} else
grad = grad_;
Tensor output;
if (output_.dim() == 0) {
output = output_.view(1);
}
else
} else
output = output_;
int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
TORCH_CHECK(
dim_ >= 0 && dim_ < grad.dim(),
"Grad:dim must be non-negative and less than input dimensions");
TORCH_CHECK(dim_ >= 0 && dim_ < grad.dim(), "Grad:dim must be non-negative and less than input dimensions");
using namespace mps;
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* softmaxTensor_ = nil;
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* gradInputTensor_ = nil;
@ -213,17 +185,16 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
MPSShape* grad_shape = mps::getMPSShape(grad);
NSString* ns_shape_key = [[grad_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":"
+ [ns_shape_key UTF8String] + ":" + std::to_string(dim_);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" + [ns_shape_key UTF8String] + ":" +
std::to_string(dim_);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -235,9 +206,7 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
MPSGraphTensor* mulTensor = [mpsGraph multiplicationWithPrimaryTensor:softmaxTensor
secondaryTensor:gradOutputTensor
name:nil];
MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor
axis:(NSInteger)dim_
name:nil];
MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor axis:(NSInteger)dim_ name:nil];
MPSGraphTensor* gradSubTensor = [mpsGraph subtractionWithPrimaryTensor:gradOutputTensor
secondaryTensor:mulSumTensor
name:nil];
@ -251,7 +220,7 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder softmaxPlaceholder = Placeholder(cachedGraph->softmaxTensor_, output, grad_shape);
@ -262,12 +231,10 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
softmaxPlaceholder.getMPSGraphTensor() : softmaxPlaceholder.getMPSGraphTensorData(),
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}

View File

@ -2,10 +2,10 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
@ -42,7 +42,7 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -50,19 +50,20 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
// Input as placeholders
MPSShape* input_shape = getMPSShape(self);
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) +
":dim" + to_string(dim) + ":descending" + to_string(descending);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) +
":descending" + to_string(descending);
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:(BOOL)descending
name:@"sort_out"];
@ -88,14 +89,10 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
feeds = @{ inputPlaceholder.getMPSGraphTensor() :
inputPlaceholder.getMPSGraphTensorData()
};
feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
valuesPlaceholder.getMPSGraphTensor() :
valuesPlaceholder.getMPSGraphTensorData(),
indicesPlaceholder.getMPSGraphTensor() :
indicesPlaceholder.getMPSGraphTensorData()
valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(),
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);

View File

@ -4,14 +4,11 @@
namespace at::native {
Tensor& bincount_mps_impl(const Tensor& self,
const Tensor& weights,
Tensor& output) {
Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tensor& output) {
using namespace mps;
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* weightsTensor_ = nil;
MPSGraphTensor* scatterDataTensor_ = nil;
@ -24,37 +21,32 @@ Tensor& bincount_mps_impl(const Tensor& self,
@autoreleasepool {
string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
// Initialize graph
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
MPSGraphTensor* scatterDataTensor =
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
MPSGraphTensor *updatesTensor = nil;
MPSGraphTensor* updatesTensor = nil;
if (has_weights) {
updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, weights);
}
else {
updatesTensor = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape(self)
dataType:getMPSDataType(output)];
} else {
updatesTensor = [mpsGraph constantWithScalar:1.0f shape:getMPSShape(self) dataType:getMPSDataType(output)];
}
MPSGraphTensor *castedInputTensor = inputTensor;
MPSGraphTensor* castedInputTensor = inputTensor;
if (self.scalar_type() == kByte) {
castedInputTensor = [mpsGraph castTensor:inputTensor
toType:MPSDataTypeInt32
name:@"castInputTensor"];
castedInputTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"castInputTensor"];
}
MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
updatesTensor:updatesTensor
indicesTensor:castedInputTensor
axis:0
@ -70,7 +62,7 @@ Tensor& bincount_mps_impl(const Tensor& self,
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
@ -80,17 +72,16 @@ Tensor& bincount_mps_impl(const Tensor& self,
Placeholder weightsPlaceholder = Placeholder();
// Create dictionary of inputs/feeds and outputs/results
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[scatterPlaceholder.getMPSGraphTensor()] = scatterPlaceholder.getMPSGraphTensorData();
if(has_weights) {
if (has_weights) {
weightsPlaceholder = Placeholder(cachedGraph->weightsTensor_, weights);
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
// Run the graph
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
@ -108,43 +99,32 @@ Tensor _bincount_mps(const Tensor& self, const c10::optional<Tensor>& weights_op
TORCH_CHECK(minlength >= 0, "minlength should be >= 0");
if (self.dim() == 1 && self.numel() == 0) {
return at::zeros(
{minlength},
kLong,
c10::nullopt /* layout */,
kMPS,
c10::nullopt /* pin_memory */);
return at::zeros({minlength}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */);
}
TORCH_CHECK(self.dim() == 1 && self.min().item<int64_t>() >= 0, "bincount only supports 1-d non-negative integral inputs.");
TORCH_CHECK(self.dim() == 1 && self.min().item<int64_t>() >= 0,
"bincount only supports 1-d non-negative integral inputs.");
bool has_weights = weights.defined();
TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))), "weights should be 1-d and have the same length as input");
TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))),
"weights should be 1-d and have the same length as input");
const int64_t nbins = std::max(self.max().item<int64_t>() + 1L, minlength);
Tensor output;
Tensor weights_ = weights;
if (has_weights) {
if(weights.scalar_type() != ScalarType::Float &&
weights.scalar_type() != ScalarType::Int &&
if (weights.scalar_type() != ScalarType::Float && weights.scalar_type() != ScalarType::Int &&
weights.scalar_type() != ScalarType::Half) {
// Scatter doesn't work for int8/int16 dtypes
weights_ = weights.to(kInt);
}
output = at::zeros(
{nbins},
output = at::zeros({nbins},
optTypeMetaToScalarType(weights_.options().dtype_opt()),
weights_.options().layout_opt(),
weights_.options().device_opt(),
weights_.options().pinned_memory_opt());
}
else {
output = at::zeros(
{nbins},
kLong,
c10::nullopt /* layout */,
kMPS,
c10::nullopt /* pin_memory */);
} else {
output = at::zeros({nbins}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */);
}
return bincount_mps_impl(self, weights_, output);

View File

@ -1,22 +1,20 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
MPSGraphTensor *minTensor = nil, *maxTensor = nil;
};
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor)
{
MPSGraph *mpsGraph = cachedGraph->graph();
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
@ -36,39 +34,30 @@ void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor)
}
}
void check_min_max_dims(const OptionalTensorRef clamp_opt,
const Tensor& input_t,
string op_name) {
if(!clamp_opt->is_same_size(input_t)) {
void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, string op_name) {
if (!clamp_opt->is_same_size(input_t)) {
auto num_clamp_dims = clamp_opt->dim();
auto num_input_dims = input_t.dim();
auto clamp_shape = clamp_opt->sizes();
auto input_shape = input_t.sizes();
TORCH_CHECK(num_clamp_dims <= num_input_dims, op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
TORCH_CHECK(num_clamp_dims <= num_input_dims,
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
for(int i = 0; i < num_clamp_dims; i++)
for (int i = 0; i < num_clamp_dims; i++)
// One of the indices is allowed to be 1; will be handled by broadcast
TORCH_CHECK(clamp_shape[num_clamp_dims-1-i] == input_shape[num_input_dims-1-i] ||
clamp_shape[num_clamp_dims-1-i] == 1 ||
input_shape[num_input_dims-1-i] == 1,
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
op_name + ": clamp tensor trailing shape must match input tensor")
}
}
void fill_new_shape(int64_t num_input_dims,
int64_t num_clamp_dims,
int64_t *new_shape,
IntArrayRef clamp_shape) {
void fill_new_shape(int64_t num_input_dims, int64_t num_clamp_dims, int64_t* new_shape, IntArrayRef clamp_shape) {
// Extend the shape with ones to the left
int clamp_idx = 0;
for(int i = 0; i < num_input_dims; i++) {
if(i < num_input_dims - num_clamp_dims)
for (int i = 0; i < num_input_dims; i++) {
if (i < num_input_dims - num_clamp_dims)
new_shape[i] = 1;
else {
new_shape[i] = clamp_shape[clamp_idx];
@ -81,8 +70,7 @@ void clamp_tensor_out_mps(const Tensor& input_t,
const OptionalTensorRef min_opt,
const OptionalTensorRef max_opt,
const Tensor& output_t,
string op_name)
{
string op_name) {
const bool has_min = (min_opt.has_value() && min_opt->defined());
const bool has_max = (max_opt.has_value() && max_opt->defined());
@ -106,12 +94,12 @@ void clamp_tensor_out_mps(const Tensor& input_t,
std::vector<int64_t> new_min_arr(num_input_dims);
std::vector<int64_t> new_max_arr(num_input_dims);
if(has_min && num_min_dims < num_input_dims) {
if (has_min && num_min_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
new_min_shape = IntArrayRef(new_min_arr);
}
if(has_max && num_max_dims < num_input_dims) {
if (has_max && num_max_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
new_max_shape = IntArrayRef(new_max_arr);
}
@ -119,29 +107,28 @@ void clamp_tensor_out_mps(const Tensor& input_t,
Tensor min_opt_tensor;
Tensor max_opt_tensor;
if(has_min) {
if (has_min) {
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
}
if(has_max) {
if (has_max) {
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
}
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
auto tensor_key = has_min ? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor})
auto tensor_key = has_min
? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor})
: getTensorsStringKey({input_t, min_opt_tensor}))
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor})
: getTensorsStringKey({input_t}));
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor}) : getTensorsStringKey({input_t}));
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "")
+ "_tensor" + tensor_key;
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") + "_tensor" + tensor_key;
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -156,13 +143,13 @@ void clamp_tensor_out_mps(const Tensor& input_t,
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
if (has_min) {
auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor);
@ -173,9 +160,8 @@ void clamp_tensor_out_mps(const Tensor& input_t,
feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
@ -185,8 +171,7 @@ void clamp_scalar_out_mps(const Tensor& input_t,
const OptionalScalarRef min_opt,
const OptionalScalarRef max_opt,
const Tensor& output_t,
string op_name)
{
string op_name) {
using scalar_t = double;
const bool has_min = (min_opt.has_value());
@ -202,48 +187,47 @@ void clamp_scalar_out_mps(const Tensor& input_t,
max_scalar = max_opt.get().to<scalar_t>();
if (output_t.numel() == 0)
return ;
return;
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + (has_max ? ("_max:" + to_string(max_scalar)) : "")
+ "_scalar:" + getTensorsStringKey({input_t});
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
if (has_min)
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
shape:(mps::getMPSShape(input_t))
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];
newCachedGraph->minTensor = [mpsGraph
constantWithScalar:min_scalar
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
if (has_max)
newCachedGraph->maxTensor = [mpsGraph constantWithScalar:max_scalar
shape:(mps::getMPSShape(input_t))
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];
newCachedGraph->maxTensor = [mpsGraph
constantWithScalar:max_scalar
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
clamp_mps_graph(newCachedGraph, input_t);
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor , input_t);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
@ -253,51 +237,45 @@ void clamp_scalar_out_mps(const Tensor& input_t,
// APIs exposed to at::native scope
TORCH_IMPL_FUNC(clamp_Tensor_out_mps)
(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t)
{
(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t) {
mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__);
}
TORCH_IMPL_FUNC(clamp_out_mps)
(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t)
{
(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t) {
mps::clamp_scalar_out_mps(input_t, min, max, const_cast<Tensor&>(output_t), "clamp_out_mps");
}
TORCH_IMPL_FUNC(clamp_min_Tensor_out_mps)
(const Tensor& input_t, const Tensor& min, const Tensor& output_t)
{
(const Tensor& input_t, const Tensor& min, const Tensor& output_t) {
mps::clamp_tensor_out_mps(input_t, min, at::OptionalTensorRef(), output_t, __func__);
}
TORCH_IMPL_FUNC(clamp_min_out_mps)
(const Tensor& input_t, const Scalar& min, const Tensor& output_t)
{
(const Tensor& input_t, const Scalar& min, const Tensor& output_t) {
mps::clamp_scalar_out_mps(input_t, min, at::OptionalScalarRef(), output_t, __func__);
}
TORCH_IMPL_FUNC(clamp_max_Tensor_out_mps)
(const Tensor& input_t, const Tensor& max, const Tensor& output_t)
{
(const Tensor& input_t, const Tensor& max, const Tensor& output_t) {
mps::clamp_tensor_out_mps(input_t, at::OptionalTensorRef(), max, output_t, __func__);
}
TORCH_IMPL_FUNC(clamp_max_out_mps)
(const Tensor& input_t, const Scalar& max, const Tensor& output_t)
{
(const Tensor& input_t, const Scalar& max, const Tensor& output_t) {
mps::clamp_scalar_out_mps(input_t, at::OptionalScalarRef(), max, output_t, __func__);
}
Tensor& where_self_out_mps(const Tensor& condition,
const Tensor& self,
const Tensor& other,
Tensor& out) {
Tensor& where_self_out_mps(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
if (condition.scalar_type() == ScalarType::Byte) {
TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
TORCH_WARN_ONCE(
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
} else {
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition.scalar_type());
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool,
"where expected condition to be a boolean tensor, but got a tensor with dtype ",
condition.scalar_type());
}
Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
@ -305,13 +283,12 @@ Tensor& where_self_out_mps(const Tensor& condition,
MPSStream* stream = getCurrentMPSStream();
// Empty output
if(out.numel() == 0)
if (out.numel() == 0)
return out;
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* conditionTensor_ = nil;
MPSGraphTensor* selfTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
@ -338,21 +315,20 @@ Tensor& where_self_out_mps(const Tensor& condition,
}
@autoreleasepool {
string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
MPSGraphTensor* conditionTensor =
mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self));
MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other));
@ -368,15 +344,15 @@ Tensor& where_self_out_mps(const Tensor& condition,
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder conditionPlaceholder = Placeholder(
cachedGraph->conditionTensor_, cond_bool, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, conditionDataType);
Placeholder selfPlaceholder = Placeholder(
cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
Placeholder otherPlaceholder = Placeholder(
cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
Placeholder selfPlaceholder =
Placeholder(cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
Placeholder otherPlaceholder =
Placeholder(cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@ -384,21 +360,16 @@ Tensor& where_self_out_mps(const Tensor& condition,
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return out;
}
Tensor where_mps(const Tensor& condition,
const Tensor& self,
const Tensor& other) {
Tensor where_mps(const Tensor& condition, const Tensor& self, const Tensor& other) {
auto max_dim = std::max(condition.dim(), std::max(self.dim(), other.dim()));
// How many leading dimensions do we broadcast across for each Tensor?
@ -409,8 +380,7 @@ Tensor where_mps(const Tensor& condition,
std::vector<int64_t> out_arr(max_dim);
// Broadcasted output shape
for(int i = 0; i < max_dim; i++) {
for (int i = 0; i < max_dim; i++) {
// Use up the leading broadcast dimensions for each Tensor, then continue from the start of the "actual" shape
int64_t cond_idx = i < cond_num_implicit_ones ? 1 : (condition.size(i - cond_num_implicit_ones));
int64_t self_idx = i < self_num_implicit_ones ? 1 : (self.size(i - self_num_implicit_ones));
@ -418,21 +388,28 @@ Tensor where_mps(const Tensor& condition,
auto max_idx = std::max({cond_idx, self_idx, other_idx});
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1), i, "'th index ", cond_idx, " of condition tensor does not match the other tensors")
TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1), i, "'th index ", self_idx, " of x tensor does not match the other tensors")
TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1), i, "'th index ", other_idx, " of x tensor does not match the other tensors")
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1),
i,
"'th index ",
cond_idx,
" of condition tensor does not match the other tensors")
TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1),
i,
"'th index ",
self_idx,
" of x tensor does not match the other tensors")
TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1),
i,
"'th index ",
other_idx,
" of x tensor does not match the other tensors")
out_arr[i] = (cond_idx == 0 || self_idx == 0 || other_idx == 0) ? 0 : max_idx;
}
Tensor ret = empty_mps(IntArrayRef(out_arr),
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
self.suggest_memory_format());
Tensor ret = empty_mps(
IntArrayRef(out_arr), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
return where_self_out_mps(condition, self, other, ret);
}
Tensor& nan_to_num_out_mps(const Tensor& self,
@ -440,8 +417,11 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
c10::optional<double> pos_inf,
c10::optional<double> neg_inf,
Tensor& result) {
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "nan_to_num: dtype of out: ",
result.scalar_type(), " should be same as input: ", self.scalar_type());
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"nan_to_num: dtype of out: ",
result.scalar_type(),
" should be same as input: ",
self.scalar_type());
if (result.numel() == 0) {
return result;
}
@ -452,7 +432,7 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
}
using namespace mps;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* selfTensor = nil;
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* nanReplacementTensor = nil;
@ -467,25 +447,27 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]);
newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]);
newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]);
newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
MPSGraphTensor* nanFreeTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isNaNWithTensor: newCachedGraph->selfTensor name:nil]
truePredicateTensor: newCachedGraph->nanReplacementTensor
falsePredicateTensor: newCachedGraph->selfTensor
name: nil];
MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor: nanFreeTensor
secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype]
name: nil];
MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor: nanFreeTensor name:nil];
MPSGraphTensor* nanFreeTensor =
[mpsGraph selectWithPredicateTensor:[mpsGraph isNaNWithTensor:newCachedGraph->selfTensor name:nil]
truePredicateTensor:newCachedGraph->nanReplacementTensor
falsePredicateTensor:newCachedGraph->selfTensor
name:nil];
MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor:nanFreeTensor
secondaryTensor:[mpsGraph constantWithScalar:0.0
dataType:self_dtype]
name:nil];
MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor:nanFreeTensor name:nil];
// workaround for Monterey; On Ventura the output of lessThan() is always Boolean
if (subZeroTensor.dataType != MPSDataTypeBool) {
subZeroTensor = castMPSTensor(mpsGraph, subZeroTensor, kBool);
@ -493,17 +475,18 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
if (isInfTensor.dataType != MPSDataTypeBool) {
isInfTensor = castMPSTensor(mpsGraph, isInfTensor, kBool);
}
MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: subZeroTensor
secondaryTensor: isInfTensor
name: nil];
MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor: isNegInfTensor
truePredicateTensor: newCachedGraph->negInfReplacementTensor
falsePredicateTensor: nanFreeTensor
name: nil];
newCachedGraph->outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isInfiniteWithTensor: negInfFreeTensor name:nil]
truePredicateTensor: newCachedGraph->posInfReplacementTensor
falsePredicateTensor: negInfFreeTensor
name: nil];
MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor:subZeroTensor
secondaryTensor:isInfTensor
name:nil];
MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor:isNegInfTensor
truePredicateTensor:newCachedGraph->negInfReplacementTensor
falsePredicateTensor:nanFreeTensor
name:nil];
newCachedGraph->outputTensor =
[mpsGraph selectWithPredicateTensor:[mpsGraph isInfiniteWithTensor:negInfFreeTensor name:nil]
truePredicateTensor:newCachedGraph->posInfReplacementTensor
falsePredicateTensor:negInfFreeTensor
name:nil];
}
return newCachedGraph;
});
@ -511,12 +494,10 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
MPSScalar nanReplacementScalar, posInfReplacementScalar, negInfReplacementScalar;
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, self.scalar_type(), "nan_to_num_mps", [&]() {
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
scalar_t pos_inf_replacement = pos_inf.has_value() ?
static_cast<scalar_t>(pos_inf.value()) :
std::numeric_limits<scalar_t>::max();
scalar_t neg_inf_replacement = neg_inf.has_value() ?
static_cast<scalar_t>(neg_inf.value()) :
std::numeric_limits<scalar_t>::lowest();
scalar_t pos_inf_replacement =
pos_inf.has_value() ? static_cast<scalar_t>(pos_inf.value()) : std::numeric_limits<scalar_t>::max();
scalar_t neg_inf_replacement =
neg_inf.has_value() ? static_cast<scalar_t>(neg_inf.value()) : std::numeric_limits<scalar_t>::lowest();
nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type());
posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type());
@ -533,9 +514,8 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
cachedGraph->posInfReplacementTensor : getMPSGraphTensorFromScalar(stream, posInfReplacementScalar),
cachedGraph->negInfReplacementTensor : getMPSGraphTensorFromScalar(stream, negInfReplacementScalar),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return result;

View File

@ -14,10 +14,7 @@
namespace at::native {
TORCH_IMPL_FUNC(triu_mps_out)
(const Tensor& self,
int64_t k,
const Tensor &output) {
(const Tensor& self, int64_t k, const Tensor& output) {
using namespace mps;
if (self.numel() == 0) {
@ -26,22 +23,21 @@ TORCH_IMPL_FUNC(triu_mps_out)
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -50,12 +46,10 @@ TORCH_IMPL_FUNC(triu_mps_out)
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1
dataType:MPSDataTypeInt32];
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
if(k > 0) {
MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k-1)
dataType:MPSDataTypeInt32];
if (k > 0) {
MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32];
MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
numLowerTensor:minusOneTensor
numUpperTensor:diagMinusOneTensor
@ -63,10 +57,8 @@ TORCH_IMPL_FUNC(triu_mps_out)
outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:complementTensor
name:nil];
}
else {
MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k)
dataType:MPSDataTypeInt32];
} else {
MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32];
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
numLowerTensor:minusDiagTensor
numUpperTensor:minusOneTensor
@ -78,29 +70,23 @@ TORCH_IMPL_FUNC(triu_mps_out)
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
TORCH_IMPL_FUNC(tril_mps_out)
(const Tensor& self,
int64_t k,
const Tensor &output) {
(const Tensor& self, int64_t k, const Tensor& output) {
using namespace mps;
if (self.numel() == 0) {
@ -109,22 +95,21 @@ TORCH_IMPL_FUNC(tril_mps_out)
MPSStream* stream = getCurrentMPSStream();
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -133,20 +118,16 @@ TORCH_IMPL_FUNC(tril_mps_out)
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1
dataType:MPSDataTypeInt32];
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
if(k >= 0) {
MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k
dataType:MPSDataTypeInt32];
if (k >= 0) {
MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32];
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
numLowerTensor:minusOneTensor
numUpperTensor:diagTensor
name:nil];
}
else {
MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k-1)
dataType:MPSDataTypeInt32];
} else {
MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32];
MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
numLowerTensor:negDiagMinusOneTensor
numUpperTensor:minusOneTensor
@ -161,22 +142,19 @@ TORCH_IMPL_FUNC(tril_mps_out)
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
} // namespace at::native

View File

@ -1,7 +1,7 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
@ -9,14 +9,16 @@ namespace mps {
typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*);
using is_noop_p = std::function<bool(const Tensor&)>;
bool is_empty_tensor(const Tensor& self) {
return self.numel() == 0;
}
void unary_op(const Tensor& self, const Tensor& output, std::string op_name, UnaryOpBlock unaryBlock, is_noop_p is_noop = is_empty_tensor)
{
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte ),
void unary_op(const Tensor& self,
const Tensor& output,
std::string op_name,
UnaryOpBlock unaryBlock,
is_noop_p is_noop = is_empty_tensor) {
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
"MPS support unary op with uint8 natively starting from macOS 13.0");
if (!output.is_same_size(self)) {
output.resize_(self.sizes());
@ -30,9 +32,9 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
string key = op_name + getTensorsStringKey({self, output});
auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^ MPSCachedGraph* () {
MPSUnaryCachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^MPSCachedGraph*() {
MPSUnaryCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new MPSUnaryCachedGraph(mpsGraph);
@ -55,18 +57,15 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, /*mpsShape=*/nullptr, gatherTensorData);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
}
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
{
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
// Rounding is a no-op for integral types, and also a reasonable workaround
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
// See https://github.com/pytorch/pytorch/issues/84995
@ -75,100 +74,91 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
return inputTensor;
}
if(!is_macos_13_or_newer()) {
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
dataType:inputTensor.dataType];
if (!is_macos_13_or_newer()) {
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
return [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil]
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
name:nil];
} else {
return [mpsGraph truncateWithTensor:inputTensor
name:nil];
return [mpsGraph truncateWithTensor:inputTensor name:nil];
}
};
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
dataType:inputTensor.dataType];
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor
secondaryTensor:oneTensor
name:nil];
return [mpsGraph logarithmWithTensor:addedTensor
name:nil];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:inputTensor.dataType];
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor secondaryTensor:oneTensor name:nil];
return [mpsGraph logarithmWithTensor:addedTensor name:nil];
}
} // namespace mps
TORCH_IMPL_FUNC(trunc_out_mps) (const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "trunc_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
{ return mps::trunc_tensor(mpsGraph, inputTensor); });
TORCH_IMPL_FUNC(trunc_out_mps)(const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "trunc_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return mps::trunc_tensor(mpsGraph, inputTensor);
});
}
TORCH_IMPL_FUNC(signbit_out_mps) (const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "signbit_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
TORCH_IMPL_FUNC(signbit_out_mps)(const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "signbit_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* output;
// signbit is not implemented for int64 type.
// workaround for `Function signbitOp_i64 was not found in the library`
if ([inputTensor dataType] == MPSDataTypeInt64) {
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
} else {
output = [mpsGraph signbitWithTensor: inputTensor name: nil];
output = [mpsGraph signbitWithTensor:inputTensor name:nil];
}
return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool);
});
}
TORCH_IMPL_FUNC(sign_out_mps) (const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "sign_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
TORCH_IMPL_FUNC(sign_out_mps)(const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "sign_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
// Sign op is not implemented in MPS as of MacOS13.0 beta, so simulate it using clamp
if ([inputTensor dataType] == MPSDataTypeInt64) {
return [mpsGraph clampWithTensor:inputTensor
minValueTensor:[mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt64]
maxValueTensor:[mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt64]
name: nil];
name:nil];
}
return [mpsGraph signWithTensor: inputTensor name: nil];
return [mpsGraph signWithTensor:inputTensor name:nil];
});
}
#define CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(func_out, func_stub) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& output) { \
mps::unary_op(self, output, #func_out, \
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }, \
[](const Tensor& t) -> bool { \
return t.numel() == 0 || isIntegralType(t.scalar_type(), true); \
}); \
}
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \
mps::unary_op( \
self, \
output, \
#func_out, \
^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
}, \
[](const Tensor& t) -> bool { return t.numel() == 0 || isIntegralType(t.scalar_type(), true); }); \
}
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(ceil_out_mps, ceil)
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(floor_out_mps, floor)
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(round_out_mps, round)
#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& output) { \
mps::unary_op(self, output, #func_out, \
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \
}
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
}); \
}
#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
Tensor& func_out(const Tensor& self, Tensor& output) { \
mps::unary_op(self, output, #func_out, \
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \
Tensor& func_out(const Tensor& self, Tensor& output) { \
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
}); \
return output; \
}
}
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
@ -195,81 +185,59 @@ CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute)
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output)
{
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
auto bool_self = self.to(ScalarType::Bool);
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor){ return [mpsGraph notWithTensor:inputTensor name:nil];});
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph notWithTensor:inputTensor name:nil];
});
return output;
}
TORCH_IMPL_FUNC(sigmoid_out_mps) (const Tensor& self, const Tensor& output)
{
TORCH_IMPL_FUNC(sigmoid_out_mps)(const Tensor& self, const Tensor& output) {
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support sigmoid op with int64 input");
mps::unary_op(self, output, "sigmoid_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
mps::unary_op(self, output, "sigmoid_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph sigmoidWithTensor:inputTensor name:nil];
});
}
TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output)
{
TORCH_IMPL_FUNC(log1p_out_mps)(const Tensor& self, const Tensor& output) {
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input");
mps::unary_op(self, output, "log1p_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
mps::unary_op(self, output, "log1p_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return mps::log1p(mpsGraph, inputTensor);
});
}
TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output) {
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
mps::unary_op(self, output, "frac_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto zeroTensor = [mpsGraph constantWithScalar:0.0
dataType:inputTensor.dataType];
auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
auto truncTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil]
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:truncTensor
name: nil];
return [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:truncTensor name:nil];
});
}
TORCH_IMPL_FUNC(expm1_out_mps) (const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "expm1_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:ePowTensor
secondaryTensor:oneTensor
name: nil];
TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) {
mps::unary_op(self, output, "expm1_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor name:nil];
return [mpsGraph subtractionWithPrimaryTensor:ePowTensor secondaryTensor:oneTensor name:nil];
});
}
void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& output, const std::string op_name) {
std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]";
mps::unary_op(self, output, key,
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];
mps::unary_op(self, output, key, ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* logitInputTensor;
if (eps.has_value()) {
MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value()
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps.value() shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor secondaryTensor:lowTensor name:nil];
logitInputTensor = [mpsGraph clampWithTensor:inputTensor
minValueTensor:lowTensor
maxValueTensor:highTensor
@ -278,56 +246,43 @@ void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& outpu
logitInputTensor = inputTensor;
}
MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: logitInputTensor
name: nil];
MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
MPSGraphTensor* oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor
secondaryTensor:logitInputTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
secondaryTensor:oneMinusInputTensor
name:nil];
return [mpsGraph logarithmWithTensor:outputTensor
name:nil];
return [mpsGraph logarithmWithTensor:outputTensor name:nil];
});
}
Tensor& logit_out_mps(const Tensor& self,
c10::optional<double> eps,
Tensor& result) {
Tensor& logit_out_mps(const Tensor& self, c10::optional<double> eps, Tensor& result) {
logit_mps_impl(self, eps, result, "logit_out_mps");
return result;
}
Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
Tensor result = at::native::empty_mps(
self.sizes(),
ScalarType::Float,
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
Tensor result =
at::native::empty_mps(self.sizes(), ScalarType::Float, c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
logit_mps_impl(self, eps, result, "logit_mps");
return result;
}
TORCH_IMPL_FUNC(logit_backward_out_mps) (
const Tensor& grad_output,
const Tensor& input,
c10::optional<double> eps,
const Tensor& grad_input)
{
TORCH_IMPL_FUNC(logit_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, c10::optional<double> eps, const Tensor& grad_input) {
using namespace mps;
// Empty output
if(grad_input.numel() == 0)
if (grad_input.numel() == 0)
return;
double eps_ = eps ? eps.value() : -1.0;
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -335,14 +290,13 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" +
"[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]";
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + "[" +
(eps.has_value() ? std::to_string(eps.value()) : "-1") + "]";
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
@ -351,40 +305,32 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input);
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor
secondaryTensor: highTensor
name: nil];
MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor
secondaryTensor: inputGreaterThanHighPredicateTensor
name: nil];
MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: inputTensor
name: nil];
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:lowTensor
name:nil];
MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor
secondaryTensor:lowTensor
name:nil];
MPSGraphTensor* inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor
secondaryTensor:highTensor
name:nil];
MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor:inputLessThanLowPredicateTensor
secondaryTensor:inputGreaterThanHighPredicateTensor
name:nil];
MPSGraphTensor* oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor
secondaryTensor:inputTensor
name:nil];
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
secondaryTensor:oneMinusInputTensor
name:nil];
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor
secondaryTensor:outputTensor
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor secondaryTensor:outputTensor name:nil];
outputTensor = [mpsGraph selectWithPredicateTensor:outOfIntervalTensor
truePredicateTensor:zeroTensor
falsePredicateTensor:outputTensor
name:nil];
outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor
truePredicateTensor: zeroTensor
falsePredicateTensor: outputTensor
name: nil];
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
@ -392,7 +338,7 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
@ -403,25 +349,25 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
const Tensor& result) {
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
"Expected wrapped dim to be between 0 and ",
self.ndimension(),
" but got ",
wrapped_dim,
"(original dim is ",
dim,
")");
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
@ -430,24 +376,22 @@ TORCH_IMPL_FUNC(cumsum_out_mps)
}
auto input = dtype.has_value() ? self.to(dtype.value()) : self;
// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
// fixed in macOS 13.3
bool castInputData = (isIntegralType(input.scalar_type()) &&
input.scalar_type() != ScalarType::Int &&
// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to
// int32 fixed in macOS 13.3
bool castInputData = (isIntegralType(input.scalar_type()) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");
mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
mps::unary_op(input,
result,
"cumsum_out_mp" + std::to_string(dim),
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
axis: dim
name: nil];
auto rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}

View File

@ -1,15 +1,14 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
struct UniqueCachedGraph : public MPSCachedGraph
{
UniqueCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct UniqueCachedGraph : public MPSCachedGraph {
UniqueCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* inverseIndicesTensor_ = nil;
@ -17,139 +16,108 @@ struct UniqueCachedGraph : public MPSCachedGraph
MPSGraphTensor* lengthTensor_ = nil;
};
static std::string getUniqueKey(const ScalarType& dtype, const IntArrayRef& base_shape,
const bool return_inverse, const bool return_counts,
const bool consecutive, c10::optional<int64_t> dimOpt)
{
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) +
"]:[" + (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) +
"]:[" + to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
static std::string getUniqueKey(const ScalarType& dtype,
const IntArrayRef& base_shape,
const bool return_inverse,
const bool return_counts,
const bool consecutive,
c10::optional<int64_t> dimOpt) {
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
(dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" +
to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
}
// dim arg not supported when non consecutive, ie sorted
std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCachedGraph *uniqueGraph, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dimOpt) {
std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self,
UniqueCachedGraph* uniqueGraph,
const bool return_inverse,
const bool return_counts,
const bool consecutive,
c10::optional<int64_t> dimOpt) {
int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0;
MPSGraph *graph = uniqueGraph->graph();
MPSGraphTensor *inputTensor = uniqueGraph->inputTensor_;
MPSShape *shape = [inputTensor shape];
MPSShape *destShape = shape;
MPSGraph* graph = uniqueGraph->graph();
MPSGraphTensor* inputTensor = uniqueGraph->inputTensor_;
MPSShape* shape = [inputTensor shape];
MPSShape* destShape = shape;
uint64_t length = [shape[dim] unsignedIntValue];
MPSDataType dataType = [inputTensor dataType];
MPSGraphTensor *resultTensor = nil;
MPSGraphTensor *inverseIndicesTensor = nil;
MPSGraphTensor *countTensor = nil;
MPSGraphTensor *lengthTensor = nil;
MPSGraphTensor* resultTensor = nil;
MPSGraphTensor* inverseIndicesTensor = nil;
MPSGraphTensor* countTensor = nil;
MPSGraphTensor* lengthTensor = nil;
if (length <= 1) {
// Trivial case, only 1 element everything is unique
resultTensor = inputTensor;
lengthTensor = [graph constantWithScalar:0.0f
dataType:MPSDataTypeInt32];
lengthTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
if (return_inverse) {
inverseIndicesTensor = [graph constantWithScalar:0.0f
dataType:MPSDataTypeInt32];
inverseIndicesTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
}
if (return_counts) {
countTensor = [graph constantWithScalar:1.0f
dataType:MPSDataTypeInt32];
countTensor = [graph constantWithScalar:1.0f dataType:MPSDataTypeInt32];
}
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
}
// #issue 104398441 sortWithTensor only supports following types, cast if necessary
if (dataType != MPSDataTypeInt32 &&
dataType != MPSDataTypeFloat32 &&
dataType != MPSDataTypeFloat16) {
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
inputTensor = [graph castTensor:inputTensor
toType:dataType
name:@"castInputTensor"];
inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
}
bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
if (needsFlatten) {
inputTensor = [graph reshapeTensor:inputTensor
withShape:@[@-1]
name:nil];
inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
length = 1;
for (const auto i: c10::irange([shape count])) {
for (const auto i : c10::irange([shape count])) {
if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
}
}
destShape = @[[NSNumber numberWithUnsignedInteger:length]];
destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
}
MPSGraphTensor *sortedInput = nil;
MPSGraphTensor* sortedInput = nil;
if (consecutive) {
sortedInput = inputTensor;
} else {
sortedInput = [graph sortWithTensor:inputTensor
axis:0
name:nil];
sortedInput = [graph sortWithTensor:inputTensor axis:0 name:nil];
}
MPSGraphTensor *frontNMinusOne = [graph sliceTensor:sortedInput
dimension:dim
start:0
length:length-1
name:nil];
MPSGraphTensor *backNMinusOne = [graph sliceTensor:sortedInput
dimension:dim
start:1
length:length-1
name:nil];
MPSGraphTensor *notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne
MPSGraphTensor* frontNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:0 length:length - 1 name:nil];
MPSGraphTensor* backNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:1 length:length - 1 name:nil];
MPSGraphTensor* notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne
secondaryTensor:frontNMinusOne
name:nil];
MPSGraphTensor *mask = [graph castTensor:notEqualToPreviousElement
toType:MPSDataTypeInt32
name:@"castMaskTensor"];
MPSGraphTensor* mask = [graph castTensor:notEqualToPreviousElement toType:MPSDataTypeInt32 name:@"castMaskTensor"];
// If comparing tensors, not scalars, check if entire tensor matches previos element using reductionOr over tensor
if (dimOpt.has_value() && [shape count] != 1) {
NSMutableArray *axes = [[NSMutableArray alloc] initWithCapacity:[shape count]-1];
NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1];
for (const auto axis : c10::irange([shape count])) {
if (axis != dim) {
[axes addObject:[NSNumber numberWithUnsignedInteger:axis]];
}
}
mask = [graph reductionOrWithTensor:mask
axes:axes
name:nil];
mask = [graph squeezeTensor:mask
axes:axes
name:nil];
mask = [graph reductionOrWithTensor:mask axes:axes name:nil];
mask = [graph squeezeTensor:mask axes:axes name:nil];
[axes release];
}
MPSGraphTensor *scannedIndices = [graph cumulativeSumWithTensor:mask
axis:0
name:nil];
lengthTensor = [graph sliceTensor:scannedIndices
dimension:0
start:length-2
length:1
name:nil];
MPSGraphTensor* scannedIndices = [graph cumulativeSumWithTensor:mask axis:0 name:nil];
lengthTensor = [graph sliceTensor:scannedIndices dimension:0 start:length - 2 length:1 name:nil];
MPSGraphTensor *minusOneTensor = [graph constantWithScalar:-1.0f
dataType:MPSDataTypeInt32];
MPSGraphTensor *maskedIndices = [graph selectWithPredicateTensor:mask
MPSGraphTensor* minusOneTensor = [graph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
MPSGraphTensor* maskedIndices = [graph selectWithPredicateTensor:mask
truePredicateTensor:scannedIndices
falsePredicateTensor:minusOneTensor
name:nil];
MPSGraphTensor *zeroTensor = [graph constantWithScalar:0.0f
shape:@[@1]
dataType:MPSDataTypeInt32];
MPSGraphTensor *maskedIndicesWithHead = [graph concatTensors:@[zeroTensor, maskedIndices]
dimension:0
name:nil];
MPSGraphTensor *scannedIndicesWithHead = [graph concatTensors:@[zeroTensor, scannedIndices]
dimension:0
name:nil];
MPSGraphTensor* zeroTensor = [graph constantWithScalar:0.0f shape:@[ @1 ] dataType:MPSDataTypeInt32];
MPSGraphTensor* maskedIndicesWithHead = [graph concatTensors:@[ zeroTensor, maskedIndices ] dimension:0 name:nil];
MPSGraphTensor* scannedIndicesWithHead = [graph concatTensors:@[ zeroTensor, scannedIndices ] dimension:0 name:nil];
resultTensor = [graph scatterWithUpdatesTensor:sortedInput
indicesTensor:maskedIndicesWithHead
@ -159,41 +127,35 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
name:nil];
// Cast back if necessary
if ([uniqueGraph->inputTensor_ dataType] != dataType) {
resultTensor = [graph castTensor:resultTensor
toType:[uniqueGraph->inputTensor_ dataType]
name:@"castResultTensor"];
resultTensor = [graph castTensor:resultTensor toType:[uniqueGraph->inputTensor_ dataType] name:@"castResultTensor"];
}
// Compute optional returned tensors if requested
if(return_inverse) {
MPSGraphTensor *argSortedInput = nil;
if (return_inverse) {
MPSGraphTensor* argSortedInput = nil;
if (consecutive)
argSortedInput = [graph coordinateAlongAxis:0
withShape:@[[NSNumber numberWithUnsignedInteger:length]]
withShape:@[ [NSNumber numberWithUnsignedInteger:length] ]
name:nil];
else
argSortedInput = [graph argSortWithTensor:inputTensor
axis:0
name:nil];
argSortedInput = [graph argSortWithTensor:inputTensor axis:0 name:nil];
inverseIndicesTensor = [graph scatterWithUpdatesTensor:scannedIndicesWithHead
indicesTensor:argSortedInput
shape:@[[NSNumber numberWithUnsignedInteger:length]]
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
axis:0
mode:MPSGraphScatterModeAdd
name:nil];
if (needsFlatten)
inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor
withShape:shape
name:nil];
inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor withShape:shape name:nil];
}
if (return_counts) {
MPSGraphTensor *unitTensor = [graph constantWithScalar:1.0f
shape:@[[NSNumber numberWithUnsignedInteger:length]]
MPSGraphTensor* unitTensor = [graph constantWithScalar:1.0f
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
dataType:MPSDataTypeInt32];
countTensor = [graph scatterWithUpdatesTensor:unitTensor
indicesTensor:scannedIndicesWithHead
shape:@[[NSNumber numberWithUnsignedInteger:length]]
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
axis:0
mode:MPSGraphScatterModeAdd
name:nil];
@ -202,16 +164,19 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
}
static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dim) {
static UniqueCachedGraph* getUniqueGraph(const Tensor& self,
const bool return_inverse,
const bool return_counts,
const bool consecutive,
c10::optional<int64_t> dim) {
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim);
UniqueCachedGraph* cachedGraph = static_cast<UniqueCachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
UniqueCachedGraph *newCachedGraph = nil;
UniqueCachedGraph* cachedGraph = static_cast<UniqueCachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
UniqueCachedGraph* newCachedGraph = nil;
@autoreleasepool {
// Initialize graph
@ -232,15 +197,20 @@ static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_i
}
return newCachedGraph;
});
cachedGraph = static_cast<UniqueCachedGraph *>(tmpCachedGraph);
cachedGraph = static_cast<UniqueCachedGraph*>(tmpCachedGraph);
}
return cachedGraph;
}
}
void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& output,
Tensor& inverse_indices, Tensor& counts, Tensor& length,
bool return_inverse, bool return_counts){
void runUniqueGraph(UniqueCachedGraph* uniqueGraph,
const Tensor& input,
Tensor& output,
Tensor& inverse_indices,
Tensor& counts,
Tensor& length,
bool return_inverse,
bool return_counts) {
Placeholder inputPlaceholder = Placeholder(uniqueGraph->inputTensor_, input);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
@ -249,10 +219,8 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary];
Placeholder outputPlaceholder = Placeholder(uniqueGraph->outputTensor_, output);
Placeholder lengthPlaceholder = Placeholder(uniqueGraph->lengthTensor_, length);
[results setObject:outputPlaceholder.getMPSGraphTensorData()
forKey:outputPlaceholder.getMPSGraphTensor()];
[results setObject:lengthPlaceholder.getMPSGraphTensorData()
forKey:lengthPlaceholder.getMPSGraphTensor()];
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
[results setObject:lengthPlaceholder.getMPSGraphTensorData() forKey:lengthPlaceholder.getMPSGraphTensor()];
if (return_inverse) {
Placeholder inverseIndicesPlaceholder = Placeholder(uniqueGraph->inverseIndicesTensor_, inverse_indices);
[results setObject:inverseIndicesPlaceholder.getMPSGraphTensorData()
@ -260,8 +228,7 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
}
if (return_counts) {
Placeholder countsPlaceholder = Placeholder(uniqueGraph->countsTensor_, counts);
[results setObject:countsPlaceholder.getMPSGraphTensorData()
forKey:countsPlaceholder.getMPSGraphTensor()];
[results setObject:countsPlaceholder.getMPSGraphTensorData() forKey:countsPlaceholder.getMPSGraphTensor()];
}
// Run the graph
@ -271,9 +238,11 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
} // namespace mps
std::tuple<Tensor, Tensor, Tensor>
_unique_impl_mps(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dimOpt) {
std::tuple<Tensor, Tensor, Tensor> _unique_impl_mps(const Tensor& self,
const bool return_inverse,
const bool return_counts,
const bool consecutive,
c10::optional<int64_t> dimOpt) {
const Tensor& input = self.contiguous();
// get flat output size
@ -303,7 +272,7 @@ _unique_impl_mps(const Tensor& self, const bool return_inverse, const bool retur
return std::make_tuple(output, inverse_indices, counts);
}
mps::UniqueCachedGraph *uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt);
mps::UniqueCachedGraph* uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt);
mps::runUniqueGraph(uniqueGraph, input, output, inverse_indices, counts, length, return_inverse, return_counts);
int64_t lengthScalar = length.item<int64_t>() + 1; // length actually holds max index, add 1
@ -316,17 +285,14 @@ _unique_impl_mps(const Tensor& self, const bool return_inverse, const bool retur
return std::make_tuple(output, inverse_indices, counts);
}
static
std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
return std::make_tuple(
get<0>(out).to("mps"),
get<1>(out).to("mps"),
get<2>(out).to("mps"));
static std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
return std::make_tuple(get<0>(out).to("mps"), get<1>(out).to("mps"), get<2>(out).to("mps"));
}
std::tuple<Tensor, Tensor, Tensor>
unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
std::tuple<Tensor, Tensor, Tensor> unique_consecutive_mps(const Tensor& self,
const bool return_inverse,
const bool return_counts,
c10::optional<int64_t> dim) {
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("MPS: unique_consecutive op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performace implications.");
@ -336,8 +302,10 @@ unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool
return _unique_impl_mps(self, return_inverse, return_counts, true, dim);
}
std::tuple<Tensor, Tensor, Tensor>
unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_inverse, const bool return_counts) {
std::tuple<Tensor, Tensor, Tensor> unique_dim_consecutive_mps(const Tensor& self,
int64_t dim,
const bool return_inverse,
const bool return_counts) {
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("MPS: unique_dim_consecutive op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performace implications.");
@ -347,8 +315,10 @@ unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_in
return _unique_impl_mps(self, return_inverse, return_counts, true, c10::make_optional((int64_t)dim));
}
std::tuple<Tensor, Tensor, Tensor>
_unique2_mps(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
const bool sorted,
const bool return_inverse,
const bool return_counts) {
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("MPS: _unique2 op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performace implications.");

View File

@ -1,8 +1,8 @@
// Copyright © 2023 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/UpSample.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
@ -34,9 +34,8 @@ void upsample_out_template(const Tensor& input,
bool centerResults = false;
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
MPSGraphTensorNamedDataLayout dataLayout = input_dim.size() > 3 ?
MPSGraphTensorNamedDataLayoutNCHW :
MPSGraphTensorNamedDataLayoutCHW;
MPSGraphTensorNamedDataLayout dataLayout =
input_dim.size() > 3 ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutCHW;
if (resize_mode_str == "nearest") {
resizeMode = MPSGraphResizeNearest;
} else if (resize_mode_str == "bilinear") {
@ -63,9 +62,9 @@ void upsample_out_template(const Tensor& input,
input_size = input_size_opt.value();
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
MPSGraphTensor *outputSizeTensor = nil;
MPSGraphTensor* outputSizeTensor = nil;
};
MPSStream* stream = getCurrentMPSStream();
@ -76,24 +75,24 @@ void upsample_out_template(const Tensor& input,
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(2)]);
newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(2) ]);
MPSGraphTensor* scaleOffsetTensor = nullptr;
MPSGraphTensor* inputSizeTensor = nullptr;
if (scale_w > 0.0) {
const float outScales[4] = {scale_h, scale_w, offset_y, offset_x};
scaleOffsetTensor = [mpsGraph constantWithData: [NSData dataWithBytes: outScales length: sizeof(outScales)]
shape: @[@4]
dataType: MPSDataTypeFloat32];
scaleOffsetTensor = [mpsGraph constantWithData:[NSData dataWithBytes:outScales length:sizeof(outScales)]
shape:@[ @4 ]
dataType:MPSDataTypeFloat32];
}
if (is_backward_pass) {
std::vector<NSNumber*> inputSizeVec(4);
@ -101,118 +100,119 @@ void upsample_out_template(const Tensor& input,
inputSizeVec[1] = @(input_size[1]);
inputSizeVec[2] = @(input_size[2]);
inputSizeVec[3] = @(input_dim.size() > 3 ? input_size[3] : 1);
inputSizeTensor = [mpsGraph constantWithScalar: 0
shape: [NSArray arrayWithObjects:inputSizeVec.data() count:input_dim.size()]
dataType: getMPSDataType(input)];
inputSizeTensor = [mpsGraph constantWithScalar:0
shape:[NSArray arrayWithObjects:inputSizeVec.data()
count:input_dim.size()]
dataType:getMPSDataType(input)];
}
if (is_macOS_13_0_or_newer) {
if (!is_backward_pass) {
if (scaleOffsetTensor && !align_corners) {
if (resizeMode == MPSGraphResizeNearest) {
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor
sizeTensor: newCachedGraph->outputSizeTensor
scaleOffsetTensor: scaleOffsetTensor
nearestRoundingMode: nearestRoundingMode
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
scaleOffsetTensor:scaleOffsetTensor
nearestRoundingMode:nearestRoundingMode
layout:dataLayout
name:nil];
} else { // bilinear forward
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor
sizeTensor: newCachedGraph->outputSizeTensor
scaleOffsetTensor: scaleOffsetTensor
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
scaleOffsetTensor:scaleOffsetTensor
layout:dataLayout
name:nil];
}
} else { // scaleOffsetTensor == nil || align_corners
if (resizeMode == MPSGraphResizeNearest) {
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor
sizeTensor: newCachedGraph->outputSizeTensor
nearestRoundingMode: nearestRoundingMode
centerResult: centerResults
alignCorners: align_corners
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
nearestRoundingMode:nearestRoundingMode
centerResult:centerResults
alignCorners:align_corners
layout:dataLayout
name:nil];
} else { // bilinear forward
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor
sizeTensor: newCachedGraph->outputSizeTensor
centerResult: centerResults
alignCorners: align_corners
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
centerResult:centerResults
alignCorners:align_corners
layout:dataLayout
name:nil];
}
}
} else { // is_backward_pass == true
if (scaleOffsetTensor && !align_corners) {
if (resizeMode == MPSGraphResizeNearest) {
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor
input: inputSizeTensor
scaleOffsetTensor: scaleOffsetTensor
nearestRoundingMode: nearestRoundingMode
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor:newCachedGraph->inputTensor
input:inputSizeTensor
scaleOffsetTensor:scaleOffsetTensor
nearestRoundingMode:nearestRoundingMode
layout:dataLayout
name:nil];
} else { // bilinear backward
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor
input: inputSizeTensor
scaleOffsetTensor: scaleOffsetTensor
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor:newCachedGraph->inputTensor
input:inputSizeTensor
scaleOffsetTensor:scaleOffsetTensor
layout:dataLayout
name:nil];
}
} else { // scaleOffsetTensor == nil || align_corners
if (resizeMode == MPSGraphResizeNearest) {
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor
input: inputSizeTensor
nearestRoundingMode: nearestRoundingMode
centerResult: centerResults
alignCorners: align_corners
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor:newCachedGraph->inputTensor
input:inputSizeTensor
nearestRoundingMode:nearestRoundingMode
centerResult:centerResults
alignCorners:align_corners
layout:dataLayout
name:nil];
} else { // bilinear backward
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor
input: inputSizeTensor
centerResult: centerResults
alignCorners: align_corners
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor:newCachedGraph->inputTensor
input:inputSizeTensor
centerResult:centerResults
alignCorners:align_corners
layout:dataLayout
name:nil];
}
}
}
} else { // if macOS version < 13.0 (for backwards compatibility)
if (!is_backward_pass) {
newCachedGraph->outputTensor = [mpsGraph resizeTensor: newCachedGraph->inputTensor
sizeTensor: newCachedGraph->outputSizeTensor
mode: resizeMode
centerResult: centerResults
alignCorners: align_corners
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
mode:resizeMode
centerResult:centerResults
alignCorners:align_corners
layout:dataLayout
name:nil];
} else {
newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->inputTensor
input: inputSizeTensor
mode: resizeMode
centerResult: centerResults
alignCorners: align_corners
layout: dataLayout
name: nil];
newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor:newCachedGraph->inputTensor
input:inputSizeTensor
mode:resizeMode
centerResult:centerResults
alignCorners:align_corners
layout:dataLayout
name:nil];
}
}
}
return newCachedGraph;
});
}
MPSNDArrayDescriptor *sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(2)]];
MPSNDArray *sizeNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: sizeDesc] autorelease];
[sizeNDArray writeBytes: (int32_t[]) {(int32_t)output_height, (int32_t)output_width} strideBytes: nil];
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
MPSNDArrayDescriptor* sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(2) ]];
MPSNDArray* sizeNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:sizeDesc] autorelease];
[sizeNDArray writeBytes:(int32_t[]){(int32_t)output_height, (int32_t)output_width} strideBytes:nil];
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:sizeNDArray] autorelease];
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
cachedGraph->outputSizeTensor : sizeTensorData,
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
if (out.has_storage()) {
@ -223,8 +223,7 @@ void upsample_out_template(const Tensor& input,
} // namespace mps
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale)
{
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale) {
static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
if (!is_macOS_13_0_or_newer) {
// passing scale factors to MPS's resize APIs is not supported on macOS < 13
@ -236,7 +235,9 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
// is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey.
// The nearest mode should work fine on Ventura.
} else if (resize_mode_str == "nearest" || resize_mode_str == "nearest-exact") {
TORCH_WARN_ONCE("MPS: '", resize_mode_str, "' mode upsampling is supported natively starting from macOS 13.0. ",
TORCH_WARN_ONCE("MPS: '",
resize_mode_str,
"' mode upsampling is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");
return false;
}
@ -244,12 +245,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
return true;
}
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scale,
const Tensor& output)
{
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps)
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
if (check_mps_compatibility("nearest", scale)) {
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest");
} else {
@ -258,27 +255,23 @@ TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
}
}
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) (
const Tensor& grad_output,
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scale,
const Tensor& grad_input)
{
const Tensor& grad_input) {
if (check_mps_compatibility("nearest", scale)) {
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(grad_input) = at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
const_cast<Tensor&>(grad_input) =
at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
}
}
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scale,
const Tensor& output)
{
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps)
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
if (check_mps_compatibility("nearest-exact", scale)) {
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact");
} else {
@ -287,113 +280,123 @@ TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
}
}
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) (
const Tensor& grad_output,
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scale,
const Tensor& grad_input)
{
const Tensor& grad_input) {
if (check_mps_compatibility("nearest-exact", scale)) {
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
mps::upsample_out_template(
grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
const_cast<Tensor&>(grad_input) =
at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
}
}
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) (
const Tensor& input,
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& output)
{
const Tensor& output) {
if (check_mps_compatibility("nearest", scales_w)) {
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(output) = at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
const_cast<Tensor&>(output) =
at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
}
}
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) (
const Tensor& grad_output,
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& grad_input)
{
const Tensor& grad_input) {
if (check_mps_compatibility("nearest", scales_w)) {
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(grad_input) = at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
const_cast<Tensor&>(grad_input) =
at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w)
.clone()
.to("mps");
}
}
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) (
const Tensor& input,
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& output)
{
const Tensor& output) {
if (check_mps_compatibility("nearest-exact", scales_w)) {
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(output) = at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
const_cast<Tensor&>(output) =
at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
}
}
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) (
const Tensor& grad_output,
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& grad_input)
{
const Tensor& grad_input) {
if (check_mps_compatibility("nearest-exact", scales_w)) {
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
mps::upsample_out_template(
grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
const_cast<Tensor&>(grad_input) =
at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w)
.clone()
.to("mps");
}
}
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
const Tensor& input,
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& output)
{
const Tensor& output) {
if (check_mps_compatibility("bilinear", scales_w)) {
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(output) = at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
const_cast<Tensor&>(output) =
at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
}
}
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) (
const Tensor& grad_output,
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& grad_input)
{
const Tensor& grad_input) {
if (check_mps_compatibility("bilinear", scales_w)) {
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
mps::upsample_out_template(
grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(grad_input) = at::upsample_bilinear2d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w).clone().to("mps");
const_cast<Tensor&>(grad_input) =
at::upsample_bilinear2d_backward(
grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w)
.clone()
.to("mps");
}
}

View File

@ -1,18 +1,17 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/mps/IndexKernels.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mps/OperationUtils.h>
#include <fmt/format.h>
#include <torch/library.h>
#include <ATen/mps/IndexKernels.h>
namespace at::native {
namespace mps {
struct ViewCachedGraph : public MPSCachedGraph
{
ViewCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
struct ViewCachedGraph : public MPSCachedGraph {
ViewCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor = nil;
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* updatesTensor = nil;
@ -20,18 +19,20 @@ struct ViewCachedGraph : public MPSCachedGraph
std::vector<MPSGraphTensor*> strideTensors;
};
static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType& updates_dtype, const IntArrayRef& base_shape,
const IntArrayRef& new_shape, const IntArrayRef& stride,
int64_t storage_offset, bool is_scatter)
{
static std::string getStridedKey(const ScalarType& self_dtype,
const ScalarType& updates_dtype,
const IntArrayRef& base_shape,
const IntArrayRef& new_shape,
const IntArrayRef& stride,
int64_t storage_offset,
bool is_scatter) {
std::string dtype_key = getMPSTypeString(self_dtype);
if (is_scatter) {
dtype_key += ":" + getMPSTypeString(updates_dtype);
}
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" +
getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]:[" +
getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" +
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
}
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
@ -44,25 +45,26 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
const int64_t storage_offset = needsScatter ? output.storage_offset() : src.storage_offset();
const MPSDataType inputType = [cachedGraph->inputTensor dataType];
MPSShape *inputShape = [cachedGraph->inputTensor shape];
MPSShape *outputShape = needsScatter ? inputShape : getMPSShape(src);
MPSShape* inputShape = [cachedGraph->inputTensor shape];
MPSShape* outputShape = needsScatter ? inputShape : getMPSShape(src);
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
// in case of scatter, we use output tensor as input buffer and write the results back to the source buffer
feeds[cachedGraph->inputTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: needsScatter ? outputBuffer : sourceBuffer
shape: inputShape
dataType: inputType] autorelease];
feeds[cachedGraph->inputTensor] =
[[[MPSGraphTensorData alloc] initWithMTLBuffer:needsScatter ? outputBuffer : sourceBuffer
shape:inputShape
dataType:inputType] autorelease];
if (needsScatter) {
auto updatesType = getMPSScalarType(src.scalar_type());
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
updatesType = MPSDataTypeInt8;
}
feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer
shape: getMPSShape(src.numel())
dataType: updatesType] autorelease];
feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer
shape:getMPSShape(src.numel())
dataType:updatesType] autorelease];
}
MPSScalar storageOffsetScalar = getMPSScalar(storage_offset, ScalarType::Int);
feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar);
@ -78,26 +80,24 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
outputType = MPSDataTypeInt8;
}
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer
shape: outputShape
dataType: outputType] autorelease];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
cachedGraph->outputTensor : outputTensorData
};
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:outputBuffer
shape:outputShape
dataType:outputType] autorelease];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor : outputTensorData};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return output;
}
MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSArray *permuteOrder) {
MPSGraphTensor* permuteTensor(MPSGraph* graph, MPSGraphTensor* inputTensor, NSArray* permuteOrder) {
NSUInteger srcRank = [[inputTensor shape] count];
if (srcRank != [permuteOrder count]) {
return nil;
}
MPSGraphTensor *outputTensor = inputTensor;
MPSGraphTensor* outputTensor = inputTensor;
std::vector<NSUInteger> dimensionOrder(srcRank);
std::iota (std::begin(dimensionOrder), std::end(dimensionOrder), 0);
std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0);
for (const auto i : c10::irange(srcRank)) {
NSUInteger axis = [permuteOrder[i] integerValue];
@ -106,28 +106,24 @@ MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSAr
NSUInteger axis2 = axisIter - dimensionOrder.begin();
iter_swap(dimensionOrder.begin() + i, axisIter);
outputTensor = [graph transposeTensor:outputTensor
dimension:axis1
withDimension:axis2
name:nil];
outputTensor = [graph transposeTensor:outputTensor dimension:axis1 withDimension:axis2 name:nil];
}
return outputTensor;
}
NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger rank, NSUInteger offset) {
NSDictionary* getStrideToDimLengthOffsetDict(MPSGraphTensor* tensor, NSUInteger rank, NSUInteger offset) {
// Assuming input tensor has default strides
NSInteger stride = 1;
NSMutableDictionary *strideToDimLengthOffset = [[NSMutableDictionary alloc] init];
NSMutableDictionary* strideToDimLengthOffset = [[NSMutableDictionary alloc] init];
for (NSInteger srcDim = rank - 1; srcDim >= 0; srcDim--) {
NSUInteger size = [[tensor shape][srcDim] integerValue];
NSDictionary *entry =
@{
@"dim": [NSNumber numberWithInteger:srcDim],
@"length": [tensor shape][srcDim],
@"offset": [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride
NSDictionary* entry = @{
@"dim" : [NSNumber numberWithInteger:srcDim],
@"length" : [tensor shape][srcDim],
@"offset" : [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride
};
[strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld",stride]];
[strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld", stride]];
offset /= size;
stride *= size;
}
@ -135,14 +131,18 @@ NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger
}
// Detect only expand dims, allows for duplicate strides
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
MPSGraphTensor* inputTensor,
int dstRank,
const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides,
int offset) {
NSUInteger srcRank = [[inputTensor shape] count];
// Not an expand dims
if (srcRank >= dstRank)
return nil;
NSMutableArray *expandAxes = [[NSMutableArray alloc] init];
NSMutableArray* expandAxes = [[NSMutableArray alloc] init];
BOOL isValidExpand = YES;
NSInteger currSrcDim = (NSInteger)srcRank - 1;
@ -173,11 +173,9 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
return nil;
}
MPSGraphTensor *expandTensor = inputTensor;
MPSGraphTensor* expandTensor = inputTensor;
if ([expandAxes count]) {
expandTensor = [graph expandDimsOfTensor:expandTensor
axes:expandAxes
name:nil];
expandTensor = [graph expandDimsOfTensor:expandTensor axes:expandAxes name:nil];
}
[expandAxes release];
@ -185,13 +183,18 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
}
// Detect contiguous reshapes, no slicing
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
MPSGraphTensor* inputTensor,
int dstRank,
const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides,
int offset) {
NSUInteger srcRank = [[inputTensor shape] count];
// Not a reshape
if (srcRank <= dstRank)
return nil;
NSMutableArray *dstShape = [[NSMutableArray alloc] init];
NSMutableArray* dstShape = [[NSMutableArray alloc] init];
BOOL isValidReshape = YES;
NSInteger srcDim = srcRank - 1;
@ -199,7 +202,7 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *i
for (NSInteger dstDim = dstRank - 1; dstDim >= 0 && isValidReshape; dstDim--) {
NSUInteger currDimLength = dstSizes[dstDim];
NSUInteger currStride = dstStrides[dstDim];
[dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex: 0];
[dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex:0];
NSUInteger targetDimLength = currDimLength;
NSUInteger currReshapeSize = 1;
@ -216,26 +219,28 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *i
}
isValidReshape &= (srcDim < 0);
MPSGraphTensor *outputTensor = nil;
MPSGraphTensor* outputTensor = nil;
if (isValidReshape)
outputTensor = [graph reshapeTensor: inputTensor
withShape: dstShape
name: nil];
outputTensor = [graph reshapeTensor:inputTensor withShape:dstShape name:nil];
[dstShape release];
return outputTensor;
}
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
MPSGraphTensor* inputTensor,
int dstRank,
const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides,
int offset) {
// Duplicate strides cannot be done
{
BOOL allUnique = YES;
NSMutableSet *uniqueStrides = [[NSMutableSet alloc] init];
NSMutableSet* uniqueStrides = [[NSMutableSet alloc] init];
for (NSInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) {
int stride = dstStrides[dstDim];
NSNumber *strideObj = [NSNumber numberWithInt:stride];
NSNumber* strideObj = [NSNumber numberWithInt:stride];
allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]);
[uniqueStrides addObject: strideObj];
[uniqueStrides addObject:strideObj];
}
[uniqueStrides release];
if (!allUnique)
@ -243,31 +248,31 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
// Skip for zero in dst shape
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++)
if (dstSizes[dstDim] == 0) { return nil; }
if (dstSizes[dstDim] == 0) {
return nil;
}
}
// 1. Flatten the inputTensor if necessary
MPSGraphTensor *flatInputTensor = inputTensor;
MPSGraphTensor* flatInputTensor = inputTensor;
{
// Flatten inputs to remove duplicate strides.
NSMutableArray *squeezeAxes = [[NSMutableArray alloc] init];
for(NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) {
NSMutableArray* squeezeAxes = [[NSMutableArray alloc] init];
for (NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) {
if ([[flatInputTensor shape][srcDim] intValue] == 1)
[squeezeAxes addObject:[NSNumber numberWithInteger:srcDim]];
}
// We have to leave at least 1 dimension, if all input dims are 1
if ([squeezeAxes count])
flatInputTensor = [graph squeezeTensor:flatInputTensor
axes:squeezeAxes
name:nil];
flatInputTensor = [graph squeezeTensor:flatInputTensor axes:squeezeAxes name:nil];
[squeezeAxes release];
}
int srcRank = (int)[[flatInputTensor shape] count];
NSDictionary *srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset);
NSDictionary* srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset);
// Populate the dimension order, slice info, and broadcast info
NSMutableArray *dstDimOrder = [[NSMutableArray alloc] init];
NSMutableArray* dstDimOrder = [[NSMutableArray alloc] init];
std::vector<int32_t> dstDimToSliceLength(dstRank);
std::vector<int32_t> dstDimToSliceOffset(dstRank);
bool needsBroadcast = false;
@ -280,31 +285,33 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
dstDimToSliceOffset[dstDim] = 0;
} else {
// Find what dimension and native length was for the specified stride
NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld",dstStrides[dstDim]]];
NSDictionary* srcDimLengthOffset =
srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld", dstStrides[dstDim]]];
dstDimToSliceLength[dstDim] = dstSizes[dstDim];
dstDimToSliceOffset[dstDim] = [srcDimLengthOffset[@"offset"] intValue];
// Stride does not exist in source tensor, or the specified size is too long. Not possible
// TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding support
// TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding
// support
if (!srcDimLengthOffset ||
// the offset + length of destination should not be larger than source's length when slicing
dstDimToSliceOffset[dstDim] + dstDimToSliceLength[dstDim] > [srcDimLengthOffset[@"length"] intValue]) {
return nil;
}
// Get the src dimension corresponding to the requested stride
NSNumber *srcDim = srcDimLengthOffset[@"dim"];
NSNumber* srcDim = srcDimLengthOffset[@"dim"];
[dstDimOrder insertObject:srcDim atIndex:0];
}
}
}
// 2. Slice out any unused dimensions
NSMutableArray *missingSrcDims = [[NSMutableArray alloc] init];
MPSGraphTensor *slicedUnusedTensor = flatInputTensor;
NSMutableArray* missingSrcDims = [[NSMutableArray alloc] init];
MPSGraphTensor* slicedUnusedTensor = flatInputTensor;
{
// Find any src strides/dims that are not present in the dst
NSMutableArray *missingSrcStrides = [[NSMutableArray alloc] init];
NSMutableArray* missingSrcStrides = [[NSMutableArray alloc] init];
{
NSUInteger stride = 1;
for (NSInteger srcDim = [[flatInputTensor shape] count] - 1; srcDim >= 0; srcDim--) {
@ -317,8 +324,8 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
}
for (NSUInteger i = 0; i < [missingSrcStrides count]; i++) {
NSUInteger stride = [missingSrcStrides[i] integerValue];
NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld",stride]];
NSNumber *missingSrcDim = srcDimLengthOffset[@"dim"];
NSDictionary* srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld", stride]];
NSNumber* missingSrcDim = srcDimLengthOffset[@"dim"];
[missingSrcDims addObject:missingSrcDim];
[dstDimOrder insertObject:missingSrcDim atIndex:0];
@ -332,35 +339,33 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
}
// 3. Transpose if necessary
MPSGraphTensor *transposedTensor = slicedUnusedTensor;
MPSGraphTensor* transposedTensor = slicedUnusedTensor;
{
// TODO: Use Transpose API
BOOL needsTranspose = NO;
for(NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++ )
for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++)
needsTranspose |= ([dstDimOrder[dstDim] intValue] != dstDim);
if (needsTranspose)
transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder);
}
// 4. Squeeze any unused dimensions following transpose
MPSGraphTensor *squeezedTensor = transposedTensor;
MPSGraphTensor* squeezedTensor = transposedTensor;
{
// Transpose the missing dims back
NSMutableArray *transposedMissingSrcDims = [[NSMutableArray alloc] init];
NSMutableArray* transposedMissingSrcDims = [[NSMutableArray alloc] init];
for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count]; dstDim++) {
NSNumber *srcDim = dstDimOrder[dstDim];
NSNumber* srcDim = dstDimOrder[dstDim];
if ([missingSrcDims containsObject:srcDim])
[transposedMissingSrcDims addObject:[NSNumber numberWithInt:dstDim]];
}
if ([transposedMissingSrcDims count])
squeezedTensor = [graph squeezeTensor:squeezedTensor
axes:transposedMissingSrcDims
name:nil];
squeezedTensor = [graph squeezeTensor:squeezedTensor axes:transposedMissingSrcDims name:nil];
[transposedMissingSrcDims release];
}
// 5. Slice
MPSGraphTensor *slicedTensor = squeezedTensor;
MPSGraphTensor* slicedTensor = squeezedTensor;
{
NSUInteger currDstDim = 0;
for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) {
@ -369,34 +374,26 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
int start = dstDimToSliceOffset[dstDim];
int length = dstDimToSliceLength[dstDim];
if (length != [[slicedTensor shape][currDstDim] intValue])
slicedTensor = [graph sliceTensor:slicedTensor
dimension:currDstDim
start:start
length:length
name:nil];
slicedTensor = [graph sliceTensor:slicedTensor dimension:currDstDim start:start length:length name:nil];
currDstDim++;
}
}
}
// 6. Expand then broadcast the source tensor
MPSGraphTensor *broadcastTensor = slicedTensor;
MPSGraphTensor* broadcastTensor = slicedTensor;
if (needsBroadcast) {
NSMutableArray *broadcastShape = [[NSMutableArray alloc] init];
NSMutableArray *expandAxes = [[NSMutableArray alloc] init];
for(NSInteger dstDim = 0; dstDim < dstRank; dstDim++) {
NSMutableArray* broadcastShape = [[NSMutableArray alloc] init];
NSMutableArray* expandAxes = [[NSMutableArray alloc] init];
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) {
[broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]];
if (dstStrides[dstDim] == 0)
[expandAxes addObject:[NSNumber numberWithInt:dstDim]];
}
if ([expandAxes count]) {
MPSGraphTensor *expandTensor = [graph expandDimsOfTensor:broadcastTensor
axes:expandAxes
name:nil];
broadcastTensor = [graph broadcastTensor:expandTensor
toShape:broadcastShape
name:nil];
MPSGraphTensor* expandTensor = [graph expandDimsOfTensor:broadcastTensor axes:expandAxes name:nil];
broadcastTensor = [graph broadcastTensor:expandTensor toShape:broadcastShape name:nil];
}
[broadcastShape release];
[expandAxes release];
@ -409,11 +406,16 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
return broadcastTensor;
}
MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph,
MPSGraphTensor* inputTensor,
int dstRank,
const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides,
int offset) {
if (!dstRank)
return nil;
MPSGraphTensor *outputTensor = nil;
MPSGraphTensor* outputTensor = nil;
outputTensor = asStridedLayer_expandDimsPattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset);
if (!outputTensor)
outputTensor = asStridedLayer_reshapePattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset);
@ -423,8 +425,7 @@ MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTen
return outputTensor;
}
static
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) {
static std::vector<int64_t> getViewShape(const Tensor& src, MPSShape* mpsShape, const bool squeeze) {
bool hasMPSShape = (mpsShape != nil);
std::vector<int64_t> src_view_shape;
if (hasMPSShape) {
@ -459,7 +460,6 @@ std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const b
return src_view_shape;
}
std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) {
std::vector<int64_t> src_base_shape;
for (const auto i : c10::irange(shape.size())) {
@ -471,8 +471,7 @@ std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape)
return src_base_shape;
}
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
bool canSliceViewTensor(const Tensor& src, MPSShape* mpsShape) {
if (!src.is_contiguous()) {
return false;
}
@ -486,7 +485,7 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
return false;
}
for (const auto i: c10::irange(src_ndim_base)) {
for (const auto i : c10::irange(src_ndim_base)) {
if (src_view_shape[i] > src_base_shape[i]) {
return false;
}
@ -494,15 +493,15 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
return true;
}
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) {
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape* mpsShape, const MPSDataType mpsDataType) {
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
size_t src_ndim_base = src_base_shape.size();
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
size_t src_ndim_view = src_view_shape.size();
MPSNDArray *srcTensorNDArrayView = nil;
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
MPSNDArray *srcTensorNDArray = nil;
MPSNDArray* srcTensorNDArrayView = nil;
MPSNDArrayDescriptor* srcTensorNDArrayDesc = nil;
MPSNDArray* srcTensorNDArray = nil;
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
int64_t base_idx = 0;
@ -537,18 +536,20 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
}
int64_t sliceOffset = src.storage_offset() / view_numel;
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
[srcTensorNDArrayDesc
sliceDimension:src_ndim_base - 1 - firstDimToSlice
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
// Slice any remaining dimensions
for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
for (const auto crtSliceOffset : c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
if (crtSliceOffset == src_base_shape.size() - 1) {
sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
} else {
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
}
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
[srcTensorNDArrayDesc
sliceDimension:src_ndim_base - 1 - crtSliceOffset
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
}
}
@ -559,13 +560,15 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
return [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcTensorNDArrayView] autorelease];
}
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size,
const IntArrayRef& stride, int64_t offset,
const IntArrayRef& base_shape, bool needsScatter,
MPSGraphTensor* updatesTensor)
{
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
const IntArrayRef& size,
const IntArrayRef& stride,
int64_t offset,
const IntArrayRef& base_shape,
bool needsScatter,
MPSGraphTensor* updatesTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor *outputTensor = nil;
MPSGraphTensor* outputTensor = nil;
const size_t shape_size = size.size();
@autoreleasepool {
@ -575,87 +578,74 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
TORCH_CHECK(size[i] <= int_max);
sizeArray[i] = static_cast<int32_t>(size[i]);
}
NSData* shapeData = [NSData dataWithBytes: sizeArray.data()
length: shape_size * sizeof(int32_t)];
MPSGraphTensor* shapeTensor = [mpsGraph constantWithData: shapeData
shape: @[[NSNumber numberWithUnsignedInteger: shape_size]]
dataType: MPSDataTypeInt32];
NSData* shapeData = [NSData dataWithBytes:sizeArray.data() length:shape_size * sizeof(int32_t)];
MPSGraphTensor* shapeTensor = [mpsGraph constantWithData:shapeData
shape:@[ [NSNumber numberWithUnsignedInteger:shape_size] ]
dataType:MPSDataTypeInt32];
MPSGraphTensor* indicesTensor = nil;
// create stride Tensors for each rank of the input tensor
for (int i = 0; i < shape_size; i++) {
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis: (-i - 1)
withShapeTensor: shapeTensor
name: nil];
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil];
MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1];
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor: rangeTensor
secondaryTensor: strideTensor
name: nil];
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor
secondaryTensor:strideTensor
name:nil];
if (!indicesTensor) {
indicesTensor = indexTensor;
} else {
indicesTensor = [mpsGraph additionWithPrimaryTensor: indexTensor
secondaryTensor: indicesTensor
name: nil];
indicesTensor = [mpsGraph additionWithPrimaryTensor:indexTensor secondaryTensor:indicesTensor name:nil];
}
}
indicesTensor = [mpsGraph additionWithPrimaryTensor: indicesTensor
secondaryTensor: cachedGraph->storageOffsetTensor
name: nil];
MPSGraphTensor *inputTensor = cachedGraph->inputTensor;
indicesTensor = [mpsGraph additionWithPrimaryTensor:indicesTensor
secondaryTensor:cachedGraph->storageOffsetTensor
name:nil];
MPSGraphTensor* inputTensor = cachedGraph->inputTensor;
if (!needsScatter) {
MPSGraphTensor *outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset);
MPSGraphTensor* outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset);
if (outputTensor) {
return outputTensor;
}
}
MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor: inputTensor
withShape: @[@-1]
name: nil];
MPSGraphTensor *reshapedIndicesTensor = [mpsGraph reshapeTensor: indicesTensor
withShape: @[@-1]
name: nil];
MPSGraphTensor* reshapedInputTensor = [mpsGraph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
MPSGraphTensor* reshapedIndicesTensor = [mpsGraph reshapeTensor:indicesTensor withShape:@[ @-1 ] name:nil];
if (needsScatter) {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wobjc-method-access"
MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis: (NSInteger) 0
withDataTensor: reshapedInputTensor
updatesTensor: updatesTensor
indicesTensor: reshapedIndicesTensor
mode: MPSGraphScatterModeSet
name: nil];
MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis:(NSInteger)0
withDataTensor:reshapedInputTensor
updatesTensor:updatesTensor
indicesTensor:reshapedIndicesTensor
mode:MPSGraphScatterModeSet
name:nil];
#pragma clang diagnostic pop
outputTensor = [mpsGraph reshapeTensor: scatteredTensor
withShape: getMPSShape(base_shape)
name: nil];
outputTensor = [mpsGraph reshapeTensor:scatteredTensor withShape:getMPSShape(base_shape) name:nil];
} else {
// Call gather to coalesce the needed values. Result will be of same shape as flattened indices tensor
MPSGraphTensor *gatheredTensor = [mpsGraph gatherWithUpdatesTensor: reshapedInputTensor
indicesTensor: reshapedIndicesTensor
axis: 0
batchDimensions: 0
name: nil];
MPSGraphTensor* gatheredTensor = [mpsGraph gatherWithUpdatesTensor:reshapedInputTensor
indicesTensor:reshapedIndicesTensor
axis:0
batchDimensions:0
name:nil];
// Reshape the data to desired size
outputTensor = [mpsGraph reshapeTensor: gatheredTensor
withShapeTensor: shapeTensor
name: nil];
outputTensor = [mpsGraph reshapeTensor:gatheredTensor withShapeTensor:shapeTensor name:nil];
}
}
return outputTensor;
}
static IntArrayRef updateTensorBaseShape(const Tensor& self)
{
static IntArrayRef updateTensorBaseShape(const Tensor& self) {
IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data());
// if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it
if (base_shape.size() == 0) {
// IntArrayRef wouldn't own the data, so we use a static storage
static const int64_t shape_1d = 1;
// self.sizes().size() could be zero
base_shape = self.sizes().size() ? self.sizes() :
((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
base_shape = self.sizes().size()
? self.sizes()
: ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
// base_shape will be retained in MPSAllocator until buffer gets recycled
if (self.storage().data())
@ -681,18 +671,23 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self)
// | / \ |
// | / \ |
// NonView T NonView T
static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &updates, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter)
{
static ViewCachedGraph* createViewGraph(const Tensor& self,
const Tensor& updates,
IntArrayRef size,
IntArrayRef stride,
int64_t storage_offset,
bool needsScatter) {
IntArrayRef base_shape = updateTensorBaseShape(self);
@autoreleasepool {
string key = getStridedKey(self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter);
string key = getStridedKey(
self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter);
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
ViewCachedGraph* cachedGraph = static_cast<ViewCachedGraph *>(cache_->LookUp(key));
ViewCachedGraph* cachedGraph = static_cast<ViewCachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
cachedGraph = static_cast<ViewCachedGraph *>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
ViewCachedGraph *newCachedGraph = nil;
cachedGraph = static_cast<ViewCachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
ViewCachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
MPSGraphTensor* updatesTensor = nil;
@ -706,9 +701,9 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
// Self is the input tensor we are creating view of
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]);
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]);
for (int i = 0; i < size.size(); i++) {
newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]));
newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]));
}
if (needsScatter) {
auto updatesType = getMPSScalarType(updates.scalar_type());
@ -718,12 +713,11 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel()));
updatesTensor = newCachedGraph->updatesTensor;
if (inputType != updatesType) {
updatesTensor = [mpsGraph castTensor:updatesTensor
toType:inputType
name:@"castUpdatesTensor"];
updatesTensor = [mpsGraph castTensor:updatesTensor toType:inputType name:@"castUpdatesTensor"];
}
}
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
newCachedGraph->outputTensor =
chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
}
return newCachedGraph;
}));
@ -732,11 +726,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
}
}
static
std::string getGatherScatterFunctionName(
ScalarType scalarType,
int64_t dim,
bool needsScatter) {
static std::string getGatherScatterFunctionName(ScalarType scalarType, int64_t dim, bool needsScatter) {
std::string kernelName = needsScatter ? "scatter" : "gather";
return kernelName + "_kernel_" + std::to_string(dim == 0 ? 1 : dim);
}
@ -759,8 +749,7 @@ const std::string& getGatherScatterScalarType(const Tensor& t) {
return it->second;
}
static
id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
static id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
const std::string& dtypeSrc,
const std::string& dtypeDst,
bool needsScatter) {
@ -770,13 +759,20 @@ id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
if (it != _libCache.end()) {
return it->second;
}
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
auto gatherScatterLib = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE : GATHER_OPS_TEMPLATE, dtypeSrc, dtypeDst).c_str()]
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto gatherScatterLib =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE
: GATHER_OPS_TEMPLATE,
dtypeSrc,
dtypeDst)
.c_str()]
options:options
error:&error];
TORCH_CHECK(gatherScatterLib != nil && error == nil, "Failed to compile gather-scatter library, error: ", [[error description] UTF8String]);
TORCH_CHECK(gatherScatterLib != nil && error == nil,
"Failed to compile gather-scatter library, error: ",
[[error description] UTF8String]);
_libCache[key] = gatherScatterLib;
return gatherScatterLib;
}
@ -793,12 +789,13 @@ static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device,
return it->second;
}
NSError *error = nil;
NSError* error = nil;
id<MTLLibrary> library = compileGatherScatterOpsLibrary(device, dtypeSrc, dtypeDst, needsScatter);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(func, "Failed to load the Metal Shader function: ", kernel);
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
TORCH_CHECK(
pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
_mtlPipelineCache[key] = pso;
return pso;
}
@ -814,8 +811,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
}
if (src.dim() > 5) {
ViewCachedGraph* cachedGraph = createViewGraph(src, dst, src.sizes(), src.strides(),
src.storage_offset(), /*needsScatter*/ false);
ViewCachedGraph* cachedGraph =
createViewGraph(src, dst, src.sizes(), src.strides(), src.storage_offset(), /*needsScatter*/ false);
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false);
}
@ -824,7 +821,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
uint32_t numThreads = output.numel();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^(){
dispatch_sync(mpsStream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = [mpsStream->commandBuffer() computeCommandEncoder];
std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/false);
id<MTLComputePipelineState> gatherPSO = getPipelineState(MPSDevice::getInstance()->device(),
@ -846,7 +843,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
}
}
[computeEncoder setComputePipelineState: gatherPSO];
[computeEncoder setComputePipelineState:gatherPSO];
[computeEncoder setBuffer:getMTLBufferStorage(src) offset:src.storage_offset() * src.element_size() atIndex:0];
[computeEncoder setBuffer:outputBuffer offset:outputStorageOffset atIndex:1];
[computeEncoder setBytes:&src_sizes[0] length:sizeof(uint32_t) * kernel_size atIndex:2];
@ -868,11 +865,14 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
return (dst.has_storage()) ? dst : output;
}
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
if (output.dim() > 5) {
ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output,
src, output.sizes(), output.strides(),
output.storage_offset(), /*needsScatter*/ true);
src,
output.sizes(),
output.strides(),
output.storage_offset(),
/*needsScatter*/ true);
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true);
}
if (src.numel() == 0 || output.numel() == 0) {
@ -884,11 +884,12 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
uint32_t numThreads = src.numel();
int64_t outputStorageOffset = output.storage_offset() * output.element_size();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^(){
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true);
std::string functionName =
getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true);
id<MTLComputePipelineState> scatterPSO = getPipelineState(MPSDevice::getInstance()->device(),
functionName,
getGatherScatterScalarType(src),
@ -908,7 +909,7 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
}
}
[computeEncoder setComputePipelineState: scatterPSO];
[computeEncoder setComputePipelineState:scatterPSO];
[computeEncoder setBuffer:sourceBuffer offset:src.storage_offset() * src.element_size() atIndex:0];
[computeEncoder setBuffer:outputBuffer offset:outputStorageOffset atIndex:1];
[computeEncoder setBytes:&output_sizes[0] length:sizeof(uint32_t) * kernel_size atIndex:2];
@ -934,16 +935,21 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
} // namespace mps
// implementation of as_strided() op
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_) {
Tensor as_strided_tensorimpl_mps(const Tensor& self,
IntArrayRef size,
IntArrayRef stride,
c10::optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto result = detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
auto result =
detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
setStrided(result, size, stride, storage_offset);
// creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called.
// In as_strided, we just update the base shape of the buffer in order to retrieve it later
// when we create/run the view graph.
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
TORCH_INTERNAL_ASSERT(base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
TORCH_INTERNAL_ASSERT(
base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
return result;
}

View File

@ -10,8 +10,7 @@ NS_ASSUME_NONNULL_BEGIN
+ (NSString*)cacheDirectory;
+ (BOOL)compileModel:(const std::string&)modelSpecs
modelID:(const std::string&)modelID;
+ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID;
+ (nullable MLModel*)loadModel:(const std::string&)modelID
backend:(const std::string&)backend