[MPS] Add slow version of kthvalue (#161817)

Which heavily borrows implementation logic from `topk`
As this method is non-deterministic, modified the logic for cpu-ops indices comparison with just an equality statement, as by default random numbers picked for input tensor allow for quite a lot of overlaps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161817
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-08-29 16:25:46 -07:00
committed by PyTorch MergeBot
parent c1e504ec2f
commit 7c30a9d7fc
4 changed files with 114 additions and 1 deletions

View File

@ -2,6 +2,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/MemoryOverlap.h> #include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h> #include <ATen/WrapDimUtils.h>
#include <ATen/native/SortingUtils.h>
#include <ATen/native/TensorShape.h> #include <ATen/native/TensorShape.h>
#include <ATen/native/TypeProperties.h> #include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h> #include <ATen/native/mps/MPSGraphVenturaOps.h>
@ -11,10 +12,85 @@
#include <ATen/Functions.h> #include <ATen/Functions.h>
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
#else #else
#include <ATen/ops/kthvalue_native.h>
#include <ATen/ops/sort.h> #include <ATen/ops/sort.h>
#include <ATen/ops/sort_native.h> #include <ATen/ops/sort_native.h>
#endif #endif
namespace at::native { namespace at::native {
namespace {
void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& values, Tensor& indices) {
using namespace mps;
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return;
}
// Handle empty tensors
if (self.numel() == 0) {
values.copy_(self);
indices.copy_(values.toType(at::ScalarType::Long));
return;
}
// issue #154890, raising error to prevent crash within MPSGraph until
// workaround is implemented.
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");
auto stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
};
// MPSGraph kthvalue is always sorted.
@autoreleasepool {
// Input as placeholders
MPSShape* input_shape = getMPSShape(self);
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
std::string key = std::string("kthvalue:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" +
std::to_string(k) + ":dim" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor;
MPSDataType dataType = getMPSDataType(self);
// #issue 104398441 sortWithTensor and argsortWithTensor
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor toType:dataType name:@"castInputTensor"];
}
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
axis:(NSUInteger)dim
descending:false
name:nil];
sortedTensor = [mpsGraph sliceTensor:sortedTensor
dimension:(NSUInteger)dim
start:((NSUInteger)k - 1)
length:1
name:nil];
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:false
name:@"kthvalue_out"];
argSortedTensor = [mpsGraph sliceTensor:argSortedTensor
dimension:dim
start:((NSUInteger)k - 1)
length:1
name:nil];
newCachedGraph->valuesTensor = sortedTensor;
newCachedGraph->indicesTensor = argSortedTensor;
});
Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self);
// Outputs as placeholders
Placeholder valuesPlaceholder = Placeholder(cachedGraph->valuesTensor, values);
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
auto results = dictionaryFromPlaceholders(valuesPlaceholder, indicesPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}
} // anonymous namespace
// sort // sort
TORCH_IMPL_FUNC(sort_stable_out_mps) TORCH_IMPL_FUNC(sort_stable_out_mps)
@ -81,4 +157,31 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
runMPSGraph(stream, cachedGraph->graph(), feeds, results); runMPSGraph(stream, cachedGraph->graph(), feeds, results);
} }
} }
std::tuple<Tensor&, Tensor&> kthvalue_out_mps(const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim,
Tensor& values,
Tensor& indices) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue MPS");
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
TORCH_CHECK(k >= 1 && k <= slicesize, "kthvalue(): selected number k out of range for dimension ", dim);
at::assert_no_overlap(self, values);
_reduction_with_indices_allocate_or_resize_output(values, indices, self, dim, keepdim);
kthvalue_out_mps_impl(self, k, dim, values, indices);
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
return std::forward_as_tuple(values, indices);
}
} // namespace at::native } // namespace at::native

View File

@ -3289,6 +3289,7 @@
dispatch: dispatch:
CPU: kthvalue_out_cpu CPU: kthvalue_out_cpu
CUDA: kthvalue_out_cuda CUDA: kthvalue_out_cuda
MPS: kthvalue_out_mps
- func: kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) - func: kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method variants: function, method

View File

@ -12303,6 +12303,15 @@ class TestConsistency(TestCaseMPS):
if op.name in "grid_sampler_3d": if op.name in "grid_sampler_3d":
atol, rtol = 1e-4, 1e-4 atol, rtol = 1e-4, 1e-4
if op.name == "kthvalue":
self.assertEqual(cpu_out[0], mps_out[0], atol=atol, rtol=rtol)
# kthvalue is non-deterministic if input has repeated values
dim = cpu_args[2] if len(cpu_args) > 2 else -1
keep_dim = cpu_args[3] if len(cpu_args) > 3 else False
values = torch.gather(mps_sample.input, dim, mps_out[1] if keep_dim else mps_out[1].unsqueeze(dim))
self.assertEqual(values if keep_dim else values.squeeze(dim), mps_out[0])
continue
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES) @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)

View File

@ -317,7 +317,7 @@ if torch.backends.mps.is_available():
"index_reducemean": None, "index_reducemean": None,
"index_reduceamax": None, "index_reduceamax": None,
"index_reduceamin": None, "index_reduceamin": None,
"kthvalue": None, # "kthvalue": None,
"lcm": None, "lcm": None,
"linalg.cond": None, "linalg.cond": None,
"linalg.eigh": None, "linalg.eigh": None,