mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Dispatch outer bin edges selection function (#101792)
Dispatch the selection function to prevent using `is_mps()` in `Histogram.cpp`. <!-- copilot:summary --> ### <samp>🤖 Generated by Copilot at b329a02</samp> This pull request refactors and implements the logic for inferring the bin edges of histograms from the input tensor for different device types. It introduces a dispatch stub `histogram_select_outer_bin_edges_stub` and moves the device-specific code to separate files, such as `HistogramKernel.cpp` and `HistogramKernel.mm`. This improves the modularity and readability of the histogram functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101792 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
217a8b4697
commit
99e87bb6a0
@ -65,6 +65,7 @@ namespace at { namespace native {
|
||||
|
||||
DEFINE_DISPATCH(histogramdd_stub);
|
||||
DEFINE_DISPATCH(histogramdd_linear_stub);
|
||||
DEFINE_DISPATCH(histogram_select_outer_bin_edges_stub);
|
||||
|
||||
namespace {
|
||||
|
||||
@ -153,22 +154,6 @@ void histogramdd_prepare_out(const Tensor& input, TensorList bins,
|
||||
histogramdd_prepare_out(input, bin_ct, hist, bin_edges);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
|
||||
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
|
||||
// Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
|
||||
Tensor min, max;
|
||||
std::tie(min, max) = aminmax(input, 0);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
|
||||
|
||||
const scalar_t *min_data = min.data_ptr<scalar_t>();
|
||||
std::copy(min_data, min_data + N, leftmost_edges.begin());
|
||||
|
||||
const scalar_t *max_data = max.data_ptr<scalar_t>();
|
||||
std::copy(max_data, max_data + N, rightmost_edges.begin());
|
||||
}
|
||||
|
||||
/* Determines the outermost bin edges. For simplicity when calling into aminmax,
|
||||
* assumes that input has already been reshaped to (M, N).
|
||||
*/
|
||||
@ -192,19 +177,8 @@ select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>>
|
||||
}
|
||||
} else if (input.numel() > 0) {
|
||||
// non-empty input
|
||||
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
|
||||
if (input.is_mps()) {
|
||||
Tensor min, max;
|
||||
std::tie(min, max) = at::aminmax(input, 0);
|
||||
|
||||
for (const auto i : c10::irange(N)) {
|
||||
leftmost_edges[i] = min[i].item().to<scalar_t>();
|
||||
rightmost_edges[i] = max[i].item().to<scalar_t>();
|
||||
}
|
||||
} else {
|
||||
infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
|
||||
}
|
||||
});
|
||||
histogram_select_outer_bin_edges_stub(input.device().type(), input, N, leftmost_edges, rightmost_edges);
|
||||
}
|
||||
|
||||
for (const auto dim : c10::irange(N)) {
|
||||
|
@ -7,8 +7,10 @@ namespace at::native {
|
||||
|
||||
using histogramdd_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&);
|
||||
using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
|
||||
using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
|
||||
|
||||
DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
|
||||
DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
|
||||
DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/aminmax.h>
|
||||
#include <ATen/ops/sum.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/zeros_like_ops.h>
|
||||
@ -282,10 +283,33 @@ static void histogramdd_linear_kernel_impl(const Tensor& self, const c10::option
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
|
||||
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
|
||||
// Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
|
||||
Tensor min, max;
|
||||
std::tie(min, max) = aminmax(input, 0);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
|
||||
|
||||
const scalar_t *min_data = min.data_ptr<scalar_t>();
|
||||
std::copy(min_data, min_data + N, leftmost_edges.begin());
|
||||
|
||||
const scalar_t *max_data = max.data_ptr<scalar_t>();
|
||||
std::copy(max_data, max_data + N, rightmost_edges.begin());
|
||||
}
|
||||
|
||||
static void histogram_select_outer_bin_edges_impl(const Tensor& input, const int64_t N,
|
||||
std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
|
||||
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
|
||||
infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
|
||||
|
||||
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
|
||||
REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_impl);
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/aminmax.h>
|
||||
#include <ATen/ops/sum.h>
|
||||
#endif
|
||||
|
||||
@ -396,6 +397,20 @@ static void histogramdd_linear_kernel(const Tensor& self,
|
||||
}
|
||||
}
|
||||
|
||||
static void histogram_select_outer_bin_edges_kernel(const Tensor& input,
|
||||
const int64_t N,
|
||||
std::vector<double>& leftmost_edges,
|
||||
std::vector<double>& rightmost_edges) {
|
||||
Tensor min, max;
|
||||
std::tie(min, max) = at::aminmax(input, 0);
|
||||
|
||||
for (const auto i : c10::irange(N)) {
|
||||
leftmost_edges[i] = min[i].item().to<double>();
|
||||
rightmost_edges[i] = max[i].item().to<double>();
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel);
|
||||
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel);
|
||||
REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_kernel);
|
||||
} // namespace at::native
|
||||
|
Reference in New Issue
Block a user