mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add slow version of kthvalue
(#161817)
Which heavily borrows implementation logic from `topk` As this method is non-deterministic, modified the logic for cpu-ops indices comparison with just an equality statement, as by default random numbers picked for input tensor allow for quite a lot of overlaps Pull Request resolved: https://github.com/pytorch/pytorch/pull/161817 Approved by: https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1e504ec2f
commit
7c30a9d7fc
@ -2,6 +2,7 @@
|
|||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/MemoryOverlap.h>
|
#include <ATen/MemoryOverlap.h>
|
||||||
#include <ATen/WrapDimUtils.h>
|
#include <ATen/WrapDimUtils.h>
|
||||||
|
#include <ATen/native/SortingUtils.h>
|
||||||
#include <ATen/native/TensorShape.h>
|
#include <ATen/native/TensorShape.h>
|
||||||
#include <ATen/native/TypeProperties.h>
|
#include <ATen/native/TypeProperties.h>
|
||||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||||
@ -11,10 +12,85 @@
|
|||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/ops/kthvalue_native.h>
|
||||||
#include <ATen/ops/sort.h>
|
#include <ATen/ops/sort.h>
|
||||||
#include <ATen/ops/sort_native.h>
|
#include <ATen/ops/sort_native.h>
|
||||||
#endif
|
#endif
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& values, Tensor& indices) {
|
||||||
|
using namespace mps;
|
||||||
|
if (self.dim() == 0 && self.numel() == 1) {
|
||||||
|
values.copy_(self);
|
||||||
|
indices.zero_();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Handle empty tensors
|
||||||
|
if (self.numel() == 0) {
|
||||||
|
values.copy_(self);
|
||||||
|
indices.copy_(values.toType(at::ScalarType::Long));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// issue #154890, raising error to prevent crash within MPSGraph until
|
||||||
|
// workaround is implemented.
|
||||||
|
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
struct CachedGraph : public MPSCachedGraph {
|
||||||
|
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||||
|
MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil;
|
||||||
|
};
|
||||||
|
|
||||||
|
// MPSGraph kthvalue is always sorted.
|
||||||
|
@autoreleasepool {
|
||||||
|
// Input as placeholders
|
||||||
|
MPSShape* input_shape = getMPSShape(self);
|
||||||
|
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||||
|
std::string key = std::string("kthvalue:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" +
|
||||||
|
std::to_string(k) + ":dim" + std::to_string(dim);
|
||||||
|
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||||
|
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||||
|
|
||||||
|
MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor;
|
||||||
|
MPSDataType dataType = getMPSDataType(self);
|
||||||
|
// #issue 104398441 sortWithTensor and argsortWithTensor
|
||||||
|
if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
|
||||||
|
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||||
|
castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor toType:dataType name:@"castInputTensor"];
|
||||||
|
}
|
||||||
|
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||||
|
axis:(NSUInteger)dim
|
||||||
|
descending:false
|
||||||
|
name:nil];
|
||||||
|
sortedTensor = [mpsGraph sliceTensor:sortedTensor
|
||||||
|
dimension:(NSUInteger)dim
|
||||||
|
start:((NSUInteger)k - 1)
|
||||||
|
length:1
|
||||||
|
name:nil];
|
||||||
|
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
|
||||||
|
axis:(NSInteger)dim
|
||||||
|
descending:false
|
||||||
|
name:@"kthvalue_out"];
|
||||||
|
argSortedTensor = [mpsGraph sliceTensor:argSortedTensor
|
||||||
|
dimension:dim
|
||||||
|
start:((NSUInteger)k - 1)
|
||||||
|
length:1
|
||||||
|
name:nil];
|
||||||
|
newCachedGraph->valuesTensor = sortedTensor;
|
||||||
|
newCachedGraph->indicesTensor = argSortedTensor;
|
||||||
|
});
|
||||||
|
Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self);
|
||||||
|
// Outputs as placeholders
|
||||||
|
Placeholder valuesPlaceholder = Placeholder(cachedGraph->valuesTensor, values);
|
||||||
|
Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices);
|
||||||
|
// Create dictionary of inputs and outputs
|
||||||
|
auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
|
||||||
|
auto results = dictionaryFromPlaceholders(valuesPlaceholder, indicesPlaceholder);
|
||||||
|
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
// sort
|
// sort
|
||||||
TORCH_IMPL_FUNC(sort_stable_out_mps)
|
TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||||
@ -81,4 +157,31 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
|||||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor&, Tensor&> kthvalue_out_mps(const Tensor& self,
|
||||||
|
int64_t k,
|
||||||
|
int64_t dim_,
|
||||||
|
bool keepdim,
|
||||||
|
Tensor& values,
|
||||||
|
Tensor& indices) {
|
||||||
|
// See note [Writing Nondeterministic Operations]
|
||||||
|
// If there are duplicate elements of the kth value, the procedure for choosing which
|
||||||
|
// of the duplicates to use for the indices output is nondeterministic.
|
||||||
|
at::globalContext().alertNotDeterministic("kthvalue MPS");
|
||||||
|
|
||||||
|
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
||||||
|
int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
|
||||||
|
TORCH_CHECK(k >= 1 && k <= slicesize, "kthvalue(): selected number k out of range for dimension ", dim);
|
||||||
|
at::assert_no_overlap(self, values);
|
||||||
|
_reduction_with_indices_allocate_or_resize_output(values, indices, self, dim, keepdim);
|
||||||
|
|
||||||
|
kthvalue_out_mps_impl(self, k, dim, values, indices);
|
||||||
|
|
||||||
|
if (!keepdim) {
|
||||||
|
values.squeeze_(dim);
|
||||||
|
indices.squeeze_(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::forward_as_tuple(values, indices);
|
||||||
|
}
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
@ -3289,6 +3289,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
CPU: kthvalue_out_cpu
|
CPU: kthvalue_out_cpu
|
||||||
CUDA: kthvalue_out_cuda
|
CUDA: kthvalue_out_cuda
|
||||||
|
MPS: kthvalue_out_mps
|
||||||
|
|
||||||
- func: kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
|
- func: kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
@ -12303,6 +12303,15 @@ class TestConsistency(TestCaseMPS):
|
|||||||
if op.name in "grid_sampler_3d":
|
if op.name in "grid_sampler_3d":
|
||||||
atol, rtol = 1e-4, 1e-4
|
atol, rtol = 1e-4, 1e-4
|
||||||
|
|
||||||
|
if op.name == "kthvalue":
|
||||||
|
self.assertEqual(cpu_out[0], mps_out[0], atol=atol, rtol=rtol)
|
||||||
|
# kthvalue is non-deterministic if input has repeated values
|
||||||
|
dim = cpu_args[2] if len(cpu_args) > 2 else -1
|
||||||
|
keep_dim = cpu_args[3] if len(cpu_args) > 3 else False
|
||||||
|
values = torch.gather(mps_sample.input, dim, mps_out[1] if keep_dim else mps_out[1].unsqueeze(dim))
|
||||||
|
self.assertEqual(values if keep_dim else values.squeeze(dim), mps_out[0])
|
||||||
|
continue
|
||||||
|
|
||||||
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
|
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
|
||||||
|
@ -317,7 +317,7 @@ if torch.backends.mps.is_available():
|
|||||||
"index_reducemean": None,
|
"index_reducemean": None,
|
||||||
"index_reduceamax": None,
|
"index_reduceamax": None,
|
||||||
"index_reduceamin": None,
|
"index_reduceamin": None,
|
||||||
"kthvalue": None,
|
# "kthvalue": None,
|
||||||
"lcm": None,
|
"lcm": None,
|
||||||
"linalg.cond": None,
|
"linalg.cond": None,
|
||||||
"linalg.eigh": None,
|
"linalg.eigh": None,
|
||||||
|
Reference in New Issue
Block a user