Compare commits

...

1 Commits

3 changed files with 93 additions and 11 deletions

View File

@ -93,7 +93,7 @@ TORCH_LIBRARY_IMPL(aten, MPS, m) {
m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
// m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
}
} // namespace at

View File

@ -34,6 +34,8 @@
#include <ATen/ops/upsample_nearest2d_backward.h>
#include <ATen/ops/upsample_nearest2d_backward_native.h>
#include <ATen/ops/upsample_nearest2d_native.h>
#include <ATen/ops/upsample_nearest3d.h>
#include <ATen/ops/upsample_nearest3d_native.h>
#endif
namespace at::native {
namespace mps {
@ -54,8 +56,10 @@ static void upsample_out_template(const Tensor& input,
const auto input_dim = input.sizes();
if (input_dim.size() <= 3) {
native::upsample_1d_common_check(input.sizes(), output_size);
} else {
} else if (input_dim.size() <= 4) {
native::upsample_2d_common_check(input.sizes(), output_size);
} else {
native::upsample_3d_common_check(input.sizes(), output_size);
}
Tensor out;
if (needsGather(output)) {
@ -67,6 +71,7 @@ static void upsample_out_template(const Tensor& input,
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
MPSGraphTensorNamedDataLayout dataLayout =
input_dim.size() > 3 ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutCHW;
dataLayout = MPSGraphTensorNamedDataLayoutNCHW;
if (resize_mode_str == "nearest") {
resizeMode = MPSGraphResizeNearest;
} else if (resize_mode_str == "bilinear") {
@ -95,6 +100,7 @@ static void upsample_out_template(const Tensor& input,
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
MPSGraphTensor* outputSizeTensor = nil;
MPSGraphTensor* outputSizeTensor1 = nil;
};
MPSStream* stream = getCurrentMPSStream();
@ -105,16 +111,25 @@ static void upsample_out_template(const Tensor& input,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(2) ]);
newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(output_size.size()) ]);
MPSGraphTensor* scaleOffsetTensor = nullptr;
MPSGraphTensor* scaleOffsetTensor2 = nullptr;
MPSGraphTensor* inputSizeTensor = nullptr;
if (scale_w > 0.0) {
const float outScales[4] = {scale_h, scale_w, offset_y, offset_x};
for(auto elem: outScales){
std::cout<<elem<<", ";
}
std::cout<<std::endl;
scaleOffsetTensor = [mpsGraph constantWithData:[NSData dataWithBytes:outScales length:sizeof(outScales)]
shape:@[ @4 ]
dataType:MPSDataTypeFloat32];
const float outScales2[4] = {1, 2, 0, 0};
scaleOffsetTensor2 = [mpsGraph constantWithData:[NSData dataWithBytes:outScales2 length:sizeof(outScales)]
shape:@[ @4 ]
dataType:MPSDataTypeFloat32];
}
if (is_backward_pass) {
std::vector<NSNumber*> inputSizeVec(4);
@ -130,12 +145,67 @@ static void upsample_out_template(const Tensor& input,
if (!is_backward_pass) {
if (scaleOffsetTensor && !align_corners) {
if (resizeMode == MPSGraphResizeNearest) {
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
scaleOffsetTensor:scaleOffsetTensor
nearestRoundingMode:nearestRoundingMode
layout:dataLayout
name:nil];
// Volumetric case
if(input_dim.size()==5) {
// Assuming H,W,D
auto firstSizeTensor = [mpsGraph sliceTensor:newCachedGraph->outputSizeTensor
dimension:0
start:0
length:2
name:nil];
const int secondSizeVal[2] = {6, 6};
auto secondSizeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:secondSizeVal length:sizeof(secondSizeVal)]
shape:@[ @2 ]
dataType:MPSDataTypeInt32];
// auto secondSizeTensor = [mpsGraph sliceTensor:newCachedGraph->outputSizeTensor
// dimension:0
// start:1
// length:2
// name:nil];
std::cout<<[[secondSizeTensor debugDescription] UTF8String]<<std::endl;
auto reshapedInput = [mpsGraph reshapeTensor:newCachedGraph->inputTensor
withShape:@[@(input_dim[0]*input_dim[1]), @(input_dim[2]), @(input_dim[3]), @(input_dim[4])]
name:nil];
// auto intermediateTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
// sizeTensor:firstSizeTensor
// nearestRoundingMode:nearestRoundingMode
// centerResult:centerResults
// alignCorners:align_corners
// name:nil];
//
// intermediateTensor = [mpsGraph resizeNearestWithTensor:intermediateTensor
// sizeTensor:secondSizeTensor
// nearestRoundingMode:nearestRoundingMode
// centerResult:centerResults
// alignCorners:align_corners
// name:nil];
auto intermediateTensor = [mpsGraph resizeNearestWithTensor:reshapedInput
sizeTensor:firstSizeTensor
scaleOffsetTensor:scaleOffsetTensor
nearestRoundingMode:nearestRoundingMode
layout:MPSGraphTensorNamedDataLayoutNHWC
name:nil];
intermediateTensor = [mpsGraph resizeNearestWithTensor:intermediateTensor
sizeTensor:secondSizeTensor
scaleOffsetTensor:scaleOffsetTensor2
nearestRoundingMode:nearestRoundingMode
layout:MPSGraphTensorNamedDataLayoutNCHW
name:nil];
newCachedGraph->outputTensor = [mpsGraph reshapeTensor:intermediateTensor
withShape:@[@(input_dim[0]), @(input_dim[1]), @(output_size[0]), @(output_size[1]), @(output_size[2])]
name:nil];
std::cout<<[[newCachedGraph->outputTensor debugDescription] UTF8String]<<std::endl;
} else {
newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
scaleOffsetTensor:scaleOffsetTensor
nearestRoundingMode:nearestRoundingMode
layout:dataLayout
name:nil];
}
} else { // bilinear forward
newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor
sizeTensor:newCachedGraph->outputSizeTensor
@ -197,11 +267,12 @@ static void upsample_out_template(const Tensor& input,
}
}
});
MPSNDArrayDescriptor* sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(2) ]];
MPSNDArrayDescriptor* sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(output_size.size()) ]];
MPSNDArray* sizeNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:sizeDesc] autorelease];
[sizeNDArray writeBytes:(int32_t[]){(int32_t)output_height, (int32_t)output_width} strideBytes:nil];
[sizeNDArray writeBytes:(int32_t[]){6, 6, 6} strideBytes:nil];
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:sizeNDArray] autorelease];
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
@ -398,6 +469,16 @@ TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps)
grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
}
TORCH_IMPL_FUNC(upsample_nearest3d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
std::optional<double> scales_d,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& output) {
mps::upsample_out_template(input, output_size, std::nullopt, scales_h, scales_w, output, false, "nearest");
}
TORCH_IMPL_FUNC(upsample_linear1d_out_mps)
(const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double> scale, const Tensor& output) {
mps::upsample_out_template(input, output_size, std::nullopt, std::nullopt, scale, output, align_corners, "bilinear");

View File

@ -12911,6 +12911,7 @@
dispatch:
CPU: upsample_nearest3d_out_cpu
CUDA: upsample_nearest3d_out_cuda
MPS: upsample_nearest3d_out_mps
- func: _upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn