Compare commits

...

11 Commits

Author SHA1 Message Date
2ce56de80e Remove few xfails 2025-05-22 15:18:11 -07:00
34cd5614c5 Fix lint 2025-05-22 15:16:53 -07:00
15d7f6ac2b clean up 2025-05-22 17:32:21 -04:00
7b80b3fd13 Apply suggestions from code review 2025-05-22 14:19:29 -07:00
b0b1902739 Update aten/src/ATen/native/mps/operations/Pooling.mm 2025-05-22 14:19:06 -07:00
1d29dc5d9c fix test_max_pool3d 2025-05-22 17:04:15 -04:00
fe518636a6 update 2025-05-22 16:33:09 -04:00
765dd32545 One is expected to return Tensor by reference from function 2025-05-22 13:16:58 -07:00
b9ca9918ba [BE] Do not call explicit constructor
Compiler should do the work for you
2025-05-22 13:16:18 -07:00
003540fcb6 Fix build 2025-05-22 13:12:38 -07:00
a7f788143e [MPS] Implement max_pool3d_with_indices 2025-05-22 15:59:53 -04:00
4 changed files with 305 additions and 6 deletions

View File

@ -15,6 +15,7 @@
#include <ATen/ops/max_pool2d_native.h>
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
#include <ATen/ops/max_pool2d_with_indices_native.h>
#include <ATen/ops/max_pool3d_with_indices_native.h>
#endif
namespace at::native {
@ -32,6 +33,10 @@ struct PoolingCachedGraph : public MPSCachedGraph {
typedef MPSGraphTensor* (^PoolingOpBlock)(PoolingCachedGraph&, MPSGraphPooling2DOpDescriptor*);
#define PoolingOpFn(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph & graph, MPSGraphPooling2DOpDescriptor * desc)
typedef MPSGraphTensor* (^Pooling4dOpBlock)(PoolingCachedGraph&, MPSGraphPooling4DOpDescriptor*);
#define Pooling4dOpFn(graph, desc) \
MPSGraphTensor*(mps::PoolingCachedGraph & graph, MPSGraphPooling4DOpDescriptor * desc)
// Pooling ops (1D/2D forward and backward Max and Average pooling)
static void pool2d_template(const Tensor& input,
const Tensor& output,
@ -240,6 +245,204 @@ static void pool2d_template(const Tensor& input,
}
}
static void pool3d_template(const Tensor& input,
const Tensor& output,
const std::optional<Tensor>& indices_opt,
const std::optional<Tensor>& grad_output_opt,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
bool count_include_pad,
const std::optional<int64_t> divisor_override,
Pooling4dOpBlock poolingBlock,
const std::string& op_name) {
const int64_t ndims = input.ndimension();
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));
const bool is_backward_pass = grad_output.defined();
const bool has_indices = indices.defined();
const bool has_divisor = divisor_override.has_value() && divisor_override.value() != 0;
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
op_name,
": kernel_size must either be a single int, or a tuple of three ints")
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3,
op_name,
": stride must either be omitted, a single int, or a tuple of three ints")
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
op_name,
": padding must either be a single int, or a tuple of three ints");
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
op_name,
": dilation must be either a single int, or a tuple of three ints");
const auto memory_format = input.suggest_memory_format();
TORCH_CHECK((ndims == 4 || ndims == 5), "non-empty 4D or 5D (batch mode) tensor expected for input");
TORCH_CHECK(memory_format == at::MemoryFormat::Contiguous, "MPS pool3d supports only Contiguous memory format");
int padD = safe_downcast<int, int64_t>(padding[0]);
int padH = padding.size() == 1 ? padD : safe_downcast<int, int64_t>(padding[1]);
int padW = padding.size() == 1 ? padD : safe_downcast<int, int64_t>(padding[2]);
const int kD = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = kernel_size.size() == 1 ? kD : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kD : safe_downcast<int, int64_t>(kernel_size[2]);
const int dD = stride.empty() ? kD : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH : stride.size() == 1 ? dD : safe_downcast<int, int64_t>(stride[1]);
const int dW = stride.empty() ? kW : stride.size() == 1 ? dD : safe_downcast<int, int64_t>(stride[2]);
const int dilationD = safe_downcast<int, int64_t>(dilation[0]);
const int dilationH = dilation.size() == 1 ? dilationD : safe_downcast<int, int64_t>(dilation[1]);
const int dilationW = dilation.size() == 1 ? dilationD : safe_downcast<int, int64_t>(dilation[2]);
const int64_t nbatch = ndims == 5 ? input.size(-5) : 1;
const int64_t nInputPlane = input.size(-4);
const int64_t inputDepth = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputDepth = pooling_output_shape<int64_t>(inputDepth, kD, padD, dD, dilationD, ceil_mode);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
pool3d_shape_check(input,
nInputPlane,
kD,
kH,
kW,
dD,
dH,
dW,
padD,
padH,
padW,
dilationD,
dilationH,
dilationW,
inputDepth,
inputHeight,
inputWidth,
outputDepth,
outputHeight,
outputWidth,
"pool3d");
std::vector<int64_t> outputSizes{nInputPlane, outputDepth, outputHeight, outputWidth};
if (ndims == 5) {
outputSizes.insert(outputSizes.begin(), nbatch);
}
output.resize_(outputSizes);
indices.resize_(outputSizes);
if (input.numel() == 0) {
return;
}
if (output.numel() == 0 || (is_backward_pass && grad_output.numel() == 0)) {
return;
}
// when both ceil_mode and count_include_pad are True
if (count_include_pad && ceil_mode) {
padD = padH = padW = 0;
}
input.unsqueeze_(-4);
output.unsqueeze_(-4);
indices.unsqueeze_(-4);
int padT = 0;
const int kT = 1;
const int dT = 1;
const int dilationT = 1;
@autoreleasepool {
std::string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" +
getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) +
"]:D[" + getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") +
(count_include_pad ? ":include_pad" : "") + (has_divisor ? ":divisor" : "");
MPSShape* inputShape = getMPSShape(input, memory_format);
MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr;
auto cachedGraph = LookUpOrCreateCachedGraph<PoolingCachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
MPSGraphPooling4DOpDescriptor* desc = [[MPSGraphPooling4DOpDescriptor alloc] init];
desc.kernelSizes = @[ @(kT), @(kD), @(kH), @(kW) ];
desc.strides = @[ @(dT), @(dD), @(dH), @(dW) ];
desc.dilationRates = @[ @(dilationT), @(dilationD), @(dilationH), @(dilationW) ];
desc.paddingValues = @[ @(padT), @(padT), @(padD), @(padD), @(padH), @(padH), @(padW), @(padW) ];
desc.paddingStyle = MPSGraphPaddingStyleExplicit;
desc.ceilMode = (padD == 0 && padW == 0 && padH == 0) ? ceil_mode : false;
if (has_indices) {
desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten4D;
desc.returnIndicesDataType = MPSDataTypeInt32;
}
newCachedGraph->inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape);
if (is_backward_pass) {
newCachedGraph->gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape);
}
if (has_divisor) {
newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[ @1 ]);
}
MPSGraphTensor* outputTensor = poolingBlock(*newCachedGraph, desc);
newCachedGraph->outputTensor = outputTensor;
});
MPSStream* mpsStream = getCurrentMPSStream();
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor,
input,
inputShape,
/*gatherTensorData=*/true,
MPSDataTypeInvalid,
/*useMPSStridedAPI=*/false);
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder()
: Placeholder(cachedGraph->gradOutputTensor,
grad_output,
gradOutputShape,
/*gatherTensorData=*/true,
MPSDataTypeInvalid,
/*useMPSStridedAPI=*/false);
Placeholder indicesPlaceholder = has_indices
? Placeholder(
cachedGraph->indicesTensor, indices, nullptr, true, MPSDataTypeInvalid, /*useMPSStridedAPI=*/false)
: Placeholder();
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor, output, nullptr, false, MPSDataTypeInvalid, false);
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary* results = [[NSMutableDictionary new] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
results[outputPlaceholder.getMPSGraphTensor()] = outputPlaceholder.getMPSGraphTensorData();
if (cachedGraph->gradOutputTensor) {
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
}
if (cachedGraph->indicesTensor) {
if (is_backward_pass) {
feeds[indicesPlaceholder.getMPSGraphTensor()] = indicesPlaceholder.getMPSGraphTensorData();
} else {
results[indicesPlaceholder.getMPSGraphTensor()] = indicesPlaceholder.getMPSGraphTensorData();
}
}
MPSScalar divisor_scalar;
if (cachedGraph->divisorTensor) {
const float divisor = float(kD * kH * kW) / (float)divisor_override.value();
divisor_scalar = getMPSScalar(divisor, ScalarType::Float);
feeds[cachedGraph->divisorTensor] = getMPSGraphTensorFromScalar(mpsStream, divisor_scalar);
}
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
}
input.squeeze_(-4);
output.squeeze_(-4);
indices.squeeze_(-4);
}
static void avg_pool2d_template(const Tensor& input,
const Tensor& output,
const std::optional<Tensor>& grad_output_opt,
@ -493,6 +696,53 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
"max_pool2d_indices_backward");
}
std::tuple<Tensor&, Tensor&> max_pool3d_with_indices_out_mps(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
Tensor& output,
Tensor& indices) {
mps::Pooling4dOpBlock pooling_op_block = ^Pooling4dOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling4DReturnIndicesWithSourceTensor:cachedGraph.inputTensor
descriptor:desc
name:nil];
cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
return poolOutputs[0];
};
mps::pool3d_template(input,
output,
indices,
std::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
std::nullopt,
pooling_op_block,
"max_pool3d_indices");
return {output, indices};
}
std::tuple<Tensor, Tensor> max_pool3d_with_indices_mps(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
max_pool3d_with_indices_out_mps(input, kernel_size, stride, padding, dilation, ceil_mode, output, indices);
return {output, indices};
}
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
(const Tensor& input,
int64_t kH,

View File

@ -12434,6 +12434,7 @@
dispatch:
CPU: max_pool3d_with_indices_out_cpu
CUDA: max_pool3d_with_indices_out_cuda
MPS: max_pool3d_with_indices_out_mps
# Return: (Tensor output, Tensor indices)
- func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
@ -12441,6 +12442,7 @@
dispatch:
CPU: max_pool3d_with_indices_cpu
CUDA: max_pool3d_with_indices_cuda
MPS: max_pool3d_with_indices_mps
tags: core
- func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)

View File

@ -855,7 +855,6 @@ torch.cuda.synchronize()
inp = torch.randn(16, 0, 20, 32, device=device)
avgpool(inp)
@expectedFailureMPS # max_pool3d_with_indices not supported on MPS
def test_pooling_shape(self, device):
"""Test the output shape calculation for pooling functions"""
@ -1939,7 +1938,6 @@ torch.cuda.synchronize()
helper(nn.AdaptiveAvgPool2d((2**6, 2**6)))
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@expectedFailureMPS
@dtypes(torch.float)
def test_pool_invalid_size(self, device, dtype):
for op in ("max", "avg"):

View File

@ -1386,9 +1386,7 @@ class TestMPS(TestCaseMPS):
# Test forward maxpool2d
def test_max_pool2d(self):
def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
cpu_x = None
if (test_ties):
if test_ties:
cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
else:
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
@ -1397,7 +1395,7 @@ class TestMPS(TestCaseMPS):
pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
ceil_mode=ceil_mode, return_indices=return_indices)
if (return_indices is False):
if not return_indices:
y = pool(x)
ref_y = pool(cpu_x)
@ -1458,6 +1456,57 @@ class TestMPS(TestCaseMPS):
helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test forward maxpool2d
def test_max_pool3d(self):
def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
if test_ties:
cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
else:
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
pool = torch.nn.MaxPool3d(kernel_size=ks, padding=padding, dilation=dilation,
ceil_mode=ceil_mode, return_indices=return_indices)
if not return_indices:
y = pool(x)
ref_y = pool(cpu_x)
self.assertEqual(y, ref_y)
else:
y, idx = pool(x)
ref_y, ref_idx = pool(cpu_x)
self.assertEqual(y, ref_y)
self.assertEqual(idx, ref_idx)
# Test with no batch dimension
helper((8, 4, 4, 4), ks=2)
helper((2, 8, 4, 4, 4), ks=2)
helper((1, 10, 32, 32, 32), ks=4)
# Test padding
helper((1, 10, 32, 32, 32), ks=4, padding=1)
# Test dilation
helper((1, 10, 32, 32, 32), ks=4, dilation=2)
# Test ceil mode
helper((1, 10, 32, 32, 32), ks=4, ceil_mode=True)
# Test return indices
for test_ties in [False, True]:
# Test with no batch dimension
helper((8, 4, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
# Test with empty input
helper((0, 8, 4, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
helper((2, 8, 4, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
helper((1, 10, 32, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
# Test padding
helper((1, 10, 32, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
# Test dilation
helper((1, 10, 32, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
# Test ceil mode
helper((1, 10, 32, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
def test_adaptive_avg_pool2d_output_size_one(self):
def helper(size, memory_format):
x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)