[EZ][BE] Move array def to c10/metal/common.h (#157746)

And use proper type aliasing instead of weird _ARRAY_NS

Also use `uint64_t` instead of `ulong`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157746
Approved by: https://github.com/Skylion007, https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-07-07 17:14:10 -07:00
committed by PyTorch MergeBot
parent a4c7e7f983
commit 0b73f7c871
4 changed files with 33 additions and 39 deletions

View File

@ -1,12 +1,5 @@
#pragma once
#ifndef __METAL__
#include <array>
#define _ARRAY_NS std
#else
#include <metal_array>
#define _ARRAY_NS metal
#endif
#include <c10/metal/common.h>
// N is the maximum allowed number of dimensions in the input and outputs. The
// maximum allowed pooling dimensions is N-2, because the input may have up to 2
@ -16,25 +9,25 @@ template <unsigned N = 5>
struct PoolingParams {
int32_t dims;
int32_t pooling_dims;
_ARRAY_NS::array<int64_t, N> input_sizes;
_ARRAY_NS::array<int64_t, N> input_strides;
_ARRAY_NS::array<int64_t, N> output_sizes;
_ARRAY_NS::array<int64_t, N> output_strides;
_ARRAY_NS::array<int64_t, N> indices_sizes;
_ARRAY_NS::array<int64_t, N> indices_strides;
_ARRAY_NS::array<int64_t, N - 2> kernel_size;
_ARRAY_NS::array<int64_t, N - 2> stride;
_ARRAY_NS::array<int64_t, N - 2> padding;
_ARRAY_NS::array<int64_t, N - 2> dilation;
::c10::metal::array<int64_t, N> input_sizes;
::c10::metal::array<int64_t, N> input_strides;
::c10::metal::array<int64_t, N> output_sizes;
::c10::metal::array<int64_t, N> output_strides;
::c10::metal::array<int64_t, N> indices_sizes;
::c10::metal::array<int64_t, N> indices_strides;
::c10::metal::array<int64_t, N - 2> kernel_size;
::c10::metal::array<int64_t, N - 2> stride;
::c10::metal::array<int64_t, N - 2> padding;
::c10::metal::array<int64_t, N - 2> dilation;
};
template <unsigned N = 5>
struct PoolingBackwardParams {
int32_t dims;
int32_t pooling_dims;
_ARRAY_NS::array<int64_t, N> grad_input_sizes;
_ARRAY_NS::array<int64_t, N> grad_input_strides;
_ARRAY_NS::array<int64_t, N> grad_output_sizes;
_ARRAY_NS::array<int64_t, N> grad_output_strides;
_ARRAY_NS::array<int64_t, N> indices_strides;
::c10::metal::array<int64_t, N> grad_input_sizes;
::c10::metal::array<int64_t, N> grad_input_strides;
::c10::metal::array<int64_t, N> grad_output_sizes;
::c10::metal::array<int64_t, N> grad_output_strides;
::c10::metal::array<int64_t, N> indices_strides;
};

View File

@ -1,20 +1,12 @@
#pragma once
#ifndef __METAL__
#include <array>
using ulong = unsigned long;
#define _ARRAY_NS std
#else
#include <metal_array>
#define _ARRAY_NS metal
#endif
#include <c10/metal/common.h>
template <unsigned N = 5>
struct UpsampleParams {
_ARRAY_NS::array<ulong, N> input_strides;
_ARRAY_NS::array<ulong, N> input_sizes;
_ARRAY_NS::array<ulong, N> output_strides;
_ARRAY_NS::array<ulong, N> output_sizes;
_ARRAY_NS::array<float, N - 2> scales;
::c10::metal::array<uint64_t, N> input_strides;
::c10::metal::array<uint64_t, N> input_sizes;
::c10::metal::array<uint64_t, N> output_strides;
::c10::metal::array<uint64_t, N> output_sizes;
::c10::metal::array<float, N - 2> scales;
bool align_corners;
};

View File

@ -66,7 +66,7 @@ template <typename scalar_t>
scalar_t upsample_get_value_bounded(
constant scalar_t* data,
uint3 dim,
array<ulong, 5> strides,
::metal::array<ulong, 5> strides,
uint n,
uint c,
uint z,
@ -131,7 +131,7 @@ template <typename scalar_t>
void upsample_increment_value_bounded(
device AtomicType_t<scalar_t>* data,
uint3 dim,
array<ulong, 5> strides,
::metal::array<ulong, 5> strides,
uint n,
uint c,
uint z,

View File

@ -2,8 +2,10 @@
// Set of global constants that could be shareable between CPU and Metal code
#ifdef __METAL__
#include <metal_array>
#define C10_METAL_CONSTEXPR constant constexpr
#else
#include <array>
#define C10_METAL_CONSTEXPR constexpr
#endif
@ -37,6 +39,13 @@
namespace c10 {
namespace metal {
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
#ifdef __METAL__
template <typename T, unsigned N>
using array = ::metal::array<T, N>;
#else
template <typename T, unsigned N>
using array = std::array<T, N>;
#endif
enum class ScalarType {
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,