mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Move torch.cat
impl to Metal (#165373)
After this change, all of the cases tested in [this performance measurement script](10de64c5ac/cat/perf0.py
) take either roughly the same runtime or less.
Before:
```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000857 ms, 0.016098 ms, 0.05, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000858 ms, 0.014861 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000806 ms, 0.015145 ms, 0.05, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000829 ms, 0.015355 ms, 0.05, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000582 ms, 1.02, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001076 ms, 0.022387 ms, 0.05, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000708 ms, 0.022300 ms, 0.03, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000640 ms, 0.014367 ms, 0.04, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000777 ms, 0.027506 ms, 0.03, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003383 ms, 0.269277 ms, 0.01, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.526138 ms, 0.650852 ms, 0.81, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.444091 ms, 0.628630 ms, 0.71, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.011870 ms, 0.989525 ms, 2.03, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.100653 ms, 0.948178 ms, 3.27, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.112174 ms, 0.954174 ms, 3.26, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```
After:
```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000790 ms, 0.013111 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000800 ms, 0.014419 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000748 ms, 0.010019 ms, 0.07, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000767 ms, 0.010063 ms, 0.08, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000591 ms, 1.00, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001220 ms, 0.009763 ms, 0.12, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000739 ms, 0.006203 ms, 0.12, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000647 ms, 0.009905 ms, 0.07, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000753 ms, 0.007818 ms, 0.10, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003823 ms, 0.192723 ms, 0.02, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.576564 ms, 0.733920 ms, 0.79, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.462957 ms, 0.692799 ms, 0.67, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.017181 ms, 0.968345 ms, 2.08, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.203508 ms, 0.986382 ms, 3.25, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.181249 ms, 1.007773 ms, 3.16, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```
Fixes #165350
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165373
Approved by: https://github.com/kulinseth, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2c82bafb7
commit
e0fe37fa68
@ -1,16 +1,16 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeSharedParams {
|
||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
|
||||
struct CatSharedParams {
|
||||
int32_t ndim;
|
||||
int32_t cat_dim;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
};
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeInputParams {
|
||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
|
||||
struct CatInputParams {
|
||||
idx_type_t cat_dim_offset;
|
||||
idx_type_t input_element_offset;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
|
@ -6,12 +6,12 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
kernel void cat_large(
|
||||
template <typename I, typename T_in, typename T_out>
|
||||
kernel void cat(
|
||||
constant T_in* input [[buffer(0)]],
|
||||
device T_out* output [[buffer(1)]],
|
||||
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
|
||||
constant CatLargeInputParams<>& input_params [[buffer(3)]],
|
||||
constant CatSharedParams<I>& shared_params [[buffer(2)]],
|
||||
constant CatInputParams<I>& input_params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto ndim = shared_params.ndim;
|
||||
auto cat_dim = shared_params.cat_dim;
|
||||
@ -23,9 +23,9 @@ kernel void cat_large(
|
||||
constant auto& input_strides = input_params.input_strides;
|
||||
constant auto& input_sizes = input_params.input_sizes;
|
||||
|
||||
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
|
||||
int64_t input_offset = 0;
|
||||
int64_t output_offset = 0;
|
||||
auto input_element_idx = static_cast<I>(tid) + input_element_offset;
|
||||
I input_offset = 0;
|
||||
I output_offset = 0;
|
||||
|
||||
for (auto dim = ndim - 1; dim >= 0; dim--) {
|
||||
auto dim_size = input_sizes[dim];
|
||||
@ -42,41 +42,45 @@ kernel void cat_large(
|
||||
output[output_offset] = static_cast<T_out>(input[input_offset]);
|
||||
}
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
|
||||
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
|
||||
kernel void cat_large<T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
|
||||
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
|
||||
#define REGISTER_CAT_OP(I, T_in, T_out) \
|
||||
template [[host_name("cat_" #I "_" #T_in "_" #T_out)]] \
|
||||
kernel void cat<I, T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatSharedParams<I> & shared_params [[buffer(2)]], \
|
||||
constant CatInputParams<I> & input_params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
|
||||
REGISTER_CAT_LARGE_OP(float, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(half, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(int, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uint, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(long, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ulong, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(short, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ushort, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(char, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uchar, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bool, T_out);
|
||||
#define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \
|
||||
REGISTER_CAT_OP(I, float, T_out); \
|
||||
REGISTER_CAT_OP(I, half, T_out); \
|
||||
REGISTER_CAT_OP(I, bfloat, T_out); \
|
||||
REGISTER_CAT_OP(I, int, T_out); \
|
||||
REGISTER_CAT_OP(I, uint, T_out); \
|
||||
REGISTER_CAT_OP(I, long, T_out); \
|
||||
REGISTER_CAT_OP(I, ulong, T_out); \
|
||||
REGISTER_CAT_OP(I, short, T_out); \
|
||||
REGISTER_CAT_OP(I, ushort, T_out); \
|
||||
REGISTER_CAT_OP(I, char, T_out); \
|
||||
REGISTER_CAT_OP(I, uchar, T_out); \
|
||||
REGISTER_CAT_OP(I, bool, T_out);
|
||||
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
|
||||
#define REGISTER_CAT_FOR_INDEX_TYPE(I) \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool); \
|
||||
\
|
||||
REGISTER_CAT_OP(I, float2, float2); \
|
||||
REGISTER_CAT_OP(I, half2, half2);
|
||||
|
||||
REGISTER_CAT_LARGE_OP(float2, float2);
|
||||
REGISTER_CAT_LARGE_OP(half2, half2);
|
||||
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
|
||||
REGISTER_CAT_FOR_INDEX_TYPE(int32_t);
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
@ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
|
||||
}
|
||||
}
|
||||
|
||||
// This implementation of cat is used only if one of the inputs or the output is
|
||||
// too large to use MPSGraph.
|
||||
template <typename T>
|
||||
std::string get_type_str();
|
||||
|
||||
template <>
|
||||
std::string get_type_str<int64_t>() {
|
||||
return "int64_t";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string get_type_str<int32_t>() {
|
||||
return "int32_t";
|
||||
}
|
||||
|
||||
// NOTE: `output` is expected to already have the correct size.
|
||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatLargeSharedParams shared_params;
|
||||
template <typename idx_type_t>
|
||||
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatSharedParams<idx_type_t> shared_params;
|
||||
|
||||
shared_params.ndim = output.dim();
|
||||
shared_params.cat_dim = dimension;
|
||||
|
||||
for (const auto dim : c10::irange(output.dim())) {
|
||||
shared_params.output_strides[dim] = output.stride(dim);
|
||||
shared_params.output_sizes[dim] = output.size(dim);
|
||||
shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim));
|
||||
shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim));
|
||||
}
|
||||
|
||||
int64_t cat_dim_offset = 0;
|
||||
idx_type_t cat_dim_offset = 0;
|
||||
size_t input_idx = 0;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Launch a separate kernels for each input. This will produce some overhead,
|
||||
// but that should be relatively minimal since at least one of the inputs is
|
||||
// very large. In order to launch only one kernel to process all inputs, we
|
||||
// would have to copy all the input tensor data into a packed buffer, which
|
||||
// would not be ideal.
|
||||
// Launch a separate kernels for each input. This will produce some overhead.
|
||||
// In order to launch only one kernel to process all inputs, we would have to
|
||||
// copy all the input tensor data into a packed buffer, which would not be
|
||||
// ideal.
|
||||
for (const Tensor& input : inputs) {
|
||||
if (input.numel() == 0) {
|
||||
continue;
|
||||
@ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen
|
||||
|
||||
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
|
||||
auto num_threads = std::min(max_num_threads, numel_remaining);
|
||||
CatLargeInputParams input_params;
|
||||
CatInputParams<idx_type_t> input_params;
|
||||
|
||||
input_params.cat_dim_offset = cat_dim_offset;
|
||||
input_params.input_element_offset = input.numel() - numel_remaining;
|
||||
input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset);
|
||||
input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining);
|
||||
|
||||
for (const auto dim : c10::irange(input.dim())) {
|
||||
input_params.input_strides[dim] = input.stride(dim);
|
||||
input_params.input_sizes[dim] = input.size(dim);
|
||||
input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim));
|
||||
input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(
|
||||
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}",
|
||||
get_type_str<idx_type_t>(),
|
||||
scalarToMetalTypeString(input),
|
||||
scalarToMetalTypeString(output)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
|
||||
@ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
" and out is on ",
|
||||
out.device());
|
||||
|
||||
// TODO: For better performance by eliminating input tensor gathering and post transpose,
|
||||
// TODO: it is better to keep the out tensor's memory format.
|
||||
// TODO: dimension needs to be recomputed as:
|
||||
// TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
|
||||
if (needsGather(out)) {
|
||||
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
}
|
||||
std::vector<int64_t> size(notSkippedTensor.sizes().vec());
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
@ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
has_large_tensor |= isTooLargeForMPSGraph(out);
|
||||
|
||||
if (has_large_tensor) {
|
||||
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
std::vector<MPSGraphTensor*> inputTensors_;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
if (!all_same_dtype) {
|
||||
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
|
||||
} else {
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
|
||||
}
|
||||
for (auto idx : skipped_tensor_indices) {
|
||||
key += "," + std::to_string(idx);
|
||||
}
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
|
||||
std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
|
||||
newCachedGraph->inputTensors_.reserve(len_tensor_array);
|
||||
|
||||
for (const auto idx : c10::irange(len_tensor_array)) {
|
||||
const Tensor& tensor = input_tensors[idx];
|
||||
auto scalar_type = getMPSScalarType(tensor.scalar_type());
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
|
||||
if (tensor.scalar_type() != out_dtype) {
|
||||
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
|
||||
toType:getMPSDataType(out_dtype)
|
||||
name:@"castInput"];
|
||||
} else {
|
||||
castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
|
||||
}
|
||||
}
|
||||
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
|
||||
dimension:dimension // Maybe convert this from int64_t -> int32
|
||||
name:nil];
|
||||
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
|
||||
}
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
|
||||
std::vector<Placeholder> inputPlaceholders;
|
||||
int i = 0;
|
||||
int t_idx = 0;
|
||||
for (const Tensor& tensor : materialized_inputs) {
|
||||
if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
|
||||
auto scalar_type = getMPSScalarType(tensor.scalar_type());
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type);
|
||||
t_idx++;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
auto outputDataType = getMPSScalarType(out.scalar_type());
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
for (auto& inputPlaceholder : inputPlaceholders) {
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
|
||||
} else {
|
||||
return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user