[MPS] Dispatch outer bin edges selection function (#101792)

Dispatch the selection function to prevent using `is_mps()` in `Histogram.cpp`.

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at b329a02</samp>

This pull request refactors and implements the logic for inferring the bin edges of histograms from the input tensor for different device types. It introduces a dispatch stub `histogram_select_outer_bin_edges_stub` and moves the device-specific code to separate files, such as `HistogramKernel.cpp` and `HistogramKernel.mm`. This improves the modularity and readability of the histogram functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101792
Approved by: https://github.com/albanD
This commit is contained in:
Li-Huai (Allan) Lin
2023-06-27 08:24:42 +00:00
committed by PyTorch MergeBot
parent 217a8b4697
commit 99e87bb6a0
4 changed files with 44 additions and 29 deletions

View File

@ -65,6 +65,7 @@ namespace at { namespace native {
DEFINE_DISPATCH(histogramdd_stub);
DEFINE_DISPATCH(histogramdd_linear_stub);
DEFINE_DISPATCH(histogram_select_outer_bin_edges_stub);
namespace {
@ -153,22 +154,6 @@ void histogramdd_prepare_out(const Tensor& input, TensorList bins,
histogramdd_prepare_out(input, bin_ct, hist, bin_edges);
}
template<typename scalar_t>
void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
// Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
Tensor min, max;
std::tie(min, max) = aminmax(input, 0);
TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
const scalar_t *min_data = min.data_ptr<scalar_t>();
std::copy(min_data, min_data + N, leftmost_edges.begin());
const scalar_t *max_data = max.data_ptr<scalar_t>();
std::copy(max_data, max_data + N, rightmost_edges.begin());
}
/* Determines the outermost bin edges. For simplicity when calling into aminmax,
* assumes that input has already been reshaped to (M, N).
*/
@ -192,19 +177,8 @@ select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>>
}
} else if (input.numel() > 0) {
// non-empty input
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
if (input.is_mps()) {
Tensor min, max;
std::tie(min, max) = at::aminmax(input, 0);
for (const auto i : c10::irange(N)) {
leftmost_edges[i] = min[i].item().to<scalar_t>();
rightmost_edges[i] = max[i].item().to<scalar_t>();
}
} else {
infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
}
});
histogram_select_outer_bin_edges_stub(input.device().type(), input, N, leftmost_edges, rightmost_edges);
}
for (const auto dim : c10::irange(N)) {

View File

@ -7,8 +7,10 @@ namespace at::native {
using histogramdd_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&);
using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
} // namespace at::native

View File

@ -10,6 +10,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/aminmax.h>
#include <ATen/ops/sum.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like_ops.h>
@ -282,10 +283,33 @@ static void histogramdd_linear_kernel_impl(const Tensor& self, const c10::option
}
}
template<typename scalar_t>
void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
// Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
Tensor min, max;
std::tie(min, max) = aminmax(input, 0);
TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
const scalar_t *min_data = min.data_ptr<scalar_t>();
std::copy(min_data, min_data + N, leftmost_edges.begin());
const scalar_t *max_data = max.data_ptr<scalar_t>();
std::copy(max_data, max_data + N, rightmost_edges.begin());
}
static void histogram_select_outer_bin_edges_impl(const Tensor& input, const int64_t N,
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
});
}
} // namespace
REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_impl);
} // namespace at::native

View File

@ -8,6 +8,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/aminmax.h>
#include <ATen/ops/sum.h>
#endif
@ -396,6 +397,20 @@ static void histogramdd_linear_kernel(const Tensor& self,
}
}
static void histogram_select_outer_bin_edges_kernel(const Tensor& input,
const int64_t N,
std::vector<double>& leftmost_edges,
std::vector<double>& rightmost_edges) {
Tensor min, max;
std::tie(min, max) = at::aminmax(input, 0);
for (const auto i : c10::irange(N)) {
leftmost_edges[i] = min[i].item().to<double>();
rightmost_edges[i] = max[i].item().to<double>();
}
}
REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel);
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel);
REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_kernel);
} // namespace at::native