mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][MPS] Add MPS to clang format (#96562)
I'm getting tired of asking to add space after if and all that jazz, so let's linter do that. Add section for Objective-C language, where column with is extended to 120 characters and `AlignAfterOpenBracket` is set to `Align` All `.mm` changes in this PR are made by running linter as follows: ``` lintrunner --take CLANGFORMAT --all-files --apply-patches ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/96562 Approved by: https://github.com/seemethere, https://github.com/janeyx99, https://github.com/ZainRizvi, https://github.com/izaitsevfb, https://github.com/PaliC, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7689e73f6
commit
4242e698a3
@ -60,9 +60,6 @@ MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
@ -85,4 +82,11 @@ SpacesInSquareBrackets: false
|
||||
Standard: Cpp11
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
---
|
||||
Language: ObjC
|
||||
ColumnLimit: 120
|
||||
AlignAfterOpenBracket: Align
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
...
|
||||
|
@ -49,6 +49,8 @@ init_command = [
|
||||
code = 'CLANGFORMAT'
|
||||
include_patterns = [
|
||||
'aten/src/ATen/*.h',
|
||||
'aten/src/ATen/mps/**/*.mm',
|
||||
'aten/src/ATen/native/mps/**/*.mm',
|
||||
'aten/src/ATen/native/vulkan/**/*.h',
|
||||
'aten/src/ATen/native/vulkan/**/*.cpp',
|
||||
'c10/**/*.h',
|
||||
|
@ -1,10 +1,10 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/CPUFunctions.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/mps/MPSAllocator.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <ATen/CPUFunctions.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace at {
|
||||
@ -19,25 +19,26 @@ uint64_t HeapBlock::heap_counter = 0;
|
||||
|
||||
void MPSHeapAllocatorImpl::init_allocator() {
|
||||
// debug verbosity flags (see DebugVerbosity enum)
|
||||
static const char *verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR");
|
||||
static const char* verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR");
|
||||
m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT;
|
||||
|
||||
static const char *high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
|
||||
const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) :
|
||||
default_high_watermark_ratio;
|
||||
static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
|
||||
const double high_watermark_ratio =
|
||||
high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio;
|
||||
setHighWatermarkRatio(high_watermark_ratio);
|
||||
|
||||
const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified :
|
||||
default_low_watermark_ratio_discrete;
|
||||
static const char *low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
|
||||
const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
|
||||
const double default_low_watermark_ratio =
|
||||
m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete;
|
||||
static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
|
||||
const double low_watermark_ratio =
|
||||
low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
|
||||
setLowWatermarkRatio(low_watermark_ratio);
|
||||
}
|
||||
|
||||
void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
|
||||
TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio);
|
||||
m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
|
||||
static_cast<size_t>(ratio * (double)max_device_size());
|
||||
m_max_total_allowed_size =
|
||||
(ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
|
||||
if (m_debug_verbosity & DebugVerbosity::PROFILING) {
|
||||
std::cerr << "\nHigh watermark memory allocation limit: "
|
||||
<< (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n";
|
||||
@ -47,11 +48,12 @@ void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
|
||||
|
||||
void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) {
|
||||
// used for comparison with lower_watermark_ratio
|
||||
const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
|
||||
const double high_watermark_limit =
|
||||
m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
|
||||
TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio);
|
||||
// we use this to detect if there's memory pressure
|
||||
m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
|
||||
static_cast<size_t>(ratio * (double)max_device_size());
|
||||
m_low_watermark_limit =
|
||||
(ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
|
||||
if (m_debug_verbosity & DebugVerbosity::PROFILING) {
|
||||
std::cerr << "Low watermark memory allocation limit: "
|
||||
<< (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n";
|
||||
@ -61,7 +63,7 @@ void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) {
|
||||
|
||||
HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) {
|
||||
BufferPool& pool = *params.pool;
|
||||
HeapBlock *heap_block = nullptr;
|
||||
HeapBlock* heap_block = nullptr;
|
||||
HeapBlock search_key(params.size());
|
||||
|
||||
auto it = pool.heaps.lower_bound(&search_key);
|
||||
@ -69,10 +71,8 @@ HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) {
|
||||
heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage);
|
||||
if (heap_block) {
|
||||
if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) {
|
||||
std::cerr << "\nAllocated "
|
||||
<< ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ")
|
||||
<< " heap #" << heap_block->heap_id
|
||||
<< " of size " << format_size(heap_block->size.total)
|
||||
std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ") << " heap #"
|
||||
<< heap_block->heap_id << " of size " << format_size(heap_block->size.total)
|
||||
<< " (#heaps: " << (pool.heaps.size() + 1)
|
||||
<< ", current allocated: " << format_size(current_allocated_size()) << ")\n";
|
||||
}
|
||||
@ -91,7 +91,7 @@ bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) {
|
||||
current_allocated_size() + params.size() > m_max_total_allowed_size) {
|
||||
return false;
|
||||
}
|
||||
HeapBlock *heap = get_free_heap(params);
|
||||
HeapBlock* heap = get_free_heap(params);
|
||||
if (!heap) {
|
||||
return false; // this will cause releasing pool buffers to free up memory
|
||||
}
|
||||
@ -109,17 +109,14 @@ bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) {
|
||||
pool.n_buffers++;
|
||||
|
||||
if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) &&
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Allocated "
|
||||
<< ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "")
|
||||
<< " buffer #" << params.buffer_block->buf_id
|
||||
<< " of size " << format_size(params.size())
|
||||
<< " at " << params.buffer_block->buffer
|
||||
<< " from heap #" << heap->heap_id
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
|
||||
<< params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at "
|
||||
<< params.buffer_block->buffer << " from heap #" << heap->heap_id
|
||||
<< " (requested: " << format_size(params.requested_size)
|
||||
<< ", heap: " << format_size(heap->size.available)
|
||||
<< ", total: " << format_size(m_total_allocated_memory) << ")\n";
|
||||
<< ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory)
|
||||
<< ")\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -158,8 +155,8 @@ bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
|
||||
// this will skip unnecessary garbage collection as we'll reuse the newly released space
|
||||
params.has_memory_pressure = false;
|
||||
} else if (params.has_memory_pressure) {
|
||||
// the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap container)
|
||||
// in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
|
||||
// the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap
|
||||
// container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
|
||||
release_buffer(buffer_block, true);
|
||||
} else {
|
||||
// only if there's no memory pressure, we'll reuse the oversized buffer
|
||||
@ -176,16 +173,13 @@ bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
|
||||
pool.available_size -= params.buffer_block->size;
|
||||
|
||||
if ((m_debug_verbosity & DebugVerbosity::RECYCLES) &&
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Reusing "
|
||||
<< ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "")
|
||||
<< " buffer #" << params.buffer_block->buf_id
|
||||
<< " of size " << format_size(params.buffer_block->size)
|
||||
<< " at " << params.buffer_block->buffer
|
||||
<< " (requested: " << format_size(params.requested_size)
|
||||
<< ", use#: " << params.buffer_block->use_count + 1
|
||||
<< ", retain#: " << params.buffer_block->retainCount() << ")\n";
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
|
||||
<< params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at "
|
||||
<< params.buffer_block->buffer << " (requested: " << format_size(params.requested_size)
|
||||
<< ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount()
|
||||
<< ")\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -214,7 +208,8 @@ BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usag
|
||||
alloc_buffer(params) ||
|
||||
// Callbacks might release more memory (eg. by forcing a GC in the host language) thus
|
||||
// we can retry getting a free buffer in the pool, before trying to alloc again.
|
||||
(trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && get_free_buffer(params)) ||
|
||||
(trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) &&
|
||||
get_free_buffer(params)) ||
|
||||
// Free enough available cached blocks to satisfy alloc and retry alloc.
|
||||
(release_available_cached_buffers(params) && alloc_buffer(params)) ||
|
||||
// Free all cached buffers and retry alloc.
|
||||
@ -229,16 +224,30 @@ BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usag
|
||||
// chunk of requested size couldn't be found.
|
||||
if (!block_found || !buffer_block) {
|
||||
if (m_high_watermark_ratio > 0.0) {
|
||||
TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory),
|
||||
", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
", max allowed: ", format_size(m_max_total_allowed_size), "). Tried to allocate ", format_size(alloc_size),
|
||||
" on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
|
||||
" pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).");
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"MPS backend out of memory (MPS allocated: ",
|
||||
format_size(m_total_allocated_memory),
|
||||
", other allocations: ",
|
||||
format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
", max allowed: ",
|
||||
format_size(m_max_total_allowed_size),
|
||||
"). Tried to allocate ",
|
||||
format_size(alloc_size),
|
||||
" on ",
|
||||
((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
|
||||
" pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).");
|
||||
} else {
|
||||
TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory),
|
||||
", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
"). Tried to allocate ", format_size(alloc_size),
|
||||
" on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), " pool.");
|
||||
TORCH_CHECK(false,
|
||||
"MPS backend out of memory (MPS allocated: ",
|
||||
format_size(m_total_allocated_memory),
|
||||
", other allocations: ",
|
||||
format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
"). Tried to allocate ",
|
||||
format_size(alloc_size),
|
||||
" on ",
|
||||
((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
|
||||
" pool.");
|
||||
}
|
||||
}
|
||||
buffer_block->in_use = true;
|
||||
@ -270,7 +279,7 @@ BufferBlock* MPSHeapAllocatorImpl::get_allocated_buffer_block(void* ptr) {
|
||||
}
|
||||
|
||||
bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) {
|
||||
HeapBlock *heap_block = buffer_block->heap;
|
||||
HeapBlock* heap_block = buffer_block->heap;
|
||||
BufferPool& pool = *heap_block->pool;
|
||||
m_total_allocated_memory -= buffer_block->size;
|
||||
pool.allocated_size -= buffer_block->size;
|
||||
@ -283,13 +292,10 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
|
||||
uint32_t retainCount = heap_block->releaseMTLBuffer(buffer_block->buffer);
|
||||
|
||||
if ((m_debug_verbosity & DebugVerbosity::RELEASES) &&
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Released buffer #" << buffer_block->buf_id
|
||||
<< " of size " << format_size(buffer_block->size)
|
||||
<< " from heap #" << heap_block->heap_id
|
||||
<< " (heap size: " << format_size(heap_block->size.available)
|
||||
<< ", use#: " << buffer_block->use_count
|
||||
<< ", retain#: " << retainCount
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size)
|
||||
<< " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available)
|
||||
<< ", use#: " << buffer_block->use_count << ", retain#: " << retainCount
|
||||
<< ", gc#: " << buffer_block->gc_count << ")\n";
|
||||
}
|
||||
delete buffer_block;
|
||||
@ -298,10 +304,9 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
|
||||
pool.heaps_pending_update.erase(heap_block);
|
||||
retainCount = heap_block->releaseMTLHeap();
|
||||
if (m_debug_verbosity & DebugVerbosity::RELEASES) {
|
||||
std::cerr << "Released heap #" << heap_block->heap_id
|
||||
<< " of size " << format_size(heap_block->size.total)
|
||||
<< " (current allocated: " << format_size(current_allocated_size())
|
||||
<< ", retain#: " << retainCount << ")\n";
|
||||
std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total)
|
||||
<< " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount
|
||||
<< ")\n";
|
||||
}
|
||||
delete heap_block;
|
||||
return true;
|
||||
@ -312,7 +317,7 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
|
||||
if (retainCount > 1) {
|
||||
pool.heaps_pending_update.insert(heap_block);
|
||||
m_mutex.unlock();
|
||||
m_stream->addCompletedHandler(^(id <MTLCommandBuffer>) {
|
||||
m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
// check if the heap block still exists
|
||||
if (pool.heaps_pending_update.find(heap_block) != pool.heaps_pending_update.end()) {
|
||||
@ -333,13 +338,11 @@ void MPSHeapAllocatorImpl::release_buffers(BufferPool& pool) {
|
||||
return;
|
||||
}
|
||||
if ((m_debug_verbosity & DebugVerbosity::RELEASES)) {
|
||||
std::cerr << "Releasing " << pool.buffers.size()
|
||||
<< " buffers from "
|
||||
<< ((pool.usage & UsageFlags::SMALL ) ? "small " : "large ")
|
||||
std::cerr << "Releasing " << pool.buffers.size() << " buffers from "
|
||||
<< ((pool.usage & UsageFlags::SMALL) ? "small " : "large ")
|
||||
<< ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "")
|
||||
<< " pool (total size: " << format_size(pool.allocated_size)
|
||||
<< ", #buffers: " << pool.n_buffers << ")\n";
|
||||
<< " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n";
|
||||
}
|
||||
auto it = pool.buffers.begin();
|
||||
while (it != pool.buffers.end()) {
|
||||
@ -381,10 +384,8 @@ bool MPSHeapAllocatorImpl::release_available_cached_buffers(AllocParams& params)
|
||||
|
||||
bool MPSHeapAllocatorImpl::release_cached_buffers() {
|
||||
if (m_debug_verbosity >= DebugVerbosity::PROFILING) {
|
||||
std::cerr << "Attempting to release cached buffers (MPS allocated: "
|
||||
<< format_size(m_total_allocated_memory)
|
||||
<< ", other allocations: "
|
||||
<< format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
|
||||
std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory)
|
||||
<< ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
|
||||
}
|
||||
// before releasing the buffers make sure the command buffer has finished.
|
||||
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
||||
@ -445,11 +446,10 @@ void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) {
|
||||
}
|
||||
}
|
||||
if (m_debug_verbosity & DebugVerbosity::RELEASES) {
|
||||
std::cerr << "Garbage collected " << freed_count
|
||||
<< " buffers from large "
|
||||
std::cerr << "Garbage collected " << freed_count << " buffers from large "
|
||||
<< ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< " pool (total reclaimed: " << format_size(gc_reclaimed)
|
||||
<< ", #buffers: " << pool.buffers.size() << ")\n";
|
||||
<< " pool (total reclaimed: " << format_size(gc_reclaimed) << ", #buffers: " << pool.buffers.size()
|
||||
<< ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -464,7 +464,7 @@ id<MTLBuffer> MPSHeapAllocatorImpl::malloc(size_t size, uint32_t usage) {
|
||||
bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
|
||||
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
|
||||
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
|
||||
// it's OK for the buffer_block to not exist yet
|
||||
return buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED);
|
||||
}
|
||||
@ -487,9 +487,9 @@ id<MTLBuffer> MPSHeapAllocatorImpl::allocScalarBufferWithValue(void* value, size
|
||||
ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
|
||||
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
|
||||
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
|
||||
if (buffer_block) {
|
||||
return (ssize_t) buffer_block->requested_size;
|
||||
return (ssize_t)buffer_block->requested_size;
|
||||
}
|
||||
// -1 indicates the passed buffer pointer wasn't found
|
||||
return -1;
|
||||
@ -498,7 +498,7 @@ ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr) {
|
||||
void MPSHeapAllocatorImpl::setBufferShape(void* ptr, const IntArrayRef& shape) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
|
||||
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
|
||||
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
|
||||
TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr);
|
||||
// note that the IntArrayRef doesn't own the underlying data, and the backing
|
||||
// memory for shape data must persist as long as the buffer is in use.
|
||||
@ -509,7 +509,7 @@ void MPSHeapAllocatorImpl::setBufferShape(void* ptr, const IntArrayRef& shape) {
|
||||
IntArrayRef MPSHeapAllocatorImpl::getBufferShape(void* ptr) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
|
||||
BufferBlock *buffer_block = get_allocated_buffer_block(ptr);
|
||||
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
|
||||
if (buffer_block && buffer_block->shape.size() > 0) {
|
||||
return IntArrayRef{buffer_block->shape};
|
||||
}
|
||||
@ -517,7 +517,7 @@ IntArrayRef MPSHeapAllocatorImpl::getBufferShape(void* ptr) {
|
||||
}
|
||||
|
||||
void MPSHeapAllocatorImpl::free(void* ptr) {
|
||||
BufferBlock *buffer_block = nullptr;
|
||||
BufferBlock* buffer_block = nullptr;
|
||||
{
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
|
||||
@ -531,7 +531,7 @@ void MPSHeapAllocatorImpl::free(void* ptr) {
|
||||
}
|
||||
// we sync the scalar pool manually with completion handler at the time buffer is
|
||||
// freed when the MPSScalar instance goes our of scope
|
||||
m_stream->addCompletedHandler(^(id <MTLCommandBuffer>) {
|
||||
m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
free_buffer(buffer_block);
|
||||
});
|
||||
@ -555,10 +555,15 @@ inline std::string MPSHeapAllocatorImpl::format_size(uint64_t size) const {
|
||||
std::ostringstream os;
|
||||
os.precision(2);
|
||||
os << std::fixed;
|
||||
if (size <= 1024UL) { os << size << " bytes"; }
|
||||
else if (size <= 1048576UL) { os << ((float) size / 1024.0) << " KB"; }
|
||||
else if (size <= 1073741824UL) { os << ((float) size / 1048576.0) << " MB"; }
|
||||
else { os << ((float) size / 1073741824.0) << " GB"; }
|
||||
if (size <= 1024UL) {
|
||||
os << size << " bytes";
|
||||
} else if (size <= 1048576UL) {
|
||||
os << ((float)size / 1024.0) << " KB";
|
||||
} else if (size <= 1073741824UL) {
|
||||
os << ((float)size / 1048576.0) << " MB";
|
||||
} else {
|
||||
os << ((float)size / 1073741824.0) << " GB";
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@ -574,16 +579,13 @@ HeapAllocator::MPSHeapAllocatorImpl& _getAllocImpl() {
|
||||
|
||||
// MPS allocator struct to be registered with Pytorch
|
||||
struct TORCH_API MPSAllocator final : public IMPSAllocator {
|
||||
public:
|
||||
explicit MPSAllocator(uint32_t Usage) :
|
||||
m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage)
|
||||
{
|
||||
public:
|
||||
explicit MPSAllocator(uint32_t Usage)
|
||||
: m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) {
|
||||
if (_getAllocImpl().getDebugVerbosity()) {
|
||||
if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) {
|
||||
std::cerr << "Initializing "
|
||||
<< ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< " heap allocator on "
|
||||
<< (m_has_unified_memory ? "unified" : "discrete")
|
||||
std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete")
|
||||
<< " device memory of size "
|
||||
<< _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n";
|
||||
}
|
||||
@ -593,34 +595,64 @@ public:
|
||||
~MPSAllocator() override {
|
||||
_getAllocImpl().emptyCache();
|
||||
}
|
||||
DeleterFnPtr raw_deleter() const override { return &Delete; }
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return &Delete;
|
||||
}
|
||||
|
||||
DataPtr allocate(const size_t nbytes) const override {
|
||||
__block id<MTLBuffer> buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr;
|
||||
return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
|
||||
return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
|
||||
}
|
||||
|
||||
// implementation of IMPSAllocator interface
|
||||
DataPtr allocScalarBufferWithValue(void *value, size_t size) const override {
|
||||
DataPtr allocScalarBufferWithValue(void* value, size_t size) const override {
|
||||
id<MTLBuffer> buf = _getAllocImpl().allocScalarBufferWithValue(value, size);
|
||||
return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
|
||||
return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
|
||||
}
|
||||
bool isSharedBuffer(void* ptr) const override {
|
||||
return _getAllocImpl().isSharedBuffer(ptr);
|
||||
}
|
||||
bool isSharedStorageSupported() const override {
|
||||
return m_has_unified_memory;
|
||||
}
|
||||
void emptyCache() const override {
|
||||
_getAllocImpl().emptyCache();
|
||||
}
|
||||
ssize_t getUnalignedBufferSize(void* ptr) const override {
|
||||
return _getAllocImpl().getUnalignedBufferSize(ptr);
|
||||
}
|
||||
IntArrayRef getBufferShape(void* ptr) const override {
|
||||
return _getAllocImpl().getBufferShape(ptr);
|
||||
}
|
||||
void setBufferShape(void* ptr, const IntArrayRef& shape) const override {
|
||||
_getAllocImpl().setBufferShape(ptr, shape);
|
||||
}
|
||||
size_t getTotalAllocatedMemory() const override {
|
||||
return _getAllocImpl().getTotalAllocatedMemory();
|
||||
}
|
||||
size_t getCurrentAllocatedMemory() const override {
|
||||
return _getAllocImpl().getCurrentAllocatedMemory();
|
||||
}
|
||||
size_t getDriverAllocatedMemory() const override {
|
||||
return _getAllocImpl().getDriverAllocatedMemory();
|
||||
}
|
||||
ssize_t getLowWatermarkValue() const override {
|
||||
return _getAllocImpl().getLowWatermarkValue();
|
||||
}
|
||||
size_t getLowWatermarkLimit() const override {
|
||||
return _getAllocImpl().getLowWatermarkLimit();
|
||||
}
|
||||
size_t getHighWatermarkLimit() const override {
|
||||
return _getAllocImpl().getHighWatermarkLimit();
|
||||
}
|
||||
void setLowWatermarkRatio(double ratio) const override {
|
||||
_getAllocImpl().setLowWatermarkRatio(ratio);
|
||||
}
|
||||
void setHighWatermarkRatio(double ratio) const override {
|
||||
_getAllocImpl().setHighWatermarkRatio(ratio);
|
||||
}
|
||||
bool isSharedBuffer(void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); }
|
||||
bool isSharedStorageSupported() const override { return m_has_unified_memory; }
|
||||
void emptyCache() const override { _getAllocImpl().emptyCache(); }
|
||||
ssize_t getUnalignedBufferSize(void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); }
|
||||
IntArrayRef getBufferShape(void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); }
|
||||
void setBufferShape(void* ptr, const IntArrayRef& shape) const override { _getAllocImpl().setBufferShape(ptr, shape); }
|
||||
size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); }
|
||||
size_t getCurrentAllocatedMemory() const override { return _getAllocImpl().getCurrentAllocatedMemory(); }
|
||||
size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); }
|
||||
ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); }
|
||||
size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); }
|
||||
size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); }
|
||||
void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); }
|
||||
void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); }
|
||||
|
||||
private:
|
||||
private:
|
||||
bool m_has_unified_memory;
|
||||
uint32_t m_usage;
|
||||
|
||||
@ -662,15 +694,13 @@ namespace native {
|
||||
// Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we
|
||||
// will be able to use SharedStorageMode for MTLBuffer allocations. This will
|
||||
// avoid extra copies on DataLoading operations.
|
||||
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device)
|
||||
{
|
||||
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
|
||||
return at::mps::_getSharedAllocator().isSharedBuffer(self.storage().data());
|
||||
}
|
||||
|
||||
// torch.pin_memory() implementation
|
||||
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device)
|
||||
{
|
||||
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
|
||||
auto* shared_allocator = at::mps::getIMPSAllocator(true);
|
||||
TORCH_CHECK(shared_allocator, "unable to pin memory on a non-unified memory device");
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
@ -23,9 +23,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
|
||||
}
|
||||
|
||||
MPSDevice* MPSDevice::getInstance() {
|
||||
c10::call_once(mpsdev_init, [] {
|
||||
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
|
||||
});
|
||||
c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
|
||||
return mps_device.get();
|
||||
}
|
||||
|
||||
@ -33,25 +31,31 @@ id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLF
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
|
||||
NSError* error = nil;
|
||||
if (!_mtl_indexing_library) {
|
||||
MTLCompileOptions *options = [MTLCompileOptions new];
|
||||
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device)];
|
||||
[options setFastMathEnabled: YES];
|
||||
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
|
||||
options: options
|
||||
error: &error];
|
||||
MTLCompileOptions* options = [MTLCompileOptions new];
|
||||
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
|
||||
[options setFastMathEnabled:YES];
|
||||
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
|
||||
encoding:NSASCIIStringEncoding]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
|
||||
id<MTLFunction> indexFunction = nil;
|
||||
if (constantValues) {
|
||||
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]
|
||||
constantValues: constantValues
|
||||
error: &error] autorelease];
|
||||
indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]
|
||||
constantValues:constantValues
|
||||
error:&error] autorelease];
|
||||
} else {
|
||||
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
|
||||
indexFunction =
|
||||
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
|
||||
}
|
||||
|
||||
TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]);
|
||||
TORCH_CHECK(indexFunction,
|
||||
"Failed to create specialized function state object: ",
|
||||
kernel,
|
||||
", error: ",
|
||||
[[error description] UTF8String]);
|
||||
|
||||
return indexFunction;
|
||||
}
|
||||
@ -63,49 +67,52 @@ MPSDevice::~MPSDevice() {
|
||||
_mtl_indexing_library = nil;
|
||||
}
|
||||
|
||||
MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
|
||||
MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) {
|
||||
// Check that MacOS 12.3+ version of MPS framework is available
|
||||
// Create the MPSGraph and check method introduced in 12.3+
|
||||
// which is used by MPS backend.
|
||||
id mpsCD = NSClassFromString(@"MPSGraph");
|
||||
|
||||
if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor:
|
||||
recurrentWeight:
|
||||
inputWeight:
|
||||
bias:
|
||||
initState:
|
||||
initCell:
|
||||
descriptor:
|
||||
name:)] == NO) {
|
||||
if ([mpsCD instancesRespondToSelector:@selector
|
||||
(LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) {
|
||||
return;
|
||||
}
|
||||
|
||||
NSArray* devices = [MTLCopyAllDevices() autorelease];
|
||||
for (unsigned long i = 0 ; i < [devices count] ; i++) {
|
||||
id<MTLDevice> device = devices[i];
|
||||
if(![device isLowPower]) { // exclude Intel GPUs
|
||||
for (unsigned long i = 0; i < [devices count]; i++) {
|
||||
id<MTLDevice> device = devices[i];
|
||||
if (![device isLowPower]) { // exclude Intel GPUs
|
||||
_mtl_device = [device retain];
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
|
||||
|
||||
}
|
||||
|
||||
bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
id mpsCD = NSClassFromString(@"MPSGraph");
|
||||
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES;
|
||||
static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector(
|
||||
sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
|
||||
static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
|
||||
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:
|
||||
axis:name:)] == YES;
|
||||
static bool _macos_13_1_plus =
|
||||
[mpsCD instancesRespondToSelector:@selector
|
||||
(sampleGridWithSourceTensor:
|
||||
coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode
|
||||
:samplingMode:constantValue:name:)] == YES;
|
||||
static bool _macos_13_2_plus =
|
||||
[mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
|
||||
static bool _macos_13_3_plus = [_mtl_device respondsToSelector:@selector(maximumConcurrentCompilationTaskCount)];
|
||||
|
||||
switch (version) {
|
||||
case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus;
|
||||
case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus;
|
||||
case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus;
|
||||
case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus;
|
||||
default: return false;
|
||||
case MacOSVersion::MACOS_VER_13_0_PLUS:
|
||||
return _macos_13_0_plus;
|
||||
case MacOSVersion::MACOS_VER_13_1_PLUS:
|
||||
return _macos_13_1_plus;
|
||||
case MacOSVersion::MACOS_VER_13_2_PLUS:
|
||||
return _macos_13_2_plus;
|
||||
case MacOSVersion::MACOS_VER_13_3_PLUS:
|
||||
return _macos_13_3_plus;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,40 +4,46 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
{
|
||||
TORCH_WARN_ONCE("The operator '", op.schema().operator_name(), "' is not currently supported ",
|
||||
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_WARN_ONCE("The operator '",
|
||||
op.schema().operator_name(),
|
||||
"' is not currently supported ",
|
||||
"on the MPS backend and will fall back to run on the CPU.",
|
||||
" This may have performance implications.");
|
||||
native::cpu_fallback(op, stack);
|
||||
}
|
||||
|
||||
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
{
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not currently implemented ",
|
||||
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack){TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"The operator '",
|
||||
op.schema().operator_name(),
|
||||
"' is not currently implemented ",
|
||||
"for the MPS device. If you want this op to be added in priority during the prototype ",
|
||||
"phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ",
|
||||
"As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
|
||||
"to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
|
||||
"on MPS.")
|
||||
}
|
||||
|
||||
"on MPS.")}
|
||||
|
||||
// This dispatch should never be called for tensor on MPS but is frequently called
|
||||
// If one of them are on CPU
|
||||
Tensor slow_conv2d_forward_mps(
|
||||
const Tensor &self,
|
||||
const Tensor &weight,
|
||||
IntArrayRef kernel_size,
|
||||
const c10::optional<Tensor> &bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding) {
|
||||
TORCH_CHECK(self.device() == weight.device(), __func__, ": input(device='", self.device(), "') and weight(device=", weight.device(), "') must be on the same device");
|
||||
TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
|
||||
Tensor slow_conv2d_forward_mps(const Tensor& self,
|
||||
const Tensor& weight,
|
||||
IntArrayRef kernel_size,
|
||||
const c10::optional<Tensor>& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding) {
|
||||
TORCH_CHECK(self.device() == weight.device(),
|
||||
__func__,
|
||||
": input(device='",
|
||||
self.device(),
|
||||
"') and weight(device=",
|
||||
weight.device(),
|
||||
"') must be on the same device");
|
||||
TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, MPS, m) {
|
||||
static const char *enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
|
||||
static const char* enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
|
||||
if (!enable_mps_fallback || std::stoi(enable_mps_fallback) == 0) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&mps_error_fallback>());
|
||||
} else {
|
||||
|
@ -23,8 +23,9 @@ Generator createMPSGenerator(uint64_t seed_val) {
|
||||
} // namespace mps
|
||||
|
||||
MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
|
||||
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
|
||||
data_({.seed = seed_in}), engine_(seed_in, 0, 0) { }
|
||||
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
|
||||
data_({.seed = seed_in}),
|
||||
engine_(seed_in, 0, 0) {}
|
||||
|
||||
void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
data_.seed = seed;
|
||||
@ -60,7 +61,8 @@ c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t total_size = states_size + seed_size;
|
||||
|
||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
|
||||
auto state_tensor = at::detail::empty_cpu(
|
||||
{(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
|
||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||
auto current_seed = this->current_seed();
|
||||
memcpy(rng_state, this->data_.state.data(), states_size);
|
||||
|
@ -1,57 +1,47 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/mps/MPSGuardImpl.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSGuardImpl.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
|
||||
void MPSGuardImpl::createEvent(
|
||||
mpsEvent_t* event,
|
||||
const EventFlag flag) const {
|
||||
}
|
||||
void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {}
|
||||
|
||||
void MPSGuardImpl::destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept {
|
||||
if (!event) return;
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
mps_event->~MPSEvent();
|
||||
void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept {
|
||||
if (!event)
|
||||
return;
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
mps_event->~MPSEvent();
|
||||
}
|
||||
|
||||
}
|
||||
void MPSGuardImpl::record(void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const {
|
||||
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
|
||||
void MPSGuardImpl::record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const {
|
||||
auto mps_event = static_cast<mpsEvent_t>(*event);
|
||||
MPSStream mps_stream{stream};
|
||||
mps_event->recordEvent(true);
|
||||
}
|
||||
|
||||
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
void MPSGuardImpl::block(void* event, const Stream& stream) const {
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
MPSStream mps_stream{stream};
|
||||
|
||||
auto mps_event = static_cast<mpsEvent_t>(*event);
|
||||
MPSStream mps_stream{stream};
|
||||
mps_event->recordEvent(true);
|
||||
}
|
||||
mps_event->waitForEvent(true);
|
||||
}
|
||||
|
||||
void MPSGuardImpl::block(
|
||||
void* event,
|
||||
const Stream& stream) const {
|
||||
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
MPSStream mps_stream{stream};
|
||||
|
||||
mps_event->waitForEvent(true);
|
||||
}
|
||||
|
||||
bool MPSGuardImpl::queryEvent(void* event) const {
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
return mps_event->queryEvent();
|
||||
}
|
||||
bool MPSGuardImpl::queryEvent(void* event) const {
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
return mps_event->queryEvent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
@ -17,9 +17,9 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
|
||||
TORCH_CHECK(_stream.device_type() == DeviceType::MPS);
|
||||
_serialQueue = dispatch_queue_create("metal gpu stream", nullptr);
|
||||
_executionDescriptor = [MPSGraphExecutionDescriptor new];
|
||||
_executionDescriptor.completionHandler = ^(NSDictionary<MPSGraphTensor *,
|
||||
MPSGraphTensorData *> * resultsDictionary,
|
||||
NSError * _Nullable error) { };
|
||||
_executionDescriptor.completionHandler =
|
||||
^(NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* resultsDictionary, NSError* _Nullable error) {
|
||||
};
|
||||
}
|
||||
|
||||
MPSStream::~MPSStream() {
|
||||
@ -41,7 +41,7 @@ MPSCommandBuffer* MPSStream::commandBuffer() {
|
||||
void MPSStream::synchronize(SyncType syncType) {
|
||||
if (!_commandBuffer)
|
||||
return;
|
||||
switch(syncType) {
|
||||
switch (syncType) {
|
||||
case SyncType::NONE:
|
||||
// typically in GPU to GPU copies we won't commit explicitly
|
||||
break;
|
||||
@ -108,32 +108,34 @@ void MPSStream::_flush(bool commitAndWait) const {
|
||||
}
|
||||
|
||||
void MPSStream::addCompletedHandler(MTLCommandBufferHandler block) {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
[commandBuffer() addCompletedHandler:block];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType)
|
||||
{
|
||||
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) {
|
||||
TORCH_INTERNAL_ASSERT(length >= offset);
|
||||
if (length == 0) return;
|
||||
if (length == 0)
|
||||
return;
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
|
||||
[blitEncoder fillBuffer:buffer
|
||||
range:NSMakeRange(offset, length)
|
||||
value:value];
|
||||
[blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value];
|
||||
[blitEncoder endEncoding];
|
||||
synchronize(syncType);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) {
|
||||
void MPSStream::copy(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
SyncType syncType) {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -149,10 +151,14 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
});
|
||||
}
|
||||
|
||||
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, size_t length,
|
||||
size_t srcOffset, size_t dstOffset, bool non_blocking) {
|
||||
copy(srcBuffer, dstBuffer, length, srcOffset, dstOffset,
|
||||
!non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
|
||||
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
bool non_blocking) {
|
||||
copy(
|
||||
srcBuffer, dstBuffer, length, srcOffset, dstOffset, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
|
||||
}
|
||||
|
||||
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) {
|
||||
@ -173,7 +179,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
resultsDictionary:results
|
||||
executionDescriptor:_executionDescriptor];
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
@ -184,8 +190,7 @@ MPSStream* MPSStreamImpl::_stream = nullptr;
|
||||
|
||||
MPSStream* MPSStreamImpl::getInstance() {
|
||||
if (_stream == nullptr) {
|
||||
_stream =
|
||||
new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
|
||||
_stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
|
||||
}
|
||||
return _stream;
|
||||
}
|
||||
@ -204,8 +209,8 @@ MPSStream* getDefaultMPSStream() {
|
||||
// MPSEvent
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
MPSEvent::MPSEvent(bool deferInitialization) :
|
||||
is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
|
||||
MPSEvent::MPSEvent(bool deferInitialization)
|
||||
: is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
|
||||
if (!deferInitialization) {
|
||||
initialize();
|
||||
}
|
||||
@ -256,8 +261,7 @@ void MPSEvent::waitForEvent(bool syncEvent) {
|
||||
});
|
||||
}
|
||||
|
||||
void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block)
|
||||
{
|
||||
void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) {
|
||||
if (!is_initialized)
|
||||
initialize();
|
||||
dispatch_sync(_stream->queue(), ^() {
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
@ -28,27 +28,31 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
case ScalarType::Bool:
|
||||
return MPSDataTypeBool;
|
||||
case ScalarType::Double:
|
||||
TORCH_CHECK_TYPE(false, "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
|
||||
TORCH_CHECK_TYPE(false,
|
||||
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
|
||||
"Please use float32 instead.")
|
||||
default:
|
||||
TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
TORCH_CHECK_TYPE(
|
||||
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
}
|
||||
}
|
||||
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
|
||||
// Int32, Half and Float32 types. These utilities are to help cast to these
|
||||
// types.
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const Tensor& input,
|
||||
bool includesInt64) {
|
||||
MPSDataType dataType = getMPSDataType(input.scalar_type());
|
||||
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
bool condition =
|
||||
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
if (includesInt64) {
|
||||
condition = condition && (dataType != MPSDataTypeInt64);
|
||||
}
|
||||
if (condition) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
return [mpsGraph castTensor:inputTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
return inputTensor;
|
||||
}
|
||||
@ -56,16 +60,18 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor,
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
|
||||
// Int32, Half and Float32 types. These utilities are to help cast from these
|
||||
// types.
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const Tensor& input,
|
||||
bool includesInt64) {
|
||||
MPSDataType dataType = getMPSDataType(input.scalar_type());
|
||||
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
bool condition =
|
||||
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
if (includesInt64) {
|
||||
condition = condition && (dataType != MPSDataTypeInt64);
|
||||
}
|
||||
if (condition) {
|
||||
inputTensor = [mpsGraph castTensor:inputTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
return inputTensor;
|
||||
}
|
||||
@ -92,7 +98,8 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
|
||||
case ScalarType::Bool:
|
||||
return MPSDataTypeBool;
|
||||
default:
|
||||
TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
TORCH_CHECK_TYPE(
|
||||
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,7 +155,7 @@ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
|
||||
NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
|
||||
int64_t ndim = t.dim();
|
||||
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
|
||||
for (const auto i: c10::irange(ndim)) {
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
axes[i] = [NSNumber numberWithInteger:i];
|
||||
}
|
||||
return axes;
|
||||
@ -159,7 +166,7 @@ NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim)
|
||||
IntArrayRef dimValues = dim.value();
|
||||
int ndim = dimValues.size();
|
||||
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
|
||||
for (const auto i: c10::irange(ndim)) {
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
axes[i] = [NSNumber numberWithInteger:dimValues[i]];
|
||||
}
|
||||
|
||||
@ -170,11 +177,11 @@ NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim)
|
||||
}
|
||||
|
||||
std::string getMPSShapeString(MPSShape* shape) {
|
||||
std::string str;
|
||||
for(NSNumber *elem in shape) {
|
||||
str += std::to_string(elem.unsignedLongValue) + ",";
|
||||
}
|
||||
return str;
|
||||
std::string str;
|
||||
for (NSNumber* elem in shape) {
|
||||
str += std::to_string(elem.unsignedLongValue) + ",";
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string getArrayRefString(const IntArrayRef s) {
|
||||
@ -184,25 +191,25 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) {
|
||||
std::string str;
|
||||
// The key format per tensor would look like ":Float32[1,1,1,10]:"
|
||||
for (const Tensor& tensor: tensors) {
|
||||
str += ":";
|
||||
if (tensor.defined()) {
|
||||
str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
|
||||
// if tensor is a scalar
|
||||
if (tensor.dim() == 0) {
|
||||
str += "Scalar";
|
||||
} else {
|
||||
const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
str += std::string(ns_shape_key.UTF8String);
|
||||
}
|
||||
str += "]";
|
||||
std::string str;
|
||||
// The key format per tensor would look like ":Float32[1,1,1,10]:"
|
||||
for (const Tensor& tensor : tensors) {
|
||||
str += ":";
|
||||
if (tensor.defined()) {
|
||||
str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
|
||||
// if tensor is a scalar
|
||||
if (tensor.dim() == 0) {
|
||||
str += "Scalar";
|
||||
} else {
|
||||
str += "Undefined";
|
||||
const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
str += std::string(ns_shape_key.UTF8String);
|
||||
}
|
||||
str += "]";
|
||||
} else {
|
||||
str += "Undefined";
|
||||
}
|
||||
return str;
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) {
|
||||
@ -216,7 +223,7 @@ MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) {
|
||||
const NSUInteger C = sizes[1];
|
||||
const NSUInteger H = sizes[2];
|
||||
const NSUInteger W = sizes[3];
|
||||
return @[@(N), @(H), @(W), @(C)];
|
||||
return @[ @(N), @(H), @(W), @(C) ];
|
||||
}
|
||||
const int sz = sizes.size();
|
||||
const int sz_ = (sz > 0) ? sz : 1;
|
||||
@ -232,27 +239,27 @@ MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) {
|
||||
}
|
||||
|
||||
void printTensorNDArray(const Tensor& t) {
|
||||
if (!t.is_mps()) return;
|
||||
if(t.numel() == 0) return;
|
||||
if (!t.is_mps())
|
||||
return;
|
||||
if (t.numel() == 0)
|
||||
return;
|
||||
// Get shape and data type
|
||||
auto selfShape = getMPSShape(t);
|
||||
auto selfDType = getMPSDataType(t.scalar_type());
|
||||
|
||||
// Initialize data
|
||||
id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
|
||||
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf
|
||||
shape:selfShape
|
||||
dataType:selfDType] autorelease];
|
||||
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape
|
||||
dataType:selfDType] autorelease];
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wobjc-method-access")
|
||||
#if C10_CLANG_HAS_WARNING("-Wobjc-method-access")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access")
|
||||
#endif
|
||||
#endif
|
||||
[tdata printNDArray];
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
}
|
||||
|
||||
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType)
|
||||
{
|
||||
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) {
|
||||
id<MTLBuffer> buffer = getMTLBufferStorage(tensor);
|
||||
MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer
|
||||
shape:shape
|
||||
@ -261,16 +268,19 @@ MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType
|
||||
return [tmpGraphTensorData mpsndarray];
|
||||
}
|
||||
|
||||
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape,
|
||||
bool gatherTensorData, MPSDataType dataType) : _tensor(src)
|
||||
{
|
||||
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
const Tensor& src,
|
||||
MPSShape* mpsShape,
|
||||
bool gatherTensorData,
|
||||
MPSDataType dataType)
|
||||
: _tensor(src) {
|
||||
TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!");
|
||||
// extract the pointer to MTLBuffer from the Tensor's storage
|
||||
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
|
||||
bool sliceViewTensor = canSliceViewTensor(src, mpsShape);
|
||||
// a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
|
||||
if ((!src.is_contiguous() || (src.storage_offset() && !sliceViewTensor)) && gatherTensorData) {
|
||||
Tensor emptyShell = Tensor();
|
||||
Tensor emptyShell = Tensor();
|
||||
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
|
||||
_tensor = gatherViewTensor(src, emptyShell);
|
||||
if (!_tensor.has_storage()) {
|
||||
@ -285,8 +295,9 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
|
||||
// if buffer size is zero in here, it's not a user error. It could be a missing check for
|
||||
// tensor.numel() == 0 in our internal implementations of ops.
|
||||
TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!");
|
||||
const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType :
|
||||
_tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type());
|
||||
const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType
|
||||
: _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type())
|
||||
: getMPSDataType(_tensor.scalar_type());
|
||||
|
||||
if (src.is_contiguous() && src.storage_offset() && sliceViewTensor) {
|
||||
_value = getMPSGraphTensorDataForView(src, mpsShape, mpsDataType);
|
||||
@ -295,34 +306,25 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
|
||||
mpsShape = getMPSShape(_tensor);
|
||||
}
|
||||
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
|
||||
shape:mpsShape
|
||||
dataType:mpsDataType] autorelease];
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:mpsShape dataType:mpsDataType] autorelease];
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(_value);
|
||||
_placeholder = mpsGraphTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph,
|
||||
MPSStream* mpsStream,
|
||||
const Tensor& tensor) {
|
||||
MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) {
|
||||
auto mpsShape = getMPSShape(tensor);
|
||||
auto dataType = getMPSDataType(tensor.scalar_type());
|
||||
|
||||
MPSGraphTensorData *result = nil;
|
||||
MPSGraphTensorData* result = nil;
|
||||
if (tensor.numel() > 0) {
|
||||
id<MTLBuffer> buf = getMTLBufferStorage(tensor);
|
||||
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf
|
||||
shape:mpsShape
|
||||
dataType:dataType]
|
||||
autorelease];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] autorelease];
|
||||
} else {
|
||||
// create empty NDArray
|
||||
MPSNDArrayDescriptor *desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
|
||||
shape:mpsShape];
|
||||
MPSNDArray *emptyArray = [[[MPSNDArray alloc]
|
||||
initWithDevice:mpsStream->device() descriptor:desc] autorelease];
|
||||
MPSNDArrayDescriptor* desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsShape];
|
||||
MPSNDArray* emptyArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:desc] autorelease];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:emptyArray] autorelease];
|
||||
}
|
||||
assert(result);
|
||||
@ -332,30 +334,40 @@ MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph,
|
||||
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) {
|
||||
switch (type) {
|
||||
case ScalarType::Double:
|
||||
case ScalarType::Float: return {.value.f = scalar.to<float>() , .size = sizeof(float) , .type = type};
|
||||
case ScalarType::Half: return {.value.h = scalar.to<at::Half>(), .size = sizeof(short) , .type = type};
|
||||
case ScalarType::Long: return {.value.i = scalar.to<int64_t>() , .size = sizeof(int64_t), .type = type};
|
||||
case ScalarType::Int: return {.value.i = scalar.to<int32_t>() , .size = sizeof(int32_t), .type = type};
|
||||
case ScalarType::Short: return {.value.i = scalar.to<int16_t>() , .size = sizeof(int16_t), .type = type};
|
||||
case ScalarType::Char: return {.value.i = scalar.to<int8_t>() , .size = sizeof(int8_t) , .type = type};
|
||||
case ScalarType::Byte: return {.value.i = scalar.to<uint8_t>() , .size = sizeof(uint8_t), .type = type};
|
||||
case ScalarType::Bool: return {.value.b = scalar.to<bool>() , .size = sizeof(bool) , .type = type};
|
||||
case ScalarType::Float:
|
||||
return {.value.f = scalar.to<float>(), .size = sizeof(float), .type = type};
|
||||
case ScalarType::Half:
|
||||
return {.value.h = scalar.to<at::Half>(), .size = sizeof(short), .type = type};
|
||||
case ScalarType::Long:
|
||||
return {.value.i = scalar.to<int64_t>(), .size = sizeof(int64_t), .type = type};
|
||||
case ScalarType::Int:
|
||||
return {.value.i = scalar.to<int32_t>(), .size = sizeof(int32_t), .type = type};
|
||||
case ScalarType::Short:
|
||||
return {.value.i = scalar.to<int16_t>(), .size = sizeof(int16_t), .type = type};
|
||||
case ScalarType::Char:
|
||||
return {.value.i = scalar.to<int8_t>(), .size = sizeof(int8_t), .type = type};
|
||||
case ScalarType::Byte:
|
||||
return {.value.i = scalar.to<uint8_t>(), .size = sizeof(uint8_t), .type = type};
|
||||
case ScalarType::Bool:
|
||||
return {.value.b = scalar.to<bool>(), .size = sizeof(bool), .type = type};
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend.");
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar) {
|
||||
MPSGraphTensorData *result = nullptr;
|
||||
MPSGraphTensorData* result = nullptr;
|
||||
// Scalar pools are only supported on devices with unified memory
|
||||
if (mpsStream->device().hasUnifiedMemory) {
|
||||
scalar.buffer = getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size);
|
||||
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer: scalar.getMTLBuffer()
|
||||
shape: @[@1]
|
||||
dataType: getMPSScalarType(scalar.type)] autorelease];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:scalar.getMTLBuffer()
|
||||
shape:@[ @1 ]
|
||||
dataType:getMPSScalarType(scalar.type)] autorelease];
|
||||
} else {
|
||||
MPSNDArrayDescriptor *tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type) shape:@[@1]];
|
||||
MPSNDArray *tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:tensorDesc] autorelease];
|
||||
MPSNDArrayDescriptor* tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type)
|
||||
shape:@[ @1 ]];
|
||||
MPSNDArray* tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device()
|
||||
descriptor:tensorDesc] autorelease];
|
||||
[tensorNDArray writeBytes:&scalar.value strideBytes:nil];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:tensorNDArray] autorelease];
|
||||
}
|
||||
@ -371,58 +383,50 @@ MPSGraph* make_mps_graph() {
|
||||
return mpsGraph;
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) {
|
||||
return [mpsGraph placeholderWithShape:nil
|
||||
dataType:dataType
|
||||
name:nil];
|
||||
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
|
||||
return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
|
||||
return [mpsGraph placeholderWithShape:mpsShape
|
||||
dataType:dataType
|
||||
name:nil];
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
|
||||
return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor) {
|
||||
return [mpsGraph placeholderWithShape:getMPSShape(tensor)
|
||||
dataType:getMPSScalarType(tensor.scalar_type())
|
||||
name:nil];
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) {
|
||||
return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) {
|
||||
return [mpsGraph placeholderWithShape:@[@1]
|
||||
dataType:dataType
|
||||
name:nil];
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
|
||||
return [mpsGraph placeholderWithShape:@[ @1 ] dataType:dataType name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar) {
|
||||
return [mpsGraph placeholderWithShape:@[@1]
|
||||
dataType:getMPSScalarType(scalar.type())
|
||||
name:nil];
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar) {
|
||||
return [mpsGraph placeholderWithShape:@[ @1 ] dataType:getMPSScalarType(scalar.type()) name:nil];
|
||||
}
|
||||
|
||||
// this is meant to suppress the availability warning on castTensor
|
||||
// we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
|
||||
if ([tensor dataType] == toType) {
|
||||
return tensor;
|
||||
}
|
||||
return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];
|
||||
}
|
||||
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
|
||||
MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
|
||||
return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"];
|
||||
}
|
||||
|
||||
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor) {
|
||||
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) {
|
||||
TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!");
|
||||
return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil]
|
||||
dimension:2 withDimension:1 name: nil];
|
||||
dimension:2
|
||||
withDimension:1
|
||||
name:nil];
|
||||
}
|
||||
|
||||
string get_mem_format_string(c10::MemoryFormat memory_format) {
|
||||
string mem_format_key;
|
||||
switch(memory_format) {
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
break;
|
||||
@ -439,11 +443,12 @@ string get_mem_format_string(c10::MemoryFormat memory_format) {
|
||||
MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;
|
||||
|
||||
class MPSGraphCacheCallback : public IMpsAllocatorCallback {
|
||||
public:
|
||||
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { }
|
||||
public:
|
||||
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) {}
|
||||
|
||||
void executeMPSAllocatorCallback(void* ptr, EventType event) override { }
|
||||
private:
|
||||
void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
|
||||
|
||||
private:
|
||||
MPSGraphCache* graph_cache;
|
||||
};
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,52 +1,54 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void set_kernel_params
|
||||
(int64_t isizeH, int64_t isizeW,
|
||||
int64_t osizeH, int64_t osizeW,
|
||||
int64_t &strideH, int64_t &strideW,
|
||||
int64_t &kernel_sizeH, int64_t &kernel_sizeW,
|
||||
bool check_avg_pooling = false) {
|
||||
|
||||
void set_kernel_params(int64_t isizeH,
|
||||
int64_t isizeW,
|
||||
int64_t osizeH,
|
||||
int64_t osizeW,
|
||||
int64_t& strideH,
|
||||
int64_t& strideW,
|
||||
int64_t& kernel_sizeH,
|
||||
int64_t& kernel_sizeW,
|
||||
bool check_avg_pooling = false) {
|
||||
TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW),
|
||||
"Adaptive pool MPS: Input height and width must both be greater than, "
|
||||
"or equal to, or lesser than output height and width")
|
||||
|
||||
if(isizeH >= osizeH) {
|
||||
if (isizeH >= osizeH) {
|
||||
if (check_avg_pooling) {
|
||||
TORCH_CHECK((isizeH % osizeH == 0 && isizeW % osizeW == 0),
|
||||
"Adaptive pool MPS: input sizes must be divisible by output sizes.");
|
||||
"Adaptive pool MPS: input sizes must be divisible by output sizes.");
|
||||
}
|
||||
strideH = (int64_t) (isizeH / osizeH);
|
||||
strideW = (int64_t) (isizeW / osizeW);
|
||||
kernel_sizeH = isizeH - (osizeH-1) * strideH;
|
||||
kernel_sizeW = isizeW - (osizeW-1) * strideW;
|
||||
strideH = (int64_t)(isizeH / osizeH);
|
||||
strideW = (int64_t)(isizeW / osizeW);
|
||||
kernel_sizeH = isizeH - (osizeH - 1) * strideH;
|
||||
kernel_sizeW = isizeW - (osizeW - 1) * strideW;
|
||||
} else {
|
||||
if (check_avg_pooling) {
|
||||
TORCH_CHECK((osizeH % isizeH == 0 && osizeW % isizeW == 0),
|
||||
"Adaptive pool MPS: output sizes must be divisible by input sizes.");
|
||||
}
|
||||
strideH = (int64_t) (osizeH / isizeH);
|
||||
strideW = (int64_t) (osizeW / isizeW);
|
||||
kernel_sizeH = osizeH - (isizeH-1) * strideH;
|
||||
kernel_sizeW = osizeW - (isizeW-1) * strideW;
|
||||
strideH = (int64_t)(osizeH / isizeH);
|
||||
strideW = (int64_t)(osizeW / isizeW);
|
||||
kernel_sizeH = osizeH - (isizeH - 1) * strideH;
|
||||
kernel_sizeW = osizeW - (isizeW - 1) * strideW;
|
||||
}
|
||||
}
|
||||
|
||||
// Adaptive average pooling
|
||||
Tensor& adaptive_avg_pool2d_out_mps
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
Tensor& output) {
|
||||
|
||||
Tensor& adaptive_avg_pool2d_out_mps(const Tensor& input, IntArrayRef output_size, Tensor& output) {
|
||||
for (int64_t i = 1; i < input.ndimension(); i++) {
|
||||
TORCH_CHECK(input.size(i) > 0,
|
||||
"adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
"but input has sizes ", input.sizes(), " with dimension ", i, " being empty");
|
||||
"adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
"but input has sizes ",
|
||||
input.sizes(),
|
||||
" with dimension ",
|
||||
i,
|
||||
" being empty");
|
||||
}
|
||||
|
||||
int64_t isizeH = input.size(-2);
|
||||
@ -57,45 +59,39 @@ Tensor& adaptive_avg_pool2d_out_mps
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW, true);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
|
||||
|
||||
if(isizeH >= osizeH) {
|
||||
output = at::avg_pool2d(input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
false,
|
||||
true,
|
||||
c10::nullopt);
|
||||
} else {
|
||||
if (isizeH >= osizeH) {
|
||||
output = at::avg_pool2d(input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
false,
|
||||
true,
|
||||
c10::nullopt);
|
||||
} else {
|
||||
Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
auto input_sizes = input.sizes();
|
||||
std::vector<int64_t> phony_shape{input_sizes.begin(), input_sizes.end() -2};
|
||||
std::vector<int64_t> phony_shape{input_sizes.begin(), input_sizes.end() - 2};
|
||||
phony_shape.push_back(output_size[0]);
|
||||
phony_shape.push_back(output_size[1]);
|
||||
phony_grad.resize_(IntArrayRef(phony_shape));
|
||||
output = at::avg_pool2d_backward(input,
|
||||
phony_grad,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
false,
|
||||
true,
|
||||
c10::nullopt);
|
||||
output = at::avg_pool2d_backward(input,
|
||||
phony_grad,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
false,
|
||||
true,
|
||||
c10::nullopt);
|
||||
// Multiply output by kernel size
|
||||
output = at::mul(output, kernel_sizeH*kernel_sizeW);
|
||||
output = at::mul(output, kernel_sizeH * kernel_sizeW);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool2d_mps
|
||||
(at::Tensor const& input,
|
||||
IntArrayRef output_size) {
|
||||
|
||||
Tensor adaptive_avg_pool2d_mps(at::Tensor const& input, IntArrayRef output_size) {
|
||||
IntArrayRef output_shape;
|
||||
|
||||
auto osizeH = output_size[0];
|
||||
@ -103,7 +99,7 @@ Tensor adaptive_avg_pool2d_mps
|
||||
|
||||
std::vector<long long> out_dims = {};
|
||||
|
||||
if(input.ndimension() == 4) {
|
||||
if (input.ndimension() == 4) {
|
||||
auto sizeB = input.size(0);
|
||||
auto sizeD = input.size(1);
|
||||
|
||||
@ -112,8 +108,7 @@ Tensor adaptive_avg_pool2d_mps
|
||||
out_dims.push_back(osizeH);
|
||||
out_dims.push_back(osizeW);
|
||||
output_shape = IntArrayRef(out_dims);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
auto sizeD = input.size(0);
|
||||
out_dims.push_back(sizeD);
|
||||
out_dims.push_back(osizeH);
|
||||
@ -122,21 +117,12 @@ Tensor adaptive_avg_pool2d_mps
|
||||
}
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
Tensor output = at::native::empty_mps(
|
||||
output_shape,
|
||||
input.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
memory_format);
|
||||
Tensor output =
|
||||
at::native::empty_mps(output_shape, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format);
|
||||
return adaptive_avg_pool2d_out_mps(input, output_size, output);
|
||||
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool2d_backward_mps
|
||||
(const Tensor& gradOutput,
|
||||
const Tensor& input) {
|
||||
|
||||
Tensor adaptive_avg_pool2d_backward_mps(const Tensor& gradOutput, const Tensor& input) {
|
||||
int64_t isizeH = input.size(-2);
|
||||
int64_t isizeW = input.size(-1);
|
||||
int64_t osizeH = gradOutput.size(-2);
|
||||
@ -145,14 +131,11 @@ Tensor adaptive_avg_pool2d_backward_mps
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW, true);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
|
||||
|
||||
auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
if (gradInput.numel() != 0) {
|
||||
if(isizeH >= osizeH) {
|
||||
if (isizeH >= osizeH) {
|
||||
gradInput = at::avg_pool2d_backward(gradOutput,
|
||||
input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
@ -163,13 +146,13 @@ Tensor adaptive_avg_pool2d_backward_mps
|
||||
c10::nullopt);
|
||||
} else {
|
||||
gradInput = at::avg_pool2d(gradOutput,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
false,
|
||||
true,
|
||||
c10::nullopt);
|
||||
gradInput = at::mul(gradInput, kernel_sizeH*kernel_sizeW);
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
false,
|
||||
true,
|
||||
c10::nullopt);
|
||||
gradInput = at::mul(gradInput, kernel_sizeH * kernel_sizeW);
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,16 +161,16 @@ Tensor adaptive_avg_pool2d_backward_mps
|
||||
|
||||
// Adaptive max pooling
|
||||
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
|
||||
(const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
|
||||
for (int64_t i = 1; i < input.ndimension(); i++) {
|
||||
TORCH_CHECK(input.size(i) > 0,
|
||||
"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
|
||||
"empty");
|
||||
"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
"but input has sizes ",
|
||||
input.sizes(),
|
||||
" with dimension ",
|
||||
i,
|
||||
" being "
|
||||
"empty");
|
||||
}
|
||||
|
||||
int64_t isizeH = input.size(-2);
|
||||
@ -198,13 +181,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
|
||||
|
||||
at::max_pool2d_with_indices_out(const_cast<Tensor&>(output),
|
||||
const_cast<Tensor&>(indices), input,
|
||||
const_cast<Tensor&>(indices),
|
||||
input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
@ -213,11 +194,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
|
||||
(const Tensor& gradOutput,
|
||||
const Tensor& input,
|
||||
const Tensor& indices,
|
||||
const Tensor& gradInput) {
|
||||
|
||||
(const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) {
|
||||
int64_t isizeH = input.size(-2);
|
||||
int64_t isizeW = input.size(-1);
|
||||
int64_t osizeH = gradOutput.size(-2);
|
||||
@ -226,13 +203,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
|
||||
|
||||
at::max_pool2d_with_indices_backward_out(const_cast<Tensor&>(gradInput),
|
||||
gradOutput, input,
|
||||
gradOutput,
|
||||
input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
@ -124,10 +124,10 @@ static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
|
||||
return binaryLibrary;
|
||||
}
|
||||
|
||||
NSError *error = nil;
|
||||
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion: MTLLanguageVersion2_3];
|
||||
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_BINARY encoding:NSASCIIStringEncoding]
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BINARY encoding:NSASCIIStringEncoding]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
|
||||
@ -159,15 +159,15 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
Tensor other = iter.input(1);
|
||||
Tensor out = iter.output();
|
||||
|
||||
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
|
||||
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
|
||||
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
|
||||
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
|
||||
id<MTLBuffer> outputBuffer = getMTLBufferStorage(out);
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const uint32_t nDim = iter.ndim();
|
||||
constexpr uint32_t nOffsets = 3;
|
||||
const uint32_t numThreads = iter.numel();
|
||||
dispatch_sync(mpsStream->queue(), ^(){
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
NSError* error = nil;
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
@ -177,23 +177,25 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
std::vector<uint32_t> iterShapeData(iterShape.size());
|
||||
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);
|
||||
|
||||
for (const auto i: c10::irange(iterShape.size())) {
|
||||
for (const auto i : c10::irange(iterShape.size())) {
|
||||
TORCH_CHECK(i <= UINT32_MAX);
|
||||
iterShapeData[i] = (uint32_t)(iterShape[i]);
|
||||
}
|
||||
|
||||
for (const auto i: c10::irange(nDim)) {
|
||||
for (const auto offset: c10::irange(nOffsets)) {
|
||||
strides[i][offset] = iter.strides(offset)[i];
|
||||
for (const auto i : c10::irange(nDim)) {
|
||||
for (const auto offset : c10::irange(nOffsets)) {
|
||||
strides[i][offset] = iter.strides(offset)[i];
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
|
||||
error: &error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
|
||||
options: 0] autorelease];
|
||||
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(
|
||||
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
|
||||
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
|
||||
@ -203,28 +205,26 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
|
||||
NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (kernelOffsetsTGSize > numThreads)
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
|
||||
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
|
||||
|
||||
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input.scalar_type());
|
||||
id<MTLComputePipelineState> binaryPSO = binaryPipelineState(device, kernel);
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
[computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
|
||||
[computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1];
|
||||
[computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
|
||||
[computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1];
|
||||
[computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
|
||||
|
||||
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (tgSize > numThreads) {
|
||||
tgSize = numThreads;
|
||||
tgSize = numThreads;
|
||||
}
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: threadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
[computeEncoder endEncoding];
|
||||
mpsStream->commit(true);
|
||||
@ -234,22 +234,22 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
} // namespace mps
|
||||
|
||||
void fmax_mps_kernel(TensorIteratorBase& iter) {
|
||||
if (isFloatingType(iter.common_dtype())) {
|
||||
mps::binary_mps_impl(iter, "fmax");
|
||||
} else {
|
||||
at::maximum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
|
||||
}
|
||||
if (isFloatingType(iter.common_dtype())) {
|
||||
mps::binary_mps_impl(iter, "fmax");
|
||||
} else {
|
||||
at::maximum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
|
||||
}
|
||||
}
|
||||
void fmin_mps_kernel(TensorIteratorBase& iter) {
|
||||
if (isFloatingType(iter.common_dtype())) {
|
||||
mps::binary_mps_impl(iter, "fmin");
|
||||
} else {
|
||||
at::minimum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
|
||||
}
|
||||
if (isFloatingType(iter.common_dtype())) {
|
||||
mps::binary_mps_impl(iter, "fmin");
|
||||
} else {
|
||||
at::minimum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
|
||||
}
|
||||
}
|
||||
|
||||
void copysign_mps_kernel(TensorIteratorBase& iter) {
|
||||
mps::binary_mps_impl(iter, "copysign");
|
||||
mps::binary_mps_impl(iter, "copysign");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel);
|
||||
|
@ -4,34 +4,40 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct BinaryOpCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
BinaryOpCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct BinaryOpCachedGraph : public MPSCachedGraph {
|
||||
BinaryOpCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *primaryTensor = nil, *secondaryTensor = nil;
|
||||
MPSGraphTensor *alphaTensor = nil, *outputTensor = nil;
|
||||
};
|
||||
|
||||
typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
|
||||
#define BinaryOpFn(graph, primary, secondary) MPSGraphTensor* (mps::BinaryOpCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
|
||||
#define BinaryOpFn(graph, primary, secondary) \
|
||||
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
|
||||
|
||||
// alpha is always 1.0 except when this function is called from add_sub_template()
|
||||
void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha,
|
||||
const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock)
|
||||
{
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte ),
|
||||
void binaryOpTensor(const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar& alpha,
|
||||
const Tensor& output_,
|
||||
std::string op_name,
|
||||
BinaryOpBlock binaryBlock) {
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
|
||||
"MPS support binary op with uint8 natively starting from macOS 13.0");
|
||||
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
|
||||
(self.scalar_type() == ScalarType::Long ||
|
||||
(other.scalar_type() == ScalarType::Long && (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
|
||||
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.2");
|
||||
(self.scalar_type() == ScalarType::Long ||
|
||||
(other.scalar_type() == ScalarType::Long &&
|
||||
(self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
|
||||
"MPS: ",
|
||||
op_name,
|
||||
" op with int64 input is supported natively starting from macOS 13.2");
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
const bool is_self_scalar = self.dim() == 0;
|
||||
@ -39,7 +45,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
|
||||
auto new_size = at::infer_size(self.sizes(), other.sizes());
|
||||
if (!output_.sizes().equals(new_size)) {
|
||||
output_.resize_(new_size);
|
||||
output_.resize_(new_size);
|
||||
}
|
||||
|
||||
// it's possible to receive empty tensors here
|
||||
@ -53,7 +59,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
if (!output_.is_contiguous()) {
|
||||
output = output_.contiguous();
|
||||
needsCopyToOutput = true;
|
||||
// else, determine if this is an in-place operation on a view output
|
||||
// else, determine if this is an in-place operation on a view output
|
||||
} else if (output_.is_view() && (self.is_alias_of(output_) || other.is_alias_of(output_))) {
|
||||
output = at::native::empty_mps(output_.sizes(), output_.scalar_type(), c10::nullopt, kMPS);
|
||||
needsCopyToOutput = true;
|
||||
@ -79,18 +85,20 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self, other, output_});
|
||||
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));
|
||||
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () {
|
||||
BinaryOpCachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
BinaryOpCachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new BinaryOpCachedGraph(mpsGraph);
|
||||
newCachedGraph->primaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
|
||||
newCachedGraph->secondaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
|
||||
newCachedGraph->primaryTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
|
||||
newCachedGraph->secondaryTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
|
||||
|
||||
MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor;
|
||||
MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor;
|
||||
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;
|
||||
|
||||
// this type inference is only required at the time of graph creation
|
||||
@ -99,10 +107,9 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
// integer inputs must be cast to float, if output is float
|
||||
if (isFloatingType(outputDataType)) {
|
||||
common_dtype = outputDataType;
|
||||
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
|
||||
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
|
||||
} else if (outputDataType == ScalarType::Bool &&
|
||||
(inputDataType == ScalarType::Byte ||
|
||||
otherDataType == ScalarType::Byte)) {
|
||||
(inputDataType == ScalarType::Byte || otherDataType == ScalarType::Byte)) {
|
||||
common_dtype = ScalarType::Byte;
|
||||
}
|
||||
}
|
||||
@ -113,19 +120,19 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
|
||||
}
|
||||
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
|
||||
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
|
||||
// Output tensor should have been promoted but it remains an int32 tensor
|
||||
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to
|
||||
// int32 tensor Output tensor should have been promoted but it remains an int32 tensor
|
||||
if (outputDataType != common_dtype ||
|
||||
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
|
||||
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
|
||||
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
|
||||
}
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<BinaryOpCachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<BinaryOpCachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
Placeholder selfPlaceholder;
|
||||
Placeholder otherPlaceholder;
|
||||
MPSScalar self_scalar;
|
||||
@ -136,16 +143,22 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
self_scalar = getMPSScalar(self.item(), inputDataType);
|
||||
feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self_scalar);
|
||||
} else {
|
||||
selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, /*mpsShape*/nil,
|
||||
/*gatherTensorData=*/true, getMPSScalarType(inputDataType));
|
||||
selfPlaceholder = Placeholder(cachedGraph->primaryTensor,
|
||||
self,
|
||||
/*mpsShape*/ nil,
|
||||
/*gatherTensorData=*/true,
|
||||
getMPSScalarType(inputDataType));
|
||||
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
if (is_other_scalar && !other.is_mps()) {
|
||||
other_scalar = getMPSScalar(other.item(), otherDataType);
|
||||
feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other_scalar);
|
||||
} else {
|
||||
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, /*mpsShape*/nil,
|
||||
/*gatherTensorData=*/true, getMPSScalarType(otherDataType));
|
||||
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor,
|
||||
other,
|
||||
/*mpsShape*/ nil,
|
||||
/*gatherTensorData=*/true,
|
||||
getMPSScalarType(otherDataType));
|
||||
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
@ -156,9 +169,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, needsCopyToOutput ? output : output_);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
if (needsCopyToOutput) {
|
||||
@ -167,34 +179,35 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
}
|
||||
}
|
||||
|
||||
void binaryOpScalar(const Tensor& self, const Scalar& other, const Scalar& alpha,
|
||||
const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock)
|
||||
{
|
||||
void binaryOpScalar(const Tensor& self,
|
||||
const Scalar& other,
|
||||
const Scalar& alpha,
|
||||
const Tensor& output,
|
||||
std::string op_name,
|
||||
BinaryOpBlock binaryBlock) {
|
||||
binaryOpTensor(self, wrapped_scalar_tensor(other), alpha, output, op_name, binaryBlock);
|
||||
}
|
||||
|
||||
void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
void div_mode_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
c10::optional<c10::string_view> rounding_mode,
|
||||
const Tensor& output, const string op_name)
|
||||
{
|
||||
if(rounding_mode.has_value() && *rounding_mode == "trunc"){
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Half,
|
||||
"MPS: does not support trunc_divide op with float16 input");
|
||||
const Tensor& output,
|
||||
const string op_name) {
|
||||
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input");
|
||||
}
|
||||
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
|
||||
if(!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
|
||||
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor
|
||||
toType:MPSDataTypeFloat32
|
||||
name:@"primaryCastTensor"];
|
||||
if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
|
||||
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"];
|
||||
secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor
|
||||
toType:MPSDataTypeFloat32
|
||||
name:@"secondaryCastTensor"];
|
||||
}
|
||||
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
// Rounding is a no-op for integral types, and also a reasonable workaround
|
||||
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
|
||||
// See https://github.com/pytorch/pytorch/issues/84995
|
||||
@ -202,14 +215,12 @@ void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
if (!rounding_mode.has_value() || !isFloatOutput) {
|
||||
return divTensor;
|
||||
} else if (*rounding_mode == "trunc") {
|
||||
auto truncTensor = trunc_tensor(mpsGraph, divTensor);
|
||||
auto truncTensor = trunc_tensor(mpsGraph, divTensor);
|
||||
if (op_name == "fmod_mps_out") {
|
||||
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:mulTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
|
||||
}
|
||||
return truncTensor;
|
||||
} else if (*rounding_mode == "floor") {
|
||||
@ -218,22 +229,28 @@ void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:mulTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
|
||||
}
|
||||
return floorTensor;
|
||||
}
|
||||
assert(0 && "Invalid rounding mode\n");
|
||||
return nullptr;
|
||||
};
|
||||
binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block);
|
||||
binaryOpTensor(self,
|
||||
other,
|
||||
Scalar(1.0),
|
||||
output,
|
||||
op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""),
|
||||
div_mode_op_block);
|
||||
}
|
||||
|
||||
void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name)
|
||||
{
|
||||
void add_sub_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar& alpha,
|
||||
const Tensor& output,
|
||||
std::string op_name) {
|
||||
if (alpha.toDouble() == 0.0) {
|
||||
if (!self.is_alias_of(output)) { // if inplace, no-op
|
||||
if (!self.is_alias_of(output)) { // if inplace, no-op
|
||||
const_cast<Tensor&>(output) = self.clone();
|
||||
}
|
||||
return;
|
||||
@ -251,60 +268,79 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp
|
||||
|
||||
// if alpha is 1.0, then we don't bother adding another multiply to graph
|
||||
if (alpha_has_value) {
|
||||
cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[@1]);
|
||||
cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[ @1 ]);
|
||||
secondaryTensor = [mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor
|
||||
secondaryTensor:cachedGraph->alphaTensor
|
||||
name:nil];
|
||||
}
|
||||
if (op_name == "add")
|
||||
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryTensor
|
||||
name:nil];
|
||||
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
|
||||
else
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
|
||||
};
|
||||
// add alpha's type to the key only if multiply was added to graph
|
||||
binaryOpTensor(self, other, alpha, output, op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""), add_sub_op_block);
|
||||
binaryOpTensor(self,
|
||||
other,
|
||||
alpha,
|
||||
output,
|
||||
op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""),
|
||||
add_sub_op_block);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \
|
||||
Tensor& func_out (const Tensor& self, const other_type& other, Tensor& output) { \
|
||||
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \
|
||||
secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \
|
||||
name:nil]; }); \
|
||||
return output; \
|
||||
}
|
||||
#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \
|
||||
Tensor& func_out(const Tensor& self, const other_type& other, Tensor& output) { \
|
||||
mps::binaryOp##other_type( \
|
||||
self, \
|
||||
other, \
|
||||
Scalar(1.0), \
|
||||
output, \
|
||||
#func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub## \
|
||||
WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \
|
||||
secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \
|
||||
name:nil]; \
|
||||
}); \
|
||||
return output; \
|
||||
}
|
||||
|
||||
#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \
|
||||
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
|
||||
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && \
|
||||
std::string(#func_stub) == "atan2"), \
|
||||
"MPS does not support ", #func_stub, " op with int64 input") \
|
||||
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
|
||||
secondaryTensor:secondaryCastTensor \
|
||||
name:nil]; }); \
|
||||
}
|
||||
#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \
|
||||
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && std::string(#func_stub) == "atan2"), \
|
||||
"MPS does not support ", \
|
||||
#func_stub, \
|
||||
" op with int64 input") \
|
||||
mps::binaryOp##other_type(self, \
|
||||
other, \
|
||||
Scalar(1.0), \
|
||||
output, \
|
||||
#func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
|
||||
secondaryTensor:secondaryCastTensor \
|
||||
name:nil]; \
|
||||
}); \
|
||||
}
|
||||
|
||||
// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor()
|
||||
#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \
|
||||
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
|
||||
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
|
||||
secondaryTensor:secondaryCastTensor \
|
||||
name:nil]; }); \
|
||||
}
|
||||
#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \
|
||||
mps::binaryOp##other_type(self, \
|
||||
other, \
|
||||
Scalar(1.0), \
|
||||
output, \
|
||||
#func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
|
||||
secondaryTensor:secondaryCastTensor \
|
||||
name:nil]; \
|
||||
}); \
|
||||
}
|
||||
|
||||
// Boolean Binary Ops
|
||||
CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(eq_scalar_out_mps, equal, Scalar);
|
||||
@ -332,24 +368,24 @@ CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor);
|
||||
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor);
|
||||
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor);
|
||||
|
||||
|
||||
TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(div_out_mode_mps)
|
||||
(const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(div_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(div_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, c10::nullopt, output, "div_out");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(add_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
|
||||
mps::add_sub_template(self, other, alpha, output, "add");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(sub_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
|
||||
mps::add_sub_template(self, other, alpha, output, "sub");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(pow_Scalar_out_mps) (const Scalar& base, const Tensor& exp, const Tensor& out) {
|
||||
TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const Tensor& out) {
|
||||
if (base.equal(1.0)) {
|
||||
out.fill_(1);
|
||||
} else {
|
||||
@ -386,21 +422,18 @@ Tensor& floor_divide_mps_(Tensor& self, const Tensor& other) {
|
||||
return floor_divide_out_mps(self, other, self);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, "floor", output, "remainder_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(fmod_mps_out) (const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(fmod_mps_out)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0
|
||||
shape:@[@1]
|
||||
dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
@ -413,46 +446,42 @@ TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const T
|
||||
mps::binaryOpTensor(self, other, Scalar(1.0), output, "hypot_out_mps", hypot_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
MPSGraphTensor* sumTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
{
|
||||
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
MPSGraphTensor* sumTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
|
||||
};
|
||||
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(xlogy_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
shape:@[@1]
|
||||
dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor name:nil];
|
||||
MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor name:nil];
|
||||
MPSGraphTensor* xlogyTensor = [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:logyTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* xEqualZeroPredicateTensor = [mpsGraph equalWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:xEqualZeroPredicateTensor
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:xlogyTensor
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/ops/logical_not_native.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/library.h>
|
||||
@ -86,17 +86,16 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
|
||||
}}
|
||||
)METAL";
|
||||
|
||||
|
||||
const std::string& getMetalType(const c10::ScalarType& t) {
|
||||
// Mapping from c10::ScalarType to integral type that can be used for bitwise ops
|
||||
// As bitwise ops sign-agnostic map signed/unsigned char and boolean to the same type
|
||||
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
|
||||
{c10::ScalarType::Long, "long"},
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Byte, "char"},
|
||||
{c10::ScalarType::Char, "char"},
|
||||
{c10::ScalarType::Bool, "char"},
|
||||
{c10::ScalarType::Long, "long"},
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Byte, "char"},
|
||||
{c10::ScalarType::Char, "char"},
|
||||
{c10::ScalarType::Bool, "char"},
|
||||
};
|
||||
|
||||
auto it = scalar_to_metal_type.find(t);
|
||||
@ -112,7 +111,6 @@ const std::string& getMetalType(const c10::Scalar& s) {
|
||||
return getMetalType(s.type());
|
||||
}
|
||||
|
||||
|
||||
static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
|
||||
const std::string& t1,
|
||||
const std::string& t2,
|
||||
@ -123,61 +121,60 @@ static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
|
||||
if (it != libMap.end()) {
|
||||
return it->second;
|
||||
}
|
||||
NSError *error = nil;
|
||||
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion: MTLLanguageVersion2_3];
|
||||
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
|
||||
libMap[key] = rc;
|
||||
return rc;
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
auto rc =
|
||||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
|
||||
libMap[key] = rc;
|
||||
return rc;
|
||||
}
|
||||
|
||||
|
||||
static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
|
||||
const std::string& t1,
|
||||
const std::string& t2,
|
||||
const std::string& t3,
|
||||
const std::string& fname) {
|
||||
const std::string& t1,
|
||||
const std::string& t2,
|
||||
const std::string& t3,
|
||||
const std::string& fname) {
|
||||
auto key = t1 + t2 + t3 + fname;
|
||||
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
|
||||
auto it = cplMap.find(key);
|
||||
if (it != cplMap.end()) {
|
||||
return it->second;
|
||||
return it->second;
|
||||
}
|
||||
NSError *error = nil;
|
||||
NSError* error = nil;
|
||||
auto library = compileBitwiseOpsLibrary(device, t1, t2, t3);
|
||||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
|
||||
TORCH_CHECK(func != nil, "Can't get function ", fname);
|
||||
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
cplMap[key] = rc;
|
||||
TORCH_CHECK(
|
||||
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
cplMap[key] = rc;
|
||||
return rc;
|
||||
}
|
||||
|
||||
void dispatch1DJob(id<MTLComputeCommandEncoder> commandEncoder, id<MTLComputePipelineState> cplState, uint32_t length)
|
||||
{
|
||||
void dispatch1DJob(id<MTLComputeCommandEncoder> commandEncoder, id<MTLComputePipelineState> cplState, uint32_t length) {
|
||||
uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
|
||||
auto size = MTLSizeMake(length, 1, 1);
|
||||
auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
|
||||
[commandEncoder dispatchThreads:size
|
||||
threadsPerThreadgroup:threadGroupSize];
|
||||
[commandEncoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
|
||||
}
|
||||
|
||||
void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& other, at::Tensor& output, const std::string& kernel_name) {
|
||||
void handle_tensor_tensor_binary_op(const at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
at::Tensor& output,
|
||||
const std::string& kernel_name) {
|
||||
using namespace at::mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
|
||||
getMetalType(output),
|
||||
getMetalType(self),
|
||||
getMetalType(other),
|
||||
kernel_name);
|
||||
id<MTLComputePipelineState> cplState = getCPLState(
|
||||
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
|
||||
uint32_t length = output.numel();
|
||||
if (length == 0) {
|
||||
return;
|
||||
}
|
||||
dispatch_sync(stream->queue(), ^(){
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
id<MTLCommandBuffer> buffer = stream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
|
||||
|
||||
@ -188,29 +185,29 @@ void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& ot
|
||||
[commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]];
|
||||
[commandEncoder setComputePipelineState:cplState];
|
||||
[commandEncoder setBytes:&length length:sizeof(length) atIndex:0];
|
||||
[commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1];
|
||||
[commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2];
|
||||
[commandEncoder setBuffer:otherBuf offset:other.storage_offset()*other.itemsize() atIndex:3];
|
||||
[commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1];
|
||||
[commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2];
|
||||
[commandEncoder setBuffer:otherBuf offset:other.storage_offset() * other.itemsize() atIndex:3];
|
||||
dispatch1DJob(commandEncoder, cplState, length);
|
||||
[commandEncoder endEncoding];
|
||||
stream->commit(true);
|
||||
});
|
||||
}
|
||||
|
||||
void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& other, at::Tensor& output, const std::string& kernel_name) {
|
||||
void handle_tensor_scalar_binary_op(const at::Tensor& self,
|
||||
const at::Scalar& other,
|
||||
at::Tensor& output,
|
||||
const std::string& kernel_name) {
|
||||
using namespace at::mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
|
||||
getMetalType(output),
|
||||
getMetalType(self),
|
||||
getMetalType(other),
|
||||
kernel_name);
|
||||
id<MTLComputePipelineState> cplState = getCPLState(
|
||||
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
|
||||
uint64_t sval = other.to<int64_t>();
|
||||
uint32_t length = output.numel();
|
||||
if (length == 0) {
|
||||
return;
|
||||
}
|
||||
dispatch_sync(stream->queue(), ^(){
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
id<MTLCommandBuffer> buffer = stream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
|
||||
|
||||
@ -220,8 +217,8 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
|
||||
[commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]];
|
||||
[commandEncoder setComputePipelineState:cplState];
|
||||
[commandEncoder setBytes:&length length:sizeof(length) atIndex:0];
|
||||
[commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1];
|
||||
[commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2];
|
||||
[commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1];
|
||||
[commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2];
|
||||
[commandEncoder setBytes:&sval length:sizeof(sval) atIndex:3];
|
||||
dispatch1DJob(commandEncoder, cplState, length);
|
||||
[commandEncoder endEncoding];
|
||||
@ -229,7 +226,10 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
|
||||
});
|
||||
}
|
||||
|
||||
at::Tensor& _bitwise_op_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output_, const std::string& op_name) {
|
||||
at::Tensor& _bitwise_op_out_mps(const at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
at::Tensor& output_,
|
||||
const std::string& op_name) {
|
||||
using namespace at::mps;
|
||||
const bool is_self_scalar = self.dim() == 0;
|
||||
const bool is_other_scalar = other.dim() == 0;
|
||||
@ -264,24 +264,24 @@ at::Tensor& _bitwise_op_out_mps (const at::Tensor& self, const at::Tensor& other
|
||||
fmt::format("bitwise_{}_tensor", op_name));
|
||||
}
|
||||
if (needs_output_copy) {
|
||||
output_.copy_(output);
|
||||
output_.copy_(output);
|
||||
}
|
||||
return output_;
|
||||
}
|
||||
|
||||
at::Tensor& bitwise_and_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
|
||||
return _bitwise_op_out_mps(self, other, output, "and");
|
||||
at::Tensor& bitwise_and_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
|
||||
return _bitwise_op_out_mps(self, other, output, "and");
|
||||
}
|
||||
|
||||
at::Tensor& bitwise_or_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
|
||||
return _bitwise_op_out_mps(self, other, output, "or");
|
||||
at::Tensor& bitwise_or_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
|
||||
return _bitwise_op_out_mps(self, other, output, "or");
|
||||
}
|
||||
|
||||
at::Tensor& bitwise_xor_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
|
||||
return _bitwise_op_out_mps(self, other, output, "xor");
|
||||
at::Tensor& bitwise_xor_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) {
|
||||
return _bitwise_op_out_mps(self, other, output, "xor");
|
||||
}
|
||||
|
||||
at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
|
||||
at::Tensor& bitwise_not_out_mps(const at::Tensor& self, at::Tensor& output_) {
|
||||
// Handle boolean tensor using logical not
|
||||
if (self.scalar_type() == c10::ScalarType::Bool) {
|
||||
return at::native::logical_not_out_mps(self, output_);
|
||||
@ -310,12 +310,9 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
|
||||
}
|
||||
using namespace at::mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
|
||||
getMetalType(output),
|
||||
getMetalType(self),
|
||||
getMetalType(self),
|
||||
"bitwise_not");
|
||||
dispatch_sync(stream->queue(), ^(){
|
||||
id<MTLComputePipelineState> cplState = getCPLState(
|
||||
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(self), "bitwise_not");
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
id<MTLCommandBuffer> buffer = stream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
|
||||
|
||||
@ -325,20 +322,18 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
|
||||
[commandEncoder pushDebugGroup:@"Dispatch bitwise_not kernel"];
|
||||
[commandEncoder setComputePipelineState:cplState];
|
||||
[commandEncoder setBytes:&length length:sizeof(length) atIndex:0];
|
||||
[commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1];
|
||||
[commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2];
|
||||
[commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1];
|
||||
[commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2];
|
||||
dispatch1DJob(commandEncoder, cplState, length);
|
||||
[commandEncoder endEncoding];
|
||||
stream->commit(true);
|
||||
});
|
||||
if (needs_output_copy) {
|
||||
output_.copy_(output);
|
||||
output_.copy_(output);
|
||||
}
|
||||
return output_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MPS, m) {
|
||||
m.impl("bitwise_and.Tensor_out", bitwise_and_out_mps);
|
||||
m.impl("bitwise_or.Tensor_out", bitwise_or_out_mps);
|
||||
|
@ -12,26 +12,19 @@
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
|
||||
Tensor dot_mps(
|
||||
const Tensor &self,
|
||||
const Tensor &other)
|
||||
{
|
||||
|
||||
Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS: dot op doesn't support int64 input")
|
||||
|
||||
using namespace mps;
|
||||
auto output = at::native::empty_mps({}, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@ -40,45 +33,38 @@ Tensor dot_mps(
|
||||
@autoreleasepool {
|
||||
string key = "dot_mps" + getTensorsStringKey({self, other});
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
|
||||
MPSGraphTensor *castSelf = nil;
|
||||
MPSGraphTensor *castOther = nil;
|
||||
MPSGraphTensor* castSelf = nil;
|
||||
MPSGraphTensor* castOther = nil;
|
||||
|
||||
if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte
|
||||
|| self.scalar_type() == ScalarType::Char) {
|
||||
castSelf = [mpsGraph castTensor:selfTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castSelfTensor"];
|
||||
castOther = [mpsGraph castTensor:otherTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castOtherTensor"];
|
||||
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
|
||||
self.scalar_type() == ScalarType::Char) {
|
||||
castSelf = [mpsGraph castTensor:selfTensor toType:MPSDataTypeInt32 name:@"castSelfTensor"];
|
||||
castOther = [mpsGraph castTensor:otherTensor toType:MPSDataTypeInt32 name:@"castOtherTensor"];
|
||||
} else {
|
||||
castSelf = selfTensor;
|
||||
castOther = otherTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor *dot = [mpsGraph multiplicationWithPrimaryTensor: castSelf
|
||||
secondaryTensor: castOther
|
||||
name: @"multiplication"];
|
||||
MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf
|
||||
secondaryTensor:castOther
|
||||
name:@"multiplication"];
|
||||
|
||||
MPSGraphTensor *dotProductTensor = [mpsGraph reductionSumWithTensor: dot
|
||||
axes: nil
|
||||
name: @"dotProduct"];
|
||||
MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor:dot axes:nil name:@"dotProduct"];
|
||||
|
||||
if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte
|
||||
|| self.scalar_type() == ScalarType::Char)
|
||||
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
|
||||
self.scalar_type() == ScalarType::Char)
|
||||
dotProductTensor = [mpsGraph castTensor:dotProductTensor
|
||||
toType:getMPSDataType(self)
|
||||
name:@"castDotProductTensor"];
|
||||
@ -89,7 +75,7 @@ Tensor dot_mps(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
|
||||
@ -101,9 +87,8 @@ Tensor dot_mps(
|
||||
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -111,14 +96,12 @@ Tensor dot_mps(
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& addmv_out_mps_impl(
|
||||
const Tensor &self,
|
||||
const Tensor &mat,
|
||||
const Tensor &vec,
|
||||
const Scalar& beta_,
|
||||
const Scalar& alpha_,
|
||||
Tensor& result)
|
||||
{
|
||||
Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
const Tensor& mat,
|
||||
const Tensor& vec,
|
||||
const Scalar& beta_,
|
||||
const Scalar& alpha_,
|
||||
Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(mat.is_mps());
|
||||
@ -129,38 +112,35 @@ Tensor& addmv_out_mps_impl(
|
||||
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
|
||||
auto betaval = beta_.toComplexDouble();
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor_ = nil;
|
||||
MPSGraphTensor *matMulVecTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* matMulVecTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
MPSStream *stream = at::mps::getCurrentMPSStream();
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
Tensor matMulVec = mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec})
|
||||
+ ":" + to_string(beta_.toDouble())
|
||||
+ ":" + to_string(alpha_.toDouble());
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
|
||||
":" + to_string(alpha_.toDouble());
|
||||
CachedGraph* cachedGraph = nil;
|
||||
if(!cachedGraph) {
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
|
||||
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
|
||||
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
// Intermediates for beta and alpha
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha_.toDouble()
|
||||
dataType: getMPSScalarType(mat.scalar_type())];
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha_.toDouble()
|
||||
dataType:getMPSScalarType(mat.scalar_type())];
|
||||
|
||||
// Intermediates for multiplying by beta and alpha
|
||||
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:matMulVecTensor
|
||||
@ -168,18 +148,17 @@ Tensor& addmv_out_mps_impl(
|
||||
name:@"MM/alpha*(mat@vec)"];
|
||||
newCachedGraph->outputTensor_ = productTimesAlphaTensor;
|
||||
|
||||
if (betaval != 0.0)
|
||||
{
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta_.toDouble()
|
||||
dataType: getMPSScalarType(self.scalar_type())];
|
||||
if (betaval != 0.0) {
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta_.toDouble()
|
||||
dataType:getMPSScalarType(self.scalar_type())];
|
||||
|
||||
MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: selfTensor
|
||||
secondaryTensor: betaTensor
|
||||
name: @"MM/beta*input"];
|
||||
MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
|
||||
secondaryTensor:betaTensor
|
||||
name:@"MM/beta*input"];
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor: productTimesAlphaTensor
|
||||
secondaryTensor: selfTimesBetaTensor
|
||||
name: @"MM/beta*input + alpha*(mat@vec)"];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
|
||||
secondaryTensor:selfTimesBetaTensor
|
||||
name:@"MM/beta*input + alpha*(mat@vec)"];
|
||||
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
@ -189,23 +168,21 @@ Tensor& addmv_out_mps_impl(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder matMulVecPlaceholder = Placeholder(cachedGraph->matMulVecTensor_, matMulVec);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
|
||||
feeds[matMulVecPlaceholder.getMPSGraphTensor()] = matMulVecPlaceholder.getMPSGraphTensorData();
|
||||
if (betaval != 0.0)
|
||||
{
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
|
||||
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
|
||||
feeds[matMulVecPlaceholder.getMPSGraphTensor()] = matMulVecPlaceholder.getMPSGraphTensorData();
|
||||
if (betaval != 0.0) {
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
|
||||
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -213,7 +190,13 @@ Tensor& addmv_out_mps_impl(
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmv_out_mps)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmv_out_mps)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat,
|
||||
const Tensor& vec,
|
||||
const Scalar& beta_,
|
||||
const Scalar& alpha_,
|
||||
const Tensor& result) {
|
||||
addmv_out_mps_impl(self, mat, vec, beta_, alpha_, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
|
@ -18,26 +18,27 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
|
||||
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
|
||||
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
|
||||
auto dataType =
|
||||
!isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
|
||||
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
|
||||
// workaround by filing it as int8 tensor and than casting to bool
|
||||
// See https://github.com/pytorch/pytorch/issues/82427
|
||||
@ -47,17 +48,12 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
|
||||
shape:getMPSShape(self)
|
||||
dataType:dataType];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
|
||||
if (isBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeBool
|
||||
name:@"constWithBool-workaround"];
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
|
||||
}
|
||||
if (isUInt8) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeUInt8
|
||||
name:@"constWithUInt8-workaround"];
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
|
||||
}
|
||||
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
@ -66,13 +62,11 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
});
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
|
||||
needsCopyToOutput ? output : self,
|
||||
nullptr, !needsCopyToOutput);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results);
|
||||
|
||||
@ -109,7 +103,10 @@ Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
|
||||
}
|
||||
|
||||
Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) {
|
||||
TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
|
||||
TORCH_CHECK(value.dim() == 0,
|
||||
"fill_ only supports 0-dimension value tensor but got tensor with ",
|
||||
value.dim(),
|
||||
" dimensions.");
|
||||
Scalar scalar_value = value.item();
|
||||
if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
|
||||
return self;
|
||||
|
@ -2,39 +2,51 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* descriptor_,
|
||||
NSUInteger strideInX, NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX, NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal, NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format, NSUInteger groups) {
|
||||
descriptor_.strides = @[@1, [[NSNumber alloc] initWithInteger: strideInY],
|
||||
[[NSNumber alloc] initWithInteger: strideInX]];
|
||||
descriptor_.dilationRates = @[@1, [[NSNumber alloc] initWithInteger: dilationRateInY],
|
||||
[[NSNumber alloc] initWithInteger: dilationRateInX]];
|
||||
NSUInteger strideInX,
|
||||
NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX,
|
||||
NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal,
|
||||
NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format,
|
||||
NSUInteger groups) {
|
||||
descriptor_.strides =
|
||||
@[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ];
|
||||
descriptor_.dilationRates =
|
||||
@[ @1, [[NSNumber alloc] initWithInteger:dilationRateInY], [[NSNumber alloc] initWithInteger:dilationRateInX] ];
|
||||
|
||||
descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit;
|
||||
descriptor_.paddingValues = @[@0, @0, [[NSNumber alloc] initWithInteger: paddingVertical], [[NSNumber alloc]
|
||||
initWithInteger: paddingVertical], [[NSNumber alloc]
|
||||
initWithInteger: paddingHorizontal], [[NSNumber alloc]
|
||||
initWithInteger: paddingHorizontal]];
|
||||
descriptor_.paddingValues = @[
|
||||
@0,
|
||||
@0,
|
||||
[[NSNumber alloc] initWithInteger:paddingVertical],
|
||||
[[NSNumber alloc] initWithInteger:paddingVertical],
|
||||
[[NSNumber alloc] initWithInteger:paddingHorizontal],
|
||||
[[NSNumber alloc] initWithInteger:paddingHorizontal]
|
||||
];
|
||||
descriptor_.channelDimensionIndex = -3LL;
|
||||
}
|
||||
|
||||
// Create convolution descriptor
|
||||
void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
|
||||
NSUInteger strideInX, NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX, NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal, NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format, NSUInteger groups) {
|
||||
NSUInteger strideInX,
|
||||
NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX,
|
||||
NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal,
|
||||
NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format,
|
||||
NSUInteger groups) {
|
||||
descriptor_.strideInX = strideInX;
|
||||
descriptor_.strideInY = strideInY;
|
||||
descriptor_.dilationRateInX = dilationRateInX;
|
||||
@ -48,64 +60,59 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
|
||||
descriptor_.paddingTop = paddingVertical;
|
||||
descriptor_.paddingBottom = paddingVertical;
|
||||
|
||||
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ?
|
||||
MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
|
||||
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW
|
||||
: MPSGraphTensorNamedDataLayoutNHWC;
|
||||
|
||||
// PyTorch always uses OIHW memory layout for weights
|
||||
descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
|
||||
descriptor_.groups = groups;
|
||||
}
|
||||
|
||||
Tensor _mps_convolution_impl(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::optional<IntArrayRef> input_shape) {
|
||||
Tensor _mps_convolution_impl(const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::optional<IntArrayRef> input_shape) {
|
||||
TORCH_CHECK(input_t.dim() < 5, "Conv3D is not supported on MPS");
|
||||
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
|
||||
namespace native_mps = at::native::mps;
|
||||
CheckedFrom c = "mps_convolution";
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
|
||||
checkAllSameType(c, {input, weight});
|
||||
checkAllSameGPU(c, {input, weight});
|
||||
|
||||
bool bias_defined;
|
||||
|
||||
if(bias_opt == c10::nullopt)
|
||||
if (bias_opt == c10::nullopt)
|
||||
bias_defined = false;
|
||||
else
|
||||
bias_defined = bias_opt->defined();
|
||||
bias_defined = bias_opt->defined();
|
||||
|
||||
auto memory_format = input_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
auto output_t = at::empty(
|
||||
input_shape.has_value() ?
|
||||
input_shape.value() :
|
||||
conv_output_size(input->sizes(), weight->sizes(),
|
||||
padding, stride, dilation),
|
||||
input->scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
auto output_t =
|
||||
at::empty(input_shape.has_value() ? input_shape.value()
|
||||
: conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation),
|
||||
input->scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
TensorArg output{ output_t, "result", 0 };
|
||||
TensorArg output{output_t, "result", 0};
|
||||
|
||||
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
@ -117,13 +124,12 @@ Tensor _mps_convolution_impl(
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
IntArrayRef bias_shape;
|
||||
if(bias_defined)
|
||||
if (bias_defined)
|
||||
bias_shape = bias_opt.value().sizes();
|
||||
|
||||
string mem_format_key;
|
||||
switch(memory_format) {
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
break;
|
||||
@ -135,76 +141,87 @@ Tensor _mps_convolution_impl(
|
||||
}
|
||||
|
||||
string bias_shape_key;
|
||||
if(bias_defined) {
|
||||
if (bias_defined) {
|
||||
bias_shape_key = to_string(bias_shape[0]);
|
||||
} else {
|
||||
bias_shape_key = "nobias";
|
||||
}
|
||||
|
||||
string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
|
||||
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
|
||||
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
|
||||
+ to_string(groups) + ":" + mem_format_key
|
||||
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
|
||||
+ to_string(bias_defined) + ":" + bias_shape_key;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) +
|
||||
":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
|
||||
to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" +
|
||||
to_string(bias_defined) + ":" + bias_shape_key;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
|
||||
if(!cachedGraph) {
|
||||
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = native_mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ =[[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
|
||||
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSShape* weightShape = mps::getMPSShape(weight_t);
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) &&
|
||||
inputShape.count >= 4 && weightShape.count >= 4 && !is_channels_last);
|
||||
if(isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
memory_format, groups);
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 &&
|
||||
weightShape.count >= 4 && !is_channels_last);
|
||||
if (isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
memory_format,
|
||||
groups);
|
||||
} else {
|
||||
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
memory_format, groups);
|
||||
fill_conv_desc(conv2dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
memory_format,
|
||||
groups);
|
||||
}
|
||||
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if(bias_defined) {
|
||||
biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()));
|
||||
if (bias_defined) {
|
||||
biasTensor =
|
||||
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()));
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor;
|
||||
if(isDepthwiseConv) {
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil];
|
||||
outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor: inputTensor
|
||||
weightsTensor: weightTransposeTensor
|
||||
descriptor: depthWiseConv3dDescriptor_
|
||||
name: nil];
|
||||
if (isDepthwiseConv) {
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
|
||||
dimension:-3
|
||||
withDimension:-4
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:inputTensor
|
||||
weightsTensor:weightTransposeTensor
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
name:nil];
|
||||
} else {
|
||||
outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor
|
||||
weightsTensor: weightTensor
|
||||
descriptor: conv2dDescriptor_
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
|
||||
weightsTensor:weightTensor
|
||||
descriptor:conv2dDescriptor_
|
||||
name:nil];
|
||||
}
|
||||
|
||||
if (is_channels_last) {
|
||||
outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor);
|
||||
}
|
||||
|
||||
if(bias_defined) {
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor
|
||||
secondaryTensor: biasTensor
|
||||
name: nil];
|
||||
if (bias_defined) {
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil];
|
||||
}
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->weightTensor_ = weightTensor;
|
||||
@ -213,27 +230,28 @@ Tensor _mps_convolution_impl(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
|
||||
auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t);
|
||||
auto biasPlaceholder = native_mps::Placeholder();
|
||||
// Reshape the bias to be broadcastable with output of conv2d
|
||||
if(bias_defined)
|
||||
biasPlaceholder = native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
|
||||
if (bias_defined)
|
||||
biasPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
|
||||
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, *output);
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 3] autorelease];
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
[[[NSMutableDictionary alloc] initWithCapacity:3] autorelease];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
|
||||
if(bias_defined) {
|
||||
if (bias_defined) {
|
||||
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -241,40 +259,42 @@ Tensor _mps_convolution_impl(
|
||||
return *output;
|
||||
}
|
||||
|
||||
Tensor _mps_convolution(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
|
||||
Tensor _mps_convolution(const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor mps_convolution_backward_input(
|
||||
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& weight_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool bias_defined) {
|
||||
namespace native_mps = at::native::mps;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
CheckedFrom c = "mps_convolution_backward_input";
|
||||
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
|
||||
weight{ weight_t, "weight", 2 };
|
||||
TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2};
|
||||
checkAllSameType(c, {grad_output, weight});
|
||||
checkAllSameGPU(c, {grad_output, weight});
|
||||
auto memory_format = grad_output_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);
|
||||
auto grad_input_t = at::empty(input_size, grad_output_t.options(), c10::nullopt);
|
||||
|
||||
// Avoid "grad_input" when this is being used as transposed convolution
|
||||
TensorArg grad_input{ grad_input_t, "result", 0 };
|
||||
TensorArg grad_input{grad_input_t, "result", 0};
|
||||
convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
MPSGraphTensor* gradInputTensor_ = nil;
|
||||
@ -284,11 +304,10 @@ Tensor mps_convolution_backward_input(
|
||||
|
||||
// Add backward with input
|
||||
@autoreleasepool {
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
string mem_format_key;
|
||||
switch(memory_format) {
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
break;
|
||||
@ -302,64 +321,77 @@ Tensor mps_convolution_backward_input(
|
||||
MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
|
||||
MPSShape* mps_input_shape = getMPSShape(input_size);
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
|
||||
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
|
||||
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
|
||||
+ to_string(groups) + ":" + mem_format_key
|
||||
+ getTensorsStringKey({grad_output_t, weight_t}) + ":"
|
||||
+ string([ns_shape_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
|
||||
string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = native_mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
|
||||
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
|
||||
MPSShape* weightOutputShape = mps::getMPSShape(weight_t);
|
||||
// Depthwise conv is input feature channels = groups. So I in OIHW has to be 1.
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) &&
|
||||
gradOutputShape.count >= 4 && weightOutputShape.count >= 4 && !is_channels_last);
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 &&
|
||||
weightOutputShape.count >= 4 && !is_channels_last);
|
||||
|
||||
if(isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
if (isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
} else {
|
||||
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
fill_conv_desc(conv2dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
|
||||
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
|
||||
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
|
||||
if (is_channels_last) {
|
||||
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor;
|
||||
if(isDepthwiseConv) {
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil];
|
||||
gradInputTensor = [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
weightsTensor:weightTransposeTensor
|
||||
outputShape:mps_input_shape
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
name:nil];
|
||||
if (isDepthwiseConv) {
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
|
||||
dimension:-3
|
||||
withDimension:-4
|
||||
name:nil];
|
||||
gradInputTensor =
|
||||
[mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
weightsTensor:weightTransposeTensor
|
||||
outputShape:mps_input_shape
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
name:nil];
|
||||
} else {
|
||||
gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
weightsTensor:weightTensor
|
||||
outputShape:mps_input_shape
|
||||
forwardConvolutionDescriptor:conv2dDescriptor_
|
||||
name:nil];
|
||||
gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
weightsTensor:weightTensor
|
||||
outputShape:mps_input_shape
|
||||
forwardConvolutionDescriptor:conv2dDescriptor_
|
||||
name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
@ -368,30 +400,34 @@ Tensor mps_convolution_backward_input(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
|
||||
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
weightsPlaceholder.getMPSGraphTensor() : weightsPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return *grad_input;
|
||||
}
|
||||
|
||||
Tensor mps_convolution_backward_weights(
|
||||
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool bias_defined) {
|
||||
namespace native_mps = at::native::mps;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
@ -403,27 +439,21 @@ Tensor mps_convolution_backward_weights(
|
||||
|
||||
// For uniformity with everything else, although it seems grad_weight
|
||||
// would be unambiguous too.
|
||||
TensorArg grad_output{ grad_output_t, "grad_output", 1 };
|
||||
TensorArg input{ input_t, "input", 2};
|
||||
TensorArg grad_output{grad_output_t, "grad_output", 1};
|
||||
TensorArg input{input_t, "input", 2};
|
||||
|
||||
checkAllSameType(c, {grad_output, input});
|
||||
checkAllSameGPU(c, {grad_output, input});
|
||||
|
||||
auto grad_weight_t = at::empty(
|
||||
weight_size,
|
||||
grad_output_t.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
TensorArg grad_weight{ grad_weight_t, "result", 0 };
|
||||
auto grad_weight_t =
|
||||
at::empty(weight_size, grad_output_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
TensorArg grad_weight{grad_weight_t, "result", 0};
|
||||
|
||||
convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* gradWeightTensor_ = nil;
|
||||
@ -432,11 +462,10 @@ Tensor mps_convolution_backward_weights(
|
||||
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
string mem_format_key;
|
||||
switch(memory_format) {
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
break;
|
||||
@ -448,64 +477,79 @@ Tensor mps_convolution_backward_weights(
|
||||
}
|
||||
MPSShape* mps_weight_shape = getMPSShape(weight_size);
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
|
||||
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
|
||||
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
|
||||
+ to_string(groups) + ":" + mem_format_key
|
||||
+ getTensorsStringKey({grad_output_t, input_t}) + ":"
|
||||
+ string([ns_shape_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
|
||||
string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, input_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = native_mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
|
||||
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t);
|
||||
bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 && mps_weight_shape.count >= 4 && !is_channels_last);
|
||||
bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 &&
|
||||
mps_weight_shape.count >= 4 && !is_channels_last);
|
||||
|
||||
if(isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
if (isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
} else {
|
||||
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
fill_conv_desc(conv2dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
|
||||
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
|
||||
if (is_channels_last) {
|
||||
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradWeightTensor;
|
||||
if(isDepthwiseConv) {
|
||||
NSNumber* outputFeatChannelDim = mps_weight_shape[0];
|
||||
MPSShape* weightShapeTranspose = @[@1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3]];
|
||||
MPSGraphTensor* gradWeightTensorTranspose = [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
sourceTensor:inputTensor
|
||||
outputShape:weightShapeTranspose
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose dimension:-3 withDimension:-4 name:nil];
|
||||
if (isDepthwiseConv) {
|
||||
NSNumber* outputFeatChannelDim = mps_weight_shape[0];
|
||||
MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ];
|
||||
MPSGraphTensor* gradWeightTensorTranspose =
|
||||
[mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
sourceTensor:inputTensor
|
||||
outputShape:weightShapeTranspose
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose
|
||||
dimension:-3
|
||||
withDimension:-4
|
||||
name:nil];
|
||||
} else {
|
||||
gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
sourceTensor:inputTensor
|
||||
outputShape:mps_weight_shape
|
||||
forwardConvolutionDescriptor:conv2dDescriptor_
|
||||
name:nil];
|
||||
gradWeightTensor =
|
||||
[mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
sourceTensor:inputTensor
|
||||
outputShape:mps_weight_shape
|
||||
forwardConvolutionDescriptor:conv2dDescriptor_
|
||||
name:nil];
|
||||
}
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -513,21 +557,20 @@ Tensor mps_convolution_backward_weights(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
|
||||
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -535,10 +578,14 @@ Tensor mps_convolution_backward_weights(
|
||||
return grad_weight_t;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
std::array<bool,3> output_mask) {
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_convolution_backward(const at::Tensor& input,
|
||||
const at::Tensor& grad_output,
|
||||
const at::Tensor& weight,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::array<bool, 3> output_mask) {
|
||||
Tensor grad_input, grad_weight, grad_bias;
|
||||
if (input.numel() == 0) {
|
||||
if (output_mask[0]) {
|
||||
@ -549,73 +596,85 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
|
||||
}
|
||||
} else {
|
||||
if (output_mask[0]) {
|
||||
grad_input = mps_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
|
||||
grad_input = mps_convolution_backward_input(
|
||||
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = mps_convolution_backward_weights(weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
|
||||
grad_weight = mps_convolution_backward_weights(
|
||||
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
|
||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_forward(
|
||||
const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
||||
{
|
||||
auto input_size = conv_input_size(grad_output.sizes(), weight.sizes(),
|
||||
padding, output_padding, stride, dilation, groups);
|
||||
return mps_convolution_backward_input(input_size, grad_output, weight,
|
||||
padding, stride, dilation, groups, false);
|
||||
Tensor mps_convolution_transpose_forward(const Tensor& grad_output,
|
||||
const Tensor& weight,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef output_padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
auto input_size =
|
||||
conv_input_size(grad_output.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
|
||||
return mps_convolution_backward_input(input_size, grad_output, weight, padding, stride, dilation, groups, false);
|
||||
}
|
||||
|
||||
Tensor _mps_convolution_transpose(
|
||||
const Tensor& input_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
Tensor _mps_convolution_transpose(const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef output_padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS");
|
||||
|
||||
auto output_t = mps_convolution_transpose_forward(
|
||||
input_t, weight_t, padding, output_padding, stride, dilation, groups);
|
||||
auto output_t =
|
||||
mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups);
|
||||
return output_t;
|
||||
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_backward_input(
|
||||
const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, IntArrayRef input_shape)
|
||||
{
|
||||
return _mps_convolution_impl(
|
||||
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
|
||||
Tensor mps_convolution_transpose_backward_input(const Tensor& grad_output_t,
|
||||
const Tensor& weight_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
IntArrayRef input_shape) {
|
||||
return _mps_convolution_impl(grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_backward_weight(
|
||||
IntArrayRef weight_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
||||
{
|
||||
Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return mps_convolution_backward_weights(
|
||||
weight_size, input_t, grad_output_t,
|
||||
padding, stride, dilation, groups, false);
|
||||
weight_size, input_t, grad_output_t, padding, stride, dilation, groups, false);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
|
||||
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
std::array<bool,2> output_mask) {
|
||||
std::tuple<Tensor, Tensor> mps_convolution_transpose_backward(const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
const Tensor& weight,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef output_padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::array<bool, 2> output_mask) {
|
||||
Tensor grad_input, grad_weight;
|
||||
if (output_mask[0]) {
|
||||
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
|
||||
grad_input =
|
||||
mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups);
|
||||
grad_weight = mps_convolution_transpose_backward_weight(
|
||||
weight.sizes(), grad_output, input, padding, stride, dilation, groups);
|
||||
}
|
||||
|
||||
return std::tuple<Tensor,Tensor>{grad_input, grad_weight};
|
||||
return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -6,10 +6,7 @@
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void* pageAlignedBlockPtr(
|
||||
const void* ptr,
|
||||
NSUInteger size,
|
||||
NSUInteger* alignedBlockSize) {
|
||||
void* pageAlignedBlockPtr(const void* ptr, NSUInteger size, NSUInteger* alignedBlockSize) {
|
||||
uintptr_t address = (uintptr_t)ptr;
|
||||
uintptr_t alignedAddress = address & ~(PAGE_SIZE - 1);
|
||||
uintptr_t alignedEnd = ((address + size) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
|
||||
@ -26,15 +23,15 @@ void* pageAlignedBlockPtr(
|
||||
* Computes number of elements one needs to transfer to preserve all the elements
|
||||
*/
|
||||
size_t compute_strided_size(const at::Tensor& t) {
|
||||
size_t rc = 1;
|
||||
if (t.numel() == 0) {
|
||||
return 0;
|
||||
}
|
||||
for(const auto i: c10::irange(t.dim())) {
|
||||
assert(t.size(i) > 0);
|
||||
rc += (t.size(i) - 1) * t.stride(i);
|
||||
}
|
||||
return rc;
|
||||
size_t rc = 1;
|
||||
if (t.numel() == 0) {
|
||||
return 0;
|
||||
}
|
||||
for (const auto i : c10::irange(t.dim())) {
|
||||
assert(t.size(i) > 0);
|
||||
rc += (t.size(i) - 1) * t.stride(i);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool is_strided_contiguous(const at::Tensor& t) {
|
||||
@ -43,13 +40,15 @@ bool is_strided_contiguous(const at::Tensor& t) {
|
||||
|
||||
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
|
||||
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
|
||||
void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
|
||||
id<MTLBuffer> destBuffer, id<MTLBuffer> sourceBuffer, bool non_blocking = true) {
|
||||
void copy_cast_mps(at::Tensor& dst,
|
||||
const at::Tensor& src,
|
||||
id<MTLBuffer> destBuffer,
|
||||
id<MTLBuffer> sourceBuffer,
|
||||
bool non_blocking = true) {
|
||||
using namespace mps;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
@ -64,11 +63,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "copy_cast_mps" + getTensorsStringKey({src, dst});
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
@ -85,23 +84,24 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc]
|
||||
initWithMTLBuffer:sourceBuffer shape:srcShape dataType:srcDType]
|
||||
autorelease];
|
||||
MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc]
|
||||
initWithMTLBuffer:destBuffer shape:dstShape dataType:dstDType]
|
||||
autorelease];
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_: srcData};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor_: dstData};
|
||||
stream->executeMPSGraph(cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
|
||||
MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer
|
||||
shape:srcShape
|
||||
dataType:srcDType] autorelease];
|
||||
MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:destBuffer
|
||||
shape:dstShape
|
||||
dataType:dstDType] autorelease];
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : srcData};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor_ : dstData};
|
||||
stream->executeMPSGraph(
|
||||
cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
|
||||
}
|
||||
}
|
||||
|
||||
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
|
||||
{
|
||||
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
|
||||
auto sameMemFormat =
|
||||
src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@ -152,8 +152,8 @@ static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool
|
||||
needsBlit = false;
|
||||
tmpBuffer = destBuffer;
|
||||
} else if (src.element_size() < dst.element_size()) {
|
||||
tmp = at::native::empty_mps(dst.sizes(), dst.scalar_type(), c10::nullopt, kMPS);
|
||||
tmpBuffer = getMTLBufferStorage(tmp);
|
||||
tmp = at::native::empty_mps(dst.sizes(), dst.scalar_type(), c10::nullopt, kMPS);
|
||||
tmpBuffer = getMTLBufferStorage(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,15 +181,14 @@ static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool
|
||||
}
|
||||
|
||||
// Copies tensor from cpu to mps backed by identical strided-contiguous data
|
||||
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
{
|
||||
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
|
||||
auto src_byte_offset = src.storage_offset() * src.itemsize();
|
||||
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
|
||||
const size_t size_to_copy = src.nbytes();
|
||||
const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;
|
||||
const void* host_src = static_cast<char*>(src.storage().data()) + src_byte_offset;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));
|
||||
|
||||
@ -201,17 +200,16 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
|
||||
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
|
||||
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
|
||||
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
|
||||
length:alignedLength
|
||||
options:options
|
||||
deallocator:nil];
|
||||
length:alignedLength
|
||||
options:options
|
||||
deallocator:nil];
|
||||
|
||||
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
|
||||
[sourceBuffer release];
|
||||
}
|
||||
}
|
||||
|
||||
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
|
||||
{
|
||||
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
|
||||
// Typecast to dst_ if needed and expand, which is a no-op
|
||||
Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);
|
||||
|
||||
@ -233,7 +231,7 @@ static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool n
|
||||
dst = at::empty_like(src, at::device(at::kMPS));
|
||||
}
|
||||
copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy);
|
||||
return needs_copy? dst_.copy_(dst) : dst_;
|
||||
return needs_copy ? dst_.copy_(dst) : dst_;
|
||||
}
|
||||
|
||||
void copy_blit_mps(void* dst, const void* src, size_t size) {
|
||||
@ -241,8 +239,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
|
||||
stream->copy_and_sync((id<MTLBuffer>)(src), (id<MTLBuffer>)(dst), size, 0, 0, true);
|
||||
}
|
||||
|
||||
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
|
||||
{
|
||||
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
|
||||
auto src_byte_offset = src_.storage_offset() * src_.itemsize();
|
||||
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
|
||||
|
||||
@ -250,7 +247,8 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
// gather into dst. This reduces the overhead of doing an additional blit for most cases
|
||||
bool returnGatherOutput = dst_.is_contiguous();
|
||||
Tensor src;
|
||||
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
auto sameMemFormat =
|
||||
src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
const bool sameDataType = src_.dtype() == dst_.dtype();
|
||||
|
||||
if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) ||
|
||||
@ -290,19 +288,18 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset);
|
||||
} else {
|
||||
if (dst_byte_offset) {
|
||||
auto tmp = at::native::empty_mps(dst_.sizes(), dst_.scalar_type(), c10::nullopt, kMPS);
|
||||
auto tmpBuffer = getMTLBufferStorage(tmp);
|
||||
copy_cast_mps(tmp, src, tmpBuffer, sourceBuffer);
|
||||
stream->copy(tmpBuffer, destBuffer, dst_.nbytes(), 0, dst_byte_offset);
|
||||
auto tmp = at::native::empty_mps(dst_.sizes(), dst_.scalar_type(), c10::nullopt, kMPS);
|
||||
auto tmpBuffer = getMTLBufferStorage(tmp);
|
||||
copy_cast_mps(tmp, src, tmpBuffer, sourceBuffer);
|
||||
stream->copy(tmpBuffer, destBuffer, dst_.nbytes(), 0, dst_byte_offset);
|
||||
} else {
|
||||
copy_cast_mps(dst_, src, destBuffer, sourceBuffer);
|
||||
copy_cast_mps(dst_, src, destBuffer, sourceBuffer);
|
||||
}
|
||||
}
|
||||
return dst_;
|
||||
}
|
||||
|
||||
at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
{
|
||||
at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking) {
|
||||
TORCH_CHECK(dst.defined(), "dst is undefined");
|
||||
TORCH_CHECK(src.defined(), "src is undefined");
|
||||
|
||||
@ -328,20 +325,16 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
|
||||
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.device().type() == DeviceType::MPS,
|
||||
"mps_copy_ is implemented only for *->MPS; MPS->*");
|
||||
TORCH_INTERNAL_ASSERT(src.device().type() == DeviceType::MPS, "mps_copy_ is implemented only for *->MPS; MPS->*");
|
||||
return dst;
|
||||
}
|
||||
} // namespace mps
|
||||
|
||||
Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst)
|
||||
{
|
||||
Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst) {
|
||||
return mps::mps_copy_(const_cast<Tensor&>(dst), self, false);
|
||||
}
|
||||
|
||||
Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking)
|
||||
{
|
||||
Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
|
||||
return mps::mps_copy_(const_cast<Tensor&>(dst), self, non_blocking);
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Cross.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -82,12 +82,12 @@ static id<MTLLibrary> compileCrossOpLibrary(id<MTLDevice> device) {
|
||||
return crossLibrary;
|
||||
}
|
||||
|
||||
NSError *error = nil;
|
||||
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion: MTLLanguageVersion2_3];
|
||||
crossLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_CROSS encoding:NSASCIIStringEncoding]
|
||||
options:options
|
||||
error:&error];
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
crossLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_CROSS encoding:NSASCIIStringEncoding]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(crossLibrary, "Failed to create metal cross library, error: ", [[error description] UTF8String]);
|
||||
return crossLibrary;
|
||||
}
|
||||
@ -115,25 +115,25 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
||||
TORCH_CHECK(input.dtype() != at::kDouble, "float64 is not supported on MPS");
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
.add_output(out)
|
||||
.add_input(input)
|
||||
.add_input(other)
|
||||
.resize_outputs(false)
|
||||
.declare_static_shape(out.sizes(), /*squash_dims=*/dim)
|
||||
.build();
|
||||
.add_output(out)
|
||||
.add_input(input)
|
||||
.add_input(other)
|
||||
.resize_outputs(false)
|
||||
.declare_static_shape(out.sizes(), /*squash_dims=*/dim)
|
||||
.build();
|
||||
|
||||
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
|
||||
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
|
||||
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
|
||||
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
|
||||
id<MTLBuffer> outputBuffer = getMTLBufferStorage(out);
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const int64_t out_dim_stride = out.stride(dim);
|
||||
const int64_t out_dim_stride = out.stride(dim);
|
||||
const int64_t input_dim_stride = input.stride(dim);
|
||||
const int64_t other_dim_stride = other.stride(dim);
|
||||
const uint32_t nDim = iter.ndim();
|
||||
constexpr uint32_t nOffsets = 3;
|
||||
const uint32_t numThreads = iter.numel();
|
||||
dispatch_sync(mpsStream->queue(), ^(){
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
NSError* error = nil;
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
@ -143,23 +143,25 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
||||
std::vector<uint32_t> iterShapeData(iterShape.size());
|
||||
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);
|
||||
|
||||
for (const auto i: c10::irange(iterShape.size())) {
|
||||
for (const auto i : c10::irange(iterShape.size())) {
|
||||
TORCH_CHECK(i <= UINT32_MAX);
|
||||
iterShapeData[i] = (uint32_t)(iterShape[i]);
|
||||
}
|
||||
|
||||
for (const auto i: c10::irange(nDim)) {
|
||||
for (const auto offset: c10::irange(nOffsets)) {
|
||||
strides[i][offset] = iter.strides(offset)[i];
|
||||
for (const auto i : c10::irange(nDim)) {
|
||||
for (const auto offset : c10::irange(nOffsets)) {
|
||||
strides[i][offset] = iter.strides(offset)[i];
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
|
||||
error: &error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
|
||||
options: 0] autorelease];
|
||||
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(
|
||||
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
|
||||
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
|
||||
@ -169,30 +171,28 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
||||
|
||||
NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (kernelOffsetsTGSize > numThreads)
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
|
||||
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
|
||||
|
||||
id<MTLComputePipelineState> crossPSO = crossPipelineState(device, out.scalar_type());
|
||||
[computeEncoder setComputePipelineState:crossPSO];
|
||||
[computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
|
||||
[computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1];
|
||||
[computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
|
||||
[computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1];
|
||||
[computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
|
||||
[computeEncoder setBytes:&out_dim_stride length:sizeof(int64_t) atIndex:4];
|
||||
[computeEncoder setBytes:&out_dim_stride length:sizeof(int64_t) atIndex:4];
|
||||
[computeEncoder setBytes:&input_dim_stride length:sizeof(int64_t) atIndex:5];
|
||||
[computeEncoder setBytes:&other_dim_stride length:sizeof(int64_t) atIndex:6];
|
||||
|
||||
NSUInteger tgSize = crossPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (tgSize > numThreads) {
|
||||
tgSize = numThreads;
|
||||
tgSize = numThreads;
|
||||
}
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: threadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
[computeEncoder endEncoding];
|
||||
mpsStream->commit(true);
|
||||
|
@ -1,39 +1,40 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/Distributions.h>
|
||||
#include <ATen/native/DistributionTemplates.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#include <ATen/native/DistributionTemplates.h>
|
||||
#include <ATen/native/Distributions.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct RandomCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
RandomCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { }
|
||||
struct RandomCachedGraph : public MPSCachedGraph {
|
||||
RandomCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
// Only relevant for multinomial
|
||||
MPSGraphTensor *probTensor = nil;
|
||||
MPSGraphTensor *resultTensor = nil;
|
||||
MPSGraphTensor *stateTensor = nil;
|
||||
MPSGraphTensor* probTensor = nil;
|
||||
MPSGraphTensor* resultTensor = nil;
|
||||
MPSGraphTensor* stateTensor = nil;
|
||||
// used for Normal distributions only
|
||||
MPSGraphTensor *meanTensor = nil, *stdTensor = nil;
|
||||
};
|
||||
|
||||
typedef MPSGraphTensor* (^RandomOpBlock)(RandomCachedGraph*, MPSGraphTensor*);
|
||||
#define RandomOpFn(graph, randomTensor) MPSGraphTensor* (mps::RandomCachedGraph* graph, MPSGraphTensor* randomTensor)
|
||||
#define RandomOpFn(graph, randomTensor) MPSGraphTensor*(mps::RandomCachedGraph * graph, MPSGraphTensor * randomTensor)
|
||||
|
||||
// for Uniform distributions with scalar from (val1) and to (val2) intervals
|
||||
// for Normal distributions with scalar mean (val1) and std (val2) values
|
||||
template<typename scalar_t>
|
||||
Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
template <typename scalar_t>
|
||||
Tensor& random_mps_impl(Tensor& self,
|
||||
scalar_t val1,
|
||||
scalar_t val2,
|
||||
const c10::optional<Tensor>& mean_opt,
|
||||
const c10::optional<Tensor>& std_opt,
|
||||
MPSGraphRandomDistribution distribution,
|
||||
c10::optional<Generator> gen,
|
||||
std::string op_name, RandomOpBlock randomBlock)
|
||||
{
|
||||
std::string op_name,
|
||||
RandomOpBlock randomBlock) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
@ -46,13 +47,14 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
|
||||
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
RandomCachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
RandomCachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new RandomCachedGraph(mpsGraph);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(at::mps::detail::PHILOX_STATE_N)]);
|
||||
newCachedGraph->stateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
|
||||
|
||||
// FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend.
|
||||
const MPSDataType inputDataType = [&] {
|
||||
@ -64,8 +66,8 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
}();
|
||||
const MPSDataType outputDataType = (std::is_same<scalar_t, bool>::value) ? MPSDataTypeBool : inputDataType;
|
||||
|
||||
MPSGraphRandomOpDescriptor *desc = [MPSGraphRandomOpDescriptor descriptorWithDistribution: distribution
|
||||
dataType: inputDataType];
|
||||
MPSGraphRandomOpDescriptor* desc = [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution
|
||||
dataType:inputDataType];
|
||||
if (distribution == MPSGraphRandomDistributionUniform) {
|
||||
if (inputDataType == MPSDataTypeInt32) {
|
||||
desc.minInteger = static_cast<NSInteger>(val1);
|
||||
@ -81,10 +83,10 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
// we don't use the output state tensor from the MPSGraph API as it requires reading back from GPU to CPU.
|
||||
// Instead, we keep the Philox state in the MPSGenerator and use the PyTorch's philox_engine to maintain
|
||||
// the counters, and feed them to the graph manually
|
||||
NSArray<MPSGraphTensor*> *resultTensors = [mpsGraph randomTensorWithShape: getMPSShape(self)
|
||||
descriptor: desc
|
||||
stateTensor: newCachedGraph->stateTensor
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor*>* resultTensors = [mpsGraph randomTensorWithShape:getMPSShape(self)
|
||||
descriptor:desc
|
||||
stateTensor:newCachedGraph->stateTensor
|
||||
name:nil];
|
||||
newCachedGraph->resultTensor = randomBlock ? randomBlock(newCachedGraph, resultTensors[0]) : resultTensors[0];
|
||||
// results will be cast if self's scalar type isn't directly supported by MPS backend.
|
||||
if (getMPSDataType(self) != outputDataType)
|
||||
@ -94,19 +96,20 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
});
|
||||
}
|
||||
// feed the updated state values to the graph
|
||||
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]];
|
||||
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
|
||||
MPSNDArrayDescriptor* stateDesc =
|
||||
[MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]];
|
||||
MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease];
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(mps_gen->mutex_);
|
||||
// update the Philox state values on each run
|
||||
mps_gen->update_philox_counters();
|
||||
[stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil];
|
||||
[stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil];
|
||||
}
|
||||
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
|
||||
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease];
|
||||
|
||||
Placeholder meanPlaceholder, stdPlaceholder;
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
feeds[cachedGraph->stateTensor] = stateTensorData;
|
||||
|
||||
if (cachedGraph->stdTensor) {
|
||||
@ -121,7 +124,7 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = @{
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
@ -131,13 +134,14 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s,
|
||||
Tensor& normal_mps_impl(Tensor& self,
|
||||
double mean_s,
|
||||
double std_s,
|
||||
const c10::optional<Tensor>& mean_opt,
|
||||
const c10::optional<Tensor>& std_opt,
|
||||
c10::optional<Generator> gen,
|
||||
std::string op_name)
|
||||
{
|
||||
const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt));
|
||||
std::string op_name) {
|
||||
const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt));
|
||||
const Tensor& mean_t = *(at::borrow_from_optional_tensor(mean_opt));
|
||||
|
||||
TORCH_CHECK(std_s >= 0.0, op_name, " expects std >= 0.0, but found std=", std_s);
|
||||
@ -153,39 +157,45 @@ Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s,
|
||||
|
||||
if (std_t.defined()) {
|
||||
cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, std_t);
|
||||
resultTensor = [mpsGraph multiplicationWithPrimaryTensor: randomTensor
|
||||
secondaryTensor: cachedGraph->stdTensor
|
||||
name: nil];
|
||||
resultTensor = [mpsGraph multiplicationWithPrimaryTensor:randomTensor
|
||||
secondaryTensor:cachedGraph->stdTensor
|
||||
name:nil];
|
||||
}
|
||||
if (mean_t.defined()) {
|
||||
cachedGraph->meanTensor = mpsGraphRankedPlaceHolder(mpsGraph, mean_t);
|
||||
return [mpsGraph additionWithPrimaryTensor: resultTensor
|
||||
secondaryTensor: cachedGraph->meanTensor
|
||||
name: nil];
|
||||
return [mpsGraph additionWithPrimaryTensor:resultTensor secondaryTensor:cachedGraph->meanTensor name:nil];
|
||||
}
|
||||
return resultTensor;
|
||||
};
|
||||
return random_mps_impl<double>(self, mean_s, std_s, mean_opt, std_opt,
|
||||
MPSGraphRandomDistributionNormal, gen,
|
||||
op_name + getTensorsStringKey({mean_t, std_t}), random_op_block);
|
||||
|
||||
return random_mps_impl<double>(self,
|
||||
mean_s,
|
||||
std_s,
|
||||
mean_opt,
|
||||
std_opt,
|
||||
MPSGraphRandomDistributionNormal,
|
||||
gen,
|
||||
op_name + getTensorsStringKey({mean_t, std_t}),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional<Generator> gen, std::string op_name)
|
||||
{
|
||||
Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional<Generator> gen, std::string op_name) {
|
||||
TORCH_CHECK(prob_t.is_same_size(self), op_name, ": probability and self tensor should be of the same shape")
|
||||
|
||||
RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, prob_t);
|
||||
return [mpsGraph lessThanWithPrimaryTensor: randomTensor
|
||||
secondaryTensor: cachedGraph->stdTensor
|
||||
name: nil];
|
||||
return [mpsGraph lessThanWithPrimaryTensor:randomTensor secondaryTensor:cachedGraph->stdTensor name:nil];
|
||||
};
|
||||
// Bernoulli generates binary output so we use bool type
|
||||
return mps::random_mps_impl<bool>(self, 0.0, 1.0, c10::nullopt, prob_t,
|
||||
MPSGraphRandomDistributionUniform, gen,
|
||||
op_name + getTensorsStringKey({prob_t}), random_op_block);
|
||||
return mps::random_mps_impl<bool>(self,
|
||||
0.0,
|
||||
1.0,
|
||||
c10::nullopt,
|
||||
prob_t,
|
||||
MPSGraphRandomDistributionUniform,
|
||||
gen,
|
||||
op_name + getTensorsStringKey({prob_t}),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
@ -196,15 +206,19 @@ Tensor& uniform_mps_(Tensor& self, double from, double to, c10::optional<Generat
|
||||
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
|
||||
TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
|
||||
TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
|
||||
"uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
|
||||
">::max(), but found to=", to, " and from=", from,
|
||||
" which result in to-from to exceed the limit");
|
||||
"uniform_ expects to-from <= std::numeric_limits<",
|
||||
toString(self.scalar_type()),
|
||||
">::max(), but found to=",
|
||||
to,
|
||||
" and from=",
|
||||
from,
|
||||
" which result in to-from to exceed the limit");
|
||||
from = std::min(std::max(from, min), max);
|
||||
to = std::max(std::min(to, max), min);
|
||||
});
|
||||
|
||||
return mps::random_mps_impl<double>(self, from, to, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
return mps::random_mps_impl<double>(
|
||||
self, from, to, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
}
|
||||
|
||||
Tensor& normal_mps_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
|
||||
@ -248,7 +262,7 @@ Tensor& normal_mps_out(const Tensor& mean, const Tensor& std, c10::optional<Gene
|
||||
|
||||
Tensor& bernoulli_out_mps(const Tensor& p_, c10::optional<Generator> gen, Tensor& result) {
|
||||
result.resize_(p_.sizes());
|
||||
return mps::bernoulli_mps_impl(result, p_, gen, __func__);
|
||||
return mps::bernoulli_mps_impl(result, p_, gen, __func__);
|
||||
}
|
||||
|
||||
Tensor& bernoulli_mps_(Tensor& self, double p, c10::optional<Generator> gen) {
|
||||
@ -271,22 +285,35 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
|
||||
to = *to_opt;
|
||||
TORCH_CHECK(from < to, "random_mps_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
|
||||
if (isFloatingType(input_dtype)) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] {
|
||||
from = templates::update_from<scalar_t>(from);
|
||||
to = templates::update_to<scalar_t>(to);
|
||||
TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
|
||||
});
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] {
|
||||
from = templates::update_from<scalar_t>(from);
|
||||
to = templates::update_to<scalar_t>(to);
|
||||
TORCH_CHECK(
|
||||
from < to,
|
||||
"random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=",
|
||||
from,
|
||||
" >= to=",
|
||||
to);
|
||||
});
|
||||
templates::check_from_to_in_range(from, to - 1, self.dtype());
|
||||
}
|
||||
} else if (from != std::numeric_limits<int64_t>::lowest()) {
|
||||
// [from, std::numeric_limits<int64_t>::max()]
|
||||
if (isFloatingType(input_dtype)) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] {
|
||||
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
|
||||
to = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
|
||||
from = templates::update_from<scalar_t>(from);
|
||||
TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=", from, " > to=", to);
|
||||
});
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] {
|
||||
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
|
||||
to = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max()
|
||||
: static_cast<int64_t>(scalar_t_max);
|
||||
from = templates::update_from<scalar_t>(from);
|
||||
TORCH_CHECK(
|
||||
from < to,
|
||||
"random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=",
|
||||
from,
|
||||
" > to=",
|
||||
to);
|
||||
});
|
||||
} else if (isIntegralType(input_dtype, /*includeBool=*/true)) {
|
||||
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, input_dtype, "random_from_to_range_calc", [&] {
|
||||
if (std::is_same<scalar_t, bool>::value) {
|
||||
@ -295,13 +322,11 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
|
||||
to = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
||||
}
|
||||
});
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
TORCH_CHECK(false, "random_mps_ handles only integral, floating-point and boolean types");
|
||||
}
|
||||
templates::check_from_to_in_range(from, to, self.dtype());
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
|
||||
// range = 2^64
|
||||
|
||||
@ -309,8 +334,8 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
|
||||
TORCH_CHECK(false, "random_mps_ currently does not handle the lowest() -> max() range");
|
||||
}
|
||||
|
||||
return mps::random_mps_impl<int64_t>(self, from, to - 1, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
return mps::random_mps_impl<int64_t>(
|
||||
self, from, to - 1, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
}
|
||||
|
||||
Tensor& random_mps_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
|
||||
@ -323,22 +348,23 @@ Tensor& exponential_mps_(Tensor& self, double lambda, c10::optional<Generator> g
|
||||
|
||||
mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar: 1.0f
|
||||
dataType: randomTensor.dataType];
|
||||
MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar: -lambda
|
||||
dataType: randomTensor.dataType];
|
||||
MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor: unitTensor
|
||||
secondaryTensor: randomTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor: subtractTensor
|
||||
name: nil];
|
||||
return [mpsGraph divisionWithPrimaryTensor: logTensor
|
||||
secondaryTensor: minusLambdaTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f dataType:randomTensor.dataType];
|
||||
MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar:-lambda dataType:randomTensor.dataType];
|
||||
MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor
|
||||
secondaryTensor:randomTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
|
||||
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
|
||||
};
|
||||
return mps::random_mps_impl<double>(self, 0.0, 1.0, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, gen,
|
||||
"exponential_mps_:" + std::to_string(lambda), random_op_block);
|
||||
return mps::random_mps_impl<double>(self,
|
||||
0.0,
|
||||
1.0,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform,
|
||||
gen,
|
||||
"exponential_mps_:" + std::to_string(lambda),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor& result) {
|
||||
@ -354,9 +380,12 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
|
||||
}
|
||||
|
||||
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
|
||||
TORCH_CHECK(!generator.has_value() ||
|
||||
(generator.has_value() && result.device() == generator->device()),
|
||||
"Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
|
||||
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()),
|
||||
"Expected a '",
|
||||
result.device(),
|
||||
"' generator device but found '",
|
||||
generator->device(),
|
||||
"'");
|
||||
check_supported_max_int_with_precision(n, result);
|
||||
|
||||
result.resize_({n});
|
||||
@ -366,36 +395,34 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
|
||||
|
||||
mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor axis:0 name:nil];
|
||||
if (result.scalar_type() != kInt) {
|
||||
argsortTensor = [mpsGraph castTensor:argsortTensor
|
||||
toType:mps::getMPSDataType(result)
|
||||
name:@"castOutput"];
|
||||
argsortTensor = [mpsGraph castTensor:argsortTensor toType:mps::getMPSDataType(result) name:@"castOutput"];
|
||||
}
|
||||
return argsortTensor;
|
||||
};
|
||||
|
||||
return mps::random_mps_impl<int64_t>(result, 0.0, 1.0, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, generator,
|
||||
"ranperm_out_mps:" + mps::getTensorsStringKey({result}), random_op_block);
|
||||
return mps::random_mps_impl<int64_t>(result,
|
||||
0.0,
|
||||
1.0,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform,
|
||||
generator,
|
||||
"ranperm_out_mps:" + mps::getTensorsStringKey({result}),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
Tensor& multinomial_with_replacement_mps_kernel(
|
||||
const Tensor& self,
|
||||
const int64_t n_sample,
|
||||
c10::optional<Generator> generator,
|
||||
Tensor& result) {
|
||||
|
||||
Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
|
||||
const int64_t n_sample,
|
||||
c10::optional<Generator> generator,
|
||||
Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(generator, at::mps::detail::getDefaultMPSGenerator());
|
||||
int inputSize = self.dim();
|
||||
int numDist =
|
||||
inputSize == 1 ? 1 : self.size(0);
|
||||
int numCategories =
|
||||
inputSize == 1 ? self.size(0) : self.size(1);
|
||||
int numDist = inputSize == 1 ? 1 : self.size(0);
|
||||
int numCategories = inputSize == 1 ? self.size(0) : self.size(1);
|
||||
|
||||
// Restructure data for 2d
|
||||
auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self;
|
||||
@ -408,24 +435,22 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
|
||||
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
RandomCachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
RandomCachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSShape* prob_shape = getMPSShape(self_v);
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new RandomCachedGraph(mpsGraph);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]);
|
||||
|
||||
auto prob_dtype = getMPSDataType(self_v);
|
||||
|
||||
// This is probability weights
|
||||
newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v), prob_shape);
|
||||
|
||||
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor* sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor axis:-1 name:nil];
|
||||
|
||||
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
|
||||
MPSGraphTensor* normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
|
||||
secondaryTensor:sumProbs
|
||||
name:nil];
|
||||
|
||||
@ -433,139 +458,125 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
auto ns_numDist = [NSNumber numberWithInt:numDist];
|
||||
auto ns_n_sample = [NSNumber numberWithInt:n_sample];
|
||||
|
||||
MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
|
||||
shape:@[ns_numCategories, ns_numCategories]
|
||||
MPSGraphTensor* ones = [mpsGraph constantWithScalar:1.0f
|
||||
shape:@[ ns_numCategories, ns_numCategories ]
|
||||
dataType:prob_dtype];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar: 0.0f
|
||||
dataType: MPSDataTypeInt32];
|
||||
auto minusOneTensor = [mpsGraph constantWithScalar: -1.0f
|
||||
dataType: MPSDataTypeInt32];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
|
||||
auto minusOneTensor = [mpsGraph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
|
||||
|
||||
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
|
||||
MPSGraphTensor* upperTriangle = [mpsGraph bandPartWithTensor:ones
|
||||
numLowerTensor:zeroTensor
|
||||
numUpperTensor:minusOneTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
|
||||
MPSGraphTensor* upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
|
||||
secondaryTensor:upperTriangle
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
|
||||
MPSGraphTensor* lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
|
||||
secondaryTensor:normalizedProbs
|
||||
name:nil];
|
||||
|
||||
upperProbRange = [mpsGraph reshapeTensor:upperProbRange
|
||||
withShape:@[ns_numDist, @1, ns_numCategories]
|
||||
withShape:@[ ns_numDist, @1, ns_numCategories ]
|
||||
name:nil];
|
||||
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
|
||||
withShape:@[ns_numDist, @1, ns_numCategories]
|
||||
withShape:@[ ns_numDist, @1, ns_numCategories ]
|
||||
name:nil];
|
||||
|
||||
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
|
||||
dataType:prob_dtype];
|
||||
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
|
||||
MPSGraphRandomOpDescriptor* descriptor =
|
||||
[MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
|
||||
dataType:prob_dtype];
|
||||
NSArray<MPSGraphTensor*>* generatorTensors = [mpsGraph randomTensorWithShape:@[ ns_numDist, ns_n_sample, @1 ]
|
||||
descriptor:descriptor
|
||||
stateTensor:newCachedGraph->stateTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *randomTensor = generatorTensors[0];
|
||||
MPSGraphTensor* randomTensor = generatorTensors[0];
|
||||
|
||||
auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
|
||||
auto broadcastShape = @[ ns_numDist, ns_n_sample, ns_numCategories ];
|
||||
int broadcastShapeVals[3] = {numDist, static_cast<int>(n_sample), numCategories};
|
||||
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
|
||||
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
|
||||
dataType:MPSDataTypeUInt32];
|
||||
MPSGraphTensor* broadcastShapeTensor = [mpsGraph
|
||||
constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:broadcastShape.count] ]
|
||||
dataType:MPSDataTypeUInt32];
|
||||
|
||||
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
|
||||
toShape:broadcastShape
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
|
||||
MPSGraphTensor* samplesTensor = [mpsGraph broadcastTensor:randomTensor toShape:broadcastShape name:nil];
|
||||
MPSGraphTensor* sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:lowerProbRange
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
|
||||
MPSGraphTensor* sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:upperProbRange
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
|
||||
secondaryTensor:sampleBelow
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"sampleMask"];
|
||||
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
|
||||
MPSGraphTensor* sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
|
||||
secondaryTensor:sampleBelow
|
||||
name:nil];
|
||||
MPSGraphTensor* sampleMask = [mpsGraph castTensor:sampleWithin toType:MPSDataTypeInt32 name:@"sampleMask"];
|
||||
MPSGraphTensor* categoriesTensor = [mpsGraph coordinateAlongAxis:-1
|
||||
withShapeTensor:broadcastShapeTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
|
||||
secondaryTensor:sampleMask
|
||||
name:nil];
|
||||
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
|
||||
withShape:@[ns_numDist ,ns_n_sample]
|
||||
name:nil];
|
||||
MPSGraphTensor* binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
|
||||
secondaryTensor:sampleMask
|
||||
name:nil];
|
||||
MPSGraphTensor* reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor axis:-1 name:nil];
|
||||
MPSGraphTensor* reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
|
||||
withShape:@[ ns_numDist, ns_n_sample ]
|
||||
name:nil];
|
||||
newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor
|
||||
toType:getMPSDataType(result)
|
||||
name:@"resultTensor"];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
});
|
||||
}
|
||||
// update the Philox state values on each run of the same graph
|
||||
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]];
|
||||
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
|
||||
MPSNDArrayDescriptor* stateDesc =
|
||||
[MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]];
|
||||
MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease];
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(mps_gen->mutex_);
|
||||
// update the Philox state values on each run
|
||||
mps_gen->update_philox_counters();
|
||||
[stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil];
|
||||
[stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil];
|
||||
}
|
||||
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
|
||||
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease];
|
||||
|
||||
auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v);
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
cachedGraph->stateTensor : stateTensorData,
|
||||
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
/* The largest consecutive integer representable in float32 (2^24) */
|
||||
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);
|
||||
|
||||
Tensor& multinomial_out_mps(const Tensor& self,
|
||||
int64_t n_sample,
|
||||
bool with_replacement,
|
||||
c10::optional<Generator> gen,
|
||||
Tensor& result) {
|
||||
|
||||
int64_t n_sample,
|
||||
bool with_replacement,
|
||||
c10::optional<Generator> gen,
|
||||
Tensor& result) {
|
||||
TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device");
|
||||
TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
|
||||
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
|
||||
"multinomial only supports floating-point dtypes for input, got: ",
|
||||
self.scalar_type());
|
||||
TORCH_CHECK(
|
||||
result.device() == self.device(),
|
||||
"multinomial arguments must have the same device");
|
||||
TORCH_CHECK(
|
||||
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
|
||||
TORCH_CHECK(
|
||||
at::isFloatingType(self.scalar_type()),
|
||||
"multinomial only supports floating-point dtypes for input, got: ",
|
||||
self.scalar_type());
|
||||
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
|
||||
"multinomial expects Long tensor out, got: ", result.scalar_type());
|
||||
result.scalar_type() == ScalarType::Long, "multinomial expects Long tensor out, got: ", result.scalar_type());
|
||||
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
|
||||
int64_t n_categories = self.size(-1);
|
||||
TORCH_CHECK(with_replacement || (n_sample <= n_categories),
|
||||
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
|
||||
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
|
||||
// Since the index tensor is float, numCategories cannot exceed max
|
||||
// float integer precision
|
||||
TORCH_CHECK(
|
||||
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
|
||||
"number of categories cannot exceed 2^24");
|
||||
TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24");
|
||||
|
||||
if (self.dim() == 1) {
|
||||
result.resize_({n_sample});
|
||||
@ -583,19 +594,15 @@ Tensor& multinomial_out_mps(const Tensor& self,
|
||||
if (!with_replacement || n_sample == 1) {
|
||||
// Sanity checks on `self`.
|
||||
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
|
||||
TORCH_CHECK(
|
||||
is_valid.to<bool>(),
|
||||
"probability tensor contains either `inf`, `nan` or element < 0");
|
||||
TORCH_CHECK(is_valid.to<bool>(), "probability tensor contains either `inf`, `nan` or element < 0");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
bool zero_prob_condition;
|
||||
if (self.dim() == 1){
|
||||
if (self.dim() == 1) {
|
||||
zero_prob_condition = (self.sum() == 0).item().to<bool>();
|
||||
} else {
|
||||
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!zero_prob_condition,
|
||||
"invalid multinomial distribution (sum of probabilities <= 0)");
|
||||
TORCH_CHECK(!zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
|
||||
|
||||
// The algorithm is from gumbel softmax.
|
||||
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
|
||||
@ -625,11 +632,7 @@ Tensor& multinomial_out_mps(const Tensor& self,
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor multinomial_mps(
|
||||
const Tensor& self,
|
||||
int64_t n_sample,
|
||||
bool with_replacement,
|
||||
c10::optional<Generator> gen) {
|
||||
Tensor multinomial_mps(const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional<Generator> gen) {
|
||||
Tensor result = at::empty({0}, self.options().dtype(kLong));
|
||||
multinomial_out_mps(self, n_sample, with_replacement, gen, result);
|
||||
return result;
|
||||
|
@ -3,9 +3,8 @@
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
// Steps to add op for MPS backend:
|
||||
// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key
|
||||
@ -29,7 +28,6 @@
|
||||
// g) Then call runMPSGraph() with input params and return the result.
|
||||
//
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Tensor& eye_out_mps(int64_t n, Tensor& result) {
|
||||
@ -38,7 +36,6 @@ Tensor& eye_out_mps(int64_t n, Tensor& result) {
|
||||
}
|
||||
|
||||
Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
|
||||
// This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts
|
||||
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
|
||||
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);
|
||||
@ -47,7 +44,7 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
result.zero_();
|
||||
|
||||
// Handle empty outputs
|
||||
if(result.numel() == 0)
|
||||
if (result.numel() == 0)
|
||||
return result;
|
||||
|
||||
// Get MPS stream
|
||||
@ -55,25 +52,24 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph time and time again for the same operation
|
||||
// The keys of this structure are based on the inputs and outputs needed for the operation
|
||||
// Here, we don't have any input tensors, just an output tensor
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph
|
||||
// time and time again for the same operation The keys of this structure are based on the inputs and outputs needed
|
||||
// for the operation Here, we don't have any input tensors, just an output tensor
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
|
||||
// etc match the earlier created MPSGraph
|
||||
string key = "eye_out_mps:" + getTensorsStringKey({result});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
// Initialize graph
|
||||
@ -84,11 +80,9 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
dataType:getMPSDataType(result)];
|
||||
|
||||
// Here we can call the MPSGraph API needed to execute the operation.
|
||||
// The API details can be found here: https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
|
||||
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor
|
||||
numLower:0
|
||||
numUpper:0
|
||||
name:nil];
|
||||
// The API details can be found here:
|
||||
// https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
|
||||
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil];
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
@ -102,9 +96,8 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
// In this case, there are no inputs, so the feeds are nil
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
// Run the graph
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
@ -113,5 +106,4 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,12 +1,15 @@
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/GridSamplerUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& grid,
|
||||
int64_t interpolation_mode, int64_t padding_mode,
|
||||
void grid_sampler_2d_mps_impl(Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& grid,
|
||||
int64_t interpolation_mode,
|
||||
int64_t padding_mode,
|
||||
bool align_corners) {
|
||||
// Grid Sampler support has been added in macOS 13.1
|
||||
#if defined(__MAC_13_2)
|
||||
@ -18,35 +21,43 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
MPSGraphPaddingMode paddingMode;
|
||||
|
||||
auto memory_format = input.suggest_memory_format();
|
||||
MPSGraphTensorNamedDataLayout inputTensorLayout =
|
||||
(memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
|
||||
MPSGraphTensorNamedDataLayout inputTensorLayout = (memory_format == at::MemoryFormat::Contiguous)
|
||||
? MPSGraphTensorNamedDataLayoutNCHW
|
||||
: MPSGraphTensorNamedDataLayoutNHWC;
|
||||
|
||||
switch (static_cast<GridSamplerPadding>(padding_mode)) {
|
||||
case GridSamplerPadding::Zeros:
|
||||
paddingMode = MPSGraphPaddingModeZero; break;
|
||||
paddingMode = MPSGraphPaddingModeZero;
|
||||
break;
|
||||
case GridSamplerPadding::Border:
|
||||
TORCH_CHECK(false, "MPS: Unsupported Border padding mode"); break;
|
||||
TORCH_CHECK(false, "MPS: Unsupported Border padding mode");
|
||||
break;
|
||||
case GridSamplerPadding::Reflection:
|
||||
paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric; break;
|
||||
paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "MPS: Unrecognised Padding Mode: ", padding_mode);
|
||||
}
|
||||
|
||||
switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
|
||||
case GridSamplerInterpolation::Bilinear:
|
||||
samplingMode = MPSGraphResizeBilinear; break;
|
||||
samplingMode = MPSGraphResizeBilinear;
|
||||
break;
|
||||
case GridSamplerInterpolation::Nearest:
|
||||
samplingMode = MPSGraphResizeNearest; break;
|
||||
samplingMode = MPSGraphResizeNearest;
|
||||
break;
|
||||
case GridSamplerInterpolation::Bicubic:
|
||||
TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation"); break;
|
||||
TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation");
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode); break;
|
||||
}
|
||||
TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode);
|
||||
break;
|
||||
}
|
||||
|
||||
MPSStream *stream = getCurrentMPSStream();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* gridTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -55,17 +66,13 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "grid_sampler_2d_mps" +
|
||||
getTensorsStringKey({input, grid}) +
|
||||
":" + std::to_string(interpolation_mode) +
|
||||
":" + std::to_string(padding_mode) +
|
||||
":" + std::to_string(align_corners);
|
||||
string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) +
|
||||
":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
@ -75,27 +82,27 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
if (static_cast<GridSamplerInterpolation>(interpolation_mode) == GridSamplerInterpolation::Nearest) {
|
||||
outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor
|
||||
coordinateTensor: gridTensor
|
||||
layout: inputTensorLayout
|
||||
normalizeCoordinates: TRUE
|
||||
relativeCoordinates: FALSE
|
||||
alignCorners: align_corners
|
||||
paddingMode: paddingMode
|
||||
nearestRoundingMode: MPSGraphResizeNearestRoundingModeRoundToEven
|
||||
constantValue: 0.0f
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph sampleGridWithSourceTensor:inputTensor
|
||||
coordinateTensor:gridTensor
|
||||
layout:inputTensorLayout
|
||||
normalizeCoordinates:TRUE
|
||||
relativeCoordinates:FALSE
|
||||
alignCorners:align_corners
|
||||
paddingMode:paddingMode
|
||||
nearestRoundingMode:MPSGraphResizeNearestRoundingModeRoundToEven
|
||||
constantValue:0.0f
|
||||
name:nil];
|
||||
} else {
|
||||
outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor
|
||||
coordinateTensor: gridTensor
|
||||
layout: inputTensorLayout
|
||||
normalizeCoordinates: TRUE
|
||||
relativeCoordinates: FALSE
|
||||
alignCorners: align_corners
|
||||
paddingMode: paddingMode
|
||||
samplingMode: samplingMode
|
||||
constantValue: 0.0f
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph sampleGridWithSourceTensor:inputTensor
|
||||
coordinateTensor:gridTensor
|
||||
layout:inputTensorLayout
|
||||
normalizeCoordinates:TRUE
|
||||
relativeCoordinates:FALSE
|
||||
alignCorners:align_corners
|
||||
paddingMode:paddingMode
|
||||
samplingMode:samplingMode
|
||||
constantValue:0.0f
|
||||
name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -104,29 +111,29 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
|
||||
Placeholder gridPlaceholder = Placeholder(cachedGraph->gridTensor_, grid);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
gridPlaceholder.getMPSGraphTensor() : gridPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
#endif // defined(__MAC_13_2)
|
||||
}
|
||||
|
||||
Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid,
|
||||
int64_t interpolation_mode, int64_t padding_mode,
|
||||
Tensor grid_sampler_2d_mps(const Tensor& input,
|
||||
const Tensor& grid,
|
||||
int64_t interpolation_mode,
|
||||
int64_t padding_mode,
|
||||
bool align_corners) {
|
||||
#if defined(__MAC_13_2)
|
||||
bool xcode_sdk_13_2_or_higher = true;
|
||||
@ -138,17 +145,16 @@ Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid,
|
||||
TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.1. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
|
||||
return at::grid_sampler_2d(
|
||||
input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners).clone().to("mps");
|
||||
return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
|
||||
auto in_size = input.sizes();
|
||||
auto grid_size = grid.sizes();
|
||||
auto output = at::empty(
|
||||
{in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
|
||||
auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
|
||||
|
||||
grid_sampler_2d_mps_impl(
|
||||
output, input, grid, interpolation_mode, padding_mode, align_corners);
|
||||
grid_sampler_2d_mps_impl(output, input, grid, interpolation_mode, padding_mode, align_corners);
|
||||
return output;
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,90 +1,82 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info)
|
||||
{
|
||||
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
|
||||
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
|
||||
auto cpu_result = result.clone().to("cpu");
|
||||
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
|
||||
info.copy_(cpu_info);
|
||||
result.copy_(cpu_result);
|
||||
return;
|
||||
}
|
||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
|
||||
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
|
||||
auto cpu_result = result.clone().to("cpu");
|
||||
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
|
||||
info.copy_(cpu_info);
|
||||
result.copy_(cpu_result);
|
||||
return;
|
||||
}
|
||||
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
info.zero_();
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
info.zero_();
|
||||
|
||||
if (A.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
if (A.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
Tensor output = result;
|
||||
bool isContiguous = true;
|
||||
if (!result.is_contiguous()) {
|
||||
output = result.contiguous();
|
||||
isContiguous = false;
|
||||
}
|
||||
Tensor output = result;
|
||||
bool isContiguous = true;
|
||||
if (!result.is_contiguous()) {
|
||||
output = result.contiguous();
|
||||
isContiguous = false;
|
||||
}
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "inv_out_mps" + getTensorsStringKey({A});
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph)
|
||||
{
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
@autoreleasepool {
|
||||
string key = "inv_out_mps" + getTensorsStringKey({A});
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
|
||||
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A);
|
||||
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor
|
||||
name: nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
|
||||
return newCachedGraph;
|
||||
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
if (!isContiguous) {
|
||||
result.copy_(output);
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
if (!isContiguous) {
|
||||
result.copy_(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -6,17 +6,14 @@ namespace at::native {
|
||||
|
||||
using namespace mps;
|
||||
|
||||
Tensor _mps_linear(
|
||||
const Tensor& input,
|
||||
const Tensor& weight_arg,
|
||||
const c10::optional<Tensor>& bias_opt) {
|
||||
Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::optional<Tensor>& bias_opt) {
|
||||
// wT = transpose(weight);
|
||||
// y=x*wT+b
|
||||
|
||||
auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg;
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == ScalarType::Float ||
|
||||
input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs");
|
||||
TORCH_CHECK(input.scalar_type() == ScalarType::Float || input.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support linear for non-float inputs");
|
||||
|
||||
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
|
||||
bool is_bias_defined = bias.defined();
|
||||
@ -24,24 +21,19 @@ Tensor _mps_linear(
|
||||
auto input_size = input.sizes();
|
||||
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
|
||||
output_size.push_back(weight.size(0));
|
||||
Tensor output = at::native::empty_mps(output_size,
|
||||
input.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
input.suggest_memory_format());
|
||||
Tensor output = at::native::empty_mps(
|
||||
output_size, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, input.suggest_memory_format());
|
||||
|
||||
TORCH_CHECK(output.is_mps());
|
||||
|
||||
if(output.numel() == 0) {
|
||||
if (output.numel() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
MPSStream *stream = getCurrentMPSStream();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
@ -51,14 +43,12 @@ Tensor _mps_linear(
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mps_linear" + getTensorsStringKey({input, weight, bias}) ;
|
||||
string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
@ -71,14 +61,11 @@ Tensor _mps_linear(
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
if (!is_bias_defined)
|
||||
{
|
||||
if (!is_bias_defined) {
|
||||
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:weightTransposeTensor
|
||||
name:nil];
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
MPSGraphTensor* inputFlattened = inputTensor;
|
||||
bool doReshape = false;
|
||||
// workaround to improve the performance with 3D+ inputs
|
||||
@ -92,9 +79,10 @@ Tensor _mps_linear(
|
||||
secondaryTensor:weightTransposeTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* biasedTensor = [mpsGraph additionWithPrimaryTensor:xMulWTTensor
|
||||
secondaryTensor:newCachedGraph->biasTensor_
|
||||
name:nil];
|
||||
outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil] : biasedTensor;
|
||||
secondaryTensor:newCachedGraph->biasTensor_
|
||||
name:nil];
|
||||
outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil]
|
||||
: biasedTensor;
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -110,89 +98,76 @@ Tensor _mps_linear(
|
||||
Placeholder biasPlaceholder = Placeholder();
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
|
||||
if (is_bias_defined) {
|
||||
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
|
||||
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
// Shave off '1' present at the end of the shape
|
||||
if(weight_arg.dim() == 1) {
|
||||
if (weight_arg.dim() == 1) {
|
||||
// Number of elements in new output shape
|
||||
auto output_sizes = output.sizes();
|
||||
std::vector<int64_t> out_shape(output_sizes.begin(), output_sizes.end()-1);
|
||||
std::vector<int64_t> out_shape(output_sizes.begin(), output_sizes.end() - 1);
|
||||
return output.view(IntArrayRef(out_shape));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor _mps_linear_backward_input(
|
||||
IntArrayRef input_size,
|
||||
const Tensor & grad_output,
|
||||
const Tensor & weight)
|
||||
{
|
||||
TORCH_CHECK(grad_output.is_mps(),
|
||||
"mps_linear_backward: grad_output needs to be mps layout");
|
||||
TORCH_CHECK(weight.device().is_mps() &&
|
||||
(weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)),
|
||||
"mps_linear_backward: unsupported weights data type: ", weight.scalar_type());
|
||||
Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) {
|
||||
TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout");
|
||||
TORCH_CHECK(weight.device().is_mps() && (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)),
|
||||
"mps_linear_backward: unsupported weights data type: ",
|
||||
weight.scalar_type());
|
||||
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double
|
||||
|| grad_output.scalar_type() == ScalarType::Float
|
||||
|| grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double || grad_output.scalar_type() == ScalarType::Float ||
|
||||
grad_output.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support linear backward for non-float inputs");
|
||||
|
||||
const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *weightTensor_ = nil;
|
||||
MPSGraphTensor *gradOutputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
Tensor output = at::native::empty_mps(input_size,
|
||||
grad_output.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
grad_output.suggest_memory_format());
|
||||
Tensor output = at::native::empty_mps(
|
||||
input_size, grad_output.scalar_type(), c10::nullopt, kMPS, c10::nullopt, grad_output.suggest_memory_format());
|
||||
TORCH_CHECK(output.is_mps());
|
||||
if (grad_output.numel() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
MPSStream *stream= getCurrentMPSStream();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
|
||||
string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph *mpsGraph = make_mps_graph();
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
|
||||
MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor *outputTensor =
|
||||
[mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTensor
|
||||
secondaryTensor: weightTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensor
|
||||
secondaryTensor:weightTensor
|
||||
name:nil];
|
||||
|
||||
newCachedGraph->weightTensor_ = weightTensor;
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
@ -211,9 +186,8 @@ Tensor _mps_linear_backward_input(
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
@ -221,27 +195,27 @@ Tensor _mps_linear_backward_input(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight, bool bias_defined)
|
||||
{
|
||||
std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
bool bias_defined) {
|
||||
TORCH_CHECK(grad_output.is_mps() && input.is_mps(),
|
||||
"_mps_linear_backward: grad_output and input needs to be mps layout");
|
||||
"_mps_linear_backward: grad_output and input needs to be mps layout");
|
||||
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float ||
|
||||
grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || grad_output.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support linear backward for non-float inputs");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *weightTensor_ = nil;
|
||||
MPSGraphTensor *gradOutputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
MPSGraphTensor *biasTensor_ = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
};
|
||||
|
||||
auto grad_output_reshaped = grad_output.dim() != 2 ?
|
||||
grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
|
||||
auto grad_output_reshaped =
|
||||
grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
|
||||
auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input;
|
||||
|
||||
TORCH_CHECK(grad_output_reshaped.is_mps());
|
||||
@ -254,59 +228,52 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
c10::nullopt,
|
||||
grad_output.suggest_memory_format());
|
||||
Tensor bias = at::native::empty_mps({grad_output_reshaped.size(1)},
|
||||
grad_output.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
grad_output.suggest_memory_format());
|
||||
grad_output.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
grad_output.suggest_memory_format());
|
||||
TORCH_CHECK(output.is_mps());
|
||||
TORCH_CHECK(bias.is_mps());
|
||||
|
||||
if (grad_output.numel() == 0) {
|
||||
output.zero_();
|
||||
bias.zero_();
|
||||
return std::tuple<Tensor, Tensor>{ output, bias };
|
||||
return std::tuple<Tensor, Tensor>{output, bias};
|
||||
}
|
||||
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
MPSStream *stream= getCurrentMPSStream();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
|
||||
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
|
||||
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
|
||||
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph *mpsGraph = make_mps_graph();
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
|
||||
MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
|
||||
MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
|
||||
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped);
|
||||
|
||||
MPSGraphTensor *gradOutputTransposeTensor =
|
||||
[mpsGraph transposeTensor: gradOutputTensor
|
||||
dimension: -1
|
||||
withDimension: -2
|
||||
name: nil];
|
||||
MPSGraphTensor* gradOutputTransposeTensor = [mpsGraph transposeTensor:gradOutputTensor
|
||||
dimension:-1
|
||||
withDimension:-2
|
||||
name:nil];
|
||||
|
||||
// grad_weight
|
||||
MPSGraphTensor *outputTensor =
|
||||
[mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTransposeTensor
|
||||
secondaryTensor: inputTensor
|
||||
name: nil];
|
||||
MPSGraphTensor *biasTensor = nil;
|
||||
if (bias_defined)
|
||||
{
|
||||
// grad_bias
|
||||
biasTensor = [mpsGraph reductionSumWithTensor: gradOutputTensor
|
||||
axis: 0
|
||||
name: nil];
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTransposeTensor
|
||||
secondaryTensor:inputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if (bias_defined) {
|
||||
// grad_bias
|
||||
biasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:0 name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -338,14 +305,14 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
return std::tuple<Tensor, Tensor>{ output, bias };
|
||||
return std::tuple<Tensor, Tensor>{output, bias};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(
|
||||
const Tensor& input, const Tensor& grad_output,
|
||||
const Tensor& weight, std::array<bool,3> output_mask) {
|
||||
std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
const Tensor& weight,
|
||||
std::array<bool, 3> output_mask) {
|
||||
Tensor grad_input, grad_weight, grad_bias;
|
||||
if (output_mask[0]) {
|
||||
grad_input = _mps_linear_backward_input(input.sizes(), grad_output, weight);
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -12,22 +12,21 @@ namespace at::native {
|
||||
*/
|
||||
|
||||
static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor,
|
||||
bool& transpose_tensor,
|
||||
int64_t& ld_tensor,
|
||||
bool transpose_result,
|
||||
int64_t m, int64_t n) {
|
||||
bool& transpose_tensor,
|
||||
int64_t& ld_tensor,
|
||||
bool transpose_result,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
IntArrayRef tensor_strides = tensor.strides();
|
||||
Tensor tensor_;
|
||||
int fast_dim = transpose_result ? 2 : 1;
|
||||
int leading_dim = transpose_result ? 1 : 2;
|
||||
|
||||
if (tensor_strides[fast_dim] == 1 &&
|
||||
(tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
|
||||
if (tensor_strides[fast_dim] == 1 && (tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
|
||||
transpose_tensor = false;
|
||||
tensor_ = tensor;
|
||||
ld_tensor = tensor_strides[leading_dim];
|
||||
} else if ((tensor_strides[leading_dim] == 1) &&
|
||||
(tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
|
||||
} else if ((tensor_strides[leading_dim] == 1) && (tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
|
||||
transpose_tensor = true;
|
||||
tensor_ = tensor;
|
||||
ld_tensor = tensor_strides[fast_dim];
|
||||
@ -50,14 +49,13 @@ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor,
|
||||
* Helper functions to be used for mm/addmm for detecting the Transpositions
|
||||
* when doing GEMM operations.
|
||||
*/
|
||||
void prepare_matrices_for_broadcasting(
|
||||
const Tensor * bias,
|
||||
const Tensor & self,
|
||||
const Tensor & other,
|
||||
const Scalar * beta,
|
||||
bool * transpose_mat1_times_mat2,
|
||||
bool & transpose_mat1,
|
||||
bool & transpose_mat2) {
|
||||
void prepare_matrices_for_broadcasting(const Tensor* bias,
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar* beta,
|
||||
bool* transpose_mat1_times_mat2,
|
||||
bool& transpose_mat1,
|
||||
bool& transpose_mat2) {
|
||||
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
|
||||
if (bias && beta->toDouble() != 0.0f) {
|
||||
TORCH_CHECK(bias->dim() == 2, "tensors must be 2-D");
|
||||
@ -79,20 +77,14 @@ void prepare_matrices_for_broadcasting(
|
||||
}
|
||||
}
|
||||
|
||||
enum LinearAlgebraOpType {
|
||||
ADDBMM_OP_TYPE,
|
||||
BADDBMM_OP_TYPE
|
||||
};
|
||||
enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
|
||||
|
||||
Tensor& mm_out_mps_impl(
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
Tensor& output) {
|
||||
Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||
using namespace mps;
|
||||
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double
|
||||
|| self.scalar_type() == ScalarType::Float
|
||||
|| self.scalar_type() == ScalarType::Half, "MPS device does not support mm for non-float inputs");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float ||
|
||||
self.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support mm for non-float inputs");
|
||||
|
||||
TensorArg args[]{{output, "out", 0}, {self, "mat1", 1}, {other, "mat2", 2}};
|
||||
checkAllSameGPU("mm", args);
|
||||
@ -105,47 +97,41 @@ Tensor& mm_out_mps_impl(
|
||||
return output;
|
||||
}
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor_ = nil;
|
||||
MPSGraphTensor *otherTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *selfTensor = nil;
|
||||
MPSGraphTensor *otherTensor = nil;
|
||||
MPSGraphTensor *outputTensor = nil;
|
||||
|
||||
if(self.numel() == 0 || other.numel() == 0) {
|
||||
MPSGraphTensor* selfTensor = nil;
|
||||
MPSGraphTensor* otherTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
if (self.numel() == 0 || other.numel() == 0) {
|
||||
outputTensor = [mpsGraph constantWithScalar:0.
|
||||
shape:getMPSShape(output_sizes)
|
||||
dataType:getMPSDataType(output)];
|
||||
|
||||
}
|
||||
else {
|
||||
dataType:getMPSDataType(output)];
|
||||
|
||||
} else {
|
||||
selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfTensor
|
||||
secondaryTensor:otherTensor
|
||||
name:nil];
|
||||
@ -157,11 +143,11 @@ Tensor& mm_out_mps_impl(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
Placeholder selfPlaceholder = Placeholder();
|
||||
Placeholder otherPlaceholder = Placeholder();
|
||||
if(!(self.numel() == 0 || other.numel() == 0)) {
|
||||
if (!(self.numel() == 0 || other.numel() == 0)) {
|
||||
selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
|
||||
otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
|
||||
}
|
||||
@ -169,15 +155,14 @@ Tensor& mm_out_mps_impl(
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
|
||||
if(!(self.numel() == 0 || other.numel() == 0))
|
||||
if (!(self.numel() == 0 || other.numel() == 0))
|
||||
feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -185,26 +170,25 @@ Tensor& mm_out_mps_impl(
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
Tensor addr_mps(const Tensor& self,
|
||||
const Tensor& vec1, const Tensor& vec2,
|
||||
const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
addr_out_mps(self, vec1,vec2,beta,alpha,result);
|
||||
addr_out_mps(self, vec1, vec2, beta, alpha, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Tensor& addr_out_mps(const Tensor& self,
|
||||
const Tensor& vec1, const Tensor& vec2,
|
||||
const Scalar& beta, const Scalar& alpha, Tensor &result) {
|
||||
const Tensor& vec1,
|
||||
const Tensor& vec2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(result.is_mps());
|
||||
TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
|
||||
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
|
||||
|| vec1.scalar_type() == ScalarType::Float
|
||||
|| vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");
|
||||
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double || vec1.scalar_type() == ScalarType::Float ||
|
||||
vec1.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support addr for non-float input");
|
||||
|
||||
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
|
||||
checkAllSameGPU(__func__, args);
|
||||
@ -239,37 +223,34 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
bool is_beta_non_zero = beta.toDouble() != 0.0;
|
||||
MPSShape* inputShape = @[@(vec1.numel()), @(1)];
|
||||
MPSShape* otherShape = @[@(1), @(vec2.numel())];
|
||||
MPSShape* inputShape = @[ @(vec1.numel()), @(1) ];
|
||||
MPSShape* otherShape = @[ @(1), @(vec2.numel()) ];
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *vec1Tensor_ = nil;
|
||||
MPSGraphTensor *vec2Tensor_ = nil;
|
||||
MPSGraphTensor *selfTensor_ = nil;
|
||||
MPSGraphTensor *resultTensor_ = nil;
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* vec1Tensor_ = nil;
|
||||
MPSGraphTensor* vec2Tensor_ = nil;
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* resultTensor_ = nil;
|
||||
};
|
||||
|
||||
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
|
||||
+ ":" + to_string(beta.toDouble())
|
||||
+ ":" + to_string(alpha.toDouble());
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) +
|
||||
":" + to_string(alpha.toDouble());
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
|
||||
MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);
|
||||
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
|
||||
MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
|
||||
MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);
|
||||
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
|
||||
|
||||
// Intermediate as placeholder
|
||||
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
|
||||
@ -280,7 +261,7 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
|
||||
dataType:getMPSScalarType((*self_).scalar_type())];
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
|
||||
dataType:getMPSScalarType(vec1.scalar_type())];
|
||||
dataType:getMPSScalarType(vec1.scalar_type())];
|
||||
|
||||
// Intermediates for multiplying by beta and alpha
|
||||
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor
|
||||
@ -298,7 +279,7 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
resultTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
|
||||
secondaryTensor:selfTimesBetaTensor
|
||||
name:@"MM/beta*input+alpha*(vec1@vec2)"];
|
||||
}
|
||||
}
|
||||
|
||||
newCachedGraph->vec1Tensor_ = t1;
|
||||
newCachedGraph->vec2Tensor_ = t2;
|
||||
@ -307,7 +288,7 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape);
|
||||
@ -321,9 +302,8 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -331,20 +311,19 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& addmm_out_mps_impl(
|
||||
const Tensor& bias,
|
||||
const Tensor& self, // input
|
||||
const Tensor& other, // weight
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& output) {
|
||||
Tensor& addmm_out_mps_impl(const Tensor& bias,
|
||||
const Tensor& self, // input
|
||||
const Tensor& other, // weight
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(output.is_mps());
|
||||
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double
|
||||
|| self.scalar_type() == ScalarType::Float
|
||||
|| self.scalar_type() == ScalarType::Half, "MPS device does not support addmm for non-float input");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float ||
|
||||
self.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support addmm for non-float input");
|
||||
|
||||
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, args);
|
||||
@ -378,62 +357,52 @@ Tensor& addmm_out_mps_impl(
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
bool transpose_mat1_times_mat2 = false;
|
||||
bool transpose_mat1 = false;
|
||||
bool transpose_mat2 = false;
|
||||
bool is_beta_non_zero = beta.toDouble() != 0.0;
|
||||
bool transpose_mat1 = false;
|
||||
bool transpose_mat2 = false;
|
||||
bool is_beta_non_zero = beta.toDouble() != 0.0;
|
||||
|
||||
prepare_matrices_for_broadcasting(&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2);
|
||||
prepare_matrices_for_broadcasting(
|
||||
&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2);
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor_ = nil;
|
||||
MPSGraphTensor *otherTensor_ = nil;
|
||||
MPSGraphTensor *biasTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_})
|
||||
+ ":" + to_string(transpose_mat1) + ":" + to_string(transpose_mat2)
|
||||
+ ":" + to_string(beta.toDouble())
|
||||
+ ":" + to_string(alpha.toDouble());
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(transpose_mat1) +
|
||||
":" + to_string(transpose_mat2) + ":" + to_string(beta.toDouble()) + ":" + to_string(alpha.toDouble());
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
MPSGraphTensor *biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
|
||||
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
MPSGraphTensor* biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
|
||||
|
||||
MPSGraphTensor* t1 = nil;
|
||||
MPSGraphTensor* t2 = nil;
|
||||
|
||||
if(transpose_mat1)
|
||||
t1 = [mpsGraph transposeTensor:selfTensor
|
||||
dimension:-1
|
||||
withDimension:-2
|
||||
name:nil];
|
||||
if (transpose_mat1)
|
||||
t1 = [mpsGraph transposeTensor:selfTensor dimension:-1 withDimension:-2 name:nil];
|
||||
else
|
||||
t1 = selfTensor;
|
||||
|
||||
if(transpose_mat2)
|
||||
t2 = [mpsGraph transposeTensor:otherTensor
|
||||
dimension:-1
|
||||
withDimension:-2
|
||||
name:nil];
|
||||
if (transpose_mat2)
|
||||
t2 = [mpsGraph transposeTensor:otherTensor dimension:-1 withDimension:-2 name:nil];
|
||||
else
|
||||
t2 = otherTensor;
|
||||
|
||||
|
||||
// TODO: Use alpha and beta here with fill_.Scalar and mul
|
||||
// Intermediate as placeholder
|
||||
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
|
||||
@ -444,7 +413,7 @@ Tensor& addmm_out_mps_impl(
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
|
||||
dataType:getMPSScalarType((*bias_).scalar_type())];
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
|
||||
dataType:getMPSScalarType(self.scalar_type())];
|
||||
dataType:getMPSScalarType(self.scalar_type())];
|
||||
|
||||
// Intermediates for multiplying by beta and alpha
|
||||
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor
|
||||
@ -458,17 +427,14 @@ Tensor& addmm_out_mps_impl(
|
||||
}
|
||||
|
||||
if (transpose_mat1_times_mat2)
|
||||
biasTimesBetaTensor = [mpsGraph transposeTensor: biasTimesBetaTensor
|
||||
dimension: -1
|
||||
withDimension: -2
|
||||
name: nil];
|
||||
biasTimesBetaTensor = [mpsGraph transposeTensor:biasTimesBetaTensor dimension:-1 withDimension:-2 name:nil];
|
||||
|
||||
MPSGraphTensor* outputTensor = productTimesAlphaTensor;
|
||||
if (is_beta_non_zero) {
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
|
||||
secondaryTensor:biasTimesBetaTensor
|
||||
name:@"MM/beta*input + alpha*(mat1@mat2)"];
|
||||
}
|
||||
}
|
||||
|
||||
newCachedGraph->selfTensor_ = selfTensor;
|
||||
newCachedGraph->otherTensor_ = otherTensor;
|
||||
@ -477,7 +443,7 @@ Tensor& addmm_out_mps_impl(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
|
||||
@ -491,9 +457,8 @@ Tensor& addmm_out_mps_impl(
|
||||
biasPlaceholder.getMPSGraphTensor() : biasPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -501,16 +466,12 @@ Tensor& addmm_out_mps_impl(
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
Tensor& bmm_out_mps_impl(
|
||||
const Tensor & batch1,
|
||||
const Tensor & batch2,
|
||||
Tensor & result) {
|
||||
Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
|
||||
|| batch1.scalar_type() == ScalarType::Float
|
||||
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs");
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float ||
|
||||
batch1.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support bmm for non-float inputs");
|
||||
|
||||
if (batch1.numel() == 0 || batch2.numel() == 0) {
|
||||
result.zero_();
|
||||
@ -519,31 +480,29 @@ Tensor& bmm_out_mps_impl(
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *batch1Tensor_ = nil;
|
||||
MPSGraphTensor *batch2Tensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* batch1Tensor_ = nil;
|
||||
MPSGraphTensor* batch2Tensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2});
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
|
||||
MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
|
||||
MPSGraphTensor* batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
|
||||
MPSGraphTensor* batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
|
||||
|
||||
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor
|
||||
secondaryTensor:batch2Tensor
|
||||
@ -555,7 +514,7 @@ Tensor& bmm_out_mps_impl(
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1);
|
||||
Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2);
|
||||
@ -566,9 +525,8 @@ Tensor& bmm_out_mps_impl(
|
||||
batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -576,14 +534,13 @@ Tensor& bmm_out_mps_impl(
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
const Tensor & input,
|
||||
const Tensor & batch1,
|
||||
const Tensor & batch2,
|
||||
const Scalar & beta,
|
||||
const Scalar & alpha,
|
||||
Tensor & result,
|
||||
LinearAlgebraOpType opType) {
|
||||
Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& result,
|
||||
LinearAlgebraOpType opType) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(input.is_mps());
|
||||
@ -591,22 +548,29 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
TORCH_CHECK(batch2.is_mps());
|
||||
TORCH_CHECK(result.is_mps());
|
||||
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
|
||||
|| batch1.scalar_type() == ScalarType::Float
|
||||
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support addbmm or baddbmm for non-float inputs");
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float ||
|
||||
batch1.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support addbmm or baddbmm for non-float inputs");
|
||||
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
|
||||
TORCH_CHECK(batch1.size(0) == batch2.size(0),
|
||||
"batch1 and batch2 must have same number of batches, got ",
|
||||
batch1.size(0), " and ", batch2.size(0));
|
||||
"batch1 and batch2 must have same number of batches, got ",
|
||||
batch1.size(0),
|
||||
" and ",
|
||||
batch2.size(0));
|
||||
TORCH_CHECK(batch1.size(2) == batch2.size(1),
|
||||
"Incompatible matrix sizes for bmm (",
|
||||
batch1.size(1), "x", batch1.size(2), " and ",
|
||||
batch2.size(1), "x", batch2.size(2), ")");
|
||||
"Incompatible matrix sizes for bmm (",
|
||||
batch1.size(1),
|
||||
"x",
|
||||
batch1.size(2),
|
||||
" and ",
|
||||
batch2.size(1),
|
||||
"x",
|
||||
batch2.size(2),
|
||||
")");
|
||||
|
||||
if (opType == ADDBMM_OP_TYPE)
|
||||
{
|
||||
if (opType == ADDBMM_OP_TYPE) {
|
||||
result.resize_as_(input);
|
||||
|
||||
const int64_t num_batches = batch1.size(0);
|
||||
@ -619,42 +583,39 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *batch1Tensor_ = nil;
|
||||
MPSGraphTensor *batch2Tensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* batch1Tensor_ = nil;
|
||||
MPSGraphTensor* batch2Tensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
|
||||
key += getTensorsStringKey({batch1, batch2, input})
|
||||
+ ":" + to_string(beta.toDouble())
|
||||
+ ":" + to_string(alpha.toDouble());
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" +
|
||||
to_string(alpha.toDouble());
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool{
|
||||
MPSGraph *mpsGraph = mps::make_mps_graph();
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
|
||||
MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
|
||||
MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
MPSGraphTensor* batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1);
|
||||
MPSGraphTensor* batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2);
|
||||
|
||||
// Intermediates for beta and alpha
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta.toDouble()
|
||||
dataType: getMPSScalarType(input.scalar_type())];
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha.toDouble()
|
||||
dataType: getMPSScalarType(batch1.scalar_type())];
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
|
||||
dataType:getMPSScalarType(input.scalar_type())];
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
|
||||
dataType:getMPSScalarType(batch1.scalar_type())];
|
||||
|
||||
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor
|
||||
secondaryTensor:batch2Tensor
|
||||
@ -662,46 +623,46 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
|
||||
MPSGraphTensor* reductionSumTensor = productTensor;
|
||||
if (opType == ADDBMM_OP_TYPE) {
|
||||
reductionSumTensor = [mpsGraph reductionSumWithTensor: productTensor
|
||||
axis: 0
|
||||
name: @"reductionSum(batch1@batch2)"];
|
||||
reductionSumTensor = [mpsGraph reductionSumWithTensor:productTensor
|
||||
axis:0
|
||||
name:@"reductionSum(batch1@batch2)"];
|
||||
}
|
||||
|
||||
// Intermediates for multiplying by beta and alpha
|
||||
MPSGraphTensor* reductionSumTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: reductionSumTensor
|
||||
secondaryTensor: alphaTensor
|
||||
name: @"alpha*(batch1@batch2)"];
|
||||
MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
|
||||
secondaryTensor: betaTensor
|
||||
name: @"beta*input"];
|
||||
MPSGraphTensor* reductionSumTimesAlphaTensor =
|
||||
[mpsGraph multiplicationWithPrimaryTensor:reductionSumTensor
|
||||
secondaryTensor:alphaTensor
|
||||
name:@"alpha*(batch1@batch2)"];
|
||||
MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:betaTensor
|
||||
name:@"beta*input"];
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:reductionSumTimesAlphaTensor
|
||||
secondaryTensor:biasTimesBetaTensor
|
||||
name:@"beta*input + alpha*(batch1@batch2)"];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->batch1Tensor_ = batch1Tensor;
|
||||
newCachedGraph->batch2Tensor_ = batch2Tensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
|
||||
Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1);
|
||||
Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
batch1Placeholder.getMPSGraphTensor() : batch1Placeholder.getMPSGraphTensorData(),
|
||||
batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -713,40 +674,67 @@ TORCH_IMPL_FUNC(mm_out_mps)(const Tensor& self, const Tensor& mat2, const Tensor
|
||||
mm_out_mps_impl(self, mat2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_out_mps)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmm_out_mps)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
addmm_out_mps_impl(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(bmm_out_mps) (const Tensor & batch1, const Tensor & batch2, const Tensor & result) {
|
||||
TORCH_IMPL_FUNC(bmm_out_mps)(const Tensor& batch1, const Tensor& batch2, const Tensor& result) {
|
||||
bmm_out_mps_impl(batch1, batch2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(baddbmm_out_mps) (const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(baddbmm_out_mps)
|
||||
(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
addbmm_or_baddbmm_out_mps_impl(self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result), BADDBMM_OP_TYPE);
|
||||
}
|
||||
|
||||
Tensor& addbmm_out_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
|
||||
Tensor& addbmm_out_mps(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& result) {
|
||||
auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out");
|
||||
|
||||
addbmm_or_baddbmm_out_mps_impl(*b_self, batch1, batch2, beta, alpha, result, ADDBMM_OP_TYPE);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor addbmm_mps(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
return addbmm_out_mps(self, batch1, batch2, beta, alpha, result);
|
||||
}
|
||||
|
||||
Tensor &addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor& addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
|
||||
return addbmm_out_mps(self, batch1, batch2, beta, alpha, self);
|
||||
}
|
||||
|
||||
Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) {
|
||||
Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool upper,
|
||||
bool transpose,
|
||||
bool left,
|
||||
bool unitriangular,
|
||||
Tensor& out) {
|
||||
using namespace mps;
|
||||
|
||||
checkInputsSolver(A, B, left, "linalg.solve_triangular");
|
||||
Tensor A_t, B_t;
|
||||
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
|
||||
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr);
|
||||
at::native::resize_output(out, B_t.sizes());
|
||||
|
||||
if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) {
|
||||
@ -768,7 +756,7 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
|
||||
dispatch_sync(mpsStream->queue(), ^(){
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1;
|
||||
@ -779,7 +767,7 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
uint64_t aElemSize = A_.element_size();
|
||||
uint64_t bElemSize = B_.element_size();
|
||||
|
||||
MPSMatrixSolveTriangular *filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device
|
||||
MPSMatrixSolveTriangular* filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device
|
||||
right:!left
|
||||
upper:upper
|
||||
transpose:transpose
|
||||
@ -794,22 +782,24 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
rowBytes:aCols * aElemSize
|
||||
matrixBytes:aRows * aCols * aElemSize
|
||||
dataType:getMPSDataType(A_)];
|
||||
MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows
|
||||
columns:bCols
|
||||
matrices:batchSize
|
||||
rowBytes:bCols * bElemSize
|
||||
matrixBytes:bRows * bCols * bElemSize
|
||||
dataType:getMPSDataType(B_)];
|
||||
for (const auto i: c10::irange(batchSize)) {
|
||||
MPSMatrixDescriptor* rightHandSideMatrixDesc =
|
||||
[MPSMatrixDescriptor matrixDescriptorWithRows:bRows
|
||||
columns:bCols
|
||||
matrices:batchSize
|
||||
rowBytes:bCols * bElemSize
|
||||
matrixBytes:bRows * bCols * bElemSize
|
||||
dataType:getMPSDataType(B_)];
|
||||
for (const auto i : c10::irange(batchSize)) {
|
||||
const uint64_t aBatchOffset = i * aRows * aCols;
|
||||
const uint64_t bBatchOffset = i * bRows * bCols;
|
||||
MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
|
||||
offset:(A_t.storage_offset() + aBatchOffset) * aElemSize
|
||||
descriptor:sourceMatrixDesc] autorelease];
|
||||
MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer
|
||||
offset:(B_t.storage_offset() + bBatchOffset) * bElemSize
|
||||
descriptor:rightHandSideMatrixDesc] autorelease];
|
||||
MPSMatrix *solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
|
||||
MPSMatrix* rightHandSideMatrix =
|
||||
[[[MPSMatrix alloc] initWithBuffer:bBuffer
|
||||
offset:(B_t.storage_offset() + bBatchOffset) * bElemSize
|
||||
descriptor:rightHandSideMatrixDesc] autorelease];
|
||||
MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
|
||||
offset:(out.storage_offset() + bBatchOffset) * bElemSize
|
||||
descriptor:rightHandSideMatrixDesc] autorelease];
|
||||
|
||||
@ -824,7 +814,12 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) {
|
||||
Tensor& linalg_solve_triangular_mps_out(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool upper,
|
||||
bool left,
|
||||
bool unitriangular,
|
||||
Tensor& out) {
|
||||
return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
|
||||
}
|
||||
|
||||
@ -834,7 +829,14 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper,
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
|
||||
TORCH_IMPL_FUNC(triangular_solve_mps_out)
|
||||
(const Tensor& self,
|
||||
const Tensor& A,
|
||||
bool upper,
|
||||
bool transpose,
|
||||
bool unitriangular,
|
||||
const Tensor& result,
|
||||
const Tensor& clone_A) {
|
||||
clone_A.copy_(A);
|
||||
Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
|
||||
linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out);
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -7,15 +7,18 @@ namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
// Pad operations (1D/2D/3D forward and backward)
|
||||
Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding,
|
||||
Tensor& pad_out_template(Tensor& output,
|
||||
const Tensor& input_,
|
||||
IntArrayRef padding,
|
||||
const c10::optional<Tensor>& grad_output_opt,
|
||||
MPSGraphPaddingMode mode, double constantValue, const string op_name)
|
||||
{
|
||||
const int padding_size = (int) padding.size();
|
||||
MPSGraphPaddingMode mode,
|
||||
double constantValue,
|
||||
const string op_name) {
|
||||
const int padding_size = (int)padding.size();
|
||||
int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
|
||||
|
||||
TORCH_CHECK(padding_size == 2 || padding_size == 4 || padding_size == 6,
|
||||
"invalid padding argument of size ", padding_size);
|
||||
TORCH_CHECK(
|
||||
padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size);
|
||||
|
||||
const Tensor& grad_output_ = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const bool is_backward_pass = grad_output_.defined();
|
||||
@ -23,8 +26,13 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
int64_t nbatch = 1;
|
||||
int64_t ndims = input_.ndimension();
|
||||
|
||||
TORCH_CHECK(ndims >= (int64_t)padding_dim, "Length of pad should be no more than twice the number of "
|
||||
"dimensions of the input. Pad length is ", padding_size, "while the input has ", ndims, "dimensions.");
|
||||
TORCH_CHECK(ndims >= (int64_t)padding_dim,
|
||||
"Length of pad should be no more than twice the number of "
|
||||
"dimensions of the input. Pad length is ",
|
||||
padding_size,
|
||||
"while the input has ",
|
||||
ndims,
|
||||
"dimensions.");
|
||||
|
||||
// number of input dims with ConstantPad could be less than 2
|
||||
int dim_w = padding_dim;
|
||||
@ -35,8 +43,9 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
if (!is_backward_pass && mode != MPSGraphPaddingModeConstant && ndims > padding_dim) {
|
||||
bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0;
|
||||
TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) ||
|
||||
(ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0),
|
||||
"3D or 4D (batch mode) tensor expected for input, but got: ", input_);
|
||||
(ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0),
|
||||
"3D or 4D (batch mode) tensor expected for input, but got: ",
|
||||
input_);
|
||||
}
|
||||
|
||||
if (ndims == padding_dim) {
|
||||
@ -59,11 +68,11 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
int64_t pad_t = padding_size > 2 ? padding[2] : 0;
|
||||
int64_t pad_b = padding_size > 2 ? padding[3] : 0;
|
||||
int64_t pad_front = padding_size > 4 ? padding[4] : 0;
|
||||
int64_t pad_back = padding_size > 4 ? padding[5] : 0;
|
||||
int64_t pad_back = padding_size > 4 ? padding[5] : 0;
|
||||
|
||||
int64_t nplane = input_.size(dim_slices);
|
||||
int64_t input_w = input_.size(dim_w);
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
int64_t output_w = input_w + pad_l + pad_r;
|
||||
int64_t input_h = padding_dim > 1 ? input_.size(dim_h) : 0;
|
||||
int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0;
|
||||
int64_t input_d = padding_dim > 2 ? input_.size(dim_d) : 0;
|
||||
@ -73,8 +82,15 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
|
||||
if (!is_backward_pass) {
|
||||
TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1,
|
||||
"input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
|
||||
"output H: ", output_h, " W: ", output_w);
|
||||
"input (H: ",
|
||||
input_h,
|
||||
", W: ",
|
||||
input_w,
|
||||
") is too small. Calculated "
|
||||
"output H: ",
|
||||
output_h,
|
||||
" W: ",
|
||||
output_w);
|
||||
|
||||
std::vector<int64_t> outputSizes;
|
||||
if (mode == MPSGraphPaddingModeConstant) {
|
||||
@ -83,7 +99,7 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
auto ori_padding_dim = padding_size / 2;
|
||||
auto l_diff = ndims - ori_padding_dim;
|
||||
|
||||
for (size_t i = 0; i < (size_t)l_diff; i ++) {
|
||||
for (size_t i = 0; i < (size_t)l_diff; i++) {
|
||||
outputSizes.emplace_back(input_sizes[i]);
|
||||
}
|
||||
for (const auto i : c10::irange((size_t)ori_padding_dim)) {
|
||||
@ -94,21 +110,39 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
} else {
|
||||
// these checks aren't relevant for constant pad
|
||||
TORCH_CHECK(pad_l < input_w && pad_r < input_w,
|
||||
"Argument #4: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (", pad_l, ", ", pad_r,
|
||||
") at dimension ", dim_w, " of input ", ndims);
|
||||
"Argument #4: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (",
|
||||
pad_l,
|
||||
", ",
|
||||
pad_r,
|
||||
") at dimension ",
|
||||
dim_w,
|
||||
" of input ",
|
||||
ndims);
|
||||
|
||||
if (padding_dim > 1) {
|
||||
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
|
||||
"Argument #6: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (", pad_t, ", ", pad_b,
|
||||
") at dimension ", dim_h, " of input ", ndims);
|
||||
"Argument #6: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (",
|
||||
pad_t,
|
||||
", ",
|
||||
pad_b,
|
||||
") at dimension ",
|
||||
dim_h,
|
||||
" of input ",
|
||||
ndims);
|
||||
}
|
||||
if (padding_dim > 2) {
|
||||
TORCH_CHECK(pad_front < input_d && pad_back < input_d,
|
||||
"Argument #8: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (", pad_front, ", ", pad_back,
|
||||
") at dimension ", dim_d, " of input ", ndims);
|
||||
"Argument #8: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (",
|
||||
pad_front,
|
||||
", ",
|
||||
pad_back,
|
||||
") at dimension ",
|
||||
dim_d,
|
||||
" of input ",
|
||||
ndims);
|
||||
}
|
||||
outputSizes.insert(outputSizes.begin(), output_w);
|
||||
if (padding_dim >= 2)
|
||||
@ -133,10 +167,16 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
input = input_.contiguous();
|
||||
} else {
|
||||
TORCH_CHECK(output_w == grad_output_.size(dim_w),
|
||||
"gradOutput width unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
|
||||
"gradOutput width unexpected. Expected: ",
|
||||
output_w,
|
||||
", Got: ",
|
||||
grad_output_.size(dim_w));
|
||||
if (padding_dim > 1) {
|
||||
TORCH_CHECK(output_h == grad_output_.size(dim_h),
|
||||
"gradOutput height unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
|
||||
"gradOutput height unexpected. Expected: ",
|
||||
output_h,
|
||||
", Got: ",
|
||||
grad_output_.size(dim_h));
|
||||
}
|
||||
output.resize_as_(input);
|
||||
if (output.numel() == 0 || grad_output_.numel() == 0)
|
||||
@ -153,11 +193,11 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
std::vector<NSNumber*> stridesVec(ndims, @(1));
|
||||
|
||||
for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) {
|
||||
const int64_t leftIdx = pdim * 2;
|
||||
const int64_t leftIdx = pdim * 2;
|
||||
const int64_t rightIdx = pdim * 2 + 1;
|
||||
const int64_t padIdx = ndims - pdim - 1;
|
||||
|
||||
leftPadVec [padIdx] = @(padding[leftIdx]);
|
||||
leftPadVec[padIdx] = @(padding[leftIdx]);
|
||||
rightPadVec[padIdx] = @(padding[rightIdx]);
|
||||
// workaround for negative padding issue in backward pass
|
||||
if (is_backward_pass) {
|
||||
@ -171,7 +211,7 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
endsVec[padIdx] = @(input.size(padIdx) + padding[rightIdx]);
|
||||
endMask &= ~(1U << padIdx);
|
||||
}
|
||||
// workaround for the right padding bug in Monterey
|
||||
// workaround for the right padding bug in Monterey
|
||||
} else if (!is_macos_13_or_newer()) {
|
||||
if (padding[rightIdx] == 1 && padding[leftIdx] == 0) {
|
||||
rightPadVec[padIdx] = @(2);
|
||||
@ -180,8 +220,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
}
|
||||
}
|
||||
}
|
||||
MPSShape *leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims];
|
||||
MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims];
|
||||
MPSShape* leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims];
|
||||
MPSShape* rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims];
|
||||
|
||||
MPSDataType dataType = getMPSScalarType(input.scalar_type());
|
||||
// workaround for Bool type assert with Constant padding
|
||||
@ -190,20 +230,20 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { }
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *gradOutputTensor = nil;
|
||||
MPSGraphTensor* gradOutputTensor = nil;
|
||||
};
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" +
|
||||
getArrayRefString(padding) + "]:" + std::to_string(constantValue);
|
||||
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
|
||||
"]:" + std::to_string(constantValue);
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
@ -211,7 +251,7 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
|
||||
|
||||
if (!is_backward_pass) {
|
||||
MPSGraphTensor *padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor
|
||||
MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor
|
||||
withPaddingMode:mode
|
||||
leftPadding:leftPadding
|
||||
rightPadding:rightPadding
|
||||
@ -219,36 +259,39 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
name:nil];
|
||||
// workaround for the right padding bug in Monterey
|
||||
if (needsSlice) {
|
||||
newCachedGraph->outputTensor = [mpsGraph sliceTensor:padTensor
|
||||
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
|
||||
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
|
||||
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
|
||||
startMask:startMask
|
||||
endMask:endMask
|
||||
squeezeMask:0
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor =
|
||||
[mpsGraph sliceTensor:padTensor
|
||||
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
|
||||
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
|
||||
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
|
||||
startMask:startMask
|
||||
endMask:endMask
|
||||
squeezeMask:0
|
||||
name:nil];
|
||||
} else {
|
||||
newCachedGraph->outputTensor = padTensor;
|
||||
}
|
||||
} else {
|
||||
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output));
|
||||
MPSGraphTensor *padGradTensor = [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor
|
||||
sourceTensor:newCachedGraph->inputTensor
|
||||
paddingMode:mode
|
||||
leftPadding:leftPadding
|
||||
rightPadding:rightPadding
|
||||
name:nil];
|
||||
MPSGraphTensor* padGradTensor =
|
||||
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor
|
||||
sourceTensor:newCachedGraph->inputTensor
|
||||
paddingMode:mode
|
||||
leftPadding:leftPadding
|
||||
rightPadding:rightPadding
|
||||
name:nil];
|
||||
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
|
||||
if (needsSlice) {
|
||||
newCachedGraph->outputTensor = [mpsGraph sliceGradientTensor:padGradTensor
|
||||
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil]
|
||||
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
|
||||
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
|
||||
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
|
||||
startMask:startMask
|
||||
endMask:endMask
|
||||
squeezeMask:0
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor =
|
||||
[mpsGraph sliceGradientTensor:padGradTensor
|
||||
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil]
|
||||
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
|
||||
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
|
||||
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
|
||||
startMask:startMask
|
||||
endMask:endMask
|
||||
squeezeMask:0
|
||||
name:nil];
|
||||
} else {
|
||||
newCachedGraph->outputTensor = padGradTensor;
|
||||
}
|
||||
@ -257,19 +300,19 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
return newCachedGraph;
|
||||
});
|
||||
}
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr, true, dataType);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() :
|
||||
Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass
|
||||
? Placeholder()
|
||||
: Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType);
|
||||
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
if (is_backward_pass) {
|
||||
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return output;
|
||||
@ -278,123 +321,156 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
|
||||
// 1D Reflection and Replication Padding
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad1d_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
|
||||
{
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps");
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input),
|
||||
input,
|
||||
padding,
|
||||
grad_output,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad1d_backward_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad1d_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
|
||||
{
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps");
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input),
|
||||
input,
|
||||
padding,
|
||||
grad_output,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad1d_backward_out_mps");
|
||||
}
|
||||
|
||||
// 2D Reflection and Replication Padding
|
||||
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output)
|
||||
{
|
||||
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) {
|
||||
return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) {
|
||||
Tensor output = at::empty({0}, input.options());
|
||||
return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
|
||||
{
|
||||
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
|
||||
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad2d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad2d_out_mps");
|
||||
}
|
||||
|
||||
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
|
||||
{
|
||||
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
|
||||
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
// 3D Reflection and Replication Padding
|
||||
TORCH_IMPL_FUNC(reflection_pad3d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad3d_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
|
||||
{
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps");
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input),
|
||||
input,
|
||||
padding,
|
||||
grad_output,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad3d_backward_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad3d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad3d_out_mps");
|
||||
}
|
||||
|
||||
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
|
||||
{
|
||||
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
|
||||
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
// backward pass is exlicitly handled in autograd by negating the "pad" argument
|
||||
Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value)
|
||||
{
|
||||
Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) {
|
||||
if (pad.size() > 6) {
|
||||
TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ",
|
||||
"It uses View Ops default implementation to run. This may have performance implications.");
|
||||
return at::native::constant_pad_nd(self, pad, value);
|
||||
}
|
||||
Tensor output = at::empty({0}, self.options());
|
||||
return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
|
||||
return mps::pad_out_template(
|
||||
output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -7,27 +7,25 @@ namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void addc_mul_div_out_mps(const Tensor& self,
|
||||
const Tensor& tensor1,
|
||||
const Tensor& tensor2,
|
||||
const Scalar& value_opt, // default value = 1.0
|
||||
const Tensor& output,
|
||||
const bool is_div,
|
||||
const string op_name)
|
||||
{
|
||||
const Tensor& tensor1,
|
||||
const Tensor& tensor2,
|
||||
const Scalar& value_opt, // default value = 1.0
|
||||
const Tensor& output,
|
||||
const bool is_div,
|
||||
const string op_name) {
|
||||
if (value_opt.toDouble() == 0.0) {
|
||||
output.copy_(self);
|
||||
return;
|
||||
}
|
||||
|
||||
if(output.numel() == 0) {
|
||||
if (output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *firstTensor = nil, *secondTensor = nil, *valueTensor = nil;
|
||||
};
|
||||
@ -39,42 +37,43 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
ScalarType common_dtype =
|
||||
c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type()));
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type()));
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
|
||||
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
|
||||
newCachedGraph->valueTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[ @1 ]);
|
||||
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
|
||||
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
|
||||
newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]);
|
||||
|
||||
// the tensor to be optionally multiplied by value_scalar
|
||||
MPSGraphTensor *multiplicandTensor = nil;
|
||||
auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype);
|
||||
auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype);
|
||||
if (is_div) {
|
||||
multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor
|
||||
secondaryTensor:secondTensor
|
||||
name:nil];
|
||||
} else {
|
||||
multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor
|
||||
secondaryTensor:secondTensor
|
||||
name:nil];
|
||||
}
|
||||
// the tensor to be added to input_tensor
|
||||
MPSGraphTensor *addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor
|
||||
secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype)
|
||||
name:nil];
|
||||
auto outputTensor = [mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype)
|
||||
secondaryTensor:addendTensor
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type());
|
||||
// the tensor to be optionally multiplied by value_scalar
|
||||
MPSGraphTensor* multiplicandTensor = nil;
|
||||
auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype);
|
||||
auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype);
|
||||
if (is_div) {
|
||||
multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor secondaryTensor:secondTensor name:nil];
|
||||
} else {
|
||||
multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor
|
||||
secondaryTensor:secondTensor
|
||||
name:nil];
|
||||
}
|
||||
return newCachedGraph;
|
||||
// the tensor to be added to input_tensor
|
||||
MPSGraphTensor* addendTensor = [mpsGraph
|
||||
multiplicationWithPrimaryTensor:multiplicandTensor
|
||||
secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype)
|
||||
name:nil];
|
||||
auto outputTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype)
|
||||
secondaryTensor:addendTensor
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type());
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
}
|
||||
|
||||
@ -93,9 +92,8 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_scalar),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -105,14 +103,12 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
|
||||
// APIs exposed to at::native scope
|
||||
TORCH_IMPL_FUNC(addcmul_out_mps)
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output)
|
||||
{
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) {
|
||||
mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, false, "addcmul_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addcdiv_out_mps)
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output)
|
||||
{
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) {
|
||||
mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, true, "addcdiv_out_mps");
|
||||
}
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct PoolingCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
PoolingCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct PoolingCachedGraph : public MPSCachedGraph {
|
||||
PoolingCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
MPSGraphTensor* indicesTensor = nil;
|
||||
@ -17,24 +16,30 @@ struct PoolingCachedGraph : public MPSCachedGraph
|
||||
};
|
||||
|
||||
typedef MPSGraphTensor* (^PoolingOpBlock)(PoolingCachedGraph&, MPSGraphPooling2DOpDescriptor*);
|
||||
#define PoolingOpFn(graph, desc) MPSGraphTensor* (mps::PoolingCachedGraph& graph, MPSGraphPooling2DOpDescriptor* desc)
|
||||
#define PoolingOpFn(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph & graph, MPSGraphPooling2DOpDescriptor * desc)
|
||||
|
||||
// Pooling ops (1D/2D forward and backward Max and Average pooling)
|
||||
static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
static void pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const c10::optional<Tensor>& indices_opt,
|
||||
const c10::optional<Tensor>& grad_output_opt,
|
||||
IntArrayRef kernel_size, IntArrayRef stride,
|
||||
IntArrayRef padding, IntArrayRef dilation,
|
||||
bool ceil_mode, bool count_include_pad,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
const c10::optional<int64_t> divisor_override,
|
||||
PoolingOpBlock poolingBlock, const c10::string& op_name)
|
||||
{
|
||||
PoolingOpBlock poolingBlock,
|
||||
const c10::string& op_name) {
|
||||
if (input.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_CHECK(input.scalar_type() != ScalarType::Long,
|
||||
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0.");
|
||||
"MPS: ",
|
||||
op_name,
|
||||
" op with int64 input is supported natively starting from macOS 13.0.");
|
||||
}
|
||||
const int64_t ndims = input.ndimension();
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
@ -48,14 +53,18 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
// be incompatible with the PyTorch's global NCHW layout.
|
||||
const auto memory_format = has_indices ? MemoryFormat::Contiguous : suggested_memory_format;
|
||||
|
||||
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, op_name,
|
||||
": kernel_size must either be a single int, or a tuple of two ints")
|
||||
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, op_name,
|
||||
": stride must either be omitted, a single int, or a tuple of two ints")
|
||||
TORCH_CHECK(padding.size() == 1 || padding.size() == 2, op_name,
|
||||
": padding must be either be a single int, or a tuple of two ints");
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, op_name,
|
||||
": dilation must be either a single int, or a tuple of two ints");
|
||||
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
|
||||
op_name,
|
||||
": kernel_size must either be a single int, or a tuple of two ints")
|
||||
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
|
||||
op_name,
|
||||
": stride must either be omitted, a single int, or a tuple of two ints")
|
||||
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
|
||||
op_name,
|
||||
": padding must be either be a single int, or a tuple of two ints");
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
|
||||
op_name,
|
||||
": dilation must be either a single int, or a tuple of two ints");
|
||||
|
||||
if (suggested_memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
TORCH_CHECK(ndims == 4, "non-empty 4D (batch mode) tensor expected for input with channels_last layout");
|
||||
@ -80,8 +89,21 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
|
||||
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
|
||||
|
||||
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
|
||||
pool2d_shape_check(input,
|
||||
kH,
|
||||
kW,
|
||||
dH,
|
||||
dW,
|
||||
padH,
|
||||
padW,
|
||||
dilationH,
|
||||
dilationW,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
memory_format);
|
||||
|
||||
auto output_memory_format = output.suggest_memory_format();
|
||||
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
|
||||
@ -90,7 +112,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
indices.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
}
|
||||
if (output.numel() == 0) {
|
||||
std::vector<int64_t> outputSizes {nInputPlane, outputHeight, outputWidth};
|
||||
std::vector<int64_t> outputSizes{nInputPlane, outputHeight, outputWidth};
|
||||
if (ndims == 4) {
|
||||
outputSizes.insert(outputSizes.begin(), nbatch);
|
||||
}
|
||||
@ -111,56 +133,57 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" +
|
||||
getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" +
|
||||
getArrayRefString(padding) + "]:D[" + getArrayRefString(dilation) + "]" +
|
||||
(ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
|
||||
(has_divisor ? ":divisor" : "") + ":" +
|
||||
(suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + getArrayRefString(kernel_size) +
|
||||
"]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) + "]:D[" +
|
||||
getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
|
||||
(has_divisor ? ":divisor" : "") + ":" +
|
||||
(suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
|
||||
MPSShape* inputShape = getMPSShape(input, memory_format);
|
||||
MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr;
|
||||
PoolingCachedGraph* cachedGraph = cache_->LookUpAs<PoolingCachedGraph>(key);
|
||||
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<PoolingCachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
PoolingCachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = cache_->CreateCachedGraphAs<PoolingCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
PoolingCachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new PoolingCachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor
|
||||
descriptorWithKernelWidth: kW
|
||||
kernelHeight: kH
|
||||
strideInX: dW
|
||||
strideInY: dH
|
||||
dilationRateInX: dilationW
|
||||
dilationRateInY: dilationH
|
||||
paddingLeft: padW
|
||||
paddingRight: ceil_mode ? padW * dW : padW
|
||||
paddingTop: padH
|
||||
paddingBottom: ceil_mode ? padH * dH : padH
|
||||
paddingStyle: MPSGraphPaddingStyleExplicit
|
||||
dataLayout: memory_format == MemoryFormat::ChannelsLast ?
|
||||
MPSGraphTensorNamedDataLayoutNHWC :
|
||||
MPSGraphTensorNamedDataLayoutNCHW];
|
||||
MPSGraphPooling2DOpDescriptor* desc =
|
||||
[MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:kW
|
||||
kernelHeight:kH
|
||||
strideInX:dW
|
||||
strideInY:dH
|
||||
dilationRateInX:dilationW
|
||||
dilationRateInY:dilationH
|
||||
paddingLeft:padW
|
||||
paddingRight:ceil_mode ? padW * dW : padW
|
||||
paddingTop:padH
|
||||
paddingBottom:ceil_mode ? padH * dH : padH
|
||||
paddingStyle:MPSGraphPaddingStyleExplicit
|
||||
dataLayout:memory_format == MemoryFormat::ChannelsLast
|
||||
? MPSGraphTensorNamedDataLayoutNHWC
|
||||
: MPSGraphTensorNamedDataLayoutNCHW];
|
||||
desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false;
|
||||
if (has_indices) {
|
||||
desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D;
|
||||
desc.returnIndicesDataType = MPSDataTypeInt32;
|
||||
}
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
|
||||
newCachedGraph->inputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
|
||||
if (is_backward_pass) {
|
||||
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
|
||||
newCachedGraph->gradOutputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
|
||||
}
|
||||
if (has_divisor) {
|
||||
newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[@1]);
|
||||
newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[ @1 ]);
|
||||
}
|
||||
MPSGraphTensor* outputTensor = poolingBlock(*newCachedGraph, desc);
|
||||
// with desc.dataLayout = NHWC (i.e., ChannelsLast), the results need to be converted back to NCHW
|
||||
newCachedGraph->outputTensor = memory_format == MemoryFormat::ChannelsLast ?
|
||||
convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
|
||||
newCachedGraph->outputTensor =
|
||||
memory_format == MemoryFormat::ChannelsLast ? convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
@ -168,14 +191,16 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
// in case of ChannelsLast we don't perform gather() in placeholder to avoid implicit conversion to NCHW
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() :
|
||||
Placeholder(cachedGraph->gradOutputTensor, grad_output,
|
||||
gradOutputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder inputPlaceholder =
|
||||
Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass
|
||||
? Placeholder()
|
||||
: Placeholder(
|
||||
cachedGraph->gradOutputTensor, grad_output, gradOutputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder indicesPlaceholder = has_indices ? Placeholder(cachedGraph->indicesTensor, indices) : Placeholder();
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary *results = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* results = [[NSMutableDictionary new] autorelease];
|
||||
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
results[outputPlaceholder.getMPSGraphTensor()] = outputPlaceholder.getMPSGraphTensorData();
|
||||
@ -192,7 +217,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
MPSScalar divisor_scalar;
|
||||
if (cachedGraph->divisorTensor) {
|
||||
const float divisor = float(kH * kW) / (float) divisor_override.value();
|
||||
const float divisor = float(kH * kW) / (float)divisor_override.value();
|
||||
divisor_scalar = getMPSScalar(divisor, ScalarType::Float);
|
||||
feeds[cachedGraph->divisorTensor] = getMPSGraphTensorFromScalar(mpsStream, divisor_scalar);
|
||||
}
|
||||
@ -205,14 +230,17 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
}
|
||||
|
||||
static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
static void avg_pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const c10::optional<Tensor>& grad_output_opt,
|
||||
IntArrayRef kernel_size, IntArrayRef stride,
|
||||
IntArrayRef padding, IntArrayRef dilation,
|
||||
bool ceil_mode, bool count_include_pad,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
const c10::optional<int64_t> divisor_override,
|
||||
const c10::string& op_name)
|
||||
{
|
||||
const c10::string& op_name) {
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const bool is_backward_pass = grad_output.defined();
|
||||
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
|
||||
@ -226,12 +254,21 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
"not supported on MPS backend. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
if (!is_backward_pass) {
|
||||
const_cast<Tensor&>(output) = at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode,
|
||||
count_include_pad, divisor_override).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
.clone()
|
||||
.to("mps");
|
||||
} else {
|
||||
const_cast<Tensor&>(output) = at::avg_pool2d_backward(grad_output.to("cpu"), input.to("cpu"),
|
||||
kernel_size, stride, padding, ceil_mode, count_include_pad,
|
||||
divisor_override).clone().to("mps");
|
||||
const_cast<Tensor&>(output) = at::avg_pool2d_backward(grad_output.to("cpu"),
|
||||
input.to("cpu"),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -239,7 +276,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
const int64_t ndims = input.ndimension();
|
||||
MPSShape *paddingShape = nil;
|
||||
MPSShape* paddingShape = nil;
|
||||
MPSGraphTensor* paddedTensor = cachedGraph.inputTensor;
|
||||
|
||||
// workaround for issue #103039644: mismatching MPS vs. CPU results
|
||||
@ -249,14 +286,14 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
std::vector<NSNumber*> padVec(ndims, @(0));
|
||||
padVec[ndims - 1] = @(padding.size() == 1 ? padding[0] : padding[1]);
|
||||
padVec[ndims - 2] = @(ndims > 3 ? padding[0] : 0);
|
||||
paddingShape = [NSArray arrayWithObjects: padVec.data() count:ndims];
|
||||
paddedTensor = [mpsGraph padTensor: cachedGraph.inputTensor
|
||||
withPaddingMode: MPSGraphPaddingModeZero
|
||||
leftPadding: paddingShape
|
||||
rightPadding: paddingShape
|
||||
constantValue: 0.0
|
||||
name: nil];
|
||||
paddedTensor = [mpsGraph identityWithTensor: paddedTensor name: nil];
|
||||
paddingShape = [NSArray arrayWithObjects:padVec.data() count:ndims];
|
||||
paddedTensor = [mpsGraph padTensor:cachedGraph.inputTensor
|
||||
withPaddingMode:MPSGraphPaddingModeZero
|
||||
leftPadding:paddingShape
|
||||
rightPadding:paddingShape
|
||||
constantValue:0.0
|
||||
name:nil];
|
||||
paddedTensor = [mpsGraph identityWithTensor:paddedTensor name:nil];
|
||||
} else {
|
||||
desc.includeZeroPadToAverage = count_include_pad;
|
||||
}
|
||||
@ -265,35 +302,33 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
|
||||
if (!is_backward_pass) {
|
||||
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor: paddedTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor:paddedTensor descriptor:desc name:nil];
|
||||
if (cachedGraph.divisorTensor) {
|
||||
// workaround: custom divisor isn't supported by MPS backend, so we scale manually
|
||||
return [mpsGraph multiplicationWithPrimaryTensor: avgPoolTensor
|
||||
secondaryTensor: cachedGraph.divisorTensor
|
||||
name: nil];
|
||||
return [mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor
|
||||
secondaryTensor:cachedGraph.divisorTensor
|
||||
name:nil];
|
||||
} else {
|
||||
return avgPoolTensor;
|
||||
}
|
||||
} else { // backward pass
|
||||
MPSGraphTensor* scaledGradTensor = cachedGraph.gradOutputTensor;
|
||||
if (cachedGraph.divisorTensor) {
|
||||
scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor: cachedGraph.gradOutputTensor
|
||||
secondaryTensor: cachedGraph.divisorTensor
|
||||
name: nil];
|
||||
scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor:cachedGraph.gradOutputTensor
|
||||
secondaryTensor:cachedGraph.divisorTensor
|
||||
name:nil];
|
||||
}
|
||||
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor: scaledGradTensor
|
||||
sourceTensor: paddedTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor:scaledGradTensor
|
||||
sourceTensor:paddedTensor
|
||||
descriptor:desc
|
||||
name:nil];
|
||||
if (explicit_padding) {
|
||||
return [mpsGraph padGradientWithIncomingGradientTensor: avgPoolTensor
|
||||
sourceTensor: cachedGraph.inputTensor
|
||||
paddingMode: MPSGraphPaddingModeZero
|
||||
leftPadding: paddingShape
|
||||
rightPadding: paddingShape
|
||||
name: nil];
|
||||
return [mpsGraph padGradientWithIncomingGradientTensor:avgPoolTensor
|
||||
sourceTensor:cachedGraph.inputTensor
|
||||
paddingMode:MPSGraphPaddingModeZero
|
||||
leftPadding:paddingShape
|
||||
rightPadding:paddingShape
|
||||
name:nil];
|
||||
|
||||
} else {
|
||||
return avgPoolTensor;
|
||||
@ -301,137 +336,199 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
};
|
||||
|
||||
pool2d_template(input, output, c10::nullopt, grad_output_opt, kernel_size, stride,
|
||||
padding, {1, 1}, ceil_mode, count_include_pad, divisor_override,
|
||||
pooling_op_block, op_name);
|
||||
pool2d_template(input,
|
||||
output,
|
||||
c10::nullopt,
|
||||
grad_output_opt,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
pooling_op_block,
|
||||
op_name);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor mps_max_pool2d(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
|
||||
Tensor mps_max_pool2d(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
|
||||
};
|
||||
mps::pool2d_template(input, output, c10::nullopt, c10::nullopt, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d");
|
||||
mps::pool2d_template(input,
|
||||
output,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d");
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor mps_max_pool2d_backward(
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
|
||||
Tensor mps_max_pool2d_backward(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous);
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor
|
||||
sourceTensor: cachedGraph.inputTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor
|
||||
sourceTensor:cachedGraph.inputTensor
|
||||
descriptor:desc
|
||||
name:nil];
|
||||
};
|
||||
mps::pool2d_template(input, grad_input, c10::nullopt, grad_output, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_backward");
|
||||
mps::pool2d_template(input,
|
||||
grad_input,
|
||||
c10::nullopt,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d_backward");
|
||||
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
auto indices_memory_format = indices.suggest_memory_format();
|
||||
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor
|
||||
descriptor:desc
|
||||
name:nil];
|
||||
cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
|
||||
return poolOutputs[0];
|
||||
};
|
||||
mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices");
|
||||
mps::pool2d_template(input,
|
||||
output,
|
||||
indices,
|
||||
c10::nullopt,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d_indices");
|
||||
|
||||
if (indices_memory_format == MemoryFormat::ChannelsLast) {
|
||||
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const Tensor& indices,
|
||||
const Tensor& grad_input) {
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const Tensor& indices,
|
||||
const Tensor& grad_input) {
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor
|
||||
sourceTensor: cachedGraph.inputTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor
|
||||
sourceTensor:cachedGraph.inputTensor
|
||||
descriptor:desc
|
||||
name:nil];
|
||||
};
|
||||
mps::pool2d_template(input, grad_input, indices, grad_output, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices_backward");
|
||||
mps::pool2d_template(input,
|
||||
grad_input,
|
||||
indices,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d_indices_backward");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps) (
|
||||
const Tensor& input,
|
||||
int64_t kH,
|
||||
int64_t kW,
|
||||
int64_t dH,
|
||||
int64_t dW,
|
||||
int64_t padH,
|
||||
int64_t padW,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
c10::optional<int64_t> divisor_override,
|
||||
const Tensor& output) {
|
||||
|
||||
mps::avg_pool2d_template(input, output, c10::nullopt, {kH, kW}, {dH, dW}, {padH, padW},
|
||||
{1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d");
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
int64_t kH,
|
||||
int64_t kW,
|
||||
int64_t dH,
|
||||
int64_t dW,
|
||||
int64_t padH,
|
||||
int64_t padW,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
c10::optional<int64_t> divisor_override,
|
||||
const Tensor& output) {
|
||||
mps::avg_pool2d_template(input,
|
||||
output,
|
||||
c10::nullopt,
|
||||
{kH, kW},
|
||||
{dH, dW},
|
||||
{padH, padW},
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
"avg_pool2d");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) (
|
||||
const Tensor& gradOutput,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
c10::optional<int64_t> divisor_override,
|
||||
const Tensor& gradInput) {
|
||||
|
||||
mps::avg_pool2d_template(input, gradInput, gradOutput, kernel_size, stride, padding,
|
||||
{1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d_backward");
|
||||
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)
|
||||
(const Tensor& gradOutput,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
c10::optional<int64_t> divisor_override,
|
||||
const Tensor& gradInput) {
|
||||
mps::avg_pool2d_template(input,
|
||||
gradInput,
|
||||
gradOutput,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
"avg_pool2d_backward");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,9 +1,9 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
@ -15,37 +15,38 @@ namespace at::native {
|
||||
namespace {
|
||||
struct RangeCachedGraph : public mps::MPSCachedGraph {
|
||||
API_AVAILABLE(macosx(12.3))
|
||||
RangeCachedGraph(MPSGraph *mpsGraph, MPSDataType dataType, int32_t shapeVal, bool needsClamp = false, bool startLessEnd = false): MPSCachedGraph(mpsGraph) {
|
||||
RangeCachedGraph(MPSGraph* mpsGraph,
|
||||
MPSDataType dataType,
|
||||
int32_t shapeVal,
|
||||
bool needsClamp = false,
|
||||
bool startLessEnd = false)
|
||||
: MPSCachedGraph(mpsGraph) {
|
||||
@autoreleasepool {
|
||||
auto shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:&shapeVal length:sizeof(int32_t)]
|
||||
shape: @[@1]
|
||||
shape:@[ @1 ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
auto coordsTensor = [mpsGraph coordinateAlongAxis:0
|
||||
withShapeTensor:shapeTensor
|
||||
name:nil];
|
||||
auto coordsTensor = [mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil];
|
||||
coordsTensor = [mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"];
|
||||
|
||||
startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]);
|
||||
multiplyTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]);
|
||||
startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
|
||||
multiplyTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
|
||||
auto scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor
|
||||
secondaryTensor:multiplyTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords
|
||||
secondaryTensor:startTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil];
|
||||
if (needsClamp) {
|
||||
endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]);
|
||||
endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
|
||||
outputTensor = [mpsGraph clampWithTensor:outputTensor
|
||||
minValueTensor: startLessEnd? startTensor : endTensor
|
||||
maxValueTensor: startLessEnd? endTensor : startTensor
|
||||
name: nil];
|
||||
minValueTensor:startLessEnd ? startTensor : endTensor
|
||||
maxValueTensor:startLessEnd ? endTensor : startTensor
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
}
|
||||
MPSGraphTensor *startTensor = nil;
|
||||
MPSGraphTensor *endTensor = nil;
|
||||
MPSGraphTensor *multiplyTensor = nil;
|
||||
MPSGraphTensor *outputTensor = nil;
|
||||
MPSGraphTensor* startTensor = nil;
|
||||
MPSGraphTensor* endTensor = nil;
|
||||
MPSGraphTensor* multiplyTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
@ -59,31 +60,37 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
|
||||
double size_d;
|
||||
if (std::is_same<scalar_t, int64_t>::value) {
|
||||
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
|
||||
/ step.to<accscalar_t>());
|
||||
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) / step.to<accscalar_t>());
|
||||
} else {
|
||||
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
|
||||
/ step.to<double>());
|
||||
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>());
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
|
||||
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
|
||||
"invalid size, possible overflow?");
|
||||
"invalid size, possible overflow?");
|
||||
int64_t size = static_cast<int64_t>(size_d);
|
||||
int64_t numel = result.numel();
|
||||
|
||||
if (numel != size) {
|
||||
if(numel > 0){
|
||||
TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(),
|
||||
" is ", numel, " which does not match the computed number of elements ", size,
|
||||
". Note that this may occur as a result of rounding error. "
|
||||
"The out tensor will be resized to a tensor of shape (", size, ",).");
|
||||
if (numel > 0) {
|
||||
TORCH_WARN("The number of elements in the out tensor of shape ",
|
||||
result.sizes(),
|
||||
" is ",
|
||||
numel,
|
||||
" which does not match the computed number of elements ",
|
||||
size,
|
||||
". Note that this may occur as a result of rounding error. "
|
||||
"The out tensor will be resized to a tensor of shape (",
|
||||
size,
|
||||
",).");
|
||||
}
|
||||
result.resize_({size});
|
||||
}
|
||||
@ -100,28 +107,27 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
|
||||
auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
|
||||
auto cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() {
|
||||
auto* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
auto mpsGraph = make_mps_graph();
|
||||
return new RangeCachedGraph(mpsGraph, mpsDataType, size);
|
||||
});
|
||||
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<RangeCachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
MPSScalar startScalar = getMPSScalar(start, result.scalar_type());
|
||||
feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar);
|
||||
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
|
||||
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
if(!is_contiguous) {
|
||||
if (!is_contiguous) {
|
||||
result.copy_(r);
|
||||
}
|
||||
});
|
||||
@ -139,22 +145,22 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
// double size_d = ((xend - xstart) / xstep) + 1;
|
||||
double size_d;
|
||||
if (std::is_same<scalar_t, int64_t>::value) {
|
||||
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
|
||||
/ step.to<accscalar_t>() + 1;
|
||||
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) / step.to<accscalar_t>() + 1;
|
||||
} else {
|
||||
size_d = static_cast<double>(end.to<double>() - start.to<double>())
|
||||
/ step.to<double>() + 1;
|
||||
size_d = static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>() + 1;
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
|
||||
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
|
||||
"invalid size, possible overflow?");
|
||||
"invalid size, possible overflow?");
|
||||
|
||||
int64_t size = static_cast<int64_t>(size_d);
|
||||
|
||||
@ -171,28 +177,27 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
|
||||
auto cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
|
||||
auto cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() {
|
||||
auto* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
auto mpsGraph = make_mps_graph();
|
||||
return new RangeCachedGraph(mpsGraph, mpsDataType, size);
|
||||
});
|
||||
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<RangeCachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
MPSScalar startScalar = getMPSScalar(start, result.scalar_type());
|
||||
feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar);
|
||||
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
|
||||
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
if(!is_contiguous) {
|
||||
if (!is_contiguous) {
|
||||
result.copy_(r);
|
||||
}
|
||||
});
|
||||
@ -222,28 +227,30 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
bool start_less_end = (start.to<double>() <= end.to<double>());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
|
||||
RangeCachedGraph* cachedGraph = static_cast<RangeCachedGraph *>(cache_->LookUp(key));
|
||||
string key =
|
||||
"linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
|
||||
RangeCachedGraph* cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
|
||||
RangeCachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
RangeCachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new RangeCachedGraph(mpsGraph, MPSDataTypeFloat32, steps, true, start_less_end);
|
||||
|
||||
if(getMPSDataType(result) != MPSDataTypeFloat32) {
|
||||
newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor toType:getMPSDataType(result) name:@"output"];
|
||||
if (getMPSDataType(result) != MPSDataTypeFloat32) {
|
||||
newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor
|
||||
toType:getMPSDataType(result)
|
||||
name:@"output"];
|
||||
}
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<RangeCachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<RangeCachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
auto multiply = (end.to<double>() - start.to<double>()) / ((double)steps - 1.0f);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r);
|
||||
|
||||
@ -255,9 +262,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
MPSScalar multiplyScalar = getMPSScalar(multiply, ScalarType::Float);
|
||||
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, multiplyScalar);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -8,8 +8,8 @@
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/Repeat.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
@ -19,8 +19,7 @@ namespace at::native {
|
||||
|
||||
Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
|
||||
auto nDims = self.dim();
|
||||
TORCH_CHECK(dims.size() == (size_t)nDims,
|
||||
"number of dims don't match in permute");
|
||||
TORCH_CHECK(dims.size() == (size_t)nDims, "number of dims don't match in permute");
|
||||
auto oldSizes = self.sizes();
|
||||
auto oldStrides = self.strides();
|
||||
DimVector newSizes(nDims);
|
||||
@ -28,8 +27,7 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
|
||||
std::vector<bool> seen(nDims);
|
||||
for (const auto i : c10::irange(nDims)) {
|
||||
auto dim = maybe_wrap_dim(dims[i], nDims);
|
||||
TORCH_CHECK(!seen[dim],
|
||||
"repeated dim in permute");
|
||||
TORCH_CHECK(!seen[dim], "repeated dim in permute");
|
||||
seen[dim] = true;
|
||||
newSizes[i] = oldSizes[dim];
|
||||
newStrides[i] = oldStrides[dim];
|
||||
@ -38,16 +36,14 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
|
||||
}
|
||||
|
||||
Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(repeats.size() >= (size_t)self.dim(),
|
||||
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
// Add new leading dimensions to the tensor if the
|
||||
@ -58,7 +54,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
|
||||
DimVector target_size(repeats.size());
|
||||
bool zero_tensor = false;
|
||||
for(const auto idx : c10::irange(repeats.size())) {
|
||||
for (const auto idx : c10::irange(repeats.size())) {
|
||||
if (repeats[idx] == 0) {
|
||||
zero_tensor = true;
|
||||
}
|
||||
@ -68,7 +64,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
Tensor expanded_tensor = self.expand(padded_size);
|
||||
Tensor result = at::empty(target_size, self.options());
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
if(zero_tensor || result.numel() == 0) {
|
||||
if (zero_tensor || result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -76,50 +72,47 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
auto inputDataType = getMPSDataType(expanded_tensor);
|
||||
auto outputDataType = getMPSDataType(result);
|
||||
if (!is_macos_13_or_newer()) {
|
||||
if (expanded_tensor.scalar_type() == kBool) {
|
||||
if (expanded_tensor.scalar_type() == kBool) {
|
||||
inputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
if (result.scalar_type() == kBool) {
|
||||
}
|
||||
if (result.scalar_type() == kBool) {
|
||||
outputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor
|
||||
withMultiplier:getMPSShape(repeats)
|
||||
name:nil];
|
||||
MPSGraphTensor* inputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor withMultiplier:getMPSShape(repeats) name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType);
|
||||
cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/ false, outputDataType);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -142,51 +135,50 @@ kernel void repeat_interleave(constant {0} * repeat_ptr [[buf
|
||||
}}
|
||||
)METAL_REPEAT";
|
||||
|
||||
static
|
||||
id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::string& t1) {
|
||||
static id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::string& t1) {
|
||||
auto key = t1;
|
||||
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
|
||||
auto it = libMap.find(key);
|
||||
if (it != libMap.end()) {
|
||||
return it->second;
|
||||
}
|
||||
NSError *error = nil;
|
||||
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion: MTLLanguageVersion2_3];
|
||||
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
|
||||
libMap[key] = rc;
|
||||
return rc;
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
auto rc =
|
||||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
|
||||
libMap[key] = rc;
|
||||
return rc;
|
||||
}
|
||||
|
||||
static
|
||||
id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::string& t1) {
|
||||
static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::string& t1) {
|
||||
static std::string kernel = "repeat_interleave";
|
||||
auto key = kernel + t1;
|
||||
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
|
||||
auto it = cplMap.find(key);
|
||||
if (it != cplMap.end()) {
|
||||
return it->second;
|
||||
return it->second;
|
||||
}
|
||||
NSError *error = nil;
|
||||
NSError* error = nil;
|
||||
auto library = compileRepeatInterleaveLib(device, t1);
|
||||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
|
||||
TORCH_CHECK(func != nil, "Can't get kernel ", kernel);
|
||||
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
cplMap[key] = rc;
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
void computeRepeatIndices(
|
||||
index_t* repeat_ptr,
|
||||
int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
void computeRepeatIndices(index_t* repeat_ptr,
|
||||
int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
|
||||
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
|
||||
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
|
||||
@ -208,7 +200,7 @@ void computeRepeatIndices(
|
||||
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> pipelineState = getPipelineState(MPSDevice::getInstance()->device(), scalar_type);
|
||||
|
||||
[computeEncoder setComputePipelineState: pipelineState];
|
||||
[computeEncoder setComputePipelineState:pipelineState];
|
||||
[computeEncoder setBuffer:repeatBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBuffer:cumsumBuffer offset:0 atIndex:1];
|
||||
[computeEncoder setBuffer:resultBuffer offset:0 atIndex:2];
|
||||
@ -216,7 +208,7 @@ void computeRepeatIndices(
|
||||
MTLSize gridSize = MTLSizeMake(size, 1, 1);
|
||||
NSUInteger threadsPerThreadgroup_ = pipelineState.maxTotalThreadsPerThreadgroup;
|
||||
if (threadsPerThreadgroup_ > size) {
|
||||
threadsPerThreadgroup_ = size;
|
||||
threadsPerThreadgroup_ = size;
|
||||
}
|
||||
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1);
|
||||
|
||||
@ -233,14 +225,14 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
|
||||
if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
|
||||
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
|
||||
TORCH_WARN_ONCE(
|
||||
"MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
|
||||
repeat = repeat.to(kInt);
|
||||
}
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(
|
||||
repeat, output_size);
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,13 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/Copy.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
@ -21,17 +21,20 @@ namespace at::native {
|
||||
Scalar _local_scalar_dense_mps(const Tensor& self) {
|
||||
Scalar r;
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_mps", [&] {
|
||||
Tensor output = at::empty_like(self, kCPU);
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
self.scalar_type(),
|
||||
"_local_scalar_dense_mps",
|
||||
[&] {
|
||||
Tensor output = at::empty_like(self, kCPU);
|
||||
|
||||
Tensor cpu_output = mps::mps_copy_(output, self, false);
|
||||
scalar_t value = *cpu_output.data_ptr<scalar_t>();
|
||||
r = Scalar(value);
|
||||
});
|
||||
Tensor cpu_output = mps::mps_copy_(output, self, false);
|
||||
scalar_t value = *cpu_output.data_ptr<scalar_t>();
|
||||
r = Scalar(value);
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -5,12 +5,7 @@
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(gather_out_mps)
|
||||
(const Tensor & self_arg,
|
||||
int64_t dim,
|
||||
const Tensor & index,
|
||||
bool sparse_grad,
|
||||
const Tensor & output)
|
||||
{
|
||||
(const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
if (self_arg.numel() == 0 || index.numel() == 0) {
|
||||
@ -20,14 +15,11 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
|
||||
TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet")
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type(),
|
||||
"gather(): self and output must have the same scalar type");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(),
|
||||
"gather(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* indexTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -36,7 +28,6 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
MPSShape* index_shape = getMPSShape(index);
|
||||
uint32_t num_input_dims = [input_shape count];
|
||||
@ -47,24 +38,25 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
bool needSlice = false;
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
|
||||
if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
|
||||
"Index dim must not exceed input dim except at gathering axis")
|
||||
if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
needSlice = true;
|
||||
}
|
||||
auto input_type = getMPSDataType(self);
|
||||
auto output_type = getMPSDataType(output);
|
||||
if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
|
||||
if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
|
||||
input_type = MPSDataTypeInt8;
|
||||
}
|
||||
if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
|
||||
if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
|
||||
output_type = MPSDataTypeInt8;
|
||||
}
|
||||
string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -76,10 +68,10 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
MPSGraphTensor* getInput = inputTensor;
|
||||
|
||||
// Slice into the input tensor IF NEEDED
|
||||
if(needSlice) {
|
||||
NSMutableArray<NSNumber*> *starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*> *ends = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*> *strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
if (needSlice) {
|
||||
NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*>* ends = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
// All strides are 1
|
||||
@ -89,23 +81,19 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
ends[i] = (i != dim) ? index_shape[i] : input_shape[i];
|
||||
}
|
||||
|
||||
getInput = [mpsGraph sliceTensor:inputTensor
|
||||
starts:starts
|
||||
ends:ends
|
||||
strides:strides
|
||||
name:nil];
|
||||
getInput = [mpsGraph sliceTensor:inputTensor starts:starts ends:ends strides:strides name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:(NSString * _Nonnull)nil];
|
||||
name:(NSString* _Nonnull)nil];
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wobjc-method-access"
|
||||
MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis: (NSInteger) dim
|
||||
withUpdatesTensor: getInput
|
||||
indicesTensor: castIndexTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis:(NSInteger)dim
|
||||
withUpdatesTensor:getInput
|
||||
indicesTensor:castIndexTensor
|
||||
name:nil];
|
||||
#pragma clang diagnostic pop
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->indexTensor_ = indexTensor;
|
||||
@ -113,7 +101,7 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type);
|
||||
@ -124,23 +112,20 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
void scatter_mps_general
|
||||
(const Tensor& self_arg,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output,
|
||||
string func_name,
|
||||
const c10::string_view reduce)
|
||||
{
|
||||
void scatter_mps_general(const Tensor& self_arg,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output,
|
||||
string func_name,
|
||||
const c10::string_view reduce) {
|
||||
using namespace mps;
|
||||
|
||||
if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
|
||||
@ -151,12 +136,10 @@ void scatter_mps_general
|
||||
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(),
|
||||
"scatter(): self, src and output must have the same scalar type");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(),
|
||||
"scatter(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* indexTensor_ = nil;
|
||||
MPSGraphTensor* srcTensor_ = nil;
|
||||
@ -166,7 +149,6 @@ void scatter_mps_general
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
MPSShape* index_shape = getMPSShape(index);
|
||||
MPSShape* src_shape = getMPSShape(src);
|
||||
@ -174,7 +156,8 @@ void scatter_mps_general
|
||||
uint32_t num_index_dims = [index_shape count];
|
||||
uint32_t num_src_dims = [src_shape count];
|
||||
|
||||
TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, "Input, index and src must have same rank")
|
||||
TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims,
|
||||
"Input, index and src must have same rank")
|
||||
|
||||
// Do we need to slice into the src tensor?
|
||||
bool needSlice = false;
|
||||
@ -182,11 +165,13 @@ void scatter_mps_general
|
||||
bool needsCast = false;
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
|
||||
TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
|
||||
if([index_shape[i] intValue] < [src_shape[i] intValue])
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
|
||||
"Index dim must not exceed input dim except at gathering axis")
|
||||
TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue],
|
||||
"Index dim must not exceed input dim except at gathering axis")
|
||||
if ([index_shape[i] intValue] < [src_shape[i] intValue])
|
||||
needSlice = true;
|
||||
if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
inputNeedSlice = true;
|
||||
}
|
||||
TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS")
|
||||
@ -197,11 +182,12 @@ void scatter_mps_general
|
||||
needsCast = true;
|
||||
}
|
||||
|
||||
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" +
|
||||
std::string(reduce);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -209,7 +195,7 @@ void scatter_mps_general
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
|
||||
MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
|
||||
MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
|
||||
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
MPSGraphTensor* castSrcTensor = srcTensor;
|
||||
@ -229,9 +215,9 @@ void scatter_mps_general
|
||||
|
||||
// Slice into the src or input tensors IF NEEDED
|
||||
if (needSlice || inputNeedSlice) {
|
||||
NSMutableArray<NSNumber*> *starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*> *strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*> *ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<NSNumber*>* ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
strides[i] = @1;
|
||||
@ -240,44 +226,41 @@ void scatter_mps_general
|
||||
scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i];
|
||||
}
|
||||
if (needSlice) {
|
||||
slicedSrc = [mpsGraph sliceTensor:castSrcTensor
|
||||
starts:starts
|
||||
ends:ends_src
|
||||
strides:strides
|
||||
name:nil];
|
||||
slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts ends:ends_src strides:strides name:nil];
|
||||
}
|
||||
if (inputNeedSlice) {
|
||||
slicedInput = [mpsGraph sliceTensor:castInputTensor
|
||||
starts:starts
|
||||
ends:scatterInputShape
|
||||
strides:strides
|
||||
name:nil];
|
||||
starts:starts
|
||||
ends:scatterInputShape
|
||||
strides:strides
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet;
|
||||
|
||||
if(reduce == "sum" || reduce == "add")
|
||||
if (reduce == "sum" || reduce == "add")
|
||||
scatter_mode = MPSGraphScatterModeAdd;
|
||||
else if(reduce == "prod" || reduce == "multiply")
|
||||
else if (reduce == "prod" || reduce == "multiply")
|
||||
scatter_mode = MPSGraphScatterModeMul;
|
||||
else if(reduce == "amax")
|
||||
else if (reduce == "amax")
|
||||
scatter_mode = MPSGraphScatterModeMax;
|
||||
else if(reduce == "amin")
|
||||
else if (reduce == "amin")
|
||||
scatter_mode = MPSGraphScatterModeMin;
|
||||
|
||||
// Scatter this into the input with set mode
|
||||
// Scatter this into the input with set mode
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wobjc-method-access"
|
||||
MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim
|
||||
withDataTensor: slicedInput
|
||||
updatesTensor: slicedSrc
|
||||
indicesTensor: castIndexTensor
|
||||
mode: scatter_mode
|
||||
name: nil];
|
||||
MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis:(NSInteger)dim
|
||||
withDataTensor:slicedInput
|
||||
updatesTensor:slicedSrc
|
||||
indicesTensor:castIndexTensor
|
||||
mode:scatter_mode
|
||||
name:nil];
|
||||
#pragma clang diagnostic pop
|
||||
if(inputNeedSlice) {
|
||||
if (inputNeedSlice) {
|
||||
// Make an array of scatter indices tensors
|
||||
NSMutableArray<MPSGraphTensor*>* indicesTensors = [NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<MPSGraphTensor*>* indicesTensors =
|
||||
[NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
|
||||
|
||||
// 1. Concatenate the coord tensors
|
||||
// 2. Flatten the values
|
||||
@ -289,18 +272,18 @@ void scatter_mps_general
|
||||
shape_data[i] = {[scatterInputShape[i] intValue]};
|
||||
}
|
||||
|
||||
MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
|
||||
shape:@[[NSNumber numberWithUnsignedInt:num_input_dims]]
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* scatterInputShapeTensor =
|
||||
[mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
|
||||
shape:@[ [NSNumber numberWithUnsignedInt:num_input_dims] ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor: axisTensor
|
||||
withShapeTensor: scatterInputShapeTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor:axisTensor
|
||||
withShapeTensor:scatterInputShapeTensor
|
||||
name:nil];
|
||||
scatter_currentIndexTensor = [mpsGraph reshapeTensor:scatter_currentIndexTensor
|
||||
withShape:@[@-1, @1]
|
||||
withShape:@[ @-1, @1 ]
|
||||
name:nil];
|
||||
indicesTensors[i] = scatter_currentIndexTensor;
|
||||
}
|
||||
@ -309,9 +292,7 @@ void scatter_mps_general
|
||||
dimension:(NSInteger)1
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor
|
||||
withShape:@[@-1]
|
||||
name:nil];
|
||||
MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor withShape:@[ @-1 ] name:nil];
|
||||
|
||||
outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor
|
||||
updatesTensor:flatValuesTensor
|
||||
@ -325,11 +306,12 @@ void scatter_mps_general
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->srcTensor_ = srcTensor;
|
||||
newCachedGraph->indexTensor_ = indexTensor;
|
||||
newCachedGraph->outputTensor_ = needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
|
||||
newCachedGraph->outputTensor_ =
|
||||
needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape);
|
||||
@ -342,41 +324,24 @@ void scatter_mps_general
|
||||
srcPlaceholder.getMPSGraphTensor() : srcPlaceholder.getMPSGraphTensorData(),
|
||||
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_src_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output) {
|
||||
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_src_out_mps", "set");
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_value_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Scalar& value,
|
||||
const Tensor& output) {
|
||||
|
||||
Tensor src = at::native::empty_mps(index.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
self.suggest_memory_format());
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& output) {
|
||||
Tensor src = at::native::empty_mps(
|
||||
index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
|
||||
src.fill_(value);
|
||||
scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_out_mps", "set");
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_reduce_out_mps)
|
||||
@ -386,9 +351,7 @@ TORCH_IMPL_FUNC(scatter_reduce_out_mps)
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce,
|
||||
const Tensor& output) {
|
||||
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce);
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
|
||||
@ -398,25 +361,14 @@ TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
|
||||
const Scalar& value,
|
||||
const c10::string_view reduce,
|
||||
const Tensor& output) {
|
||||
|
||||
Tensor src = at::native::empty_mps(index.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
self.suggest_memory_format());
|
||||
Tensor src = at::native::empty_mps(
|
||||
index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
|
||||
src.fill_(value);
|
||||
scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_reduce_out_mps", reduce);
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_add_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output) {
|
||||
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_add_mps_out", "add");
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -27,21 +27,12 @@ std::vector<int64_t> getTopK0Shape(IntArrayRef sizes, const int64_t dim_) {
|
||||
|
||||
// topk
|
||||
TORCH_IMPL_FUNC(topk_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t k,
|
||||
int64_t dim_,
|
||||
bool largest,
|
||||
bool sorted,
|
||||
const Tensor& values,
|
||||
const Tensor& indices)
|
||||
{
|
||||
(const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted, const Tensor& values, const Tensor& indices) {
|
||||
using namespace mps;
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
||||
TORCH_CHECK(
|
||||
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
|
||||
"selected index k out of range");
|
||||
TORCH_CHECK(k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range");
|
||||
|
||||
if (!is_macos_13_or_newer() && (k>16)) {
|
||||
if (!is_macos_13_or_newer() && (k > 16)) {
|
||||
TORCH_WARN_ONCE("torch.topk support for k>16 by MPS on MacOS 13+, please upgrade");
|
||||
Tensor cpu_indices = indices.clone().to("cpu");
|
||||
Tensor cpu_values = values.clone().to("cpu");
|
||||
@ -52,31 +43,29 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
}
|
||||
|
||||
if (self.dim() == 0 && self.numel() == 1) {
|
||||
values.copy_(self);
|
||||
indices.zero_();
|
||||
return;
|
||||
values.copy_(self);
|
||||
indices.zero_();
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle empty tensors
|
||||
if (self.numel() == 0)
|
||||
{
|
||||
values.copy_(self);
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
return;
|
||||
if (self.numel() == 0) {
|
||||
values.copy_(self);
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
return;
|
||||
}
|
||||
// Handle k == 0 case. Needed because MPSGraph does not support k == 0.
|
||||
if (k == 0)
|
||||
{
|
||||
const auto out_shape = getTopK0Shape(self.sizes(), dim);
|
||||
values.resize_(out_shape);
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
return;
|
||||
if (k == 0) {
|
||||
const auto out_shape = getTopK0Shape(self.sizes(), dim);
|
||||
values.resize_(out_shape);
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
return;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@ -85,154 +74,126 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" +
|
||||
getMPSTypeString(self) +
|
||||
":k" + to_string(k) + ":dim" + to_string(dim_) +
|
||||
":largest" + to_string(largest);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) +
|
||||
":dim" + to_string(dim_) + ":largest" + to_string(largest);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
if (is_macos_13_or_newer()) {
|
||||
MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor;
|
||||
MPSDataType dataType = getMPSDataType(self);
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor
|
||||
if (dataType != MPSDataTypeInt32 &&
|
||||
dataType != MPSDataTypeFloat32 &&
|
||||
dataType != MPSDataTypeFloat16) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
}
|
||||
MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||
axis:(NSUInteger)dim
|
||||
descending:largest
|
||||
name:nil];
|
||||
sortedTensor = [mpsGraph sliceTensor:sortedTensor
|
||||
dimension:(NSUInteger)dim
|
||||
start:((NSUInteger) 0)
|
||||
length:k
|
||||
name:nil];
|
||||
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:largest
|
||||
name:@"argmax_out"];
|
||||
argSortedTensor = [mpsGraph sliceTensor:argSortedTensor
|
||||
dimension:dim
|
||||
start:((NSUInteger) 0)
|
||||
length:k
|
||||
name:nil];
|
||||
newCachedGraph->valuesTensor = sortedTensor;
|
||||
newCachedGraph->indicesTensor = argSortedTensor;
|
||||
|
||||
} else {
|
||||
if ((dim_ != -1 && dim_ != self.dim() - 1) && (!largest)) {
|
||||
// transpose and negate
|
||||
MPSGraphTensor *transposedInput = [mpsGraph transposeTensor: newCachedGraph->selfTensor
|
||||
dimension: (NSUInteger)self.dim()-1
|
||||
withDimension: (NSUInteger)dim_
|
||||
name: nil];
|
||||
MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput
|
||||
name: nil];
|
||||
MPSGraphTensor * negatedTransposedInput = [mpsGraph negativeWithTensor:identity
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:negatedTransposedInput
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
MPSGraphTensor *valuesNegatedTransposed = outputMPSGraphTensors[0];
|
||||
MPSGraphTensor *indicesTransposed = outputMPSGraphTensors[1];
|
||||
MPSGraphTensor *valuesNegated = [mpsGraph transposeTensor: valuesNegatedTransposed
|
||||
dimension: (NSUInteger)self.dim()-1
|
||||
withDimension: (NSUInteger)dim_
|
||||
name: nil];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated
|
||||
name: nil];
|
||||
newCachedGraph->indicesTensor = [mpsGraph transposeTensor: indicesTransposed
|
||||
dimension: (NSUInteger)self.dim()-1
|
||||
withDimension: (NSUInteger)dim_
|
||||
name: nil];
|
||||
} else if (dim_ != -1 && dim_ != self.dim() - 1) {
|
||||
MPSGraphTensor *transposedInput = [mpsGraph transposeTensor: newCachedGraph->selfTensor
|
||||
dimension: (NSUInteger)self.dim()-1
|
||||
withDimension: (NSUInteger)dim_
|
||||
name: nil];
|
||||
MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:identity
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
MPSGraphTensor *valuesTransposed = outputMPSGraphTensors[0];
|
||||
MPSGraphTensor *indicesTransposed = outputMPSGraphTensors[1];
|
||||
newCachedGraph->valuesTensor = [mpsGraph transposeTensor:valuesTransposed
|
||||
dimension: (NSUInteger)self.dim()-1
|
||||
withDimension: (NSUInteger)dim_
|
||||
name: nil];
|
||||
newCachedGraph->indicesTensor = [mpsGraph transposeTensor: indicesTransposed
|
||||
dimension: (NSUInteger)self.dim()-1
|
||||
withDimension: (NSUInteger)dim_
|
||||
name: nil];
|
||||
} else if (!largest) {
|
||||
// only negate
|
||||
MPSGraphTensor *negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:negatedInput
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
MPSGraphTensor *valuesNegated = outputMPSGraphTensors[0];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated
|
||||
name: nil];
|
||||
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
|
||||
} else {
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:newCachedGraph->selfTensor
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
newCachedGraph->valuesTensor = outputMPSGraphTensors[0];
|
||||
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
|
||||
}
|
||||
if (is_macos_13_or_newer()) {
|
||||
MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor;
|
||||
MPSDataType dataType = getMPSDataType(self);
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor
|
||||
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
}
|
||||
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||
axis:(NSUInteger)dim
|
||||
descending:largest
|
||||
name:nil];
|
||||
sortedTensor = [mpsGraph sliceTensor:sortedTensor
|
||||
dimension:(NSUInteger)dim
|
||||
start:((NSUInteger)0)length:k
|
||||
name:nil];
|
||||
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:largest
|
||||
name:@"argmax_out"];
|
||||
argSortedTensor = [mpsGraph sliceTensor:argSortedTensor
|
||||
dimension:dim
|
||||
start:((NSUInteger)0)length:k
|
||||
name:nil];
|
||||
newCachedGraph->valuesTensor = sortedTensor;
|
||||
newCachedGraph->indicesTensor = argSortedTensor;
|
||||
|
||||
} else {
|
||||
if ((dim_ != -1 && dim_ != self.dim() - 1) && (!largest)) {
|
||||
// transpose and negate
|
||||
MPSGraphTensor* transposedInput = [mpsGraph transposeTensor:newCachedGraph->selfTensor
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil];
|
||||
MPSGraphTensor* negatedTransposedInput = [mpsGraph negativeWithTensor:identity name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedTransposedInput
|
||||
k:((NSUInteger)k)name:nil];
|
||||
MPSGraphTensor* valuesNegatedTransposed = outputMPSGraphTensors[0];
|
||||
MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1];
|
||||
MPSGraphTensor* valuesNegated = [mpsGraph transposeTensor:valuesNegatedTransposed
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil];
|
||||
newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
} else if (dim_ != -1 && dim_ != self.dim() - 1) {
|
||||
MPSGraphTensor* transposedInput = [mpsGraph transposeTensor:newCachedGraph->selfTensor
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:identity
|
||||
k:((NSUInteger)k)name:nil];
|
||||
MPSGraphTensor* valuesTransposed = outputMPSGraphTensors[0];
|
||||
MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1];
|
||||
newCachedGraph->valuesTensor = [mpsGraph transposeTensor:valuesTransposed
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
} else if (!largest) {
|
||||
// only negate
|
||||
MPSGraphTensor* negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedInput
|
||||
k:((NSUInteger)k)name:nil];
|
||||
MPSGraphTensor* valuesNegated = outputMPSGraphTensors[0];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil];
|
||||
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
|
||||
} else {
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors =
|
||||
[mpsGraph topKWithSourceTensor:newCachedGraph->selfTensor k:((NSUInteger)k)name:nil];
|
||||
newCachedGraph->valuesTensor = outputMPSGraphTensors[0];
|
||||
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
}
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self);
|
||||
// Outputs as placeholders
|
||||
Placeholder valuesPlaceholder = Placeholder(cachedGraph->valuesTensor, values);
|
||||
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() :
|
||||
inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
valuesPlaceholder.getMPSGraphTensor() :
|
||||
valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() :
|
||||
indicesPlaceholder.getMPSGraphTensorData()
|
||||
valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
void check_shape_except_dim(const Tensor &first, const Tensor &second,
|
||||
int dimension, int index)
|
||||
{
|
||||
void check_shape_except_dim(const Tensor& first, const Tensor& second, int dimension, int index) {
|
||||
int first_dims = first.dim();
|
||||
int second_dims = second.dim();
|
||||
TORCH_CHECK(first_dims == second_dims,
|
||||
"Tensors must have same number of dimensions: got ", first_dims,
|
||||
" and ", second_dims);
|
||||
TORCH_CHECK(
|
||||
first_dims == second_dims, "Tensors must have same number of dimensions: got ", first_dims, " and ", second_dims);
|
||||
for (int dim = 0; dim < first_dims; dim++) {
|
||||
if (dim == dimension) {
|
||||
continue;
|
||||
@ -240,23 +201,27 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
|
||||
int64_t first_dim_size = at::native::size(first, dim);
|
||||
int64_t second_dim_size = at::native::size(second, dim);
|
||||
TORCH_CHECK(first_dim_size == second_dim_size,
|
||||
"Sizes of tensors must match except in dimension ", dim, ". Got ",
|
||||
static_cast<long long>(first_dim_size), " and ",
|
||||
static_cast<long long>(second_dim_size), " (The offending index is ",
|
||||
index, ")");
|
||||
"Sizes of tensors must match except in dimension ",
|
||||
dim,
|
||||
". Got ",
|
||||
static_cast<long long>(first_dim_size),
|
||||
" and ",
|
||||
static_cast<long long>(second_dim_size),
|
||||
" (The offending index is ",
|
||||
index,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(cat_out_mps)
|
||||
(const ITensorListRef& inputs,
|
||||
int64_t dimension,
|
||||
int64_t valid,
|
||||
bool all_contiguous,
|
||||
bool all_same_dtype,
|
||||
bool all_same_sizes_and_stride,
|
||||
MemoryFormat memory_format,
|
||||
const Tensor& out) {
|
||||
|
||||
(const ITensorListRef& inputs,
|
||||
int64_t dimension,
|
||||
int64_t valid,
|
||||
bool all_contiguous,
|
||||
bool all_same_dtype,
|
||||
bool all_same_sizes_and_stride,
|
||||
MemoryFormat memory_format,
|
||||
const Tensor& out) {
|
||||
using namespace mps;
|
||||
|
||||
if (out.numel() == 0) {
|
||||
@ -270,14 +235,16 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
|
||||
auto lap = at::get_overlap_status(out, t);
|
||||
TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
|
||||
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
|
||||
"of the output memory locations. Found overlap in input tensor ", idx);
|
||||
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
|
||||
"of the output memory locations. Found overlap in input tensor ",
|
||||
idx);
|
||||
idx++;
|
||||
}
|
||||
// Check for type promotion
|
||||
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
|
||||
"torch.cat(): input types can't be cast to the desired output type ", out.scalar_type());
|
||||
TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size());
|
||||
"torch.cat(): input types can't be cast to the desired output type ",
|
||||
out.scalar_type());
|
||||
TORCH_CHECK(inputs.size() > 0, "torch.cat(): invalid number of inputs ", inputs.size());
|
||||
|
||||
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
|
||||
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
|
||||
@ -288,9 +255,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
// this behavior for backwards compatibility, but only for this specific size
|
||||
// (i.e. other empty sizes are not skipped).
|
||||
// FIXME: warn if this is the case
|
||||
auto should_skip = [](const Tensor& t) {
|
||||
return t.dim() == 1 && at::native::size(t, 0) == 0;
|
||||
};
|
||||
auto should_skip = [](const Tensor& t) { return t.dim() == 1 && at::native::size(t, 0) == 0; };
|
||||
at::assert_no_internal_overlap(out);
|
||||
|
||||
Tensor notSkippedTensor;
|
||||
@ -317,11 +282,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
for (const Tensor& t : inputs) {
|
||||
TORCH_CHECK(t.device() == notSkippedTensor.device(),
|
||||
"torch.cat(): all input tensors must be on the same device. Received ",
|
||||
t.device(), " and ", notSkippedTensor.device());
|
||||
t.device(),
|
||||
" and ",
|
||||
notSkippedTensor.device());
|
||||
}
|
||||
TORCH_CHECK(out.device() == notSkippedTensor.device(),
|
||||
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
|
||||
notSkippedTensor.device(), " and out is on ", out.device());
|
||||
notSkippedTensor.device(),
|
||||
" and out is on ",
|
||||
out.device());
|
||||
|
||||
// TODO: For better performance by eliminating input tensor gathering and post transpose,
|
||||
// TODO: it is better to keep the out tensor's memory format.
|
||||
@ -354,23 +323,23 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
std::vector<MPSGraphTensor*> inputTensors_;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/ true) +
|
||||
":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph *mpsGraph = make_mps_graph();
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
|
||||
@ -383,7 +352,8 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
|
||||
newCachedGraph->inputTensors_[idx] =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
|
||||
if (tensor.scalar_type() != out_dtype) {
|
||||
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
|
||||
toType:getMPSDataType(out_dtype)
|
||||
@ -393,15 +363,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
}
|
||||
}
|
||||
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data()
|
||||
count:len_tensor_array];
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
|
||||
dimension:dimension // Maybe convert this from int64_t -> int32
|
||||
name:nil];
|
||||
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeBool
|
||||
name:@"outputTensor"];
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
|
||||
}
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
@ -418,9 +385,11 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx],
|
||||
tensor,
|
||||
getMPSShape(tensor, MemoryFormat::Contiguous),
|
||||
/*gatherTensorData*/true, scalar_type);
|
||||
/*gatherTensorData*/ true,
|
||||
scalar_type);
|
||||
t_idx++;
|
||||
}
|
||||
i++;
|
||||
@ -430,16 +399,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
|
||||
outputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
for (auto& inputPlaceholder : inputPlaceholders) {
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
@ -16,30 +16,26 @@
|
||||
namespace at::native {
|
||||
|
||||
void get_shapes(MPSShape* input_shape_readonly,
|
||||
NSMutableArray<NSNumber*>* &input_shape,
|
||||
int num_input_dims, c10::MemoryFormat memory_format) {
|
||||
NSMutableArray<NSNumber*>*& input_shape,
|
||||
int num_input_dims,
|
||||
c10::MemoryFormat memory_format) {
|
||||
// Modify the shape
|
||||
if(memory_format == at::MemoryFormat::Contiguous) {
|
||||
for(int i = 0; i < num_input_dims; i++)
|
||||
if (memory_format == at::MemoryFormat::Contiguous) {
|
||||
for (int i = 0; i < num_input_dims; i++)
|
||||
input_shape[i] = input_shape_readonly[i];
|
||||
}
|
||||
else { // ChannelsLast
|
||||
} else { // ChannelsLast
|
||||
auto num_channels = input_shape_readonly[1];
|
||||
input_shape[0] = input_shape_readonly[0];
|
||||
for(int i = 1; i < num_input_dims-1; i++)
|
||||
input_shape[i] = input_shape_readonly[i+1];
|
||||
input_shape[num_input_dims-1] = num_channels;
|
||||
for (int i = 1; i < num_input_dims - 1; i++)
|
||||
input_shape[i] = input_shape_readonly[i + 1];
|
||||
input_shape[num_input_dims - 1] = num_channels;
|
||||
}
|
||||
}
|
||||
|
||||
// Note - Currently only supported for 4D image tensors
|
||||
|
||||
TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
(const Tensor& input_,
|
||||
const int64_t dim,
|
||||
const bool half_to_float,
|
||||
const Tensor& output) {
|
||||
|
||||
(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) {
|
||||
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS");
|
||||
|
||||
if (input_.numel() == 0) {
|
||||
@ -49,25 +45,22 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
Tensor input;
|
||||
if (input_.dim() == 0) {
|
||||
input = input_.view(1);
|
||||
}
|
||||
else
|
||||
} else
|
||||
input = input_;
|
||||
|
||||
int64_t dim_ = maybe_wrap_dim(dim, input.dim());
|
||||
TORCH_CHECK(
|
||||
dim_ >= 0 && dim_ < input.dim(),
|
||||
"Softmax:dim must be non-negative and less than input dimensions");
|
||||
TORCH_CHECK(dim_ >= 0 && dim_ < input.dim(), "Softmax:dim must be non-negative and less than input dimensions");
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
// TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should match")
|
||||
// TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should
|
||||
// match")
|
||||
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
@ -75,20 +68,20 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string mem_format_key = get_mem_format_string(memory_format);
|
||||
MPSShape* input_shape_readonly = mps::getMPSShape(input);
|
||||
int num_input_dims = [input_shape_readonly count];
|
||||
// Check - Channels last implies 4d
|
||||
TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4, "ChannelsLast implies 4d tensor")
|
||||
TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4,
|
||||
"ChannelsLast implies 4d tensor")
|
||||
// Input shape changes based on memory format
|
||||
NSMutableArray<NSNumber*>* input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
|
||||
get_shapes(input_shape_readonly, input_shape, num_input_dims, memory_format);
|
||||
|
||||
// Change dim
|
||||
if(memory_format == at::MemoryFormat::ChannelsLast && dim_ > 0) {
|
||||
switch(dim_) {
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast && dim_ > 0) {
|
||||
switch (dim_) {
|
||||
case 1:
|
||||
dim_ = 3;
|
||||
break;
|
||||
@ -105,13 +98,13 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":"
|
||||
+ [ns_shape_key UTF8String] + ":" + std::to_string(dim_);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":" + [ns_shape_key UTF8String] +
|
||||
":" + std::to_string(dim_);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -120,28 +113,20 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
|
||||
|
||||
// passing selector of softMaxWithTensor on the mpsGraph object
|
||||
MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor
|
||||
axis:(NSInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor axis:(NSInteger)dim_ name:nil];
|
||||
|
||||
// Output needs to be contiguous format
|
||||
if(memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
auto N = input_shape[0];
|
||||
auto H = input_shape[1];
|
||||
auto W = input_shape[2];
|
||||
auto C = input_shape[3];
|
||||
|
||||
outputTensor = [mpsGraph reshapeTensor:outputTensor
|
||||
withShape:@[N, ([NSNumber numberWithInt:[H intValue]* [W intValue]]), C]
|
||||
withShape:@[ N, ([NSNumber numberWithInt:[H intValue] * [W intValue]]), C ]
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor
|
||||
dimension:1
|
||||
withDimension:2
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:outputTensor
|
||||
withShape:@[N, C, H, W]
|
||||
name:nil];
|
||||
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor dimension:1 withDimension:2 name:nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:@[ N, C, H, W ] name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -149,32 +134,24 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, input_shape);
|
||||
// This must be the Contiguous shape
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
(const Tensor& grad_,
|
||||
const Tensor& output_,
|
||||
int64_t dim,
|
||||
ScalarType input_dtype,
|
||||
const Tensor& grad_input) {
|
||||
|
||||
(const Tensor& grad_, const Tensor& output_, int64_t dim, ScalarType input_dtype, const Tensor& grad_input) {
|
||||
if (output_.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -182,29 +159,24 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
Tensor grad;
|
||||
if (grad_.dim() == 0) {
|
||||
grad = grad_.view(1);
|
||||
}
|
||||
else
|
||||
} else
|
||||
grad = grad_;
|
||||
|
||||
Tensor output;
|
||||
if (output_.dim() == 0) {
|
||||
output = output_.view(1);
|
||||
}
|
||||
else
|
||||
} else
|
||||
output = output_;
|
||||
|
||||
int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
|
||||
TORCH_CHECK(
|
||||
dim_ >= 0 && dim_ < grad.dim(),
|
||||
"Grad:dim must be non-negative and less than input dimensions");
|
||||
TORCH_CHECK(dim_ >= 0 && dim_ < grad.dim(), "Grad:dim must be non-negative and less than input dimensions");
|
||||
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* softmaxTensor_ = nil;
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* gradInputTensor_ = nil;
|
||||
@ -213,17 +185,16 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* grad_shape = mps::getMPSShape(grad);
|
||||
NSString* ns_shape_key = [[grad_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":"
|
||||
+ [ns_shape_key UTF8String] + ":" + std::to_string(dim_);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::to_string(dim_);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -235,9 +206,7 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
MPSGraphTensor* mulTensor = [mpsGraph multiplicationWithPrimaryTensor:softmaxTensor
|
||||
secondaryTensor:gradOutputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor
|
||||
axis:(NSInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor axis:(NSInteger)dim_ name:nil];
|
||||
MPSGraphTensor* gradSubTensor = [mpsGraph subtractionWithPrimaryTensor:gradOutputTensor
|
||||
secondaryTensor:mulSumTensor
|
||||
name:nil];
|
||||
@ -251,7 +220,7 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder softmaxPlaceholder = Placeholder(cachedGraph->softmaxTensor_, output, grad_shape);
|
||||
@ -262,12 +231,10 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
softmaxPlaceholder.getMPSGraphTensor() : softmaxPlaceholder.getMPSGraphTensorData(),
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -42,60 +42,57 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
|
||||
};
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@autoreleasepool {
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) +
|
||||
":dim" + to_string(dim) + ":descending" + to_string(descending);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) +
|
||||
":descending" + to_string(descending);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:(BOOL)descending
|
||||
name:@"sort_out"];
|
||||
if ([sortedTensor dataType] != getMPSDataType(values)) {
|
||||
sortedTensor = castMPSTensor(mpsGraph, sortedTensor, values.scalar_type());
|
||||
}
|
||||
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:(BOOL)descending
|
||||
name:@"argsort_out"];
|
||||
if ([argSortedTensor dataType] != getMPSDataType(indices)) {
|
||||
argSortedTensor = castMPSTensor(mpsGraph, argSortedTensor, indices.scalar_type());
|
||||
}
|
||||
newCachedGraph->valuesTensor = sortedTensor;
|
||||
newCachedGraph->indicesTensor = argSortedTensor;
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:(BOOL)descending
|
||||
name:@"sort_out"];
|
||||
if ([sortedTensor dataType] != getMPSDataType(values)) {
|
||||
sortedTensor = castMPSTensor(mpsGraph, sortedTensor, values.scalar_type());
|
||||
}
|
||||
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:(BOOL)descending
|
||||
name:@"argsort_out"];
|
||||
if ([argSortedTensor dataType] != getMPSDataType(indices)) {
|
||||
argSortedTensor = castMPSTensor(mpsGraph, argSortedTensor, indices.scalar_type());
|
||||
}
|
||||
newCachedGraph->valuesTensor = sortedTensor;
|
||||
newCachedGraph->indicesTensor = argSortedTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
}
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self);
|
||||
// Outputs as placeholders
|
||||
Placeholder valuesPlaceholder = Placeholder(cachedGraph->valuesTensor, values);
|
||||
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
feeds = @{ inputPlaceholder.getMPSGraphTensor() :
|
||||
inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
valuesPlaceholder.getMPSGraphTensor() :
|
||||
valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() :
|
||||
indicesPlaceholder.getMPSGraphTensorData()
|
||||
valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
@ -4,14 +4,11 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Tensor& bincount_mps_impl(const Tensor& self,
|
||||
const Tensor& weights,
|
||||
Tensor& output) {
|
||||
Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightsTensor_ = nil;
|
||||
MPSGraphTensor* scatterDataTensor_ = nil;
|
||||
@ -24,42 +21,37 @@ Tensor& bincount_mps_impl(const Tensor& self,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
// Initialize graph
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
|
||||
MPSGraphTensor* scatterDataTensor =
|
||||
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
|
||||
|
||||
MPSGraphTensor *updatesTensor = nil;
|
||||
MPSGraphTensor* updatesTensor = nil;
|
||||
if (has_weights) {
|
||||
updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, weights);
|
||||
}
|
||||
else {
|
||||
updatesTensor = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape(self)
|
||||
dataType:getMPSDataType(output)];
|
||||
} else {
|
||||
updatesTensor = [mpsGraph constantWithScalar:1.0f shape:getMPSShape(self) dataType:getMPSDataType(output)];
|
||||
}
|
||||
|
||||
MPSGraphTensor *castedInputTensor = inputTensor;
|
||||
MPSGraphTensor* castedInputTensor = inputTensor;
|
||||
if (self.scalar_type() == kByte) {
|
||||
castedInputTensor = [mpsGraph castTensor:inputTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castInputTensor"];
|
||||
castedInputTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"castInputTensor"];
|
||||
}
|
||||
|
||||
MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
|
||||
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
|
||||
updatesTensor:updatesTensor
|
||||
indicesTensor:castedInputTensor
|
||||
axis:0
|
||||
mode:MPSGraphScatterModeAdd
|
||||
name:nil];
|
||||
axis:0
|
||||
mode:MPSGraphScatterModeAdd
|
||||
name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
@ -70,7 +62,7 @@ Tensor& bincount_mps_impl(const Tensor& self,
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
|
||||
@ -80,17 +72,16 @@ Tensor& bincount_mps_impl(const Tensor& self,
|
||||
Placeholder weightsPlaceholder = Placeholder();
|
||||
|
||||
// Create dictionary of inputs/feeds and outputs/results
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
feeds[scatterPlaceholder.getMPSGraphTensor()] = scatterPlaceholder.getMPSGraphTensorData();
|
||||
if(has_weights) {
|
||||
if (has_weights) {
|
||||
weightsPlaceholder = Placeholder(cachedGraph->weightsTensor_, weights);
|
||||
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
// Run the graph
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
@ -108,43 +99,32 @@ Tensor _bincount_mps(const Tensor& self, const c10::optional<Tensor>& weights_op
|
||||
TORCH_CHECK(minlength >= 0, "minlength should be >= 0");
|
||||
|
||||
if (self.dim() == 1 && self.numel() == 0) {
|
||||
return at::zeros(
|
||||
{minlength},
|
||||
kLong,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS,
|
||||
c10::nullopt /* pin_memory */);
|
||||
return at::zeros({minlength}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */);
|
||||
}
|
||||
TORCH_CHECK(self.dim() == 1 && self.min().item<int64_t>() >= 0, "bincount only supports 1-d non-negative integral inputs.");
|
||||
TORCH_CHECK(self.dim() == 1 && self.min().item<int64_t>() >= 0,
|
||||
"bincount only supports 1-d non-negative integral inputs.");
|
||||
|
||||
bool has_weights = weights.defined();
|
||||
TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))), "weights should be 1-d and have the same length as input");
|
||||
TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))),
|
||||
"weights should be 1-d and have the same length as input");
|
||||
|
||||
const int64_t nbins = std::max(self.max().item<int64_t>() + 1L, minlength);
|
||||
Tensor output;
|
||||
|
||||
Tensor weights_ = weights;
|
||||
if (has_weights) {
|
||||
if(weights.scalar_type() != ScalarType::Float &&
|
||||
weights.scalar_type() != ScalarType::Int &&
|
||||
weights.scalar_type() != ScalarType::Half) {
|
||||
// Scatter doesn't work for int8/int16 dtypes
|
||||
weights_ = weights.to(kInt);
|
||||
if (weights.scalar_type() != ScalarType::Float && weights.scalar_type() != ScalarType::Int &&
|
||||
weights.scalar_type() != ScalarType::Half) {
|
||||
// Scatter doesn't work for int8/int16 dtypes
|
||||
weights_ = weights.to(kInt);
|
||||
}
|
||||
output = at::zeros(
|
||||
{nbins},
|
||||
optTypeMetaToScalarType(weights_.options().dtype_opt()),
|
||||
weights_.options().layout_opt(),
|
||||
weights_.options().device_opt(),
|
||||
weights_.options().pinned_memory_opt());
|
||||
}
|
||||
else {
|
||||
output = at::zeros(
|
||||
{nbins},
|
||||
kLong,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS,
|
||||
c10::nullopt /* pin_memory */);
|
||||
output = at::zeros({nbins},
|
||||
optTypeMetaToScalarType(weights_.options().dtype_opt()),
|
||||
weights_.options().layout_opt(),
|
||||
weights_.options().device_opt(),
|
||||
weights_.options().pinned_memory_opt());
|
||||
} else {
|
||||
output = at::zeros({nbins}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */);
|
||||
}
|
||||
|
||||
return bincount_mps_impl(self, weights_, output);
|
||||
|
@ -1,303 +1,281 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *minTensor = nil, *maxTensor = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *minTensor = nil, *maxTensor = nil;
|
||||
};
|
||||
|
||||
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor)
|
||||
{
|
||||
MPSGraph *mpsGraph = cachedGraph->graph();
|
||||
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
||||
cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
||||
|
||||
if (cachedGraph->minTensor && cachedGraph->maxTensor) {
|
||||
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
|
||||
minValueTensor:cachedGraph->minTensor
|
||||
maxValueTensor:cachedGraph->maxTensor
|
||||
name:nil];
|
||||
} else if (cachedGraph->maxTensor) {
|
||||
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
|
||||
secondaryTensor:cachedGraph->maxTensor
|
||||
name:nil];
|
||||
} else if (cachedGraph->minTensor) {
|
||||
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
|
||||
secondaryTensor:cachedGraph->minTensor
|
||||
name:nil];
|
||||
}
|
||||
if (cachedGraph->minTensor && cachedGraph->maxTensor) {
|
||||
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
|
||||
minValueTensor:cachedGraph->minTensor
|
||||
maxValueTensor:cachedGraph->maxTensor
|
||||
name:nil];
|
||||
} else if (cachedGraph->maxTensor) {
|
||||
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
|
||||
secondaryTensor:cachedGraph->maxTensor
|
||||
name:nil];
|
||||
} else if (cachedGraph->minTensor) {
|
||||
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
|
||||
secondaryTensor:cachedGraph->minTensor
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
|
||||
void check_min_max_dims(const OptionalTensorRef clamp_opt,
|
||||
const Tensor& input_t,
|
||||
string op_name) {
|
||||
void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, string op_name) {
|
||||
if (!clamp_opt->is_same_size(input_t)) {
|
||||
auto num_clamp_dims = clamp_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
|
||||
if(!clamp_opt->is_same_size(input_t)) {
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
auto num_clamp_dims = clamp_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims,
|
||||
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims, op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
|
||||
for(int i = 0; i < num_clamp_dims; i++)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims-1-i] == input_shape[num_input_dims-1-i] ||
|
||||
clamp_shape[num_clamp_dims-1-i] == 1 ||
|
||||
input_shape[num_input_dims-1-i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
|
||||
}
|
||||
for (int i = 0; i < num_clamp_dims; i++)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
|
||||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
}
|
||||
}
|
||||
|
||||
void fill_new_shape(int64_t num_input_dims,
|
||||
int64_t num_clamp_dims,
|
||||
int64_t *new_shape,
|
||||
IntArrayRef clamp_shape) {
|
||||
|
||||
// Extend the shape with ones to the left
|
||||
int clamp_idx = 0;
|
||||
for(int i = 0; i < num_input_dims; i++) {
|
||||
if(i < num_input_dims - num_clamp_dims)
|
||||
new_shape[i] = 1;
|
||||
else {
|
||||
new_shape[i] = clamp_shape[clamp_idx];
|
||||
clamp_idx++;
|
||||
}
|
||||
void fill_new_shape(int64_t num_input_dims, int64_t num_clamp_dims, int64_t* new_shape, IntArrayRef clamp_shape) {
|
||||
// Extend the shape with ones to the left
|
||||
int clamp_idx = 0;
|
||||
for (int i = 0; i < num_input_dims; i++) {
|
||||
if (i < num_input_dims - num_clamp_dims)
|
||||
new_shape[i] = 1;
|
||||
else {
|
||||
new_shape[i] = clamp_shape[clamp_idx];
|
||||
clamp_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
const OptionalTensorRef min_opt,
|
||||
const OptionalTensorRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name)
|
||||
{
|
||||
const bool has_min = (min_opt.has_value() && min_opt->defined());
|
||||
const bool has_max = (max_opt.has_value() && max_opt->defined());
|
||||
string op_name) {
|
||||
const bool has_min = (min_opt.has_value() && min_opt->defined());
|
||||
const bool has_max = (max_opt.has_value() && max_opt->defined());
|
||||
|
||||
TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both tensors must be defined")
|
||||
if (has_min)
|
||||
check_min_max_dims(min_opt, input_t, op_name);
|
||||
TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both tensors must be defined")
|
||||
if (has_min)
|
||||
check_min_max_dims(min_opt, input_t, op_name);
|
||||
|
||||
if (has_max)
|
||||
check_min_max_dims(max_opt, input_t, op_name);
|
||||
if (has_max)
|
||||
check_min_max_dims(max_opt, input_t, op_name);
|
||||
|
||||
if (output_t.numel() == 0)
|
||||
return;
|
||||
if (output_t.numel() == 0)
|
||||
return;
|
||||
|
||||
IntArrayRef new_min_shape;
|
||||
IntArrayRef new_max_shape;
|
||||
IntArrayRef new_min_shape;
|
||||
IntArrayRef new_max_shape;
|
||||
|
||||
auto num_min_dims = min_opt->dim();
|
||||
auto num_max_dims = max_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
auto num_min_dims = min_opt->dim();
|
||||
auto num_max_dims = max_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
|
||||
std::vector<int64_t> new_min_arr(num_input_dims);
|
||||
std::vector<int64_t> new_max_arr(num_input_dims);
|
||||
std::vector<int64_t> new_min_arr(num_input_dims);
|
||||
std::vector<int64_t> new_max_arr(num_input_dims);
|
||||
|
||||
if(has_min && num_min_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
|
||||
new_min_shape = IntArrayRef(new_min_arr);
|
||||
}
|
||||
if (has_min && num_min_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
|
||||
new_min_shape = IntArrayRef(new_min_arr);
|
||||
}
|
||||
|
||||
if(has_max && num_max_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
|
||||
new_max_shape = IntArrayRef(new_max_arr);
|
||||
}
|
||||
if (has_max && num_max_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
|
||||
new_max_shape = IntArrayRef(new_max_arr);
|
||||
}
|
||||
|
||||
Tensor min_opt_tensor;
|
||||
Tensor max_opt_tensor;
|
||||
Tensor min_opt_tensor;
|
||||
Tensor max_opt_tensor;
|
||||
|
||||
if(has_min) {
|
||||
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
|
||||
}
|
||||
if(has_max) {
|
||||
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
|
||||
}
|
||||
if (has_min) {
|
||||
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
|
||||
}
|
||||
if (has_max) {
|
||||
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
|
||||
auto tensor_key = has_min ? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor})
|
||||
: getTensorsStringKey({input_t, min_opt_tensor}))
|
||||
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor})
|
||||
: getTensorsStringKey({input_t}));
|
||||
auto tensor_key = has_min
|
||||
? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor})
|
||||
: getTensorsStringKey({input_t, min_opt_tensor}))
|
||||
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor}) : getTensorsStringKey({input_t}));
|
||||
|
||||
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "")
|
||||
+ "_tensor" + tensor_key;
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") + "_tensor" + tensor_key;
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor);
|
||||
if (has_max)
|
||||
newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor);
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor);
|
||||
if (has_max)
|
||||
newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor);
|
||||
|
||||
clamp_mps_graph(newCachedGraph, input_t);
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
clamp_mps_graph(newCachedGraph, input_t);
|
||||
}
|
||||
|
||||
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
|
||||
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
if (has_min) {
|
||||
auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor);
|
||||
feeds[minPlaceholder.getMPSGraphTensor()] = minPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
if (has_max) {
|
||||
auto maxPlaceholder = Placeholder(cachedGraph->maxTensor, max_opt_tensor);
|
||||
feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
if (has_min) {
|
||||
auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor);
|
||||
feeds[minPlaceholder.getMPSGraphTensor()] = minPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
if (has_max) {
|
||||
auto maxPlaceholder = Placeholder(cachedGraph->maxTensor, max_opt_tensor);
|
||||
feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
const OptionalScalarRef min_opt,
|
||||
const OptionalScalarRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name)
|
||||
{
|
||||
using scalar_t = double;
|
||||
const OptionalScalarRef min_opt,
|
||||
const OptionalScalarRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name) {
|
||||
using scalar_t = double;
|
||||
|
||||
const bool has_min = (min_opt.has_value());
|
||||
const bool has_max = (max_opt.has_value());
|
||||
TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both scalars must be defined")
|
||||
const bool has_min = (min_opt.has_value());
|
||||
const bool has_max = (max_opt.has_value());
|
||||
TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both scalars must be defined")
|
||||
|
||||
scalar_t min_scalar = std::numeric_limits<scalar_t>::infinity();
|
||||
scalar_t max_scalar = -std::numeric_limits<scalar_t>::infinity();
|
||||
scalar_t min_scalar = std::numeric_limits<scalar_t>::infinity();
|
||||
scalar_t max_scalar = -std::numeric_limits<scalar_t>::infinity();
|
||||
|
||||
if (has_min)
|
||||
min_scalar = min_opt.get().to<scalar_t>();
|
||||
if (has_max)
|
||||
max_scalar = max_opt.get().to<scalar_t>();
|
||||
if (has_min)
|
||||
min_scalar = min_opt.get().to<scalar_t>();
|
||||
if (has_max)
|
||||
max_scalar = max_opt.get().to<scalar_t>();
|
||||
|
||||
if (output_t.numel() == 0)
|
||||
return ;
|
||||
if (output_t.numel() == 0)
|
||||
return;
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + (has_max ? ("_max:" + to_string(max_scalar)) : "")
|
||||
+ "_scalar:" + getTensorsStringKey({input_t});
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
shape:(mps::getMPSShape(input_t))
|
||||
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];
|
||||
if (has_max)
|
||||
newCachedGraph->maxTensor = [mpsGraph constantWithScalar:max_scalar
|
||||
shape:(mps::getMPSShape(input_t))
|
||||
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph
|
||||
constantWithScalar:min_scalar
|
||||
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
|
||||
if (has_max)
|
||||
newCachedGraph->maxTensor = [mpsGraph
|
||||
constantWithScalar:max_scalar
|
||||
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
|
||||
|
||||
clamp_mps_graph(newCachedGraph, input_t);
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
clamp_mps_graph(newCachedGraph, input_t);
|
||||
}
|
||||
|
||||
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor , input_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
// APIs exposed to at::native scope
|
||||
TORCH_IMPL_FUNC(clamp_Tensor_out_mps)
|
||||
(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t)
|
||||
{
|
||||
mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__);
|
||||
(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t) {
|
||||
mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_out_mps)
|
||||
(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t)
|
||||
{
|
||||
mps::clamp_scalar_out_mps(input_t, min, max, const_cast<Tensor&>(output_t), "clamp_out_mps");
|
||||
(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t) {
|
||||
mps::clamp_scalar_out_mps(input_t, min, max, const_cast<Tensor&>(output_t), "clamp_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_min_Tensor_out_mps)
|
||||
(const Tensor& input_t, const Tensor& min, const Tensor& output_t)
|
||||
{
|
||||
mps::clamp_tensor_out_mps(input_t, min, at::OptionalTensorRef(), output_t, __func__);
|
||||
(const Tensor& input_t, const Tensor& min, const Tensor& output_t) {
|
||||
mps::clamp_tensor_out_mps(input_t, min, at::OptionalTensorRef(), output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_min_out_mps)
|
||||
(const Tensor& input_t, const Scalar& min, const Tensor& output_t)
|
||||
{
|
||||
mps::clamp_scalar_out_mps(input_t, min, at::OptionalScalarRef(), output_t, __func__);
|
||||
(const Tensor& input_t, const Scalar& min, const Tensor& output_t) {
|
||||
mps::clamp_scalar_out_mps(input_t, min, at::OptionalScalarRef(), output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_max_Tensor_out_mps)
|
||||
(const Tensor& input_t, const Tensor& max, const Tensor& output_t)
|
||||
{
|
||||
mps::clamp_tensor_out_mps(input_t, at::OptionalTensorRef(), max, output_t, __func__);
|
||||
(const Tensor& input_t, const Tensor& max, const Tensor& output_t) {
|
||||
mps::clamp_tensor_out_mps(input_t, at::OptionalTensorRef(), max, output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_max_out_mps)
|
||||
(const Tensor& input_t, const Scalar& max, const Tensor& output_t)
|
||||
{
|
||||
mps::clamp_scalar_out_mps(input_t, at::OptionalScalarRef(), max, output_t, __func__);
|
||||
(const Tensor& input_t, const Scalar& max, const Tensor& output_t) {
|
||||
mps::clamp_scalar_out_mps(input_t, at::OptionalScalarRef(), max, output_t, __func__);
|
||||
}
|
||||
|
||||
Tensor& where_self_out_mps(const Tensor& condition,
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
Tensor& out) {
|
||||
Tensor& where_self_out_mps(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
|
||||
|
||||
if (condition.scalar_type() == ScalarType::Byte) {
|
||||
TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
|
||||
TORCH_WARN_ONCE(
|
||||
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
|
||||
} else {
|
||||
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition.scalar_type());
|
||||
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool,
|
||||
"where expected condition to be a boolean tensor, but got a tensor with dtype ",
|
||||
condition.scalar_type());
|
||||
}
|
||||
Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
|
||||
|
||||
@ -305,13 +283,12 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Empty output
|
||||
if(out.numel() == 0)
|
||||
if (out.numel() == 0)
|
||||
return out;
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* conditionTensor_ = nil;
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
@ -326,57 +303,56 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
// Workaround for `selectWithPredicateTensor` on macOS Monterey where bool data type may cause a hang
|
||||
// The issue is fixed in macOS Ventura (13.0)
|
||||
if (!is_macos_13_or_newer()) {
|
||||
if (condition.scalar_type() == kBool) {
|
||||
if (condition.scalar_type() == kBool) {
|
||||
conditionDataType = MPSDataTypeInt8;
|
||||
}
|
||||
if (self.scalar_type() == kBool) {
|
||||
}
|
||||
if (self.scalar_type() == kBool) {
|
||||
selfDataType = MPSDataTypeInt8;
|
||||
}
|
||||
if (other.scalar_type() == kBool) {
|
||||
}
|
||||
if (other.scalar_type() == kBool) {
|
||||
otherDataType = MPSDataTypeInt8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other});
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
MPSGraphTensor* conditionTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
|
||||
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self));
|
||||
MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other));
|
||||
|
||||
MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
|
||||
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self));
|
||||
MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:conditionTensor
|
||||
truePredicateTensor:selfTensor
|
||||
falsePredicateTensor:otherTensor
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:conditionTensor
|
||||
truePredicateTensor:selfTensor
|
||||
falsePredicateTensor:otherTensor
|
||||
name:nil];
|
||||
|
||||
newCachedGraph->conditionTensor_ = conditionTensor;
|
||||
newCachedGraph->selfTensor_ = selfTensor;
|
||||
newCachedGraph->otherTensor_ = otherTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
newCachedGraph->conditionTensor_ = conditionTensor;
|
||||
newCachedGraph->selfTensor_ = selfTensor;
|
||||
newCachedGraph->otherTensor_ = otherTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder conditionPlaceholder = Placeholder(
|
||||
cachedGraph->conditionTensor_, cond_bool, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, conditionDataType);
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
|
||||
Placeholder otherPlaceholder = Placeholder(
|
||||
cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
|
||||
Placeholder selfPlaceholder =
|
||||
Placeholder(cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
|
||||
Placeholder otherPlaceholder =
|
||||
Placeholder(cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
@ -384,21 +360,16 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor where_mps(const Tensor& condition,
|
||||
const Tensor& self,
|
||||
const Tensor& other) {
|
||||
|
||||
Tensor where_mps(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
auto max_dim = std::max(condition.dim(), std::max(self.dim(), other.dim()));
|
||||
|
||||
// How many leading dimensions do we broadcast across for each Tensor?
|
||||
@ -409,8 +380,7 @@ Tensor where_mps(const Tensor& condition,
|
||||
std::vector<int64_t> out_arr(max_dim);
|
||||
|
||||
// Broadcasted output shape
|
||||
for(int i = 0; i < max_dim; i++) {
|
||||
|
||||
for (int i = 0; i < max_dim; i++) {
|
||||
// Use up the leading broadcast dimensions for each Tensor, then continue from the start of the "actual" shape
|
||||
int64_t cond_idx = i < cond_num_implicit_ones ? 1 : (condition.size(i - cond_num_implicit_ones));
|
||||
int64_t self_idx = i < self_num_implicit_ones ? 1 : (self.size(i - self_num_implicit_ones));
|
||||
@ -418,21 +388,28 @@ Tensor where_mps(const Tensor& condition,
|
||||
|
||||
auto max_idx = std::max({cond_idx, self_idx, other_idx});
|
||||
|
||||
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1), i, "'th index ", cond_idx, " of condition tensor does not match the other tensors")
|
||||
TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1), i, "'th index ", self_idx, " of x tensor does not match the other tensors")
|
||||
TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1), i, "'th index ", other_idx, " of x tensor does not match the other tensors")
|
||||
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1),
|
||||
i,
|
||||
"'th index ",
|
||||
cond_idx,
|
||||
" of condition tensor does not match the other tensors")
|
||||
TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1),
|
||||
i,
|
||||
"'th index ",
|
||||
self_idx,
|
||||
" of x tensor does not match the other tensors")
|
||||
TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1),
|
||||
i,
|
||||
"'th index ",
|
||||
other_idx,
|
||||
" of x tensor does not match the other tensors")
|
||||
|
||||
out_arr[i] = (cond_idx == 0 || self_idx == 0 || other_idx == 0) ? 0 : max_idx;
|
||||
}
|
||||
|
||||
Tensor ret = empty_mps(IntArrayRef(out_arr),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
self.suggest_memory_format());
|
||||
Tensor ret = empty_mps(
|
||||
IntArrayRef(out_arr), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
|
||||
return where_self_out_mps(condition, self, other, ret);
|
||||
|
||||
}
|
||||
|
||||
Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
@ -440,8 +417,11 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
c10::optional<double> pos_inf,
|
||||
c10::optional<double> neg_inf,
|
||||
Tensor& result) {
|
||||
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "nan_to_num: dtype of out: ",
|
||||
result.scalar_type(), " should be same as input: ", self.scalar_type());
|
||||
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
|
||||
"nan_to_num: dtype of out: ",
|
||||
result.scalar_type(),
|
||||
" should be same as input: ",
|
||||
self.scalar_type());
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
@ -452,7 +432,7 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
}
|
||||
using namespace mps;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
MPSGraphTensor* nanReplacementTensor = nil;
|
||||
@ -467,25 +447,27 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]);
|
||||
newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]);
|
||||
newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]);
|
||||
newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
|
||||
newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
|
||||
newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
|
||||
|
||||
MPSGraphTensor* nanFreeTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isNaNWithTensor: newCachedGraph->selfTensor name:nil]
|
||||
truePredicateTensor: newCachedGraph->nanReplacementTensor
|
||||
falsePredicateTensor: newCachedGraph->selfTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor: nanFreeTensor
|
||||
secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype]
|
||||
name: nil];
|
||||
MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor: nanFreeTensor name:nil];
|
||||
MPSGraphTensor* nanFreeTensor =
|
||||
[mpsGraph selectWithPredicateTensor:[mpsGraph isNaNWithTensor:newCachedGraph->selfTensor name:nil]
|
||||
truePredicateTensor:newCachedGraph->nanReplacementTensor
|
||||
falsePredicateTensor:newCachedGraph->selfTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor:nanFreeTensor
|
||||
secondaryTensor:[mpsGraph constantWithScalar:0.0
|
||||
dataType:self_dtype]
|
||||
name:nil];
|
||||
MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor:nanFreeTensor name:nil];
|
||||
// workaround for Monterey; On Ventura the output of lessThan() is always Boolean
|
||||
if (subZeroTensor.dataType != MPSDataTypeBool) {
|
||||
subZeroTensor = castMPSTensor(mpsGraph, subZeroTensor, kBool);
|
||||
@ -493,34 +475,33 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
if (isInfTensor.dataType != MPSDataTypeBool) {
|
||||
isInfTensor = castMPSTensor(mpsGraph, isInfTensor, kBool);
|
||||
}
|
||||
MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: subZeroTensor
|
||||
secondaryTensor: isInfTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor: isNegInfTensor
|
||||
truePredicateTensor: newCachedGraph->negInfReplacementTensor
|
||||
falsePredicateTensor: nanFreeTensor
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isInfiniteWithTensor: negInfFreeTensor name:nil]
|
||||
truePredicateTensor: newCachedGraph->posInfReplacementTensor
|
||||
falsePredicateTensor: negInfFreeTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor:subZeroTensor
|
||||
secondaryTensor:isInfTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor:isNegInfTensor
|
||||
truePredicateTensor:newCachedGraph->negInfReplacementTensor
|
||||
falsePredicateTensor:nanFreeTensor
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor =
|
||||
[mpsGraph selectWithPredicateTensor:[mpsGraph isInfiniteWithTensor:negInfFreeTensor name:nil]
|
||||
truePredicateTensor:newCachedGraph->posInfReplacementTensor
|
||||
falsePredicateTensor:negInfFreeTensor
|
||||
name:nil];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
}
|
||||
MPSScalar nanReplacementScalar, posInfReplacementScalar, negInfReplacementScalar;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, self.scalar_type(), "nan_to_num_mps", [&]() {
|
||||
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
|
||||
scalar_t pos_inf_replacement = pos_inf.has_value() ?
|
||||
static_cast<scalar_t>(pos_inf.value()) :
|
||||
std::numeric_limits<scalar_t>::max();
|
||||
scalar_t neg_inf_replacement = neg_inf.has_value() ?
|
||||
static_cast<scalar_t>(neg_inf.value()) :
|
||||
std::numeric_limits<scalar_t>::lowest();
|
||||
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
|
||||
scalar_t pos_inf_replacement =
|
||||
pos_inf.has_value() ? static_cast<scalar_t>(pos_inf.value()) : std::numeric_limits<scalar_t>::max();
|
||||
scalar_t neg_inf_replacement =
|
||||
neg_inf.has_value() ? static_cast<scalar_t>(neg_inf.value()) : std::numeric_limits<scalar_t>::lowest();
|
||||
|
||||
nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type());
|
||||
posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type());
|
||||
negInfReplacementScalar = getMPSScalar(neg_inf_replacement, self.scalar_type());
|
||||
nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type());
|
||||
posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type());
|
||||
negInfReplacementScalar = getMPSScalar(neg_inf_replacement, self.scalar_type());
|
||||
});
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@ -528,14 +509,13 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, result);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
cachedGraph->nanReplacementTensor : getMPSGraphTensorFromScalar(stream, nanReplacementScalar),
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
cachedGraph->nanReplacementTensor : getMPSGraphTensorFromScalar(stream, nanReplacementScalar),
|
||||
cachedGraph->posInfReplacementTensor : getMPSGraphTensorFromScalar(stream, posInfReplacementScalar),
|
||||
cachedGraph->negInfReplacementTensor : getMPSGraphTensorFromScalar(stream, negInfReplacementScalar),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return result;
|
||||
|
@ -14,10 +14,7 @@
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(triu_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t k,
|
||||
const Tensor &output) {
|
||||
|
||||
(const Tensor& self, int64_t k, const Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
if (self.numel() == 0) {
|
||||
@ -26,22 +23,21 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -50,12 +46,10 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
|
||||
|
||||
if(k > 0) {
|
||||
MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k-1)
|
||||
dataType:MPSDataTypeInt32];
|
||||
if (k > 0) {
|
||||
MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:minusOneTensor
|
||||
numUpperTensor:diagMinusOneTensor
|
||||
@ -63,10 +57,8 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:complementTensor
|
||||
name:nil];
|
||||
}
|
||||
else {
|
||||
MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k)
|
||||
dataType:MPSDataTypeInt32];
|
||||
} else {
|
||||
MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32];
|
||||
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:minusDiagTensor
|
||||
numUpperTensor:minusOneTensor
|
||||
@ -78,29 +70,23 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(tril_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t k,
|
||||
const Tensor &output) {
|
||||
|
||||
(const Tensor& self, int64_t k, const Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
if (self.numel() == 0) {
|
||||
@ -109,22 +95,21 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -133,20 +118,16 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
|
||||
|
||||
if(k >= 0) {
|
||||
MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k
|
||||
dataType:MPSDataTypeInt32];
|
||||
if (k >= 0) {
|
||||
MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32];
|
||||
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:minusOneTensor
|
||||
numUpperTensor:diagTensor
|
||||
name:nil];
|
||||
}
|
||||
else {
|
||||
MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k-1)
|
||||
dataType:MPSDataTypeInt32];
|
||||
} else {
|
||||
MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:negDiagMinusOneTensor
|
||||
numUpperTensor:minusOneTensor
|
||||
@ -161,22 +142,19 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
@ -9,14 +9,16 @@ namespace mps {
|
||||
typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*);
|
||||
using is_noop_p = std::function<bool(const Tensor&)>;
|
||||
|
||||
|
||||
bool is_empty_tensor(const Tensor& self) {
|
||||
return self.numel() == 0;
|
||||
}
|
||||
|
||||
void unary_op(const Tensor& self, const Tensor& output, std::string op_name, UnaryOpBlock unaryBlock, is_noop_p is_noop = is_empty_tensor)
|
||||
{
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte ),
|
||||
void unary_op(const Tensor& self,
|
||||
const Tensor& output,
|
||||
std::string op_name,
|
||||
UnaryOpBlock unaryBlock,
|
||||
is_noop_p is_noop = is_empty_tensor) {
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
|
||||
"MPS support unary op with uint8 natively starting from macOS 13.0");
|
||||
if (!output.is_same_size(self)) {
|
||||
output.resize_(self.sizes());
|
||||
@ -30,9 +32,9 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
|
||||
string key = op_name + getTensorsStringKey({self, output});
|
||||
auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);
|
||||
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^ MPSCachedGraph* () {
|
||||
MPSUnaryCachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
MPSUnaryCachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new MPSUnaryCachedGraph(mpsGraph);
|
||||
@ -55,18 +57,15 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, /*mpsShape=*/nullptr, gatherTensorData);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
{
|
||||
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// Rounding is a no-op for integral types, and also a reasonable workaround
|
||||
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
|
||||
// See https://github.com/pytorch/pytorch/issues/84995
|
||||
@ -75,100 +74,91 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
return inputTensor;
|
||||
}
|
||||
|
||||
if(!is_macos_13_or_newer()) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
if (!is_macos_13_or_newer()) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
name:nil];
|
||||
return [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil]
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
} else {
|
||||
return [mpsGraph truncateWithTensor:inputTensor
|
||||
name:nil];
|
||||
return [mpsGraph truncateWithTensor:inputTensor name:nil];
|
||||
}
|
||||
};
|
||||
|
||||
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:oneTensor
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:addedTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor secondaryTensor:oneTensor name:nil];
|
||||
return [mpsGraph logarithmWithTensor:addedTensor name:nil];
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
TORCH_IMPL_FUNC(trunc_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "trunc_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
{ return mps::trunc_tensor(mpsGraph, inputTensor); });
|
||||
TORCH_IMPL_FUNC(trunc_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "trunc_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return mps::trunc_tensor(mpsGraph, inputTensor);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(signbit_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "signbit_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* output;
|
||||
// signbit is not implemented for int64 type.
|
||||
// workaround for `Function signbitOp_i64 was not found in the library`
|
||||
if ([inputTensor dataType] == MPSDataTypeInt64) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
} else {
|
||||
output = [mpsGraph signbitWithTensor: inputTensor name: nil];
|
||||
}
|
||||
return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool);
|
||||
});
|
||||
TORCH_IMPL_FUNC(signbit_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "signbit_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* output;
|
||||
// signbit is not implemented for int64 type.
|
||||
// workaround for `Function signbitOp_i64 was not found in the library`
|
||||
if ([inputTensor dataType] == MPSDataTypeInt64) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
|
||||
} else {
|
||||
output = [mpsGraph signbitWithTensor:inputTensor name:nil];
|
||||
}
|
||||
return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(sign_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "sign_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// Sign op is not implemented in MPS as of MacOS13.0 beta, so simulate it using clamp
|
||||
if ([inputTensor dataType] == MPSDataTypeInt64) {
|
||||
return [mpsGraph clampWithTensor:inputTensor
|
||||
minValueTensor:[mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt64]
|
||||
maxValueTensor:[mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt64]
|
||||
name: nil];
|
||||
}
|
||||
return [mpsGraph signWithTensor: inputTensor name: nil];
|
||||
});
|
||||
TORCH_IMPL_FUNC(sign_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "sign_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// Sign op is not implemented in MPS as of MacOS13.0 beta, so simulate it using clamp
|
||||
if ([inputTensor dataType] == MPSDataTypeInt64) {
|
||||
return [mpsGraph clampWithTensor:inputTensor
|
||||
minValueTensor:[mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt64]
|
||||
maxValueTensor:[mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt64]
|
||||
name:nil];
|
||||
}
|
||||
return [mpsGraph signWithTensor:inputTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
#define CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, \
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
|
||||
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }, \
|
||||
[](const Tensor& t) -> bool { \
|
||||
return t.numel() == 0 || isIntegralType(t.scalar_type(), true); \
|
||||
}); \
|
||||
}
|
||||
#define CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \
|
||||
mps::unary_op( \
|
||||
self, \
|
||||
output, \
|
||||
#func_out, \
|
||||
^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
|
||||
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
|
||||
}, \
|
||||
[](const Tensor& t) -> bool { return t.numel() == 0 || isIntegralType(t.scalar_type(), true); }); \
|
||||
}
|
||||
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(ceil_out_mps, ceil)
|
||||
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(floor_out_mps, floor)
|
||||
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(round_out_mps, round)
|
||||
|
||||
#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, \
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
|
||||
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \
|
||||
}
|
||||
|
||||
#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
Tensor& func_out(const Tensor& self, Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, \
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
|
||||
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \
|
||||
return output; \
|
||||
}
|
||||
#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
|
||||
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
|
||||
}); \
|
||||
}
|
||||
|
||||
#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
Tensor& func_out(const Tensor& self, Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
|
||||
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
|
||||
}); \
|
||||
return output; \
|
||||
}
|
||||
|
||||
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
|
||||
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
|
||||
@ -195,139 +185,104 @@ CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
|
||||
|
||||
CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute)
|
||||
|
||||
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output)
|
||||
{
|
||||
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
|
||||
auto bool_self = self.to(ScalarType::Bool);
|
||||
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor){ return [mpsGraph notWithTensor:inputTensor name:nil];});
|
||||
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph notWithTensor:inputTensor name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(sigmoid_out_mps) (const Tensor& self, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(sigmoid_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support sigmoid op with int64 input");
|
||||
mps::unary_op(self, output, "sigmoid_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph sigmoidWithTensor:inputTensor name:nil];
|
||||
});
|
||||
mps::unary_op(self, output, "sigmoid_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph sigmoidWithTensor:inputTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(log1p_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input");
|
||||
mps::unary_op(self, output, "log1p_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return mps::log1p(mpsGraph, inputTensor);
|
||||
});
|
||||
mps::unary_op(self, output, "log1p_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return mps::log1p(mpsGraph, inputTensor);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
|
||||
mps::unary_op(self, output, "frac_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
auto truncTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:truncTensor
|
||||
name: nil];
|
||||
});
|
||||
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
|
||||
auto truncTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil]
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:truncTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(expm1_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "expm1_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:ePowTensor
|
||||
secondaryTensor:oneTensor
|
||||
name: nil];
|
||||
});
|
||||
TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "expm1_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:ePowTensor secondaryTensor:oneTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& output, const std::string op_name) {
|
||||
std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]";
|
||||
|
||||
mps::unary_op(self, output, key,
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* logitInputTensor;
|
||||
mps::unary_op(self, output, key, ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* logitInputTensor;
|
||||
|
||||
if (eps.has_value()) {
|
||||
MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value()
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
|
||||
secondaryTensor: lowTensor
|
||||
name: nil];
|
||||
logitInputTensor = [mpsGraph clampWithTensor:inputTensor
|
||||
minValueTensor:lowTensor
|
||||
maxValueTensor:highTensor
|
||||
name:nil];
|
||||
} else {
|
||||
logitInputTensor = inputTensor;
|
||||
}
|
||||
if (eps.has_value()) {
|
||||
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps.value() shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor secondaryTensor:lowTensor name:nil];
|
||||
logitInputTensor = [mpsGraph clampWithTensor:inputTensor
|
||||
minValueTensor:lowTensor
|
||||
maxValueTensor:highTensor
|
||||
name:nil];
|
||||
} else {
|
||||
logitInputTensor = inputTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
|
||||
secondaryTensor: logitInputTensor
|
||||
name: nil];
|
||||
MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
|
||||
secondaryTensor:oneMinusInputTensor
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:outputTensor
|
||||
name:nil];
|
||||
});
|
||||
MPSGraphTensor* oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor
|
||||
secondaryTensor:logitInputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
|
||||
secondaryTensor:oneMinusInputTensor
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:outputTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
Tensor& logit_out_mps(const Tensor& self,
|
||||
c10::optional<double> eps,
|
||||
Tensor& result) {
|
||||
Tensor& logit_out_mps(const Tensor& self, c10::optional<double> eps, Tensor& result) {
|
||||
logit_mps_impl(self, eps, result, "logit_out_mps");
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
|
||||
Tensor result = at::native::empty_mps(
|
||||
self.sizes(),
|
||||
ScalarType::Float,
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
Tensor result =
|
||||
at::native::empty_mps(self.sizes(), ScalarType::Float, c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
logit_mps_impl(self, eps, result, "logit_mps");
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
c10::optional<double> eps,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(logit_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, c10::optional<double> eps, const Tensor& grad_input) {
|
||||
using namespace mps;
|
||||
|
||||
// Empty output
|
||||
if(grad_input.numel() == 0)
|
||||
if (grad_input.numel() == 0)
|
||||
return;
|
||||
|
||||
double eps_ = eps ? eps.value() : -1.0;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *gradOutputTensor_ = nil;
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *outputTensor_ = nil;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@ -335,14 +290,13 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" +
|
||||
"[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]";
|
||||
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + "[" +
|
||||
(eps.has_value() ? std::to_string(eps.value()) : "-1") + "]";
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -351,40 +305,32 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input);
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor
|
||||
secondaryTensor: lowTensor
|
||||
name: nil];
|
||||
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
|
||||
secondaryTensor: lowTensor
|
||||
name: nil];
|
||||
MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor
|
||||
secondaryTensor: highTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor
|
||||
secondaryTensor: inputGreaterThanHighPredicateTensor
|
||||
name: nil];
|
||||
MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
|
||||
secondaryTensor: inputTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:lowTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor
|
||||
secondaryTensor:lowTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:highTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor:inputLessThanLowPredicateTensor
|
||||
secondaryTensor:inputGreaterThanHighPredicateTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor
|
||||
secondaryTensor:inputTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:oneMinusInputTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor
|
||||
secondaryTensor:outputTensor
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor secondaryTensor:outputTensor name:nil];
|
||||
outputTensor = [mpsGraph selectWithPredicateTensor:outOfIntervalTensor
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:outputTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor
|
||||
truePredicateTensor: zeroTensor
|
||||
falsePredicateTensor: outputTensor
|
||||
name: nil];
|
||||
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -392,7 +338,7 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
|
||||
@ -403,25 +349,25 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
TORCH_IMPL_FUNC(cumsum_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<ScalarType> dtype,
|
||||
const Tensor& result) {
|
||||
|
||||
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
auto nDims = self.dim();
|
||||
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
|
||||
TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
|
||||
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
|
||||
"Expected wrapped dim to be between 0 and ",
|
||||
self.ndimension(),
|
||||
" but got ",
|
||||
wrapped_dim,
|
||||
"(original dim is ",
|
||||
dim,
|
||||
")");
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
|
||||
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
|
||||
@ -430,29 +376,27 @@ TORCH_IMPL_FUNC(cumsum_out_mps)
|
||||
}
|
||||
auto input = dtype.has_value() ? self.to(dtype.value()) : self;
|
||||
|
||||
// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
|
||||
// fixed in macOS 13.3
|
||||
bool castInputData = (isIntegralType(input.scalar_type()) &&
|
||||
input.scalar_type() != ScalarType::Int &&
|
||||
// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to
|
||||
// int32 fixed in macOS 13.3
|
||||
bool castInputData = (isIntegralType(input.scalar_type()) && input.scalar_type() != ScalarType::Int &&
|
||||
input.scalar_type() != ScalarType::Long);
|
||||
|
||||
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
|
||||
"MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");
|
||||
|
||||
mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
|
||||
if (castInputData) {
|
||||
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
|
||||
}
|
||||
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
|
||||
axis: dim
|
||||
name: nil];
|
||||
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
|
||||
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
|
||||
}
|
||||
return rc;
|
||||
});
|
||||
mps::unary_op(input,
|
||||
result,
|
||||
"cumsum_out_mp" + std::to_string(dim),
|
||||
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
if (castInputData) {
|
||||
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
|
||||
}
|
||||
auto rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
|
||||
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
|
||||
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
|
||||
}
|
||||
return rc;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,15 +1,14 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct UniqueCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
UniqueCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct UniqueCachedGraph : public MPSCachedGraph {
|
||||
UniqueCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
MPSGraphTensor* inverseIndicesTensor_ = nil;
|
||||
@ -17,230 +16,201 @@ struct UniqueCachedGraph : public MPSCachedGraph
|
||||
MPSGraphTensor* lengthTensor_ = nil;
|
||||
};
|
||||
|
||||
static std::string getUniqueKey(const ScalarType& dtype, const IntArrayRef& base_shape,
|
||||
const bool return_inverse, const bool return_counts,
|
||||
const bool consecutive, c10::optional<int64_t> dimOpt)
|
||||
{
|
||||
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) +
|
||||
"]:[" + (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) +
|
||||
"]:[" + to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
|
||||
static std::string getUniqueKey(const ScalarType& dtype,
|
||||
const IntArrayRef& base_shape,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
(dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" +
|
||||
to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
|
||||
}
|
||||
|
||||
// dim arg not supported when non consecutive, ie sorted
|
||||
std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCachedGraph *uniqueGraph, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dimOpt) {
|
||||
std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self,
|
||||
UniqueCachedGraph* uniqueGraph,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0;
|
||||
|
||||
MPSGraph *graph = uniqueGraph->graph();
|
||||
MPSGraphTensor *inputTensor = uniqueGraph->inputTensor_;
|
||||
MPSShape *shape = [inputTensor shape];
|
||||
MPSShape *destShape = shape;
|
||||
MPSGraph* graph = uniqueGraph->graph();
|
||||
MPSGraphTensor* inputTensor = uniqueGraph->inputTensor_;
|
||||
MPSShape* shape = [inputTensor shape];
|
||||
MPSShape* destShape = shape;
|
||||
uint64_t length = [shape[dim] unsignedIntValue];
|
||||
MPSDataType dataType = [inputTensor dataType];
|
||||
|
||||
MPSGraphTensor *resultTensor = nil;
|
||||
MPSGraphTensor *inverseIndicesTensor = nil;
|
||||
MPSGraphTensor *countTensor = nil;
|
||||
MPSGraphTensor *lengthTensor = nil;
|
||||
MPSGraphTensor* resultTensor = nil;
|
||||
MPSGraphTensor* inverseIndicesTensor = nil;
|
||||
MPSGraphTensor* countTensor = nil;
|
||||
MPSGraphTensor* lengthTensor = nil;
|
||||
if (length <= 1) {
|
||||
// Trivial case, only 1 element everything is unique
|
||||
resultTensor = inputTensor;
|
||||
lengthTensor = [graph constantWithScalar:0.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
lengthTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
|
||||
if (return_inverse) {
|
||||
inverseIndicesTensor = [graph constantWithScalar:0.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
inverseIndicesTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
|
||||
}
|
||||
if (return_counts) {
|
||||
countTensor = [graph constantWithScalar:1.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
countTensor = [graph constantWithScalar:1.0f dataType:MPSDataTypeInt32];
|
||||
}
|
||||
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
|
||||
}
|
||||
|
||||
// #issue 104398441 sortWithTensor only supports following types, cast if necessary
|
||||
if (dataType != MPSDataTypeInt32 &&
|
||||
dataType != MPSDataTypeFloat32 &&
|
||||
dataType != MPSDataTypeFloat16) {
|
||||
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
inputTensor = [graph castTensor:inputTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
|
||||
bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
|
||||
if (needsFlatten) {
|
||||
inputTensor = [graph reshapeTensor:inputTensor
|
||||
withShape:@[@-1]
|
||||
name:nil];
|
||||
inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
|
||||
length = 1;
|
||||
for (const auto i: c10::irange([shape count])) {
|
||||
for (const auto i : c10::irange([shape count])) {
|
||||
if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
|
||||
TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
|
||||
}
|
||||
}
|
||||
|
||||
destShape = @[[NSNumber numberWithUnsignedInteger:length]];
|
||||
destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
|
||||
}
|
||||
|
||||
MPSGraphTensor *sortedInput = nil;
|
||||
MPSGraphTensor* sortedInput = nil;
|
||||
if (consecutive) {
|
||||
sortedInput = inputTensor;
|
||||
} else {
|
||||
sortedInput = [graph sortWithTensor:inputTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
sortedInput = [graph sortWithTensor:inputTensor axis:0 name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor *frontNMinusOne = [graph sliceTensor:sortedInput
|
||||
dimension:dim
|
||||
start:0
|
||||
length:length-1
|
||||
name:nil];
|
||||
MPSGraphTensor *backNMinusOne = [graph sliceTensor:sortedInput
|
||||
dimension:dim
|
||||
start:1
|
||||
length:length-1
|
||||
name:nil];
|
||||
MPSGraphTensor *notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne
|
||||
MPSGraphTensor* frontNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:0 length:length - 1 name:nil];
|
||||
MPSGraphTensor* backNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:1 length:length - 1 name:nil];
|
||||
MPSGraphTensor* notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne
|
||||
secondaryTensor:frontNMinusOne
|
||||
name:nil];
|
||||
MPSGraphTensor *mask = [graph castTensor:notEqualToPreviousElement
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castMaskTensor"];
|
||||
MPSGraphTensor* mask = [graph castTensor:notEqualToPreviousElement toType:MPSDataTypeInt32 name:@"castMaskTensor"];
|
||||
|
||||
// If comparing tensors, not scalars, check if entire tensor matches previos element using reductionOr over tensor
|
||||
if (dimOpt.has_value() && [shape count] != 1) {
|
||||
NSMutableArray *axes = [[NSMutableArray alloc] initWithCapacity:[shape count]-1];
|
||||
NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1];
|
||||
for (const auto axis : c10::irange([shape count])) {
|
||||
if (axis != dim) {
|
||||
[axes addObject:[NSNumber numberWithUnsignedInteger:axis]];
|
||||
}
|
||||
}
|
||||
mask = [graph reductionOrWithTensor:mask
|
||||
axes:axes
|
||||
name:nil];
|
||||
mask = [graph squeezeTensor:mask
|
||||
axes:axes
|
||||
name:nil];
|
||||
mask = [graph reductionOrWithTensor:mask axes:axes name:nil];
|
||||
mask = [graph squeezeTensor:mask axes:axes name:nil];
|
||||
[axes release];
|
||||
}
|
||||
|
||||
MPSGraphTensor *scannedIndices = [graph cumulativeSumWithTensor:mask
|
||||
axis:0
|
||||
name:nil];
|
||||
lengthTensor = [graph sliceTensor:scannedIndices
|
||||
dimension:0
|
||||
start:length-2
|
||||
length:1
|
||||
name:nil];
|
||||
MPSGraphTensor* scannedIndices = [graph cumulativeSumWithTensor:mask axis:0 name:nil];
|
||||
lengthTensor = [graph sliceTensor:scannedIndices dimension:0 start:length - 2 length:1 name:nil];
|
||||
|
||||
MPSGraphTensor *minusOneTensor = [graph constantWithScalar:-1.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor *maskedIndices = [graph selectWithPredicateTensor:mask
|
||||
MPSGraphTensor* minusOneTensor = [graph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* maskedIndices = [graph selectWithPredicateTensor:mask
|
||||
truePredicateTensor:scannedIndices
|
||||
falsePredicateTensor:minusOneTensor
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor *zeroTensor = [graph constantWithScalar:0.0f
|
||||
shape:@[@1]
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor *maskedIndicesWithHead = [graph concatTensors:@[zeroTensor, maskedIndices]
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor *scannedIndicesWithHead = [graph concatTensors:@[zeroTensor, scannedIndices]
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* zeroTensor = [graph constantWithScalar:0.0f shape:@[ @1 ] dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* maskedIndicesWithHead = [graph concatTensors:@[ zeroTensor, maskedIndices ] dimension:0 name:nil];
|
||||
MPSGraphTensor* scannedIndicesWithHead = [graph concatTensors:@[ zeroTensor, scannedIndices ] dimension:0 name:nil];
|
||||
|
||||
resultTensor = [graph scatterWithUpdatesTensor:sortedInput
|
||||
indicesTensor:maskedIndicesWithHead
|
||||
shape:destShape
|
||||
axis:dim
|
||||
mode:MPSGraphScatterModeSet
|
||||
name:nil];
|
||||
indicesTensor:maskedIndicesWithHead
|
||||
shape:destShape
|
||||
axis:dim
|
||||
mode:MPSGraphScatterModeSet
|
||||
name:nil];
|
||||
// Cast back if necessary
|
||||
if ([uniqueGraph->inputTensor_ dataType] != dataType) {
|
||||
resultTensor = [graph castTensor:resultTensor
|
||||
toType:[uniqueGraph->inputTensor_ dataType]
|
||||
name:@"castResultTensor"];
|
||||
resultTensor = [graph castTensor:resultTensor toType:[uniqueGraph->inputTensor_ dataType] name:@"castResultTensor"];
|
||||
}
|
||||
|
||||
// Compute optional returned tensors if requested
|
||||
if(return_inverse) {
|
||||
MPSGraphTensor *argSortedInput = nil;
|
||||
if (return_inverse) {
|
||||
MPSGraphTensor* argSortedInput = nil;
|
||||
if (consecutive)
|
||||
argSortedInput = [graph coordinateAlongAxis:0
|
||||
withShape:@[[NSNumber numberWithUnsignedInteger:length]]
|
||||
withShape:@[ [NSNumber numberWithUnsignedInteger:length] ]
|
||||
name:nil];
|
||||
else
|
||||
argSortedInput = [graph argSortWithTensor:inputTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
argSortedInput = [graph argSortWithTensor:inputTensor axis:0 name:nil];
|
||||
inverseIndicesTensor = [graph scatterWithUpdatesTensor:scannedIndicesWithHead
|
||||
indicesTensor:argSortedInput
|
||||
shape:@[[NSNumber numberWithUnsignedInteger:length]]
|
||||
axis:0
|
||||
mode:MPSGraphScatterModeAdd
|
||||
name:nil];
|
||||
if (needsFlatten)
|
||||
inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor
|
||||
withShape:shape
|
||||
name:nil];
|
||||
}
|
||||
|
||||
if (return_counts) {
|
||||
MPSGraphTensor *unitTensor = [graph constantWithScalar:1.0f
|
||||
shape:@[[NSNumber numberWithUnsignedInteger:length]]
|
||||
dataType:MPSDataTypeInt32];
|
||||
countTensor = [graph scatterWithUpdatesTensor:unitTensor
|
||||
indicesTensor:scannedIndicesWithHead
|
||||
shape:@[[NSNumber numberWithUnsignedInteger:length]]
|
||||
indicesTensor:argSortedInput
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
|
||||
axis:0
|
||||
mode:MPSGraphScatterModeAdd
|
||||
name:nil];
|
||||
if (needsFlatten)
|
||||
inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor withShape:shape name:nil];
|
||||
}
|
||||
|
||||
if (return_counts) {
|
||||
MPSGraphTensor* unitTensor = [graph constantWithScalar:1.0f
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
countTensor = [graph scatterWithUpdatesTensor:unitTensor
|
||||
indicesTensor:scannedIndicesWithHead
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
|
||||
axis:0
|
||||
mode:MPSGraphScatterModeAdd
|
||||
name:nil];
|
||||
}
|
||||
|
||||
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
|
||||
}
|
||||
|
||||
static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dim) {
|
||||
static UniqueCachedGraph* getUniqueGraph(const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dim) {
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim);
|
||||
UniqueCachedGraph* cachedGraph = static_cast<UniqueCachedGraph *>(cache_->LookUp(key));
|
||||
if(!cachedGraph) {
|
||||
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
UniqueCachedGraph* cachedGraph = static_cast<UniqueCachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
UniqueCachedGraph* newCachedGraph = nil;
|
||||
|
||||
UniqueCachedGraph *newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
// Initialize graph
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new UniqueCachedGraph(mpsGraph);
|
||||
|
||||
@autoreleasepool {
|
||||
// Initialize graph
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new UniqueCachedGraph(mpsGraph);
|
||||
// Workaround for MPSShaderLibrary bug
|
||||
// TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved
|
||||
auto inputType = getMPSScalarType(self.scalar_type());
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self.sizes()));
|
||||
|
||||
// Workaround for MPSShaderLibrary bug
|
||||
// TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved
|
||||
auto inputType = getMPSScalarType(self.scalar_type());
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self.sizes()));
|
||||
auto outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim);
|
||||
|
||||
auto outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim);
|
||||
|
||||
newCachedGraph->outputTensor_ = outputTensors[0];
|
||||
newCachedGraph->inverseIndicesTensor_ = outputTensors[1];
|
||||
newCachedGraph->countsTensor_ = outputTensors[2];
|
||||
newCachedGraph->lengthTensor_ = outputTensors[3];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<UniqueCachedGraph *>(tmpCachedGraph);
|
||||
}
|
||||
newCachedGraph->outputTensor_ = outputTensors[0];
|
||||
newCachedGraph->inverseIndicesTensor_ = outputTensors[1];
|
||||
newCachedGraph->countsTensor_ = outputTensors[2];
|
||||
newCachedGraph->lengthTensor_ = outputTensors[3];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
cachedGraph = static_cast<UniqueCachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
return cachedGraph;
|
||||
}
|
||||
}
|
||||
|
||||
void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& output,
|
||||
Tensor& inverse_indices, Tensor& counts, Tensor& length,
|
||||
bool return_inverse, bool return_counts){
|
||||
void runUniqueGraph(UniqueCachedGraph* uniqueGraph,
|
||||
const Tensor& input,
|
||||
Tensor& output,
|
||||
Tensor& inverse_indices,
|
||||
Tensor& counts,
|
||||
Tensor& length,
|
||||
bool return_inverse,
|
||||
bool return_counts) {
|
||||
Placeholder inputPlaceholder = Placeholder(uniqueGraph->inputTensor_, input);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
@ -249,10 +219,8 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary];
|
||||
Placeholder outputPlaceholder = Placeholder(uniqueGraph->outputTensor_, output);
|
||||
Placeholder lengthPlaceholder = Placeholder(uniqueGraph->lengthTensor_, length);
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData()
|
||||
forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:lengthPlaceholder.getMPSGraphTensorData()
|
||||
forKey:lengthPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:lengthPlaceholder.getMPSGraphTensorData() forKey:lengthPlaceholder.getMPSGraphTensor()];
|
||||
if (return_inverse) {
|
||||
Placeholder inverseIndicesPlaceholder = Placeholder(uniqueGraph->inverseIndicesTensor_, inverse_indices);
|
||||
[results setObject:inverseIndicesPlaceholder.getMPSGraphTensorData()
|
||||
@ -260,8 +228,7 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
|
||||
}
|
||||
if (return_counts) {
|
||||
Placeholder countsPlaceholder = Placeholder(uniqueGraph->countsTensor_, counts);
|
||||
[results setObject:countsPlaceholder.getMPSGraphTensorData()
|
||||
forKey:countsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:countsPlaceholder.getMPSGraphTensorData() forKey:countsPlaceholder.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
// Run the graph
|
||||
@ -271,9 +238,11 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
|
||||
|
||||
} // namespace mps
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique_impl_mps(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dimOpt) {
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique_impl_mps(const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
const Tensor& input = self.contiguous();
|
||||
|
||||
// get flat output size
|
||||
@ -303,7 +272,7 @@ _unique_impl_mps(const Tensor& self, const bool return_inverse, const bool retur
|
||||
return std::make_tuple(output, inverse_indices, counts);
|
||||
}
|
||||
|
||||
mps::UniqueCachedGraph *uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt);
|
||||
mps::UniqueCachedGraph* uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt);
|
||||
mps::runUniqueGraph(uniqueGraph, input, output, inverse_indices, counts, length, return_inverse, return_counts);
|
||||
|
||||
int64_t lengthScalar = length.item<int64_t>() + 1; // length actually holds max index, add 1
|
||||
@ -316,17 +285,14 @@ _unique_impl_mps(const Tensor& self, const bool return_inverse, const bool retur
|
||||
return std::make_tuple(output, inverse_indices, counts);
|
||||
}
|
||||
|
||||
|
||||
static
|
||||
std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
|
||||
return std::make_tuple(
|
||||
get<0>(out).to("mps"),
|
||||
get<1>(out).to("mps"),
|
||||
get<2>(out).to("mps"));
|
||||
static std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
|
||||
return std::make_tuple(get<0>(out).to("mps"), get<1>(out).to("mps"), get<2>(out).to("mps"));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
|
||||
std::tuple<Tensor, Tensor, Tensor> unique_consecutive_mps(const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
c10::optional<int64_t> dim) {
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("MPS: unique_consecutive op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performace implications.");
|
||||
@ -336,8 +302,10 @@ unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, true, dim);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_inverse, const bool return_counts) {
|
||||
std::tuple<Tensor, Tensor, Tensor> unique_dim_consecutive_mps(const Tensor& self,
|
||||
int64_t dim,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("MPS: unique_dim_consecutive op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performace implications.");
|
||||
@ -347,8 +315,10 @@ unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_in
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, true, c10::make_optional((int64_t)dim));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique2_mps(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
|
||||
const bool sorted,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("MPS: _unique2 op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performace implications.");
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/UpSample.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
@ -20,7 +20,7 @@ void upsample_out_template(const Tensor& input,
|
||||
if (input.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
const auto input_dim = input.sizes();
|
||||
const auto input_dim = input.sizes();
|
||||
if (input_dim.size() <= 3) {
|
||||
native::upsample_1d_common_check(input.sizes(), output_size);
|
||||
} else {
|
||||
@ -34,9 +34,8 @@ void upsample_out_template(const Tensor& input,
|
||||
bool centerResults = false;
|
||||
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
|
||||
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
|
||||
MPSGraphTensorNamedDataLayout dataLayout = input_dim.size() > 3 ?
|
||||
MPSGraphTensorNamedDataLayoutNCHW :
|
||||
MPSGraphTensorNamedDataLayoutCHW;
|
||||
MPSGraphTensorNamedDataLayout dataLayout =
|
||||
input_dim.size() > 3 ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutCHW;
|
||||
if (resize_mode_str == "nearest") {
|
||||
resizeMode = MPSGraphResizeNearest;
|
||||
} else if (resize_mode_str == "bilinear") {
|
||||
@ -50,7 +49,7 @@ void upsample_out_template(const Tensor& input,
|
||||
}
|
||||
|
||||
const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
|
||||
const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0];
|
||||
const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0];
|
||||
const int64_t output_height = output_size.size() > 1 ? output_size[0] : 1;
|
||||
const float scale_w = (scale_w_opt.value_or(0.) > 0.) ? static_cast<float>(scale_w_opt.value()) : 0.;
|
||||
const float scale_h = (scale_h_opt.value_or(0.) > 0.) ? static_cast<float>(scale_h_opt.value()) : 1.;
|
||||
@ -63,37 +62,37 @@ void upsample_out_template(const Tensor& input,
|
||||
input_size = input_size_opt.value();
|
||||
}
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *outputSizeTensor = nil;
|
||||
MPSGraphTensor* outputSizeTensor = nil;
|
||||
};
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
|
||||
getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" +
|
||||
(is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
|
||||
getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" +
|
||||
(is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if(!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
CachedGraph *newCachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(2)]);
|
||||
newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(2) ]);
|
||||
|
||||
MPSGraphTensor* scaleOffsetTensor = nullptr;
|
||||
MPSGraphTensor* inputSizeTensor = nullptr;
|
||||
|
||||
if (scale_w > 0.0) {
|
||||
const float outScales[4] = {scale_h, scale_w, offset_y, offset_x};
|
||||
scaleOffsetTensor = [mpsGraph constantWithData: [NSData dataWithBytes: outScales length: sizeof(outScales)]
|
||||
shape: @[@4]
|
||||
dataType: MPSDataTypeFloat32];
|
||||
scaleOffsetTensor = [mpsGraph constantWithData:[NSData dataWithBytes:outScales length:sizeof(outScales)]
|
||||
shape:@[ @4 ]
|
||||
dataType:MPSDataTypeFloat32];
|
||||
}
|
||||
if (is_backward_pass) {
|
||||
std::vector<NSNumber*> inputSizeVec(4);
|
||||
@ -101,118 +100,119 @@ void upsample_out_template(const Tensor& input,
|
||||
inputSizeVec[1] = @(input_size[1]);
|
||||
inputSizeVec[2] = @(input_size[2]);
|
||||
inputSizeVec[3] = @(input_dim.size() > 3 ? input_size[3] : 1);
|
||||
inputSizeTensor = [mpsGraph constantWithScalar: 0
|
||||
shape: [NSArray arrayWithObjects:inputSizeVec.data() count:input_dim.size()]
|
||||
dataType: getMPSDataType(input)];
|
||||
inputSizeTensor = [mpsGraph constantWithScalar:0
|
||||
shape:[NSArray arrayWithObjects:inputSizeVec.data()
|
||||
count:input_dim.size()]
|
||||
dataType:getMPSDataType(input)];
|
||||
}
|
||||
if (is_macOS_13_0_or_newer) {
|
||||
if (!is_backward_pass) {
|
||||
if (scaleOffsetTensor && !align_corners) {
|
||||
if (resizeMode == MPSGraphResizeNearest) {
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor
|
||||
sizeTensor: newCachedGraph->outputSizeTensor
|
||||
scaleOffsetTensor: scaleOffsetTensor
|
||||
nearestRoundingMode: nearestRoundingMode
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
|
||||
sizeTensor:newCachedGraph->outputSizeTensor
|
||||
scaleOffsetTensor:scaleOffsetTensor
|
||||
nearestRoundingMode:nearestRoundingMode
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
} else { // bilinear forward
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor
|
||||
sizeTensor: newCachedGraph->outputSizeTensor
|
||||
scaleOffsetTensor: scaleOffsetTensor
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor
|
||||
sizeTensor:newCachedGraph->outputSizeTensor
|
||||
scaleOffsetTensor:scaleOffsetTensor
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
}
|
||||
} else { // scaleOffsetTensor == nil || align_corners
|
||||
if (resizeMode == MPSGraphResizeNearest) {
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor
|
||||
sizeTensor: newCachedGraph->outputSizeTensor
|
||||
nearestRoundingMode: nearestRoundingMode
|
||||
centerResult: centerResults
|
||||
alignCorners: align_corners
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
|
||||
sizeTensor:newCachedGraph->outputSizeTensor
|
||||
nearestRoundingMode:nearestRoundingMode
|
||||
centerResult:centerResults
|
||||
alignCorners:align_corners
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
} else { // bilinear forward
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor
|
||||
sizeTensor: newCachedGraph->outputSizeTensor
|
||||
centerResult: centerResults
|
||||
alignCorners: align_corners
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor
|
||||
sizeTensor:newCachedGraph->outputSizeTensor
|
||||
centerResult:centerResults
|
||||
alignCorners:align_corners
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
} else { // is_backward_pass == true
|
||||
if (scaleOffsetTensor && !align_corners) {
|
||||
if (resizeMode == MPSGraphResizeNearest) {
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor
|
||||
input: inputSizeTensor
|
||||
scaleOffsetTensor: scaleOffsetTensor
|
||||
nearestRoundingMode: nearestRoundingMode
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor:newCachedGraph->inputTensor
|
||||
input:inputSizeTensor
|
||||
scaleOffsetTensor:scaleOffsetTensor
|
||||
nearestRoundingMode:nearestRoundingMode
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
} else { // bilinear backward
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor
|
||||
input: inputSizeTensor
|
||||
scaleOffsetTensor: scaleOffsetTensor
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor:newCachedGraph->inputTensor
|
||||
input:inputSizeTensor
|
||||
scaleOffsetTensor:scaleOffsetTensor
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
}
|
||||
} else { // scaleOffsetTensor == nil || align_corners
|
||||
if (resizeMode == MPSGraphResizeNearest) {
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor
|
||||
input: inputSizeTensor
|
||||
nearestRoundingMode: nearestRoundingMode
|
||||
centerResult: centerResults
|
||||
alignCorners: align_corners
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor:newCachedGraph->inputTensor
|
||||
input:inputSizeTensor
|
||||
nearestRoundingMode:nearestRoundingMode
|
||||
centerResult:centerResults
|
||||
alignCorners:align_corners
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
} else { // bilinear backward
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor
|
||||
input: inputSizeTensor
|
||||
centerResult: centerResults
|
||||
alignCorners: align_corners
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor:newCachedGraph->inputTensor
|
||||
input:inputSizeTensor
|
||||
centerResult:centerResults
|
||||
alignCorners:align_corners
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // if macOS version < 13.0 (for backwards compatibility)
|
||||
if (!is_backward_pass) {
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeTensor: newCachedGraph->inputTensor
|
||||
sizeTensor: newCachedGraph->outputSizeTensor
|
||||
mode: resizeMode
|
||||
centerResult: centerResults
|
||||
alignCorners: align_corners
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeTensor:newCachedGraph->inputTensor
|
||||
sizeTensor:newCachedGraph->outputSizeTensor
|
||||
mode:resizeMode
|
||||
centerResult:centerResults
|
||||
alignCorners:align_corners
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
} else {
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->inputTensor
|
||||
input: inputSizeTensor
|
||||
mode: resizeMode
|
||||
centerResult: centerResults
|
||||
alignCorners: align_corners
|
||||
layout: dataLayout
|
||||
name: nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor:newCachedGraph->inputTensor
|
||||
input:inputSizeTensor
|
||||
mode:resizeMode
|
||||
centerResult:centerResults
|
||||
alignCorners:align_corners
|
||||
layout:dataLayout
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
}
|
||||
MPSNDArrayDescriptor *sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(2)]];
|
||||
MPSNDArray *sizeNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: sizeDesc] autorelease];
|
||||
[sizeNDArray writeBytes: (int32_t[]) {(int32_t)output_height, (int32_t)output_width} strideBytes: nil];
|
||||
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
|
||||
MPSNDArrayDescriptor* sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(2) ]];
|
||||
MPSNDArray* sizeNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:sizeDesc] autorelease];
|
||||
[sizeNDArray writeBytes:(int32_t[]){(int32_t)output_height, (int32_t)output_width} strideBytes:nil];
|
||||
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:sizeNDArray] autorelease];
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
cachedGraph->outputSizeTensor : sizeTensorData,
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
cachedGraph->outputSizeTensor : sizeTensorData,
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
if (out.has_storage()) {
|
||||
@ -223,8 +223,7 @@ void upsample_out_template(const Tensor& input,
|
||||
|
||||
} // namespace mps
|
||||
|
||||
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale)
|
||||
{
|
||||
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale) {
|
||||
static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
|
||||
if (!is_macOS_13_0_or_newer) {
|
||||
// passing scale factors to MPS's resize APIs is not supported on macOS < 13
|
||||
@ -232,11 +231,13 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
|
||||
TORCH_WARN_ONCE("MPS: passing scale factor to upsample ops is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
return false;
|
||||
// nearest mode on Monterey uses round() to compute source indices which
|
||||
// is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey.
|
||||
// The nearest mode should work fine on Ventura.
|
||||
// nearest mode on Monterey uses round() to compute source indices which
|
||||
// is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey.
|
||||
// The nearest mode should work fine on Ventura.
|
||||
} else if (resize_mode_str == "nearest" || resize_mode_str == "nearest-exact") {
|
||||
TORCH_WARN_ONCE("MPS: '", resize_mode_str, "' mode upsampling is supported natively starting from macOS 13.0. ",
|
||||
TORCH_WARN_ONCE("MPS: '",
|
||||
resize_mode_str,
|
||||
"' mode upsampling is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
return false;
|
||||
}
|
||||
@ -244,12 +245,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
|
||||
return true;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest", scale)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest");
|
||||
} else {
|
||||
@ -258,27 +255,23 @@ TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest", scale)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest-exact", scale)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact");
|
||||
} else {
|
||||
@ -287,113 +280,123 @@ TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest-exact", scale)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
|
||||
mps::upsample_out_template(
|
||||
grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest", scales_w)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) = at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) =
|
||||
at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest", scales_w)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest-exact", scales_w)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) = at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest-exact", scales_w)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
|
||||
mps::upsample_out_template(
|
||||
grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
bool align_corners,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
bool align_corners,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output) {
|
||||
if (check_mps_compatibility("bilinear", scales_w)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) = at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
bool align_corners,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
bool align_corners,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("bilinear", scales_w)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
|
||||
mps::upsample_out_template(
|
||||
grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::upsample_bilinear2d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::upsample_bilinear2d_backward(
|
||||
grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,18 +1,17 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct ViewCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
ViewCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
struct ViewCachedGraph : public MPSCachedGraph {
|
||||
ViewCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
MPSGraphTensor* updatesTensor = nil;
|
||||
@ -20,18 +19,20 @@ struct ViewCachedGraph : public MPSCachedGraph
|
||||
std::vector<MPSGraphTensor*> strideTensors;
|
||||
};
|
||||
|
||||
static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType& updates_dtype, const IntArrayRef& base_shape,
|
||||
const IntArrayRef& new_shape, const IntArrayRef& stride,
|
||||
int64_t storage_offset, bool is_scatter)
|
||||
{
|
||||
static std::string getStridedKey(const ScalarType& self_dtype,
|
||||
const ScalarType& updates_dtype,
|
||||
const IntArrayRef& base_shape,
|
||||
const IntArrayRef& new_shape,
|
||||
const IntArrayRef& stride,
|
||||
int64_t storage_offset,
|
||||
bool is_scatter) {
|
||||
std::string dtype_key = getMPSTypeString(self_dtype);
|
||||
if (is_scatter) {
|
||||
dtype_key += ":" + getMPSTypeString(updates_dtype);
|
||||
}
|
||||
|
||||
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" +
|
||||
getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]:[" +
|
||||
getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
|
||||
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
|
||||
}
|
||||
|
||||
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
|
||||
@ -39,30 +40,31 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
||||
const id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
|
||||
const id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
|
||||
|
||||
const IntArrayRef& strides = needsScatter ? output.strides() : src.strides();
|
||||
const IntArrayRef& sizes = needsScatter ? output.sizes() : src.sizes();
|
||||
const IntArrayRef& strides = needsScatter ? output.strides() : src.strides();
|
||||
const IntArrayRef& sizes = needsScatter ? output.sizes() : src.sizes();
|
||||
const int64_t storage_offset = needsScatter ? output.storage_offset() : src.storage_offset();
|
||||
const MPSDataType inputType = [cachedGraph->inputTensor dataType];
|
||||
const MPSDataType inputType = [cachedGraph->inputTensor dataType];
|
||||
|
||||
MPSShape *inputShape = [cachedGraph->inputTensor shape];
|
||||
MPSShape *outputShape = needsScatter ? inputShape : getMPSShape(src);
|
||||
MPSShape* inputShape = [cachedGraph->inputTensor shape];
|
||||
MPSShape* outputShape = needsScatter ? inputShape : getMPSShape(src);
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@autoreleasepool {
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
// in case of scatter, we use output tensor as input buffer and write the results back to the source buffer
|
||||
feeds[cachedGraph->inputTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: needsScatter ? outputBuffer : sourceBuffer
|
||||
shape: inputShape
|
||||
dataType: inputType] autorelease];
|
||||
feeds[cachedGraph->inputTensor] =
|
||||
[[[MPSGraphTensorData alloc] initWithMTLBuffer:needsScatter ? outputBuffer : sourceBuffer
|
||||
shape:inputShape
|
||||
dataType:inputType] autorelease];
|
||||
if (needsScatter) {
|
||||
auto updatesType = getMPSScalarType(src.scalar_type());
|
||||
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
updatesType = MPSDataTypeInt8;
|
||||
}
|
||||
|
||||
feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer
|
||||
shape: getMPSShape(src.numel())
|
||||
dataType: updatesType] autorelease];
|
||||
feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer
|
||||
shape:getMPSShape(src.numel())
|
||||
dataType:updatesType] autorelease];
|
||||
}
|
||||
MPSScalar storageOffsetScalar = getMPSScalar(storage_offset, ScalarType::Int);
|
||||
feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar);
|
||||
@ -75,59 +77,53 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
||||
// Workaround for MPSShaderLibrary bug in macOS Monterey
|
||||
// This is fixed in macOS Ventura
|
||||
auto outputType = getMPSScalarType(output.scalar_type());
|
||||
if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
outputType = MPSDataTypeInt8;
|
||||
if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
outputType = MPSDataTypeInt8;
|
||||
}
|
||||
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer
|
||||
shape: outputShape
|
||||
dataType: outputType] autorelease];
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
cachedGraph->outputTensor : outputTensorData
|
||||
};
|
||||
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:outputBuffer
|
||||
shape:outputShape
|
||||
dataType:outputType] autorelease];
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor : outputTensorData};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSArray *permuteOrder) {
|
||||
MPSGraphTensor* permuteTensor(MPSGraph* graph, MPSGraphTensor* inputTensor, NSArray* permuteOrder) {
|
||||
NSUInteger srcRank = [[inputTensor shape] count];
|
||||
if (srcRank != [permuteOrder count]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
MPSGraphTensor *outputTensor = inputTensor;
|
||||
MPSGraphTensor* outputTensor = inputTensor;
|
||||
std::vector<NSUInteger> dimensionOrder(srcRank);
|
||||
std::iota (std::begin(dimensionOrder), std::end(dimensionOrder), 0);
|
||||
std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0);
|
||||
|
||||
for (const auto i : c10::irange(srcRank)) {
|
||||
for (const auto i : c10::irange(srcRank)) {
|
||||
NSUInteger axis = [permuteOrder[i] integerValue];
|
||||
auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis);
|
||||
NSUInteger axis1 = i;
|
||||
NSUInteger axis2 = axisIter - dimensionOrder.begin();
|
||||
iter_swap(dimensionOrder.begin() + i, axisIter);
|
||||
|
||||
outputTensor = [graph transposeTensor:outputTensor
|
||||
dimension:axis1
|
||||
withDimension:axis2
|
||||
name:nil];
|
||||
outputTensor = [graph transposeTensor:outputTensor dimension:axis1 withDimension:axis2 name:nil];
|
||||
}
|
||||
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger rank, NSUInteger offset) {
|
||||
NSDictionary* getStrideToDimLengthOffsetDict(MPSGraphTensor* tensor, NSUInteger rank, NSUInteger offset) {
|
||||
// Assuming input tensor has default strides
|
||||
NSInteger stride = 1;
|
||||
NSMutableDictionary *strideToDimLengthOffset = [[NSMutableDictionary alloc] init];
|
||||
NSMutableDictionary* strideToDimLengthOffset = [[NSMutableDictionary alloc] init];
|
||||
for (NSInteger srcDim = rank - 1; srcDim >= 0; srcDim--) {
|
||||
NSUInteger size = [[tensor shape][srcDim] integerValue];
|
||||
NSDictionary *entry =
|
||||
@{
|
||||
@"dim": [NSNumber numberWithInteger:srcDim],
|
||||
@"length": [tensor shape][srcDim],
|
||||
@"offset": [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride
|
||||
NSDictionary* entry = @{
|
||||
@"dim" : [NSNumber numberWithInteger:srcDim],
|
||||
@"length" : [tensor shape][srcDim],
|
||||
@"offset" : [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride
|
||||
};
|
||||
[strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld",stride]];
|
||||
[strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld", stride]];
|
||||
offset /= size;
|
||||
stride *= size;
|
||||
}
|
||||
@ -135,14 +131,18 @@ NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger
|
||||
}
|
||||
|
||||
// Detect only expand dims, allows for duplicate strides
|
||||
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
|
||||
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
NSUInteger srcRank = [[inputTensor shape] count];
|
||||
// Not an expand dims
|
||||
if (srcRank >= dstRank)
|
||||
return nil;
|
||||
|
||||
NSMutableArray *expandAxes = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* expandAxes = [[NSMutableArray alloc] init];
|
||||
|
||||
BOOL isValidExpand = YES;
|
||||
NSInteger currSrcDim = (NSInteger)srcRank - 1;
|
||||
@ -152,7 +152,7 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
|
||||
NSUInteger currStride = dstStrides[dstDim];
|
||||
NSUInteger currSrcDimLength = currSrcDim >= 0 ? [[inputTensor shape][currSrcDim] integerValue] : 1;
|
||||
|
||||
NSUInteger targetDimLength = currSrcDimLength;
|
||||
NSUInteger targetDimLength = currSrcDimLength;
|
||||
if (currDimLength != targetDimLength) {
|
||||
targetDimLength = 1;
|
||||
}
|
||||
@ -173,11 +173,9 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
|
||||
return nil;
|
||||
}
|
||||
|
||||
MPSGraphTensor *expandTensor = inputTensor;
|
||||
MPSGraphTensor* expandTensor = inputTensor;
|
||||
if ([expandAxes count]) {
|
||||
expandTensor = [graph expandDimsOfTensor:expandTensor
|
||||
axes:expandAxes
|
||||
name:nil];
|
||||
expandTensor = [graph expandDimsOfTensor:expandTensor axes:expandAxes name:nil];
|
||||
}
|
||||
[expandAxes release];
|
||||
|
||||
@ -185,13 +183,18 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
|
||||
}
|
||||
|
||||
// Detect contiguous reshapes, no slicing
|
||||
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
NSUInteger srcRank = [[inputTensor shape] count];
|
||||
// Not a reshape
|
||||
if (srcRank <= dstRank)
|
||||
return nil;
|
||||
|
||||
NSMutableArray *dstShape = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* dstShape = [[NSMutableArray alloc] init];
|
||||
|
||||
BOOL isValidReshape = YES;
|
||||
NSInteger srcDim = srcRank - 1;
|
||||
@ -199,7 +202,7 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
for (NSInteger dstDim = dstRank - 1; dstDim >= 0 && isValidReshape; dstDim--) {
|
||||
NSUInteger currDimLength = dstSizes[dstDim];
|
||||
NSUInteger currStride = dstStrides[dstDim];
|
||||
[dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex: 0];
|
||||
[dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex:0];
|
||||
|
||||
NSUInteger targetDimLength = currDimLength;
|
||||
NSUInteger currReshapeSize = 1;
|
||||
@ -216,26 +219,28 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
}
|
||||
isValidReshape &= (srcDim < 0);
|
||||
|
||||
MPSGraphTensor *outputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
if (isValidReshape)
|
||||
outputTensor = [graph reshapeTensor: inputTensor
|
||||
withShape: dstShape
|
||||
name: nil];
|
||||
outputTensor = [graph reshapeTensor:inputTensor withShape:dstShape name:nil];
|
||||
[dstShape release];
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
|
||||
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
// Duplicate strides cannot be done
|
||||
{
|
||||
BOOL allUnique = YES;
|
||||
NSMutableSet *uniqueStrides = [[NSMutableSet alloc] init];
|
||||
NSMutableSet* uniqueStrides = [[NSMutableSet alloc] init];
|
||||
for (NSInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) {
|
||||
int stride = dstStrides[dstDim];
|
||||
NSNumber *strideObj = [NSNumber numberWithInt:stride];
|
||||
NSNumber* strideObj = [NSNumber numberWithInt:stride];
|
||||
allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]);
|
||||
[uniqueStrides addObject: strideObj];
|
||||
[uniqueStrides addObject:strideObj];
|
||||
}
|
||||
[uniqueStrides release];
|
||||
if (!allUnique)
|
||||
@ -243,31 +248,31 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
|
||||
// Skip for zero in dst shape
|
||||
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++)
|
||||
if (dstSizes[dstDim] == 0) { return nil; }
|
||||
if (dstSizes[dstDim] == 0) {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Flatten the inputTensor if necessary
|
||||
MPSGraphTensor *flatInputTensor = inputTensor;
|
||||
MPSGraphTensor* flatInputTensor = inputTensor;
|
||||
{
|
||||
// Flatten inputs to remove duplicate strides.
|
||||
NSMutableArray *squeezeAxes = [[NSMutableArray alloc] init];
|
||||
for(NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) {
|
||||
if ([[flatInputTensor shape][srcDim] intValue] == 1)
|
||||
[squeezeAxes addObject:[NSNumber numberWithInteger:srcDim]];
|
||||
NSMutableArray* squeezeAxes = [[NSMutableArray alloc] init];
|
||||
for (NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) {
|
||||
if ([[flatInputTensor shape][srcDim] intValue] == 1)
|
||||
[squeezeAxes addObject:[NSNumber numberWithInteger:srcDim]];
|
||||
}
|
||||
// We have to leave at least 1 dimension, if all input dims are 1
|
||||
if ([squeezeAxes count])
|
||||
flatInputTensor = [graph squeezeTensor:flatInputTensor
|
||||
axes:squeezeAxes
|
||||
name:nil];
|
||||
flatInputTensor = [graph squeezeTensor:flatInputTensor axes:squeezeAxes name:nil];
|
||||
[squeezeAxes release];
|
||||
}
|
||||
|
||||
int srcRank = (int)[[flatInputTensor shape] count];
|
||||
NSDictionary *srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset);
|
||||
NSDictionary* srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset);
|
||||
|
||||
// Populate the dimension order, slice info, and broadcast info
|
||||
NSMutableArray *dstDimOrder = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* dstDimOrder = [[NSMutableArray alloc] init];
|
||||
std::vector<int32_t> dstDimToSliceLength(dstRank);
|
||||
std::vector<int32_t> dstDimToSliceOffset(dstRank);
|
||||
bool needsBroadcast = false;
|
||||
@ -280,31 +285,33 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
dstDimToSliceOffset[dstDim] = 0;
|
||||
} else {
|
||||
// Find what dimension and native length was for the specified stride
|
||||
NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld",dstStrides[dstDim]]];
|
||||
NSDictionary* srcDimLengthOffset =
|
||||
srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld", dstStrides[dstDim]]];
|
||||
|
||||
dstDimToSliceLength[dstDim] = dstSizes[dstDim];
|
||||
dstDimToSliceOffset[dstDim] = [srcDimLengthOffset[@"offset"] intValue];
|
||||
|
||||
// Stride does not exist in source tensor, or the specified size is too long. Not possible
|
||||
// TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding support
|
||||
// TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding
|
||||
// support
|
||||
if (!srcDimLengthOffset ||
|
||||
// the offset + length of destination should not be larger than source's length when slicing
|
||||
dstDimToSliceOffset[dstDim] + dstDimToSliceLength[dstDim] > [srcDimLengthOffset[@"length"] intValue]) {
|
||||
return nil;
|
||||
}
|
||||
// Get the src dimension corresponding to the requested stride
|
||||
NSNumber *srcDim = srcDimLengthOffset[@"dim"];
|
||||
NSNumber* srcDim = srcDimLengthOffset[@"dim"];
|
||||
[dstDimOrder insertObject:srcDim atIndex:0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Slice out any unused dimensions
|
||||
NSMutableArray *missingSrcDims = [[NSMutableArray alloc] init];
|
||||
MPSGraphTensor *slicedUnusedTensor = flatInputTensor;
|
||||
NSMutableArray* missingSrcDims = [[NSMutableArray alloc] init];
|
||||
MPSGraphTensor* slicedUnusedTensor = flatInputTensor;
|
||||
{
|
||||
// Find any src strides/dims that are not present in the dst
|
||||
NSMutableArray *missingSrcStrides = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* missingSrcStrides = [[NSMutableArray alloc] init];
|
||||
{
|
||||
NSUInteger stride = 1;
|
||||
for (NSInteger srcDim = [[flatInputTensor shape] count] - 1; srcDim >= 0; srcDim--) {
|
||||
@ -317,8 +324,8 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
}
|
||||
for (NSUInteger i = 0; i < [missingSrcStrides count]; i++) {
|
||||
NSUInteger stride = [missingSrcStrides[i] integerValue];
|
||||
NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld",stride]];
|
||||
NSNumber *missingSrcDim = srcDimLengthOffset[@"dim"];
|
||||
NSDictionary* srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld", stride]];
|
||||
NSNumber* missingSrcDim = srcDimLengthOffset[@"dim"];
|
||||
[missingSrcDims addObject:missingSrcDim];
|
||||
[dstDimOrder insertObject:missingSrcDim atIndex:0];
|
||||
|
||||
@ -332,35 +339,33 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
}
|
||||
|
||||
// 3. Transpose if necessary
|
||||
MPSGraphTensor *transposedTensor = slicedUnusedTensor;
|
||||
MPSGraphTensor* transposedTensor = slicedUnusedTensor;
|
||||
{
|
||||
// TODO: Use Transpose API
|
||||
BOOL needsTranspose = NO;
|
||||
for(NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++ )
|
||||
for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++)
|
||||
needsTranspose |= ([dstDimOrder[dstDim] intValue] != dstDim);
|
||||
if (needsTranspose)
|
||||
transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder);
|
||||
}
|
||||
|
||||
// 4. Squeeze any unused dimensions following transpose
|
||||
MPSGraphTensor *squeezedTensor = transposedTensor;
|
||||
MPSGraphTensor* squeezedTensor = transposedTensor;
|
||||
{
|
||||
// Transpose the missing dims back
|
||||
NSMutableArray *transposedMissingSrcDims = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* transposedMissingSrcDims = [[NSMutableArray alloc] init];
|
||||
for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count]; dstDim++) {
|
||||
NSNumber *srcDim = dstDimOrder[dstDim];
|
||||
NSNumber* srcDim = dstDimOrder[dstDim];
|
||||
if ([missingSrcDims containsObject:srcDim])
|
||||
[transposedMissingSrcDims addObject:[NSNumber numberWithInt:dstDim]];
|
||||
}
|
||||
if ([transposedMissingSrcDims count])
|
||||
squeezedTensor = [graph squeezeTensor:squeezedTensor
|
||||
axes:transposedMissingSrcDims
|
||||
name:nil];
|
||||
squeezedTensor = [graph squeezeTensor:squeezedTensor axes:transposedMissingSrcDims name:nil];
|
||||
[transposedMissingSrcDims release];
|
||||
}
|
||||
|
||||
// 5. Slice
|
||||
MPSGraphTensor *slicedTensor = squeezedTensor;
|
||||
MPSGraphTensor* slicedTensor = squeezedTensor;
|
||||
{
|
||||
NSUInteger currDstDim = 0;
|
||||
for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
@ -369,34 +374,26 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
int start = dstDimToSliceOffset[dstDim];
|
||||
int length = dstDimToSliceLength[dstDim];
|
||||
if (length != [[slicedTensor shape][currDstDim] intValue])
|
||||
slicedTensor = [graph sliceTensor:slicedTensor
|
||||
dimension:currDstDim
|
||||
start:start
|
||||
length:length
|
||||
name:nil];
|
||||
slicedTensor = [graph sliceTensor:slicedTensor dimension:currDstDim start:start length:length name:nil];
|
||||
currDstDim++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Expand then broadcast the source tensor
|
||||
MPSGraphTensor *broadcastTensor = slicedTensor;
|
||||
MPSGraphTensor* broadcastTensor = slicedTensor;
|
||||
if (needsBroadcast) {
|
||||
NSMutableArray *broadcastShape = [[NSMutableArray alloc] init];
|
||||
NSMutableArray *expandAxes = [[NSMutableArray alloc] init];
|
||||
for(NSInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
NSMutableArray* broadcastShape = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* expandAxes = [[NSMutableArray alloc] init];
|
||||
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
[broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]];
|
||||
if (dstStrides[dstDim] == 0)
|
||||
[expandAxes addObject:[NSNumber numberWithInt:dstDim]];
|
||||
}
|
||||
|
||||
if ([expandAxes count]) {
|
||||
MPSGraphTensor *expandTensor = [graph expandDimsOfTensor:broadcastTensor
|
||||
axes:expandAxes
|
||||
name:nil];
|
||||
broadcastTensor = [graph broadcastTensor:expandTensor
|
||||
toShape:broadcastShape
|
||||
name:nil];
|
||||
MPSGraphTensor* expandTensor = [graph expandDimsOfTensor:broadcastTensor axes:expandAxes name:nil];
|
||||
broadcastTensor = [graph broadcastTensor:expandTensor toShape:broadcastShape name:nil];
|
||||
}
|
||||
[broadcastShape release];
|
||||
[expandAxes release];
|
||||
@ -409,11 +406,16 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
return broadcastTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
if (!dstRank)
|
||||
return nil;
|
||||
|
||||
MPSGraphTensor *outputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
outputTensor = asStridedLayer_expandDimsPattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset);
|
||||
if (!outputTensor)
|
||||
outputTensor = asStridedLayer_reshapePattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset);
|
||||
@ -423,8 +425,7 @@ MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTen
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
static
|
||||
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) {
|
||||
static std::vector<int64_t> getViewShape(const Tensor& src, MPSShape* mpsShape, const bool squeeze) {
|
||||
bool hasMPSShape = (mpsShape != nil);
|
||||
std::vector<int64_t> src_view_shape;
|
||||
if (hasMPSShape) {
|
||||
@ -459,7 +460,6 @@ std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const b
|
||||
return src_view_shape;
|
||||
}
|
||||
|
||||
|
||||
std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) {
|
||||
std::vector<int64_t> src_base_shape;
|
||||
for (const auto i : c10::irange(shape.size())) {
|
||||
@ -471,8 +471,7 @@ std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape)
|
||||
return src_base_shape;
|
||||
}
|
||||
|
||||
|
||||
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
|
||||
bool canSliceViewTensor(const Tensor& src, MPSShape* mpsShape) {
|
||||
if (!src.is_contiguous()) {
|
||||
return false;
|
||||
}
|
||||
@ -486,23 +485,23 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto i: c10::irange(src_ndim_base)) {
|
||||
if (src_view_shape[i] > src_base_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(src_ndim_base)) {
|
||||
if (src_view_shape[i] > src_base_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) {
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape* mpsShape, const MPSDataType mpsDataType) {
|
||||
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
|
||||
size_t src_ndim_base = src_base_shape.size();
|
||||
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
|
||||
size_t src_ndim_view = src_view_shape.size();
|
||||
|
||||
MPSNDArray *srcTensorNDArrayView = nil;
|
||||
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
|
||||
MPSNDArray *srcTensorNDArray = nil;
|
||||
MPSNDArray* srcTensorNDArrayView = nil;
|
||||
MPSNDArrayDescriptor* srcTensorNDArrayDesc = nil;
|
||||
MPSNDArray* srcTensorNDArray = nil;
|
||||
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
|
||||
int64_t base_idx = 0;
|
||||
|
||||
@ -537,19 +536,21 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
|
||||
}
|
||||
|
||||
int64_t sliceOffset = src.storage_offset() / view_numel;
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
|
||||
[srcTensorNDArrayDesc
|
||||
sliceDimension:src_ndim_base - 1 - firstDimToSlice
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
|
||||
|
||||
// Slice any remaining dimensions
|
||||
for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
|
||||
for (const auto crtSliceOffset : c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
|
||||
if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
|
||||
if (crtSliceOffset == src_base_shape.size() - 1) {
|
||||
sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
|
||||
} else {
|
||||
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
|
||||
}
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
|
||||
[srcTensorNDArrayDesc
|
||||
sliceDimension:src_ndim_base - 1 - crtSliceOffset
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
|
||||
}
|
||||
}
|
||||
srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer
|
||||
@ -559,13 +560,15 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
|
||||
return [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcTensorNDArrayView] autorelease];
|
||||
}
|
||||
|
||||
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size,
|
||||
const IntArrayRef& stride, int64_t offset,
|
||||
const IntArrayRef& base_shape, bool needsScatter,
|
||||
MPSGraphTensor* updatesTensor)
|
||||
{
|
||||
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
|
||||
const IntArrayRef& size,
|
||||
const IntArrayRef& stride,
|
||||
int64_t offset,
|
||||
const IntArrayRef& base_shape,
|
||||
bool needsScatter,
|
||||
MPSGraphTensor* updatesTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor *outputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
const size_t shape_size = size.size();
|
||||
|
||||
@autoreleasepool {
|
||||
@ -575,87 +578,74 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
||||
TORCH_CHECK(size[i] <= int_max);
|
||||
sizeArray[i] = static_cast<int32_t>(size[i]);
|
||||
}
|
||||
NSData* shapeData = [NSData dataWithBytes: sizeArray.data()
|
||||
length: shape_size * sizeof(int32_t)];
|
||||
MPSGraphTensor* shapeTensor = [mpsGraph constantWithData: shapeData
|
||||
shape: @[[NSNumber numberWithUnsignedInteger: shape_size]]
|
||||
dataType: MPSDataTypeInt32];
|
||||
NSData* shapeData = [NSData dataWithBytes:sizeArray.data() length:shape_size * sizeof(int32_t)];
|
||||
MPSGraphTensor* shapeTensor = [mpsGraph constantWithData:shapeData
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:shape_size] ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* indicesTensor = nil;
|
||||
// create stride Tensors for each rank of the input tensor
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis: (-i - 1)
|
||||
withShapeTensor: shapeTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil];
|
||||
MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1];
|
||||
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor: rangeTensor
|
||||
secondaryTensor: strideTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor
|
||||
secondaryTensor:strideTensor
|
||||
name:nil];
|
||||
if (!indicesTensor) {
|
||||
indicesTensor = indexTensor;
|
||||
} else {
|
||||
indicesTensor = [mpsGraph additionWithPrimaryTensor: indexTensor
|
||||
secondaryTensor: indicesTensor
|
||||
name: nil];
|
||||
indicesTensor = [mpsGraph additionWithPrimaryTensor:indexTensor secondaryTensor:indicesTensor name:nil];
|
||||
}
|
||||
}
|
||||
|
||||
indicesTensor = [mpsGraph additionWithPrimaryTensor: indicesTensor
|
||||
secondaryTensor: cachedGraph->storageOffsetTensor
|
||||
name: nil];
|
||||
MPSGraphTensor *inputTensor = cachedGraph->inputTensor;
|
||||
indicesTensor = [mpsGraph additionWithPrimaryTensor:indicesTensor
|
||||
secondaryTensor:cachedGraph->storageOffsetTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* inputTensor = cachedGraph->inputTensor;
|
||||
|
||||
if (!needsScatter) {
|
||||
MPSGraphTensor *outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset);
|
||||
MPSGraphTensor* outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset);
|
||||
if (outputTensor) {
|
||||
return outputTensor;
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor: inputTensor
|
||||
withShape: @[@-1]
|
||||
name: nil];
|
||||
MPSGraphTensor *reshapedIndicesTensor = [mpsGraph reshapeTensor: indicesTensor
|
||||
withShape: @[@-1]
|
||||
name: nil];
|
||||
MPSGraphTensor* reshapedInputTensor = [mpsGraph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
|
||||
MPSGraphTensor* reshapedIndicesTensor = [mpsGraph reshapeTensor:indicesTensor withShape:@[ @-1 ] name:nil];
|
||||
if (needsScatter) {
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wobjc-method-access"
|
||||
MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis: (NSInteger) 0
|
||||
withDataTensor: reshapedInputTensor
|
||||
updatesTensor: updatesTensor
|
||||
indicesTensor: reshapedIndicesTensor
|
||||
mode: MPSGraphScatterModeSet
|
||||
name: nil];
|
||||
MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis:(NSInteger)0
|
||||
withDataTensor:reshapedInputTensor
|
||||
updatesTensor:updatesTensor
|
||||
indicesTensor:reshapedIndicesTensor
|
||||
mode:MPSGraphScatterModeSet
|
||||
name:nil];
|
||||
#pragma clang diagnostic pop
|
||||
outputTensor = [mpsGraph reshapeTensor: scatteredTensor
|
||||
withShape: getMPSShape(base_shape)
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:scatteredTensor withShape:getMPSShape(base_shape) name:nil];
|
||||
} else {
|
||||
// Call gather to coalesce the needed values. Result will be of same shape as flattened indices tensor
|
||||
MPSGraphTensor *gatheredTensor = [mpsGraph gatherWithUpdatesTensor: reshapedInputTensor
|
||||
indicesTensor: reshapedIndicesTensor
|
||||
axis: 0
|
||||
batchDimensions: 0
|
||||
name: nil];
|
||||
MPSGraphTensor* gatheredTensor = [mpsGraph gatherWithUpdatesTensor:reshapedInputTensor
|
||||
indicesTensor:reshapedIndicesTensor
|
||||
axis:0
|
||||
batchDimensions:0
|
||||
name:nil];
|
||||
// Reshape the data to desired size
|
||||
outputTensor = [mpsGraph reshapeTensor: gatheredTensor
|
||||
withShapeTensor: shapeTensor
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:gatheredTensor withShapeTensor:shapeTensor name:nil];
|
||||
}
|
||||
}
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
static IntArrayRef updateTensorBaseShape(const Tensor& self)
|
||||
{
|
||||
static IntArrayRef updateTensorBaseShape(const Tensor& self) {
|
||||
IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data());
|
||||
// if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it
|
||||
if (base_shape.size() == 0) {
|
||||
// IntArrayRef wouldn't own the data, so we use a static storage
|
||||
static const int64_t shape_1d = 1;
|
||||
// self.sizes().size() could be zero
|
||||
base_shape = self.sizes().size() ? self.sizes() :
|
||||
((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
|
||||
base_shape = self.sizes().size()
|
||||
? self.sizes()
|
||||
: ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
|
||||
|
||||
// base_shape will be retained in MPSAllocator until buffer gets recycled
|
||||
if (self.storage().data())
|
||||
@ -681,49 +671,53 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self)
|
||||
// | / \ |
|
||||
// | / \ |
|
||||
// NonView T NonView T
|
||||
static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &updates, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter)
|
||||
{
|
||||
static ViewCachedGraph* createViewGraph(const Tensor& self,
|
||||
const Tensor& updates,
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
int64_t storage_offset,
|
||||
bool needsScatter) {
|
||||
IntArrayRef base_shape = updateTensorBaseShape(self);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = getStridedKey(self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter);
|
||||
string key = getStridedKey(
|
||||
self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter);
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
ViewCachedGraph* cachedGraph = static_cast<ViewCachedGraph *>(cache_->LookUp(key));
|
||||
ViewCachedGraph* cachedGraph = static_cast<ViewCachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = static_cast<ViewCachedGraph *>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
ViewCachedGraph *newCachedGraph = nil;
|
||||
cachedGraph = static_cast<ViewCachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
ViewCachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
MPSGraphTensor* updatesTensor = nil;
|
||||
newCachedGraph = new ViewCachedGraph(mpsGraph);
|
||||
// Workaround for MPSShaderLibrary bug in macOS Monterey
|
||||
// This is fixed in macOS Ventura
|
||||
auto inputType = getMPSScalarType(self.scalar_type());
|
||||
if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
inputType = MPSDataTypeInt8;
|
||||
}
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
MPSGraphTensor* updatesTensor = nil;
|
||||
newCachedGraph = new ViewCachedGraph(mpsGraph);
|
||||
// Workaround for MPSShaderLibrary bug in macOS Monterey
|
||||
// This is fixed in macOS Ventura
|
||||
auto inputType = getMPSScalarType(self.scalar_type());
|
||||
if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
inputType = MPSDataTypeInt8;
|
||||
}
|
||||
|
||||
// Self is the input tensor we are creating view of
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
|
||||
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]);
|
||||
for (int i = 0; i < size.size(); i++) {
|
||||
newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]));
|
||||
// Self is the input tensor we are creating view of
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
|
||||
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]);
|
||||
for (int i = 0; i < size.size(); i++) {
|
||||
newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]));
|
||||
}
|
||||
if (needsScatter) {
|
||||
auto updatesType = getMPSScalarType(updates.scalar_type());
|
||||
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
updatesType = MPSDataTypeInt8;
|
||||
}
|
||||
if (needsScatter) {
|
||||
auto updatesType = getMPSScalarType(updates.scalar_type());
|
||||
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
updatesType = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel()));
|
||||
updatesTensor = newCachedGraph->updatesTensor;
|
||||
if (inputType != updatesType) {
|
||||
updatesTensor = [mpsGraph castTensor:updatesTensor
|
||||
toType:inputType
|
||||
name:@"castUpdatesTensor"];
|
||||
}
|
||||
newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel()));
|
||||
updatesTensor = newCachedGraph->updatesTensor;
|
||||
if (inputType != updatesType) {
|
||||
updatesTensor = [mpsGraph castTensor:updatesTensor toType:inputType name:@"castUpdatesTensor"];
|
||||
}
|
||||
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
|
||||
}
|
||||
newCachedGraph->outputTensor =
|
||||
chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
@ -732,11 +726,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
|
||||
}
|
||||
}
|
||||
|
||||
static
|
||||
std::string getGatherScatterFunctionName(
|
||||
ScalarType scalarType,
|
||||
int64_t dim,
|
||||
bool needsScatter) {
|
||||
static std::string getGatherScatterFunctionName(ScalarType scalarType, int64_t dim, bool needsScatter) {
|
||||
std::string kernelName = needsScatter ? "scatter" : "gather";
|
||||
return kernelName + "_kernel_" + std::to_string(dim == 0 ? 1 : dim);
|
||||
}
|
||||
@ -744,14 +734,14 @@ std::string getGatherScatterFunctionName(
|
||||
const std::string& getGatherScatterScalarType(const Tensor& t) {
|
||||
auto scalar_type = t.scalar_type();
|
||||
static std::unordered_map<c10::ScalarType, std::string> scalarToMetalType = {
|
||||
{c10::ScalarType::Float, "float"},
|
||||
{c10::ScalarType::Half, "half"},
|
||||
{c10::ScalarType::Long, "long"},
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Char, "char"},
|
||||
{c10::ScalarType::Byte, "uchar"},
|
||||
{c10::ScalarType::Bool, "bool"},
|
||||
{c10::ScalarType::Float, "float"},
|
||||
{c10::ScalarType::Half, "half"},
|
||||
{c10::ScalarType::Long, "long"},
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Char, "char"},
|
||||
{c10::ScalarType::Byte, "uchar"},
|
||||
{c10::ScalarType::Bool, "bool"},
|
||||
};
|
||||
|
||||
auto it = scalarToMetalType.find(scalar_type);
|
||||
@ -759,24 +749,30 @@ const std::string& getGatherScatterScalarType(const Tensor& t) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static
|
||||
id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
|
||||
const std::string& dtypeSrc,
|
||||
const std::string& dtypeDst,
|
||||
bool needsScatter) {
|
||||
static id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
|
||||
const std::string& dtypeSrc,
|
||||
const std::string& dtypeDst,
|
||||
bool needsScatter) {
|
||||
auto key = std::to_string(needsScatter) + dtypeSrc + dtypeDst;
|
||||
static std::unordered_map<std::string, id<MTLLibrary>> _libCache;
|
||||
auto it = _libCache.find(key);
|
||||
if (it != _libCache.end()) {
|
||||
return it->second;
|
||||
}
|
||||
NSError *error = nil;
|
||||
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion: MTLLanguageVersion2_3];
|
||||
auto gatherScatterLib = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE : GATHER_OPS_TEMPLATE, dtypeSrc, dtypeDst).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(gatherScatterLib != nil && error == nil, "Failed to compile gather-scatter library, error: ", [[error description] UTF8String]);
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
auto gatherScatterLib =
|
||||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE
|
||||
: GATHER_OPS_TEMPLATE,
|
||||
dtypeSrc,
|
||||
dtypeDst)
|
||||
.c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(gatherScatterLib != nil && error == nil,
|
||||
"Failed to compile gather-scatter library, error: ",
|
||||
[[error description] UTF8String]);
|
||||
_libCache[key] = gatherScatterLib;
|
||||
return gatherScatterLib;
|
||||
}
|
||||
@ -790,15 +786,16 @@ static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device,
|
||||
static std::unordered_map<std::string, id<MTLComputePipelineState>> _mtlPipelineCache;
|
||||
auto it = _mtlPipelineCache.find(key);
|
||||
if (it != _mtlPipelineCache.end()) {
|
||||
return it->second;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
NSError *error = nil;
|
||||
NSError* error = nil;
|
||||
id<MTLLibrary> library = compileGatherScatterOpsLibrary(device, dtypeSrc, dtypeDst, needsScatter);
|
||||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
|
||||
TORCH_CHECK(func, "Failed to load the Metal Shader function: ", kernel);
|
||||
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
_mtlPipelineCache[key] = pso;
|
||||
return pso;
|
||||
}
|
||||
@ -814,8 +811,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
}
|
||||
|
||||
if (src.dim() > 5) {
|
||||
ViewCachedGraph* cachedGraph = createViewGraph(src, dst, src.sizes(), src.strides(),
|
||||
src.storage_offset(), /*needsScatter*/ false);
|
||||
ViewCachedGraph* cachedGraph =
|
||||
createViewGraph(src, dst, src.sizes(), src.strides(), src.storage_offset(), /*needsScatter*/ false);
|
||||
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false);
|
||||
}
|
||||
|
||||
@ -824,7 +821,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
uint32_t numThreads = output.numel();
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
dispatch_sync(mpsStream->queue(), ^(){
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = [mpsStream->commandBuffer() computeCommandEncoder];
|
||||
std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/false);
|
||||
id<MTLComputePipelineState> gatherPSO = getPipelineState(MPSDevice::getInstance()->device(),
|
||||
@ -846,7 +843,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
}
|
||||
}
|
||||
|
||||
[computeEncoder setComputePipelineState: gatherPSO];
|
||||
[computeEncoder setComputePipelineState:gatherPSO];
|
||||
[computeEncoder setBuffer:getMTLBufferStorage(src) offset:src.storage_offset() * src.element_size() atIndex:0];
|
||||
[computeEncoder setBuffer:outputBuffer offset:outputStorageOffset atIndex:1];
|
||||
[computeEncoder setBytes:&src_sizes[0] length:sizeof(uint32_t) * kernel_size atIndex:2];
|
||||
@ -856,7 +853,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
||||
NSUInteger threadsPerThreadgroup_ = gatherPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (threadsPerThreadgroup_ > numThreads) {
|
||||
threadsPerThreadgroup_ = numThreads;
|
||||
threadsPerThreadgroup_ = numThreads;
|
||||
}
|
||||
|
||||
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1);
|
||||
@ -868,11 +865,14 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
return (dst.has_storage()) ? dst : output;
|
||||
}
|
||||
|
||||
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
|
||||
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
|
||||
if (output.dim() > 5) {
|
||||
ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output,
|
||||
src, output.sizes(), output.strides(),
|
||||
output.storage_offset(), /*needsScatter*/ true);
|
||||
ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output,
|
||||
src,
|
||||
output.sizes(),
|
||||
output.strides(),
|
||||
output.storage_offset(),
|
||||
/*needsScatter*/ true);
|
||||
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true);
|
||||
}
|
||||
if (src.numel() == 0 || output.numel() == 0) {
|
||||
@ -884,11 +884,12 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
|
||||
uint32_t numThreads = src.numel();
|
||||
int64_t outputStorageOffset = output.storage_offset() * output.element_size();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
dispatch_sync(mpsStream->queue(), ^(){
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
||||
std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true);
|
||||
std::string functionName =
|
||||
getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true);
|
||||
id<MTLComputePipelineState> scatterPSO = getPipelineState(MPSDevice::getInstance()->device(),
|
||||
functionName,
|
||||
getGatherScatterScalarType(src),
|
||||
@ -908,7 +909,7 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
|
||||
}
|
||||
}
|
||||
|
||||
[computeEncoder setComputePipelineState: scatterPSO];
|
||||
[computeEncoder setComputePipelineState:scatterPSO];
|
||||
[computeEncoder setBuffer:sourceBuffer offset:src.storage_offset() * src.element_size() atIndex:0];
|
||||
[computeEncoder setBuffer:outputBuffer offset:outputStorageOffset atIndex:1];
|
||||
[computeEncoder setBytes:&output_sizes[0] length:sizeof(uint32_t) * kernel_size atIndex:2];
|
||||
@ -934,16 +935,21 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
|
||||
} // namespace mps
|
||||
|
||||
// implementation of as_strided() op
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_) {
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self,
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset_) {
|
||||
auto storage_offset = storage_offset_.value_or(self.storage_offset());
|
||||
auto result = detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
auto result =
|
||||
detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
setStrided(result, size, stride, storage_offset);
|
||||
|
||||
// creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called.
|
||||
// In as_strided, we just update the base shape of the buffer in order to retrieve it later
|
||||
// when we create/run the view graph.
|
||||
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
|
||||
TORCH_INTERNAL_ASSERT(base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -10,8 +10,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
+ (NSString*)cacheDirectory;
|
||||
|
||||
+ (BOOL)compileModel:(const std::string&)modelSpecs
|
||||
modelID:(const std::string&)modelID;
|
||||
+ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID;
|
||||
|
||||
+ (nullable MLModel*)loadModel:(const std::string&)modelID
|
||||
backend:(const std::string&)backend
|
||||
|
Reference in New Issue
Block a user