Compare commits

...

4 Commits

2 changed files with 154 additions and 10 deletions

View File

@ -40,6 +40,8 @@
#include <ATen/ops/xlogy_native.h>
#endif
#include <cmath>
namespace at::native {
namespace mps {
@ -253,6 +255,131 @@ static void div_mode_template(const Tensor& self,
div_mode_op_block);
}
static const Tensor& tiled_add_sub_lerp(const Tensor& self,
const Tensor& other,
const Scalar& alpha,
const Tensor& output,
std::string op_name) {
std::cout << self.sizes() << std::endl;
std::cout << other.sizes() << std::endl;
std::cout << output.sizes() << std::endl;
id<MTLBuffer> selfBuffer = getMTLBufferStorage(self);
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
@autoreleasepool {
mpsStream->endKernelCoalescing();
uint64_t originalBatchSize = self.sizes().size() > 2 ? self.size(-3) : 1;
uint64_t selfRows = self.size(-2);
uint64_t otherRows = other.size(-2);
uint64_t outputRows = output.size(-2);
uint64_t selfCols = self.size(-1);
uint64_t otherCols = other.size(-1);
uint64_t outputCols = output.size(-1);
uint64_t selfElemSize = self.element_size();
uint64_t otherElemSize = other.element_size();
uint64_t outputElemSize = output.element_size();
MPSDataType dtype = getMPSDataType(self);
uint64_t elemInMatrix = outputRows * outputCols;
uint64_t largestSupportedBatchSize = floor(std::pow(2, 32) / elemInMatrix);
uint64_t batchSize = std::min(largestSupportedBatchSize, originalBatchSize);
uint64_t lastBatchSize = originalBatchSize % batchSize;
std::cout << "Largest supported: " << largestSupportedBatchSize << std::endl;
std::cout << "chosen supported: " << batchSize << std::endl;
std::cout << "last supported: " << lastBatchSize << std::endl;
MPSShape* selfShape = @[ @(batchSize), @(selfRows), @(selfCols) ];
MPSShape* otherShape = @[ @1, @(otherRows), @(otherCols) ];
MPSShape* outputShape = @[ @(batchSize), @(outputRows), @(outputCols) ];
std::cout << batchSize << ", " << selfRows << ", " << selfCols << std::endl;
std::cout << 1 << ", " << otherRows << ", " << otherCols << std::endl;
std::cout << batchSize << ", " << outputRows << ", " << outputCols << std::endl;
auto selfDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:selfShape];
selfDesc_.preferPackedRows = true;
auto otherDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:otherShape];
otherDesc_.preferPackedRows = true;
auto outputDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:outputShape];
outputDesc_.preferPackedRows = true;
MPSNDArrayDescriptor *selfDescLastBatch_, *otherDescLastBatch_, *outputDescLastBatch_;
if (lastBatchSize != 0) {
selfDescLastBatch_ =
[MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(selfRows), @(selfCols) ]];
selfDescLastBatch_.preferPackedRows = true;
otherDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype
shape:@[ @1, @(otherRows), @(otherCols) ]];
otherDescLastBatch_.preferPackedRows = true;
outputDescLastBatch_ =
[MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(outputRows), @(outputCols) ]];
outputDescLastBatch_.preferPackedRows = true;
}
uint64_t requiredIterations = ceil(float(originalBatchSize) / batchSize);
auto selfDesc = selfDesc_;
auto otherDesc = otherDesc_;
auto outputDesc = outputDesc_;
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
for (const auto i : c10::irange(requiredIterations)) {
@autoreleasepool {
if (i == requiredIterations - 1 && lastBatchSize != 0) {
selfDesc = selfDescLastBatch_;
otherDesc = otherDescLastBatch_;
outputDesc = outputDescLastBatch_;
}
const uint64_t selfArrayOffset = i * batchSize * selfRows * selfCols;
const uint64_t otherArrayOffset = 0; // i * batchSize * otherRows * otherCols;
const uint64_t outputArrayOffset = i * batchSize * outputRows * outputCols;
auto selfNDArray = [[[MPSNDArray alloc] initWithBuffer:selfBuffer
offset:(self.storage_offset() + selfArrayOffset) * selfElemSize
descriptor:selfDesc] autorelease];
auto otherNDArray =
[[[MPSNDArray alloc] initWithBuffer:otherBuffer
offset:(other.storage_offset() + otherArrayOffset) * otherElemSize
descriptor:otherDesc] autorelease];
auto outputNDArray =
[[[MPSNDArray alloc] initWithBuffer:outputBuffer
offset:(output.storage_offset() + outputArrayOffset) * outputElemSize
descriptor:outputDesc] autorelease];
string key = op_name + getMPSShapeString([selfDesc getShape]) + getMPSShapeString([otherDesc getShape]) +
getMPSShapeString([outputDesc getShape]);
std::cout << key << std::endl;
auto cachedGraph = LookUpOrCreateCachedGraph<BinaryOpCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* selfTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* otherTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(other));
newCachedGraph->primaryTensor = selfTensor;
newCachedGraph->secondaryTensor = otherTensor;
auto outputTensor = [mpsGraph additionWithPrimaryTensor:selfTensor secondaryTensor:otherTensor name:nil];
newCachedGraph->outputTensor = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, selfNDArray);
Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, otherNDArray);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, outputNDArray);
auto feeds = dictionaryFromPlaceholders(selfPlaceholder, otherPlaceholder);
auto results = dictionaryFromPlaceholders(outputPlaceholder);
// runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
mpsStream->executeMPSGraph(cachedGraph->graph(), feeds, results, SyncType::COMMIT);
}
}
}
return output;
}
static void add_sub_lerp_template(const Tensor& self,
const Tensor& other,
const Scalar& alpha,
@ -414,16 +541,18 @@ TORCH_IMPL_FUNC(div_out_mps)(const Tensor& self, const Tensor& other, const Tens
}
TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
if ((isComplexType(self.scalar_type()) || isComplexType(other.scalar_type())) && !alpha.isComplex() &&
!mps::supportsComplex()) {
// Complex add with non-complex alpha is just add over views
return mps::add_sub_lerp_template(mps::legacy_complex_as_view(self),
mps::legacy_complex_as_view(other),
alpha,
mps::legacy_complex_as_view(output),
"add");
}
mps::add_sub_lerp_template(self, other, alpha, output, "add");
// if ((isComplexType(self.scalar_type()) || isComplexType(other.scalar_type())) && !alpha.isComplex() &&
// !mps::supportsComplex()) {
// // Complex add with non-complex alpha is just add over views
// return mps::add_sub_lerp_template(mps::legacy_complex_as_view(self),
// mps::legacy_complex_as_view(other),
// alpha,
// mps::legacy_complex_as_view(output),
// "add");
// }
//
// mps::add_sub_lerp_template(self, other, alpha, output, "add");
mps::tiled_add_sub_lerp(self, other, alpha, output, "add");
}
TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {

View File

@ -1929,6 +1929,21 @@ class TestMPS(TestCaseMPS):
self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
self.assertEqual(output_cpu.size(), output_mps.size())
@xfailIf(product_version < 15.0)
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_large_add(self, dtype):
a = torch.randn(11, 20064, 20064, dtype=dtype)
b = torch.randn(20064, 20064, dtype=dtype)
output_cpu = torch.add(a, b)
output_mps = torch.add(a.to('mps'), b.to('mps'))
# Using the low precision comparison for FP16
tol = 1e-2 if dtype == torch.float16 else None
self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_addr(self):
A = torch.ones(5, 10).to("mps")