mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[EZ][BE] Move array def to c10/metal/common.h
(#157746)
And use proper type aliasing instead of weird _ARRAY_NS Also use `uint64_t` instead of `ulong` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157746 Approved by: https://github.com/Skylion007, https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4c7e7f983
commit
0b73f7c871
@ -1,12 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef __METAL__
|
||||
#include <array>
|
||||
#define _ARRAY_NS std
|
||||
#else
|
||||
#include <metal_array>
|
||||
#define _ARRAY_NS metal
|
||||
#endif
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
// N is the maximum allowed number of dimensions in the input and outputs. The
|
||||
// maximum allowed pooling dimensions is N-2, because the input may have up to 2
|
||||
@ -16,25 +9,25 @@ template <unsigned N = 5>
|
||||
struct PoolingParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
_ARRAY_NS::array<int64_t, N> input_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> input_strides;
|
||||
_ARRAY_NS::array<int64_t, N> output_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> output_strides;
|
||||
_ARRAY_NS::array<int64_t, N> indices_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> indices_strides;
|
||||
_ARRAY_NS::array<int64_t, N - 2> kernel_size;
|
||||
_ARRAY_NS::array<int64_t, N - 2> stride;
|
||||
_ARRAY_NS::array<int64_t, N - 2> padding;
|
||||
_ARRAY_NS::array<int64_t, N - 2> dilation;
|
||||
::c10::metal::array<int64_t, N> input_sizes;
|
||||
::c10::metal::array<int64_t, N> input_strides;
|
||||
::c10::metal::array<int64_t, N> output_sizes;
|
||||
::c10::metal::array<int64_t, N> output_strides;
|
||||
::c10::metal::array<int64_t, N> indices_sizes;
|
||||
::c10::metal::array<int64_t, N> indices_strides;
|
||||
::c10::metal::array<int64_t, N - 2> kernel_size;
|
||||
::c10::metal::array<int64_t, N - 2> stride;
|
||||
::c10::metal::array<int64_t, N - 2> padding;
|
||||
::c10::metal::array<int64_t, N - 2> dilation;
|
||||
};
|
||||
|
||||
template <unsigned N = 5>
|
||||
struct PoolingBackwardParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
_ARRAY_NS::array<int64_t, N> grad_input_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> grad_input_strides;
|
||||
_ARRAY_NS::array<int64_t, N> grad_output_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> grad_output_strides;
|
||||
_ARRAY_NS::array<int64_t, N> indices_strides;
|
||||
::c10::metal::array<int64_t, N> grad_input_sizes;
|
||||
::c10::metal::array<int64_t, N> grad_input_strides;
|
||||
::c10::metal::array<int64_t, N> grad_output_sizes;
|
||||
::c10::metal::array<int64_t, N> grad_output_strides;
|
||||
::c10::metal::array<int64_t, N> indices_strides;
|
||||
};
|
||||
|
@ -1,20 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef __METAL__
|
||||
#include <array>
|
||||
using ulong = unsigned long;
|
||||
#define _ARRAY_NS std
|
||||
#else
|
||||
#include <metal_array>
|
||||
#define _ARRAY_NS metal
|
||||
#endif
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = 5>
|
||||
struct UpsampleParams {
|
||||
_ARRAY_NS::array<ulong, N> input_strides;
|
||||
_ARRAY_NS::array<ulong, N> input_sizes;
|
||||
_ARRAY_NS::array<ulong, N> output_strides;
|
||||
_ARRAY_NS::array<ulong, N> output_sizes;
|
||||
_ARRAY_NS::array<float, N - 2> scales;
|
||||
::c10::metal::array<uint64_t, N> input_strides;
|
||||
::c10::metal::array<uint64_t, N> input_sizes;
|
||||
::c10::metal::array<uint64_t, N> output_strides;
|
||||
::c10::metal::array<uint64_t, N> output_sizes;
|
||||
::c10::metal::array<float, N - 2> scales;
|
||||
bool align_corners;
|
||||
};
|
||||
|
@ -66,7 +66,7 @@ template <typename scalar_t>
|
||||
scalar_t upsample_get_value_bounded(
|
||||
constant scalar_t* data,
|
||||
uint3 dim,
|
||||
array<ulong, 5> strides,
|
||||
::metal::array<ulong, 5> strides,
|
||||
uint n,
|
||||
uint c,
|
||||
uint z,
|
||||
@ -131,7 +131,7 @@ template <typename scalar_t>
|
||||
void upsample_increment_value_bounded(
|
||||
device AtomicType_t<scalar_t>* data,
|
||||
uint3 dim,
|
||||
array<ulong, 5> strides,
|
||||
::metal::array<ulong, 5> strides,
|
||||
uint n,
|
||||
uint c,
|
||||
uint z,
|
||||
|
@ -2,8 +2,10 @@
|
||||
// Set of global constants that could be shareable between CPU and Metal code
|
||||
|
||||
#ifdef __METAL__
|
||||
#include <metal_array>
|
||||
#define C10_METAL_CONSTEXPR constant constexpr
|
||||
#else
|
||||
#include <array>
|
||||
#define C10_METAL_CONSTEXPR constexpr
|
||||
#endif
|
||||
|
||||
@ -37,6 +39,13 @@
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
||||
#ifdef __METAL__
|
||||
template <typename T, unsigned N>
|
||||
using array = ::metal::array<T, N>;
|
||||
#else
|
||||
template <typename T, unsigned N>
|
||||
using array = std::array<T, N>;
|
||||
#endif
|
||||
|
||||
enum class ScalarType {
|
||||
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
|
||||
|
Reference in New Issue
Block a user