mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add embedding_bag
forward pass (#163012)
Part of #162270 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163012 Approved by: https://github.com/kulinseth, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
167ad09be5
commit
5236007806
@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <cstdint>
|
||||
|
25
aten/src/ATen/native/mps/kernels/EmbeddingBag.h
Normal file
25
aten/src/ATen/native/mps/kernels/EmbeddingBag.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
#ifdef __METAL__
|
||||
enum class EmbeddingBagMode { SUM = 0, MEAN, MAX };
|
||||
#else
|
||||
#include <ATen/native/EmbeddingBag.h>
|
||||
using at::native::EmbeddingBagMode;
|
||||
#endif
|
||||
|
||||
template <typename idx_type_t = uint32_t>
|
||||
struct EmbeddingBagParams {
|
||||
::c10::metal::array<idx_type_t, 2> weight_strides;
|
||||
::c10::metal::array<idx_type_t, 2> output_strides;
|
||||
::c10::metal::array<idx_type_t, 2> max_indices_strides;
|
||||
|
||||
idx_type_t per_sample_weights_strides;
|
||||
|
||||
idx_type_t num_indices;
|
||||
idx_type_t num_bags;
|
||||
idx_type_t feature_size;
|
||||
|
||||
EmbeddingBagMode mode;
|
||||
int64_t padding_idx;
|
||||
};
|
212
aten/src/ATen/native/mps/kernels/EmbeddingBag.metal
Normal file
212
aten/src/ATen/native/mps/kernels/EmbeddingBag.metal
Normal file
@ -0,0 +1,212 @@
|
||||
#include <ATen/native/mps/kernels/EmbeddingBag.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct ReductionOpInit {
|
||||
inline opmath_t<T> operator()() {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOpInit<EmbeddingBagMode::MAX, T> {
|
||||
inline opmath_t<T> operator()() {
|
||||
return static_cast<opmath_t<T>>(-INFINITY);
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct ReductionOp {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_strides);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOp<EmbeddingBagMode::SUM, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_strides) {
|
||||
if (per_sample_weights_strides) {
|
||||
T per_sample_weight = per_sample_weights
|
||||
[per_sample_weights_strides * per_sample_weights_index];
|
||||
return static_cast<opmath_t<T>>(per_sample_weight) *
|
||||
static_cast<opmath_t<T>>(weight_val) +
|
||||
out_val;
|
||||
} else {
|
||||
return static_cast<opmath_t<T>>(weight_val) + out_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOp<EmbeddingBagMode::MEAN, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t,
|
||||
constant T*,
|
||||
uint32_t) {
|
||||
return static_cast<opmath_t<T>>(weight_val) + out_val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOp<EmbeddingBagMode::MAX, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t,
|
||||
constant T*,
|
||||
uint32_t) {
|
||||
return max(static_cast<opmath_t<T>>(weight_val), out_val);
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct ReductionOpFinal {
|
||||
inline T operator()(opmath_t<T> val, uint32_t) {
|
||||
return static_cast<T>(val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOpFinal<EmbeddingBagMode::MEAN, T> {
|
||||
inline T operator()(opmath_t<T> val, uint32_t count) {
|
||||
auto out = val / count;
|
||||
return static_cast<T>((count == 0) ? 0 : out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOpFinal<EmbeddingBagMode::MAX, T> {
|
||||
inline T operator()(opmath_t<T> val, uint32_t count) {
|
||||
return static_cast<T>((count == 0) ? 0 : val);
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T, typename I>
|
||||
void embedding_bag_impl(
|
||||
constant T* weight,
|
||||
constant I* indices,
|
||||
constant I* offsets,
|
||||
constant T* per_sample_weights,
|
||||
device T* output,
|
||||
device I* offset2bag,
|
||||
device I* bag_size,
|
||||
device I* max_indices,
|
||||
constant EmbeddingBagParams<uint32_t>& params,
|
||||
uint tid) {
|
||||
auto num_indices = params.num_indices;
|
||||
auto num_bags = params.num_bags;
|
||||
auto feature_size = params.feature_size;
|
||||
auto padding_idx = params.padding_idx;
|
||||
auto per_sample_weights_strides = params.per_sample_weights_strides;
|
||||
constant auto& output_strides = params.output_strides;
|
||||
constant auto& weight_strides = params.weight_strides;
|
||||
constant auto& max_indices_strides = params.max_indices_strides;
|
||||
|
||||
auto bag_idx = tid / feature_size;
|
||||
auto feature_idx = tid % feature_size;
|
||||
|
||||
output += bag_idx * output_strides[0] + feature_idx * output_strides[1];
|
||||
|
||||
uint32_t offsets_end = min(bag_idx + 1, num_bags - 1);
|
||||
bool is_last_bag = bag_idx + 1 == num_bags;
|
||||
uint32_t indices_start = static_cast<uint32_t>(offsets[bag_idx]);
|
||||
uint32_t indices_end = is_last_bag * (num_indices) +
|
||||
(!is_last_bag) * (static_cast<uint32_t>(offsets[offsets_end]));
|
||||
|
||||
auto out_val = ReductionOpInit<M, T>()();
|
||||
|
||||
uint32_t bag_size_ = 0;
|
||||
|
||||
for (uint32_t indices_idx = indices_start; indices_idx < indices_end;
|
||||
indices_idx++) {
|
||||
I weight_idx = indices[indices_idx];
|
||||
bool pad = (weight_idx == padding_idx);
|
||||
T weight_val = weight
|
||||
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
|
||||
feature_idx * weight_strides[1]];
|
||||
|
||||
bag_size_ += static_cast<uint32_t>(!pad);
|
||||
|
||||
auto tmp_val = ReductionOp<M, T>()(
|
||||
weight_val,
|
||||
out_val,
|
||||
indices_idx,
|
||||
per_sample_weights,
|
||||
per_sample_weights_strides);
|
||||
|
||||
out_val = pad ? out_val : tmp_val;
|
||||
}
|
||||
|
||||
*output = ReductionOpFinal<M, T>()(out_val, bag_size_);
|
||||
}
|
||||
|
||||
#define DISPATCH_IMPL(MODE) \
|
||||
return embedding_bag_impl<MODE>( \
|
||||
weight, \
|
||||
indices, \
|
||||
offsets, \
|
||||
per_sample_weights, \
|
||||
output, \
|
||||
offset2bag, \
|
||||
bag_size, \
|
||||
max_indices, \
|
||||
params, \
|
||||
tid)
|
||||
|
||||
template <typename T, typename I>
|
||||
kernel void embedding_bag(
|
||||
constant T* weight [[buffer(0)]],
|
||||
constant I* indices [[buffer(1)]],
|
||||
constant I* offsets [[buffer(2)]],
|
||||
constant T* per_sample_weights [[buffer(3)]],
|
||||
device T* output [[buffer(4)]],
|
||||
device I* offset2bag [[buffer(5)]],
|
||||
device I* bag_size [[buffer(6)]],
|
||||
device I* max_indices [[buffer(7)]],
|
||||
constant EmbeddingBagParams<uint32_t>& params [[buffer(8)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
switch (params.mode) {
|
||||
case EmbeddingBagMode::SUM:
|
||||
DISPATCH_IMPL(EmbeddingBagMode::SUM);
|
||||
case EmbeddingBagMode::MEAN:
|
||||
DISPATCH_IMPL(EmbeddingBagMode::MEAN);
|
||||
case EmbeddingBagMode::MAX:
|
||||
DISPATCH_IMPL(EmbeddingBagMode::MAX);
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
|
||||
template [[host_name("embedding_bag_" #T "_" #I)]] \
|
||||
kernel void embedding_bag<T, I>( \
|
||||
constant T * weight [[buffer(0)]], \
|
||||
constant I * indices [[buffer(1)]], \
|
||||
constant I * offsets [[buffer(2)]], \
|
||||
constant T * per_sample_weights [[buffer(3)]], \
|
||||
device T * output [[buffer(4)]], \
|
||||
device I * offset2bag [[buffer(5)]], \
|
||||
device I * bag_size [[buffer(6)]], \
|
||||
device I * max_indices [[buffer(7)]], \
|
||||
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_EMBEDDING_BAG_OP(float, int);
|
||||
REGISTER_EMBEDDING_BAG_OP(float, long);
|
||||
REGISTER_EMBEDDING_BAG_OP(half, int);
|
||||
REGISTER_EMBEDDING_BAG_OP(half, long);
|
||||
REGISTER_EMBEDDING_BAG_OP(bfloat, int);
|
||||
REGISTER_EMBEDDING_BAG_OP(bfloat, long);
|
179
aten/src/ATen/native/mps/operations/EmbeddingBag.mm
Normal file
179
aten/src/ATen/native/mps/operations/EmbeddingBag.mm
Normal file
@ -0,0 +1,179 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/EmbeddingBag.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/EmbeddingBag.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_embedding_bag_forward_only_native.h>
|
||||
#include <ATen/ops/_embedding_bag_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/EmbeddingBag_metallib.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<Tensor, Tensor> promoteIndicesAndOffsets(const Tensor& indices, const Tensor& offsets) {
|
||||
const auto commonType = promoteTypes(offsets.scalar_type(), indices.scalar_type());
|
||||
return {indices.scalar_type() == commonType ? indices : indices.toType(commonType),
|
||||
offsets.scalar_type() == commonType ? offsets : offsets.toType(commonType)};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mps {
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
|
||||
const Tensor& weight,
|
||||
const Tensor& indices_,
|
||||
const Tensor& offsets_,
|
||||
const bool scale_grad_by_freq,
|
||||
const int64_t mode,
|
||||
bool sparse,
|
||||
const std::optional<Tensor>& per_sample_weights_opt,
|
||||
bool include_last_offset,
|
||||
int64_t padding_idx) {
|
||||
TORCH_CHECK(indices_.dim() == 1, "input has to be a 1D Tensor, but got Tensor of dimension ", indices_.dim());
|
||||
if (indices_.dim() == 1) {
|
||||
TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor, but got Tensor of dimension ", offsets_.dim());
|
||||
}
|
||||
TORCH_CHECK(weight.dim() == 2, "weight has to be a 2D Tensor, but got Tensor of dimension ", weight.dim());
|
||||
|
||||
Tensor indices, offsets;
|
||||
std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
|
||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||
checkScalarTypes("embedding_bag_mps", indices_arg, {kLong, kInt});
|
||||
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
||||
checkScalarTypes("embedding_bag_mps", offsets_arg, {kLong, kInt});
|
||||
checkSameType("embedding_bag_mps", indices_arg, offsets_arg);
|
||||
auto weight_arg = TensorArg(weight, "weight", 1);
|
||||
|
||||
int64_t num_indices = indices.size(0);
|
||||
int64_t num_bags = offsets.size(0);
|
||||
if (include_last_offset) {
|
||||
num_bags -= 1;
|
||||
}
|
||||
int64_t feature_size = weight.size(1);
|
||||
|
||||
auto bag_size = at::empty(offsets.sizes(), indices.options());
|
||||
auto offset2bag = at::empty({indices.size(0)}, indices.options());
|
||||
auto output = at::empty({num_bags, feature_size}, weight.options());
|
||||
|
||||
Tensor max_indices;
|
||||
|
||||
if (mode == EmbeddingBagMode::MAX) {
|
||||
max_indices = at::empty({num_bags, feature_size}, indices.options());
|
||||
} else {
|
||||
max_indices = at::empty({0}, indices.options());
|
||||
}
|
||||
|
||||
EmbeddingBagParams<uint32_t> params;
|
||||
|
||||
for (const auto dim : c10::irange(weight.dim())) {
|
||||
params.weight_strides[dim] = safe_downcast<uint32_t, int64_t>(weight.stride(dim));
|
||||
params.output_strides[dim] = safe_downcast<uint32_t, int64_t>(output.stride(dim));
|
||||
|
||||
if (mode == EmbeddingBagMode::MAX) {
|
||||
params.max_indices_strides[dim] = safe_downcast<uint32_t, int64_t>(max_indices.stride(dim));
|
||||
}
|
||||
}
|
||||
|
||||
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
|
||||
params.per_sample_weights_strides = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
|
||||
|
||||
params.num_indices = num_indices;
|
||||
params.num_bags = num_bags;
|
||||
params.feature_size = feature_size;
|
||||
params.mode = static_cast<EmbeddingBagMode>(mode);
|
||||
params.padding_idx = padding_idx;
|
||||
|
||||
auto num_threads = output.numel();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(
|
||||
fmt::format("embedding_bag_{}_{}", scalarToMetalTypeString(weight), scalarToMetalTypeString(indices)));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "embedding_bag", {weight, indices, offsets});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(computeEncoder,
|
||||
weight,
|
||||
indices,
|
||||
offsets,
|
||||
use_per_sample_weights ? per_sample_weights_opt : std::nullopt,
|
||||
output,
|
||||
offset2bag,
|
||||
bag_size,
|
||||
max_indices,
|
||||
params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
|
||||
std::move(output), std::move(offset2bag), std::move(bag_size), std::move(max_indices));
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps(const Tensor& weight,
|
||||
const Tensor& indices,
|
||||
const Tensor& offsets,
|
||||
const bool scale_grad_by_freq,
|
||||
const int64_t mode,
|
||||
bool sparse,
|
||||
const std::optional<Tensor>& per_sample_weights_opt,
|
||||
bool include_last_offset,
|
||||
int64_t padding_idx) {
|
||||
return mps::_embedding_bag_mps_impl(weight,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights_opt,
|
||||
include_last_offset,
|
||||
padding_idx);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only_mps(
|
||||
const Tensor& weight,
|
||||
const Tensor& indices,
|
||||
const Tensor& offsets,
|
||||
const bool scale_grad_by_freq,
|
||||
const int64_t mode,
|
||||
bool sparse,
|
||||
const std::optional<Tensor>& per_sample_weights_opt,
|
||||
bool include_last_offset,
|
||||
int64_t padding_idx) {
|
||||
return _embedding_bag_mps(weight,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse,
|
||||
per_sample_weights_opt,
|
||||
include_last_offset,
|
||||
padding_idx);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -2351,6 +2351,7 @@
|
||||
dispatch:
|
||||
CPU: _embedding_bag_forward_only_cpu
|
||||
CUDA: _embedding_bag_forward_only_cuda
|
||||
MPS: _embedding_bag_forward_only_mps
|
||||
autogen: _embedding_bag_forward_only.out
|
||||
|
||||
- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
|
||||
@ -2372,6 +2373,7 @@
|
||||
dispatch:
|
||||
CPU: _embedding_bag_cpu
|
||||
CUDA: _embedding_bag_cuda
|
||||
MPS: _embedding_bag_mps
|
||||
autogen: _embedding_bag.out
|
||||
tags: core
|
||||
|
||||
|
@ -7277,11 +7277,8 @@ GPU_TEST_FAILURES = {
|
||||
}
|
||||
|
||||
MPS_TEST_FAILURES = {
|
||||
# aten::_embedding_bag is not currently implemented for the MPS device.
|
||||
# aten::_embedding_bag backward is not currently implemented for the MPS device.
|
||||
"test_embedding_bag": fail_mps(),
|
||||
# aten::_embedding_bag is not currently implemented for the MPS device.
|
||||
"test_misc_1_max_autotune_False": fail_mps(),
|
||||
"test_misc_1_max_autotune_True": fail_mps(),
|
||||
# aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device.
|
||||
"test_scaled_dot_product_efficient_attention": fail_mps(),
|
||||
# aten::_int_mm is not implemented for MPS backend
|
||||
|
@ -15,6 +15,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d(AtenTensorH
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||
|
@ -348,7 +348,6 @@ if torch.backends.mps.is_available():
|
||||
"nn.functional.interpolatearea": None,
|
||||
"nn.functional.interpolatebicubic": [torch.uint8],
|
||||
"nn.functional.ctc_loss": None,
|
||||
"nn.functional.embedding_bag": None,
|
||||
"nn.functional.multi_margin_loss": None,
|
||||
"nn.functional.multilabel_margin_loss": None,
|
||||
"nn.functional.pdist": None,
|
||||
@ -740,6 +739,7 @@ if torch.backends.mps.is_available():
|
||||
"equal": [torch.float16, torch.float32],
|
||||
# 'float' object is not iterable
|
||||
"item": [torch.float16, torch.float32],
|
||||
"nn.functional.embedding_bag": None,
|
||||
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
|
||||
"nn.functional.smooth_l1_loss": [torch.float16],
|
||||
# cpu error: grad requires non-empty inputs
|
||||
|
Reference in New Issue
Block a user