Compare commits

...

2 Commits

Author SHA1 Message Date
74b5dc9316 Will that help 2024-12-13 20:46:33 -08:00
b8f534f021 [MPS] Add col2im shader
As a native path for a frequently requested torch.unfold
2024-12-13 20:46:33 -08:00
4 changed files with 178 additions and 13 deletions

View File

@ -91,7 +91,6 @@ TORCH_LIBRARY_IMPL(aten, MPS, m) {
m.impl("embedding_renorm_", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
}

View File

@ -1,5 +1,8 @@
#include <metal_stdlib>
// Heavily inspired by
// https://github.com/pytorch/pytorch/blob/09519eb19/aten/src/ATen/native/cuda/im2col.cuh#L51
template <typename T>
void im2col_kernel(
constant T* input,
@ -60,8 +63,72 @@ kernel void im2col(
output_strides.y);
}
#define INSTANTIATE_IM2COL(DTYPE) \
template <typename T>
kernel void col2im(
constant T* inputData [[buffer(0)]],
device T* outputData [[buffer(1)]],
constant uint4& kernel_dilation [[buffer(2)]],
constant int4& padding_stride [[buffer(3)]],
constant ulong4& input_strides [[buffer(4)]],
constant ulong4& output_strides [[buffer(5)]],
constant long4& input_sizes [[buffer(6)]],
uint3 thread_index [[thread_position_in_grid]]) {
// thread_index is (output_length, output_channels, input_batch)
const auto N = thread_index.z;
const auto C = thread_index.y;
const auto L = thread_index.x;
const auto output_width = input_strides.w;
const auto output_height = input_sizes.w;
const int64_t pad_height = padding_stride.y;
const int64_t pad_width = padding_stride.x;
const int64_t dilation_height = kernel_dilation.w;
const int64_t dilation_width = kernel_dilation.z;
const int64_t kernel_width = kernel_dilation.x;
const int64_t kernel_height = kernel_dilation.y;
const int64_t stride_width = padding_stride.z;
const int64_t stride_height = padding_stride.w;
const int64_t height_col = (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / stride_height + 1;
const int64_t width_col = (output_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1)) / stride_width + 1;
T val = static_cast<T>(0);
const int64_t w_im = L % output_width + pad_width;
const int64_t h_im = L / output_width + pad_height;
const int64_t kernel_extent_w = (kernel_width - 1) * dilation_width + 1;
const int64_t kernel_extent_h = (kernel_height - 1) * dilation_height + 1;
const int64_t w_col_start = (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_width + 1;
const int64_t w_col_end = metal::min(w_im / stride_width + 1, width_col);
const int64_t h_col_start = (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_height + 1;
const int64_t h_col_end = metal::min(h_im / stride_height + 1, height_col);
for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) {
for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) {
int64_t h_k = (h_im - h_col * stride_height);
int64_t w_k = (w_im - w_col * stride_width);
if (h_k % dilation_height == 0 && w_k % dilation_width == 0) {
h_k /= dilation_height;
w_k /= dilation_width;
int64_t data_col_index =
(((C * kernel_height + h_k) * kernel_width + w_k) * height_col +
h_col) *
width_col +
w_col;
val += inputData[data_col_index];
}
}
}
outputData[N*output_strides.w + C * output_strides.z + L ] = val;
}
#define INSTANTIATE_IM2COL_COL2IM(DTYPE) \
template [[host_name("im2col_" #DTYPE)]] kernel void im2col<DTYPE>( \
constant DTYPE * inputData [[buffer(0)]], \
device DTYPE * outputData [[buffer(1)]], \
constant uint4 & kernel_dilation [[buffer(2)]], \
constant int4 & padding_stride [[buffer(3)]], \
constant ulong4 & input_strides [[buffer(4)]], \
constant ulong4 & output_strides [[buffer(5)]], \
constant long4 & input_sizes [[buffer(6)]], \
uint3 thread_index [[thread_position_in_grid]]); \
template [[host_name("col2im_" #DTYPE)]] kernel void col2im<DTYPE>( \
constant DTYPE * inputData [[buffer(0)]], \
device DTYPE * outputData [[buffer(1)]], \
constant uint4 & kernel_dilation [[buffer(2)]], \
@ -71,11 +138,11 @@ kernel void im2col(
constant long4 & input_sizes [[buffer(6)]], \
uint3 thread_index [[thread_position_in_grid]])
INSTANTIATE_IM2COL(bool);
INSTANTIATE_IM2COL(float);
INSTANTIATE_IM2COL(float2);
INSTANTIATE_IM2COL(half);
INSTANTIATE_IM2COL(half2);
INSTANTIATE_IM2COL_COL2IM(bool);
INSTANTIATE_IM2COL_COL2IM(float);
INSTANTIATE_IM2COL_COL2IM(float2);
INSTANTIATE_IM2COL_COL2IM(half);
INSTANTIATE_IM2COL_COL2IM(half2);
#if __METAL_VERSION__ >= 310
INSTANTIATE_IM2COL(bfloat);
INSTANTIATE_IM2COL_COL2IM(bfloat);
#endif

View File

@ -27,11 +27,8 @@ static void im2col_out_mps_template(Tensor& output,
IntArrayRef padding,
IntArrayRef stride) {
TORCH_CHECK(kernel_size.size() == 2, "It is expected kernel_size equals to 2, but got size ", kernel_size.size());
TORCH_CHECK(dilation.size() == 2, "It is expected dilation equals to 2, but got size ", dilation.size());
TORCH_CHECK(padding.size() == 2, "It is expected padding equals to 2, but got size ", padding.size());
TORCH_CHECK(stride.size() == 2, "It is expected stride equals to 2, but got size ", stride.size());
const auto kernel_height = kernel_size[0];
@ -65,7 +62,6 @@ static void im2col_out_mps_template(Tensor& output,
output.resize_({batch_size, n_output_plane, output_length});
auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();
auto im2colPSO = lib.getPipelineStateForFunc("im2col_" + mps::scalarToMetalTypeString(input));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
@ -79,7 +75,7 @@ static void im2col_out_mps_template(Tensor& output,
static_cast<int32_t>(stride_height)};
std::array<int64_t, 4> input_sizes = {input_width, input_height, n_input_plane, batch_size};
std::array<int64_t, 4> input_strides = {input.stride(3), input.stride(2), input.stride(1), input.stride(0)};
std::array<int64_t, 4> output_strides = {output.stride(2), output.stride(1), output.stride(0), output_width};
std::array<int64_t, 4> output_strides = {output.stride(3), output.stride(2), output.stride(1), output.stride(0)};
getMPSProfiler().beginProfileKernel(im2colPSO, "im2col", {input, output});
auto computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:im2colPSO];
@ -95,7 +91,84 @@ static void im2col_out_mps_template(Tensor& output,
}
}
void col2im_out_mps_template(Tensor& output,
const Tensor& input_,
IntArrayRef output_size,
IntArrayRef kernel_size,
IntArrayRef dilation,
IntArrayRef padding,
IntArrayRef stride) {
TensorArg input_arg{input_, "input", 1};
TensorArg output_arg{output, "output", 2};
checkAllSameGPU(__func__, {input_arg, output_arg});
TORCH_CHECK(output_size.size() == 2, "It is expected output_size equals to 2, but got size ", output_size.size());
TORCH_CHECK(kernel_size.size() == 2, "It is expected kernel_size equals to 2, but got size ", kernel_size.size());
TORCH_CHECK(dilation.size() == 2, "It is expected dilation equals to 2, but got size ", dilation.size());
TORCH_CHECK(padding.size() == 2, "It is expected padding equals to 2, but got size ", padding.size());
TORCH_CHECK(stride.size() == 2, "It is expected stride equals to 2, but got size ", stride.size());
int64_t output_height = output_size[0];
int64_t output_width = output_size[1];
int64_t kernel_height = kernel_size[0];
int64_t kernel_width = kernel_size[1];
int64_t dilation_height = dilation[0];
int64_t dilation_width = dilation[1];
int64_t pad_height = padding[0];
int64_t pad_width = padding[1];
int64_t stride_height = stride[0];
int64_t stride_width = stride[1];
Tensor input = input_.contiguous();
bool batched_input = true;
if (input.dim() == 2) {
// Force batch
batched_input = false;
input = input.unsqueeze(0);
}
int64_t batch_size = input.size(0);
int64_t n_input_plane = input.size(1);
int64_t input_width = input.size(2);
int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height);
int64_t input_batch_stride = input.stride(0);
output.resize_({batch_size, n_output_plane, output_height, output_width});
auto stream = getCurrentMPSStream();
auto col2imPSO = lib.getPipelineStateForFunc("col2im_" + mps::scalarToMetalTypeString(input));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
getMPSProfiler().startCapture("col2im");
std::array<int32_t, 4> kernel_dilation = {static_cast<int32_t>(kernel_width),
static_cast<int32_t>(kernel_height),
static_cast<int32_t>(dilation_width),
static_cast<int32_t>(dilation_height)};
std::array<int32_t, 4> padding_stride = {static_cast<int32_t>(pad_width),
static_cast<int32_t>(pad_height),
static_cast<int32_t>(stride_width),
static_cast<int32_t>(stride_height)};
std::array<int64_t, 4> input_sizes = {output_height, input_width, n_input_plane, batch_size};
std::array<int64_t, 4> input_strides = {output_width, input.stride(2), input.stride(1), input.stride(0)};
std::array<int64_t, 4> output_strides = {output.stride(3), output.stride(2), output.stride(1), output.stride(0)};
//getMPSProfiler().beginProfileKernel(col2imPSO, "col2im", {input, output});
auto computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:col2imPSO];
mtl_setArgs(
computeEncoder, input, output, kernel_dilation, padding_stride, input_strides, output_strides, input_sizes);
[computeEncoder dispatchThreads:MTLSizeMake(output_width * output_height, n_output_plane, batch_size)
threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
//getMPSProfiler().endProfileKernel(col2imPSO);
getMPSProfiler().stopCapture();
}
});
if (!batched_input) {
output = output.squeeze(0);
}
}
} // anonymous namespace
Tensor& im2col_out_mps(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef dilation,
@ -115,4 +188,28 @@ Tensor im2col_mps(const Tensor& input,
im2col_out_mps_template(output, input, kernel_size, dilation, padding, stride);
return output;
}
Tensor& col2im_out_mps(const Tensor& input,
IntArrayRef output_size,
IntArrayRef kernel_size,
IntArrayRef dilation,
IntArrayRef padding,
IntArrayRef stride,
Tensor& output) {
col2im_out_mps_template(output, input, output_size, kernel_size, dilation, padding, stride);
return output;
}
Tensor col2im_mps(const Tensor& input,
IntArrayRef output_size,
IntArrayRef kernel_size,
IntArrayRef dilation,
IntArrayRef padding,
IntArrayRef stride) {
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
col2im_out_mps_template(output, input, output_size, kernel_size, dilation, padding, stride);
return output;
}
} // namespace at::native

View File

@ -13109,12 +13109,14 @@
dispatch:
CPU: col2im_out_cpu
CUDA: col2im_out_cuda
MPS: col2im_out_mps
- func: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
python_module: nn
dispatch:
CPU: col2im_cpu
CUDA: col2im_cuda
MPS: col2im_mps
tags: core
- func: column_stack(Tensor[] tensors) -> Tensor