mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][MPS] Add MPS to clang format (#96562)
I'm getting tired of asking to add space after if and all that jazz, so let's linter do that. Add section for Objective-C language, where column with is extended to 120 characters and `AlignAfterOpenBracket` is set to `Align` All `.mm` changes in this PR are made by running linter as follows: ``` lintrunner --take CLANGFORMAT --all-files --apply-patches ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/96562 Approved by: https://github.com/seemethere, https://github.com/janeyx99, https://github.com/ZainRizvi, https://github.com/izaitsevfb, https://github.com/PaliC, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7689e73f6
commit
4242e698a3
@ -60,9 +60,6 @@ MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
@ -85,4 +82,11 @@ SpacesInSquareBrackets: false
|
||||
Standard: Cpp11
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
---
|
||||
Language: ObjC
|
||||
ColumnLimit: 120
|
||||
AlignAfterOpenBracket: Align
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
...
|
||||
|
@ -49,6 +49,8 @@ init_command = [
|
||||
code = 'CLANGFORMAT'
|
||||
include_patterns = [
|
||||
'aten/src/ATen/*.h',
|
||||
'aten/src/ATen/mps/**/*.mm',
|
||||
'aten/src/ATen/native/mps/**/*.mm',
|
||||
'aten/src/ATen/native/vulkan/**/*.h',
|
||||
'aten/src/ATen/native/vulkan/**/*.cpp',
|
||||
'c10/**/*.h',
|
||||
|
@ -1,10 +1,10 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/CPUFunctions.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/mps/MPSAllocator.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <ATen/CPUFunctions.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace at {
|
||||
@ -23,21 +23,22 @@ void MPSHeapAllocatorImpl::init_allocator() {
|
||||
m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT;
|
||||
|
||||
static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
|
||||
const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) :
|
||||
default_high_watermark_ratio;
|
||||
const double high_watermark_ratio =
|
||||
high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio;
|
||||
setHighWatermarkRatio(high_watermark_ratio);
|
||||
|
||||
const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified :
|
||||
default_low_watermark_ratio_discrete;
|
||||
const double default_low_watermark_ratio =
|
||||
m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete;
|
||||
static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
|
||||
const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
|
||||
const double low_watermark_ratio =
|
||||
low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
|
||||
setLowWatermarkRatio(low_watermark_ratio);
|
||||
}
|
||||
|
||||
void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
|
||||
TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio);
|
||||
m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
|
||||
static_cast<size_t>(ratio * (double)max_device_size());
|
||||
m_max_total_allowed_size =
|
||||
(ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
|
||||
if (m_debug_verbosity & DebugVerbosity::PROFILING) {
|
||||
std::cerr << "\nHigh watermark memory allocation limit: "
|
||||
<< (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n";
|
||||
@ -47,11 +48,12 @@ void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
|
||||
|
||||
void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) {
|
||||
// used for comparison with lower_watermark_ratio
|
||||
const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
|
||||
const double high_watermark_limit =
|
||||
m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
|
||||
TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio);
|
||||
// we use this to detect if there's memory pressure
|
||||
m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
|
||||
static_cast<size_t>(ratio * (double)max_device_size());
|
||||
m_low_watermark_limit =
|
||||
(ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
|
||||
if (m_debug_verbosity & DebugVerbosity::PROFILING) {
|
||||
std::cerr << "Low watermark memory allocation limit: "
|
||||
<< (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n";
|
||||
@ -69,10 +71,8 @@ HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) {
|
||||
heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage);
|
||||
if (heap_block) {
|
||||
if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) {
|
||||
std::cerr << "\nAllocated "
|
||||
<< ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ")
|
||||
<< " heap #" << heap_block->heap_id
|
||||
<< " of size " << format_size(heap_block->size.total)
|
||||
std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ") << " heap #"
|
||||
<< heap_block->heap_id << " of size " << format_size(heap_block->size.total)
|
||||
<< " (#heaps: " << (pool.heaps.size() + 1)
|
||||
<< ", current allocated: " << format_size(current_allocated_size()) << ")\n";
|
||||
}
|
||||
@ -110,16 +110,13 @@ bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) {
|
||||
|
||||
if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) &&
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Allocated "
|
||||
<< ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "")
|
||||
<< " buffer #" << params.buffer_block->buf_id
|
||||
<< " of size " << format_size(params.size())
|
||||
<< " at " << params.buffer_block->buffer
|
||||
<< " from heap #" << heap->heap_id
|
||||
std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
|
||||
<< params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at "
|
||||
<< params.buffer_block->buffer << " from heap #" << heap->heap_id
|
||||
<< " (requested: " << format_size(params.requested_size)
|
||||
<< ", heap: " << format_size(heap->size.available)
|
||||
<< ", total: " << format_size(m_total_allocated_memory) << ")\n";
|
||||
<< ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory)
|
||||
<< ")\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -158,8 +155,8 @@ bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
|
||||
// this will skip unnecessary garbage collection as we'll reuse the newly released space
|
||||
params.has_memory_pressure = false;
|
||||
} else if (params.has_memory_pressure) {
|
||||
// the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap container)
|
||||
// in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
|
||||
// the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap
|
||||
// container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
|
||||
release_buffer(buffer_block, true);
|
||||
} else {
|
||||
// only if there's no memory pressure, we'll reuse the oversized buffer
|
||||
@ -177,15 +174,12 @@ bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
|
||||
|
||||
if ((m_debug_verbosity & DebugVerbosity::RECYCLES) &&
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Reusing "
|
||||
<< ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "")
|
||||
<< " buffer #" << params.buffer_block->buf_id
|
||||
<< " of size " << format_size(params.buffer_block->size)
|
||||
<< " at " << params.buffer_block->buffer
|
||||
<< " (requested: " << format_size(params.requested_size)
|
||||
<< ", use#: " << params.buffer_block->use_count + 1
|
||||
<< ", retain#: " << params.buffer_block->retainCount() << ")\n";
|
||||
std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
|
||||
<< params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at "
|
||||
<< params.buffer_block->buffer << " (requested: " << format_size(params.requested_size)
|
||||
<< ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount()
|
||||
<< ")\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -214,7 +208,8 @@ BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usag
|
||||
alloc_buffer(params) ||
|
||||
// Callbacks might release more memory (eg. by forcing a GC in the host language) thus
|
||||
// we can retry getting a free buffer in the pool, before trying to alloc again.
|
||||
(trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && get_free_buffer(params)) ||
|
||||
(trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) &&
|
||||
get_free_buffer(params)) ||
|
||||
// Free enough available cached blocks to satisfy alloc and retry alloc.
|
||||
(release_available_cached_buffers(params) && alloc_buffer(params)) ||
|
||||
// Free all cached buffers and retry alloc.
|
||||
@ -229,16 +224,30 @@ BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usag
|
||||
// chunk of requested size couldn't be found.
|
||||
if (!block_found || !buffer_block) {
|
||||
if (m_high_watermark_ratio > 0.0) {
|
||||
TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory),
|
||||
", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
", max allowed: ", format_size(m_max_total_allowed_size), "). Tried to allocate ", format_size(alloc_size),
|
||||
" on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"MPS backend out of memory (MPS allocated: ",
|
||||
format_size(m_total_allocated_memory),
|
||||
", other allocations: ",
|
||||
format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
", max allowed: ",
|
||||
format_size(m_max_total_allowed_size),
|
||||
"). Tried to allocate ",
|
||||
format_size(alloc_size),
|
||||
" on ",
|
||||
((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
|
||||
" pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).");
|
||||
} else {
|
||||
TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory),
|
||||
", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
"). Tried to allocate ", format_size(alloc_size),
|
||||
" on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), " pool.");
|
||||
TORCH_CHECK(false,
|
||||
"MPS backend out of memory (MPS allocated: ",
|
||||
format_size(m_total_allocated_memory),
|
||||
", other allocations: ",
|
||||
format_size(current_allocated_size() - m_total_allocated_memory),
|
||||
"). Tried to allocate ",
|
||||
format_size(alloc_size),
|
||||
" on ",
|
||||
((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
|
||||
" pool.");
|
||||
}
|
||||
}
|
||||
buffer_block->in_use = true;
|
||||
@ -284,12 +293,9 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
|
||||
|
||||
if ((m_debug_verbosity & DebugVerbosity::RELEASES) &&
|
||||
(!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
|
||||
std::cerr << "Released buffer #" << buffer_block->buf_id
|
||||
<< " of size " << format_size(buffer_block->size)
|
||||
<< " from heap #" << heap_block->heap_id
|
||||
<< " (heap size: " << format_size(heap_block->size.available)
|
||||
<< ", use#: " << buffer_block->use_count
|
||||
<< ", retain#: " << retainCount
|
||||
std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size)
|
||||
<< " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available)
|
||||
<< ", use#: " << buffer_block->use_count << ", retain#: " << retainCount
|
||||
<< ", gc#: " << buffer_block->gc_count << ")\n";
|
||||
}
|
||||
delete buffer_block;
|
||||
@ -298,10 +304,9 @@ bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove
|
||||
pool.heaps_pending_update.erase(heap_block);
|
||||
retainCount = heap_block->releaseMTLHeap();
|
||||
if (m_debug_verbosity & DebugVerbosity::RELEASES) {
|
||||
std::cerr << "Released heap #" << heap_block->heap_id
|
||||
<< " of size " << format_size(heap_block->size.total)
|
||||
<< " (current allocated: " << format_size(current_allocated_size())
|
||||
<< ", retain#: " << retainCount << ")\n";
|
||||
std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total)
|
||||
<< " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount
|
||||
<< ")\n";
|
||||
}
|
||||
delete heap_block;
|
||||
return true;
|
||||
@ -333,13 +338,11 @@ void MPSHeapAllocatorImpl::release_buffers(BufferPool& pool) {
|
||||
return;
|
||||
}
|
||||
if ((m_debug_verbosity & DebugVerbosity::RELEASES)) {
|
||||
std::cerr << "Releasing " << pool.buffers.size()
|
||||
<< " buffers from "
|
||||
std::cerr << "Releasing " << pool.buffers.size() << " buffers from "
|
||||
<< ((pool.usage & UsageFlags::SMALL) ? "small " : "large ")
|
||||
<< ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "")
|
||||
<< " pool (total size: " << format_size(pool.allocated_size)
|
||||
<< ", #buffers: " << pool.n_buffers << ")\n";
|
||||
<< " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n";
|
||||
}
|
||||
auto it = pool.buffers.begin();
|
||||
while (it != pool.buffers.end()) {
|
||||
@ -381,10 +384,8 @@ bool MPSHeapAllocatorImpl::release_available_cached_buffers(AllocParams& params)
|
||||
|
||||
bool MPSHeapAllocatorImpl::release_cached_buffers() {
|
||||
if (m_debug_verbosity >= DebugVerbosity::PROFILING) {
|
||||
std::cerr << "Attempting to release cached buffers (MPS allocated: "
|
||||
<< format_size(m_total_allocated_memory)
|
||||
<< ", other allocations: "
|
||||
<< format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
|
||||
std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory)
|
||||
<< ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
|
||||
}
|
||||
// before releasing the buffers make sure the command buffer has finished.
|
||||
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
||||
@ -445,11 +446,10 @@ void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) {
|
||||
}
|
||||
}
|
||||
if (m_debug_verbosity & DebugVerbosity::RELEASES) {
|
||||
std::cerr << "Garbage collected " << freed_count
|
||||
<< " buffers from large "
|
||||
std::cerr << "Garbage collected " << freed_count << " buffers from large "
|
||||
<< ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< " pool (total reclaimed: " << format_size(gc_reclaimed)
|
||||
<< ", #buffers: " << pool.buffers.size() << ")\n";
|
||||
<< " pool (total reclaimed: " << format_size(gc_reclaimed) << ", #buffers: " << pool.buffers.size()
|
||||
<< ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -555,10 +555,15 @@ inline std::string MPSHeapAllocatorImpl::format_size(uint64_t size) const {
|
||||
std::ostringstream os;
|
||||
os.precision(2);
|
||||
os << std::fixed;
|
||||
if (size <= 1024UL) { os << size << " bytes"; }
|
||||
else if (size <= 1048576UL) { os << ((float) size / 1024.0) << " KB"; }
|
||||
else if (size <= 1073741824UL) { os << ((float) size / 1048576.0) << " MB"; }
|
||||
else { os << ((float) size / 1073741824.0) << " GB"; }
|
||||
if (size <= 1024UL) {
|
||||
os << size << " bytes";
|
||||
} else if (size <= 1048576UL) {
|
||||
os << ((float)size / 1024.0) << " KB";
|
||||
} else if (size <= 1073741824UL) {
|
||||
os << ((float)size / 1048576.0) << " MB";
|
||||
} else {
|
||||
os << ((float)size / 1073741824.0) << " GB";
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@ -575,15 +580,12 @@ HeapAllocator::MPSHeapAllocatorImpl& _getAllocImpl() {
|
||||
// MPS allocator struct to be registered with Pytorch
|
||||
struct TORCH_API MPSAllocator final : public IMPSAllocator {
|
||||
public:
|
||||
explicit MPSAllocator(uint32_t Usage) :
|
||||
m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage)
|
||||
{
|
||||
explicit MPSAllocator(uint32_t Usage)
|
||||
: m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) {
|
||||
if (_getAllocImpl().getDebugVerbosity()) {
|
||||
if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) {
|
||||
std::cerr << "Initializing "
|
||||
<< ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< " heap allocator on "
|
||||
<< (m_has_unified_memory ? "unified" : "discrete")
|
||||
std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
|
||||
<< " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete")
|
||||
<< " device memory of size "
|
||||
<< _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n";
|
||||
}
|
||||
@ -593,7 +595,9 @@ public:
|
||||
~MPSAllocator() override {
|
||||
_getAllocImpl().emptyCache();
|
||||
}
|
||||
DeleterFnPtr raw_deleter() const override { return &Delete; }
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return &Delete;
|
||||
}
|
||||
|
||||
DataPtr allocate(const size_t nbytes) const override {
|
||||
__block id<MTLBuffer> buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr;
|
||||
@ -605,20 +609,48 @@ public:
|
||||
id<MTLBuffer> buf = _getAllocImpl().allocScalarBufferWithValue(value, size);
|
||||
return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
|
||||
}
|
||||
bool isSharedBuffer(void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); }
|
||||
bool isSharedStorageSupported() const override { return m_has_unified_memory; }
|
||||
void emptyCache() const override { _getAllocImpl().emptyCache(); }
|
||||
ssize_t getUnalignedBufferSize(void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); }
|
||||
IntArrayRef getBufferShape(void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); }
|
||||
void setBufferShape(void* ptr, const IntArrayRef& shape) const override { _getAllocImpl().setBufferShape(ptr, shape); }
|
||||
size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); }
|
||||
size_t getCurrentAllocatedMemory() const override { return _getAllocImpl().getCurrentAllocatedMemory(); }
|
||||
size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); }
|
||||
ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); }
|
||||
size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); }
|
||||
size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); }
|
||||
void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); }
|
||||
void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); }
|
||||
bool isSharedBuffer(void* ptr) const override {
|
||||
return _getAllocImpl().isSharedBuffer(ptr);
|
||||
}
|
||||
bool isSharedStorageSupported() const override {
|
||||
return m_has_unified_memory;
|
||||
}
|
||||
void emptyCache() const override {
|
||||
_getAllocImpl().emptyCache();
|
||||
}
|
||||
ssize_t getUnalignedBufferSize(void* ptr) const override {
|
||||
return _getAllocImpl().getUnalignedBufferSize(ptr);
|
||||
}
|
||||
IntArrayRef getBufferShape(void* ptr) const override {
|
||||
return _getAllocImpl().getBufferShape(ptr);
|
||||
}
|
||||
void setBufferShape(void* ptr, const IntArrayRef& shape) const override {
|
||||
_getAllocImpl().setBufferShape(ptr, shape);
|
||||
}
|
||||
size_t getTotalAllocatedMemory() const override {
|
||||
return _getAllocImpl().getTotalAllocatedMemory();
|
||||
}
|
||||
size_t getCurrentAllocatedMemory() const override {
|
||||
return _getAllocImpl().getCurrentAllocatedMemory();
|
||||
}
|
||||
size_t getDriverAllocatedMemory() const override {
|
||||
return _getAllocImpl().getDriverAllocatedMemory();
|
||||
}
|
||||
ssize_t getLowWatermarkValue() const override {
|
||||
return _getAllocImpl().getLowWatermarkValue();
|
||||
}
|
||||
size_t getLowWatermarkLimit() const override {
|
||||
return _getAllocImpl().getLowWatermarkLimit();
|
||||
}
|
||||
size_t getHighWatermarkLimit() const override {
|
||||
return _getAllocImpl().getHighWatermarkLimit();
|
||||
}
|
||||
void setLowWatermarkRatio(double ratio) const override {
|
||||
_getAllocImpl().setLowWatermarkRatio(ratio);
|
||||
}
|
||||
void setHighWatermarkRatio(double ratio) const override {
|
||||
_getAllocImpl().setHighWatermarkRatio(ratio);
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_has_unified_memory;
|
||||
@ -662,15 +694,13 @@ namespace native {
|
||||
// Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we
|
||||
// will be able to use SharedStorageMode for MTLBuffer allocations. This will
|
||||
// avoid extra copies on DataLoading operations.
|
||||
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device)
|
||||
{
|
||||
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
|
||||
return at::mps::_getSharedAllocator().isSharedBuffer(self.storage().data());
|
||||
}
|
||||
|
||||
// torch.pin_memory() implementation
|
||||
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device)
|
||||
{
|
||||
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
|
||||
auto* shared_allocator = at::mps::getIMPSAllocator(true);
|
||||
TORCH_CHECK(shared_allocator, "unable to pin memory on a non-unified memory device");
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
@ -23,9 +23,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
|
||||
}
|
||||
|
||||
MPSDevice* MPSDevice::getInstance() {
|
||||
c10::call_once(mpsdev_init, [] {
|
||||
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
|
||||
});
|
||||
c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
|
||||
return mps_device.get();
|
||||
}
|
||||
|
||||
@ -36,7 +34,8 @@ id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLF
|
||||
MTLCompileOptions* options = [MTLCompileOptions new];
|
||||
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
|
||||
[options setFastMathEnabled:YES];
|
||||
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
|
||||
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
|
||||
encoding:NSASCIIStringEncoding]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
|
||||
@ -48,10 +47,15 @@ id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLF
|
||||
constantValues:constantValues
|
||||
error:&error] autorelease];
|
||||
} else {
|
||||
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
|
||||
indexFunction =
|
||||
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
|
||||
}
|
||||
|
||||
TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]);
|
||||
TORCH_CHECK(indexFunction,
|
||||
"Failed to create specialized function state object: ",
|
||||
kernel,
|
||||
", error: ",
|
||||
[[error description] UTF8String]);
|
||||
|
||||
return indexFunction;
|
||||
}
|
||||
@ -69,14 +73,8 @@ MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
|
||||
// which is used by MPS backend.
|
||||
id mpsCD = NSClassFromString(@"MPSGraph");
|
||||
|
||||
if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor:
|
||||
recurrentWeight:
|
||||
inputWeight:
|
||||
bias:
|
||||
initState:
|
||||
initCell:
|
||||
descriptor:
|
||||
name:)] == NO) {
|
||||
if ([mpsCD instancesRespondToSelector:@selector
|
||||
(LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -89,23 +87,32 @@ MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
|
||||
|
||||
}
|
||||
|
||||
bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
id mpsCD = NSClassFromString(@"MPSGraph");
|
||||
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES;
|
||||
static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector(
|
||||
sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
|
||||
static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
|
||||
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:
|
||||
axis:name:)] == YES;
|
||||
static bool _macos_13_1_plus =
|
||||
[mpsCD instancesRespondToSelector:@selector
|
||||
(sampleGridWithSourceTensor:
|
||||
coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode
|
||||
:samplingMode:constantValue:name:)] == YES;
|
||||
static bool _macos_13_2_plus =
|
||||
[mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
|
||||
static bool _macos_13_3_plus = [_mtl_device respondsToSelector:@selector(maximumConcurrentCompilationTaskCount)];
|
||||
|
||||
switch (version) {
|
||||
case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus;
|
||||
case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus;
|
||||
case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus;
|
||||
case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus;
|
||||
default: return false;
|
||||
case MacOSVersion::MACOS_VER_13_0_PLUS:
|
||||
return _macos_13_0_plus;
|
||||
case MacOSVersion::MACOS_VER_13_1_PLUS:
|
||||
return _macos_13_1_plus;
|
||||
case MacOSVersion::MACOS_VER_13_2_PLUS:
|
||||
return _macos_13_2_plus;
|
||||
case MacOSVersion::MACOS_VER_13_3_PLUS:
|
||||
return _macos_13_3_plus;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,35 +4,41 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
{
|
||||
TORCH_WARN_ONCE("The operator '", op.schema().operator_name(), "' is not currently supported ",
|
||||
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_WARN_ONCE("The operator '",
|
||||
op.schema().operator_name(),
|
||||
"' is not currently supported ",
|
||||
"on the MPS backend and will fall back to run on the CPU.",
|
||||
" This may have performance implications.");
|
||||
native::cpu_fallback(op, stack);
|
||||
}
|
||||
|
||||
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
{
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not currently implemented ",
|
||||
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack){TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"The operator '",
|
||||
op.schema().operator_name(),
|
||||
"' is not currently implemented ",
|
||||
"for the MPS device. If you want this op to be added in priority during the prototype ",
|
||||
"phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ",
|
||||
"As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
|
||||
"to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
|
||||
"on MPS.")
|
||||
}
|
||||
|
||||
"on MPS.")}
|
||||
|
||||
// This dispatch should never be called for tensor on MPS but is frequently called
|
||||
// If one of them are on CPU
|
||||
Tensor slow_conv2d_forward_mps(
|
||||
const Tensor &self,
|
||||
Tensor slow_conv2d_forward_mps(const Tensor& self,
|
||||
const Tensor& weight,
|
||||
IntArrayRef kernel_size,
|
||||
const c10::optional<Tensor>& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding) {
|
||||
TORCH_CHECK(self.device() == weight.device(), __func__, ": input(device='", self.device(), "') and weight(device=", weight.device(), "') must be on the same device");
|
||||
TORCH_CHECK(self.device() == weight.device(),
|
||||
__func__,
|
||||
": input(device='",
|
||||
self.device(),
|
||||
"') and weight(device=",
|
||||
weight.device(),
|
||||
"') must be on the same device");
|
||||
TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,8 @@ Generator createMPSGenerator(uint64_t seed_val) {
|
||||
|
||||
MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
|
||||
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
|
||||
data_({.seed = seed_in}), engine_(seed_in, 0, 0) { }
|
||||
data_({.seed = seed_in}),
|
||||
engine_(seed_in, 0, 0) {}
|
||||
|
||||
void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
data_.seed = seed;
|
||||
@ -60,7 +61,8 @@ c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t total_size = states_size + seed_size;
|
||||
|
||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
|
||||
auto state_tensor = at::detail::empty_cpu(
|
||||
{(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
|
||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||
auto current_seed = this->current_seed();
|
||||
memcpy(rng_state, this->data_.state.data(), states_size);
|
||||
|
@ -1,31 +1,24 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/mps/MPSGuardImpl.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSGuardImpl.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
|
||||
void MPSGuardImpl::createEvent(
|
||||
mpsEvent_t* event,
|
||||
const EventFlag flag) const {
|
||||
}
|
||||
void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {}
|
||||
|
||||
void MPSGuardImpl::destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept {
|
||||
if (!event) return;
|
||||
void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept {
|
||||
if (!event)
|
||||
return;
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
mps_event->~MPSEvent();
|
||||
|
||||
}
|
||||
|
||||
void MPSGuardImpl::record(
|
||||
void** event,
|
||||
void MPSGuardImpl::record(void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const {
|
||||
|
||||
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
@ -38,10 +31,7 @@ namespace mps {
|
||||
mps_event->recordEvent(true);
|
||||
}
|
||||
|
||||
void MPSGuardImpl::block(
|
||||
void* event,
|
||||
const Stream& stream) const {
|
||||
|
||||
void MPSGuardImpl::block(void* event, const Stream& stream) const {
|
||||
auto mps_event = static_cast<mpsEvent_t>(event);
|
||||
MPSStream mps_stream{stream};
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
@ -17,9 +17,9 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
|
||||
TORCH_CHECK(_stream.device_type() == DeviceType::MPS);
|
||||
_serialQueue = dispatch_queue_create("metal gpu stream", nullptr);
|
||||
_executionDescriptor = [MPSGraphExecutionDescriptor new];
|
||||
_executionDescriptor.completionHandler = ^(NSDictionary<MPSGraphTensor *,
|
||||
MPSGraphTensorData *> * resultsDictionary,
|
||||
NSError * _Nullable error) { };
|
||||
_executionDescriptor.completionHandler =
|
||||
^(NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* resultsDictionary, NSError* _Nullable error) {
|
||||
};
|
||||
}
|
||||
|
||||
MPSStream::~MPSStream() {
|
||||
@ -115,25 +115,27 @@ void MPSStream::addCompletedHandler(MTLCommandBufferHandler block) {
|
||||
});
|
||||
}
|
||||
|
||||
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType)
|
||||
{
|
||||
void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) {
|
||||
TORCH_INTERNAL_ASSERT(length >= offset);
|
||||
if (length == 0) return;
|
||||
if (length == 0)
|
||||
return;
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
|
||||
[blitEncoder fillBuffer:buffer
|
||||
range:NSMakeRange(offset, length)
|
||||
value:value];
|
||||
[blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value];
|
||||
[blitEncoder endEncoding];
|
||||
synchronize(syncType);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) {
|
||||
void MPSStream::copy(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
SyncType syncType) {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -149,10 +151,14 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
});
|
||||
}
|
||||
|
||||
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, size_t length,
|
||||
size_t srcOffset, size_t dstOffset, bool non_blocking) {
|
||||
copy(srcBuffer, dstBuffer, length, srcOffset, dstOffset,
|
||||
!non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
|
||||
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
bool non_blocking) {
|
||||
copy(
|
||||
srcBuffer, dstBuffer, length, srcOffset, dstOffset, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
|
||||
}
|
||||
|
||||
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) {
|
||||
@ -184,8 +190,7 @@ MPSStream* MPSStreamImpl::_stream = nullptr;
|
||||
|
||||
MPSStream* MPSStreamImpl::getInstance() {
|
||||
if (_stream == nullptr) {
|
||||
_stream =
|
||||
new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
|
||||
_stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
|
||||
}
|
||||
return _stream;
|
||||
}
|
||||
@ -204,8 +209,8 @@ MPSStream* getDefaultMPSStream() {
|
||||
// MPSEvent
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
MPSEvent::MPSEvent(bool deferInitialization) :
|
||||
is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
|
||||
MPSEvent::MPSEvent(bool deferInitialization)
|
||||
: is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
|
||||
if (!deferInitialization) {
|
||||
initialize();
|
||||
}
|
||||
@ -256,8 +261,7 @@ void MPSEvent::waitForEvent(bool syncEvent) {
|
||||
});
|
||||
}
|
||||
|
||||
void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block)
|
||||
{
|
||||
void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) {
|
||||
if (!is_initialized)
|
||||
initialize();
|
||||
dispatch_sync(_stream->queue(), ^() {
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
@ -28,27 +28,31 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
case ScalarType::Bool:
|
||||
return MPSDataTypeBool;
|
||||
case ScalarType::Double:
|
||||
TORCH_CHECK_TYPE(false, "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
|
||||
TORCH_CHECK_TYPE(false,
|
||||
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
|
||||
"Please use float32 instead.")
|
||||
default:
|
||||
TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
TORCH_CHECK_TYPE(
|
||||
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
}
|
||||
}
|
||||
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
|
||||
// Int32, Half and Float32 types. These utilities are to help cast to these
|
||||
// types.
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const Tensor& input,
|
||||
bool includesInt64) {
|
||||
MPSDataType dataType = getMPSDataType(input.scalar_type());
|
||||
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
bool condition =
|
||||
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
if (includesInt64) {
|
||||
condition = condition && (dataType != MPSDataTypeInt64);
|
||||
}
|
||||
if (condition) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
return [mpsGraph castTensor:inputTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
return inputTensor;
|
||||
}
|
||||
@ -56,16 +60,18 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor,
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
|
||||
// Int32, Half and Float32 types. These utilities are to help cast from these
|
||||
// types.
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const Tensor& input,
|
||||
bool includesInt64) {
|
||||
MPSDataType dataType = getMPSDataType(input.scalar_type());
|
||||
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
bool condition =
|
||||
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
if (includesInt64) {
|
||||
condition = condition && (dataType != MPSDataTypeInt64);
|
||||
}
|
||||
if (condition) {
|
||||
inputTensor = [mpsGraph castTensor:inputTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
return inputTensor;
|
||||
}
|
||||
@ -92,7 +98,8 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
|
||||
case ScalarType::Bool:
|
||||
return MPSDataTypeBool;
|
||||
default:
|
||||
TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
TORCH_CHECK_TYPE(
|
||||
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,16 +239,17 @@ MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) {
|
||||
}
|
||||
|
||||
void printTensorNDArray(const Tensor& t) {
|
||||
if (!t.is_mps()) return;
|
||||
if(t.numel() == 0) return;
|
||||
if (!t.is_mps())
|
||||
return;
|
||||
if (t.numel() == 0)
|
||||
return;
|
||||
// Get shape and data type
|
||||
auto selfShape = getMPSShape(t);
|
||||
auto selfDType = getMPSDataType(t.scalar_type());
|
||||
|
||||
// Initialize data
|
||||
id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
|
||||
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf
|
||||
shape:selfShape
|
||||
MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape
|
||||
dataType:selfDType] autorelease];
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wobjc-method-access")
|
||||
@ -251,8 +259,7 @@ void printTensorNDArray(const Tensor& t) {
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
}
|
||||
|
||||
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType)
|
||||
{
|
||||
MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) {
|
||||
id<MTLBuffer> buffer = getMTLBufferStorage(tensor);
|
||||
MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer
|
||||
shape:shape
|
||||
@ -261,9 +268,12 @@ MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType
|
||||
return [tmpGraphTensorData mpsndarray];
|
||||
}
|
||||
|
||||
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape,
|
||||
bool gatherTensorData, MPSDataType dataType) : _tensor(src)
|
||||
{
|
||||
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
const Tensor& src,
|
||||
MPSShape* mpsShape,
|
||||
bool gatherTensorData,
|
||||
MPSDataType dataType)
|
||||
: _tensor(src) {
|
||||
TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!");
|
||||
// extract the pointer to MTLBuffer from the Tensor's storage
|
||||
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
|
||||
@ -285,8 +295,9 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
|
||||
// if buffer size is zero in here, it's not a user error. It could be a missing check for
|
||||
// tensor.numel() == 0 in our internal implementations of ops.
|
||||
TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!");
|
||||
const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType :
|
||||
_tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type());
|
||||
const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType
|
||||
: _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type())
|
||||
: getMPSDataType(_tensor.scalar_type());
|
||||
|
||||
if (src.is_contiguous() && src.storage_offset() && sliceViewTensor) {
|
||||
_value = getMPSGraphTensorDataForView(src, mpsShape, mpsDataType);
|
||||
@ -295,34 +306,25 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
|
||||
mpsShape = getMPSShape(_tensor);
|
||||
}
|
||||
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
|
||||
shape:mpsShape
|
||||
dataType:mpsDataType] autorelease];
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:mpsShape dataType:mpsDataType] autorelease];
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(_value);
|
||||
_placeholder = mpsGraphTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph,
|
||||
MPSStream* mpsStream,
|
||||
const Tensor& tensor) {
|
||||
MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) {
|
||||
auto mpsShape = getMPSShape(tensor);
|
||||
auto dataType = getMPSDataType(tensor.scalar_type());
|
||||
|
||||
MPSGraphTensorData* result = nil;
|
||||
if (tensor.numel() > 0) {
|
||||
id<MTLBuffer> buf = getMTLBufferStorage(tensor);
|
||||
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf
|
||||
shape:mpsShape
|
||||
dataType:dataType]
|
||||
autorelease];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] autorelease];
|
||||
} else {
|
||||
// create empty NDArray
|
||||
MPSNDArrayDescriptor *desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
|
||||
shape:mpsShape];
|
||||
MPSNDArray *emptyArray = [[[MPSNDArray alloc]
|
||||
initWithDevice:mpsStream->device() descriptor:desc] autorelease];
|
||||
MPSNDArrayDescriptor* desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsShape];
|
||||
MPSNDArray* emptyArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:desc] autorelease];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:emptyArray] autorelease];
|
||||
}
|
||||
assert(result);
|
||||
@ -332,14 +334,22 @@ MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph,
|
||||
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) {
|
||||
switch (type) {
|
||||
case ScalarType::Double:
|
||||
case ScalarType::Float: return {.value.f = scalar.to<float>() , .size = sizeof(float) , .type = type};
|
||||
case ScalarType::Half: return {.value.h = scalar.to<at::Half>(), .size = sizeof(short) , .type = type};
|
||||
case ScalarType::Long: return {.value.i = scalar.to<int64_t>() , .size = sizeof(int64_t), .type = type};
|
||||
case ScalarType::Int: return {.value.i = scalar.to<int32_t>() , .size = sizeof(int32_t), .type = type};
|
||||
case ScalarType::Short: return {.value.i = scalar.to<int16_t>() , .size = sizeof(int16_t), .type = type};
|
||||
case ScalarType::Char: return {.value.i = scalar.to<int8_t>() , .size = sizeof(int8_t) , .type = type};
|
||||
case ScalarType::Byte: return {.value.i = scalar.to<uint8_t>() , .size = sizeof(uint8_t), .type = type};
|
||||
case ScalarType::Bool: return {.value.b = scalar.to<bool>() , .size = sizeof(bool) , .type = type};
|
||||
case ScalarType::Float:
|
||||
return {.value.f = scalar.to<float>(), .size = sizeof(float), .type = type};
|
||||
case ScalarType::Half:
|
||||
return {.value.h = scalar.to<at::Half>(), .size = sizeof(short), .type = type};
|
||||
case ScalarType::Long:
|
||||
return {.value.i = scalar.to<int64_t>(), .size = sizeof(int64_t), .type = type};
|
||||
case ScalarType::Int:
|
||||
return {.value.i = scalar.to<int32_t>(), .size = sizeof(int32_t), .type = type};
|
||||
case ScalarType::Short:
|
||||
return {.value.i = scalar.to<int16_t>(), .size = sizeof(int16_t), .type = type};
|
||||
case ScalarType::Char:
|
||||
return {.value.i = scalar.to<int8_t>(), .size = sizeof(int8_t), .type = type};
|
||||
case ScalarType::Byte:
|
||||
return {.value.i = scalar.to<uint8_t>(), .size = sizeof(uint8_t), .type = type};
|
||||
case ScalarType::Bool:
|
||||
return {.value.b = scalar.to<bool>(), .size = sizeof(bool), .type = type};
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend.");
|
||||
}
|
||||
@ -354,8 +364,10 @@ MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar&
|
||||
shape:@[ @1 ]
|
||||
dataType:getMPSScalarType(scalar.type)] autorelease];
|
||||
} else {
|
||||
MPSNDArrayDescriptor *tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type) shape:@[@1]];
|
||||
MPSNDArray *tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:tensorDesc] autorelease];
|
||||
MPSNDArrayDescriptor* tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type)
|
||||
shape:@[ @1 ]];
|
||||
MPSNDArray* tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device()
|
||||
descriptor:tensorDesc] autorelease];
|
||||
[tensorNDArray writeBytes:&scalar.value strideBytes:nil];
|
||||
result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:tensorNDArray] autorelease];
|
||||
}
|
||||
@ -372,33 +384,23 @@ MPSGraph* make_mps_graph() {
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
|
||||
return [mpsGraph placeholderWithShape:nil
|
||||
dataType:dataType
|
||||
name:nil];
|
||||
return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
|
||||
return [mpsGraph placeholderWithShape:mpsShape
|
||||
dataType:dataType
|
||||
name:nil];
|
||||
return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) {
|
||||
return [mpsGraph placeholderWithShape:getMPSShape(tensor)
|
||||
dataType:getMPSScalarType(tensor.scalar_type())
|
||||
name:nil];
|
||||
return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
|
||||
return [mpsGraph placeholderWithShape:@[@1]
|
||||
dataType:dataType
|
||||
name:nil];
|
||||
return [mpsGraph placeholderWithShape:@[ @1 ] dataType:dataType name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar) {
|
||||
return [mpsGraph placeholderWithShape:@[@1]
|
||||
dataType:getMPSScalarType(scalar.type())
|
||||
name:nil];
|
||||
return [mpsGraph placeholderWithShape:@[ @1 ] dataType:getMPSScalarType(scalar.type()) name:nil];
|
||||
}
|
||||
|
||||
// this is meant to suppress the availability warning on castTensor
|
||||
@ -417,7 +419,9 @@ MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, Scalar
|
||||
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) {
|
||||
TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!");
|
||||
return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil]
|
||||
dimension:2 withDimension:1 name: nil];
|
||||
dimension:2
|
||||
withDimension:1
|
||||
name:nil];
|
||||
}
|
||||
|
||||
string get_mem_format_string(c10::MemoryFormat memory_format) {
|
||||
@ -443,6 +447,7 @@ public:
|
||||
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) {}
|
||||
|
||||
void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
|
||||
|
||||
private:
|
||||
MPSGraphCache* graph_cache;
|
||||
};
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,17 +1,19 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void set_kernel_params
|
||||
(int64_t isizeH, int64_t isizeW,
|
||||
int64_t osizeH, int64_t osizeW,
|
||||
int64_t &strideH, int64_t &strideW,
|
||||
int64_t &kernel_sizeH, int64_t &kernel_sizeW,
|
||||
void set_kernel_params(int64_t isizeH,
|
||||
int64_t isizeW,
|
||||
int64_t osizeH,
|
||||
int64_t osizeW,
|
||||
int64_t& strideH,
|
||||
int64_t& strideW,
|
||||
int64_t& kernel_sizeH,
|
||||
int64_t& kernel_sizeW,
|
||||
bool check_avg_pooling = false) {
|
||||
|
||||
TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW),
|
||||
"Adaptive pool MPS: Input height and width must both be greater than, "
|
||||
"or equal to, or lesser than output height and width")
|
||||
@ -38,15 +40,15 @@ void set_kernel_params
|
||||
}
|
||||
|
||||
// Adaptive average pooling
|
||||
Tensor& adaptive_avg_pool2d_out_mps
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
Tensor& output) {
|
||||
|
||||
Tensor& adaptive_avg_pool2d_out_mps(const Tensor& input, IntArrayRef output_size, Tensor& output) {
|
||||
for (int64_t i = 1; i < input.ndimension(); i++) {
|
||||
TORCH_CHECK(input.size(i) > 0,
|
||||
"adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
"but input has sizes ", input.sizes(), " with dimension ", i, " being empty");
|
||||
"but input has sizes ",
|
||||
input.sizes(),
|
||||
" with dimension ",
|
||||
i,
|
||||
" being empty");
|
||||
}
|
||||
|
||||
int64_t isizeH = input.size(-2);
|
||||
@ -57,10 +59,7 @@ Tensor& adaptive_avg_pool2d_out_mps
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW, true);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
|
||||
|
||||
if (isizeH >= osizeH) {
|
||||
output = at::avg_pool2d(input,
|
||||
@ -92,10 +91,7 @@ Tensor& adaptive_avg_pool2d_out_mps
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool2d_mps
|
||||
(at::Tensor const& input,
|
||||
IntArrayRef output_size) {
|
||||
|
||||
Tensor adaptive_avg_pool2d_mps(at::Tensor const& input, IntArrayRef output_size) {
|
||||
IntArrayRef output_shape;
|
||||
|
||||
auto osizeH = output_size[0];
|
||||
@ -112,8 +108,7 @@ Tensor adaptive_avg_pool2d_mps
|
||||
out_dims.push_back(osizeH);
|
||||
out_dims.push_back(osizeW);
|
||||
output_shape = IntArrayRef(out_dims);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
auto sizeD = input.size(0);
|
||||
out_dims.push_back(sizeD);
|
||||
out_dims.push_back(osizeH);
|
||||
@ -122,21 +117,12 @@ Tensor adaptive_avg_pool2d_mps
|
||||
}
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
Tensor output = at::native::empty_mps(
|
||||
output_shape,
|
||||
input.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
memory_format);
|
||||
Tensor output =
|
||||
at::native::empty_mps(output_shape, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format);
|
||||
return adaptive_avg_pool2d_out_mps(input, output_size, output);
|
||||
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool2d_backward_mps
|
||||
(const Tensor& gradOutput,
|
||||
const Tensor& input) {
|
||||
|
||||
Tensor adaptive_avg_pool2d_backward_mps(const Tensor& gradOutput, const Tensor& input) {
|
||||
int64_t isizeH = input.size(-2);
|
||||
int64_t isizeW = input.size(-1);
|
||||
int64_t osizeH = gradOutput.size(-2);
|
||||
@ -145,10 +131,7 @@ Tensor adaptive_avg_pool2d_backward_mps
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW, true);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
|
||||
|
||||
auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
if (gradInput.numel() != 0) {
|
||||
@ -178,15 +161,15 @@ Tensor adaptive_avg_pool2d_backward_mps
|
||||
|
||||
// Adaptive max pooling
|
||||
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
|
||||
(const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
|
||||
for (int64_t i = 1; i < input.ndimension(); i++) {
|
||||
TORCH_CHECK(input.size(i) > 0,
|
||||
"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
|
||||
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
|
||||
"but input has sizes ",
|
||||
input.sizes(),
|
||||
" with dimension ",
|
||||
i,
|
||||
" being "
|
||||
"empty");
|
||||
}
|
||||
|
||||
@ -198,13 +181,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
|
||||
|
||||
at::max_pool2d_with_indices_out(const_cast<Tensor&>(output),
|
||||
const_cast<Tensor&>(indices), input,
|
||||
const_cast<Tensor&>(indices),
|
||||
input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
@ -213,11 +194,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
|
||||
(const Tensor& gradOutput,
|
||||
const Tensor& input,
|
||||
const Tensor& indices,
|
||||
const Tensor& gradInput) {
|
||||
|
||||
(const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) {
|
||||
int64_t isizeH = input.size(-2);
|
||||
int64_t isizeW = input.size(-1);
|
||||
int64_t osizeH = gradOutput.size(-2);
|
||||
@ -226,13 +203,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
|
||||
int64_t strideH = 0, strideW = 0;
|
||||
int64_t kernel_sizeH = 0, kernel_sizeW = 0;
|
||||
|
||||
set_kernel_params(isizeH, isizeW,
|
||||
osizeH, osizeW,
|
||||
strideH, strideW,
|
||||
kernel_sizeH, kernel_sizeW);
|
||||
set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
|
||||
|
||||
at::max_pool2d_with_indices_backward_out(const_cast<Tensor&>(gradInput),
|
||||
gradOutput, input,
|
||||
gradOutput,
|
||||
input,
|
||||
IntArrayRef({kernel_sizeH, kernel_sizeW}),
|
||||
IntArrayRef({strideH, strideW}),
|
||||
IntArrayRef({0, 0}),
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
@ -188,12 +188,14 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
|
||||
error: &error] autorelease];
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
|
||||
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
|
||||
@ -206,8 +208,7 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
|
||||
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
|
||||
|
||||
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input.scalar_type());
|
||||
id<MTLComputePipelineState> binaryPSO = binaryPipelineState(device, kernel);
|
||||
@ -223,8 +224,7 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
||||
}
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: threadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
[computeEncoder endEncoding];
|
||||
mpsStream->commit(true);
|
||||
|
@ -4,34 +4,40 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct BinaryOpCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct BinaryOpCachedGraph : public MPSCachedGraph {
|
||||
BinaryOpCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *primaryTensor = nil, *secondaryTensor = nil;
|
||||
MPSGraphTensor *alphaTensor = nil, *outputTensor = nil;
|
||||
};
|
||||
|
||||
typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
|
||||
#define BinaryOpFn(graph, primary, secondary) MPSGraphTensor* (mps::BinaryOpCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
|
||||
#define BinaryOpFn(graph, primary, secondary) \
|
||||
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
|
||||
|
||||
// alpha is always 1.0 except when this function is called from add_sub_template()
|
||||
void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha,
|
||||
const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock)
|
||||
{
|
||||
void binaryOpTensor(const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar& alpha,
|
||||
const Tensor& output_,
|
||||
std::string op_name,
|
||||
BinaryOpBlock binaryBlock) {
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
|
||||
"MPS support binary op with uint8 natively starting from macOS 13.0");
|
||||
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
|
||||
(self.scalar_type() == ScalarType::Long ||
|
||||
(other.scalar_type() == ScalarType::Long && (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
|
||||
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.2");
|
||||
(other.scalar_type() == ScalarType::Long &&
|
||||
(self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
|
||||
"MPS: ",
|
||||
op_name,
|
||||
" op with int64 input is supported natively starting from macOS 13.2");
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
const bool is_self_scalar = self.dim() == 0;
|
||||
@ -87,8 +93,10 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new BinaryOpCachedGraph(mpsGraph);
|
||||
newCachedGraph->primaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
|
||||
newCachedGraph->secondaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
|
||||
newCachedGraph->primaryTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
|
||||
newCachedGraph->secondaryTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
|
||||
|
||||
MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor;
|
||||
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;
|
||||
@ -101,8 +109,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
common_dtype = outputDataType;
|
||||
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
|
||||
} else if (outputDataType == ScalarType::Bool &&
|
||||
(inputDataType == ScalarType::Byte ||
|
||||
otherDataType == ScalarType::Byte)) {
|
||||
(inputDataType == ScalarType::Byte || otherDataType == ScalarType::Byte)) {
|
||||
common_dtype = ScalarType::Byte;
|
||||
}
|
||||
}
|
||||
@ -113,8 +120,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
|
||||
}
|
||||
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
|
||||
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
|
||||
// Output tensor should have been promoted but it remains an int32 tensor
|
||||
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to
|
||||
// int32 tensor Output tensor should have been promoted but it remains an int32 tensor
|
||||
if (outputDataType != common_dtype ||
|
||||
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
|
||||
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
|
||||
@ -136,16 +143,22 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
self_scalar = getMPSScalar(self.item(), inputDataType);
|
||||
feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self_scalar);
|
||||
} else {
|
||||
selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, /*mpsShape*/nil,
|
||||
/*gatherTensorData=*/true, getMPSScalarType(inputDataType));
|
||||
selfPlaceholder = Placeholder(cachedGraph->primaryTensor,
|
||||
self,
|
||||
/*mpsShape*/ nil,
|
||||
/*gatherTensorData=*/true,
|
||||
getMPSScalarType(inputDataType));
|
||||
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
if (is_other_scalar && !other.is_mps()) {
|
||||
other_scalar = getMPSScalar(other.item(), otherDataType);
|
||||
feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other_scalar);
|
||||
} else {
|
||||
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, /*mpsShape*/nil,
|
||||
/*gatherTensorData=*/true, getMPSScalarType(otherDataType));
|
||||
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor,
|
||||
other,
|
||||
/*mpsShape*/ nil,
|
||||
/*gatherTensorData=*/true,
|
||||
getMPSScalarType(otherDataType));
|
||||
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
@ -156,9 +169,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, needsCopyToOutput ? output : output_);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
if (needsCopyToOutput) {
|
||||
@ -167,27 +179,28 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
}
|
||||
}
|
||||
|
||||
void binaryOpScalar(const Tensor& self, const Scalar& other, const Scalar& alpha,
|
||||
const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock)
|
||||
{
|
||||
void binaryOpScalar(const Tensor& self,
|
||||
const Scalar& other,
|
||||
const Scalar& alpha,
|
||||
const Tensor& output,
|
||||
std::string op_name,
|
||||
BinaryOpBlock binaryBlock) {
|
||||
binaryOpTensor(self, wrapped_scalar_tensor(other), alpha, output, op_name, binaryBlock);
|
||||
}
|
||||
|
||||
void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
void div_mode_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
c10::optional<c10::string_view> rounding_mode,
|
||||
const Tensor& output, const string op_name)
|
||||
{
|
||||
const Tensor& output,
|
||||
const string op_name) {
|
||||
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Half,
|
||||
"MPS: does not support trunc_divide op with float16 input");
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input");
|
||||
}
|
||||
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
|
||||
if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
|
||||
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor
|
||||
toType:MPSDataTypeFloat32
|
||||
name:@"primaryCastTensor"];
|
||||
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"];
|
||||
secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor
|
||||
toType:MPSDataTypeFloat32
|
||||
name:@"secondaryCastTensor"];
|
||||
@ -207,9 +220,7 @@ void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:mulTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
|
||||
}
|
||||
return truncTensor;
|
||||
} else if (*rounding_mode == "floor") {
|
||||
@ -218,20 +229,26 @@ void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:mulTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
|
||||
}
|
||||
return floorTensor;
|
||||
}
|
||||
assert(0 && "Invalid rounding mode\n");
|
||||
return nullptr;
|
||||
};
|
||||
binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block);
|
||||
binaryOpTensor(self,
|
||||
other,
|
||||
Scalar(1.0),
|
||||
output,
|
||||
op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""),
|
||||
div_mode_op_block);
|
||||
}
|
||||
|
||||
void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name)
|
||||
{
|
||||
void add_sub_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar& alpha,
|
||||
const Tensor& output,
|
||||
std::string op_name) {
|
||||
if (alpha.toDouble() == 0.0) {
|
||||
if (!self.is_alias_of(output)) { // if inplace, no-op
|
||||
const_cast<Tensor&>(output) = self.clone();
|
||||
@ -257,53 +274,72 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp
|
||||
name:nil];
|
||||
}
|
||||
if (op_name == "add")
|
||||
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryTensor
|
||||
name:nil];
|
||||
return [mpsGraph additionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
|
||||
else
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil];
|
||||
};
|
||||
// add alpha's type to the key only if multiply was added to graph
|
||||
binaryOpTensor(self, other, alpha, output, op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""), add_sub_op_block);
|
||||
binaryOpTensor(self,
|
||||
other,
|
||||
alpha,
|
||||
output,
|
||||
op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""),
|
||||
add_sub_op_block);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \
|
||||
Tensor& func_out(const Tensor& self, const other_type& other, Tensor& output) { \
|
||||
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
|
||||
mps::binaryOp##other_type( \
|
||||
self, \
|
||||
other, \
|
||||
Scalar(1.0), \
|
||||
output, \
|
||||
#func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \
|
||||
return [mpsGraph func_stub## \
|
||||
WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \
|
||||
secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \
|
||||
name:nil]; }); \
|
||||
name:nil]; \
|
||||
}); \
|
||||
return output; \
|
||||
}
|
||||
|
||||
#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \
|
||||
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && \
|
||||
std::string(#func_stub) == "atan2"), \
|
||||
"MPS does not support ", #func_stub, " op with int64 input") \
|
||||
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
|
||||
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && std::string(#func_stub) == "atan2"), \
|
||||
"MPS does not support ", \
|
||||
#func_stub, \
|
||||
" op with int64 input") \
|
||||
mps::binaryOp##other_type(self, \
|
||||
other, \
|
||||
Scalar(1.0), \
|
||||
output, \
|
||||
#func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
|
||||
secondaryTensor:secondaryCastTensor \
|
||||
name:nil]; }); \
|
||||
name:nil]; \
|
||||
}); \
|
||||
}
|
||||
|
||||
// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor()
|
||||
#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \
|
||||
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
|
||||
mps::binaryOp##other_type(self, \
|
||||
other, \
|
||||
Scalar(1.0), \
|
||||
output, \
|
||||
#func_stub, \
|
||||
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
|
||||
MPSGraph* mpsGraph = cachedGraph->graph(); \
|
||||
return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \
|
||||
secondaryTensor:secondaryCastTensor \
|
||||
name:nil]; }); \
|
||||
name:nil]; \
|
||||
}); \
|
||||
}
|
||||
|
||||
// Boolean Binary Ops
|
||||
@ -332,8 +368,8 @@ CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor);
|
||||
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor);
|
||||
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor);
|
||||
|
||||
|
||||
TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(div_out_mode_mps)
|
||||
(const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out");
|
||||
}
|
||||
|
||||
@ -394,13 +430,10 @@ TORCH_IMPL_FUNC(fmod_mps_out) (const Tensor& self, const Tensor& other, const Te
|
||||
mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0
|
||||
shape:@[@1]
|
||||
dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:twoTensor
|
||||
name:nil]
|
||||
@ -413,11 +446,11 @@ TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const T
|
||||
mps::binaryOpTensor(self, other, Scalar(1.0), output, "hypot_out_mps", hypot_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
|
||||
MPSGraphTensor* sumTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
|
||||
@ -425,11 +458,11 @@ TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, con
|
||||
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
|
||||
MPSGraphTensor* sumTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
|
||||
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
|
||||
@ -440,13 +473,9 @@ TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, co
|
||||
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
shape:@[@1]
|
||||
dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
||||
MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor name:nil];
|
||||
MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor name:nil];
|
||||
MPSGraphTensor* xlogyTensor = [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:logyTensor
|
||||
name:nil];
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/ops/logical_not_native.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/library.h>
|
||||
@ -86,7 +86,6 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
|
||||
}}
|
||||
)METAL";
|
||||
|
||||
|
||||
const std::string& getMetalType(const c10::ScalarType& t) {
|
||||
// Mapping from c10::ScalarType to integral type that can be used for bitwise ops
|
||||
// As bitwise ops sign-agnostic map signed/unsigned char and boolean to the same type
|
||||
@ -112,7 +111,6 @@ const std::string& getMetalType(const c10::Scalar& s) {
|
||||
return getMetalType(s.type());
|
||||
}
|
||||
|
||||
|
||||
static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
|
||||
const std::string& t1,
|
||||
const std::string& t2,
|
||||
@ -126,7 +124,8 @@ static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
|
||||
auto rc =
|
||||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
|
||||
@ -134,7 +133,6 @@ static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
|
||||
return rc;
|
||||
}
|
||||
|
||||
|
||||
static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
|
||||
const std::string& t1,
|
||||
const std::string& t2,
|
||||
@ -151,28 +149,27 @@ static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
|
||||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
|
||||
TORCH_CHECK(func != nil, "Can't get function ", fname);
|
||||
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
cplMap[key] = rc;
|
||||
return rc;
|
||||
}
|
||||
|
||||
void dispatch1DJob(id<MTLComputeCommandEncoder> commandEncoder, id<MTLComputePipelineState> cplState, uint32_t length)
|
||||
{
|
||||
void dispatch1DJob(id<MTLComputeCommandEncoder> commandEncoder, id<MTLComputePipelineState> cplState, uint32_t length) {
|
||||
uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
|
||||
auto size = MTLSizeMake(length, 1, 1);
|
||||
auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
|
||||
[commandEncoder dispatchThreads:size
|
||||
threadsPerThreadgroup:threadGroupSize];
|
||||
[commandEncoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
|
||||
}
|
||||
|
||||
void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& other, at::Tensor& output, const std::string& kernel_name) {
|
||||
void handle_tensor_tensor_binary_op(const at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
at::Tensor& output,
|
||||
const std::string& kernel_name) {
|
||||
using namespace at::mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
|
||||
getMetalType(output),
|
||||
getMetalType(self),
|
||||
getMetalType(other),
|
||||
kernel_name);
|
||||
id<MTLComputePipelineState> cplState = getCPLState(
|
||||
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
|
||||
uint32_t length = output.numel();
|
||||
if (length == 0) {
|
||||
return;
|
||||
@ -197,14 +194,14 @@ void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& ot
|
||||
});
|
||||
}
|
||||
|
||||
void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& other, at::Tensor& output, const std::string& kernel_name) {
|
||||
void handle_tensor_scalar_binary_op(const at::Tensor& self,
|
||||
const at::Scalar& other,
|
||||
at::Tensor& output,
|
||||
const std::string& kernel_name) {
|
||||
using namespace at::mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
|
||||
getMetalType(output),
|
||||
getMetalType(self),
|
||||
getMetalType(other),
|
||||
kernel_name);
|
||||
id<MTLComputePipelineState> cplState = getCPLState(
|
||||
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
|
||||
uint64_t sval = other.to<int64_t>();
|
||||
uint32_t length = output.numel();
|
||||
if (length == 0) {
|
||||
@ -229,7 +226,10 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
|
||||
});
|
||||
}
|
||||
|
||||
at::Tensor& _bitwise_op_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output_, const std::string& op_name) {
|
||||
at::Tensor& _bitwise_op_out_mps(const at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
at::Tensor& output_,
|
||||
const std::string& op_name) {
|
||||
using namespace at::mps;
|
||||
const bool is_self_scalar = self.dim() == 0;
|
||||
const bool is_other_scalar = other.dim() == 0;
|
||||
@ -310,11 +310,8 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
|
||||
}
|
||||
using namespace at::mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLComputePipelineState> cplState = getCPLState(MPSDevice::getInstance()->device(),
|
||||
getMetalType(output),
|
||||
getMetalType(self),
|
||||
getMetalType(self),
|
||||
"bitwise_not");
|
||||
id<MTLComputePipelineState> cplState = getCPLState(
|
||||
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(self), "bitwise_not");
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
id<MTLCommandBuffer> buffer = stream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
|
||||
@ -337,8 +334,6 @@ at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) {
|
||||
return output_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MPS, m) {
|
||||
m.impl("bitwise_and.Tensor_out", bitwise_and_out_mps);
|
||||
m.impl("bitwise_or.Tensor_out", bitwise_or_out_mps);
|
||||
|
@ -12,22 +12,15 @@
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
|
||||
Tensor dot_mps(
|
||||
const Tensor &self,
|
||||
const Tensor &other)
|
||||
{
|
||||
|
||||
Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS: dot op doesn't support int64 input")
|
||||
|
||||
using namespace mps;
|
||||
auto output = at::native::empty_mps({}, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
@ -42,7 +35,6 @@ Tensor dot_mps(
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -56,14 +48,10 @@ Tensor dot_mps(
|
||||
MPSGraphTensor* castSelf = nil;
|
||||
MPSGraphTensor* castOther = nil;
|
||||
|
||||
if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte
|
||||
|| self.scalar_type() == ScalarType::Char) {
|
||||
castSelf = [mpsGraph castTensor:selfTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castSelfTensor"];
|
||||
castOther = [mpsGraph castTensor:otherTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castOtherTensor"];
|
||||
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
|
||||
self.scalar_type() == ScalarType::Char) {
|
||||
castSelf = [mpsGraph castTensor:selfTensor toType:MPSDataTypeInt32 name:@"castSelfTensor"];
|
||||
castOther = [mpsGraph castTensor:otherTensor toType:MPSDataTypeInt32 name:@"castOtherTensor"];
|
||||
} else {
|
||||
castSelf = selfTensor;
|
||||
castOther = otherTensor;
|
||||
@ -73,12 +61,10 @@ Tensor dot_mps(
|
||||
secondaryTensor:castOther
|
||||
name:@"multiplication"];
|
||||
|
||||
MPSGraphTensor *dotProductTensor = [mpsGraph reductionSumWithTensor: dot
|
||||
axes: nil
|
||||
name: @"dotProduct"];
|
||||
MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor:dot axes:nil name:@"dotProduct"];
|
||||
|
||||
if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte
|
||||
|| self.scalar_type() == ScalarType::Char)
|
||||
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
|
||||
self.scalar_type() == ScalarType::Char)
|
||||
dotProductTensor = [mpsGraph castTensor:dotProductTensor
|
||||
toType:getMPSDataType(self)
|
||||
name:@"castDotProductTensor"];
|
||||
@ -101,9 +87,8 @@ Tensor dot_mps(
|
||||
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -111,14 +96,12 @@ Tensor dot_mps(
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& addmv_out_mps_impl(
|
||||
const Tensor &self,
|
||||
Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
const Tensor& mat,
|
||||
const Tensor& vec,
|
||||
const Scalar& beta_,
|
||||
const Scalar& alpha_,
|
||||
Tensor& result)
|
||||
{
|
||||
Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(mat.is_mps());
|
||||
@ -129,8 +112,7 @@ Tensor& addmv_out_mps_impl(
|
||||
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
|
||||
auto betaval = beta_.toComplexDouble();
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* matMulVecTensor_ = nil;
|
||||
@ -142,12 +124,10 @@ Tensor& addmv_out_mps_impl(
|
||||
Tensor matMulVec = mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec})
|
||||
+ ":" + to_string(beta_.toDouble())
|
||||
+ ":" + to_string(alpha_.toDouble());
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
|
||||
":" + to_string(alpha_.toDouble());
|
||||
CachedGraph* cachedGraph = nil;
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -168,8 +148,7 @@ Tensor& addmv_out_mps_impl(
|
||||
name:@"MM/alpha*(mat@vec)"];
|
||||
newCachedGraph->outputTensor_ = productTimesAlphaTensor;
|
||||
|
||||
if (betaval != 0.0)
|
||||
{
|
||||
if (betaval != 0.0) {
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta_.toDouble()
|
||||
dataType:getMPSScalarType(self.scalar_type())];
|
||||
|
||||
@ -197,15 +176,13 @@ Tensor& addmv_out_mps_impl(
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
|
||||
feeds[matMulVecPlaceholder.getMPSGraphTensor()] = matMulVecPlaceholder.getMPSGraphTensorData();
|
||||
if (betaval != 0.0)
|
||||
{
|
||||
if (betaval != 0.0) {
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
|
||||
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -213,7 +190,13 @@ Tensor& addmv_out_mps_impl(
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmv_out_mps)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmv_out_mps)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat,
|
||||
const Tensor& vec,
|
||||
const Scalar& beta_,
|
||||
const Scalar& alpha_,
|
||||
const Tensor& result) {
|
||||
addmv_out_mps_impl(self, mat, vec, beta_, alpha_, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,8 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
|
||||
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
|
||||
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
|
||||
auto dataType =
|
||||
!isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
|
||||
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
|
||||
// workaround by filing it as int8 tensor and than casting to bool
|
||||
// See https://github.com/pytorch/pytorch/issues/82427
|
||||
@ -47,17 +48,12 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
|
||||
shape:getMPSShape(self)
|
||||
dataType:dataType];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
|
||||
if (isBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeBool
|
||||
name:@"constWithBool-workaround"];
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
|
||||
}
|
||||
if (isUInt8) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeUInt8
|
||||
name:@"constWithUInt8-workaround"];
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
|
||||
}
|
||||
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
@ -66,13 +62,11 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
});
|
||||
}
|
||||
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
|
||||
needsCopyToOutput ? output : self,
|
||||
nullptr, !needsCopyToOutput);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results);
|
||||
|
||||
@ -109,7 +103,10 @@ Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
|
||||
}
|
||||
|
||||
Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) {
|
||||
TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
|
||||
TORCH_CHECK(value.dim() == 0,
|
||||
"fill_ only supports 0-dimension value tensor but got tensor with ",
|
||||
value.dim(),
|
||||
" dimensions.");
|
||||
Scalar scalar_value = value.item();
|
||||
if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
|
||||
return self;
|
||||
|
@ -2,39 +2,51 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* descriptor_,
|
||||
NSUInteger strideInX, NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX, NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal, NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format, NSUInteger groups) {
|
||||
descriptor_.strides = @[@1, [[NSNumber alloc] initWithInteger: strideInY],
|
||||
[[NSNumber alloc] initWithInteger: strideInX]];
|
||||
descriptor_.dilationRates = @[@1, [[NSNumber alloc] initWithInteger: dilationRateInY],
|
||||
[[NSNumber alloc] initWithInteger: dilationRateInX]];
|
||||
NSUInteger strideInX,
|
||||
NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX,
|
||||
NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal,
|
||||
NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format,
|
||||
NSUInteger groups) {
|
||||
descriptor_.strides =
|
||||
@[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ];
|
||||
descriptor_.dilationRates =
|
||||
@[ @1, [[NSNumber alloc] initWithInteger:dilationRateInY], [[NSNumber alloc] initWithInteger:dilationRateInX] ];
|
||||
|
||||
descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit;
|
||||
descriptor_.paddingValues = @[@0, @0, [[NSNumber alloc] initWithInteger: paddingVertical], [[NSNumber alloc]
|
||||
initWithInteger: paddingVertical], [[NSNumber alloc]
|
||||
initWithInteger: paddingHorizontal], [[NSNumber alloc]
|
||||
initWithInteger: paddingHorizontal]];
|
||||
descriptor_.paddingValues = @[
|
||||
@0,
|
||||
@0,
|
||||
[[NSNumber alloc] initWithInteger:paddingVertical],
|
||||
[[NSNumber alloc] initWithInteger:paddingVertical],
|
||||
[[NSNumber alloc] initWithInteger:paddingHorizontal],
|
||||
[[NSNumber alloc] initWithInteger:paddingHorizontal]
|
||||
];
|
||||
descriptor_.channelDimensionIndex = -3LL;
|
||||
}
|
||||
|
||||
// Create convolution descriptor
|
||||
void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
|
||||
NSUInteger strideInX, NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX, NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal, NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format, NSUInteger groups) {
|
||||
NSUInteger strideInX,
|
||||
NSUInteger strideInY,
|
||||
NSUInteger dilationRateInX,
|
||||
NSUInteger dilationRateInY,
|
||||
NSUInteger paddingHorizontal,
|
||||
NSUInteger paddingVertical,
|
||||
c10::MemoryFormat memory_format,
|
||||
NSUInteger groups) {
|
||||
descriptor_.strideInX = strideInX;
|
||||
descriptor_.strideInY = strideInY;
|
||||
descriptor_.dilationRateInX = dilationRateInX;
|
||||
@ -48,16 +60,15 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
|
||||
descriptor_.paddingTop = paddingVertical;
|
||||
descriptor_.paddingBottom = paddingVertical;
|
||||
|
||||
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ?
|
||||
MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
|
||||
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW
|
||||
: MPSGraphTensorNamedDataLayoutNHWC;
|
||||
|
||||
// PyTorch always uses OIHW memory layout for weights
|
||||
descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
|
||||
descriptor_.groups = groups;
|
||||
}
|
||||
|
||||
Tensor _mps_convolution_impl(
|
||||
const Tensor& input_t,
|
||||
Tensor _mps_convolution_impl(const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
@ -70,8 +81,7 @@ Tensor _mps_convolution_impl(
|
||||
|
||||
namespace native_mps = at::native::mps;
|
||||
CheckedFrom c = "mps_convolution";
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
|
||||
checkAllSameType(c, {input, weight});
|
||||
checkAllSameGPU(c, {input, weight});
|
||||
|
||||
@ -84,11 +94,9 @@ Tensor _mps_convolution_impl(
|
||||
|
||||
auto memory_format = input_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
auto output_t = at::empty(
|
||||
input_shape.has_value() ?
|
||||
input_shape.value() :
|
||||
conv_output_size(input->sizes(), weight->sizes(),
|
||||
padding, stride, dilation),
|
||||
auto output_t =
|
||||
at::empty(input_shape.has_value() ? input_shape.value()
|
||||
: conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation),
|
||||
input->scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
@ -103,8 +111,7 @@ Tensor _mps_convolution_impl(
|
||||
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
@ -117,7 +124,6 @@ Tensor _mps_convolution_impl(
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
IntArrayRef bias_shape;
|
||||
if (bias_defined)
|
||||
bias_shape = bias_opt.value().sizes();
|
||||
@ -141,17 +147,14 @@ Tensor _mps_convolution_impl(
|
||||
bias_shape_key = "nobias";
|
||||
}
|
||||
|
||||
string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
|
||||
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
|
||||
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
|
||||
+ to_string(groups) + ":" + mem_format_key
|
||||
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
|
||||
+ to_string(bias_defined) + ":" + bias_shape_key;
|
||||
string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) +
|
||||
":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
|
||||
to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" +
|
||||
to_string(bias_defined) + ":" + bias_shape_key;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -159,33 +162,49 @@ Tensor _mps_convolution_impl(
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
|
||||
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSShape* weightShape = mps::getMPSShape(weight_t);
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) &&
|
||||
inputShape.count >= 4 && weightShape.count >= 4 && !is_channels_last);
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 &&
|
||||
weightShape.count >= 4 && !is_channels_last);
|
||||
if (isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
memory_format, groups);
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
memory_format,
|
||||
groups);
|
||||
} else {
|
||||
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
memory_format, groups);
|
||||
fill_conv_desc(conv2dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
memory_format,
|
||||
groups);
|
||||
}
|
||||
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if (bias_defined) {
|
||||
biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()));
|
||||
biasTensor =
|
||||
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()));
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor;
|
||||
if (isDepthwiseConv) {
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil];
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
|
||||
dimension:-3
|
||||
withDimension:-4
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:inputTensor
|
||||
weightsTensor:weightTransposeTensor
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
@ -202,9 +221,7 @@ Tensor _mps_convolution_impl(
|
||||
}
|
||||
|
||||
if (bias_defined) {
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor
|
||||
secondaryTensor: biasTensor
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil];
|
||||
}
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->weightTensor_ = weightTensor;
|
||||
@ -221,19 +238,20 @@ Tensor _mps_convolution_impl(
|
||||
auto biasPlaceholder = native_mps::Placeholder();
|
||||
// Reshape the bias to be broadcastable with output of conv2d
|
||||
if (bias_defined)
|
||||
biasPlaceholder = native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
|
||||
biasPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
|
||||
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, *output);
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 3] autorelease];
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
[[[NSMutableDictionary alloc] initWithCapacity:3] autorelease];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
|
||||
if (bias_defined) {
|
||||
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -241,8 +259,7 @@ Tensor _mps_convolution_impl(
|
||||
return *output;
|
||||
}
|
||||
|
||||
Tensor _mps_convolution(
|
||||
const Tensor& input_t,
|
||||
Tensor _mps_convolution(const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
@ -252,15 +269,19 @@ Tensor _mps_convolution(
|
||||
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor mps_convolution_backward_input(
|
||||
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& weight_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool bias_defined) {
|
||||
namespace native_mps = at::native::mps;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
CheckedFrom c = "mps_convolution_backward_input";
|
||||
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
|
||||
weight{ weight_t, "weight", 2 };
|
||||
TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2};
|
||||
checkAllSameType(c, {grad_output, weight});
|
||||
checkAllSameGPU(c, {grad_output, weight});
|
||||
auto memory_format = grad_output_t.suggest_memory_format();
|
||||
@ -272,8 +293,7 @@ Tensor mps_convolution_backward_input(
|
||||
convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
@ -284,7 +304,6 @@ Tensor mps_convolution_backward_input(
|
||||
|
||||
// Add backward with input
|
||||
@autoreleasepool {
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
string mem_format_key;
|
||||
@ -302,17 +321,14 @@ Tensor mps_convolution_backward_input(
|
||||
MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
|
||||
MPSShape* mps_input_shape = getMPSShape(input_size);
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
|
||||
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
|
||||
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
|
||||
+ to_string(groups) + ":" + mem_format_key
|
||||
+ getTensorsStringKey({grad_output_t, weight_t}) + ":"
|
||||
+ string([ns_shape_key UTF8String]);
|
||||
string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -320,26 +336,38 @@ Tensor mps_convolution_backward_input(
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
|
||||
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
|
||||
MPSShape* weightOutputShape = mps::getMPSShape(weight_t);
|
||||
// Depthwise conv is input feature channels = groups. So I in OIHW has to be 1.
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) &&
|
||||
gradOutputShape.count >= 4 && weightOutputShape.count >= 4 && !is_channels_last);
|
||||
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 &&
|
||||
weightOutputShape.count >= 4 && !is_channels_last);
|
||||
|
||||
if (isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
} else {
|
||||
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
fill_conv_desc(conv2dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
|
||||
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
|
||||
@ -348,8 +376,12 @@ Tensor mps_convolution_backward_input(
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor;
|
||||
if (isDepthwiseConv) {
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil];
|
||||
gradInputTensor = [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
|
||||
dimension:-3
|
||||
withDimension:-4
|
||||
name:nil];
|
||||
gradInputTensor =
|
||||
[mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
weightsTensor:weightTransposeTensor
|
||||
outputShape:mps_input_shape
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
@ -380,18 +412,22 @@ Tensor mps_convolution_backward_input(
|
||||
weightsPlaceholder.getMPSGraphTensor() : weightsPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return *grad_input;
|
||||
}
|
||||
|
||||
Tensor mps_convolution_backward_weights(
|
||||
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool bias_defined) {
|
||||
namespace native_mps = at::native::mps;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
@ -409,20 +445,14 @@ Tensor mps_convolution_backward_weights(
|
||||
checkAllSameType(c, {grad_output, input});
|
||||
checkAllSameGPU(c, {grad_output, input});
|
||||
|
||||
auto grad_weight_t = at::empty(
|
||||
weight_size,
|
||||
grad_output_t.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
auto grad_weight_t =
|
||||
at::empty(weight_size, grad_output_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
TensorArg grad_weight{grad_weight_t, "result", 0};
|
||||
|
||||
convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
@ -432,7 +462,6 @@ Tensor mps_convolution_backward_weights(
|
||||
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
string mem_format_key;
|
||||
@ -448,17 +477,14 @@ Tensor mps_convolution_backward_weights(
|
||||
}
|
||||
MPSShape* mps_weight_shape = getMPSShape(weight_size);
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":"
|
||||
+ to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":"
|
||||
+ to_string(padding[0]) + ":" + to_string(padding[1]) + ":"
|
||||
+ to_string(groups) + ":" + mem_format_key
|
||||
+ getTensorsStringKey({grad_output_t, input_t}) + ":"
|
||||
+ string([ns_shape_key UTF8String]);
|
||||
string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, input_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -466,23 +492,36 @@ Tensor mps_convolution_backward_weights(
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
|
||||
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t);
|
||||
bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 && mps_weight_shape.count >= 4 && !is_channels_last);
|
||||
bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 &&
|
||||
mps_weight_shape.count >= 4 && !is_channels_last);
|
||||
|
||||
if (isDepthwiseConv) {
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
} else {
|
||||
fill_conv_desc(conv2dDescriptor_, stride[1], stride[0],
|
||||
dilation[1], dilation[0],
|
||||
padding[1], padding[0],
|
||||
at::MemoryFormat::Contiguous, groups);
|
||||
fill_conv_desc(conv2dDescriptor_,
|
||||
stride[1],
|
||||
stride[0],
|
||||
dilation[1],
|
||||
dilation[0],
|
||||
padding[1],
|
||||
padding[0],
|
||||
at::MemoryFormat::Contiguous,
|
||||
groups);
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
|
||||
@ -494,14 +533,19 @@ Tensor mps_convolution_backward_weights(
|
||||
if (isDepthwiseConv) {
|
||||
NSNumber* outputFeatChannelDim = mps_weight_shape[0];
|
||||
MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ];
|
||||
MPSGraphTensor* gradWeightTensorTranspose = [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
MPSGraphTensor* gradWeightTensorTranspose =
|
||||
[mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
sourceTensor:inputTensor
|
||||
outputShape:weightShapeTranspose
|
||||
descriptor:depthWiseConv3dDescriptor_
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose dimension:-3 withDimension:-4 name:nil];
|
||||
gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose
|
||||
dimension:-3
|
||||
withDimension:-4
|
||||
name:nil];
|
||||
} else {
|
||||
gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
gradWeightTensor =
|
||||
[mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
|
||||
sourceTensor:inputTensor
|
||||
outputShape:mps_weight_shape
|
||||
forwardConvolutionDescriptor:conv2dDescriptor_
|
||||
@ -525,9 +569,8 @@ Tensor mps_convolution_backward_weights(
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -535,9 +578,13 @@ Tensor mps_convolution_backward_weights(
|
||||
return grad_weight_t;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_convolution_backward(const at::Tensor& input,
|
||||
const at::Tensor& grad_output,
|
||||
const at::Tensor& weight,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::array<bool, 3> output_mask) {
|
||||
Tensor grad_input, grad_weight, grad_bias;
|
||||
if (input.numel() == 0) {
|
||||
@ -549,73 +596,85 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
|
||||
}
|
||||
} else {
|
||||
if (output_mask[0]) {
|
||||
grad_input = mps_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
|
||||
grad_input = mps_convolution_backward_input(
|
||||
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = mps_convolution_backward_weights(weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
|
||||
grad_weight = mps_convolution_backward_weights(
|
||||
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_forward(
|
||||
const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
||||
{
|
||||
auto input_size = conv_input_size(grad_output.sizes(), weight.sizes(),
|
||||
padding, output_padding, stride, dilation, groups);
|
||||
return mps_convolution_backward_input(input_size, grad_output, weight,
|
||||
padding, stride, dilation, groups, false);
|
||||
Tensor mps_convolution_transpose_forward(const Tensor& grad_output,
|
||||
const Tensor& weight,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef output_padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
auto input_size =
|
||||
conv_input_size(grad_output.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
|
||||
return mps_convolution_backward_input(input_size, grad_output, weight, padding, stride, dilation, groups, false);
|
||||
}
|
||||
|
||||
Tensor _mps_convolution_transpose(
|
||||
const Tensor& input_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
Tensor _mps_convolution_transpose(const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef output_padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS");
|
||||
|
||||
auto output_t = mps_convolution_transpose_forward(
|
||||
input_t, weight_t, padding, output_padding, stride, dilation, groups);
|
||||
auto output_t =
|
||||
mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups);
|
||||
return output_t;
|
||||
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_backward_input(
|
||||
const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, IntArrayRef input_shape)
|
||||
{
|
||||
return _mps_convolution_impl(
|
||||
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
|
||||
Tensor mps_convolution_transpose_backward_input(const Tensor& grad_output_t,
|
||||
const Tensor& weight_t,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
IntArrayRef input_shape) {
|
||||
return _mps_convolution_impl(grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_backward_weight(
|
||||
IntArrayRef weight_size,
|
||||
Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size,
|
||||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
||||
{
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return mps_convolution_backward_weights(
|
||||
weight_size, input_t, grad_output_t,
|
||||
padding, stride, dilation, groups, false);
|
||||
weight_size, input_t, grad_output_t, padding, stride, dilation, groups, false);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
|
||||
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
std::tuple<Tensor, Tensor> mps_convolution_transpose_backward(const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
const Tensor& weight,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef output_padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::array<bool, 2> output_mask) {
|
||||
Tensor grad_input, grad_weight;
|
||||
if (output_mask[0]) {
|
||||
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
|
||||
grad_input =
|
||||
mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups);
|
||||
grad_weight = mps_convolution_transpose_backward_weight(
|
||||
weight.sizes(), grad_output, input, padding, stride, dilation, groups);
|
||||
}
|
||||
|
||||
return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -6,10 +6,7 @@
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void* pageAlignedBlockPtr(
|
||||
const void* ptr,
|
||||
NSUInteger size,
|
||||
NSUInteger* alignedBlockSize) {
|
||||
void* pageAlignedBlockPtr(const void* ptr, NSUInteger size, NSUInteger* alignedBlockSize) {
|
||||
uintptr_t address = (uintptr_t)ptr;
|
||||
uintptr_t alignedAddress = address & ~(PAGE_SIZE - 1);
|
||||
uintptr_t alignedEnd = ((address + size) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
|
||||
@ -43,12 +40,14 @@ bool is_strided_contiguous(const at::Tensor& t) {
|
||||
|
||||
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
|
||||
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
|
||||
void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
|
||||
id<MTLBuffer> destBuffer, id<MTLBuffer> sourceBuffer, bool non_blocking = true) {
|
||||
void copy_cast_mps(at::Tensor& dst,
|
||||
const at::Tensor& src,
|
||||
id<MTLBuffer> destBuffer,
|
||||
id<MTLBuffer> sourceBuffer,
|
||||
bool non_blocking = true) {
|
||||
using namespace mps;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -87,21 +86,22 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc]
|
||||
initWithMTLBuffer:sourceBuffer shape:srcShape dataType:srcDType]
|
||||
autorelease];
|
||||
MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc]
|
||||
initWithMTLBuffer:destBuffer shape:dstShape dataType:dstDType]
|
||||
autorelease];
|
||||
MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer
|
||||
shape:srcShape
|
||||
dataType:srcDType] autorelease];
|
||||
MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:destBuffer
|
||||
shape:dstShape
|
||||
dataType:dstDType] autorelease];
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : srcData};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor_ : dstData};
|
||||
stream->executeMPSGraph(cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
|
||||
stream->executeMPSGraph(
|
||||
cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
|
||||
}
|
||||
}
|
||||
|
||||
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
|
||||
{
|
||||
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
|
||||
auto sameMemFormat =
|
||||
src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@ -181,8 +181,7 @@ static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool
|
||||
}
|
||||
|
||||
// Copies tensor from cpu to mps backed by identical strided-contiguous data
|
||||
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
{
|
||||
static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
|
||||
@ -210,8 +209,7 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
|
||||
}
|
||||
}
|
||||
|
||||
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
|
||||
{
|
||||
static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
|
||||
// Typecast to dst_ if needed and expand, which is a no-op
|
||||
Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);
|
||||
|
||||
@ -241,8 +239,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
|
||||
stream->copy_and_sync((id<MTLBuffer>)(src), (id<MTLBuffer>)(dst), size, 0, 0, true);
|
||||
}
|
||||
|
||||
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
|
||||
{
|
||||
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) {
|
||||
auto src_byte_offset = src_.storage_offset() * src_.itemsize();
|
||||
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
|
||||
|
||||
@ -250,7 +247,8 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
// gather into dst. This reduces the overhead of doing an additional blit for most cases
|
||||
bool returnGatherOutput = dst_.is_contiguous();
|
||||
Tensor src;
|
||||
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
auto sameMemFormat =
|
||||
src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
const bool sameDataType = src_.dtype() == dst_.dtype();
|
||||
|
||||
if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) ||
|
||||
@ -301,8 +299,7 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
return dst_;
|
||||
}
|
||||
|
||||
at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
{
|
||||
at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking) {
|
||||
TORCH_CHECK(dst.defined(), "dst is undefined");
|
||||
TORCH_CHECK(src.defined(), "src is undefined");
|
||||
|
||||
@ -328,20 +325,16 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
|
||||
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.device().type() == DeviceType::MPS,
|
||||
"mps_copy_ is implemented only for *->MPS; MPS->*");
|
||||
TORCH_INTERNAL_ASSERT(src.device().type() == DeviceType::MPS, "mps_copy_ is implemented only for *->MPS; MPS->*");
|
||||
return dst;
|
||||
}
|
||||
} // namespace mps
|
||||
|
||||
Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst)
|
||||
{
|
||||
Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst) {
|
||||
return mps::mps_copy_(const_cast<Tensor&>(dst), self, false);
|
||||
}
|
||||
|
||||
Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking)
|
||||
{
|
||||
Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
|
||||
return mps::mps_copy_(const_cast<Tensor&>(dst), self, non_blocking);
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Cross.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -154,12 +154,14 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
|
||||
error: &error] autorelease];
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
|
||||
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
|
||||
@ -172,8 +174,7 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
|
||||
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
|
||||
|
||||
id<MTLComputePipelineState> crossPSO = crossPipelineState(device, out.scalar_type());
|
||||
[computeEncoder setComputePipelineState:crossPSO];
|
||||
@ -191,8 +192,7 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
||||
}
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: threadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
[computeEncoder endEncoding];
|
||||
mpsStream->commit(true);
|
||||
|
@ -1,17 +1,16 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/Distributions.h>
|
||||
#include <ATen/native/DistributionTemplates.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#include <ATen/native/DistributionTemplates.h>
|
||||
#include <ATen/native/Distributions.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct RandomCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct RandomCachedGraph : public MPSCachedGraph {
|
||||
RandomCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
// Only relevant for multinomial
|
||||
MPSGraphTensor* probTensor = nil;
|
||||
@ -27,13 +26,15 @@ typedef MPSGraphTensor* (^RandomOpBlock)(RandomCachedGraph*, MPSGraphTensor*);
|
||||
// for Uniform distributions with scalar from (val1) and to (val2) intervals
|
||||
// for Normal distributions with scalar mean (val1) and std (val2) values
|
||||
template <typename scalar_t>
|
||||
Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
Tensor& random_mps_impl(Tensor& self,
|
||||
scalar_t val1,
|
||||
scalar_t val2,
|
||||
const c10::optional<Tensor>& mean_opt,
|
||||
const c10::optional<Tensor>& std_opt,
|
||||
MPSGraphRandomDistribution distribution,
|
||||
c10::optional<Generator> gen,
|
||||
std::string op_name, RandomOpBlock randomBlock)
|
||||
{
|
||||
std::string op_name,
|
||||
RandomOpBlock randomBlock) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
@ -52,7 +53,8 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new RandomCachedGraph(mpsGraph);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(at::mps::detail::PHILOX_STATE_N)]);
|
||||
newCachedGraph->stateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
|
||||
|
||||
// FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend.
|
||||
const MPSDataType inputDataType = [&] {
|
||||
@ -94,7 +96,8 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
});
|
||||
}
|
||||
// feed the updated state values to the graph
|
||||
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]];
|
||||
MPSNDArrayDescriptor* stateDesc =
|
||||
[MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]];
|
||||
MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease];
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -131,12 +134,13 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s,
|
||||
Tensor& normal_mps_impl(Tensor& self,
|
||||
double mean_s,
|
||||
double std_s,
|
||||
const c10::optional<Tensor>& mean_opt,
|
||||
const c10::optional<Tensor>& std_opt,
|
||||
c10::optional<Generator> gen,
|
||||
std::string op_name)
|
||||
{
|
||||
std::string op_name) {
|
||||
const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt));
|
||||
const Tensor& mean_t = *(at::borrow_from_optional_tensor(mean_opt));
|
||||
|
||||
@ -159,33 +163,39 @@ Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s,
|
||||
}
|
||||
if (mean_t.defined()) {
|
||||
cachedGraph->meanTensor = mpsGraphRankedPlaceHolder(mpsGraph, mean_t);
|
||||
return [mpsGraph additionWithPrimaryTensor: resultTensor
|
||||
secondaryTensor: cachedGraph->meanTensor
|
||||
name: nil];
|
||||
return [mpsGraph additionWithPrimaryTensor:resultTensor secondaryTensor:cachedGraph->meanTensor name:nil];
|
||||
}
|
||||
return resultTensor;
|
||||
};
|
||||
return random_mps_impl<double>(self, mean_s, std_s, mean_opt, std_opt,
|
||||
MPSGraphRandomDistributionNormal, gen,
|
||||
op_name + getTensorsStringKey({mean_t, std_t}), random_op_block);
|
||||
|
||||
return random_mps_impl<double>(self,
|
||||
mean_s,
|
||||
std_s,
|
||||
mean_opt,
|
||||
std_opt,
|
||||
MPSGraphRandomDistributionNormal,
|
||||
gen,
|
||||
op_name + getTensorsStringKey({mean_t, std_t}),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional<Generator> gen, std::string op_name)
|
||||
{
|
||||
Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional<Generator> gen, std::string op_name) {
|
||||
TORCH_CHECK(prob_t.is_same_size(self), op_name, ": probability and self tensor should be of the same shape")
|
||||
|
||||
RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, prob_t);
|
||||
return [mpsGraph lessThanWithPrimaryTensor: randomTensor
|
||||
secondaryTensor: cachedGraph->stdTensor
|
||||
name: nil];
|
||||
return [mpsGraph lessThanWithPrimaryTensor:randomTensor secondaryTensor:cachedGraph->stdTensor name:nil];
|
||||
};
|
||||
// Bernoulli generates binary output so we use bool type
|
||||
return mps::random_mps_impl<bool>(self, 0.0, 1.0, c10::nullopt, prob_t,
|
||||
MPSGraphRandomDistributionUniform, gen,
|
||||
op_name + getTensorsStringKey({prob_t}), random_op_block);
|
||||
return mps::random_mps_impl<bool>(self,
|
||||
0.0,
|
||||
1.0,
|
||||
c10::nullopt,
|
||||
prob_t,
|
||||
MPSGraphRandomDistributionUniform,
|
||||
gen,
|
||||
op_name + getTensorsStringKey({prob_t}),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
@ -196,15 +206,19 @@ Tensor& uniform_mps_(Tensor& self, double from, double to, c10::optional<Generat
|
||||
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
|
||||
TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
|
||||
TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
|
||||
"uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
|
||||
">::max(), but found to=", to, " and from=", from,
|
||||
"uniform_ expects to-from <= std::numeric_limits<",
|
||||
toString(self.scalar_type()),
|
||||
">::max(), but found to=",
|
||||
to,
|
||||
" and from=",
|
||||
from,
|
||||
" which result in to-from to exceed the limit");
|
||||
from = std::min(std::max(from, min), max);
|
||||
to = std::max(std::min(to, max), min);
|
||||
});
|
||||
|
||||
return mps::random_mps_impl<double>(self, from, to, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
return mps::random_mps_impl<double>(
|
||||
self, from, to, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
}
|
||||
|
||||
Tensor& normal_mps_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
|
||||
@ -271,21 +285,34 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
|
||||
to = *to_opt;
|
||||
TORCH_CHECK(from < to, "random_mps_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
|
||||
if (isFloatingType(input_dtype)) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] {
|
||||
from = templates::update_from<scalar_t>(from);
|
||||
to = templates::update_to<scalar_t>(to);
|
||||
TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
|
||||
TORCH_CHECK(
|
||||
from < to,
|
||||
"random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=",
|
||||
from,
|
||||
" >= to=",
|
||||
to);
|
||||
});
|
||||
templates::check_from_to_in_range(from, to - 1, self.dtype());
|
||||
}
|
||||
} else if (from != std::numeric_limits<int64_t>::lowest()) {
|
||||
// [from, std::numeric_limits<int64_t>::max()]
|
||||
if (isFloatingType(input_dtype)) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] {
|
||||
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
|
||||
to = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
|
||||
to = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max()
|
||||
: static_cast<int64_t>(scalar_t_max);
|
||||
from = templates::update_from<scalar_t>(from);
|
||||
TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=", from, " > to=", to);
|
||||
TORCH_CHECK(
|
||||
from < to,
|
||||
"random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=",
|
||||
from,
|
||||
" > to=",
|
||||
to);
|
||||
});
|
||||
} else if (isIntegralType(input_dtype, /*includeBool=*/true)) {
|
||||
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, input_dtype, "random_from_to_range_calc", [&] {
|
||||
@ -295,13 +322,11 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
|
||||
to = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
||||
}
|
||||
});
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
TORCH_CHECK(false, "random_mps_ handles only integral, floating-point and boolean types");
|
||||
}
|
||||
templates::check_from_to_in_range(from, to, self.dtype());
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
|
||||
// range = 2^64
|
||||
|
||||
@ -309,8 +334,8 @@ Tensor& random_mps_(Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c
|
||||
TORCH_CHECK(false, "random_mps_ currently does not handle the lowest() -> max() range");
|
||||
}
|
||||
|
||||
return mps::random_mps_impl<int64_t>(self, from, to - 1, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
return mps::random_mps_impl<int64_t>(
|
||||
self, from, to - 1, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr);
|
||||
}
|
||||
|
||||
Tensor& random_mps_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
|
||||
@ -323,22 +348,23 @@ Tensor& exponential_mps_(Tensor& self, double lambda, c10::optional<Generator> g
|
||||
|
||||
mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar: 1.0f
|
||||
dataType: randomTensor.dataType];
|
||||
MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar: -lambda
|
||||
dataType: randomTensor.dataType];
|
||||
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f dataType:randomTensor.dataType];
|
||||
MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar:-lambda dataType:randomTensor.dataType];
|
||||
MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor
|
||||
secondaryTensor:randomTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor: subtractTensor
|
||||
name: nil];
|
||||
return [mpsGraph divisionWithPrimaryTensor: logTensor
|
||||
secondaryTensor: minusLambdaTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
|
||||
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
|
||||
};
|
||||
return mps::random_mps_impl<double>(self, 0.0, 1.0, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, gen,
|
||||
"exponential_mps_:" + std::to_string(lambda), random_op_block);
|
||||
return mps::random_mps_impl<double>(self,
|
||||
0.0,
|
||||
1.0,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform,
|
||||
gen,
|
||||
"exponential_mps_:" + std::to_string(lambda),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor& result) {
|
||||
@ -354,9 +380,12 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
|
||||
}
|
||||
|
||||
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
|
||||
TORCH_CHECK(!generator.has_value() ||
|
||||
(generator.has_value() && result.device() == generator->device()),
|
||||
"Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
|
||||
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()),
|
||||
"Expected a '",
|
||||
result.device(),
|
||||
"' generator device but found '",
|
||||
generator->device(),
|
||||
"'");
|
||||
check_supported_max_int_with_precision(n, result);
|
||||
|
||||
result.resize_({n});
|
||||
@ -366,36 +395,34 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
|
||||
|
||||
mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor axis:0 name:nil];
|
||||
if (result.scalar_type() != kInt) {
|
||||
argsortTensor = [mpsGraph castTensor:argsortTensor
|
||||
toType:mps::getMPSDataType(result)
|
||||
name:@"castOutput"];
|
||||
argsortTensor = [mpsGraph castTensor:argsortTensor toType:mps::getMPSDataType(result) name:@"castOutput"];
|
||||
}
|
||||
return argsortTensor;
|
||||
};
|
||||
|
||||
return mps::random_mps_impl<int64_t>(result, 0.0, 1.0, c10::nullopt, c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform, generator,
|
||||
"ranperm_out_mps:" + mps::getTensorsStringKey({result}), random_op_block);
|
||||
return mps::random_mps_impl<int64_t>(result,
|
||||
0.0,
|
||||
1.0,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
MPSGraphRandomDistributionUniform,
|
||||
generator,
|
||||
"ranperm_out_mps:" + mps::getTensorsStringKey({result}),
|
||||
random_op_block);
|
||||
}
|
||||
|
||||
Tensor& multinomial_with_replacement_mps_kernel(
|
||||
const Tensor& self,
|
||||
Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
|
||||
const int64_t n_sample,
|
||||
c10::optional<Generator> generator,
|
||||
Tensor& result) {
|
||||
|
||||
using namespace mps;
|
||||
|
||||
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(generator, at::mps::detail::getDefaultMPSGenerator());
|
||||
int inputSize = self.dim();
|
||||
int numDist =
|
||||
inputSize == 1 ? 1 : self.size(0);
|
||||
int numCategories =
|
||||
inputSize == 1 ? self.size(0) : self.size(1);
|
||||
int numDist = inputSize == 1 ? 1 : self.size(0);
|
||||
int numCategories = inputSize == 1 ? self.size(0) : self.size(1);
|
||||
|
||||
// Restructure data for 2d
|
||||
auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self;
|
||||
@ -421,9 +448,7 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
// This is probability weights
|
||||
newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v), prob_shape);
|
||||
|
||||
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor* sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor axis:-1 name:nil];
|
||||
|
||||
MPSGraphTensor* normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
|
||||
secondaryTensor:sumProbs
|
||||
@ -436,10 +461,8 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
MPSGraphTensor* ones = [mpsGraph constantWithScalar:1.0f
|
||||
shape:@[ ns_numCategories, ns_numCategories ]
|
||||
dataType:prob_dtype];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar: 0.0f
|
||||
dataType: MPSDataTypeInt32];
|
||||
auto minusOneTensor = [mpsGraph constantWithScalar: -1.0f
|
||||
dataType: MPSDataTypeInt32];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
|
||||
auto minusOneTensor = [mpsGraph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
|
||||
|
||||
MPSGraphTensor* upperTriangle = [mpsGraph bandPartWithTensor:ones
|
||||
numLowerTensor:zeroTensor
|
||||
@ -460,7 +483,8 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
withShape:@[ ns_numDist, @1, ns_numCategories ]
|
||||
name:nil];
|
||||
|
||||
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
|
||||
MPSGraphRandomOpDescriptor* descriptor =
|
||||
[MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
|
||||
dataType:prob_dtype];
|
||||
NSArray<MPSGraphTensor*>* generatorTensors = [mpsGraph randomTensorWithShape:@[ ns_numDist, ns_n_sample, @1 ]
|
||||
descriptor:descriptor
|
||||
@ -470,13 +494,12 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
|
||||
auto broadcastShape = @[ ns_numDist, ns_n_sample, ns_numCategories ];
|
||||
int broadcastShapeVals[3] = {numDist, static_cast<int>(n_sample), numCategories};
|
||||
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
|
||||
MPSGraphTensor* broadcastShapeTensor = [mpsGraph
|
||||
constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:broadcastShape.count] ]
|
||||
dataType:MPSDataTypeUInt32];
|
||||
|
||||
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
|
||||
toShape:broadcastShape
|
||||
name:nil];
|
||||
MPSGraphTensor* samplesTensor = [mpsGraph broadcastTensor:randomTensor toShape:broadcastShape name:nil];
|
||||
MPSGraphTensor* sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:lowerProbRange
|
||||
name:nil];
|
||||
@ -486,18 +509,14 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
MPSGraphTensor* sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
|
||||
secondaryTensor:sampleBelow
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"sampleMask"];
|
||||
MPSGraphTensor* sampleMask = [mpsGraph castTensor:sampleWithin toType:MPSDataTypeInt32 name:@"sampleMask"];
|
||||
MPSGraphTensor* categoriesTensor = [mpsGraph coordinateAlongAxis:-1
|
||||
withShapeTensor:broadcastShapeTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
|
||||
secondaryTensor:sampleMask
|
||||
name:nil];
|
||||
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor* reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor axis:-1 name:nil];
|
||||
MPSGraphTensor* reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
|
||||
withShape:@[ ns_numDist, ns_n_sample ]
|
||||
name:nil];
|
||||
@ -509,7 +528,8 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
});
|
||||
}
|
||||
// update the Philox state values on each run of the same graph
|
||||
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]];
|
||||
MPSNDArrayDescriptor* stateDesc =
|
||||
[MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]];
|
||||
MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease];
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -526,15 +546,13 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
cachedGraph->stateTensor : stateTensorData,
|
||||
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
/* The largest consecutive integer representable in float32 (2^24) */
|
||||
@ -545,27 +563,20 @@ Tensor& multinomial_out_mps(const Tensor& self,
|
||||
bool with_replacement,
|
||||
c10::optional<Generator> gen,
|
||||
Tensor& result) {
|
||||
|
||||
TORCH_CHECK(
|
||||
result.device() == self.device(),
|
||||
"multinomial arguments must have the same device");
|
||||
TORCH_CHECK(
|
||||
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
|
||||
TORCH_CHECK(
|
||||
at::isFloatingType(self.scalar_type()),
|
||||
TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device");
|
||||
TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
|
||||
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
|
||||
"multinomial only supports floating-point dtypes for input, got: ",
|
||||
self.scalar_type());
|
||||
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
|
||||
"multinomial expects Long tensor out, got: ", result.scalar_type());
|
||||
TORCH_CHECK(
|
||||
result.scalar_type() == ScalarType::Long, "multinomial expects Long tensor out, got: ", result.scalar_type());
|
||||
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
|
||||
int64_t n_categories = self.size(-1);
|
||||
TORCH_CHECK(with_replacement || (n_sample <= n_categories),
|
||||
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
|
||||
// Since the index tensor is float, numCategories cannot exceed max
|
||||
// float integer precision
|
||||
TORCH_CHECK(
|
||||
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
|
||||
"number of categories cannot exceed 2^24");
|
||||
TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24");
|
||||
|
||||
if (self.dim() == 1) {
|
||||
result.resize_({n_sample});
|
||||
@ -583,9 +594,7 @@ Tensor& multinomial_out_mps(const Tensor& self,
|
||||
if (!with_replacement || n_sample == 1) {
|
||||
// Sanity checks on `self`.
|
||||
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
|
||||
TORCH_CHECK(
|
||||
is_valid.to<bool>(),
|
||||
"probability tensor contains either `inf`, `nan` or element < 0");
|
||||
TORCH_CHECK(is_valid.to<bool>(), "probability tensor contains either `inf`, `nan` or element < 0");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
bool zero_prob_condition;
|
||||
if (self.dim() == 1) {
|
||||
@ -593,9 +602,7 @@ Tensor& multinomial_out_mps(const Tensor& self,
|
||||
} else {
|
||||
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!zero_prob_condition,
|
||||
"invalid multinomial distribution (sum of probabilities <= 0)");
|
||||
TORCH_CHECK(!zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
|
||||
|
||||
// The algorithm is from gumbel softmax.
|
||||
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
|
||||
@ -625,11 +632,7 @@ Tensor& multinomial_out_mps(const Tensor& self,
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor multinomial_mps(
|
||||
const Tensor& self,
|
||||
int64_t n_sample,
|
||||
bool with_replacement,
|
||||
c10::optional<Generator> gen) {
|
||||
Tensor multinomial_mps(const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional<Generator> gen) {
|
||||
Tensor result = at::empty({0}, self.options().dtype(kLong));
|
||||
multinomial_out_mps(self, n_sample, with_replacement, gen, result);
|
||||
return result;
|
||||
|
@ -3,9 +3,8 @@
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
// Steps to add op for MPS backend:
|
||||
// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key
|
||||
@ -29,7 +28,6 @@
|
||||
// g) Then call runMPSGraph() with input params and return the result.
|
||||
//
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Tensor& eye_out_mps(int64_t n, Tensor& result) {
|
||||
@ -38,7 +36,6 @@ Tensor& eye_out_mps(int64_t n, Tensor& result) {
|
||||
}
|
||||
|
||||
Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
|
||||
// This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts
|
||||
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
|
||||
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);
|
||||
@ -55,11 +52,10 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph time and time again for the same operation
|
||||
// The keys of this structure are based on the inputs and outputs needed for the operation
|
||||
// Here, we don't have any input tensors, just an output tensor
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph
|
||||
// time and time again for the same operation The keys of this structure are based on the inputs and outputs needed
|
||||
// for the operation Here, we don't have any input tensors, just an output tensor
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
@ -67,12 +63,12 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
|
||||
// etc match the earlier created MPSGraph
|
||||
string key = "eye_out_mps:" + getTensorsStringKey({result});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -84,11 +80,9 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
dataType:getMPSDataType(result)];
|
||||
|
||||
// Here we can call the MPSGraph API needed to execute the operation.
|
||||
// The API details can be found here: https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
|
||||
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor
|
||||
numLower:0
|
||||
numUpper:0
|
||||
name:nil];
|
||||
// The API details can be found here:
|
||||
// https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
|
||||
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil];
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
@ -102,9 +96,8 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
// In this case, there are no inputs, so the feeds are nil
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
// Run the graph
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
@ -113,5 +106,4 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,12 +1,15 @@
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/GridSamplerUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& grid,
|
||||
int64_t interpolation_mode, int64_t padding_mode,
|
||||
void grid_sampler_2d_mps_impl(Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& grid,
|
||||
int64_t interpolation_mode,
|
||||
int64_t padding_mode,
|
||||
bool align_corners) {
|
||||
// Grid Sampler support has been added in macOS 13.1
|
||||
#if defined(__MAC_13_2)
|
||||
@ -18,29 +21,37 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
MPSGraphPaddingMode paddingMode;
|
||||
|
||||
auto memory_format = input.suggest_memory_format();
|
||||
MPSGraphTensorNamedDataLayout inputTensorLayout =
|
||||
(memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
|
||||
MPSGraphTensorNamedDataLayout inputTensorLayout = (memory_format == at::MemoryFormat::Contiguous)
|
||||
? MPSGraphTensorNamedDataLayoutNCHW
|
||||
: MPSGraphTensorNamedDataLayoutNHWC;
|
||||
|
||||
switch (static_cast<GridSamplerPadding>(padding_mode)) {
|
||||
case GridSamplerPadding::Zeros:
|
||||
paddingMode = MPSGraphPaddingModeZero; break;
|
||||
paddingMode = MPSGraphPaddingModeZero;
|
||||
break;
|
||||
case GridSamplerPadding::Border:
|
||||
TORCH_CHECK(false, "MPS: Unsupported Border padding mode"); break;
|
||||
TORCH_CHECK(false, "MPS: Unsupported Border padding mode");
|
||||
break;
|
||||
case GridSamplerPadding::Reflection:
|
||||
paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric; break;
|
||||
paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "MPS: Unrecognised Padding Mode: ", padding_mode);
|
||||
}
|
||||
|
||||
switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
|
||||
case GridSamplerInterpolation::Bilinear:
|
||||
samplingMode = MPSGraphResizeBilinear; break;
|
||||
samplingMode = MPSGraphResizeBilinear;
|
||||
break;
|
||||
case GridSamplerInterpolation::Nearest:
|
||||
samplingMode = MPSGraphResizeNearest; break;
|
||||
samplingMode = MPSGraphResizeNearest;
|
||||
break;
|
||||
case GridSamplerInterpolation::Bicubic:
|
||||
TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation"); break;
|
||||
TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation");
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode); break;
|
||||
TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode);
|
||||
break;
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@ -55,16 +66,12 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "grid_sampler_2d_mps" +
|
||||
getTensorsStringKey({input, grid}) +
|
||||
":" + std::to_string(interpolation_mode) +
|
||||
":" + std::to_string(padding_mode) +
|
||||
":" + std::to_string(align_corners);
|
||||
string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) +
|
||||
":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
@ -111,22 +118,22 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor&
|
||||
Placeholder gridPlaceholder = Placeholder(cachedGraph->gridTensor_, grid);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
gridPlaceholder.getMPSGraphTensor() : gridPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
#endif // defined(__MAC_13_2)
|
||||
}
|
||||
|
||||
Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid,
|
||||
int64_t interpolation_mode, int64_t padding_mode,
|
||||
Tensor grid_sampler_2d_mps(const Tensor& input,
|
||||
const Tensor& grid,
|
||||
int64_t interpolation_mode,
|
||||
int64_t padding_mode,
|
||||
bool align_corners) {
|
||||
#if defined(__MAC_13_2)
|
||||
bool xcode_sdk_13_2_or_higher = true;
|
||||
@ -138,17 +145,16 @@ Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid,
|
||||
TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.1. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
|
||||
return at::grid_sampler_2d(
|
||||
input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners).clone().to("mps");
|
||||
return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
|
||||
auto in_size = input.sizes();
|
||||
auto grid_size = grid.sizes();
|
||||
auto output = at::empty(
|
||||
{in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
|
||||
auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
|
||||
|
||||
grid_sampler_2d_mps_impl(
|
||||
output, input, grid, interpolation_mode, padding_mode, align_corners);
|
||||
grid_sampler_2d_mps_impl(output, input, grid, interpolation_mode, padding_mode, align_corners);
|
||||
return output;
|
||||
}
|
||||
|
||||
|
@ -3,25 +3,25 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/IndexKernel.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorAdvancedIndexing.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/operations/Indexing.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/TensorAdvancedIndexing.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/core/QScheme.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <ATen/native/IndexKernel.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
@ -29,8 +29,7 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
static
|
||||
bool dispatchIndexKernel(TensorIteratorBase& iter,
|
||||
static bool dispatchIndexKernel(TensorIteratorBase& iter,
|
||||
IntArrayRef index_size,
|
||||
IntArrayRef index_stride,
|
||||
bool index_select,
|
||||
@ -73,12 +72,14 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
|
||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
||||
id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
|
||||
error: &error] autorelease];
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
|
||||
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
|
||||
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
|
||||
@ -92,14 +93,14 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
|
||||
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
|
||||
|
||||
MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues new] autorelease];
|
||||
[constantValues setConstantValue:&num_indices type:MTLDataTypeUInt atIndex:0];
|
||||
|
||||
std::string indexFunction = getIndexFunctionName(inputTensor.scalar_type(), index_select, accumulate);
|
||||
id<MTLFunction> indexKernelFunction = MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues);
|
||||
id<MTLFunction> indexKernelFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues);
|
||||
id<MTLArgumentEncoder> argumentEncoder = [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease];
|
||||
NSUInteger argumentBufferLength = argumentEncoder.encodedLength;
|
||||
id<MTLBuffer> indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
|
||||
@ -129,15 +130,16 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
|
||||
[computeEncoder setBytes:index_stride.data() length:sizeof(index_stride[0]) * index_stride.size() atIndex:2];
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
|
||||
[computeEncoder setBuffer:inputBuffer offset:inputTensor.storage_offset() * inputTensor.element_size() atIndex:4];
|
||||
[computeEncoder setBuffer:outputBuffer offset:outputTensor.storage_offset() * outputTensor.element_size() atIndex:5];
|
||||
[computeEncoder setBuffer:outputBuffer
|
||||
offset:outputTensor.storage_offset() * outputTensor.element_size()
|
||||
atIndex:5];
|
||||
|
||||
NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (tgSize > numThreads)
|
||||
tgSize = numThreads;
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
||||
[computeEncoder dispatchThreads: gridSize
|
||||
threadsPerThreadgroup: threadGroupSize];
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
[computeEncoder endEncoding];
|
||||
mpsStream->synchronize(SyncType::COMMIT_AND_CONTINUE);
|
||||
@ -147,7 +149,11 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
|
||||
return true;
|
||||
}
|
||||
|
||||
static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const std::string& op, bool accumulate) {
|
||||
static void validateInputData(const TensorIteratorBase& iter,
|
||||
IntArrayRef index_size,
|
||||
IntArrayRef index_stride,
|
||||
const std::string& op,
|
||||
bool accumulate) {
|
||||
using namespace mps;
|
||||
|
||||
int64_t num_indices = index_size.size();
|
||||
@ -159,13 +165,11 @@ static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_
|
||||
|
||||
if (accumulate) {
|
||||
// No atomic support for the rest of dtypes
|
||||
TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float ||
|
||||
inputTensor.scalar_type() == ScalarType::Int ||
|
||||
TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int ||
|
||||
inputTensor.scalar_type() == ScalarType::Bool);
|
||||
} else {
|
||||
TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true) ||
|
||||
inputTensor.scalar_type() == ScalarType::Float ||
|
||||
inputTensor.scalar_type() == ScalarType::Half,
|
||||
inputTensor.scalar_type() == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Half,
|
||||
getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out"));
|
||||
}
|
||||
}
|
||||
@ -189,31 +193,27 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray
|
||||
static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) {
|
||||
NoNamesGuard guard;
|
||||
|
||||
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
|
||||
"masked_select: expected BoolTensor for mask");
|
||||
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "masked_select: expected BoolTensor for mask");
|
||||
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
|
||||
"masked_select(): self and result must have the same scalar type");
|
||||
|
||||
auto mask_temp = (mask.dim() == 0)
|
||||
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
|
||||
: c10::MaybeOwned<Tensor>::borrowed(mask);
|
||||
auto self_temp = (self.dim() == 0)
|
||||
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
|
||||
: c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
auto mask_temp =
|
||||
(mask.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(mask);
|
||||
auto self_temp =
|
||||
(self.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
|
||||
// Cannot reassign to mask_temp and self_temp here! if they are
|
||||
// owning and expand_outplace returns a borrow, the returned borrow
|
||||
// would dangle.
|
||||
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
|
||||
at::index_out(
|
||||
result, *std::get<1>(mask_self_expanded),
|
||||
at::index_out(result,
|
||||
*std::get<1>(mask_self_expanded),
|
||||
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static
|
||||
Tensor nonzero_fallback(const Tensor& self) {
|
||||
static Tensor nonzero_fallback(const Tensor& self) {
|
||||
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
|
||||
@ -237,17 +237,21 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
|
||||
using namespace mps;
|
||||
const uint32_t maxDimensions = 16;
|
||||
|
||||
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
|
||||
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(),
|
||||
"nonzero is not supported for tensors with more than INT_MAX elements, \
|
||||
file a support request");
|
||||
TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
|
||||
TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ",
|
||||
out_.device(), " and self on ", self.device());
|
||||
TORCH_CHECK(
|
||||
out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
|
||||
TORCH_CHECK(self.device() == out_.device(),
|
||||
"expected self and out to be on the same device, but got out on ",
|
||||
out_.device(),
|
||||
" and self on ",
|
||||
self.device());
|
||||
TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions");
|
||||
TORCH_CHECK(out_.is_mps());
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -258,12 +262,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt));
|
||||
Tensor out = at::native::empty_mps(
|
||||
{self.numel(), nDim == 0 ? 1 : nDim},
|
||||
out_.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
{self.numel(), nDim == 0 ? 1 : nDim}, out_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
|
||||
int64_t _apparentInputShape = 1;
|
||||
for (auto dim : self.sizes()) {
|
||||
@ -295,23 +294,21 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
|
||||
MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
|
||||
MPSGraphTensor* inputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
|
||||
MPSGraphTensor* scatterDataTensor =
|
||||
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
MPSGraphTensor* countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor axis:0 name:nil];
|
||||
MPSGraphTensor* maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castToInt32"];
|
||||
MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
MPSGraphTensor* indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor axis:0 name:nil];
|
||||
MPSGraphTensor* indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
|
||||
secondaryTensor:oneTensor
|
||||
name:nil];
|
||||
@ -319,15 +316,16 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
|
||||
truePredicateTensor:indicesMinusOneTensor
|
||||
falsePredicateTensor:minusMaxDimTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil]
|
||||
MPSGraphTensor* coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0
|
||||
withShape:inputShape
|
||||
name:nil]
|
||||
withShape:@[ @-1 ]
|
||||
name:nil];
|
||||
if (nDim > 1) {
|
||||
NSMutableArray<MPSGraphTensor*>* maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
|
||||
NSMutableArray<MPSGraphTensor*>* coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
|
||||
|
||||
MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* constantRankTensor = [mpsGraph constantWithScalar:nDim dataType:MPSDataTypeInt32];
|
||||
maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
|
||||
secondaryTensor:constantRankTensor
|
||||
name:nil];
|
||||
@ -336,7 +334,9 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
|
||||
maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
|
||||
secondaryTensor:oneTensor
|
||||
name:nil];
|
||||
coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil]
|
||||
coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i
|
||||
withShape:inputShape
|
||||
name:nil]
|
||||
withShape:@[ @-1 ]
|
||||
name:nil];
|
||||
}
|
||||
@ -409,13 +409,8 @@ Tensor & masked_select_out_mps(const Tensor & self, const Tensor & mask, Tensor
|
||||
Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
||||
using namespace mps;
|
||||
|
||||
Tensor result = at::native::empty_mps(
|
||||
self.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
Tensor result =
|
||||
at::native::empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
|
||||
auto total_dims = self.dim();
|
||||
// It wraps the dims and checks that there are no repeated dims
|
||||
@ -451,12 +446,12 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
||||
}
|
||||
@autoreleasepool {
|
||||
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
|
||||
// etc match the earlier created MPSGraph
|
||||
string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]);
|
||||
auto cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -464,9 +459,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor
|
||||
axes:ns_dims
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor axes:ns_dims name:nil];
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
@ -475,36 +468,31 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
||||
}
|
||||
|
||||
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
|
||||
Placeholder inputPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, result, /*mpsShape*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
Placeholder inputPlaceholder =
|
||||
Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, result, /*mpsShape*/ nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
// Run the graph
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(index_add_mps_out)(
|
||||
const Tensor& self,
|
||||
TORCH_IMPL_FUNC(index_add_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& source,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
@ -515,8 +503,7 @@ TORCH_IMPL_FUNC(index_add_mps_out)(
|
||||
TORCH_CHECK(source.scalar_type() != ScalarType::Long, "index_add(): Expected non int64 dtype for source.");
|
||||
auto casted_type = isFloatingType(source.scalar_type()) ? ScalarType::Float : ScalarType::Int;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* indexTensor_ = nil;
|
||||
@ -528,7 +515,6 @@ TORCH_IMPL_FUNC(index_add_mps_out)(
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim);
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
|
||||
@ -585,17 +571,14 @@ TORCH_IMPL_FUNC(index_add_mps_out)(
|
||||
sourcePlaceholder.getMPSGraphTensor() : sourcePlaceholder.getMPSGraphTensorData(),
|
||||
cachedGraph->alphaTensor_ : getMPSGraphTensorFromScalar(stream, alpha_scalar),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor index_select_mps(const Tensor & self,
|
||||
int64_t dim,
|
||||
const Tensor & index) {
|
||||
Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
|
||||
IntArrayRef input_shape = self.sizes();
|
||||
auto num_input_dims = input_shape.size();
|
||||
|
||||
@ -616,33 +599,24 @@ Tensor index_select_mps(const Tensor & self,
|
||||
|
||||
IntArrayRef output_shape = IntArrayRef(shape_data.data(), num_input_dims);
|
||||
|
||||
Tensor result = at::native::empty_mps(
|
||||
output_shape,
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
Tensor result =
|
||||
at::native::empty_mps(output_shape, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
|
||||
index_select_out_mps(self, dim, index, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& index_select_out_mps(const Tensor & self,
|
||||
int64_t dim,
|
||||
const Tensor & index,
|
||||
Tensor & output) {
|
||||
|
||||
Tensor& index_select_out_mps(const Tensor& self, int64_t dim, const Tensor& index, Tensor& output) {
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
// Checks
|
||||
TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
|
||||
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
|
||||
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
|
||||
"index_select(): Expected dtype int32 or int64 for index");
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type(),
|
||||
"index_select(): self and output must have the same scalar type");
|
||||
TORCH_CHECK(dim == 0 || dim < self.dim(),
|
||||
"index_select(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
|
||||
// Empty index
|
||||
if (index.numel() == 0) {
|
||||
@ -656,8 +630,7 @@ Tensor& index_select_out_mps(const Tensor & self,
|
||||
}
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* indexTensor_ = nil;
|
||||
@ -667,17 +640,14 @@ Tensor& index_select_out_mps(const Tensor & self,
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
auto inputType = getMPSDataType(self);
|
||||
auto outputType = getMPSDataType(output);
|
||||
if (inputType == MPSDataTypeUInt8 ||
|
||||
(!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) {
|
||||
if (inputType == MPSDataTypeUInt8 || (!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) {
|
||||
inputType = MPSDataTypeInt8;
|
||||
}
|
||||
if (outputType == MPSDataTypeUInt8 ||
|
||||
(!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) {
|
||||
if (outputType == MPSDataTypeUInt8 || (!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) {
|
||||
outputType = MPSDataTypeInt8;
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim);
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
|
||||
@ -706,25 +676,29 @@ Tensor& index_select_out_mps(const Tensor & self,
|
||||
});
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self,
|
||||
/*mpsShape=*/nullptr, /*gatherTensorData=*/true, /*dataType=*/inputType);
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_,
|
||||
self,
|
||||
/*mpsShape=*/nullptr,
|
||||
/*gatherTensorData=*/true,
|
||||
/*dataType=*/inputType);
|
||||
Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output,
|
||||
/*mpsShape=*/nullptr, /*gatherTensorData=*/false, /*dataType=*/outputType);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
|
||||
output,
|
||||
/*mpsShape=*/nullptr,
|
||||
/*gatherTensorData=*/false,
|
||||
/*dataType=*/outputType);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
}
|
||||
|
||||
Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Scalar& value) {
|
||||
@ -733,16 +707,19 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ",
|
||||
mask.device(), " and self on ", self.device());
|
||||
TORCH_CHECK(self.device() == mask.device(),
|
||||
"expected self and mask to be on the same device, but got mask on ",
|
||||
mask.device(),
|
||||
" and self on ",
|
||||
self.device());
|
||||
TORCH_CHECK(mask.scalar_type() == kByte || mask.scalar_type() == kBool,
|
||||
"expected mask dtype to be Bool but got ", mask.scalar_type());
|
||||
"expected mask dtype to be Bool but got ",
|
||||
mask.scalar_type());
|
||||
auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
|
||||
|
||||
c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* maskTensor_ = nil;
|
||||
@ -772,7 +749,6 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -786,9 +762,7 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
||||
MPSDataType valueType = getMPSScalarType(value.type());
|
||||
MPSGraphTensor* castValueTensor = valueTensor;
|
||||
if (valueType != inputDataType) {
|
||||
castValueTensor = [mpsGraph castTensor:valueTensor
|
||||
toType:inputDataType
|
||||
name:@"castValueTensor"];
|
||||
castValueTensor = [mpsGraph castTensor:valueTensor toType:inputDataType name:@"castValueTensor"];
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor
|
||||
@ -805,12 +779,12 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
||||
});
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder maskPlaceholder = Placeholder(
|
||||
cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nil, /*gatherTensorData=*/true, maskDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/false, inputDataType);
|
||||
Placeholder selfPlaceholder =
|
||||
Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder maskPlaceholder =
|
||||
Placeholder(cachedGraph->maskTensor_, *b_mask, /*mpsShape*/ nil, /*gatherTensorData=*/true, maskDataType);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/false, inputDataType);
|
||||
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
@ -819,9 +793,8 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
||||
cachedGraph->valueTensor_ : getMPSGraphTensorFromScalar(stream, valueScalar)
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -829,14 +802,14 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor embedding_dense_backward_mps(
|
||||
const Tensor & grad_, const Tensor & indices, int64_t num_weights,
|
||||
int64_t padding_idx, bool scale_grad_by_freq)
|
||||
{
|
||||
Tensor embedding_dense_backward_mps(const Tensor& grad_,
|
||||
const Tensor& indices,
|
||||
int64_t num_weights,
|
||||
int64_t padding_idx,
|
||||
bool scale_grad_by_freq) {
|
||||
// TODO: implement padding_idx & scale_grad_by_freq.
|
||||
namespace native_mps = at::native::mps;
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* incomingGradTensor_ = nil;
|
||||
MPSGraphTensor* indicesTensor_ = nil;
|
||||
@ -854,12 +827,7 @@ Tensor embedding_dense_backward_mps(
|
||||
int64_t D = incoming_gradient_shape[num_incoming_gradient_dims - 1];
|
||||
c10::SmallVector<int64_t, 2> outgoing_gradient_shape{num_weights, D};
|
||||
Tensor outgoing_gradient = at::native::empty_mps(
|
||||
IntArrayRef(outgoing_gradient_shape),
|
||||
grad_.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
IntArrayRef(outgoing_gradient_shape), grad_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
|
||||
if (outgoing_gradient.numel() == 0) {
|
||||
return outgoing_gradient;
|
||||
@ -868,21 +836,24 @@ Tensor embedding_dense_backward_mps(
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "edb_mps:" + native_mps::getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) + ":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq);
|
||||
string key = "edb_mps:" + native_mps::getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) +
|
||||
":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" +
|
||||
std::to_string(scale_grad_by_freq);
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
// Initialize once if configuration not found in cache
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = native_mps::make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* incomingGradTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_));
|
||||
MPSGraphTensor* incomingGradTensor =
|
||||
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_));
|
||||
|
||||
MPSGraphTensor* indicesTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices));
|
||||
MPSGraphTensor* indicesTensor =
|
||||
native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices));
|
||||
|
||||
MPSGraphTensor* reshapedIndicesTensor = indicesTensor;
|
||||
|
||||
@ -890,17 +861,14 @@ Tensor embedding_dense_backward_mps(
|
||||
MPSDataType dataType = mps::getMPSDataType(grad_);
|
||||
// issue 105486100, scatterNDWithUpdatesTensor produces wrong result for float16
|
||||
if (dataType == MPSDataTypeFloat16) {
|
||||
castGradTensor = [mpsGraph castTensor: incomingGradTensor
|
||||
toType: MPSDataTypeFloat32
|
||||
name: @"castGradTensor"];
|
||||
castGradTensor = [mpsGraph castTensor:incomingGradTensor toType:MPSDataTypeFloat32 name:@"castGradTensor"];
|
||||
}
|
||||
if (num_indices_dims != 0) {
|
||||
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor: indicesTensor
|
||||
axes: @[@-1]
|
||||
name: nil];
|
||||
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor:indicesTensor axes:@[ @-1 ] name:nil];
|
||||
}
|
||||
|
||||
auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: castGradTensor
|
||||
auto outgoingGradTensor =
|
||||
[mpsGraph scatterNDWithUpdatesTensor:castGradTensor
|
||||
indicesTensor:reshapedIndicesTensor
|
||||
shape:native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape))
|
||||
batchDimensions:0
|
||||
@ -914,7 +882,6 @@ Tensor embedding_dense_backward_mps(
|
||||
newCachedGraph->incomingGradTensor_ = incomingGradTensor;
|
||||
newCachedGraph->indicesTensor_ = indicesTensor;
|
||||
newCachedGraph->outgoingGradTensor_ = outgoingGradTensor;
|
||||
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
@ -928,24 +895,25 @@ Tensor embedding_dense_backward_mps(
|
||||
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outgoingGradPlaceholder.getMPSGraphTensor() : outgoingGradPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outgoingGradPlaceholder.getMPSGraphTensor() : outgoingGradPlaceholder.getMPSGraphTensorData()};
|
||||
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return outgoing_gradient;
|
||||
}
|
||||
|
||||
Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Tensor& value) {
|
||||
TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
|
||||
"with ", value.dim(), " dimension(s).");
|
||||
TORCH_CHECK(value.dim() == 0,
|
||||
"masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
|
||||
"with ",
|
||||
value.dim(),
|
||||
" dimension(s).");
|
||||
return masked_fill__mps(self, mask, value.item());
|
||||
}
|
||||
|
||||
Tensor& masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& source) {
|
||||
at::assert_no_internal_overlap(self);
|
||||
TORCH_CHECK(
|
||||
self.scalar_type() == source.scalar_type(),
|
||||
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
|
||||
"masked_scatter: expected self and source to have same dtypes but got",
|
||||
self.scalar_type(),
|
||||
" and ",
|
||||
@ -958,21 +926,18 @@ Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& sou
|
||||
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
|
||||
"masked_scatter: expected BoolTensor or ByteTensor for mask");
|
||||
|
||||
auto mask_temp = (mask.dim() == 0)
|
||||
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
|
||||
: c10::MaybeOwned<Tensor>::borrowed(mask);
|
||||
auto self_temp = (self.dim() == 0)
|
||||
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
|
||||
: c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
auto mask_temp =
|
||||
(mask.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(mask);
|
||||
auto self_temp =
|
||||
(self.dim() == 0) ? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0)) : c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
|
||||
// Cannot reassign to mask_temp and self_temp here! if they are
|
||||
// owning and expand_outplace returns a borrow, the returned borrow
|
||||
// would dangle.
|
||||
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
|
||||
auto indices = at::native::expandTensors(
|
||||
*std::get<1>(mask_self_expanded),
|
||||
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))})
|
||||
);
|
||||
auto indices =
|
||||
at::native::expandTensors(*std::get<1>(mask_self_expanded),
|
||||
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
|
||||
// next broadcast all index tensors together
|
||||
try {
|
||||
indices = at::expand_outplace(indices);
|
||||
@ -990,12 +955,7 @@ Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& sou
|
||||
for (const auto index : indices) {
|
||||
final_indices.push_back(index);
|
||||
}
|
||||
return at::index_put_out(
|
||||
self,
|
||||
*std::get<1>(mask_self_expanded),
|
||||
final_indices,
|
||||
source.resize_(indices[0].numel())
|
||||
);
|
||||
return at::index_put_out(self, *std::get<1>(mask_self_expanded), final_indices, source.resize_(indices[0].numel()));
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(index_stub, &index_kernel_mps);
|
||||
|
@ -1,17 +1,16 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info)
|
||||
{
|
||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
|
||||
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
|
||||
auto cpu_result = result.clone().to("cpu");
|
||||
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
|
||||
@ -28,8 +27,7 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
return;
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -47,24 +45,20 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
@autoreleasepool {
|
||||
string key = "inv_out_mps" + getTensorsStringKey({A});
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if(!cachedGraph)
|
||||
{
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
|
||||
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
|
||||
return newCachedGraph;
|
||||
|
||||
});
|
||||
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
|
||||
}
|
||||
@ -72,13 +66,11 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
if (!isContiguous) {
|
||||
|
@ -6,17 +6,14 @@ namespace at::native {
|
||||
|
||||
using namespace mps;
|
||||
|
||||
Tensor _mps_linear(
|
||||
const Tensor& input,
|
||||
const Tensor& weight_arg,
|
||||
const c10::optional<Tensor>& bias_opt) {
|
||||
Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::optional<Tensor>& bias_opt) {
|
||||
// wT = transpose(weight);
|
||||
// y=x*wT+b
|
||||
|
||||
auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg;
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == ScalarType::Float ||
|
||||
input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs");
|
||||
TORCH_CHECK(input.scalar_type() == ScalarType::Float || input.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support linear for non-float inputs");
|
||||
|
||||
const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
|
||||
bool is_bias_defined = bias.defined();
|
||||
@ -24,12 +21,8 @@ Tensor _mps_linear(
|
||||
auto input_size = input.sizes();
|
||||
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
|
||||
output_size.push_back(weight.size(0));
|
||||
Tensor output = at::native::empty_mps(output_size,
|
||||
input.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
input.suggest_memory_format());
|
||||
Tensor output = at::native::empty_mps(
|
||||
output_size, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, input.suggest_memory_format());
|
||||
|
||||
TORCH_CHECK(output.is_mps());
|
||||
|
||||
@ -39,8 +32,7 @@ Tensor _mps_linear(
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
@ -55,10 +47,8 @@ Tensor _mps_linear(
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
@ -71,14 +61,11 @@ Tensor _mps_linear(
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
if (!is_bias_defined)
|
||||
{
|
||||
if (!is_bias_defined) {
|
||||
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:weightTransposeTensor
|
||||
name:nil];
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
MPSGraphTensor* inputFlattened = inputTensor;
|
||||
bool doReshape = false;
|
||||
// workaround to improve the performance with 3D+ inputs
|
||||
@ -94,7 +81,8 @@ Tensor _mps_linear(
|
||||
MPSGraphTensor* biasedTensor = [mpsGraph additionWithPrimaryTensor:xMulWTTensor
|
||||
secondaryTensor:newCachedGraph->biasTensor_
|
||||
name:nil];
|
||||
outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil] : biasedTensor;
|
||||
outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil]
|
||||
: biasedTensor;
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -117,9 +105,8 @@ Tensor _mps_linear(
|
||||
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
|
||||
feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -134,37 +121,27 @@ Tensor _mps_linear(
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor _mps_linear_backward_input(
|
||||
IntArrayRef input_size,
|
||||
const Tensor & grad_output,
|
||||
const Tensor & weight)
|
||||
{
|
||||
TORCH_CHECK(grad_output.is_mps(),
|
||||
"mps_linear_backward: grad_output needs to be mps layout");
|
||||
TORCH_CHECK(weight.device().is_mps() &&
|
||||
(weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)),
|
||||
"mps_linear_backward: unsupported weights data type: ", weight.scalar_type());
|
||||
Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) {
|
||||
TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout");
|
||||
TORCH_CHECK(weight.device().is_mps() && (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)),
|
||||
"mps_linear_backward: unsupported weights data type: ",
|
||||
weight.scalar_type());
|
||||
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double
|
||||
|| grad_output.scalar_type() == ScalarType::Float
|
||||
|| grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double || grad_output.scalar_type() == ScalarType::Float ||
|
||||
grad_output.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support linear backward for non-float inputs");
|
||||
|
||||
const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
Tensor output = at::native::empty_mps(input_size,
|
||||
grad_output.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
grad_output.suggest_memory_format());
|
||||
Tensor output = at::native::empty_mps(
|
||||
input_size, grad_output.scalar_type(), c10::nullopt, kMPS, c10::nullopt, grad_output.suggest_memory_format());
|
||||
TORCH_CHECK(output.is_mps());
|
||||
if (grad_output.numel() == 0) {
|
||||
return output;
|
||||
@ -175,7 +152,6 @@ Tensor _mps_linear_backward_input(
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
@ -189,8 +165,7 @@ Tensor _mps_linear_backward_input(
|
||||
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor *outputTensor =
|
||||
[mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTensor
|
||||
MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensor
|
||||
secondaryTensor:weightTensor
|
||||
name:nil];
|
||||
|
||||
@ -211,9 +186,8 @@ Tensor _mps_linear_backward_input(
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
@ -221,17 +195,17 @@ Tensor _mps_linear_backward_input(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight, bool bias_defined)
|
||||
{
|
||||
std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
bool bias_defined) {
|
||||
TORCH_CHECK(grad_output.is_mps() && input.is_mps(),
|
||||
"_mps_linear_backward: grad_output and input needs to be mps layout");
|
||||
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float ||
|
||||
grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs");
|
||||
TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || grad_output.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support linear backward for non-float inputs");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
@ -240,8 +214,8 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
};
|
||||
|
||||
auto grad_output_reshaped = grad_output.dim() != 2 ?
|
||||
grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
|
||||
auto grad_output_reshaped =
|
||||
grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
|
||||
auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input;
|
||||
|
||||
TORCH_CHECK(grad_output_reshaped.is_mps());
|
||||
@ -272,7 +246,6 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
|
||||
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
@ -288,25 +261,19 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped);
|
||||
|
||||
MPSGraphTensor *gradOutputTransposeTensor =
|
||||
[mpsGraph transposeTensor: gradOutputTensor
|
||||
MPSGraphTensor* gradOutputTransposeTensor = [mpsGraph transposeTensor:gradOutputTensor
|
||||
dimension:-1
|
||||
withDimension:-2
|
||||
name:nil];
|
||||
|
||||
// grad_weight
|
||||
MPSGraphTensor *outputTensor =
|
||||
[mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTransposeTensor
|
||||
MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTransposeTensor
|
||||
secondaryTensor:inputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if (bias_defined)
|
||||
{
|
||||
if (bias_defined) {
|
||||
// grad_bias
|
||||
biasTensor = [mpsGraph reductionSumWithTensor: gradOutputTensor
|
||||
axis: 0
|
||||
name: nil];
|
||||
|
||||
biasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:0 name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -342,10 +309,10 @@ std::tuple<Tensor, Tensor> _mps_linear_backward_weights(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(
|
||||
const Tensor& input, const Tensor& grad_output,
|
||||
const Tensor& weight, std::array<bool,3> output_mask) {
|
||||
std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
const Tensor& weight,
|
||||
std::array<bool, 3> output_mask) {
|
||||
Tensor grad_input, grad_weight, grad_bias;
|
||||
if (output_mask[0]) {
|
||||
grad_input = _mps_linear_backward_input(input.sizes(), grad_output, weight);
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -15,19 +15,18 @@ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor,
|
||||
bool& transpose_tensor,
|
||||
int64_t& ld_tensor,
|
||||
bool transpose_result,
|
||||
int64_t m, int64_t n) {
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
IntArrayRef tensor_strides = tensor.strides();
|
||||
Tensor tensor_;
|
||||
int fast_dim = transpose_result ? 2 : 1;
|
||||
int leading_dim = transpose_result ? 1 : 2;
|
||||
|
||||
if (tensor_strides[fast_dim] == 1 &&
|
||||
(tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
|
||||
if (tensor_strides[fast_dim] == 1 && (tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
|
||||
transpose_tensor = false;
|
||||
tensor_ = tensor;
|
||||
ld_tensor = tensor_strides[leading_dim];
|
||||
} else if ((tensor_strides[leading_dim] == 1) &&
|
||||
(tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
|
||||
} else if ((tensor_strides[leading_dim] == 1) && (tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
|
||||
transpose_tensor = true;
|
||||
tensor_ = tensor;
|
||||
ld_tensor = tensor_strides[fast_dim];
|
||||
@ -50,8 +49,7 @@ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor,
|
||||
* Helper functions to be used for mm/addmm for detecting the Transpositions
|
||||
* when doing GEMM operations.
|
||||
*/
|
||||
void prepare_matrices_for_broadcasting(
|
||||
const Tensor * bias,
|
||||
void prepare_matrices_for_broadcasting(const Tensor* bias,
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar* beta,
|
||||
@ -79,20 +77,14 @@ void prepare_matrices_for_broadcasting(
|
||||
}
|
||||
}
|
||||
|
||||
enum LinearAlgebraOpType {
|
||||
ADDBMM_OP_TYPE,
|
||||
BADDBMM_OP_TYPE
|
||||
};
|
||||
enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
|
||||
|
||||
Tensor& mm_out_mps_impl(
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
Tensor& output) {
|
||||
Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||
using namespace mps;
|
||||
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double
|
||||
|| self.scalar_type() == ScalarType::Float
|
||||
|| self.scalar_type() == ScalarType::Half, "MPS device does not support mm for non-float inputs");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float ||
|
||||
self.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support mm for non-float inputs");
|
||||
|
||||
TensorArg args[]{{output, "out", 0}, {self, "mat1", 1}, {other, "mat2", 2}};
|
||||
checkAllSameGPU("mm", args);
|
||||
@ -105,8 +97,7 @@ Tensor& mm_out_mps_impl(
|
||||
return output;
|
||||
}
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
@ -118,12 +109,10 @@ Tensor& mm_out_mps_impl(
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -136,14 +125,11 @@ Tensor& mm_out_mps_impl(
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
if (self.numel() == 0 || other.numel() == 0) {
|
||||
|
||||
outputTensor = [mpsGraph constantWithScalar:0.
|
||||
shape:getMPSShape(output_sizes)
|
||||
dataType:getMPSDataType(output)];
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
} else {
|
||||
selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
|
||||
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfTensor
|
||||
@ -175,9 +161,8 @@ Tensor& mm_out_mps_impl(
|
||||
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -185,26 +170,25 @@ Tensor& mm_out_mps_impl(
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
Tensor addr_mps(const Tensor& self,
|
||||
const Tensor& vec1, const Tensor& vec2,
|
||||
const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
addr_out_mps(self, vec1, vec2, beta, alpha, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Tensor& addr_out_mps(const Tensor& self,
|
||||
const Tensor& vec1, const Tensor& vec2,
|
||||
const Scalar& beta, const Scalar& alpha, Tensor &result) {
|
||||
const Tensor& vec1,
|
||||
const Tensor& vec2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(result.is_mps());
|
||||
TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
|
||||
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
|
||||
|| vec1.scalar_type() == ScalarType::Float
|
||||
|| vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");
|
||||
TORCH_CHECK(vec1.scalar_type() == ScalarType::Double || vec1.scalar_type() == ScalarType::Float ||
|
||||
vec1.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support addr for non-float input");
|
||||
|
||||
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
|
||||
checkAllSameGPU(__func__, args);
|
||||
@ -242,8 +226,7 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
MPSShape* inputShape = @[ @(vec1.numel()), @(1) ];
|
||||
MPSShape* otherShape = @[ @(1), @(vec2.numel()) ];
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* vec1Tensor_ = nil;
|
||||
MPSGraphTensor* vec2Tensor_ = nil;
|
||||
@ -254,12 +237,10 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
|
||||
+ ":" + to_string(beta.toDouble())
|
||||
+ ":" + to_string(alpha.toDouble());
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) +
|
||||
":" + to_string(alpha.toDouble());
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -321,9 +302,8 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -331,8 +311,7 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& addmm_out_mps_impl(
|
||||
const Tensor& bias,
|
||||
Tensor& addmm_out_mps_impl(const Tensor& bias,
|
||||
const Tensor& self, // input
|
||||
const Tensor& other, // weight
|
||||
const Scalar& beta,
|
||||
@ -342,9 +321,9 @@ Tensor& addmm_out_mps_impl(
|
||||
|
||||
TORCH_CHECK(output.is_mps());
|
||||
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double
|
||||
|| self.scalar_type() == ScalarType::Float
|
||||
|| self.scalar_type() == ScalarType::Half, "MPS device does not support addmm for non-float input");
|
||||
TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float ||
|
||||
self.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support addmm for non-float input");
|
||||
|
||||
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
|
||||
checkAllSameGPU(__func__, args);
|
||||
@ -382,10 +361,10 @@ Tensor& addmm_out_mps_impl(
|
||||
bool transpose_mat2 = false;
|
||||
bool is_beta_non_zero = beta.toDouble() != 0.0;
|
||||
|
||||
prepare_matrices_for_broadcasting(&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2);
|
||||
prepare_matrices_for_broadcasting(
|
||||
&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2);
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
@ -396,13 +375,10 @@ Tensor& addmm_out_mps_impl(
|
||||
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_})
|
||||
+ ":" + to_string(transpose_mat1) + ":" + to_string(transpose_mat2)
|
||||
+ ":" + to_string(beta.toDouble())
|
||||
+ ":" + to_string(alpha.toDouble());
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(transpose_mat1) +
|
||||
":" + to_string(transpose_mat2) + ":" + to_string(beta.toDouble()) + ":" + to_string(alpha.toDouble());
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -418,22 +394,15 @@ Tensor& addmm_out_mps_impl(
|
||||
MPSGraphTensor* t2 = nil;
|
||||
|
||||
if (transpose_mat1)
|
||||
t1 = [mpsGraph transposeTensor:selfTensor
|
||||
dimension:-1
|
||||
withDimension:-2
|
||||
name:nil];
|
||||
t1 = [mpsGraph transposeTensor:selfTensor dimension:-1 withDimension:-2 name:nil];
|
||||
else
|
||||
t1 = selfTensor;
|
||||
|
||||
if (transpose_mat2)
|
||||
t2 = [mpsGraph transposeTensor:otherTensor
|
||||
dimension:-1
|
||||
withDimension:-2
|
||||
name:nil];
|
||||
t2 = [mpsGraph transposeTensor:otherTensor dimension:-1 withDimension:-2 name:nil];
|
||||
else
|
||||
t2 = otherTensor;
|
||||
|
||||
|
||||
// TODO: Use alpha and beta here with fill_.Scalar and mul
|
||||
// Intermediate as placeholder
|
||||
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
|
||||
@ -458,10 +427,7 @@ Tensor& addmm_out_mps_impl(
|
||||
}
|
||||
|
||||
if (transpose_mat1_times_mat2)
|
||||
biasTimesBetaTensor = [mpsGraph transposeTensor: biasTimesBetaTensor
|
||||
dimension: -1
|
||||
withDimension: -2
|
||||
name: nil];
|
||||
biasTimesBetaTensor = [mpsGraph transposeTensor:biasTimesBetaTensor dimension:-1 withDimension:-2 name:nil];
|
||||
|
||||
MPSGraphTensor* outputTensor = productTimesAlphaTensor;
|
||||
if (is_beta_non_zero) {
|
||||
@ -491,9 +457,8 @@ Tensor& addmm_out_mps_impl(
|
||||
biasPlaceholder.getMPSGraphTensor() : biasPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -501,16 +466,12 @@ Tensor& addmm_out_mps_impl(
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
Tensor& bmm_out_mps_impl(
|
||||
const Tensor & batch1,
|
||||
const Tensor & batch2,
|
||||
Tensor & result) {
|
||||
Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
|
||||
|| batch1.scalar_type() == ScalarType::Float
|
||||
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs");
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float ||
|
||||
batch1.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support bmm for non-float inputs");
|
||||
|
||||
if (batch1.numel() == 0 || batch2.numel() == 0) {
|
||||
result.zero_();
|
||||
@ -519,8 +480,7 @@ Tensor& bmm_out_mps_impl(
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* batch1Tensor_ = nil;
|
||||
MPSGraphTensor* batch2Tensor_ = nil;
|
||||
@ -534,7 +494,6 @@ Tensor& bmm_out_mps_impl(
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -566,9 +525,8 @@ Tensor& bmm_out_mps_impl(
|
||||
batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -576,8 +534,7 @@ Tensor& bmm_out_mps_impl(
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
const Tensor & input,
|
||||
Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
@ -591,22 +548,29 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
TORCH_CHECK(batch2.is_mps());
|
||||
TORCH_CHECK(result.is_mps());
|
||||
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double
|
||||
|| batch1.scalar_type() == ScalarType::Float
|
||||
|| batch1.scalar_type() == ScalarType::Half, "MPS device does not support addbmm or baddbmm for non-float inputs");
|
||||
TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float ||
|
||||
batch1.scalar_type() == ScalarType::Half,
|
||||
"MPS device does not support addbmm or baddbmm for non-float inputs");
|
||||
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
|
||||
TORCH_CHECK(batch1.size(0) == batch2.size(0),
|
||||
"batch1 and batch2 must have same number of batches, got ",
|
||||
batch1.size(0), " and ", batch2.size(0));
|
||||
batch1.size(0),
|
||||
" and ",
|
||||
batch2.size(0));
|
||||
TORCH_CHECK(batch1.size(2) == batch2.size(1),
|
||||
"Incompatible matrix sizes for bmm (",
|
||||
batch1.size(1), "x", batch1.size(2), " and ",
|
||||
batch2.size(1), "x", batch2.size(2), ")");
|
||||
batch1.size(1),
|
||||
"x",
|
||||
batch1.size(2),
|
||||
" and ",
|
||||
batch2.size(1),
|
||||
"x",
|
||||
batch2.size(2),
|
||||
")");
|
||||
|
||||
if (opType == ADDBMM_OP_TYPE)
|
||||
{
|
||||
if (opType == ADDBMM_OP_TYPE) {
|
||||
result.resize_as_(input);
|
||||
|
||||
const int64_t num_batches = batch1.size(0);
|
||||
@ -619,8 +583,7 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* batch1Tensor_ = nil;
|
||||
@ -632,13 +595,11 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
|
||||
@autoreleasepool {
|
||||
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
|
||||
key += getTensorsStringKey({batch1, batch2, input})
|
||||
+ ":" + to_string(beta.toDouble())
|
||||
+ ":" + to_string(alpha.toDouble());
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" +
|
||||
to_string(alpha.toDouble());
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
|
||||
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@ -668,7 +629,8 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
}
|
||||
|
||||
// Intermediates for multiplying by beta and alpha
|
||||
MPSGraphTensor* reductionSumTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: reductionSumTensor
|
||||
MPSGraphTensor* reductionSumTimesAlphaTensor =
|
||||
[mpsGraph multiplicationWithPrimaryTensor:reductionSumTensor
|
||||
secondaryTensor:alphaTensor
|
||||
name:@"alpha*(batch1@batch2)"];
|
||||
MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
|
||||
@ -699,9 +661,8 @@ Tensor& addbmm_or_baddbmm_out_mps_impl(
|
||||
batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -713,7 +674,13 @@ TORCH_IMPL_FUNC(mm_out_mps)(const Tensor& self, const Tensor& mat2, const Tensor
|
||||
mm_out_mps_impl(self, mat2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_out_mps)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmm_out_mps)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
addmm_out_mps_impl(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
@ -721,18 +688,33 @@ TORCH_IMPL_FUNC(bmm_out_mps) (const Tensor & batch1, const Tensor & batch2, cons
|
||||
bmm_out_mps_impl(batch1, batch2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(baddbmm_out_mps) (const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(baddbmm_out_mps)
|
||||
(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
addbmm_or_baddbmm_out_mps_impl(self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result), BADDBMM_OP_TYPE);
|
||||
}
|
||||
|
||||
Tensor& addbmm_out_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
|
||||
Tensor& addbmm_out_mps(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& result) {
|
||||
auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out");
|
||||
|
||||
addbmm_or_baddbmm_out_mps_impl(*b_self, batch1, batch2, beta, alpha, result, ADDBMM_OP_TYPE);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
|
||||
Tensor addbmm_mps(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
return addbmm_out_mps(self, batch1, batch2, beta, alpha, result);
|
||||
}
|
||||
@ -741,7 +723,13 @@ Tensor &addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, co
|
||||
return addbmm_out_mps(self, batch1, batch2, beta, alpha, self);
|
||||
}
|
||||
|
||||
Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) {
|
||||
Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool upper,
|
||||
bool transpose,
|
||||
bool left,
|
||||
bool unitriangular,
|
||||
Tensor& out) {
|
||||
using namespace mps;
|
||||
|
||||
checkInputsSolver(A, B, left, "linalg.solve_triangular");
|
||||
@ -794,7 +782,8 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
rowBytes:aCols * aElemSize
|
||||
matrixBytes:aRows * aCols * aElemSize
|
||||
dataType:getMPSDataType(A_)];
|
||||
MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows
|
||||
MPSMatrixDescriptor* rightHandSideMatrixDesc =
|
||||
[MPSMatrixDescriptor matrixDescriptorWithRows:bRows
|
||||
columns:bCols
|
||||
matrices:batchSize
|
||||
rowBytes:bCols * bElemSize
|
||||
@ -806,7 +795,8 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
|
||||
offset:(A_t.storage_offset() + aBatchOffset) * aElemSize
|
||||
descriptor:sourceMatrixDesc] autorelease];
|
||||
MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer
|
||||
MPSMatrix* rightHandSideMatrix =
|
||||
[[[MPSMatrix alloc] initWithBuffer:bBuffer
|
||||
offset:(B_t.storage_offset() + bBatchOffset) * bElemSize
|
||||
descriptor:rightHandSideMatrixDesc] autorelease];
|
||||
MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer
|
||||
@ -824,7 +814,12 @@ Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) {
|
||||
Tensor& linalg_solve_triangular_mps_out(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool upper,
|
||||
bool left,
|
||||
bool unitriangular,
|
||||
Tensor& out) {
|
||||
return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out);
|
||||
}
|
||||
|
||||
@ -834,7 +829,14 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper,
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
|
||||
TORCH_IMPL_FUNC(triangular_solve_mps_out)
|
||||
(const Tensor& self,
|
||||
const Tensor& A,
|
||||
bool upper,
|
||||
bool transpose,
|
||||
bool unitriangular,
|
||||
const Tensor& result,
|
||||
const Tensor& clone_A) {
|
||||
clone_A.copy_(A);
|
||||
Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous);
|
||||
linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,12 +2,12 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/layer_norm.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
@ -16,14 +16,14 @@ void get_shapes(MPSShape* input_shape_readonly,
|
||||
NSMutableArray<NSNumber*>*& input_shape,
|
||||
NSMutableArray<NSNumber*>*& new_mean_shape,
|
||||
NSMutableArray<NSNumber*>*& axes,
|
||||
int num_input_dims, c10::MemoryFormat memory_format,
|
||||
int num_input_dims,
|
||||
c10::MemoryFormat memory_format,
|
||||
bool isBackward) {
|
||||
// Modify the shape
|
||||
if (memory_format == at::MemoryFormat::Contiguous) {
|
||||
for (int i = 0; i < num_input_dims; i++)
|
||||
input_shape[i] = input_shape_readonly[i];
|
||||
}
|
||||
else { // ChannelsLast
|
||||
} else { // ChannelsLast
|
||||
auto num_channels = input_shape_readonly[1];
|
||||
input_shape[0] = input_shape_readonly[0];
|
||||
for (int i = 1; i < num_input_dims - 1; i++)
|
||||
@ -37,8 +37,7 @@ void get_shapes(MPSShape* input_shape_readonly,
|
||||
new_mean_shape[1] = input_shape_readonly[1];
|
||||
for (int i = 2; i < num_input_dims; i++)
|
||||
new_mean_shape[i] = @1;
|
||||
}
|
||||
else if(memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
} else if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
for (int i = 0; i < num_input_dims - 1; i++)
|
||||
new_mean_shape[i] = @1;
|
||||
new_mean_shape[num_input_dims - 1] = input_shape[num_input_dims - 1];
|
||||
@ -49,28 +48,26 @@ void get_shapes(MPSShape* input_shape_readonly,
|
||||
axes[0] = @0;
|
||||
for (int i = 2; i < num_input_dims; i++)
|
||||
axes[i - 1] = [NSNumber numberWithInt:i];
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
for (int i = 0; i < num_input_dims - 1; i++)
|
||||
axes[i] = [NSNumber numberWithInt:i];
|
||||
}
|
||||
}
|
||||
|
||||
// Inverse standard deviation now becomes variance (without epsilon)
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
(const Tensor& self,
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
const c10::optional<Tensor>& running_mean_opt,
|
||||
const c10::optional<Tensor>& running_var_opt,
|
||||
bool train, double momentum, double epsilon,
|
||||
bool train,
|
||||
double momentum,
|
||||
double epsilon,
|
||||
Tensor& output,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var) {
|
||||
|
||||
namespace native_mps = at::native::mps;
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightTensor_ = nil;
|
||||
@ -98,11 +95,11 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
auto memory_format = self.suggest_memory_format();
|
||||
|
||||
if (output.numel() == 0) {
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_var);;
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_var);
|
||||
;
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
@ -130,13 +127,14 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "batch_norm_mps_out:" + mem_format_key + ":" + std::to_string(epsilon) + ":"
|
||||
+ std::to_string(momentum) + ":" + std::to_string(train) + ":"
|
||||
+ std::to_string(has_running_mean) + ":"
|
||||
+ std::to_string(has_weight) + ":" + std::to_string(has_bias) + ":"
|
||||
+ [ns_shape_key UTF8String] + ":"
|
||||
+ native_mps::getTensorsStringKey({
|
||||
self, weight_opt.value_or(Tensor()), bias_opt.value_or(Tensor()), running_mean_opt.value_or(Tensor()), running_var_opt.value_or(Tensor())});
|
||||
string key = "batch_norm_mps_out:" + mem_format_key + ":" + std::to_string(epsilon) + ":" +
|
||||
std::to_string(momentum) + ":" + std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" +
|
||||
std::to_string(has_weight) + ":" + std::to_string(has_bias) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
native_mps::getTensorsStringKey({self,
|
||||
weight_opt.value_or(Tensor()),
|
||||
bias_opt.value_or(Tensor()),
|
||||
running_mean_opt.value_or(Tensor()),
|
||||
running_var_opt.value_or(Tensor())});
|
||||
auto input_mps_dtype = native_mps::getMPSDataType(self);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
@ -155,7 +153,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -166,15 +163,19 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
MPSGraphTensor* weightTensor = nil;
|
||||
// Should have shape of mean
|
||||
if (has_weight)
|
||||
weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape);
|
||||
weightTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape);
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if (has_bias)
|
||||
biasTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()), new_mean_shape);
|
||||
biasTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(bias_opt.value()), new_mean_shape);
|
||||
MPSGraphTensor* runningMeanTensor = nil;
|
||||
MPSGraphTensor* runningVarTensor = nil;
|
||||
if (has_running_mean) {
|
||||
runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape);
|
||||
runningVarTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape);
|
||||
runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape);
|
||||
runningVarTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape);
|
||||
}
|
||||
|
||||
// Mean and inv std tensors to be saved and returned
|
||||
@ -207,12 +208,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
|
||||
if (train) {
|
||||
// Compute mean and variance of the current batch
|
||||
MPSGraphTensor* batchMeanTensor = [mpsGraph meanOfTensor:inputTensor
|
||||
axes:axes
|
||||
name:nil];
|
||||
MPSGraphTensor* batchVarianceTensor = [mpsGraph varianceOfTensor:inputTensor
|
||||
axes:axes
|
||||
name:nil];
|
||||
MPSGraphTensor* batchMeanTensor = [mpsGraph meanOfTensor:inputTensor axes:axes name:nil];
|
||||
MPSGraphTensor* batchVarianceTensor = [mpsGraph varianceOfTensor:inputTensor axes:axes name:nil];
|
||||
varTensor = batchVarianceTensor;
|
||||
if (has_running_mean) {
|
||||
// TODO: This is not the formula used in PyTorch, is this OK? Seems more robust
|
||||
@ -260,20 +257,15 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
secondaryTensor:epsilonTensor
|
||||
name:@"varianceEps"];
|
||||
|
||||
MPSGraphTensor *sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps
|
||||
name:@"sqrtVariance"];
|
||||
scaledInverseSqrtVariance = [mpsGraph reciprocalWithTensor:sqrtVariance
|
||||
name:nil];
|
||||
MPSGraphTensor* sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps name:@"sqrtVariance"];
|
||||
scaledInverseSqrtVariance = [mpsGraph reciprocalWithTensor:sqrtVariance name:nil];
|
||||
// Update saved mean and inverse std tensor
|
||||
saveMeanTensor = batchMeanTensor;
|
||||
saveVarTensor = scaledInverseSqrtVariance;
|
||||
}
|
||||
else { // Test
|
||||
} else { // Test
|
||||
TORCH_CHECK(has_running_mean);
|
||||
saveMeanTensor = [mpsGraph identityWithTensor:runningMeanTensor
|
||||
name:nil];
|
||||
saveVarTensor = [mpsGraph identityWithTensor:runningVarTensor
|
||||
name:nil];
|
||||
saveMeanTensor = [mpsGraph identityWithTensor:runningMeanTensor name:nil];
|
||||
saveVarTensor = [mpsGraph identityWithTensor:runningVarTensor name:nil];
|
||||
varTensor = saveVarTensor;
|
||||
}
|
||||
|
||||
@ -287,12 +279,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
name:nil];
|
||||
|
||||
// Reshape saved mean and var to fit output
|
||||
saveMeanTensor = [mpsGraph reshapeTensor:saveMeanTensor
|
||||
withShape:@[new_mean_shape[channelsDim]]
|
||||
name:nil];
|
||||
saveVarTensor = [mpsGraph reshapeTensor:saveVarTensor
|
||||
withShape:@[new_mean_shape[channelsDim]]
|
||||
name:nil];
|
||||
saveMeanTensor = [mpsGraph reshapeTensor:saveMeanTensor withShape:@[ new_mean_shape[channelsDim] ] name:nil];
|
||||
saveVarTensor = [mpsGraph reshapeTensor:saveVarTensor withShape:@[ new_mean_shape[channelsDim] ] name:nil];
|
||||
|
||||
if (train && has_running_mean) {
|
||||
// Running stats inplace update
|
||||
@ -330,16 +318,20 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
auto runningMeanPlaceholder = native_mps::Placeholder();
|
||||
auto runningVarPlaceholder = native_mps::Placeholder();
|
||||
if (has_running_mean) {
|
||||
runningMeanPlaceholder = native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape);
|
||||
runningVarPlaceholder = native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape);
|
||||
runningMeanPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape);
|
||||
runningVarPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape);
|
||||
}
|
||||
|
||||
auto runningMeanInplaceUpdatePlaceholder = native_mps::Placeholder();
|
||||
auto runningVarInplaceUpdatePlaceholder = native_mps::Placeholder();
|
||||
|
||||
if (train && has_running_mean) {
|
||||
runningMeanInplaceUpdatePlaceholder = native_mps::Placeholder(cachedGraph->runningMeanInplaceUpdate_, running_mean_opt.value());
|
||||
runningVarInplaceUpdatePlaceholder = native_mps::Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value());
|
||||
runningMeanInplaceUpdatePlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->runningMeanInplaceUpdate_, running_mean_opt.value());
|
||||
runningVarInplaceUpdatePlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value());
|
||||
}
|
||||
|
||||
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output, input_shape, false);
|
||||
@ -364,12 +356,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
|
||||
// If train and has_running_mean, add updated running mean to the output
|
||||
if (train && has_running_mean) {
|
||||
results[runningMeanInplaceUpdatePlaceholder.getMPSGraphTensor()] = runningMeanInplaceUpdatePlaceholder.getMPSGraphTensorData();
|
||||
results[runningVarInplaceUpdatePlaceholder.getMPSGraphTensor()] = runningVarInplaceUpdatePlaceholder.getMPSGraphTensorData();
|
||||
results[runningMeanInplaceUpdatePlaceholder.getMPSGraphTensor()] =
|
||||
runningMeanInplaceUpdatePlaceholder.getMPSGraphTensorData();
|
||||
results[runningVarInplaceUpdatePlaceholder.getMPSGraphTensor()] =
|
||||
runningVarInplaceUpdatePlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
|
||||
if (!train) {
|
||||
@ -379,8 +372,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_var);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_mps
|
||||
(const Tensor& self,
|
||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_mps(const Tensor& self,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
const c10::optional<Tensor>& running_mean_opt,
|
||||
@ -388,21 +380,14 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_mps
|
||||
bool train,
|
||||
double momentum,
|
||||
double epsilon) {
|
||||
|
||||
const auto memory_format = self.suggest_memory_format();
|
||||
|
||||
auto output = at::native::empty_mps(
|
||||
self.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
memory_format);
|
||||
auto output =
|
||||
at::native::empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format);
|
||||
|
||||
int64_t n_input = self.size(1);
|
||||
|
||||
auto save_mean = at::native::empty_mps(
|
||||
{n_input},
|
||||
auto save_mean = at::native::empty_mps({n_input},
|
||||
self.scalar_type(),
|
||||
// TODO: Accumulate type?
|
||||
// at::toAccumulateType(self.scalar_type(), /*is_cuda=*/false),
|
||||
@ -410,8 +395,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_mps
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
auto save_var = at::native::empty_mps(
|
||||
{n_input},
|
||||
auto save_var = at::native::empty_mps({n_input},
|
||||
self.scalar_type(),
|
||||
// TODO: Accumulate type?
|
||||
// at::toAccumulateType(self.scalar_type(), /*is_cuda=*/false),
|
||||
@ -420,8 +404,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_mps
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
|
||||
at::native::batch_norm_mps_out(
|
||||
self,
|
||||
at::native::batch_norm_mps_out(self,
|
||||
weight_opt,
|
||||
bias_opt,
|
||||
running_mean_opt,
|
||||
@ -435,8 +418,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_mps
|
||||
return std::make_tuple(output, save_mean, save_var);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps
|
||||
(const Tensor& self,
|
||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
Tensor& running_mean,
|
||||
@ -444,43 +426,44 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps
|
||||
bool train,
|
||||
double momentum,
|
||||
double epsilon) {
|
||||
|
||||
return batch_norm_mps(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_mps
|
||||
(const Tensor& self,
|
||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_mps(const Tensor& self,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
bool train,
|
||||
double momentum,
|
||||
double epsilon) {
|
||||
|
||||
return batch_norm_mps(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_mps_out
|
||||
(const Tensor& self,
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_mps_out(const Tensor& self,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
Tensor& running_mean,
|
||||
Tensor& running_var,
|
||||
bool train, double momentum, double epsilon,
|
||||
bool train,
|
||||
double momentum,
|
||||
double epsilon,
|
||||
Tensor& output,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var) {
|
||||
return batch_norm_mps_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_var);
|
||||
return batch_norm_mps_out(
|
||||
self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_var);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_mps_out
|
||||
(const Tensor& self,
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_mps_out(const Tensor& self,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
bool train, double momentum, double epsilon,
|
||||
bool train,
|
||||
double momentum,
|
||||
double epsilon,
|
||||
Tensor& output,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var) {
|
||||
return batch_norm_mps_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var);
|
||||
return batch_norm_mps_out(
|
||||
self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var);
|
||||
}
|
||||
|
||||
string get_mem_string(c10::MemoryFormat memory_format) {
|
||||
@ -500,8 +483,7 @@ string get_mem_string(c10::MemoryFormat memory_format) {
|
||||
}
|
||||
|
||||
// Batch norm backward
|
||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
(const Tensor& grad_out,
|
||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out,
|
||||
const Tensor& input,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& running_mean_opt,
|
||||
@ -511,7 +493,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
bool train,
|
||||
double epsilon,
|
||||
std::array<bool, 3> grad_input_mask) {
|
||||
|
||||
Tensor grad_input;
|
||||
Tensor grad_weight;
|
||||
Tensor grad_bias;
|
||||
@ -519,12 +500,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
|
||||
if (grad_input_mask[0]) {
|
||||
grad_input = at::native::empty_mps(input.sizes(),
|
||||
input.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
memory_format);
|
||||
grad_input =
|
||||
at::native::empty_mps(input.sizes(), input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format);
|
||||
}
|
||||
// Assuming that if grad_input_mask of weight is 1, then the weight is available
|
||||
if (grad_input_mask[1]) {
|
||||
@ -547,8 +524,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
namespace native_mps = at::native::mps;
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
@ -580,7 +556,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
@ -605,17 +580,14 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "batch_norm_backward_mps:" + mem_format_key + ":" + std::to_string(epsilon) + ":"
|
||||
+ std::to_string(train) + ":"
|
||||
+ std::to_string(has_running_mean) + ":"
|
||||
+ std::to_string(has_weight) + ":"
|
||||
+ [ns_shape_key UTF8String] + ":" + native_mps::getMPSTypeString(input);
|
||||
string key = "batch_norm_backward_mps:" + mem_format_key + ":" + std::to_string(epsilon) + ":" +
|
||||
std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" + std::to_string(has_weight) + ":" +
|
||||
[ns_shape_key UTF8String] + ":" + native_mps::getMPSTypeString(input);
|
||||
auto input_mps_dtype = native_mps::getMPSDataType(input);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -625,25 +597,32 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
// NCHW - Channels dim is 1
|
||||
int channelsDim = 1;
|
||||
|
||||
MPSGraphTensor* inputTensorOriginal = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape);
|
||||
MPSGraphTensor* inputTensorOriginal =
|
||||
native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape);
|
||||
// Shape is the ORIGINAL NCHW shape
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_out), input_shape_readonly);
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(grad_out), input_shape_readonly);
|
||||
MPSGraphTensor* weightTensor = nil;
|
||||
if (has_weight)
|
||||
weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape);
|
||||
weightTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape);
|
||||
MPSGraphTensor* runningMeanTensor = nil;
|
||||
MPSGraphTensor* runningVarTensor = nil;
|
||||
if (has_running_mean) {
|
||||
runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape);
|
||||
runningVarTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape);
|
||||
runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape);
|
||||
runningVarTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape);
|
||||
}
|
||||
|
||||
// Mean and inv std tensors to be saved and returned
|
||||
MPSGraphTensor* saveMeanTensor = nil;
|
||||
MPSGraphTensor* saveVarTensor = nil;
|
||||
if (has_save_mean) {
|
||||
saveMeanTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(save_mean_opt.value()), new_mean_shape);
|
||||
saveVarTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(save_var_opt.value()), new_mean_shape);
|
||||
saveMeanTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(save_mean_opt.value()), new_mean_shape);
|
||||
saveVarTensor = native_mps::mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, native_mps::getMPSDataType(save_var_opt.value()), new_mean_shape);
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradInputTensor = nil;
|
||||
@ -663,21 +642,15 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
inputTensor = [mpsGraph reshapeTensor:inputTensorOriginal
|
||||
withShape:@[ N, ([NSNumber numberWithInt:[H intValue] * [W intValue]]), C ]
|
||||
name:nil];
|
||||
inputTensor = [mpsGraph transposeTensor:inputTensor
|
||||
dimension:1
|
||||
withDimension:2
|
||||
name:nil];
|
||||
inputTensor = [mpsGraph reshapeTensor:inputTensor
|
||||
withShape:@[N, C, H, W]
|
||||
name:nil];
|
||||
inputTensor = [mpsGraph transposeTensor:inputTensor dimension:1 withDimension:2 name:nil];
|
||||
inputTensor = [mpsGraph reshapeTensor:inputTensor withShape:@[ N, C, H, W ] name:nil];
|
||||
}
|
||||
|
||||
if (train) {
|
||||
// Use save_mean and save_var
|
||||
MPSGraphTensor* epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:input_mps_dtype];
|
||||
MPSGraphTensor* revertSaveVarTensor = saveVarTensor;
|
||||
revertSaveVarTensor = [mpsGraph reciprocalWithTensor: revertSaveVarTensor
|
||||
name: nil];
|
||||
revertSaveVarTensor = [mpsGraph reciprocalWithTensor:revertSaveVarTensor name:nil];
|
||||
revertSaveVarTensor = [mpsGraph multiplicationWithPrimaryTensor:revertSaveVarTensor
|
||||
secondaryTensor:revertSaveVarTensor
|
||||
name:nil];
|
||||
@ -711,32 +684,26 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
epsilon:(float)epsilon
|
||||
name:nil];
|
||||
}
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// Use running mean and running var
|
||||
MPSGraphTensor* rsqrtTensor = nil;
|
||||
MPSGraphTensor* epsilonTensor = nil;
|
||||
if (grad_input_mask[1]) {
|
||||
epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon
|
||||
shape:@[@1]
|
||||
dataType:input_mps_dtype];
|
||||
epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon shape:@[ @1 ] dataType:input_mps_dtype];
|
||||
MPSGraphTensor* xMinusMean = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:runningMeanTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor
|
||||
secondaryTensor:epsilonTensor
|
||||
name:nil];
|
||||
rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor
|
||||
name:nil];
|
||||
rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil];
|
||||
MPSGraphTensor* bnForwardTensor = [mpsGraph multiplicationWithPrimaryTensor:xMinusMean
|
||||
secondaryTensor:rsqrtTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* gradBnMulTensor = [mpsGraph multiplicationWithPrimaryTensor:bnForwardTensor
|
||||
secondaryTensor:gradOutputTensor
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor
|
||||
axes:axes
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor axes:axes name:nil];
|
||||
}
|
||||
if (grad_input_mask[2]) {
|
||||
gradBiasTensor = [mpsGraph normalizationBetaGradientWithIncomingGradientTensor:gradOutputTensor
|
||||
@ -745,20 +712,16 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
name:nil];
|
||||
}
|
||||
if (grad_input_mask[0]) {
|
||||
|
||||
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:input_shape_readonly
|
||||
dataType:input_mps_dtype];
|
||||
if (!epsilonTensor)
|
||||
epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon
|
||||
shape:@[@1]
|
||||
dataType:input_mps_dtype];
|
||||
epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon shape:@[ @1 ] dataType:input_mps_dtype];
|
||||
if (!rsqrtTensor) {
|
||||
MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor
|
||||
secondaryTensor:epsilonTensor
|
||||
name:nil];
|
||||
rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor
|
||||
name:nil];
|
||||
rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil];
|
||||
}
|
||||
|
||||
gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:unitTensor
|
||||
@ -796,16 +759,12 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
auto W = input_shape[2];
|
||||
auto C = input_shape[3];
|
||||
|
||||
gradInputTensorFinal = [mpsGraph reshapeTensor:gradInputTensor
|
||||
gradInputTensorFinal =
|
||||
[mpsGraph reshapeTensor:gradInputTensor
|
||||
withShape:@[ N, C, ([NSNumber numberWithInt:[H intValue] * [W intValue]]) ]
|
||||
name:nil];
|
||||
gradInputTensorFinal = [mpsGraph transposeTensor:gradInputTensorFinal
|
||||
dimension:1
|
||||
withDimension:2
|
||||
name:nil];
|
||||
gradInputTensorFinal = [mpsGraph reshapeTensor:gradInputTensorFinal
|
||||
withShape:@[N, H, W, C]
|
||||
name:nil];
|
||||
gradInputTensorFinal = [mpsGraph transposeTensor:gradInputTensorFinal dimension:1 withDimension:2 name:nil];
|
||||
gradInputTensorFinal = [mpsGraph reshapeTensor:gradInputTensorFinal withShape:@[ N, H, W, C ] name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
@ -825,20 +784,24 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
}
|
||||
|
||||
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input, input_shape);
|
||||
auto gradOutputPlaceholder = native_mps::Placeholder(cachedGraph->gradOutputTensor_, grad_out, input_shape_readonly);
|
||||
auto gradOutputPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->gradOutputTensor_, grad_out, input_shape_readonly);
|
||||
auto weightPlaceholder = native_mps::Placeholder();
|
||||
if (has_weight)
|
||||
weightPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape);
|
||||
auto runningMeanPlaceholder = native_mps::Placeholder();
|
||||
auto runningVarPlaceholder = native_mps::Placeholder();
|
||||
if (has_running_mean) {
|
||||
runningMeanPlaceholder = native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape);
|
||||
runningVarPlaceholder = native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape);
|
||||
runningMeanPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape);
|
||||
runningVarPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape);
|
||||
}
|
||||
auto saveMeanPlaceholder = native_mps::Placeholder();
|
||||
auto saveVarPlaceholder = native_mps::Placeholder();
|
||||
if (has_save_mean) {
|
||||
saveMeanPlaceholder = native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean_opt.value(), new_mean_shape);
|
||||
saveMeanPlaceholder =
|
||||
native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean_opt.value(), new_mean_shape);
|
||||
saveVarPlaceholder = native_mps::Placeholder(cachedGraph->saveVarTensor_, save_var_opt.value(), new_mean_shape);
|
||||
}
|
||||
|
||||
@ -848,7 +811,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
auto gradWeightPlaceholder = native_mps::Placeholder();
|
||||
if (grad_input_mask[1])
|
||||
gradWeightPlaceholder = native_mps::Placeholder(cachedGraph->gradWeightTensor_, grad_weight);
|
||||
auto gradBiasPlaceholder = native_mps::Placeholder();;
|
||||
auto gradBiasPlaceholder = native_mps::Placeholder();
|
||||
;
|
||||
if (grad_input_mask[2])
|
||||
gradBiasPlaceholder = native_mps::Placeholder(cachedGraph->gradBiasTensor_, grad_bias);
|
||||
|
||||
@ -875,21 +839,17 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps
|
||||
results[gradBiasPlaceholder.getMPSGraphTensor()] = gradBiasPlaceholder.getMPSGraphTensorData();
|
||||
|
||||
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
|
||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||
|
||||
}
|
||||
|
||||
// Layer norm forward for MPS
|
||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(
|
||||
const Tensor& input,
|
||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
double eps) {
|
||||
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
@ -910,9 +870,14 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(
|
||||
// entire channel/plane with the affine option, Layer Normalization applies
|
||||
// per-element scale and bias. E.g. For input {N, C, H, W}, weight for
|
||||
// batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
|
||||
auto outputs = at::native_batch_norm(
|
||||
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
|
||||
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
|
||||
auto outputs = at::native_batch_norm(input_reshaped,
|
||||
/*weight=*/{},
|
||||
/*bias=*/{},
|
||||
/*running_mean=*/{},
|
||||
/*running_var=*/{},
|
||||
/*training=*/true,
|
||||
/*momentum=*/0,
|
||||
eps);
|
||||
at::Tensor out = std::get<0>(outputs);
|
||||
out = out.view(input_shape);
|
||||
if (weight.defined() && bias.defined()) {
|
||||
@ -938,8 +903,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(
|
||||
return std::make_tuple(out, mean, variance);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
const Tensor& grad_out,
|
||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_out,
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const Tensor& mean,
|
||||
@ -947,12 +911,9 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
const c10::optional<Tensor>& weight_opt /* optional */,
|
||||
const c10::optional<Tensor>& bias_opt /* optional */,
|
||||
std::array<bool, 3> grad_input_mask) {
|
||||
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned =
|
||||
at::borrow_from_optional_tensor(weight_opt);
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned =
|
||||
at::borrow_from_optional_tensor(bias_opt);
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
|
||||
@ -967,8 +928,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
Tensor grad_weight;
|
||||
Tensor grad_bias;
|
||||
if (grad_input_mask[0]) {
|
||||
grad_input = at::native::empty_like(
|
||||
*X,
|
||||
grad_input = at::native::empty_like(*X,
|
||||
c10::nullopt /* dtype */,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS /* device */,
|
||||
@ -976,15 +936,13 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
at::MemoryFormat::Contiguous);
|
||||
}
|
||||
if (grad_input_mask[1]) {
|
||||
grad_weight = M > 0 ? at::native::empty_like(
|
||||
*gamma,
|
||||
grad_weight = M > 0 ? at::native::empty_like(*gamma,
|
||||
c10::nullopt /* dtype */,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS /* device */,
|
||||
c10::nullopt /* pin_memory */,
|
||||
at::MemoryFormat::Contiguous)
|
||||
: at::native::zeros_like(
|
||||
*gamma,
|
||||
: at::native::zeros_like(*gamma,
|
||||
c10::nullopt /* dtype */,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS /* device */,
|
||||
@ -992,15 +950,13 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
at::MemoryFormat::Contiguous);
|
||||
}
|
||||
if (grad_input_mask[2]) {
|
||||
grad_bias = M > 0 ? at::native::empty_like(
|
||||
*beta,
|
||||
grad_bias = M > 0 ? at::native::empty_like(*beta,
|
||||
c10::nullopt /* dtype */,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS /* device */,
|
||||
c10::nullopt /* pin_memory */,
|
||||
at::MemoryFormat::Contiguous)
|
||||
: at::native::zeros_like(
|
||||
*beta,
|
||||
: at::native::zeros_like(*beta,
|
||||
c10::nullopt /* dtype */,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS /* device */,
|
||||
@ -1008,12 +964,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
at::MemoryFormat::Contiguous);
|
||||
}
|
||||
if (M > 0) {
|
||||
|
||||
namespace native_mps = at::native::mps;
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public native_mps::MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
@ -1038,7 +992,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
// const auto memory_format = input.suggest_memory_format();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* input_shape = mps::getMPSShape(*X);
|
||||
MPSShape* gamma_shape = mps::getMPSShape(normalized_shape);
|
||||
|
||||
@ -1066,7 +1019,8 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
|
||||
// Shape of mean to do "batch norm" backward
|
||||
// This is [1, M, [1,1,1..1]]
|
||||
NSMutableArray<NSNumber*>* bn_mean_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims+2)];
|
||||
NSMutableArray<NSNumber*>* bn_mean_shape =
|
||||
[NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
|
||||
bn_mean_shape[0] = [NSNumber numberWithInt:1];
|
||||
bn_mean_shape[1] = [NSNumber numberWithInt:M];
|
||||
for (int i = 0; i < num_normalized_dims; i++)
|
||||
@ -1074,23 +1028,21 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
|
||||
// Shape of gamma to multiply with "batch norm" backward
|
||||
// This is [1, 1, -1]
|
||||
NSMutableArray<NSNumber*>* bn_gamma_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims+2)];
|
||||
NSMutableArray<NSNumber*>* bn_gamma_shape =
|
||||
[NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
|
||||
bn_gamma_shape[0] = [NSNumber numberWithInt:1];
|
||||
bn_gamma_shape[1] = [NSNumber numberWithInt:1];
|
||||
for (int i = 0; i < num_normalized_dims; i++)
|
||||
bn_gamma_shape[i + 2] = input_shape[i + num_channel_dims];
|
||||
|
||||
string key = "layer_norm_backward_mps:"
|
||||
+ std::to_string(has_weight) + ":"
|
||||
+ native_mps::getArrayRefString(normalized_shape) + ":"
|
||||
+ native_mps::getArrayRefString((*X).sizes()) + ":"
|
||||
+ native_mps::getMPSTypeString(*X);
|
||||
string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" +
|
||||
native_mps::getArrayRefString(normalized_shape) + ":" + native_mps::getArrayRefString((*X).sizes()) + ":" +
|
||||
native_mps::getMPSTypeString(*X);
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -1121,46 +1073,31 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
MPSGraphTensor* gradBnMulTensor = [mpsGraph multiplicationWithPrimaryTensor:bnForwardTensor
|
||||
secondaryTensor:gradOutputTensor
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor
|
||||
axes:gamma_axes
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor axes:gamma_axes name:nil];
|
||||
}
|
||||
if (grad_input_mask[2]) {
|
||||
gradBiasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor
|
||||
axes:gamma_axes
|
||||
name:nil];
|
||||
gradBiasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axes:gamma_axes name:nil];
|
||||
}
|
||||
if (grad_input_mask[0]) {
|
||||
|
||||
// Reshape input to [1, M, -1]
|
||||
// Reshape mean and rstd to [1, M, -1]
|
||||
// Reshape gamma to [1, 1, -1] (-1 has N dims)
|
||||
|
||||
MPSGraphTensor* bnInputTensor = [mpsGraph reshapeTensor:inputTensor
|
||||
withShape:bn_shape
|
||||
name:nil];
|
||||
MPSGraphTensor* bnInputTensor = [mpsGraph reshapeTensor:inputTensor withShape:bn_shape name:nil];
|
||||
MPSGraphTensor* bnGradOutputTensor = [mpsGraph reshapeTensor:gradOutputTensor
|
||||
withShape:bn_shape
|
||||
name:nil];
|
||||
// Do this at the end
|
||||
if (has_weight) {
|
||||
MPSGraphTensor* bnGammaTensor = [mpsGraph reshapeTensor:weightTensor
|
||||
withShape:bn_gamma_shape
|
||||
name:nil];
|
||||
MPSGraphTensor* bnGammaTensor = [mpsGraph reshapeTensor:weightTensor withShape:bn_gamma_shape name:nil];
|
||||
bnGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:bnGradOutputTensor
|
||||
secondaryTensor:bnGammaTensor
|
||||
name:nil];
|
||||
}
|
||||
MPSGraphTensor* bnMeanTensor = [mpsGraph reshapeTensor:meanTensor
|
||||
withShape:bn_mean_shape
|
||||
name:nil];
|
||||
MPSGraphTensor* bnRstdTensor = [mpsGraph reshapeTensor:rstdTensor
|
||||
withShape:bn_mean_shape
|
||||
name:nil];
|
||||
MPSGraphTensor* bnMeanTensor = [mpsGraph reshapeTensor:meanTensor withShape:bn_mean_shape name:nil];
|
||||
MPSGraphTensor* bnRstdTensor = [mpsGraph reshapeTensor:rstdTensor withShape:bn_mean_shape name:nil];
|
||||
|
||||
MPSGraphTensor* mulTensor = [mpsGraph constantWithScalar:N
|
||||
shape:@[@1]
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* mulTensor = [mpsGraph constantWithScalar:N shape:@[ @1 ] dataType:MPSDataTypeInt32];
|
||||
|
||||
MPSGraphTensor* numberToReduceTensor = mulTensor;
|
||||
|
||||
@ -1168,8 +1105,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
toType:bnInputTensor.dataType
|
||||
name:@"cast2Tensor"];
|
||||
|
||||
MPSGraphTensor* sizeReciprocalTensor = [mpsGraph reciprocalWithTensor:cast2Tensor
|
||||
name:nil];
|
||||
MPSGraphTensor* sizeReciprocalTensor = [mpsGraph reciprocalWithTensor:cast2Tensor name:nil];
|
||||
|
||||
// TODO: Reduce redundant computation
|
||||
MPSGraphTensor* xMinusMean = [mpsGraph subtractionWithPrimaryTensor:bnInputTensor
|
||||
@ -1184,13 +1120,9 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
secondaryTensor:normalizedTensor
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* gammaGradient = [mpsGraph reductionSumWithTensor:bnGradMulTensor
|
||||
axes:bn_axes
|
||||
name:nil];
|
||||
MPSGraphTensor* gammaGradient = [mpsGraph reductionSumWithTensor:bnGradMulTensor axes:bn_axes name:nil];
|
||||
|
||||
MPSGraphTensor* betaGradient = [mpsGraph reductionSumWithTensor:bnGradOutputTensor
|
||||
axes:bn_axes
|
||||
name:nil];
|
||||
MPSGraphTensor* betaGradient = [mpsGraph reductionSumWithTensor:bnGradOutputTensor axes:bn_axes name:nil];
|
||||
|
||||
MPSGraphTensor* gradient1 = [mpsGraph multiplicationWithPrimaryTensor:bnGradOutputTensor
|
||||
secondaryTensor:bnRstdTensor
|
||||
@ -1201,8 +1133,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
name:nil];
|
||||
|
||||
// reverseVariance is square of rstd
|
||||
MPSGraphTensor* reverseVariance = [mpsGraph squareWithTensor:bnRstdTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* reverseVariance = [mpsGraph squareWithTensor:bnRstdTensor name:nil];
|
||||
MPSGraphTensor* gradient2_2 = [mpsGraph multiplicationWithPrimaryTensor:gammaGradient
|
||||
secondaryTensor:reverseVariance
|
||||
name:nil];
|
||||
@ -1227,21 +1158,14 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
secondaryTensor:gradient3
|
||||
name:nil];
|
||||
|
||||
gradInputTensor = [mpsGraph reshapeTensor:gradient
|
||||
withShape:input_shape
|
||||
name:nil];
|
||||
|
||||
gradInputTensor = [mpsGraph reshapeTensor:gradient withShape:input_shape name:nil];
|
||||
}
|
||||
|
||||
if (grad_input_mask[1]) {
|
||||
gradWeightTensor = [mpsGraph reshapeTensor:gradWeightTensor
|
||||
withShape:gamma_shape
|
||||
name:nil];
|
||||
gradWeightTensor = [mpsGraph reshapeTensor:gradWeightTensor withShape:gamma_shape name:nil];
|
||||
}
|
||||
if (grad_input_mask[2]) {
|
||||
gradBiasTensor = [mpsGraph reshapeTensor:gradBiasTensor
|
||||
withShape:gamma_shape
|
||||
name:nil];
|
||||
gradBiasTensor = [mpsGraph reshapeTensor:gradBiasTensor withShape:gamma_shape name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
@ -1272,7 +1196,8 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
auto gradWeightPlaceholder = native_mps::Placeholder();
|
||||
if (grad_input_mask[1])
|
||||
gradWeightPlaceholder = native_mps::Placeholder(cachedGraph->gradWeightTensor_, grad_weight);
|
||||
auto gradBiasPlaceholder = native_mps::Placeholder();;
|
||||
auto gradBiasPlaceholder = native_mps::Placeholder();
|
||||
;
|
||||
if (grad_input_mask[2])
|
||||
gradBiasPlaceholder = native_mps::Placeholder(cachedGraph->gradBiasTensor_, grad_bias);
|
||||
|
||||
@ -1293,12 +1218,9 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
|
||||
results[gradBiasPlaceholder.getMPSGraphTensor()] = gradBiasPlaceholder.getMPSGraphTensorData();
|
||||
|
||||
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias));
|
||||
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -7,15 +7,18 @@ namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
// Pad operations (1D/2D/3D forward and backward)
|
||||
Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding,
|
||||
Tensor& pad_out_template(Tensor& output,
|
||||
const Tensor& input_,
|
||||
IntArrayRef padding,
|
||||
const c10::optional<Tensor>& grad_output_opt,
|
||||
MPSGraphPaddingMode mode, double constantValue, const string op_name)
|
||||
{
|
||||
MPSGraphPaddingMode mode,
|
||||
double constantValue,
|
||||
const string op_name) {
|
||||
const int padding_size = (int)padding.size();
|
||||
int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
|
||||
|
||||
TORCH_CHECK(padding_size == 2 || padding_size == 4 || padding_size == 6,
|
||||
"invalid padding argument of size ", padding_size);
|
||||
TORCH_CHECK(
|
||||
padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size);
|
||||
|
||||
const Tensor& grad_output_ = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const bool is_backward_pass = grad_output_.defined();
|
||||
@ -23,8 +26,13 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
int64_t nbatch = 1;
|
||||
int64_t ndims = input_.ndimension();
|
||||
|
||||
TORCH_CHECK(ndims >= (int64_t)padding_dim, "Length of pad should be no more than twice the number of "
|
||||
"dimensions of the input. Pad length is ", padding_size, "while the input has ", ndims, "dimensions.");
|
||||
TORCH_CHECK(ndims >= (int64_t)padding_dim,
|
||||
"Length of pad should be no more than twice the number of "
|
||||
"dimensions of the input. Pad length is ",
|
||||
padding_size,
|
||||
"while the input has ",
|
||||
ndims,
|
||||
"dimensions.");
|
||||
|
||||
// number of input dims with ConstantPad could be less than 2
|
||||
int dim_w = padding_dim;
|
||||
@ -36,7 +44,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0;
|
||||
TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) ||
|
||||
(ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0),
|
||||
"3D or 4D (batch mode) tensor expected for input, but got: ", input_);
|
||||
"3D or 4D (batch mode) tensor expected for input, but got: ",
|
||||
input_);
|
||||
}
|
||||
|
||||
if (ndims == padding_dim) {
|
||||
@ -73,8 +82,15 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
|
||||
if (!is_backward_pass) {
|
||||
TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1,
|
||||
"input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
|
||||
"output H: ", output_h, " W: ", output_w);
|
||||
"input (H: ",
|
||||
input_h,
|
||||
", W: ",
|
||||
input_w,
|
||||
") is too small. Calculated "
|
||||
"output H: ",
|
||||
output_h,
|
||||
" W: ",
|
||||
output_w);
|
||||
|
||||
std::vector<int64_t> outputSizes;
|
||||
if (mode == MPSGraphPaddingModeConstant) {
|
||||
@ -95,20 +111,38 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
// these checks aren't relevant for constant pad
|
||||
TORCH_CHECK(pad_l < input_w && pad_r < input_w,
|
||||
"Argument #4: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (", pad_l, ", ", pad_r,
|
||||
") at dimension ", dim_w, " of input ", ndims);
|
||||
"input dimension, but got: padding (",
|
||||
pad_l,
|
||||
", ",
|
||||
pad_r,
|
||||
") at dimension ",
|
||||
dim_w,
|
||||
" of input ",
|
||||
ndims);
|
||||
|
||||
if (padding_dim > 1) {
|
||||
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
|
||||
"Argument #6: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (", pad_t, ", ", pad_b,
|
||||
") at dimension ", dim_h, " of input ", ndims);
|
||||
"input dimension, but got: padding (",
|
||||
pad_t,
|
||||
", ",
|
||||
pad_b,
|
||||
") at dimension ",
|
||||
dim_h,
|
||||
" of input ",
|
||||
ndims);
|
||||
}
|
||||
if (padding_dim > 2) {
|
||||
TORCH_CHECK(pad_front < input_d && pad_back < input_d,
|
||||
"Argument #8: Padding size should be less than the corresponding "
|
||||
"input dimension, but got: padding (", pad_front, ", ", pad_back,
|
||||
") at dimension ", dim_d, " of input ", ndims);
|
||||
"input dimension, but got: padding (",
|
||||
pad_front,
|
||||
", ",
|
||||
pad_back,
|
||||
") at dimension ",
|
||||
dim_d,
|
||||
" of input ",
|
||||
ndims);
|
||||
}
|
||||
outputSizes.insert(outputSizes.begin(), output_w);
|
||||
if (padding_dim >= 2)
|
||||
@ -133,10 +167,16 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
input = input_.contiguous();
|
||||
} else {
|
||||
TORCH_CHECK(output_w == grad_output_.size(dim_w),
|
||||
"gradOutput width unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
|
||||
"gradOutput width unexpected. Expected: ",
|
||||
output_w,
|
||||
", Got: ",
|
||||
grad_output_.size(dim_w));
|
||||
if (padding_dim > 1) {
|
||||
TORCH_CHECK(output_h == grad_output_.size(dim_h),
|
||||
"gradOutput height unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
|
||||
"gradOutput height unexpected. Expected: ",
|
||||
output_h,
|
||||
", Got: ",
|
||||
grad_output_.size(dim_h));
|
||||
}
|
||||
output.resize_as_(input);
|
||||
if (output.numel() == 0 || grad_output_.numel() == 0)
|
||||
@ -197,8 +237,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" +
|
||||
getArrayRefString(padding) + "]:" + std::to_string(constantValue);
|
||||
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
|
||||
"]:" + std::to_string(constantValue);
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
@ -219,7 +259,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
name:nil];
|
||||
// workaround for the right padding bug in Monterey
|
||||
if (needsSlice) {
|
||||
newCachedGraph->outputTensor = [mpsGraph sliceTensor:padTensor
|
||||
newCachedGraph->outputTensor =
|
||||
[mpsGraph sliceTensor:padTensor
|
||||
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
|
||||
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
|
||||
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
|
||||
@ -232,7 +273,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
}
|
||||
} else {
|
||||
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output));
|
||||
MPSGraphTensor *padGradTensor = [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor
|
||||
MPSGraphTensor* padGradTensor =
|
||||
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor
|
||||
sourceTensor:newCachedGraph->inputTensor
|
||||
paddingMode:mode
|
||||
leftPadding:leftPadding
|
||||
@ -240,7 +282,8 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
name:nil];
|
||||
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
|
||||
if (needsSlice) {
|
||||
newCachedGraph->outputTensor = [mpsGraph sliceGradientTensor:padGradTensor
|
||||
newCachedGraph->outputTensor =
|
||||
[mpsGraph sliceGradientTensor:padGradTensor
|
||||
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil]
|
||||
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
|
||||
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
|
||||
@ -259,17 +302,17 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
}
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr, true, dataType);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() :
|
||||
Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass
|
||||
? Placeholder()
|
||||
: Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
if (is_backward_pass) {
|
||||
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return output;
|
||||
@ -278,123 +321,156 @@ Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef paddi
|
||||
|
||||
// 1D Reflection and Replication Padding
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad1d_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
|
||||
{
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps");
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input),
|
||||
input,
|
||||
padding,
|
||||
grad_output,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad1d_backward_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad1d_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
|
||||
{
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps");
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input),
|
||||
input,
|
||||
padding,
|
||||
grad_output,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad1d_backward_out_mps");
|
||||
}
|
||||
|
||||
// 2D Reflection and Replication Padding
|
||||
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output)
|
||||
{
|
||||
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) {
|
||||
return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) {
|
||||
Tensor output = at::empty({0}, input.options());
|
||||
return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
|
||||
{
|
||||
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
|
||||
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad2d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad2d_out_mps");
|
||||
}
|
||||
|
||||
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
|
||||
{
|
||||
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
|
||||
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
// 3D Reflection and Replication Padding
|
||||
TORCH_IMPL_FUNC(reflection_pad3d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad3d_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
|
||||
{
|
||||
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
|
||||
MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps");
|
||||
mps::pad_out_template(const_cast<Tensor&>(grad_input),
|
||||
input,
|
||||
padding,
|
||||
grad_output,
|
||||
MPSGraphPaddingModeReflect,
|
||||
0.0,
|
||||
"reflection_pad3d_backward_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(replication_pad3d_out_mps)
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output)
|
||||
{
|
||||
mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps");
|
||||
(const Tensor& input, IntArrayRef padding, const Tensor& output) {
|
||||
mps::pad_out_template(const_cast<Tensor&>(output),
|
||||
input,
|
||||
padding,
|
||||
c10::nullopt,
|
||||
MPSGraphPaddingModeClampToEdge,
|
||||
0.0,
|
||||
"replication_pad3d_out_mps");
|
||||
}
|
||||
|
||||
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
|
||||
{
|
||||
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef padding,
|
||||
Tensor& grad_input) {
|
||||
grad_input.resize_as_(input).zero_();
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
|
||||
{
|
||||
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
|
||||
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
|
||||
}
|
||||
|
||||
// backward pass is exlicitly handled in autograd by negating the "pad" argument
|
||||
Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value)
|
||||
{
|
||||
Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) {
|
||||
if (pad.size() > 6) {
|
||||
TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ",
|
||||
"It uses View Ops default implementation to run. This may have performance implications.");
|
||||
return at::native::constant_pad_nd(self, pad, value);
|
||||
}
|
||||
Tensor output = at::empty({0}, self.options());
|
||||
return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
|
||||
return mps::pad_out_template(
|
||||
output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -12,8 +12,7 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
const Scalar& value_opt, // default value = 1.0
|
||||
const Tensor& output,
|
||||
const bool is_div,
|
||||
const string op_name)
|
||||
{
|
||||
const string op_name) {
|
||||
if (value_opt.toDouble() == 0.0) {
|
||||
output.copy_(self);
|
||||
return;
|
||||
@ -25,8 +24,7 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *firstTensor = nil, *secondTensor = nil, *valueTensor = nil;
|
||||
@ -40,9 +38,9 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type()));
|
||||
ScalarType common_dtype =
|
||||
c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type()));
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
@ -50,26 +48,27 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
|
||||
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
|
||||
newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]);
|
||||
newCachedGraph->valueTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[ @1 ]);
|
||||
|
||||
// the tensor to be optionally multiplied by value_scalar
|
||||
MPSGraphTensor* multiplicandTensor = nil;
|
||||
auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype);
|
||||
auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype);
|
||||
if (is_div) {
|
||||
multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor
|
||||
secondaryTensor:secondTensor
|
||||
name:nil];
|
||||
multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor secondaryTensor:secondTensor name:nil];
|
||||
} else {
|
||||
multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor
|
||||
secondaryTensor:secondTensor
|
||||
name:nil];
|
||||
}
|
||||
// the tensor to be added to input_tensor
|
||||
MPSGraphTensor *addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor
|
||||
MPSGraphTensor* addendTensor = [mpsGraph
|
||||
multiplicationWithPrimaryTensor:multiplicandTensor
|
||||
secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype)
|
||||
name:nil];
|
||||
auto outputTensor = [mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype)
|
||||
auto outputTensor =
|
||||
[mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype)
|
||||
secondaryTensor:addendTensor
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type());
|
||||
@ -93,9 +92,8 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_scalar),
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -105,14 +103,12 @@ void addc_mul_div_out_mps(const Tensor& self,
|
||||
|
||||
// APIs exposed to at::native scope
|
||||
TORCH_IMPL_FUNC(addcmul_out_mps)
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output)
|
||||
{
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) {
|
||||
mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, false, "addcmul_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addcdiv_out_mps)
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output)
|
||||
{
|
||||
(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) {
|
||||
mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, true, "addcdiv_out_mps");
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,12 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct PoolingCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct PoolingCachedGraph : public MPSCachedGraph {
|
||||
PoolingCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
@ -20,21 +19,27 @@ typedef MPSGraphTensor* (^PoolingOpBlock)(PoolingCachedGraph&, MPSGraphPooling2D
|
||||
#define PoolingOpFn(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph & graph, MPSGraphPooling2DOpDescriptor * desc)
|
||||
|
||||
// Pooling ops (1D/2D forward and backward Max and Average pooling)
|
||||
static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
static void pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const c10::optional<Tensor>& indices_opt,
|
||||
const c10::optional<Tensor>& grad_output_opt,
|
||||
IntArrayRef kernel_size, IntArrayRef stride,
|
||||
IntArrayRef padding, IntArrayRef dilation,
|
||||
bool ceil_mode, bool count_include_pad,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
const c10::optional<int64_t> divisor_override,
|
||||
PoolingOpBlock poolingBlock, const c10::string& op_name)
|
||||
{
|
||||
PoolingOpBlock poolingBlock,
|
||||
const c10::string& op_name) {
|
||||
if (input.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_CHECK(input.scalar_type() != ScalarType::Long,
|
||||
"MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0.");
|
||||
"MPS: ",
|
||||
op_name,
|
||||
" op with int64 input is supported natively starting from macOS 13.0.");
|
||||
}
|
||||
const int64_t ndims = input.ndimension();
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
@ -48,13 +53,17 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
// be incompatible with the PyTorch's global NCHW layout.
|
||||
const auto memory_format = has_indices ? MemoryFormat::Contiguous : suggested_memory_format;
|
||||
|
||||
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, op_name,
|
||||
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
|
||||
op_name,
|
||||
": kernel_size must either be a single int, or a tuple of two ints")
|
||||
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, op_name,
|
||||
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
|
||||
op_name,
|
||||
": stride must either be omitted, a single int, or a tuple of two ints")
|
||||
TORCH_CHECK(padding.size() == 1 || padding.size() == 2, op_name,
|
||||
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
|
||||
op_name,
|
||||
": padding must be either be a single int, or a tuple of two ints");
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, op_name,
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
|
||||
op_name,
|
||||
": dilation must be either a single int, or a tuple of two ints");
|
||||
|
||||
if (suggested_memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
@ -80,8 +89,21 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
|
||||
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
|
||||
|
||||
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
|
||||
pool2d_shape_check(input,
|
||||
kH,
|
||||
kW,
|
||||
dH,
|
||||
dW,
|
||||
padH,
|
||||
padW,
|
||||
dilationH,
|
||||
dilationW,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
memory_format);
|
||||
|
||||
auto output_memory_format = output.suggest_memory_format();
|
||||
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
|
||||
@ -111,10 +133,9 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" +
|
||||
getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" +
|
||||
getArrayRefString(padding) + "]:D[" + getArrayRefString(dilation) + "]" +
|
||||
(ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
|
||||
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + getArrayRefString(kernel_size) +
|
||||
"]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) + "]:D[" +
|
||||
getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
|
||||
(has_divisor ? ":divisor" : "") + ":" +
|
||||
(suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
|
||||
@ -130,8 +151,8 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new PoolingCachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor
|
||||
descriptorWithKernelWidth: kW
|
||||
MPSGraphPooling2DOpDescriptor* desc =
|
||||
[MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:kW
|
||||
kernelHeight:kH
|
||||
strideInX:dW
|
||||
strideInY:dH
|
||||
@ -142,25 +163,27 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
paddingTop:padH
|
||||
paddingBottom:ceil_mode ? padH * dH : padH
|
||||
paddingStyle:MPSGraphPaddingStyleExplicit
|
||||
dataLayout: memory_format == MemoryFormat::ChannelsLast ?
|
||||
MPSGraphTensorNamedDataLayoutNHWC :
|
||||
MPSGraphTensorNamedDataLayoutNCHW];
|
||||
dataLayout:memory_format == MemoryFormat::ChannelsLast
|
||||
? MPSGraphTensorNamedDataLayoutNHWC
|
||||
: MPSGraphTensorNamedDataLayoutNCHW];
|
||||
desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false;
|
||||
if (has_indices) {
|
||||
desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D;
|
||||
desc.returnIndicesDataType = MPSDataTypeInt32;
|
||||
}
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
|
||||
newCachedGraph->inputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
|
||||
if (is_backward_pass) {
|
||||
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
|
||||
newCachedGraph->gradOutputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
|
||||
}
|
||||
if (has_divisor) {
|
||||
newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[ @1 ]);
|
||||
}
|
||||
MPSGraphTensor* outputTensor = poolingBlock(*newCachedGraph, desc);
|
||||
// with desc.dataLayout = NHWC (i.e., ChannelsLast), the results need to be converted back to NCHW
|
||||
newCachedGraph->outputTensor = memory_format == MemoryFormat::ChannelsLast ?
|
||||
convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
|
||||
newCachedGraph->outputTensor =
|
||||
memory_format == MemoryFormat::ChannelsLast ? convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
@ -168,10 +191,12 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
// in case of ChannelsLast we don't perform gather() in placeholder to avoid implicit conversion to NCHW
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() :
|
||||
Placeholder(cachedGraph->gradOutputTensor, grad_output,
|
||||
gradOutputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder inputPlaceholder =
|
||||
Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder gradOutputPlaceholder = !is_backward_pass
|
||||
? Placeholder()
|
||||
: Placeholder(
|
||||
cachedGraph->gradOutputTensor, grad_output, gradOutputShape, memory_format != MemoryFormat::ChannelsLast);
|
||||
Placeholder indicesPlaceholder = has_indices ? Placeholder(cachedGraph->indicesTensor, indices) : Placeholder();
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
@ -205,14 +230,17 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
}
|
||||
|
||||
static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
static void avg_pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const c10::optional<Tensor>& grad_output_opt,
|
||||
IntArrayRef kernel_size, IntArrayRef stride,
|
||||
IntArrayRef padding, IntArrayRef dilation,
|
||||
bool ceil_mode, bool count_include_pad,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
const c10::optional<int64_t> divisor_override,
|
||||
const c10::string& op_name)
|
||||
{
|
||||
const c10::string& op_name) {
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const bool is_backward_pass = grad_output.defined();
|
||||
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
|
||||
@ -226,12 +254,21 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
"not supported on MPS backend. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
if (!is_backward_pass) {
|
||||
const_cast<Tensor&>(output) = at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode,
|
||||
count_include_pad, divisor_override).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
.clone()
|
||||
.to("mps");
|
||||
} else {
|
||||
const_cast<Tensor&>(output) = at::avg_pool2d_backward(grad_output.to("cpu"), input.to("cpu"),
|
||||
kernel_size, stride, padding, ceil_mode, count_include_pad,
|
||||
divisor_override).clone().to("mps");
|
||||
const_cast<Tensor&>(output) = at::avg_pool2d_backward(grad_output.to("cpu"),
|
||||
input.to("cpu"),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -265,9 +302,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
|
||||
if (!is_backward_pass) {
|
||||
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor: paddedTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor:paddedTensor descriptor:desc name:nil];
|
||||
if (cachedGraph.divisorTensor) {
|
||||
// workaround: custom divisor isn't supported by MPS backend, so we scale manually
|
||||
return [mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor
|
||||
@ -301,43 +336,58 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
};
|
||||
|
||||
pool2d_template(input, output, c10::nullopt, grad_output_opt, kernel_size, stride,
|
||||
padding, {1, 1}, ceil_mode, count_include_pad, divisor_override,
|
||||
pooling_op_block, op_name);
|
||||
pool2d_template(input,
|
||||
output,
|
||||
c10::nullopt,
|
||||
grad_output_opt,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
pooling_op_block,
|
||||
op_name);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor mps_max_pool2d(
|
||||
const Tensor& input,
|
||||
Tensor mps_max_pool2d(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
|
||||
Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor
|
||||
descriptor: desc
|
||||
name: nil];
|
||||
return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
|
||||
};
|
||||
mps::pool2d_template(input, output, c10::nullopt, c10::nullopt, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d");
|
||||
mps::pool2d_template(input,
|
||||
output,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d");
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor mps_max_pool2d_backward(
|
||||
const Tensor& grad_output,
|
||||
Tensor mps_max_pool2d_backward(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
|
||||
Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous);
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
@ -346,14 +396,25 @@ Tensor mps_max_pool2d_backward(
|
||||
descriptor:desc
|
||||
name:nil];
|
||||
};
|
||||
mps::pool2d_template(input, grad_input, c10::nullopt, grad_output, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_backward");
|
||||
mps::pool2d_template(input,
|
||||
grad_input,
|
||||
c10::nullopt,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d_backward");
|
||||
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
|
||||
const Tensor& input,
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
@ -361,7 +422,6 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
|
||||
bool ceil_mode,
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
|
||||
auto indices_memory_format = indices.suggest_memory_format();
|
||||
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
@ -372,16 +432,27 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
|
||||
cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
|
||||
return poolOutputs[0];
|
||||
};
|
||||
mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices");
|
||||
mps::pool2d_template(input,
|
||||
output,
|
||||
indices,
|
||||
c10::nullopt,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d_indices");
|
||||
|
||||
if (indices_memory_format == MemoryFormat::ChannelsLast) {
|
||||
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
|
||||
const Tensor& grad_output,
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
@ -390,7 +461,6 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
|
||||
bool ceil_mode,
|
||||
const Tensor& indices,
|
||||
const Tensor& grad_input) {
|
||||
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor
|
||||
@ -398,12 +468,23 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
|
||||
descriptor:desc
|
||||
name:nil];
|
||||
};
|
||||
mps::pool2d_template(input, grad_input, indices, grad_output, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices_backward");
|
||||
mps::pool2d_template(input,
|
||||
grad_input,
|
||||
indices,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
false,
|
||||
c10::nullopt,
|
||||
pooling_op_block,
|
||||
"max_pool2d_indices_backward");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps) (
|
||||
const Tensor& input,
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
int64_t kH,
|
||||
int64_t kW,
|
||||
int64_t dH,
|
||||
@ -414,13 +495,21 @@ TORCH_IMPL_FUNC(avg_pool2d_out_mps) (
|
||||
bool count_include_pad,
|
||||
c10::optional<int64_t> divisor_override,
|
||||
const Tensor& output) {
|
||||
|
||||
mps::avg_pool2d_template(input, output, c10::nullopt, {kH, kW}, {dH, dW}, {padH, padW},
|
||||
{1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d");
|
||||
mps::avg_pool2d_template(input,
|
||||
output,
|
||||
c10::nullopt,
|
||||
{kH, kW},
|
||||
{dH, dW},
|
||||
{padH, padW},
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
"avg_pool2d");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) (
|
||||
const Tensor& gradOutput,
|
||||
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)
|
||||
(const Tensor& gradOutput,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
@ -429,9 +518,17 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) (
|
||||
bool count_include_pad,
|
||||
c10::optional<int64_t> divisor_override,
|
||||
const Tensor& gradInput) {
|
||||
|
||||
mps::avg_pool2d_template(input, gradInput, gradOutput, kernel_size, stride, padding,
|
||||
{1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d_backward");
|
||||
mps::avg_pool2d_template(input,
|
||||
gradInput,
|
||||
gradOutput,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
"avg_pool2d_backward");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,9 +1,9 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
@ -15,14 +15,17 @@ namespace at::native {
|
||||
namespace {
|
||||
struct RangeCachedGraph : public mps::MPSCachedGraph {
|
||||
API_AVAILABLE(macosx(12.3))
|
||||
RangeCachedGraph(MPSGraph *mpsGraph, MPSDataType dataType, int32_t shapeVal, bool needsClamp = false, bool startLessEnd = false): MPSCachedGraph(mpsGraph) {
|
||||
RangeCachedGraph(MPSGraph* mpsGraph,
|
||||
MPSDataType dataType,
|
||||
int32_t shapeVal,
|
||||
bool needsClamp = false,
|
||||
bool startLessEnd = false)
|
||||
: MPSCachedGraph(mpsGraph) {
|
||||
@autoreleasepool {
|
||||
auto shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:&shapeVal length:sizeof(int32_t)]
|
||||
shape:@[ @1 ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
auto coordsTensor = [mpsGraph coordinateAlongAxis:0
|
||||
withShapeTensor:shapeTensor
|
||||
name:nil];
|
||||
auto coordsTensor = [mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil];
|
||||
coordsTensor = [mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"];
|
||||
|
||||
startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
|
||||
@ -30,9 +33,7 @@ struct RangeCachedGraph : public mps::MPSCachedGraph {
|
||||
auto scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor
|
||||
secondaryTensor:multiplyTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords
|
||||
secondaryTensor:startTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil];
|
||||
if (needsClamp) {
|
||||
endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]);
|
||||
outputTensor = [mpsGraph clampWithTensor:outputTensor
|
||||
@ -59,17 +60,17 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
|
||||
double size_d;
|
||||
if (std::is_same<scalar_t, int64_t>::value) {
|
||||
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
|
||||
/ step.to<accscalar_t>());
|
||||
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) / step.to<accscalar_t>());
|
||||
} else {
|
||||
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
|
||||
/ step.to<double>());
|
||||
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>());
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
|
||||
@ -80,10 +81,16 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
|
||||
if (numel != size) {
|
||||
if (numel > 0) {
|
||||
TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(),
|
||||
" is ", numel, " which does not match the computed number of elements ", size,
|
||||
TORCH_WARN("The number of elements in the out tensor of shape ",
|
||||
result.sizes(),
|
||||
" is ",
|
||||
numel,
|
||||
" which does not match the computed number of elements ",
|
||||
size,
|
||||
". Note that this may occur as a result of rounding error. "
|
||||
"The out tensor will be resized to a tensor of shape (", size, ",).");
|
||||
"The out tensor will be resized to a tensor of shape (",
|
||||
size,
|
||||
",).");
|
||||
}
|
||||
result.resize_({size});
|
||||
}
|
||||
@ -115,9 +122,8 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
|
||||
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
@ -139,17 +145,17 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
// double size_d = ((xend - xstart) / xstep) + 1;
|
||||
double size_d;
|
||||
if (std::is_same<scalar_t, int64_t>::value) {
|
||||
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
|
||||
/ step.to<accscalar_t>() + 1;
|
||||
size_d = static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>()) / step.to<accscalar_t>() + 1;
|
||||
} else {
|
||||
size_d = static_cast<double>(end.to<double>() - start.to<double>())
|
||||
/ step.to<double>() + 1;
|
||||
size_d = static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>() + 1;
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
|
||||
@ -186,9 +192,8 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
MPSScalar stepScalar = getMPSScalar(step, result.scalar_type());
|
||||
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
@ -222,12 +227,12 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
bool start_less_end = (start.to<double>() <= end.to<double>());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
|
||||
string key =
|
||||
"linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
|
||||
RangeCachedGraph* cachedGraph = static_cast<RangeCachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
RangeCachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -235,7 +240,9 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
newCachedGraph = new RangeCachedGraph(mpsGraph, MPSDataTypeFloat32, steps, true, start_less_end);
|
||||
|
||||
if (getMPSDataType(result) != MPSDataTypeFloat32) {
|
||||
newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor toType:getMPSDataType(result) name:@"output"];
|
||||
newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor
|
||||
toType:getMPSDataType(result)
|
||||
name:@"output"];
|
||||
}
|
||||
}
|
||||
return newCachedGraph;
|
||||
@ -255,9 +262,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
MPSScalar multiplyScalar = getMPSScalar(multiply, ScalarType::Float);
|
||||
feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, multiplyScalar);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -8,8 +8,8 @@
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/Repeat.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
@ -19,8 +19,7 @@ namespace at::native {
|
||||
|
||||
Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
|
||||
auto nDims = self.dim();
|
||||
TORCH_CHECK(dims.size() == (size_t)nDims,
|
||||
"number of dims don't match in permute");
|
||||
TORCH_CHECK(dims.size() == (size_t)nDims, "number of dims don't match in permute");
|
||||
auto oldSizes = self.sizes();
|
||||
auto oldStrides = self.strides();
|
||||
DimVector newSizes(nDims);
|
||||
@ -28,8 +27,7 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
|
||||
std::vector<bool> seen(nDims);
|
||||
for (const auto i : c10::irange(nDims)) {
|
||||
auto dim = maybe_wrap_dim(dims[i], nDims);
|
||||
TORCH_CHECK(!seen[dim],
|
||||
"repeated dim in permute");
|
||||
TORCH_CHECK(!seen[dim], "repeated dim in permute");
|
||||
seen[dim] = true;
|
||||
newSizes[i] = oldSizes[dim];
|
||||
newStrides[i] = oldStrides[dim];
|
||||
@ -38,13 +36,11 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
|
||||
}
|
||||
|
||||
Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(repeats.size() >= (size_t)self.dim(),
|
||||
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -96,10 +92,9 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor
|
||||
withMultiplier:getMPSShape(repeats)
|
||||
name:nil];
|
||||
MPSGraphTensor* inputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor withMultiplier:getMPSShape(repeats) name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
@ -111,15 +106,13 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/ false, outputDataType);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -142,8 +135,7 @@ kernel void repeat_interleave(constant {0} * repeat_ptr [[buf
|
||||
}}
|
||||
)METAL_REPEAT";
|
||||
|
||||
static
|
||||
id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::string& t1) {
|
||||
static id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::string& t1) {
|
||||
auto key = t1;
|
||||
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
|
||||
auto it = libMap.find(key);
|
||||
@ -153,7 +145,8 @@ id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::strin
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()]
|
||||
auto rc =
|
||||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
|
||||
@ -161,8 +154,7 @@ id<MTLLibrary> compileRepeatInterleaveLib(id<MTLDevice> device, const std::strin
|
||||
return rc;
|
||||
}
|
||||
|
||||
static
|
||||
id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::string& t1) {
|
||||
static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::string& t1) {
|
||||
static std::string kernel = "repeat_interleave";
|
||||
auto key = kernel + t1;
|
||||
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
|
||||
@ -175,14 +167,14 @@ id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const std::st
|
||||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
|
||||
TORCH_CHECK(func != nil, "Can't get kernel ", kernel);
|
||||
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
cplMap[key] = rc;
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
void computeRepeatIndices(
|
||||
index_t* repeat_ptr,
|
||||
void computeRepeatIndices(index_t* repeat_ptr,
|
||||
int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
@ -233,12 +225,12 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
|
||||
if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
|
||||
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
|
||||
TORCH_WARN_ONCE(
|
||||
"MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
|
||||
repeat = repeat.to(kInt);
|
||||
}
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(
|
||||
repeat, output_size);
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
@ -4,9 +4,9 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/native/RNN.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/RNN.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#import <MetalPerformanceShadersGraph/MPSGraphRNNOps.h>
|
||||
@ -30,31 +30,26 @@ std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
|
||||
*/
|
||||
static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*>
|
||||
getMPSTensorsFromPytorchTensors(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* stateTensor, MPSGraphTensor* cellStateTensor,
|
||||
MPSGraphTensor* stateTensor,
|
||||
MPSGraphTensor* cellStateTensor,
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentKernelWeightsList,
|
||||
NSMutableArray<MPSGraphTensor*>* kernelWeightsList,
|
||||
NSMutableArray<MPSGraphTensor*>* kernelBiasList,
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentBiasList,
|
||||
bool has_biases, bool bidirectional, size_t layer_no) {
|
||||
bool has_biases,
|
||||
bool bidirectional,
|
||||
size_t layer_no) {
|
||||
MPSGraphTensor* biasTensor_ = nil;
|
||||
MPSGraphTensor *stateTensor_ = nil, *cellStateTensor_ = nil;
|
||||
MPSGraphTensor *recurrentWeight_ = nil, *inputWeight_ = nil;
|
||||
|
||||
if (bidirectional) {
|
||||
stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:layer_no * 2
|
||||
length:2
|
||||
name:nil];
|
||||
stateTensor_ = [mpsGraph sliceTensor:stateTensor dimension:0 start:layer_no * 2 length:2 name:nil];
|
||||
// [2, N, H] -> [N, 2, H]
|
||||
stateTensor_ = [mpsGraph transposeTensor:stateTensor_ dimension:0 withDimension:1 name:nil];
|
||||
// [N, 2, H] -> [N, 2 * H]
|
||||
stateTensor_ = [mpsGraph flatten2DTensor:stateTensor_ axis:1 name:nil];
|
||||
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:layer_no * 2
|
||||
length:2
|
||||
name:nil];
|
||||
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor dimension:0 start:layer_no * 2 length:2 name:nil];
|
||||
cellStateTensor_ = [mpsGraph transposeTensor:cellStateTensor_ dimension:0 withDimension:1 name:nil];
|
||||
cellStateTensor_ = [mpsGraph flatten2DTensor:cellStateTensor_ axis:1 name:nil];
|
||||
|
||||
@ -62,14 +57,11 @@ static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTen
|
||||
concatTensor:[mpsGraph expandDimsOfTensor:recurrentKernelWeightsList[layer_no * 2] axis:0 name:nil]
|
||||
withTensor:[mpsGraph expandDimsOfTensor:recurrentKernelWeightsList[layer_no * 2 + 1] axis:0 name:nil]
|
||||
dimension:0
|
||||
name: nil
|
||||
];
|
||||
inputWeight_ = [mpsGraph
|
||||
concatTensor: kernelWeightsList[layer_no * 2]
|
||||
name:nil];
|
||||
inputWeight_ = [mpsGraph concatTensor:kernelWeightsList[layer_no * 2]
|
||||
withTensor:kernelWeightsList[layer_no * 2 + 1]
|
||||
dimension:0
|
||||
name: nil
|
||||
];
|
||||
name:nil];
|
||||
if (has_biases) {
|
||||
auto biasTensorFwd_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2]
|
||||
secondaryTensor:recurrentBiasList[layer_no * 2]
|
||||
@ -81,16 +73,8 @@ static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTen
|
||||
biasTensor_ = [mpsGraph concatTensor:biasTensorFwd_ withTensor:biasTensorBack_ dimension:0 name:nil];
|
||||
}
|
||||
} else {
|
||||
stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:layer_no
|
||||
length:1
|
||||
name:nil];
|
||||
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:layer_no
|
||||
length:1
|
||||
name:nil];
|
||||
stateTensor_ = [mpsGraph sliceTensor:stateTensor dimension:0 start:layer_no length:1 name:nil];
|
||||
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor dimension:0 start:layer_no length:1 name:nil];
|
||||
recurrentWeight_ = recurrentKernelWeightsList[layer_no];
|
||||
inputWeight_ = kernelWeightsList[layer_no];
|
||||
if (has_biases) {
|
||||
@ -102,7 +86,15 @@ static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTen
|
||||
return std::make_tuple(stateTensor_, cellStateTensor_, recurrentWeight_, inputWeight_, biasTensor_);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input,
|
||||
TensorList hx,
|
||||
TensorList params,
|
||||
bool has_biases,
|
||||
int64_t num_layers,
|
||||
double dropout_p,
|
||||
bool train,
|
||||
bool bidirectional,
|
||||
bool batch_first) {
|
||||
using namespace mps;
|
||||
|
||||
// Projections are not currently supported, raise an error if needed
|
||||
@ -144,28 +136,36 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + std::to_string(batch_first);
|
||||
string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" +
|
||||
std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" +
|
||||
std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" +
|
||||
std::to_string(batch_first);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
NSMutableArray<MPSGraphTensor*>* kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentKernelWeightsList =
|
||||
[[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
|
||||
for (const auto i : c10::irange(total_layers)) {
|
||||
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))];
|
||||
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))];
|
||||
[kernelWeightsList
|
||||
addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))];
|
||||
[recurrentKernelWeightsList
|
||||
addObject:mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))];
|
||||
if (has_biases) {
|
||||
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))];
|
||||
[kernelBiasList
|
||||
addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))];
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,14 +176,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(input));
|
||||
MPSGraphTensor* stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[0]));
|
||||
MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1]));
|
||||
std::vector<MPSGraphTensor*> inputTensors = {inputTensor, stateTensor, cellStateTensor,};
|
||||
MPSGraphTensor* cellStateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1]));
|
||||
std::vector<MPSGraphTensor*> inputTensors = {
|
||||
inputTensor,
|
||||
stateTensor,
|
||||
cellStateTensor,
|
||||
};
|
||||
|
||||
if (batch_first) {
|
||||
inputTensor = [mpsGraph transposeTensor:inputTensor
|
||||
dimension:0
|
||||
withDimension:1
|
||||
name:nil];
|
||||
inputTensor = [mpsGraph transposeTensor:inputTensor dimension:0 withDimension:1 name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* inputTensor_ = inputTensor;
|
||||
@ -191,17 +193,23 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
NSMutableArray<MPSGraphTensor*>* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputCellStateFwdArray =
|
||||
[[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor,
|
||||
recurrentKernelWeightsList, kernelWeightsList,
|
||||
kernelBiasList, recurrentBiasList, has_biases,
|
||||
bidirectional, i);
|
||||
auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph,
|
||||
stateTensor,
|
||||
cellStateTensor,
|
||||
recurrentKernelWeightsList,
|
||||
kernelWeightsList,
|
||||
kernelBiasList,
|
||||
recurrentBiasList,
|
||||
has_biases,
|
||||
bidirectional,
|
||||
i);
|
||||
MPSGraphTensor *stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData);
|
||||
MPSGraphTensor *recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData);
|
||||
MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData);
|
||||
|
||||
|
||||
outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
|
||||
recurrentWeight:recurrentWeight_
|
||||
inputWeight:inputWeight_
|
||||
@ -215,15 +223,10 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
// no need to keep the final layer output copy as it is
|
||||
// returned anyway and not used in backprop
|
||||
if (i != num_layers - 1) {
|
||||
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
|
||||
axis:0
|
||||
name:nil]];
|
||||
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_ axis:0 name:nil]];
|
||||
}
|
||||
if (dropout_p > 0.0 && train && (i != num_layers - 1)) {
|
||||
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
|
||||
rate:dropout_p
|
||||
name:nil];
|
||||
|
||||
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_ rate:dropout_p name:nil];
|
||||
}
|
||||
|
||||
if (bidirectional) {
|
||||
@ -231,54 +234,71 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
auto stateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil];
|
||||
auto stateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:0 length:1 name:nil];
|
||||
// [1, N, H] ([1, N, 0:H])
|
||||
auto stateForward = [mpsGraph sliceTensor:stateLastT dimension: -1 start:0 length:hx[0].sizes()[2] name:nil];
|
||||
auto stateForward = [mpsGraph sliceTensor:stateLastT
|
||||
dimension:-1
|
||||
start:0
|
||||
length:hx[0].sizes()[2]
|
||||
name:nil];
|
||||
// [1, N, H] ([1, N, H:2H])
|
||||
auto stateBack = [mpsGraph sliceTensor:stateFirstT dimension: -1 start:hx[0].sizes()[2] length:hx[0].sizes()[2] name:nil];
|
||||
auto stateBack = [mpsGraph sliceTensor:stateFirstT
|
||||
dimension:-1
|
||||
start:hx[0].sizes()[2]
|
||||
length:hx[0].sizes()[2]
|
||||
name:nil];
|
||||
[outputStateArray addObject:stateForward];
|
||||
[outputStateArray addObject:stateBack];
|
||||
|
||||
auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil];
|
||||
auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:0 length:1 name:nil];
|
||||
auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT dimension: -1 start:0 length:hx[1].sizes()[2] name:nil];
|
||||
auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT dimension: -1 start:hx[1].sizes()[2] length:hx[1].sizes()[2] name:nil];
|
||||
auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1]
|
||||
dimension:0
|
||||
start:-1
|
||||
length:1
|
||||
name:nil];
|
||||
auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1]
|
||||
dimension:0
|
||||
start:0
|
||||
length:1
|
||||
name:nil];
|
||||
auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT
|
||||
dimension:-1
|
||||
start:0
|
||||
length:hx[1].sizes()[2]
|
||||
name:nil];
|
||||
auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT
|
||||
dimension:-1
|
||||
start:hx[1].sizes()[2]
|
||||
length:hx[1].sizes()[2]
|
||||
name:nil];
|
||||
[outputCellStateArray addObject:cellStateForward];
|
||||
[outputCellStateArray addObject:cellStateBack];
|
||||
} else {
|
||||
[outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]];
|
||||
[outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]];
|
||||
[outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0]
|
||||
dimension:0
|
||||
start:-1
|
||||
length:1
|
||||
name:nil]];
|
||||
[outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1]
|
||||
dimension:0
|
||||
start:-1
|
||||
length:1
|
||||
name:nil]];
|
||||
}
|
||||
[outputCellStateFwdArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1]
|
||||
axis:0
|
||||
name:nil]];
|
||||
[outputZStateArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:2]
|
||||
axis:0
|
||||
name:nil]];
|
||||
[outputCellStateFwdArray addObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1] axis:0 name:nil]];
|
||||
[outputZStateArray addObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:2] axis:0 name:nil]];
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor = inputTensor_;
|
||||
if (batch_first) {
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor
|
||||
dimension:0
|
||||
withDimension:1
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor dimension:0 withDimension:1 name:nil];
|
||||
}
|
||||
MPSGraphTensor* outputStates = [mpsGraph concatTensors:outputStateArray
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* outputCellStates = [mpsGraph concatTensors:outputCellStateArray
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* outputZStates = [mpsGraph concatTensors:outputZStateArray
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* layersOutputs = (num_layers > 1)
|
||||
? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil]
|
||||
: nil;
|
||||
MPSGraphTensor* outputStates = [mpsGraph concatTensors:outputStateArray dimension:0 name:nil];
|
||||
MPSGraphTensor* outputCellStates = [mpsGraph concatTensors:outputCellStateArray dimension:0 name:nil];
|
||||
MPSGraphTensor* outputZStates = [mpsGraph concatTensors:outputZStateArray dimension:0 name:nil];
|
||||
MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray dimension:0 name:nil];
|
||||
MPSGraphTensor* layersOutputs =
|
||||
(num_layers > 1) ? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil] : nil;
|
||||
|
||||
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs};
|
||||
std::vector<MPSGraphTensor*> outputTensors = {
|
||||
outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs};
|
||||
newCachedGraph->inputTensors_ = inputTensors;
|
||||
newCachedGraph->outputTensors_ = outputTensors;
|
||||
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
|
||||
@ -299,7 +319,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
for (const auto i : c10::irange(total_layers)) {
|
||||
Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
|
||||
Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
|
||||
Placeholder recurrentKernelWeight =
|
||||
Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
|
||||
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
|
||||
if (has_biases) {
|
||||
@ -316,7 +337,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
[feeds setObject:selfState.getMPSGraphTensorData() forKey:selfState.getMPSGraphTensor()];
|
||||
[feeds setObject:selfCellState.getMPSGraphTensorData() forKey:selfCellState.getMPSGraphTensor()];
|
||||
|
||||
|
||||
auto dims = getTensorShape(cachedGraph->outputTensors_[0]);
|
||||
Tensor output = at::empty(IntArrayRef(dims), input.options());
|
||||
Tensor hy = at::empty_like(hx[0], input.options());
|
||||
@ -351,7 +371,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, const Tensor& layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y,
|
||||
const c10::optional<Tensor>& grad_hy_opt,
|
||||
const c10::optional<Tensor>& grad_cy_opt,
|
||||
const Tensor& z_state,
|
||||
const Tensor& cell_state_fwd,
|
||||
const Tensor& input,
|
||||
const Tensor& layersOutputs,
|
||||
TensorList hx,
|
||||
TensorList params,
|
||||
bool has_biases,
|
||||
int64_t num_layers,
|
||||
double dropout_p,
|
||||
bool train,
|
||||
bool bidirectional,
|
||||
bool batch_first) {
|
||||
using namespace mps;
|
||||
const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] { return Tensor(); });
|
||||
const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] { return Tensor(); });
|
||||
@ -395,53 +429,69 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
// Get stream
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" + std::to_string(batch_first);
|
||||
string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy}) +
|
||||
getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" +
|
||||
std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" +
|
||||
std::to_string(batch_first);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
NSMutableArray<MPSGraphTensor*>* kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentKernelWeightsList =
|
||||
[[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
|
||||
for (const auto i : c10::irange(total_layers)) {
|
||||
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))];
|
||||
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))];
|
||||
[kernelWeightsList
|
||||
addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))];
|
||||
[recurrentKernelWeightsList
|
||||
addObject:mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))];
|
||||
if (has_biases) {
|
||||
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))];
|
||||
[kernelBiasList
|
||||
addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(
|
||||
mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))];
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(input));
|
||||
MPSGraphTensor* stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[0]));
|
||||
MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1]));
|
||||
MPSGraphTensor* zStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(z_state));
|
||||
MPSGraphTensor* gradientTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_y), getMPSShape(grad_y));
|
||||
MPSGraphTensor* gradientCyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy), getMPSShape(grad_cy));
|
||||
MPSGraphTensor* gradientHyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy), getMPSShape(grad_hy));
|
||||
MPSGraphTensor* cellStateFwdTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd), getMPSShape(cell_state_fwd));
|
||||
MPSGraphTensor* layersOutputsTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs), getMPSShape(layersOutputs));
|
||||
MPSGraphTensor* cellStateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1]));
|
||||
MPSGraphTensor* zStateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(z_state));
|
||||
MPSGraphTensor* gradientTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_y), getMPSShape(grad_y));
|
||||
MPSGraphTensor* gradientCyTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy), getMPSShape(grad_cy));
|
||||
MPSGraphTensor* gradientHyTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy), getMPSShape(grad_hy));
|
||||
MPSGraphTensor* cellStateFwdTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd), getMPSShape(cell_state_fwd));
|
||||
MPSGraphTensor* layersOutputsTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs), getMPSShape(layersOutputs));
|
||||
|
||||
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor, layersOutputsTensor};
|
||||
std::vector<MPSGraphTensor*> inputs = {inputTensor,
|
||||
stateTensor,
|
||||
cellStateTensor,
|
||||
gradientTensor,
|
||||
zStateTensor,
|
||||
cellStateFwdTensor,
|
||||
gradientHyTensor,
|
||||
gradientCyTensor,
|
||||
layersOutputsTensor};
|
||||
|
||||
if (batch_first) {
|
||||
inputTensor = [mpsGraph transposeTensor: inputTensor
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name: nil];
|
||||
inputTensor = [mpsGraph transposeTensor:inputTensor dimension:0 withDimension:1 name:nil];
|
||||
|
||||
gradientTensor = [mpsGraph transposeTensor: gradientTensor
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name: nil];
|
||||
gradientTensor = [mpsGraph transposeTensor:gradientTensor dimension:0 withDimension:1 name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList;
|
||||
@ -468,62 +518,43 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
auto hidden_size = hx[0].sizes()[2];
|
||||
|
||||
for (int i = num_layers - 1; i >= 0; i--) {
|
||||
MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
zState = [mpsGraph squeezeTensor:zState
|
||||
axis:0
|
||||
name:nil];
|
||||
MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor dimension:0 start:i length:1 name:nil];
|
||||
zState = [mpsGraph squeezeTensor:zState axis:0 name:nil];
|
||||
MPSGraphTensor* cellStateFwd = [mpsGraph sliceTensor:cellStateFwdTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd
|
||||
axis:0
|
||||
name:nil];
|
||||
auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor,
|
||||
recurrentKernelWeightsList, kernelWeightsList,
|
||||
kernelBiasList, recurrentBiasList, has_biases,
|
||||
bidirectional, i);
|
||||
cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd axis:0 name:nil];
|
||||
auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph,
|
||||
stateTensor,
|
||||
cellStateTensor,
|
||||
recurrentKernelWeightsList,
|
||||
kernelWeightsList,
|
||||
kernelBiasList,
|
||||
recurrentBiasList,
|
||||
has_biases,
|
||||
bidirectional,
|
||||
i);
|
||||
MPSGraphTensor *stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData);
|
||||
MPSGraphTensor *recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData);
|
||||
MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData);
|
||||
|
||||
MPSGraphTensor *gradientHyTensor_ = nil, *gradientCyTensor_ = nil;
|
||||
if (bidirectional) {
|
||||
gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
|
||||
dimension:0
|
||||
start:i * 2
|
||||
length:2
|
||||
name:nil];
|
||||
gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor dimension:0 start:i * 2 length:2 name:nil];
|
||||
// [2, N, H] -> [N, 2, H]
|
||||
gradientHyTensor_ = [mpsGraph transposeTensor:gradientHyTensor_ dimension:0 withDimension:1 name:nil];
|
||||
// [N, 2, H] -> [N, 2 * H]
|
||||
gradientHyTensor_ = [mpsGraph flatten2DTensor:gradientHyTensor_ axis:1 name:nil];
|
||||
|
||||
|
||||
gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
|
||||
dimension:0
|
||||
start:i * 2
|
||||
length:2
|
||||
name:nil];
|
||||
gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor dimension:0 start:i * 2 length:2 name:nil];
|
||||
gradientCyTensor_ = [mpsGraph transposeTensor:gradientCyTensor_ dimension:0 withDimension:1 name:nil];
|
||||
gradientCyTensor_ = [mpsGraph flatten2DTensor:gradientCyTensor_ axis:1 name:nil];
|
||||
} else {
|
||||
gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor dimension:0 start:i length:1 name:nil];
|
||||
|
||||
gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor dimension:0 start:i length:1 name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* iterationInputTensor_ = nil;
|
||||
@ -538,9 +569,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
start:i - num_layers
|
||||
length:1
|
||||
name:nil];
|
||||
iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_
|
||||
axis:0
|
||||
name: nil];
|
||||
iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_ axis:0 name:nil];
|
||||
}
|
||||
|
||||
outputs = [mpsGraph LSTMGradientsWithSourceTensor:iterationInputTensor_
|
||||
@ -599,9 +628,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
// has shape [1, 1, 8H] vs [8H] as should be
|
||||
// so, squeeze these two first dimensions
|
||||
auto gradBiasBidirectional = [outputs objectAtIndex:outputIter++];
|
||||
gradBiasBidirectional = [mpsGraph squeezeTensor: gradBiasBidirectional
|
||||
axes: @[@0, @1]
|
||||
name: nil];
|
||||
gradBiasBidirectional = [mpsGraph squeezeTensor:gradBiasBidirectional axes:@[ @0, @1 ] name:nil];
|
||||
auto gradBiasFwd = [mpsGraph sliceTensor:gradBiasBidirectional
|
||||
dimension:0
|
||||
start:0
|
||||
@ -644,8 +671,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
length:hidden_size
|
||||
name:nil];
|
||||
|
||||
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil] atIndex:0];
|
||||
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil] atIndex:0];
|
||||
[gradCellStateArray insertObject:[mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil]
|
||||
atIndex:0];
|
||||
[gradCellStateArray insertObject:[mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil]
|
||||
atIndex:0];
|
||||
} else {
|
||||
int outputIter = 1;
|
||||
[gradRecWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0];
|
||||
@ -653,8 +682,14 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
if (has_biases) {
|
||||
[gradBiasArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0];
|
||||
}
|
||||
[gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0];
|
||||
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0];
|
||||
[gradStateArray insertObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++]
|
||||
axis:0
|
||||
name:nil]
|
||||
atIndex:0];
|
||||
[gradCellStateArray insertObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++]
|
||||
axis:0
|
||||
name:nil]
|
||||
atIndex:0];
|
||||
}
|
||||
}
|
||||
if (batch_first) {
|
||||
@ -696,8 +731,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
[feeds setObject:statePlaceholder.getMPSGraphTensorData() forKey:statePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() forKey:cellStateFwdPlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() forKey:layersOutputsPlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData()
|
||||
forKey:cellStateFwdPlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData()
|
||||
forKey:layersOutputsPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
NSMutableArray<MPSGraphTensor*>* kernelWeightsList = cachedGraph->kernelWeightsList_;
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
|
||||
@ -706,7 +743,8 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
|
||||
for (const auto i : c10::irange(total_layers)) {
|
||||
Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
|
||||
Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
|
||||
Placeholder recurrentKernelWeight =
|
||||
Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
|
||||
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
|
||||
if (has_biases) {
|
||||
@ -721,10 +759,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
Tensor grad_state_out = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state_out = at::empty_like(hx[1]);
|
||||
|
||||
|
||||
std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out};
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
[[[NSMutableDictionary alloc] init] autorelease];
|
||||
NSMutableArray<MPSGraphTensor*>* gradRecWeightsArray = cachedGraph->gradRecWeights_;
|
||||
NSMutableArray<MPSGraphTensor*>* gradWeightsArray = cachedGraph->gradWeights_;
|
||||
NSMutableArray<MPSGraphTensor*>* gradBiasArray = cachedGraph->gradBias_;
|
||||
@ -736,7 +774,8 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out);
|
||||
Placeholder outputPlaceholder = Placeholder(gradOutput, output_out);
|
||||
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData()
|
||||
forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
|
||||
@ -752,8 +791,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:i], grad_rec_weights);
|
||||
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:i], grad_weights);
|
||||
|
||||
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData()
|
||||
forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData()
|
||||
forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
if (has_biases) {
|
||||
Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
|
||||
@ -773,7 +814,6 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>(output_out, grad_hx, weights);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/Copy.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
@ -21,8 +21,12 @@ namespace at::native {
|
||||
Scalar _local_scalar_dense_mps(const Tensor& self) {
|
||||
Scalar r;
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_mps", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
self.scalar_type(),
|
||||
"_local_scalar_dense_mps",
|
||||
[&] {
|
||||
Tensor output = at::empty_like(self, kCPU);
|
||||
|
||||
Tensor cpu_output = mps::mps_copy_(output, self, false);
|
||||
@ -33,5 +37,4 @@ Scalar _local_scalar_dense_mps(const Tensor& self) {
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -5,12 +5,7 @@
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(gather_out_mps)
|
||||
(const Tensor & self_arg,
|
||||
int64_t dim,
|
||||
const Tensor & index,
|
||||
bool sparse_grad,
|
||||
const Tensor & output)
|
||||
{
|
||||
(const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
if (self_arg.numel() == 0 || index.numel() == 0) {
|
||||
@ -20,13 +15,10 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
|
||||
TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet")
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type(),
|
||||
"gather(): self and output must have the same scalar type");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(),
|
||||
"gather(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* indexTensor_ = nil;
|
||||
@ -36,7 +28,6 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
MPSShape* index_shape = getMPSShape(index);
|
||||
uint32_t num_input_dims = [input_shape count];
|
||||
@ -47,7 +38,8 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
bool needSlice = false;
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
|
||||
"Index dim must not exceed input dim except at gathering axis")
|
||||
if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
needSlice = true;
|
||||
}
|
||||
@ -89,11 +81,7 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
ends[i] = (i != dim) ? index_shape[i] : input_shape[i];
|
||||
}
|
||||
|
||||
getInput = [mpsGraph sliceTensor:inputTensor
|
||||
starts:starts
|
||||
ends:ends
|
||||
strides:strides
|
||||
name:nil];
|
||||
getInput = [mpsGraph sliceTensor:inputTensor starts:starts ends:ends strides:strides name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor
|
||||
@ -124,23 +112,20 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
void scatter_mps_general
|
||||
(const Tensor& self_arg,
|
||||
void scatter_mps_general(const Tensor& self_arg,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output,
|
||||
string func_name,
|
||||
const c10::string_view reduce)
|
||||
{
|
||||
const c10::string_view reduce) {
|
||||
using namespace mps;
|
||||
|
||||
if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
|
||||
@ -151,11 +136,9 @@ void scatter_mps_general
|
||||
|
||||
TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(),
|
||||
"scatter(): self, src and output must have the same scalar type");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(),
|
||||
"scatter(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor");
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* indexTensor_ = nil;
|
||||
@ -166,7 +149,6 @@ void scatter_mps_general
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
MPSShape* index_shape = getMPSShape(index);
|
||||
MPSShape* src_shape = getMPSShape(src);
|
||||
@ -174,7 +156,8 @@ void scatter_mps_general
|
||||
uint32_t num_index_dims = [index_shape count];
|
||||
uint32_t num_src_dims = [src_shape count];
|
||||
|
||||
TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, "Input, index and src must have same rank")
|
||||
TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims,
|
||||
"Input, index and src must have same rank")
|
||||
|
||||
// Do we need to slice into the src tensor?
|
||||
bool needSlice = false;
|
||||
@ -182,8 +165,10 @@ void scatter_mps_general
|
||||
bool needsCast = false;
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
|
||||
TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis")
|
||||
TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
|
||||
"Index dim must not exceed input dim except at gathering axis")
|
||||
TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue],
|
||||
"Index dim must not exceed input dim except at gathering axis")
|
||||
if ([index_shape[i] intValue] < [src_shape[i] intValue])
|
||||
needSlice = true;
|
||||
if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
@ -197,7 +182,8 @@ void scatter_mps_general
|
||||
needsCast = true;
|
||||
}
|
||||
|
||||
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce);
|
||||
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" +
|
||||
std::string(reduce);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
@ -240,11 +226,7 @@ void scatter_mps_general
|
||||
scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i];
|
||||
}
|
||||
if (needSlice) {
|
||||
slicedSrc = [mpsGraph sliceTensor:castSrcTensor
|
||||
starts:starts
|
||||
ends:ends_src
|
||||
strides:strides
|
||||
name:nil];
|
||||
slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts ends:ends_src strides:strides name:nil];
|
||||
}
|
||||
if (inputNeedSlice) {
|
||||
slicedInput = [mpsGraph sliceTensor:castInputTensor
|
||||
@ -277,7 +259,8 @@ void scatter_mps_general
|
||||
#pragma clang diagnostic pop
|
||||
if (inputNeedSlice) {
|
||||
// Make an array of scatter indices tensors
|
||||
NSMutableArray<MPSGraphTensor*>* indicesTensors = [NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
|
||||
NSMutableArray<MPSGraphTensor*>* indicesTensors =
|
||||
[NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
|
||||
|
||||
// 1. Concatenate the coord tensors
|
||||
// 2. Flatten the values
|
||||
@ -289,13 +272,13 @@ void scatter_mps_general
|
||||
shape_data[i] = {[scatterInputShape[i] intValue]};
|
||||
}
|
||||
|
||||
MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
|
||||
MPSGraphTensor* scatterInputShapeTensor =
|
||||
[mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
|
||||
shape:@[ [NSNumber numberWithUnsignedInt:num_input_dims] ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor:axisTensor
|
||||
withShapeTensor:scatterInputShapeTensor
|
||||
name:nil];
|
||||
@ -309,9 +292,7 @@ void scatter_mps_general
|
||||
dimension:(NSInteger)1
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor
|
||||
withShape:@[@-1]
|
||||
name:nil];
|
||||
MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor withShape:@[ @-1 ] name:nil];
|
||||
|
||||
outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor
|
||||
updatesTensor:flatValuesTensor
|
||||
@ -325,7 +306,8 @@ void scatter_mps_general
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->srcTensor_ = srcTensor;
|
||||
newCachedGraph->indexTensor_ = indexTensor;
|
||||
newCachedGraph->outputTensor_ = needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
|
||||
newCachedGraph->outputTensor_ =
|
||||
needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
@ -342,41 +324,24 @@ void scatter_mps_general
|
||||
srcPlaceholder.getMPSGraphTensor() : srcPlaceholder.getMPSGraphTensorData(),
|
||||
indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_src_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output) {
|
||||
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_src_out_mps", "set");
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_value_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Scalar& value,
|
||||
const Tensor& output) {
|
||||
|
||||
Tensor src = at::native::empty_mps(index.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
self.suggest_memory_format());
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& output) {
|
||||
Tensor src = at::native::empty_mps(
|
||||
index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
|
||||
src.fill_(value);
|
||||
scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_out_mps", "set");
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_reduce_out_mps)
|
||||
@ -386,9 +351,7 @@ TORCH_IMPL_FUNC(scatter_reduce_out_mps)
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce,
|
||||
const Tensor& output) {
|
||||
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce);
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
|
||||
@ -398,25 +361,14 @@ TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
|
||||
const Scalar& value,
|
||||
const c10::string_view reduce,
|
||||
const Tensor& output) {
|
||||
|
||||
Tensor src = at::native::empty_mps(index.sizes(),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
self.suggest_memory_format());
|
||||
Tensor src = at::native::empty_mps(
|
||||
index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
|
||||
src.fill_(value);
|
||||
scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_reduce_out_mps", reduce);
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(scatter_add_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output) {
|
||||
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_add_mps_out", "add");
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -27,19 +27,10 @@ std::vector<int64_t> getTopK0Shape(IntArrayRef sizes, const int64_t dim_) {
|
||||
|
||||
// topk
|
||||
TORCH_IMPL_FUNC(topk_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t k,
|
||||
int64_t dim_,
|
||||
bool largest,
|
||||
bool sorted,
|
||||
const Tensor& values,
|
||||
const Tensor& indices)
|
||||
{
|
||||
(const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted, const Tensor& values, const Tensor& indices) {
|
||||
using namespace mps;
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
||||
TORCH_CHECK(
|
||||
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
|
||||
"selected index k out of range");
|
||||
TORCH_CHECK(k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range");
|
||||
|
||||
if (!is_macos_13_or_newer() && (k > 16)) {
|
||||
TORCH_WARN_ONCE("torch.topk support for k>16 by MPS on MacOS 13+, please upgrade");
|
||||
@ -58,15 +49,13 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
}
|
||||
|
||||
// Handle empty tensors
|
||||
if (self.numel() == 0)
|
||||
{
|
||||
if (self.numel() == 0) {
|
||||
values.copy_(self);
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
return;
|
||||
}
|
||||
// Handle k == 0 case. Needed because MPSGraph does not support k == 0.
|
||||
if (k == 0)
|
||||
{
|
||||
if (k == 0) {
|
||||
const auto out_shape = getTopK0Shape(self.sizes(), dim);
|
||||
values.resize_(out_shape);
|
||||
indices.copy_(values.toType(at::ScalarType::Long));
|
||||
@ -85,10 +74,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" +
|
||||
getMPSTypeString(self) +
|
||||
":k" + to_string(k) + ":dim" + to_string(dim_) +
|
||||
":largest" + to_string(largest);
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) +
|
||||
":dim" + to_string(dim_) + ":largest" + to_string(largest);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
@ -102,9 +89,7 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor;
|
||||
MPSDataType dataType = getMPSDataType(self);
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor
|
||||
if (dataType != MPSDataTypeInt32 &&
|
||||
dataType != MPSDataTypeFloat32 &&
|
||||
dataType != MPSDataTypeFloat16) {
|
||||
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor
|
||||
toType:dataType
|
||||
@ -116,8 +101,7 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
name:nil];
|
||||
sortedTensor = [mpsGraph sliceTensor:sortedTensor
|
||||
dimension:(NSUInteger)dim
|
||||
start:((NSUInteger) 0)
|
||||
length:k
|
||||
start:((NSUInteger)0)length:k
|
||||
name:nil];
|
||||
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
@ -125,8 +109,7 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
name:@"argmax_out"];
|
||||
argSortedTensor = [mpsGraph sliceTensor:argSortedTensor
|
||||
dimension:dim
|
||||
start:((NSUInteger) 0)
|
||||
length:k
|
||||
start:((NSUInteger)0)length:k
|
||||
name:nil];
|
||||
newCachedGraph->valuesTensor = sortedTensor;
|
||||
newCachedGraph->indicesTensor = argSortedTensor;
|
||||
@ -138,22 +121,17 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput
|
||||
name: nil];
|
||||
MPSGraphTensor * negatedTransposedInput = [mpsGraph negativeWithTensor:identity
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:negatedTransposedInput
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil];
|
||||
MPSGraphTensor* negatedTransposedInput = [mpsGraph negativeWithTensor:identity name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedTransposedInput
|
||||
k:((NSUInteger)k)name:nil];
|
||||
MPSGraphTensor* valuesNegatedTransposed = outputMPSGraphTensors[0];
|
||||
MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1];
|
||||
MPSGraphTensor* valuesNegated = [mpsGraph transposeTensor:valuesNegatedTransposed
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated
|
||||
name: nil];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil];
|
||||
newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
@ -163,12 +141,9 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
dimension:(NSUInteger)self.dim() - 1
|
||||
withDimension:(NSUInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:identity
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:identity
|
||||
k:((NSUInteger)k)name:nil];
|
||||
MPSGraphTensor* valuesTransposed = outputMPSGraphTensors[0];
|
||||
MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1];
|
||||
newCachedGraph->valuesTensor = [mpsGraph transposeTensor:valuesTransposed
|
||||
@ -181,21 +156,15 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
name:nil];
|
||||
} else if (!largest) {
|
||||
// only negate
|
||||
MPSGraphTensor *negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor
|
||||
name: nil];
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:negatedInput
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
MPSGraphTensor* negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedInput
|
||||
k:((NSUInteger)k)name:nil];
|
||||
MPSGraphTensor* valuesNegated = outputMPSGraphTensors[0];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated
|
||||
name: nil];
|
||||
newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil];
|
||||
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
|
||||
} else {
|
||||
NSArray<MPSGraphTensor *> * outputMPSGraphTensors = [mpsGraph
|
||||
topKWithSourceTensor:newCachedGraph->selfTensor
|
||||
k:((NSUInteger) k)
|
||||
name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputMPSGraphTensors =
|
||||
[mpsGraph topKWithSourceTensor:newCachedGraph->selfTensor k:((NSUInteger)k)name:nil];
|
||||
newCachedGraph->valuesTensor = outputMPSGraphTensors[0];
|
||||
newCachedGraph->indicesTensor = outputMPSGraphTensors[1];
|
||||
}
|
||||
@ -210,29 +179,21 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() :
|
||||
inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
valuesPlaceholder.getMPSGraphTensor() :
|
||||
valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() :
|
||||
indicesPlaceholder.getMPSGraphTensorData()
|
||||
valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
void check_shape_except_dim(const Tensor &first, const Tensor &second,
|
||||
int dimension, int index)
|
||||
{
|
||||
void check_shape_except_dim(const Tensor& first, const Tensor& second, int dimension, int index) {
|
||||
int first_dims = first.dim();
|
||||
int second_dims = second.dim();
|
||||
TORCH_CHECK(first_dims == second_dims,
|
||||
"Tensors must have same number of dimensions: got ", first_dims,
|
||||
" and ", second_dims);
|
||||
TORCH_CHECK(
|
||||
first_dims == second_dims, "Tensors must have same number of dimensions: got ", first_dims, " and ", second_dims);
|
||||
for (int dim = 0; dim < first_dims; dim++) {
|
||||
if (dim == dimension) {
|
||||
continue;
|
||||
@ -240,10 +201,15 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
|
||||
int64_t first_dim_size = at::native::size(first, dim);
|
||||
int64_t second_dim_size = at::native::size(second, dim);
|
||||
TORCH_CHECK(first_dim_size == second_dim_size,
|
||||
"Sizes of tensors must match except in dimension ", dim, ". Got ",
|
||||
static_cast<long long>(first_dim_size), " and ",
|
||||
static_cast<long long>(second_dim_size), " (The offending index is ",
|
||||
index, ")");
|
||||
"Sizes of tensors must match except in dimension ",
|
||||
dim,
|
||||
". Got ",
|
||||
static_cast<long long>(first_dim_size),
|
||||
" and ",
|
||||
static_cast<long long>(second_dim_size),
|
||||
" (The offending index is ",
|
||||
index,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
||||
@ -256,7 +222,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
bool all_same_sizes_and_stride,
|
||||
MemoryFormat memory_format,
|
||||
const Tensor& out) {
|
||||
|
||||
using namespace mps;
|
||||
|
||||
if (out.numel() == 0) {
|
||||
@ -271,12 +236,14 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
auto lap = at::get_overlap_status(out, t);
|
||||
TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
|
||||
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
|
||||
"of the output memory locations. Found overlap in input tensor ", idx);
|
||||
"of the output memory locations. Found overlap in input tensor ",
|
||||
idx);
|
||||
idx++;
|
||||
}
|
||||
// Check for type promotion
|
||||
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
|
||||
"torch.cat(): input types can't be cast to the desired output type ", out.scalar_type());
|
||||
"torch.cat(): input types can't be cast to the desired output type ",
|
||||
out.scalar_type());
|
||||
TORCH_CHECK(inputs.size() > 0, "torch.cat(): invalid number of inputs ", inputs.size());
|
||||
|
||||
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
|
||||
@ -288,9 +255,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
// this behavior for backwards compatibility, but only for this specific size
|
||||
// (i.e. other empty sizes are not skipped).
|
||||
// FIXME: warn if this is the case
|
||||
auto should_skip = [](const Tensor& t) {
|
||||
return t.dim() == 1 && at::native::size(t, 0) == 0;
|
||||
};
|
||||
auto should_skip = [](const Tensor& t) { return t.dim() == 1 && at::native::size(t, 0) == 0; };
|
||||
at::assert_no_internal_overlap(out);
|
||||
|
||||
Tensor notSkippedTensor;
|
||||
@ -317,11 +282,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
for (const Tensor& t : inputs) {
|
||||
TORCH_CHECK(t.device() == notSkippedTensor.device(),
|
||||
"torch.cat(): all input tensors must be on the same device. Received ",
|
||||
t.device(), " and ", notSkippedTensor.device());
|
||||
t.device(),
|
||||
" and ",
|
||||
notSkippedTensor.device());
|
||||
}
|
||||
TORCH_CHECK(out.device() == notSkippedTensor.device(),
|
||||
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
|
||||
notSkippedTensor.device(), " and out is on ", out.device());
|
||||
notSkippedTensor.device(),
|
||||
" and out is on ",
|
||||
out.device());
|
||||
|
||||
// TODO: For better performance by eliminating input tensor gathering and post transpose,
|
||||
// TODO: it is better to keep the out tensor's memory format.
|
||||
@ -361,8 +330,8 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/ true) +
|
||||
":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
|
||||
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
@ -383,7 +352,8 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
|
||||
newCachedGraph->inputTensors_[idx] =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous));
|
||||
if (tensor.scalar_type() != out_dtype) {
|
||||
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
|
||||
toType:getMPSDataType(out_dtype)
|
||||
@ -393,15 +363,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
}
|
||||
}
|
||||
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data()
|
||||
count:len_tensor_array];
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
|
||||
dimension:dimension // Maybe convert this from int64_t -> int32
|
||||
name:nil];
|
||||
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeBool
|
||||
name:@"outputTensor"];
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
|
||||
}
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
}
|
||||
@ -418,9 +385,11 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx],
|
||||
tensor,
|
||||
getMPSShape(tensor, MemoryFormat::Contiguous),
|
||||
/*gatherTensorData*/true, scalar_type);
|
||||
/*gatherTensorData*/ true,
|
||||
scalar_type);
|
||||
t_idx++;
|
||||
}
|
||||
i++;
|
||||
@ -430,16 +399,15 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
|
||||
outputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
for (auto& inputPlaceholder : inputPlaceholders) {
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
@ -17,13 +17,13 @@ namespace at::native {
|
||||
|
||||
void get_shapes(MPSShape* input_shape_readonly,
|
||||
NSMutableArray<NSNumber*>*& input_shape,
|
||||
int num_input_dims, c10::MemoryFormat memory_format) {
|
||||
int num_input_dims,
|
||||
c10::MemoryFormat memory_format) {
|
||||
// Modify the shape
|
||||
if (memory_format == at::MemoryFormat::Contiguous) {
|
||||
for (int i = 0; i < num_input_dims; i++)
|
||||
input_shape[i] = input_shape_readonly[i];
|
||||
}
|
||||
else { // ChannelsLast
|
||||
} else { // ChannelsLast
|
||||
auto num_channels = input_shape_readonly[1];
|
||||
input_shape[0] = input_shape_readonly[0];
|
||||
for (int i = 1; i < num_input_dims - 1; i++)
|
||||
@ -35,11 +35,7 @@ void get_shapes(MPSShape* input_shape_readonly,
|
||||
// Note - Currently only supported for 4D image tensors
|
||||
|
||||
TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
(const Tensor& input_,
|
||||
const int64_t dim,
|
||||
const bool half_to_float,
|
||||
const Tensor& output) {
|
||||
|
||||
(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) {
|
||||
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS");
|
||||
|
||||
if (input_.numel() == 0) {
|
||||
@ -49,24 +45,21 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
Tensor input;
|
||||
if (input_.dim() == 0) {
|
||||
input = input_.view(1);
|
||||
}
|
||||
else
|
||||
} else
|
||||
input = input_;
|
||||
|
||||
int64_t dim_ = maybe_wrap_dim(dim, input.dim());
|
||||
TORCH_CHECK(
|
||||
dim_ >= 0 && dim_ < input.dim(),
|
||||
"Softmax:dim must be non-negative and less than input dimensions");
|
||||
TORCH_CHECK(dim_ >= 0 && dim_ < input.dim(), "Softmax:dim must be non-negative and less than input dimensions");
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
// TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should match")
|
||||
// TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should
|
||||
// match")
|
||||
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -75,12 +68,12 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string mem_format_key = get_mem_format_string(memory_format);
|
||||
MPSShape* input_shape_readonly = mps::getMPSShape(input);
|
||||
int num_input_dims = [input_shape_readonly count];
|
||||
// Check - Channels last implies 4d
|
||||
TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4, "ChannelsLast implies 4d tensor")
|
||||
TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4,
|
||||
"ChannelsLast implies 4d tensor")
|
||||
// Input shape changes based on memory format
|
||||
NSMutableArray<NSNumber*>* input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
|
||||
|
||||
@ -105,8 +98,8 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":"
|
||||
+ [ns_shape_key UTF8String] + ":" + std::to_string(dim_);
|
||||
string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":" + [ns_shape_key UTF8String] +
|
||||
":" + std::to_string(dim_);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
@ -120,9 +113,7 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
|
||||
|
||||
// passing selector of softMaxWithTensor on the mpsGraph object
|
||||
MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor
|
||||
axis:(NSInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor axis:(NSInteger)dim_ name:nil];
|
||||
|
||||
// Output needs to be contiguous format
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
@ -134,14 +125,8 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
outputTensor = [mpsGraph reshapeTensor:outputTensor
|
||||
withShape:@[ N, ([NSNumber numberWithInt:[H intValue] * [W intValue]]), C ]
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor
|
||||
dimension:1
|
||||
withDimension:2
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:outputTensor
|
||||
withShape:@[N, C, H, W]
|
||||
name:nil];
|
||||
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor dimension:1 withDimension:2 name:nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:@[ N, C, H, W ] name:nil];
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
@ -156,25 +141,17 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
// This must be the Contiguous shape
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
(const Tensor& grad_,
|
||||
const Tensor& output_,
|
||||
int64_t dim,
|
||||
ScalarType input_dtype,
|
||||
const Tensor& grad_input) {
|
||||
|
||||
(const Tensor& grad_, const Tensor& output_, int64_t dim, ScalarType input_dtype, const Tensor& grad_input) {
|
||||
if (output_.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -182,28 +159,23 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
Tensor grad;
|
||||
if (grad_.dim() == 0) {
|
||||
grad = grad_.view(1);
|
||||
}
|
||||
else
|
||||
} else
|
||||
grad = grad_;
|
||||
|
||||
Tensor output;
|
||||
if (output_.dim() == 0) {
|
||||
output = output_.view(1);
|
||||
}
|
||||
else
|
||||
} else
|
||||
output = output_;
|
||||
|
||||
int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
|
||||
TORCH_CHECK(
|
||||
dim_ >= 0 && dim_ < grad.dim(),
|
||||
"Grad:dim must be non-negative and less than input dimensions");
|
||||
TORCH_CHECK(dim_ >= 0 && dim_ < grad.dim(), "Grad:dim must be non-negative and less than input dimensions");
|
||||
|
||||
using namespace mps;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* softmaxTensor_ = nil;
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
@ -213,12 +185,11 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
MPSShape* grad_shape = mps::getMPSShape(grad);
|
||||
NSString* ns_shape_key = [[grad_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":"
|
||||
+ [ns_shape_key UTF8String] + ":" + std::to_string(dim_);
|
||||
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::to_string(dim_);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
@ -235,9 +206,7 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
MPSGraphTensor* mulTensor = [mpsGraph multiplicationWithPrimaryTensor:softmaxTensor
|
||||
secondaryTensor:gradOutputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor
|
||||
axis:(NSInteger)dim_
|
||||
name:nil];
|
||||
MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor axis:(NSInteger)dim_ name:nil];
|
||||
MPSGraphTensor* gradSubTensor = [mpsGraph subtractionWithPrimaryTensor:gradOutputTensor
|
||||
secondaryTensor:mulSumTensor
|
||||
name:nil];
|
||||
@ -262,12 +231,10 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
softmaxPlaceholder.getMPSGraphTensor() : softmaxPlaceholder.getMPSGraphTensorData(),
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -50,8 +50,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) +
|
||||
":dim" + to_string(dim) + ":descending" + to_string(descending);
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) +
|
||||
":descending" + to_string(descending);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
@ -61,7 +61,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:(BOOL)descending
|
||||
@ -88,14 +89,10 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
feeds = @{ inputPlaceholder.getMPSGraphTensor() :
|
||||
inputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
valuesPlaceholder.getMPSGraphTensor() :
|
||||
valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() :
|
||||
indicesPlaceholder.getMPSGraphTensorData()
|
||||
valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(),
|
||||
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
@ -4,13 +4,10 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Tensor& bincount_mps_impl(const Tensor& self,
|
||||
const Tensor& weights,
|
||||
Tensor& output) {
|
||||
Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* weightsTensor_ = nil;
|
||||
@ -27,7 +24,6 @@ Tensor& bincount_mps_impl(const Tensor& self,
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -35,23 +31,19 @@ Tensor& bincount_mps_impl(const Tensor& self,
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
|
||||
MPSGraphTensor* scatterDataTensor =
|
||||
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
|
||||
|
||||
MPSGraphTensor* updatesTensor = nil;
|
||||
if (has_weights) {
|
||||
updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, weights);
|
||||
}
|
||||
else {
|
||||
updatesTensor = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape(self)
|
||||
dataType:getMPSDataType(output)];
|
||||
} else {
|
||||
updatesTensor = [mpsGraph constantWithScalar:1.0f shape:getMPSShape(self) dataType:getMPSDataType(output)];
|
||||
}
|
||||
|
||||
MPSGraphTensor* castedInputTensor = inputTensor;
|
||||
if (self.scalar_type() == kByte) {
|
||||
castedInputTensor = [mpsGraph castTensor:inputTensor
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castInputTensor"];
|
||||
castedInputTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"castInputTensor"];
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
|
||||
@ -88,9 +80,8 @@ Tensor& bincount_mps_impl(const Tensor& self,
|
||||
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
// Run the graph
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
@ -108,43 +99,32 @@ Tensor _bincount_mps(const Tensor& self, const c10::optional<Tensor>& weights_op
|
||||
TORCH_CHECK(minlength >= 0, "minlength should be >= 0");
|
||||
|
||||
if (self.dim() == 1 && self.numel() == 0) {
|
||||
return at::zeros(
|
||||
{minlength},
|
||||
kLong,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS,
|
||||
c10::nullopt /* pin_memory */);
|
||||
return at::zeros({minlength}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */);
|
||||
}
|
||||
TORCH_CHECK(self.dim() == 1 && self.min().item<int64_t>() >= 0, "bincount only supports 1-d non-negative integral inputs.");
|
||||
TORCH_CHECK(self.dim() == 1 && self.min().item<int64_t>() >= 0,
|
||||
"bincount only supports 1-d non-negative integral inputs.");
|
||||
|
||||
bool has_weights = weights.defined();
|
||||
TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))), "weights should be 1-d and have the same length as input");
|
||||
TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))),
|
||||
"weights should be 1-d and have the same length as input");
|
||||
|
||||
const int64_t nbins = std::max(self.max().item<int64_t>() + 1L, minlength);
|
||||
Tensor output;
|
||||
|
||||
Tensor weights_ = weights;
|
||||
if (has_weights) {
|
||||
if(weights.scalar_type() != ScalarType::Float &&
|
||||
weights.scalar_type() != ScalarType::Int &&
|
||||
if (weights.scalar_type() != ScalarType::Float && weights.scalar_type() != ScalarType::Int &&
|
||||
weights.scalar_type() != ScalarType::Half) {
|
||||
// Scatter doesn't work for int8/int16 dtypes
|
||||
weights_ = weights.to(kInt);
|
||||
}
|
||||
output = at::zeros(
|
||||
{nbins},
|
||||
output = at::zeros({nbins},
|
||||
optTypeMetaToScalarType(weights_.options().dtype_opt()),
|
||||
weights_.options().layout_opt(),
|
||||
weights_.options().device_opt(),
|
||||
weights_.options().pinned_memory_opt());
|
||||
}
|
||||
else {
|
||||
output = at::zeros(
|
||||
{nbins},
|
||||
kLong,
|
||||
c10::nullopt /* layout */,
|
||||
kMPS,
|
||||
c10::nullopt /* pin_memory */);
|
||||
} else {
|
||||
output = at::zeros({nbins}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */);
|
||||
}
|
||||
|
||||
return bincount_mps_impl(self, weights_, output);
|
||||
|
@ -1,21 +1,19 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
|
||||
MPSGraphTensor *minTensor = nil, *maxTensor = nil;
|
||||
};
|
||||
|
||||
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor)
|
||||
{
|
||||
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
||||
@ -36,35 +34,26 @@ void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor)
|
||||
}
|
||||
}
|
||||
|
||||
void check_min_max_dims(const OptionalTensorRef clamp_opt,
|
||||
const Tensor& input_t,
|
||||
string op_name) {
|
||||
|
||||
void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, string op_name) {
|
||||
if (!clamp_opt->is_same_size(input_t)) {
|
||||
|
||||
auto num_clamp_dims = clamp_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims, op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims,
|
||||
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
|
||||
for (int i = 0; i < num_clamp_dims; i++)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
|
||||
clamp_shape[num_clamp_dims-1-i] == 1 ||
|
||||
input_shape[num_input_dims-1-i] == 1,
|
||||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void fill_new_shape(int64_t num_input_dims,
|
||||
int64_t num_clamp_dims,
|
||||
int64_t *new_shape,
|
||||
IntArrayRef clamp_shape) {
|
||||
|
||||
void fill_new_shape(int64_t num_input_dims, int64_t num_clamp_dims, int64_t* new_shape, IntArrayRef clamp_shape) {
|
||||
// Extend the shape with ones to the left
|
||||
int clamp_idx = 0;
|
||||
for (int i = 0; i < num_input_dims; i++) {
|
||||
@ -81,8 +70,7 @@ void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
const OptionalTensorRef min_opt,
|
||||
const OptionalTensorRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name)
|
||||
{
|
||||
string op_name) {
|
||||
const bool has_min = (min_opt.has_value() && min_opt->defined());
|
||||
const bool has_max = (max_opt.has_value() && max_opt->defined());
|
||||
|
||||
@ -129,13 +117,12 @@ void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
|
||||
auto tensor_key = has_min ? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor})
|
||||
auto tensor_key = has_min
|
||||
? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor})
|
||||
: getTensorsStringKey({input_t, min_opt_tensor}))
|
||||
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor})
|
||||
: getTensorsStringKey({input_t}));
|
||||
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor}) : getTensorsStringKey({input_t}));
|
||||
|
||||
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "")
|
||||
+ "_tensor" + tensor_key;
|
||||
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") + "_tensor" + tensor_key;
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
@ -173,9 +160,8 @@ void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -185,8 +171,7 @@ void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
const OptionalScalarRef min_opt,
|
||||
const OptionalScalarRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name)
|
||||
{
|
||||
string op_name) {
|
||||
using scalar_t = double;
|
||||
|
||||
const bool has_min = (min_opt.has_value());
|
||||
@ -206,8 +191,8 @@ void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + (has_max ? ("_max:" + to_string(max_scalar)) : "")
|
||||
+ "_scalar:" + getTensorsStringKey({input_t});
|
||||
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
@ -220,13 +205,13 @@ void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
shape:(mps::getMPSShape(input_t))
|
||||
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];
|
||||
newCachedGraph->minTensor = [mpsGraph
|
||||
constantWithScalar:min_scalar
|
||||
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
|
||||
if (has_max)
|
||||
newCachedGraph->maxTensor = [mpsGraph constantWithScalar:max_scalar
|
||||
shape:(mps::getMPSShape(input_t))
|
||||
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];
|
||||
newCachedGraph->maxTensor = [mpsGraph
|
||||
constantWithScalar:max_scalar
|
||||
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
|
||||
|
||||
clamp_mps_graph(newCachedGraph, input_t);
|
||||
}
|
||||
@ -241,9 +226,8 @@ void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
@ -253,51 +237,45 @@ void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
// APIs exposed to at::native scope
|
||||
TORCH_IMPL_FUNC(clamp_Tensor_out_mps)
|
||||
(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t)
|
||||
{
|
||||
(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t) {
|
||||
mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_out_mps)
|
||||
(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t)
|
||||
{
|
||||
(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t) {
|
||||
mps::clamp_scalar_out_mps(input_t, min, max, const_cast<Tensor&>(output_t), "clamp_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_min_Tensor_out_mps)
|
||||
(const Tensor& input_t, const Tensor& min, const Tensor& output_t)
|
||||
{
|
||||
(const Tensor& input_t, const Tensor& min, const Tensor& output_t) {
|
||||
mps::clamp_tensor_out_mps(input_t, min, at::OptionalTensorRef(), output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_min_out_mps)
|
||||
(const Tensor& input_t, const Scalar& min, const Tensor& output_t)
|
||||
{
|
||||
(const Tensor& input_t, const Scalar& min, const Tensor& output_t) {
|
||||
mps::clamp_scalar_out_mps(input_t, min, at::OptionalScalarRef(), output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_max_Tensor_out_mps)
|
||||
(const Tensor& input_t, const Tensor& max, const Tensor& output_t)
|
||||
{
|
||||
(const Tensor& input_t, const Tensor& max, const Tensor& output_t) {
|
||||
mps::clamp_tensor_out_mps(input_t, at::OptionalTensorRef(), max, output_t, __func__);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_max_out_mps)
|
||||
(const Tensor& input_t, const Scalar& max, const Tensor& output_t)
|
||||
{
|
||||
(const Tensor& input_t, const Scalar& max, const Tensor& output_t) {
|
||||
mps::clamp_scalar_out_mps(input_t, at::OptionalScalarRef(), max, output_t, __func__);
|
||||
}
|
||||
|
||||
Tensor& where_self_out_mps(const Tensor& condition,
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
Tensor& out) {
|
||||
Tensor& where_self_out_mps(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
|
||||
|
||||
if (condition.scalar_type() == ScalarType::Byte) {
|
||||
TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
|
||||
TORCH_WARN_ONCE(
|
||||
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
|
||||
} else {
|
||||
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition.scalar_type());
|
||||
TORCH_CHECK(condition.scalar_type() == ScalarType::Bool,
|
||||
"where expected condition to be a boolean tensor, but got a tensor with dtype ",
|
||||
condition.scalar_type());
|
||||
}
|
||||
Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
|
||||
|
||||
@ -309,8 +287,7 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
return out;
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* conditionTensor_ = nil;
|
||||
MPSGraphTensor* selfTensor_ = nil;
|
||||
@ -338,21 +315,20 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other});
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
|
||||
MPSGraphTensor* conditionTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
|
||||
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self));
|
||||
MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other));
|
||||
|
||||
@ -373,10 +349,10 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
|
||||
Placeholder conditionPlaceholder = Placeholder(
|
||||
cachedGraph->conditionTensor_, cond_bool, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, conditionDataType);
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
|
||||
Placeholder otherPlaceholder = Placeholder(
|
||||
cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
|
||||
Placeholder selfPlaceholder =
|
||||
Placeholder(cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
|
||||
Placeholder otherPlaceholder =
|
||||
Placeholder(cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
@ -384,21 +360,16 @@ Tensor& where_self_out_mps(const Tensor& condition,
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor where_mps(const Tensor& condition,
|
||||
const Tensor& self,
|
||||
const Tensor& other) {
|
||||
|
||||
Tensor where_mps(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
auto max_dim = std::max(condition.dim(), std::max(self.dim(), other.dim()));
|
||||
|
||||
// How many leading dimensions do we broadcast across for each Tensor?
|
||||
@ -410,7 +381,6 @@ Tensor where_mps(const Tensor& condition,
|
||||
|
||||
// Broadcasted output shape
|
||||
for (int i = 0; i < max_dim; i++) {
|
||||
|
||||
// Use up the leading broadcast dimensions for each Tensor, then continue from the start of the "actual" shape
|
||||
int64_t cond_idx = i < cond_num_implicit_ones ? 1 : (condition.size(i - cond_num_implicit_ones));
|
||||
int64_t self_idx = i < self_num_implicit_ones ? 1 : (self.size(i - self_num_implicit_ones));
|
||||
@ -418,21 +388,28 @@ Tensor where_mps(const Tensor& condition,
|
||||
|
||||
auto max_idx = std::max({cond_idx, self_idx, other_idx});
|
||||
|
||||
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1), i, "'th index ", cond_idx, " of condition tensor does not match the other tensors")
|
||||
TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1), i, "'th index ", self_idx, " of x tensor does not match the other tensors")
|
||||
TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1), i, "'th index ", other_idx, " of x tensor does not match the other tensors")
|
||||
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1),
|
||||
i,
|
||||
"'th index ",
|
||||
cond_idx,
|
||||
" of condition tensor does not match the other tensors")
|
||||
TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1),
|
||||
i,
|
||||
"'th index ",
|
||||
self_idx,
|
||||
" of x tensor does not match the other tensors")
|
||||
TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1),
|
||||
i,
|
||||
"'th index ",
|
||||
other_idx,
|
||||
" of x tensor does not match the other tensors")
|
||||
|
||||
out_arr[i] = (cond_idx == 0 || self_idx == 0 || other_idx == 0) ? 0 : max_idx;
|
||||
}
|
||||
|
||||
Tensor ret = empty_mps(IntArrayRef(out_arr),
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
self.suggest_memory_format());
|
||||
Tensor ret = empty_mps(
|
||||
IntArrayRef(out_arr), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format());
|
||||
return where_self_out_mps(condition, self, other, ret);
|
||||
|
||||
}
|
||||
|
||||
Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
@ -440,8 +417,11 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
c10::optional<double> pos_inf,
|
||||
c10::optional<double> neg_inf,
|
||||
Tensor& result) {
|
||||
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "nan_to_num: dtype of out: ",
|
||||
result.scalar_type(), " should be same as input: ", self.scalar_type());
|
||||
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
|
||||
"nan_to_num: dtype of out: ",
|
||||
result.scalar_type(),
|
||||
" should be same as input: ",
|
||||
self.scalar_type());
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
@ -478,12 +458,14 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
|
||||
newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]);
|
||||
|
||||
MPSGraphTensor* nanFreeTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isNaNWithTensor: newCachedGraph->selfTensor name:nil]
|
||||
MPSGraphTensor* nanFreeTensor =
|
||||
[mpsGraph selectWithPredicateTensor:[mpsGraph isNaNWithTensor:newCachedGraph->selfTensor name:nil]
|
||||
truePredicateTensor:newCachedGraph->nanReplacementTensor
|
||||
falsePredicateTensor:newCachedGraph->selfTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor:nanFreeTensor
|
||||
secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype]
|
||||
secondaryTensor:[mpsGraph constantWithScalar:0.0
|
||||
dataType:self_dtype]
|
||||
name:nil];
|
||||
MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor:nanFreeTensor name:nil];
|
||||
// workaround for Monterey; On Ventura the output of lessThan() is always Boolean
|
||||
@ -500,7 +482,8 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
truePredicateTensor:newCachedGraph->negInfReplacementTensor
|
||||
falsePredicateTensor:nanFreeTensor
|
||||
name:nil];
|
||||
newCachedGraph->outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isInfiniteWithTensor: negInfFreeTensor name:nil]
|
||||
newCachedGraph->outputTensor =
|
||||
[mpsGraph selectWithPredicateTensor:[mpsGraph isInfiniteWithTensor:negInfFreeTensor name:nil]
|
||||
truePredicateTensor:newCachedGraph->posInfReplacementTensor
|
||||
falsePredicateTensor:negInfFreeTensor
|
||||
name:nil];
|
||||
@ -511,12 +494,10 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
MPSScalar nanReplacementScalar, posInfReplacementScalar, negInfReplacementScalar;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, self.scalar_type(), "nan_to_num_mps", [&]() {
|
||||
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
|
||||
scalar_t pos_inf_replacement = pos_inf.has_value() ?
|
||||
static_cast<scalar_t>(pos_inf.value()) :
|
||||
std::numeric_limits<scalar_t>::max();
|
||||
scalar_t neg_inf_replacement = neg_inf.has_value() ?
|
||||
static_cast<scalar_t>(neg_inf.value()) :
|
||||
std::numeric_limits<scalar_t>::lowest();
|
||||
scalar_t pos_inf_replacement =
|
||||
pos_inf.has_value() ? static_cast<scalar_t>(pos_inf.value()) : std::numeric_limits<scalar_t>::max();
|
||||
scalar_t neg_inf_replacement =
|
||||
neg_inf.has_value() ? static_cast<scalar_t>(neg_inf.value()) : std::numeric_limits<scalar_t>::lowest();
|
||||
|
||||
nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type());
|
||||
posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type());
|
||||
@ -533,9 +514,8 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
cachedGraph->posInfReplacementTensor : getMPSGraphTensorFromScalar(stream, posInfReplacementScalar),
|
||||
cachedGraph->negInfReplacementTensor : getMPSGraphTensorFromScalar(stream, negInfReplacementScalar),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return result;
|
||||
|
@ -14,10 +14,7 @@
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(triu_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t k,
|
||||
const Tensor &output) {
|
||||
|
||||
(const Tensor& self, int64_t k, const Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
if (self.numel() == 0) {
|
||||
@ -26,8 +23,7 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -50,12 +46,10 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
|
||||
|
||||
if (k > 0) {
|
||||
MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k-1)
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:minusOneTensor
|
||||
numUpperTensor:diagMinusOneTensor
|
||||
@ -63,10 +57,8 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:complementTensor
|
||||
name:nil];
|
||||
}
|
||||
else {
|
||||
MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k)
|
||||
dataType:MPSDataTypeInt32];
|
||||
} else {
|
||||
MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32];
|
||||
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:minusDiagTensor
|
||||
numUpperTensor:minusOneTensor
|
||||
@ -84,23 +76,17 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(tril_mps_out)
|
||||
(const Tensor& self,
|
||||
int64_t k,
|
||||
const Tensor &output) {
|
||||
|
||||
(const Tensor& self, int64_t k, const Tensor& output) {
|
||||
using namespace mps;
|
||||
|
||||
if (self.numel() == 0) {
|
||||
@ -109,8 +95,7 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Derive from MPSCachedGraph
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -133,20 +118,16 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32];
|
||||
|
||||
if (k >= 0) {
|
||||
MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32];
|
||||
outputTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:minusOneTensor
|
||||
numUpperTensor:diagTensor
|
||||
name:nil];
|
||||
}
|
||||
else {
|
||||
MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k-1)
|
||||
dataType:MPSDataTypeInt32];
|
||||
} else {
|
||||
MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
numLowerTensor:negDiagMinusOneTensor
|
||||
numUpperTensor:minusOneTensor
|
||||
@ -167,16 +148,13 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
@ -9,13 +9,15 @@ namespace mps {
|
||||
typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*);
|
||||
using is_noop_p = std::function<bool(const Tensor&)>;
|
||||
|
||||
|
||||
bool is_empty_tensor(const Tensor& self) {
|
||||
return self.numel() == 0;
|
||||
}
|
||||
|
||||
void unary_op(const Tensor& self, const Tensor& output, std::string op_name, UnaryOpBlock unaryBlock, is_noop_p is_noop = is_empty_tensor)
|
||||
{
|
||||
void unary_op(const Tensor& self,
|
||||
const Tensor& output,
|
||||
std::string op_name,
|
||||
UnaryOpBlock unaryBlock,
|
||||
is_noop_p is_noop = is_empty_tensor) {
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
|
||||
"MPS support unary op with uint8 natively starting from macOS 13.0");
|
||||
if (!output.is_same_size(self)) {
|
||||
@ -55,18 +57,15 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, /*mpsShape=*/nullptr, gatherTensorData);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
|
||||
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
{
|
||||
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// Rounding is a no-op for integral types, and also a reasonable workaround
|
||||
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
|
||||
// See https://github.com/pytorch/pytorch/issues/84995
|
||||
@ -76,8 +75,7 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
}
|
||||
|
||||
if (!is_macos_13_or_newer()) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
@ -86,40 +84,32 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
} else {
|
||||
return [mpsGraph truncateWithTensor:inputTensor
|
||||
name:nil];
|
||||
return [mpsGraph truncateWithTensor:inputTensor name:nil];
|
||||
}
|
||||
};
|
||||
|
||||
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:oneTensor
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:addedTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor secondaryTensor:oneTensor name:nil];
|
||||
return [mpsGraph logarithmWithTensor:addedTensor name:nil];
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
TORCH_IMPL_FUNC(trunc_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "trunc_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
{ return mps::trunc_tensor(mpsGraph, inputTensor); });
|
||||
mps::unary_op(self, output, "trunc_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return mps::trunc_tensor(mpsGraph, inputTensor);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(signbit_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "signbit_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
mps::unary_op(self, output, "signbit_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* output;
|
||||
// signbit is not implemented for int64 type.
|
||||
// workaround for `Function signbitOp_i64 was not found in the library`
|
||||
if ([inputTensor dataType] == MPSDataTypeInt64) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
output = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
|
||||
} else {
|
||||
output = [mpsGraph signbitWithTensor:inputTensor name:nil];
|
||||
}
|
||||
@ -128,8 +118,7 @@ TORCH_IMPL_FUNC(signbit_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(sign_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "sign_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
mps::unary_op(self, output, "sign_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// Sign op is not implemented in MPS as of MacOS13.0 beta, so simulate it using clamp
|
||||
if ([inputTensor dataType] == MPSDataTypeInt64) {
|
||||
return [mpsGraph clampWithTensor:inputTensor
|
||||
@ -143,12 +132,14 @@ TORCH_IMPL_FUNC(sign_out_mps) (const Tensor& self, const Tensor& output) {
|
||||
|
||||
#define CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, \
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
|
||||
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }, \
|
||||
[](const Tensor& t) -> bool { \
|
||||
return t.numel() == 0 || isIntegralType(t.scalar_type(), true); \
|
||||
}); \
|
||||
mps::unary_op( \
|
||||
self, \
|
||||
output, \
|
||||
#func_out, \
|
||||
^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
|
||||
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
|
||||
}, \
|
||||
[](const Tensor& t) -> bool { return t.numel() == 0 || isIntegralType(t.scalar_type(), true); }); \
|
||||
}
|
||||
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(ceil_out_mps, ceil)
|
||||
CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(floor_out_mps, floor)
|
||||
@ -156,20 +147,19 @@ CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(round_out_mps, round)
|
||||
|
||||
#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, \
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
|
||||
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \
|
||||
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
|
||||
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
|
||||
}); \
|
||||
}
|
||||
|
||||
#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
Tensor& func_out(const Tensor& self, Tensor& output) { \
|
||||
mps::unary_op(self, output, #func_out, \
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \
|
||||
{ return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \
|
||||
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
|
||||
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
|
||||
}); \
|
||||
return output; \
|
||||
}
|
||||
|
||||
|
||||
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
|
||||
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
|
||||
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal)
|
||||
@ -195,81 +185,59 @@ CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
|
||||
|
||||
CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute)
|
||||
|
||||
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output)
|
||||
{
|
||||
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
|
||||
auto bool_self = self.to(ScalarType::Bool);
|
||||
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor){ return [mpsGraph notWithTensor:inputTensor name:nil];});
|
||||
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph notWithTensor:inputTensor name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(sigmoid_out_mps) (const Tensor& self, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(sigmoid_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support sigmoid op with int64 input");
|
||||
mps::unary_op(self, output, "sigmoid_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
mps::unary_op(self, output, "sigmoid_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph sigmoidWithTensor:inputTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(log1p_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input");
|
||||
mps::unary_op(self, output, "log1p_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
mps::unary_op(self, output, "log1p_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return mps::log1p(mpsGraph, inputTensor);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
|
||||
mps::unary_op(self, output, "frac_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
|
||||
auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
|
||||
auto truncTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil]
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:truncTensor
|
||||
name: nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:truncTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
mps::unary_op(self, output, "expm1_out_mps",
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:ePowTensor
|
||||
secondaryTensor:oneTensor
|
||||
name: nil];
|
||||
mps::unary_op(self, output, "expm1_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:ePowTensor secondaryTensor:oneTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& output, const std::string op_name) {
|
||||
std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]";
|
||||
|
||||
mps::unary_op(self, output, key,
|
||||
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
mps::unary_op(self, output, key, ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* logitInputTensor;
|
||||
|
||||
if (eps.has_value()) {
|
||||
MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value()
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
|
||||
secondaryTensor: lowTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps.value() shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor secondaryTensor:lowTensor name:nil];
|
||||
logitInputTensor = [mpsGraph clampWithTensor:inputTensor
|
||||
minValueTensor:lowTensor
|
||||
maxValueTensor:highTensor
|
||||
@ -284,36 +252,24 @@ void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& outpu
|
||||
MPSGraphTensor* outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
|
||||
secondaryTensor:oneMinusInputTensor
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:outputTensor
|
||||
name:nil];
|
||||
return [mpsGraph logarithmWithTensor:outputTensor name:nil];
|
||||
});
|
||||
}
|
||||
|
||||
Tensor& logit_out_mps(const Tensor& self,
|
||||
c10::optional<double> eps,
|
||||
Tensor& result) {
|
||||
Tensor& logit_out_mps(const Tensor& self, c10::optional<double> eps, Tensor& result) {
|
||||
logit_mps_impl(self, eps, result, "logit_out_mps");
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
|
||||
Tensor result = at::native::empty_mps(
|
||||
self.sizes(),
|
||||
ScalarType::Float,
|
||||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
c10::nullopt);
|
||||
Tensor result =
|
||||
at::native::empty_mps(self.sizes(), ScalarType::Float, c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
|
||||
logit_mps_impl(self, eps, result, "logit_mps");
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
c10::optional<double> eps,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
TORCH_IMPL_FUNC(logit_backward_out_mps)
|
||||
(const Tensor& grad_output, const Tensor& input, c10::optional<double> eps, const Tensor& grad_input) {
|
||||
using namespace mps;
|
||||
|
||||
// Empty output
|
||||
@ -322,8 +278,7 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
|
||||
double eps_ = eps ? eps.value() : -1.0;
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
@ -335,13 +290,12 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" +
|
||||
"[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]";
|
||||
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + "[" +
|
||||
(eps.has_value() ? std::to_string(eps.value()) : "-1") + "]";
|
||||
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
CachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -351,15 +305,9 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input);
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_
|
||||
shape:@[@1]
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:lowTensor
|
||||
name:nil];
|
||||
@ -378,9 +326,7 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:oneMinusInputTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor
|
||||
secondaryTensor:outputTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor secondaryTensor:outputTensor name:nil];
|
||||
outputTensor = [mpsGraph selectWithPredicateTensor:outOfIntervalTensor
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:outputTensor
|
||||
@ -403,25 +349,25 @@ TORCH_IMPL_FUNC(logit_backward_out_mps) (
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
TORCH_IMPL_FUNC(cumsum_out_mps)
|
||||
(const Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<ScalarType> dtype,
|
||||
const Tensor& result) {
|
||||
|
||||
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
auto nDims = self.dim();
|
||||
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
|
||||
TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
|
||||
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
|
||||
"Expected wrapped dim to be between 0 and ",
|
||||
self.ndimension(),
|
||||
" but got ",
|
||||
wrapped_dim,
|
||||
"(original dim is ",
|
||||
dim,
|
||||
")");
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
|
||||
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
|
||||
@ -430,24 +376,22 @@ TORCH_IMPL_FUNC(cumsum_out_mps)
|
||||
}
|
||||
auto input = dtype.has_value() ? self.to(dtype.value()) : self;
|
||||
|
||||
// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
|
||||
// fixed in macOS 13.3
|
||||
bool castInputData = (isIntegralType(input.scalar_type()) &&
|
||||
input.scalar_type() != ScalarType::Int &&
|
||||
// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to
|
||||
// int32 fixed in macOS 13.3
|
||||
bool castInputData = (isIntegralType(input.scalar_type()) && input.scalar_type() != ScalarType::Int &&
|
||||
input.scalar_type() != ScalarType::Long);
|
||||
|
||||
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
|
||||
"MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");
|
||||
|
||||
mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
|
||||
mps::unary_op(input,
|
||||
result,
|
||||
"cumsum_out_mp" + std::to_string(dim),
|
||||
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
|
||||
if (castInputData) {
|
||||
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
|
||||
}
|
||||
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
|
||||
axis: dim
|
||||
name: nil];
|
||||
auto rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
|
||||
if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) {
|
||||
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
|
||||
}
|
||||
|
@ -1,14 +1,13 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct UniqueCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct UniqueCachedGraph : public MPSCachedGraph {
|
||||
UniqueCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
@ -17,17 +16,24 @@ struct UniqueCachedGraph : public MPSCachedGraph
|
||||
MPSGraphTensor* lengthTensor_ = nil;
|
||||
};
|
||||
|
||||
static std::string getUniqueKey(const ScalarType& dtype, const IntArrayRef& base_shape,
|
||||
const bool return_inverse, const bool return_counts,
|
||||
const bool consecutive, c10::optional<int64_t> dimOpt)
|
||||
{
|
||||
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) +
|
||||
"]:[" + (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) +
|
||||
"]:[" + to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
|
||||
static std::string getUniqueKey(const ScalarType& dtype,
|
||||
const IntArrayRef& base_shape,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
(dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" +
|
||||
to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
|
||||
}
|
||||
|
||||
// dim arg not supported when non consecutive, ie sorted
|
||||
std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCachedGraph *uniqueGraph, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dimOpt) {
|
||||
std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self,
|
||||
UniqueCachedGraph* uniqueGraph,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0;
|
||||
|
||||
MPSGraph* graph = uniqueGraph->graph();
|
||||
@ -44,34 +50,25 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
if (length <= 1) {
|
||||
// Trivial case, only 1 element everything is unique
|
||||
resultTensor = inputTensor;
|
||||
lengthTensor = [graph constantWithScalar:0.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
lengthTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
|
||||
if (return_inverse) {
|
||||
inverseIndicesTensor = [graph constantWithScalar:0.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
inverseIndicesTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
|
||||
}
|
||||
if (return_counts) {
|
||||
countTensor = [graph constantWithScalar:1.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
countTensor = [graph constantWithScalar:1.0f dataType:MPSDataTypeInt32];
|
||||
}
|
||||
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
|
||||
}
|
||||
|
||||
// #issue 104398441 sortWithTensor only supports following types, cast if necessary
|
||||
if (dataType != MPSDataTypeInt32 &&
|
||||
dataType != MPSDataTypeFloat32 &&
|
||||
dataType != MPSDataTypeFloat16) {
|
||||
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
inputTensor = [graph castTensor:inputTensor
|
||||
toType:dataType
|
||||
name:@"castInputTensor"];
|
||||
inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
|
||||
bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
|
||||
if (needsFlatten) {
|
||||
inputTensor = [graph reshapeTensor:inputTensor
|
||||
withShape:@[@-1]
|
||||
name:nil];
|
||||
inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
|
||||
length = 1;
|
||||
for (const auto i : c10::irange([shape count])) {
|
||||
if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
|
||||
@ -86,27 +83,15 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
if (consecutive) {
|
||||
sortedInput = inputTensor;
|
||||
} else {
|
||||
sortedInput = [graph sortWithTensor:inputTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
sortedInput = [graph sortWithTensor:inputTensor axis:0 name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor *frontNMinusOne = [graph sliceTensor:sortedInput
|
||||
dimension:dim
|
||||
start:0
|
||||
length:length-1
|
||||
name:nil];
|
||||
MPSGraphTensor *backNMinusOne = [graph sliceTensor:sortedInput
|
||||
dimension:dim
|
||||
start:1
|
||||
length:length-1
|
||||
name:nil];
|
||||
MPSGraphTensor* frontNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:0 length:length - 1 name:nil];
|
||||
MPSGraphTensor* backNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:1 length:length - 1 name:nil];
|
||||
MPSGraphTensor* notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne
|
||||
secondaryTensor:frontNMinusOne
|
||||
name:nil];
|
||||
MPSGraphTensor *mask = [graph castTensor:notEqualToPreviousElement
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"castMaskTensor"];
|
||||
MPSGraphTensor* mask = [graph castTensor:notEqualToPreviousElement toType:MPSDataTypeInt32 name:@"castMaskTensor"];
|
||||
|
||||
// If comparing tensors, not scalars, check if entire tensor matches previos element using reductionOr over tensor
|
||||
if (dimOpt.has_value() && [shape count] != 1) {
|
||||
@ -116,40 +101,23 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
[axes addObject:[NSNumber numberWithUnsignedInteger:axis]];
|
||||
}
|
||||
}
|
||||
mask = [graph reductionOrWithTensor:mask
|
||||
axes:axes
|
||||
name:nil];
|
||||
mask = [graph squeezeTensor:mask
|
||||
axes:axes
|
||||
name:nil];
|
||||
mask = [graph reductionOrWithTensor:mask axes:axes name:nil];
|
||||
mask = [graph squeezeTensor:mask axes:axes name:nil];
|
||||
[axes release];
|
||||
}
|
||||
|
||||
MPSGraphTensor *scannedIndices = [graph cumulativeSumWithTensor:mask
|
||||
axis:0
|
||||
name:nil];
|
||||
lengthTensor = [graph sliceTensor:scannedIndices
|
||||
dimension:0
|
||||
start:length-2
|
||||
length:1
|
||||
name:nil];
|
||||
MPSGraphTensor* scannedIndices = [graph cumulativeSumWithTensor:mask axis:0 name:nil];
|
||||
lengthTensor = [graph sliceTensor:scannedIndices dimension:0 start:length - 2 length:1 name:nil];
|
||||
|
||||
MPSGraphTensor *minusOneTensor = [graph constantWithScalar:-1.0f
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* minusOneTensor = [graph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* maskedIndices = [graph selectWithPredicateTensor:mask
|
||||
truePredicateTensor:scannedIndices
|
||||
falsePredicateTensor:minusOneTensor
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor *zeroTensor = [graph constantWithScalar:0.0f
|
||||
shape:@[@1]
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor *maskedIndicesWithHead = [graph concatTensors:@[zeroTensor, maskedIndices]
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor *scannedIndicesWithHead = [graph concatTensors:@[zeroTensor, scannedIndices]
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* zeroTensor = [graph constantWithScalar:0.0f shape:@[ @1 ] dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* maskedIndicesWithHead = [graph concatTensors:@[ zeroTensor, maskedIndices ] dimension:0 name:nil];
|
||||
MPSGraphTensor* scannedIndicesWithHead = [graph concatTensors:@[ zeroTensor, scannedIndices ] dimension:0 name:nil];
|
||||
|
||||
resultTensor = [graph scatterWithUpdatesTensor:sortedInput
|
||||
indicesTensor:maskedIndicesWithHead
|
||||
@ -159,9 +127,7 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
name:nil];
|
||||
// Cast back if necessary
|
||||
if ([uniqueGraph->inputTensor_ dataType] != dataType) {
|
||||
resultTensor = [graph castTensor:resultTensor
|
||||
toType:[uniqueGraph->inputTensor_ dataType]
|
||||
name:@"castResultTensor"];
|
||||
resultTensor = [graph castTensor:resultTensor toType:[uniqueGraph->inputTensor_ dataType] name:@"castResultTensor"];
|
||||
}
|
||||
|
||||
// Compute optional returned tensors if requested
|
||||
@ -172,9 +138,7 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
withShape:@[ [NSNumber numberWithUnsignedInteger:length] ]
|
||||
name:nil];
|
||||
else
|
||||
argSortedInput = [graph argSortWithTensor:inputTensor
|
||||
axis:0
|
||||
name:nil];
|
||||
argSortedInput = [graph argSortWithTensor:inputTensor axis:0 name:nil];
|
||||
inverseIndicesTensor = [graph scatterWithUpdatesTensor:scannedIndicesWithHead
|
||||
indicesTensor:argSortedInput
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
|
||||
@ -182,9 +146,7 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
mode:MPSGraphScatterModeAdd
|
||||
name:nil];
|
||||
if (needsFlatten)
|
||||
inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor
|
||||
withShape:shape
|
||||
name:nil];
|
||||
inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor withShape:shape name:nil];
|
||||
}
|
||||
|
||||
if (return_counts) {
|
||||
@ -202,7 +164,11 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self, UniqueCached
|
||||
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
|
||||
}
|
||||
|
||||
static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dim) {
|
||||
static UniqueCachedGraph* getUniqueGraph(const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dim) {
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
@ -210,7 +176,6 @@ static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_i
|
||||
UniqueCachedGraph* cachedGraph = static_cast<UniqueCachedGraph*>(cache_->LookUp(key));
|
||||
if (!cachedGraph) {
|
||||
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
|
||||
|
||||
UniqueCachedGraph* newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -238,9 +203,14 @@ static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_i
|
||||
}
|
||||
}
|
||||
|
||||
void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& output,
|
||||
Tensor& inverse_indices, Tensor& counts, Tensor& length,
|
||||
bool return_inverse, bool return_counts){
|
||||
void runUniqueGraph(UniqueCachedGraph* uniqueGraph,
|
||||
const Tensor& input,
|
||||
Tensor& output,
|
||||
Tensor& inverse_indices,
|
||||
Tensor& counts,
|
||||
Tensor& length,
|
||||
bool return_inverse,
|
||||
bool return_counts) {
|
||||
Placeholder inputPlaceholder = Placeholder(uniqueGraph->inputTensor_, input);
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
@ -249,10 +219,8 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary];
|
||||
Placeholder outputPlaceholder = Placeholder(uniqueGraph->outputTensor_, output);
|
||||
Placeholder lengthPlaceholder = Placeholder(uniqueGraph->lengthTensor_, length);
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData()
|
||||
forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:lengthPlaceholder.getMPSGraphTensorData()
|
||||
forKey:lengthPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:lengthPlaceholder.getMPSGraphTensorData() forKey:lengthPlaceholder.getMPSGraphTensor()];
|
||||
if (return_inverse) {
|
||||
Placeholder inverseIndicesPlaceholder = Placeholder(uniqueGraph->inverseIndicesTensor_, inverse_indices);
|
||||
[results setObject:inverseIndicesPlaceholder.getMPSGraphTensorData()
|
||||
@ -260,8 +228,7 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
|
||||
}
|
||||
if (return_counts) {
|
||||
Placeholder countsPlaceholder = Placeholder(uniqueGraph->countsTensor_, counts);
|
||||
[results setObject:countsPlaceholder.getMPSGraphTensorData()
|
||||
forKey:countsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:countsPlaceholder.getMPSGraphTensorData() forKey:countsPlaceholder.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
// Run the graph
|
||||
@ -271,9 +238,11 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor&
|
||||
|
||||
} // namespace mps
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique_impl_mps(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional<int64_t> dimOpt) {
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique_impl_mps(const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
const Tensor& input = self.contiguous();
|
||||
|
||||
// get flat output size
|
||||
@ -316,17 +285,14 @@ _unique_impl_mps(const Tensor& self, const bool return_inverse, const bool retur
|
||||
return std::make_tuple(output, inverse_indices, counts);
|
||||
}
|
||||
|
||||
|
||||
static
|
||||
std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
|
||||
return std::make_tuple(
|
||||
get<0>(out).to("mps"),
|
||||
get<1>(out).to("mps"),
|
||||
get<2>(out).to("mps"));
|
||||
static std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
|
||||
return std::make_tuple(get<0>(out).to("mps"), get<1>(out).to("mps"), get<2>(out).to("mps"));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
|
||||
std::tuple<Tensor, Tensor, Tensor> unique_consecutive_mps(const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts,
|
||||
c10::optional<int64_t> dim) {
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("MPS: unique_consecutive op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performace implications.");
|
||||
@ -336,8 +302,10 @@ unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, true, dim);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_inverse, const bool return_counts) {
|
||||
std::tuple<Tensor, Tensor, Tensor> unique_dim_consecutive_mps(const Tensor& self,
|
||||
int64_t dim,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("MPS: unique_dim_consecutive op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performace implications.");
|
||||
@ -347,8 +315,10 @@ unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_in
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, true, c10::make_optional((int64_t)dim));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique2_mps(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
|
||||
const bool sorted,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
if (!is_macos_13_or_newer()) {
|
||||
TORCH_WARN_ONCE("MPS: _unique2 op is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performace implications.");
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/UpSample.h>
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
@ -34,9 +34,8 @@ void upsample_out_template(const Tensor& input,
|
||||
bool centerResults = false;
|
||||
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
|
||||
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
|
||||
MPSGraphTensorNamedDataLayout dataLayout = input_dim.size() > 3 ?
|
||||
MPSGraphTensorNamedDataLayoutNCHW :
|
||||
MPSGraphTensorNamedDataLayoutCHW;
|
||||
MPSGraphTensorNamedDataLayout dataLayout =
|
||||
input_dim.size() > 3 ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutCHW;
|
||||
if (resize_mode_str == "nearest") {
|
||||
resizeMode = MPSGraphResizeNearest;
|
||||
} else if (resize_mode_str == "bilinear") {
|
||||
@ -102,7 +101,8 @@ void upsample_out_template(const Tensor& input,
|
||||
inputSizeVec[2] = @(input_size[2]);
|
||||
inputSizeVec[3] = @(input_dim.size() > 3 ? input_size[3] : 1);
|
||||
inputSizeTensor = [mpsGraph constantWithScalar:0
|
||||
shape: [NSArray arrayWithObjects:inputSizeVec.data() count:input_dim.size()]
|
||||
shape:[NSArray arrayWithObjects:inputSizeVec.data()
|
||||
count:input_dim.size()]
|
||||
dataType:getMPSDataType(input)];
|
||||
}
|
||||
if (is_macOS_13_0_or_newer) {
|
||||
@ -204,15 +204,15 @@ void upsample_out_template(const Tensor& input,
|
||||
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:sizeNDArray] autorelease];
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
cachedGraph->outputSizeTensor : sizeTensorData,
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
||||
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
if (out.has_storage()) {
|
||||
@ -223,8 +223,7 @@ void upsample_out_template(const Tensor& input,
|
||||
|
||||
} // namespace mps
|
||||
|
||||
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale)
|
||||
{
|
||||
static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale) {
|
||||
static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
|
||||
if (!is_macOS_13_0_or_newer) {
|
||||
// passing scale factors to MPS's resize APIs is not supported on macOS < 13
|
||||
@ -236,7 +235,9 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
|
||||
// is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey.
|
||||
// The nearest mode should work fine on Ventura.
|
||||
} else if (resize_mode_str == "nearest" || resize_mode_str == "nearest-exact") {
|
||||
TORCH_WARN_ONCE("MPS: '", resize_mode_str, "' mode upsampling is supported natively starting from macOS 13.0. ",
|
||||
TORCH_WARN_ONCE("MPS: '",
|
||||
resize_mode_str,
|
||||
"' mode upsampling is supported natively starting from macOS 13.0. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
return false;
|
||||
}
|
||||
@ -244,12 +245,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
|
||||
return true;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest", scale)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest");
|
||||
} else {
|
||||
@ -258,27 +255,23 @@ TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest", scale)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& output)
|
||||
{
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps)
|
||||
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest-exact", scale)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact");
|
||||
} else {
|
||||
@ -287,113 +280,123 @@ TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scale,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest-exact", scale)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
|
||||
mps::upsample_out_template(
|
||||
grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) (
|
||||
const Tensor& input,
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output)
|
||||
{
|
||||
const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest", scales_w)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) = at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest", scales_w)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) (
|
||||
const Tensor& input,
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output)
|
||||
{
|
||||
const Tensor& output) {
|
||||
if (check_mps_compatibility("nearest-exact", scales_w)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) = at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("nearest-exact", scales_w)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
|
||||
mps::upsample_out_template(
|
||||
grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
|
||||
const Tensor& input,
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
bool align_corners,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& output)
|
||||
{
|
||||
const Tensor& output) {
|
||||
if (check_mps_compatibility("bilinear", scales_w)) {
|
||||
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(output) = at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(output) =
|
||||
at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) (
|
||||
const Tensor& grad_output,
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
bool align_corners,
|
||||
c10::optional<double> scales_h,
|
||||
c10::optional<double> scales_w,
|
||||
const Tensor& grad_input)
|
||||
{
|
||||
const Tensor& grad_input) {
|
||||
if (check_mps_compatibility("bilinear", scales_w)) {
|
||||
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
|
||||
mps::upsample_out_template(
|
||||
grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(grad_input) = at::upsample_bilinear2d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w).clone().to("mps");
|
||||
const_cast<Tensor&>(grad_input) =
|
||||
at::upsample_bilinear2d_backward(
|
||||
grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,17 +1,16 @@
|
||||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
struct ViewCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct ViewCachedGraph : public MPSCachedGraph {
|
||||
ViewCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor = nil;
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
@ -20,18 +19,20 @@ struct ViewCachedGraph : public MPSCachedGraph
|
||||
std::vector<MPSGraphTensor*> strideTensors;
|
||||
};
|
||||
|
||||
static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType& updates_dtype, const IntArrayRef& base_shape,
|
||||
const IntArrayRef& new_shape, const IntArrayRef& stride,
|
||||
int64_t storage_offset, bool is_scatter)
|
||||
{
|
||||
static std::string getStridedKey(const ScalarType& self_dtype,
|
||||
const ScalarType& updates_dtype,
|
||||
const IntArrayRef& base_shape,
|
||||
const IntArrayRef& new_shape,
|
||||
const IntArrayRef& stride,
|
||||
int64_t storage_offset,
|
||||
bool is_scatter) {
|
||||
std::string dtype_key = getMPSTypeString(self_dtype);
|
||||
if (is_scatter) {
|
||||
dtype_key += ":" + getMPSTypeString(updates_dtype);
|
||||
}
|
||||
|
||||
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" +
|
||||
getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]:[" +
|
||||
getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
|
||||
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
|
||||
}
|
||||
|
||||
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
|
||||
@ -51,7 +52,8 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
||||
@autoreleasepool {
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
// in case of scatter, we use output tensor as input buffer and write the results back to the source buffer
|
||||
feeds[cachedGraph->inputTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: needsScatter ? outputBuffer : sourceBuffer
|
||||
feeds[cachedGraph->inputTensor] =
|
||||
[[[MPSGraphTensorData alloc] initWithMTLBuffer:needsScatter ? outputBuffer : sourceBuffer
|
||||
shape:inputShape
|
||||
dataType:inputType] autorelease];
|
||||
if (needsScatter) {
|
||||
@ -81,9 +83,7 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
||||
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:outputBuffer
|
||||
shape:outputShape
|
||||
dataType:outputType] autorelease];
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
cachedGraph->outputTensor : outputTensorData
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{cachedGraph->outputTensor : outputTensorData};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
return output;
|
||||
@ -106,10 +106,7 @@ MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSAr
|
||||
NSUInteger axis2 = axisIter - dimensionOrder.begin();
|
||||
iter_swap(dimensionOrder.begin() + i, axisIter);
|
||||
|
||||
outputTensor = [graph transposeTensor:outputTensor
|
||||
dimension:axis1
|
||||
withDimension:axis2
|
||||
name:nil];
|
||||
outputTensor = [graph transposeTensor:outputTensor dimension:axis1 withDimension:axis2 name:nil];
|
||||
}
|
||||
|
||||
return outputTensor;
|
||||
@ -121,8 +118,7 @@ NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger
|
||||
NSMutableDictionary* strideToDimLengthOffset = [[NSMutableDictionary alloc] init];
|
||||
for (NSInteger srcDim = rank - 1; srcDim >= 0; srcDim--) {
|
||||
NSUInteger size = [[tensor shape][srcDim] integerValue];
|
||||
NSDictionary *entry =
|
||||
@{
|
||||
NSDictionary* entry = @{
|
||||
@"dim" : [NSNumber numberWithInteger:srcDim],
|
||||
@"length" : [tensor shape][srcDim],
|
||||
@"offset" : [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride
|
||||
@ -135,8 +131,12 @@ NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger
|
||||
}
|
||||
|
||||
// Detect only expand dims, allows for duplicate strides
|
||||
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
|
||||
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
NSUInteger srcRank = [[inputTensor shape] count];
|
||||
// Not an expand dims
|
||||
if (srcRank >= dstRank)
|
||||
@ -175,9 +175,7 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
|
||||
|
||||
MPSGraphTensor* expandTensor = inputTensor;
|
||||
if ([expandAxes count]) {
|
||||
expandTensor = [graph expandDimsOfTensor:expandTensor
|
||||
axes:expandAxes
|
||||
name:nil];
|
||||
expandTensor = [graph expandDimsOfTensor:expandTensor axes:expandAxes name:nil];
|
||||
}
|
||||
[expandAxes release];
|
||||
|
||||
@ -185,7 +183,12 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor
|
||||
}
|
||||
|
||||
// Detect contiguous reshapes, no slicing
|
||||
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
NSUInteger srcRank = [[inputTensor shape] count];
|
||||
// Not a reshape
|
||||
if (srcRank <= dstRank)
|
||||
@ -218,15 +221,17 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
if (isValidReshape)
|
||||
outputTensor = [graph reshapeTensor: inputTensor
|
||||
withShape: dstShape
|
||||
name: nil];
|
||||
outputTensor = [graph reshapeTensor:inputTensor withShape:dstShape name:nil];
|
||||
[dstShape release];
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
|
||||
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
// Duplicate strides cannot be done
|
||||
{
|
||||
BOOL allUnique = YES;
|
||||
@ -243,7 +248,9 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
|
||||
// Skip for zero in dst shape
|
||||
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++)
|
||||
if (dstSizes[dstDim] == 0) { return nil; }
|
||||
if (dstSizes[dstDim] == 0) {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Flatten the inputTensor if necessary
|
||||
@ -257,9 +264,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
}
|
||||
// We have to leave at least 1 dimension, if all input dims are 1
|
||||
if ([squeezeAxes count])
|
||||
flatInputTensor = [graph squeezeTensor:flatInputTensor
|
||||
axes:squeezeAxes
|
||||
name:nil];
|
||||
flatInputTensor = [graph squeezeTensor:flatInputTensor axes:squeezeAxes name:nil];
|
||||
[squeezeAxes release];
|
||||
}
|
||||
|
||||
@ -280,13 +285,15 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
dstDimToSliceOffset[dstDim] = 0;
|
||||
} else {
|
||||
// Find what dimension and native length was for the specified stride
|
||||
NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld",dstStrides[dstDim]]];
|
||||
NSDictionary* srcDimLengthOffset =
|
||||
srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld", dstStrides[dstDim]]];
|
||||
|
||||
dstDimToSliceLength[dstDim] = dstSizes[dstDim];
|
||||
dstDimToSliceOffset[dstDim] = [srcDimLengthOffset[@"offset"] intValue];
|
||||
|
||||
// Stride does not exist in source tensor, or the specified size is too long. Not possible
|
||||
// TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding support
|
||||
// TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding
|
||||
// support
|
||||
if (!srcDimLengthOffset ||
|
||||
// the offset + length of destination should not be larger than source's length when slicing
|
||||
dstDimToSliceOffset[dstDim] + dstDimToSliceLength[dstDim] > [srcDimLengthOffset[@"length"] intValue]) {
|
||||
@ -353,9 +360,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
[transposedMissingSrcDims addObject:[NSNumber numberWithInt:dstDim]];
|
||||
}
|
||||
if ([transposedMissingSrcDims count])
|
||||
squeezedTensor = [graph squeezeTensor:squeezedTensor
|
||||
axes:transposedMissingSrcDims
|
||||
name:nil];
|
||||
squeezedTensor = [graph squeezeTensor:squeezedTensor axes:transposedMissingSrcDims name:nil];
|
||||
[transposedMissingSrcDims release];
|
||||
}
|
||||
|
||||
@ -369,11 +374,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
int start = dstDimToSliceOffset[dstDim];
|
||||
int length = dstDimToSliceLength[dstDim];
|
||||
if (length != [[slicedTensor shape][currDstDim] intValue])
|
||||
slicedTensor = [graph sliceTensor:slicedTensor
|
||||
dimension:currDstDim
|
||||
start:start
|
||||
length:length
|
||||
name:nil];
|
||||
slicedTensor = [graph sliceTensor:slicedTensor dimension:currDstDim start:start length:length name:nil];
|
||||
currDstDim++;
|
||||
}
|
||||
}
|
||||
@ -391,12 +392,8 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
}
|
||||
|
||||
if ([expandAxes count]) {
|
||||
MPSGraphTensor *expandTensor = [graph expandDimsOfTensor:broadcastTensor
|
||||
axes:expandAxes
|
||||
name:nil];
|
||||
broadcastTensor = [graph broadcastTensor:expandTensor
|
||||
toShape:broadcastShape
|
||||
name:nil];
|
||||
MPSGraphTensor* expandTensor = [graph expandDimsOfTensor:broadcastTensor axes:expandAxes name:nil];
|
||||
broadcastTensor = [graph broadcastTensor:expandTensor toShape:broadcastShape name:nil];
|
||||
}
|
||||
[broadcastShape release];
|
||||
[expandAxes release];
|
||||
@ -409,7 +406,12 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *i
|
||||
return broadcastTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) {
|
||||
MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
if (!dstRank)
|
||||
return nil;
|
||||
|
||||
@ -423,8 +425,7 @@ MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTen
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
static
|
||||
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) {
|
||||
static std::vector<int64_t> getViewShape(const Tensor& src, MPSShape* mpsShape, const bool squeeze) {
|
||||
bool hasMPSShape = (mpsShape != nil);
|
||||
std::vector<int64_t> src_view_shape;
|
||||
if (hasMPSShape) {
|
||||
@ -459,7 +460,6 @@ std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const b
|
||||
return src_view_shape;
|
||||
}
|
||||
|
||||
|
||||
std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) {
|
||||
std::vector<int64_t> src_base_shape;
|
||||
for (const auto i : c10::irange(shape.size())) {
|
||||
@ -471,7 +471,6 @@ std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape)
|
||||
return src_base_shape;
|
||||
}
|
||||
|
||||
|
||||
bool canSliceViewTensor(const Tensor& src, MPSShape* mpsShape) {
|
||||
if (!src.is_contiguous()) {
|
||||
return false;
|
||||
@ -537,7 +536,8 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
|
||||
}
|
||||
|
||||
int64_t sliceOffset = src.storage_offset() / view_numel;
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
|
||||
[srcTensorNDArrayDesc
|
||||
sliceDimension:src_ndim_base - 1 - firstDimToSlice
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
|
||||
|
||||
// Slice any remaining dimensions
|
||||
@ -548,7 +548,8 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
|
||||
} else {
|
||||
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
|
||||
}
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
|
||||
[srcTensorNDArrayDesc
|
||||
sliceDimension:src_ndim_base - 1 - crtSliceOffset
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
|
||||
}
|
||||
}
|
||||
@ -559,11 +560,13 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
|
||||
return [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcTensorNDArrayView] autorelease];
|
||||
}
|
||||
|
||||
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size,
|
||||
const IntArrayRef& stride, int64_t offset,
|
||||
const IntArrayRef& base_shape, bool needsScatter,
|
||||
MPSGraphTensor* updatesTensor)
|
||||
{
|
||||
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
|
||||
const IntArrayRef& size,
|
||||
const IntArrayRef& stride,
|
||||
int64_t offset,
|
||||
const IntArrayRef& base_shape,
|
||||
bool needsScatter,
|
||||
MPSGraphTensor* updatesTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
const size_t shape_size = size.size();
|
||||
@ -575,17 +578,14 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
||||
TORCH_CHECK(size[i] <= int_max);
|
||||
sizeArray[i] = static_cast<int32_t>(size[i]);
|
||||
}
|
||||
NSData* shapeData = [NSData dataWithBytes: sizeArray.data()
|
||||
length: shape_size * sizeof(int32_t)];
|
||||
NSData* shapeData = [NSData dataWithBytes:sizeArray.data() length:shape_size * sizeof(int32_t)];
|
||||
MPSGraphTensor* shapeTensor = [mpsGraph constantWithData:shapeData
|
||||
shape:@[ [NSNumber numberWithUnsignedInteger:shape_size] ]
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* indicesTensor = nil;
|
||||
// create stride Tensors for each rank of the input tensor
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis: (-i - 1)
|
||||
withShapeTensor: shapeTensor
|
||||
name: nil];
|
||||
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil];
|
||||
MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1];
|
||||
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor
|
||||
secondaryTensor:strideTensor
|
||||
@ -593,9 +593,7 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
||||
if (!indicesTensor) {
|
||||
indicesTensor = indexTensor;
|
||||
} else {
|
||||
indicesTensor = [mpsGraph additionWithPrimaryTensor: indexTensor
|
||||
secondaryTensor: indicesTensor
|
||||
name: nil];
|
||||
indicesTensor = [mpsGraph additionWithPrimaryTensor:indexTensor secondaryTensor:indicesTensor name:nil];
|
||||
}
|
||||
}
|
||||
|
||||
@ -611,12 +609,8 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor: inputTensor
|
||||
withShape: @[@-1]
|
||||
name: nil];
|
||||
MPSGraphTensor *reshapedIndicesTensor = [mpsGraph reshapeTensor: indicesTensor
|
||||
withShape: @[@-1]
|
||||
name: nil];
|
||||
MPSGraphTensor* reshapedInputTensor = [mpsGraph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
|
||||
MPSGraphTensor* reshapedIndicesTensor = [mpsGraph reshapeTensor:indicesTensor withShape:@[ @-1 ] name:nil];
|
||||
if (needsScatter) {
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wobjc-method-access"
|
||||
@ -627,9 +621,7 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
||||
mode:MPSGraphScatterModeSet
|
||||
name:nil];
|
||||
#pragma clang diagnostic pop
|
||||
outputTensor = [mpsGraph reshapeTensor: scatteredTensor
|
||||
withShape: getMPSShape(base_shape)
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:scatteredTensor withShape:getMPSShape(base_shape) name:nil];
|
||||
} else {
|
||||
// Call gather to coalesce the needed values. Result will be of same shape as flattened indices tensor
|
||||
MPSGraphTensor* gatheredTensor = [mpsGraph gatherWithUpdatesTensor:reshapedInputTensor
|
||||
@ -638,24 +630,22 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
||||
batchDimensions:0
|
||||
name:nil];
|
||||
// Reshape the data to desired size
|
||||
outputTensor = [mpsGraph reshapeTensor: gatheredTensor
|
||||
withShapeTensor: shapeTensor
|
||||
name: nil];
|
||||
outputTensor = [mpsGraph reshapeTensor:gatheredTensor withShapeTensor:shapeTensor name:nil];
|
||||
}
|
||||
}
|
||||
return outputTensor;
|
||||
}
|
||||
|
||||
static IntArrayRef updateTensorBaseShape(const Tensor& self)
|
||||
{
|
||||
static IntArrayRef updateTensorBaseShape(const Tensor& self) {
|
||||
IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data());
|
||||
// if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it
|
||||
if (base_shape.size() == 0) {
|
||||
// IntArrayRef wouldn't own the data, so we use a static storage
|
||||
static const int64_t shape_1d = 1;
|
||||
// self.sizes().size() could be zero
|
||||
base_shape = self.sizes().size() ? self.sizes() :
|
||||
((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
|
||||
base_shape = self.sizes().size()
|
||||
? self.sizes()
|
||||
: ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1));
|
||||
|
||||
// base_shape will be retained in MPSAllocator until buffer gets recycled
|
||||
if (self.storage().data())
|
||||
@ -681,12 +671,17 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self)
|
||||
// | / \ |
|
||||
// | / \ |
|
||||
// NonView T NonView T
|
||||
static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &updates, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter)
|
||||
{
|
||||
static ViewCachedGraph* createViewGraph(const Tensor& self,
|
||||
const Tensor& updates,
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
int64_t storage_offset,
|
||||
bool needsScatter) {
|
||||
IntArrayRef base_shape = updateTensorBaseShape(self);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = getStridedKey(self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter);
|
||||
string key = getStridedKey(
|
||||
self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter);
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
ViewCachedGraph* cachedGraph = static_cast<ViewCachedGraph*>(cache_->LookUp(key));
|
||||
|
||||
@ -718,12 +713,11 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
|
||||
newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel()));
|
||||
updatesTensor = newCachedGraph->updatesTensor;
|
||||
if (inputType != updatesType) {
|
||||
updatesTensor = [mpsGraph castTensor:updatesTensor
|
||||
toType:inputType
|
||||
name:@"castUpdatesTensor"];
|
||||
updatesTensor = [mpsGraph castTensor:updatesTensor toType:inputType name:@"castUpdatesTensor"];
|
||||
}
|
||||
}
|
||||
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
|
||||
newCachedGraph->outputTensor =
|
||||
chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
@ -732,11 +726,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
|
||||
}
|
||||
}
|
||||
|
||||
static
|
||||
std::string getGatherScatterFunctionName(
|
||||
ScalarType scalarType,
|
||||
int64_t dim,
|
||||
bool needsScatter) {
|
||||
static std::string getGatherScatterFunctionName(ScalarType scalarType, int64_t dim, bool needsScatter) {
|
||||
std::string kernelName = needsScatter ? "scatter" : "gather";
|
||||
return kernelName + "_kernel_" + std::to_string(dim == 0 ? 1 : dim);
|
||||
}
|
||||
@ -759,8 +749,7 @@ const std::string& getGatherScatterScalarType(const Tensor& t) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static
|
||||
id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
|
||||
static id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
|
||||
const std::string& dtypeSrc,
|
||||
const std::string& dtypeDst,
|
||||
bool needsScatter) {
|
||||
@ -773,10 +762,17 @@ id<MTLLibrary> compileGatherScatterOpsLibrary(id<MTLDevice> device,
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
auto gatherScatterLib = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE : GATHER_OPS_TEMPLATE, dtypeSrc, dtypeDst).c_str()]
|
||||
auto gatherScatterLib =
|
||||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE
|
||||
: GATHER_OPS_TEMPLATE,
|
||||
dtypeSrc,
|
||||
dtypeDst)
|
||||
.c_str()]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(gatherScatterLib != nil && error == nil, "Failed to compile gather-scatter library, error: ", [[error description] UTF8String]);
|
||||
TORCH_CHECK(gatherScatterLib != nil && error == nil,
|
||||
"Failed to compile gather-scatter library, error: ",
|
||||
[[error description] UTF8String]);
|
||||
_libCache[key] = gatherScatterLib;
|
||||
return gatherScatterLib;
|
||||
}
|
||||
@ -798,7 +794,8 @@ static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device,
|
||||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
|
||||
TORCH_CHECK(func, "Failed to load the Metal Shader function: ", kernel);
|
||||
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
TORCH_CHECK(
|
||||
pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
|
||||
_mtlPipelineCache[key] = pso;
|
||||
return pso;
|
||||
}
|
||||
@ -814,8 +811,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
}
|
||||
|
||||
if (src.dim() > 5) {
|
||||
ViewCachedGraph* cachedGraph = createViewGraph(src, dst, src.sizes(), src.strides(),
|
||||
src.storage_offset(), /*needsScatter*/ false);
|
||||
ViewCachedGraph* cachedGraph =
|
||||
createViewGraph(src, dst, src.sizes(), src.strides(), src.storage_offset(), /*needsScatter*/ false);
|
||||
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false);
|
||||
}
|
||||
|
||||
@ -871,8 +868,11 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
|
||||
if (output.dim() > 5) {
|
||||
ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output,
|
||||
src, output.sizes(), output.strides(),
|
||||
output.storage_offset(), /*needsScatter*/ true);
|
||||
src,
|
||||
output.sizes(),
|
||||
output.strides(),
|
||||
output.storage_offset(),
|
||||
/*needsScatter*/ true);
|
||||
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true);
|
||||
}
|
||||
if (src.numel() == 0 || output.numel() == 0) {
|
||||
@ -888,7 +888,8 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
||||
std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true);
|
||||
std::string functionName =
|
||||
getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true);
|
||||
id<MTLComputePipelineState> scatterPSO = getPipelineState(MPSDevice::getInstance()->device(),
|
||||
functionName,
|
||||
getGatherScatterScalarType(src),
|
||||
@ -934,16 +935,21 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){
|
||||
} // namespace mps
|
||||
|
||||
// implementation of as_strided() op
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_) {
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self,
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset_) {
|
||||
auto storage_offset = storage_offset_.value_or(self.storage_offset());
|
||||
auto result = detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
auto result =
|
||||
detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
setStrided(result, size, stride, storage_offset);
|
||||
|
||||
// creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called.
|
||||
// In as_strided, we just update the base shape of the buffer in order to retrieve it later
|
||||
// when we create/run the view graph.
|
||||
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
|
||||
TORCH_INTERNAL_ASSERT(base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -10,8 +10,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
+ (NSString*)cacheDirectory;
|
||||
|
||||
+ (BOOL)compileModel:(const std::string&)modelSpecs
|
||||
modelID:(const std::string&)modelID;
|
||||
+ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID;
|
||||
|
||||
+ (nullable MLModel*)loadModel:(const std::string&)modelID
|
||||
backend:(const std::string&)backend
|
||||
|
Reference in New Issue
Block a user