mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Support large tensors in torch.cat
(#164416)
Fixes #164415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164416 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
684df93975
commit
83cbba8759
@ -99,6 +99,9 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape);
|
||||
MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
|
||||
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
|
||||
|
||||
// Determines whether a tensor is too large to use MPSGraph
|
||||
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI = true);
|
||||
|
||||
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
}
|
||||
|
@ -439,6 +439,22 @@ static void check_mps_shape(MPSShape* shape) {
|
||||
}
|
||||
}
|
||||
|
||||
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI) {
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
if ((!tensor.is_contiguous() || tensor.storage_offset()) && useMPSStridedAPI && is_macOS_15_0_or_newer) {
|
||||
auto storage_numel = tensor.storage().nbytes() / tensor.element_size() - tensor.storage_offset();
|
||||
if (storage_numel > std::numeric_limits<int32_t>::max()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (auto size : tensor.sizes()) {
|
||||
if (size > std::numeric_limits<int32_t>::max()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
|
||||
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
|
||||
|
||||
|
18
aten/src/ATen/native/mps/kernels/Shape.h
Normal file
18
aten/src/ATen/native/mps/kernels/Shape.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeSharedParams {
|
||||
int32_t ndim;
|
||||
int32_t cat_dim;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
};
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeInputParams {
|
||||
idx_type_t cat_dim_offset;
|
||||
idx_type_t input_element_offset;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
};
|
82
aten/src/ATen/native/mps/kernels/Shape.metal
Normal file
82
aten/src/ATen/native/mps/kernels/Shape.metal
Normal file
@ -0,0 +1,82 @@
|
||||
#include <ATen/native/mps/kernels/Shape.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
kernel void cat_large(
|
||||
constant T_in* input [[buffer(0)]],
|
||||
device T_out* output [[buffer(1)]],
|
||||
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
|
||||
constant CatLargeInputParams<>& input_params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto ndim = shared_params.ndim;
|
||||
auto cat_dim = shared_params.cat_dim;
|
||||
constant auto& output_strides = shared_params.output_strides;
|
||||
constant auto& output_sizes = shared_params.output_sizes;
|
||||
|
||||
auto cat_dim_offset = input_params.cat_dim_offset;
|
||||
auto input_element_offset = input_params.input_element_offset;
|
||||
constant auto& input_strides = input_params.input_strides;
|
||||
constant auto& input_sizes = input_params.input_sizes;
|
||||
|
||||
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
|
||||
int64_t input_offset = 0;
|
||||
int64_t output_offset = 0;
|
||||
|
||||
for (auto dim = ndim - 1; dim >= 0; dim--) {
|
||||
auto dim_size = input_sizes[dim];
|
||||
auto input_dim_idx = input_element_idx % dim_size;
|
||||
auto output_dim_idx =
|
||||
input_dim_idx + ((dim == cat_dim) ? cat_dim_offset : 0);
|
||||
|
||||
input_offset += input_strides[dim] * input_dim_idx;
|
||||
output_offset += output_strides[dim] * output_dim_idx;
|
||||
|
||||
input_element_idx = input_element_idx / dim_size;
|
||||
}
|
||||
|
||||
output[output_offset] = static_cast<T_out>(input[input_offset]);
|
||||
}
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
|
||||
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
|
||||
kernel void cat_large<T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
|
||||
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
|
||||
REGISTER_CAT_LARGE_OP(float, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(half, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(int, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uint, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(long, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ulong, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(short, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ushort, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(char, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uchar, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bool, T_out);
|
||||
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
|
||||
|
||||
REGISTER_CAT_LARGE_OP(float2, float2);
|
||||
REGISTER_CAT_LARGE_OP(half2, half2);
|
@ -2,9 +2,13 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/Shape.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -16,6 +20,13 @@
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/Shape_metallib.h>
|
||||
#endif
|
||||
|
||||
namespace mps {
|
||||
|
||||
// Produces a shape with the `dim` dimension set to 0.
|
||||
@ -57,6 +68,70 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
||||
// This implementation of cat is used only if one of the inputs or the output is
|
||||
// too large to use MPSGraph.
|
||||
// NOTE: `output` is expected to already have the correct size.
|
||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatLargeSharedParams shared_params;
|
||||
|
||||
shared_params.ndim = output.dim();
|
||||
shared_params.cat_dim = dimension;
|
||||
|
||||
for (const auto dim : c10::irange(output.dim())) {
|
||||
shared_params.output_strides[dim] = output.stride(dim);
|
||||
shared_params.output_sizes[dim] = output.size(dim);
|
||||
}
|
||||
|
||||
int64_t cat_dim_offset = 0;
|
||||
size_t input_idx = 0;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Launch a separate kernels for each input. This will produce some overhead,
|
||||
// but that should be relatively minimal since at least one of the inputs is
|
||||
// very large. In order to launch only one kernel to process all inputs, we
|
||||
// would have to copy all the input tensor data into a packed buffer, which
|
||||
// would not be ideal.
|
||||
for (const Tensor& input : inputs) {
|
||||
if (input.numel() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Metal can only launch up to MAX_INT threads at one time. If the input has
|
||||
// more than that number of elements, launch multiple kernels with different
|
||||
// offsets into the data.
|
||||
const int64_t max_num_threads = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
|
||||
|
||||
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
|
||||
auto num_threads = std::min(max_num_threads, numel_remaining);
|
||||
CatLargeInputParams input_params;
|
||||
|
||||
input_params.cat_dim_offset = cat_dim_offset;
|
||||
input_params.input_element_offset = input.numel() - numel_remaining;
|
||||
|
||||
for (const auto dim : c10::irange(input.dim())) {
|
||||
input_params.input_strides[dim] = input.stride(dim);
|
||||
input_params.input_sizes[dim] = input.size(dim);
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(
|
||||
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
|
||||
mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
cat_dim_offset += input.size(dimension);
|
||||
input_idx++;
|
||||
}
|
||||
}
|
||||
} // namespace mps
|
||||
|
||||
// topk
|
||||
@ -231,7 +306,11 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
idx = 0;
|
||||
bool has_large_tensor = false;
|
||||
for (const Tensor& tensor : materialized_inputs) {
|
||||
if (isTooLargeForMPSGraph(tensor)) {
|
||||
has_large_tensor |= true;
|
||||
}
|
||||
if (!should_skip(tensor)) {
|
||||
// TODO: Factor out `check_shape_except_dim`
|
||||
check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
|
||||
@ -249,6 +328,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
return;
|
||||
}
|
||||
|
||||
has_large_tensor |= isTooLargeForMPSGraph(out);
|
||||
|
||||
if (has_large_tensor) {
|
||||
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
std::vector<MPSGraphTensor*> inputTensors_;
|
||||
|
@ -80,6 +80,9 @@ if not torch.backends.mps.is_available():
|
||||
|
||||
total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
|
||||
|
||||
MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble]
|
||||
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
|
||||
|
||||
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
|
||||
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
|
||||
|
||||
@ -3637,6 +3640,70 @@ class TestMPS(TestCaseMPS):
|
||||
# TODO: enable memory format test
|
||||
# self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
|
||||
|
||||
# Skip if a test needs more memory than the system has.
|
||||
def _skip_if_exceeds_total_memory(self, required_memory):
|
||||
if total_memory < required_memory:
|
||||
self.skipTest(
|
||||
f"Needs {required_memory / (1024**3):0.01f} GiB RAM, "
|
||||
f"but only {total_memory / (1024**3):0.01f} GiB is available.")
|
||||
|
||||
@parametrize("dtype", MPS_DTYPES)
|
||||
def test_cat_large_tensor(self, dtype):
|
||||
a_shape = (1, 11 + (1 << 31), 1)
|
||||
b_shape = (1, 100, 1)
|
||||
|
||||
# Assume up to 1% extra overhead memory might be required.
|
||||
required_memory = 1.01 * (math.prod(a_shape) + math.prod(a_shape)) * dtype.itemsize
|
||||
self._skip_if_exceeds_total_memory(required_memory)
|
||||
|
||||
a_cpu = make_tensor((1,), dtype=dtype, device='cpu').expand(a_shape)
|
||||
b_cpu = make_tensor(b_shape, dtype=dtype, device='cpu')
|
||||
r_cpu = torch.cat([a_cpu, b_cpu], dim=1)
|
||||
|
||||
# Pick a subset of output elements to compare, because comparing all of
|
||||
# them takes too long.
|
||||
rand_indices = torch.randint(0, a_cpu.shape[1] + b_cpu.shape[1], (10_000,))
|
||||
r_cpu_part0 = r_cpu[:, rand_indices, :].clone()
|
||||
r_cpu_part1 = r_cpu[:, -200:, :].clone()
|
||||
r_cpu_part2 = r_cpu[:, :200, :].clone()
|
||||
|
||||
# Delete the CPU result to free up memory for the MPS run.
|
||||
del r_cpu
|
||||
|
||||
a_mps = (
|
||||
torch.empty(0, dtype=dtype, device='mps')
|
||||
.set_(a_cpu.untyped_storage().mps())
|
||||
.as_strided(size=a_cpu.size(), stride=a_cpu.stride())
|
||||
)
|
||||
b_mps = b_cpu.to('mps')
|
||||
|
||||
try:
|
||||
r_mps = torch.cat([a_mps, b_mps], dim=1)
|
||||
|
||||
except RuntimeError as e:
|
||||
if "Invalid buffer size" in str(e):
|
||||
self.skipTest(f"Exceeds max buffer size for MPS: {str(e)}.")
|
||||
raise e
|
||||
|
||||
self.assertEqual(r_mps[:, rand_indices, :], r_cpu_part0)
|
||||
self.assertEqual(r_mps[:, -200:, :], r_cpu_part1)
|
||||
self.assertEqual(r_mps[:, :200, :], r_cpu_part2)
|
||||
|
||||
def test_large_tensor_to_string(self):
|
||||
shape = (2, 1 << 31)
|
||||
|
||||
# Assume up to 1% extra overhead memory might be required.
|
||||
required_memory = 1.01 * 2 * math.prod(shape)
|
||||
self._skip_if_exceeds_total_memory(required_memory)
|
||||
|
||||
self.assertEqual(
|
||||
str(torch.ones(shape, dtype=torch.int8, device='mps')),
|
||||
(
|
||||
"tensor([[1, 1, 1, ..., 1, 1, 1],\n"
|
||||
" [1, 1, 1, ..., 1, 1, 1]], device='mps:0', dtype=torch.int8)"
|
||||
),
|
||||
)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/152701
|
||||
def test_jacfwd_cat(self):
|
||||
def fn(x, y):
|
||||
@ -12167,9 +12234,6 @@ class TestNoRegression(TestCase):
|
||||
self.assertEqual(x2.device.type, "mps")
|
||||
|
||||
|
||||
MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble]
|
||||
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
|
||||
|
||||
MPS_GRAD_DTYPES = [torch.float32, torch.float16]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user