mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
c10::string_view -> std::string_view in aten (#141903)
D66560348 passes internally, but won't export, so I'm rebuilding here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141903 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
8cb68b136f
commit
b1bb860d3c
@ -86,7 +86,7 @@ void Context::setDeterministicFillUninitializedMemory(bool b) {
|
||||
_deterministic_fill_uninitialized_memory = b;
|
||||
}
|
||||
|
||||
void Context::alertNotDeterministic(c10::string_view const& caller) {
|
||||
void Context::alertNotDeterministic(std::string_view const& caller) {
|
||||
if (globalContext().deterministicAlgorithms()) {
|
||||
if (globalContext().deterministicAlgorithmsWarnOnly()) {
|
||||
TORCH_WARN(
|
||||
|
@ -313,7 +313,7 @@ class TORCH_API Context {
|
||||
// }
|
||||
|
||||
// Throws an error if `Context::deterministicAlgorithms()` is true
|
||||
static void alertNotDeterministic(c10::string_view const& caller);
|
||||
static void alertNotDeterministic(std::string_view const& caller);
|
||||
|
||||
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
|
||||
// >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
|
||||
|
@ -1209,10 +1209,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
|
||||
BINARY_POINTWISE(mul);
|
||||
BINARY_POINTWISE(div);
|
||||
{
|
||||
using Binop = Tensor (*)(const Tensor&, const Tensor&, std::optional<c10::string_view>);
|
||||
using Unop = Tensor (*)(const Tensor&, const Scalar&, std::optional<c10::string_view>);
|
||||
m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::optional<c10::string_view>>);
|
||||
m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, const Scalar&, std::optional<c10::string_view>>);
|
||||
using Binop = Tensor (*)(const Tensor&, const Tensor&, std::optional<std::string_view>);
|
||||
using Unop = Tensor (*)(const Tensor&, const Scalar&, std::optional<std::string_view>);
|
||||
m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, std::optional<std::string_view>>);
|
||||
m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, const Scalar&, std::optional<std::string_view>>);
|
||||
}
|
||||
|
||||
// at::pow has three out-of-place overloads
|
||||
|
@ -41,7 +41,7 @@ inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
|
||||
if (ivalue.isInt()) {
|
||||
return std::hash<int64_t>()(ivalue.toInt());
|
||||
} else if (ivalue.isString()) {
|
||||
return std::hash<c10::string_view>()(ivalue.toStringView());
|
||||
return std::hash<std::string_view>()(ivalue.toStringView());
|
||||
} else if (ivalue.isDouble()) {
|
||||
return std::hash<double>()(ivalue.toDouble());
|
||||
} else if (ivalue.isComplexDouble()) {
|
||||
|
@ -54,7 +54,7 @@ class TORCH_API Blob final : public c10::intrusive_ptr_target {
|
||||
/**
|
||||
* Returns a printable typename of the blob.
|
||||
*/
|
||||
c10::string_view TypeName() const noexcept {
|
||||
std::string_view TypeName() const noexcept {
|
||||
return meta_.name();
|
||||
}
|
||||
|
||||
|
@ -87,7 +87,7 @@ using supported_primitive_arg_types = guts::typelist::typelist<
|
||||
int64_t,
|
||||
double,
|
||||
bool,
|
||||
c10::string_view,
|
||||
std::string_view,
|
||||
at::Tensor,
|
||||
at::Scalar,
|
||||
c10::QScheme,
|
||||
@ -220,7 +220,7 @@ struct assert_is_valid_input_type<
|
||||
std::enable_if_t<std::is_same_v<const char*, T>>> {
|
||||
static_assert(
|
||||
guts::false_t<T>::value,
|
||||
"You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead.");
|
||||
"You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead.");
|
||||
};
|
||||
template <class T, bool AllowDeprecatedTypes>
|
||||
struct assert_is_valid_input_type<
|
||||
@ -357,7 +357,7 @@ struct assert_is_valid_output_type<
|
||||
std::enable_if_t<std::is_same_v<const char*, T>>> {
|
||||
static_assert(
|
||||
guts::false_t<T>::value,
|
||||
"You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead.");
|
||||
"You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead.");
|
||||
};
|
||||
template <class T, bool AllowDeprecatedTypes>
|
||||
struct assert_is_valid_output_type<
|
||||
|
@ -699,7 +699,7 @@ struct TORCH_API IValue final {
|
||||
const std::string& toStringRef() const;
|
||||
std::optional<std::reference_wrapper<const std::string>> toOptionalStringRef()
|
||||
const;
|
||||
c10::string_view toStringView() const;
|
||||
std::string_view toStringView() const;
|
||||
|
||||
// DoubleList
|
||||
bool isDoubleList() const;
|
||||
|
@ -309,7 +309,7 @@ struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
|
||||
const std::string& string() const {
|
||||
return str_;
|
||||
}
|
||||
c10::string_view string_view() const {
|
||||
std::string_view string_view() const {
|
||||
return str_;
|
||||
}
|
||||
|
||||
@ -1742,7 +1742,7 @@ DEFINE_TO(c10::impl::GenericList, toList)
|
||||
DEFINE_TO(c10::impl::GenericDict, toGenericDict)
|
||||
DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
|
||||
DEFINE_TO(std::string, toStringRef)
|
||||
DEFINE_TO(c10::string_view, toStringView)
|
||||
DEFINE_TO(std::string_view, toStringView)
|
||||
DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
|
||||
DEFINE_TO(c10::intrusive_ptr<ivalue::Await>, toAwait)
|
||||
DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef)
|
||||
@ -1962,11 +1962,11 @@ inline T IValue::to() && {
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::optional<c10::string_view> IValue::to() && {
|
||||
inline std::optional<std::string_view> IValue::to() && {
|
||||
// In the default implementation, the IValue is destroyed with std::move.
|
||||
// But if the unboxed type is std::optional<string_view> we cannot destroy
|
||||
// the IValue.
|
||||
return generic_to(*this, _fake_type<std::optional<c10::string_view>>{});
|
||||
return generic_to(*this, _fake_type<std::optional<std::string_view>>{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -2390,7 +2390,7 @@ inline std::optional<std::reference_wrapper<const std::string>> IValue::
|
||||
->string());
|
||||
}
|
||||
|
||||
inline c10::string_view IValue::toStringView() const {
|
||||
inline std::string_view IValue::toStringView() const {
|
||||
AT_ASSERT(isString(), "Expected String but got ", tagKind());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
|
||||
|
@ -174,13 +174,13 @@ The kernel function can take any of the following types as inputs or outputs:
|
||||
* `double` (note: `float` is not supported)
|
||||
* `int64_t` (note: other integer types like `int`, `uint64_t`, `int32_t`, `...` are not supported)
|
||||
* `bool`
|
||||
* `c10::string_view`
|
||||
* `std::string_view`
|
||||
* `at::Scalar` (this is a type that can hold either an integer or a floating point value)
|
||||
* `std::optional<T>` with T being any type from the list above
|
||||
|
||||
The kernel function can take and return list inputs by using `torch::List<T>`. `T` must be one of the supported types from above excluding `at::Scalar`.
|
||||
|
||||
The kernel function can take and return dicts by using `torch::Dict<Key, Value>`. `Key` must be `int64_t`, `c10::string_view`, `double` or `bool`, and `Value` must be from the list of supported types above excluding `at::Scalar`.
|
||||
The kernel function can take and return dicts by using `torch::Dict<Key, Value>`. `Key` must be `int64_t`, `std::string_view`, `double` or `bool`, and `Value` must be from the list of supported types above excluding `at::Scalar`.
|
||||
|
||||
When taken as input, any of these types can be taken by value (i.e. `Tensor`) or by const-reference (i.e. `const Tensor&`). We recommend taking all arguments by value, even Tensors. They will be moved in, so there is no performance overhead.
|
||||
|
||||
|
@ -29,7 +29,7 @@ struct TORCH_API TypeFactoryBase<c10::DynamicType> {
|
||||
}
|
||||
static c10::DynamicTypePtr createNamedTuple(
|
||||
const std::string& name,
|
||||
const std::vector<c10::string_view>& fields,
|
||||
const std::vector<std::string_view>& fields,
|
||||
const std::vector<c10::TypePtr>& types) {
|
||||
return std::make_shared<c10::DynamicType>(
|
||||
c10::DynamicType::Tag::Tuple,
|
||||
|
@ -159,7 +159,7 @@ static std::tuple<Tensor, std::optional<int64_t>> where_self_batch_rule(
|
||||
|
||||
static std::tuple<Tensor, std::optional<int64_t>> gelu_backward_batch_rule(
|
||||
const Tensor& grad_out, std::optional<int64_t> grad_out_bdim, const Tensor& input, std::optional<int64_t> input_bdim,
|
||||
c10::string_view approximate) {
|
||||
std::string_view approximate) {
|
||||
|
||||
// repeat the preprocessing from _binary_pointwise_batch_rule
|
||||
auto [grad_out_, input_]= _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
|
||||
@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
|
||||
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
|
||||
using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
|
||||
using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, std::optional<c10::string_view>) const;
|
||||
using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, std::optional<std::string_view>) const;
|
||||
using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const;
|
||||
using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const;
|
||||
|
||||
@ -499,7 +499,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::mul_>));
|
||||
VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::mul_, const Scalar&>));
|
||||
VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::div_>));
|
||||
VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceModeT, &Tensor::div_, std::optional<c10::string_view>>));
|
||||
VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceModeT, &Tensor::div_, std::optional<std::string_view>>));
|
||||
VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::div_, const Scalar&>));
|
||||
VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_min_>));
|
||||
VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_max_>));
|
||||
|
@ -151,7 +151,7 @@ Tensor addmm_decomp(const Tensor& self, const Tensor& mat1, const Tensor& mat2,
|
||||
return at::add(self * beta, at::mm(mat1, mat2), alpha);
|
||||
}
|
||||
|
||||
void _linalg_check_errors_batch_rule(const Tensor& info, std::optional<int64_t> info_bdim, c10::string_view api_name, bool is_matrix) {
|
||||
void _linalg_check_errors_batch_rule(const Tensor& info, std::optional<int64_t> info_bdim, std::string_view api_name, bool is_matrix) {
|
||||
auto info_ = moveBatchDimToFront(info, info_bdim);
|
||||
// Not a matrix means this is a batch of matrices
|
||||
at::_linalg_check_errors(info_, api_name, false);
|
||||
@ -421,7 +421,7 @@ std::optional<int64_t> batch_dim_if_not_empty(const Tensor& t) {
|
||||
|
||||
fourOutputs linalg_lstsq_batch_rule(
|
||||
const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& b, std::optional<int64_t> b_bdim,
|
||||
std::optional<double> rcond, std::optional<c10::string_view> driver) {
|
||||
std::optional<double> rcond, std::optional<std::string_view> driver) {
|
||||
TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
|
||||
TORCH_CHECK(rankWithoutBatchDim(b, b_bdim) >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
|
||||
|
||||
|
@ -321,7 +321,7 @@ static std::tuple<Tensor, std::optional<int64_t>> searchsorted_batch_rule(
|
||||
std::optional<int64_t> self_bdim,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
std::optional<c10::string_view> side,
|
||||
std::optional<std::string_view> side,
|
||||
const std::optional<Tensor>& sorter,
|
||||
std::optional<int64_t> sorter_bdim) {
|
||||
auto buckets_logical_rank = rankWithoutBatchDim(sorted_sequence, sorted_sequence_bdim);
|
||||
|
@ -774,7 +774,7 @@ std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_batch_rule(
|
||||
int64_t dim,
|
||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||
const Tensor& src, std::optional<int64_t> src_bdim,
|
||||
const c10::string_view reduce) {
|
||||
const std::string_view reduce) {
|
||||
return scatter_batch_rule(ATEN_FN2(scatter, reduce),
|
||||
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
|
||||
}
|
||||
@ -784,7 +784,7 @@ std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_two_batch_rule(
|
||||
int64_t dim,
|
||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||
const Tensor& src, std::optional<int64_t> src_bdim,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_self) {
|
||||
return scatter_batch_rule(ATEN_FN2(scatter_reduce, two),
|
||||
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self);
|
||||
@ -795,7 +795,7 @@ std::tuple<Tensor, std::optional<int64_t>> scatter_reduce__two_batch_rule(
|
||||
int64_t dim,
|
||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||
const Tensor& src, std::optional<int64_t> src_bdim,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_self) {
|
||||
return scatter_batch_rule(ATEN_FN2(scatter_reduce_, two),
|
||||
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self);
|
||||
@ -806,7 +806,7 @@ std::tuple<Tensor, std::optional<int64_t>> scatter_value_reduce_batch_rule(
|
||||
int64_t dim,
|
||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||
const Scalar& src,
|
||||
const c10::string_view reduce) {
|
||||
const std::string_view reduce) {
|
||||
return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
|
||||
self, self_bdim, dim, index, index_bdim, src, reduce);
|
||||
}
|
||||
|
@ -226,12 +226,12 @@ TORCH_META_FUNC(softshrink_backward) (
|
||||
build_borrowing_binary_op(maybe_get_output(), grad, self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) {
|
||||
TORCH_META_FUNC(gelu) (const Tensor & self, std::string_view approximate) {
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(gelu_backward) (
|
||||
const Tensor& grad, const Tensor& self, c10::string_view approximate
|
||||
const Tensor& grad, const Tensor& self, std::string_view approximate
|
||||
) {
|
||||
build_borrowing_binary_op(maybe_get_output(), grad, self);
|
||||
}
|
||||
@ -387,7 +387,7 @@ static bool use_mkldnn(const Tensor& input) {
|
||||
#endif
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_out_cpu) (
|
||||
const Tensor& self, c10::string_view approximate, const Tensor& result
|
||||
const Tensor& self, std::string_view approximate, const Tensor& result
|
||||
) {
|
||||
auto approximate_type = get_gelutype_enum(approximate);
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
@ -412,7 +412,7 @@ auto approximate_type = get_gelutype_enum(approximate);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
|
||||
const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
|
||||
const Tensor& grad, const Tensor& self, std::string_view approximate, const Tensor& grad_input
|
||||
) {
|
||||
auto approximate_type = get_gelutype_enum(approximate);
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
|
@ -690,7 +690,7 @@ TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A,
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(linalg_qr)(const Tensor& A,
|
||||
c10::string_view mode) {
|
||||
std::string_view mode) {
|
||||
at::native::checkIsMatrix(A, "linalg.qr");
|
||||
at::native::checkFloatingOrComplex(A, "linalg.qr");
|
||||
auto [compute_q, reduced_mode] = at::native::_parse_qr_mode(mode);
|
||||
@ -720,7 +720,7 @@ TORCH_META_FUNC(linalg_qr)(const Tensor& A,
|
||||
TORCH_META_FUNC(_linalg_svd)(const Tensor& A,
|
||||
bool full_matrices,
|
||||
bool compute_uv,
|
||||
std::optional<c10::string_view> driver) {
|
||||
std::optional<std::string_view> driver) {
|
||||
at::native::checkIsMatrix(A, "linalg.svd");
|
||||
at::native::checkFloatingOrComplex(A, "linalg.svd");
|
||||
|
||||
@ -792,7 +792,7 @@ TORCH_META_FUNC(lu_unpack)(const Tensor& LU, const Tensor& pivots, bool unpack_d
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(_linalg_eigh)(const Tensor& A,
|
||||
c10::string_view uplo,
|
||||
std::string_view uplo,
|
||||
bool compute_v) {
|
||||
at::native::squareCheckInputs(A, "linalg.eigh");
|
||||
at::native::checkUplo(uplo);
|
||||
@ -1558,7 +1558,7 @@ template<> void blasTriangularSolve<float>(char side, char uplo, char trans, cha
|
||||
|
||||
void _linalg_check_errors(
|
||||
const Tensor& infos,
|
||||
const c10::string_view api_name,
|
||||
const std::string_view api_name,
|
||||
bool is_matrix) {
|
||||
TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
|
||||
TORCH_INTERNAL_ASSERT(infos.is_contiguous());
|
||||
@ -2425,7 +2425,7 @@ std::tuple<Tensor, Tensor> geqrf(const Tensor& input) {
|
||||
For further details, please see the LAPACK documentation for GEQRF and ORGQR.
|
||||
*/
|
||||
TORCH_IMPL_FUNC(linalg_qr_out)(const Tensor& A,
|
||||
c10::string_view mode,
|
||||
std::string_view mode,
|
||||
const Tensor & Q,
|
||||
const Tensor & R) {
|
||||
auto m = A.size(-2);
|
||||
@ -2801,7 +2801,7 @@ DEFINE_DISPATCH(linalg_eigh_stub);
|
||||
*/
|
||||
|
||||
TORCH_IMPL_FUNC(_linalg_eigh_out)(const Tensor& A,
|
||||
c10::string_view uplo,
|
||||
std::string_view uplo,
|
||||
bool compute_v,
|
||||
const Tensor& L,
|
||||
const Tensor& V) {
|
||||
@ -2826,22 +2826,22 @@ TORCH_IMPL_FUNC(_linalg_eigh_out)(const Tensor& A,
|
||||
at::_linalg_check_errors(info, "linalg.eigh", /*is_matrix*/A.dim() == 2);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& A, c10::string_view uplo) {
|
||||
std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& A, std::string_view uplo) {
|
||||
// TODO (Good intro task) Implement linalg_eigh_ex_out
|
||||
return at::_linalg_eigh(A, uplo, /*compute_v*/true);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&> linalg_eigh_out(const Tensor& A, c10::string_view uplo, Tensor& L, Tensor& V) {
|
||||
std::tuple<Tensor&, Tensor&> linalg_eigh_out(const Tensor& A, std::string_view uplo, Tensor& L, Tensor& V) {
|
||||
return at::_linalg_eigh_out(L, V, A, uplo, /*compute_v=*/true);
|
||||
}
|
||||
|
||||
|
||||
Tensor linalg_eigvalsh(const Tensor& A, c10::string_view uplo) {
|
||||
Tensor linalg_eigvalsh(const Tensor& A, std::string_view uplo) {
|
||||
return std::get<0>(at::_linalg_eigh(A, uplo,
|
||||
/*compute_v=*/_may_require_fw_or_bw_grad(A)));
|
||||
}
|
||||
|
||||
Tensor& linalg_eigvalsh_out(const Tensor& A, c10::string_view uplo, Tensor& L) {
|
||||
Tensor& linalg_eigvalsh_out(const Tensor& A, std::string_view uplo, Tensor& L) {
|
||||
auto V = at::empty({0}, A.options());
|
||||
at::_linalg_eigh_out(L, V, A, uplo, /*compute_v=*/false);
|
||||
return L;
|
||||
@ -3197,7 +3197,7 @@ DEFINE_DISPATCH(svd_stub);
|
||||
TORCH_IMPL_FUNC(_linalg_svd_out)(const Tensor& A,
|
||||
const bool full_matrices,
|
||||
const bool compute_uv,
|
||||
std::optional<c10::string_view> driver,
|
||||
std::optional<std::string_view> driver,
|
||||
const Tensor & U,
|
||||
const Tensor & S,
|
||||
const Tensor & Vh) {
|
||||
@ -3246,7 +3246,7 @@ TORCH_IMPL_FUNC(_linalg_svd_out)(const Tensor& A,
|
||||
std::tuple<Tensor&, Tensor&, Tensor&>
|
||||
linalg_svd_out(const Tensor& A,
|
||||
bool full_matrices,
|
||||
std::optional<c10::string_view> driver,
|
||||
std::optional<std::string_view> driver,
|
||||
Tensor & U,
|
||||
Tensor & S,
|
||||
Tensor & Vh) {
|
||||
@ -3265,12 +3265,12 @@ linalg_svd_out(const Tensor& A,
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& A, bool full_matrices,
|
||||
std::optional<c10::string_view> driver) {
|
||||
std::optional<std::string_view> driver) {
|
||||
return at::_linalg_svd(A, full_matrices, /*compute_uv=*/true, driver);
|
||||
}
|
||||
|
||||
// See note in linalg_svd for why this function does not have an _ex variant
|
||||
Tensor& linalg_svdvals_out(const Tensor& A, std::optional<c10::string_view> driver, Tensor & S) {
|
||||
Tensor& linalg_svdvals_out(const Tensor& A, std::optional<std::string_view> driver, Tensor & S) {
|
||||
// Dummies
|
||||
auto U = at::empty({0}, A.options());
|
||||
auto Vh = at::empty({0}, A.options());
|
||||
@ -3278,7 +3278,7 @@ Tensor& linalg_svdvals_out(const Tensor& A, std::optional<c10::string_view> driv
|
||||
return S;
|
||||
}
|
||||
|
||||
Tensor linalg_svdvals(const Tensor& A, std::optional<c10::string_view> driver) {
|
||||
Tensor linalg_svdvals(const Tensor& A, std::optional<std::string_view> driver) {
|
||||
return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false,
|
||||
/*compute_uv=*/_may_require_fw_or_bw_grad(A),
|
||||
/*driver=*/driver));
|
||||
@ -3538,7 +3538,7 @@ static void linalg_lstsq_out_info(
|
||||
}
|
||||
}
|
||||
|
||||
static std::string get_default_lstsq_driver(std::optional<c10::string_view> driver, const Tensor& input) {
|
||||
static std::string get_default_lstsq_driver(std::optional<std::string_view> driver, const Tensor& input) {
|
||||
// if `driver` is empty, we set driver_str to "gels" if working with CUDA tensors,
|
||||
// otherwise to "gelsy" driver.
|
||||
std::string driver_str;
|
||||
@ -3548,7 +3548,7 @@ static std::string get_default_lstsq_driver(std::optional<c10::string_view> driv
|
||||
// convert `driver_str` to lower case inplace.
|
||||
std::transform(driver_str.begin(), driver_str.end(), driver_str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
static std::unordered_set<c10::string_view> allowed_drivers = {
|
||||
static std::unordered_set<std::string_view> allowed_drivers = {
|
||||
"gels", "gelsy", "gelsd", "gelss"
|
||||
};
|
||||
if (input.device() == at::kCPU) {
|
||||
@ -3575,7 +3575,7 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
std::optional<double> rcond,
|
||||
std::optional<c10::string_view> driver,
|
||||
std::optional<std::string_view> driver,
|
||||
Tensor& solution,
|
||||
Tensor& residuals,
|
||||
Tensor& rank,
|
||||
@ -3739,7 +3739,7 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
|
||||
const Tensor& input, const Tensor& other,
|
||||
std::optional<double> rcond,
|
||||
std::optional<c10::string_view> driver) {
|
||||
std::optional<std::string_view> driver) {
|
||||
Tensor solution = at::empty({0}, input.options());
|
||||
Tensor residuals = at::empty({0}, input.options().dtype(toRealValueType(input.scalar_type())));
|
||||
Tensor rank = at::empty({0}, input.options().dtype(at::kLong));
|
||||
|
@ -304,7 +304,7 @@ using svd_fn = void (*)(
|
||||
const Tensor& /*A*/,
|
||||
const bool /*full_matrices*/,
|
||||
const bool /*compute_uv*/,
|
||||
const std::optional<c10::string_view>& /*driver*/,
|
||||
const std::optional<std::string_view>& /*driver*/,
|
||||
const Tensor& /*U*/,
|
||||
const Tensor& /*S*/,
|
||||
const Tensor& /*Vh*/,
|
||||
|
@ -609,7 +609,7 @@ void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_valu
|
||||
// This is a type and driver dispatching helper function for 'apply_lstsq'
|
||||
void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name) {
|
||||
|
||||
static auto driver_string_to_type = std::unordered_map<c10::string_view, LapackLstsqDriverType>({
|
||||
static auto driver_string_to_type = std::unordered_map<std::string_view, LapackLstsqDriverType>({
|
||||
{"gels", at::native::LapackLstsqDriverType::Gels},
|
||||
{"gelsy", at::native::LapackLstsqDriverType::Gelsy},
|
||||
{"gelsd", at::native::LapackLstsqDriverType::Gelsd},
|
||||
@ -1087,7 +1087,7 @@ static void apply_svd(const Tensor& A,
|
||||
void svd_kernel(const Tensor& A,
|
||||
const bool full_matrices,
|
||||
const bool compute_uv,
|
||||
const std::optional<c10::string_view>& driver,
|
||||
const std::optional<std::string_view>& driver,
|
||||
const Tensor& U,
|
||||
const Tensor& S,
|
||||
const Tensor& Vh,
|
||||
|
@ -173,7 +173,7 @@ TORCH_META_FUNC2(div, Tensor) (const Tensor& self, const Tensor& other) {
|
||||
build_borrowing_binary_float_op(maybe_get_output(), self, other);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(div, Tensor_mode) (const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode) {
|
||||
TORCH_META_FUNC2(div, Tensor_mode) (const Tensor& self, const Tensor& other, std::optional<std::string_view> rounding_mode) {
|
||||
if (!rounding_mode.has_value()) {
|
||||
build_borrowing_binary_float_op(maybe_get_output(), self, other);
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
@ -448,7 +448,7 @@ TORCH_IMPL_FUNC(div_out) (const Tensor& self, const Tensor& other, const Tensor&
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(div_out_mode) (
|
||||
const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode, const Tensor& result
|
||||
const Tensor& self, const Tensor& other, std::optional<std::string_view> rounding_mode, const Tensor& result
|
||||
) {
|
||||
if (!rounding_mode.has_value()) {
|
||||
div_true_stub(device_type(), *this);
|
||||
@ -896,11 +896,11 @@ Tensor& div_(Tensor& self, const Scalar& other) {
|
||||
return self.div_(wrapped_scalar_tensor(other)); // redispatch!
|
||||
}
|
||||
|
||||
Tensor div(const Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor div(const Tensor& self, const Scalar& other, std::optional<std::string_view> rounding_mode) {
|
||||
return self.div(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
|
||||
}
|
||||
|
||||
Tensor& div_(Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor& div_(Tensor& self, const Scalar& other, std::optional<std::string_view> rounding_mode) {
|
||||
return self.div_(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch!
|
||||
}
|
||||
|
||||
@ -925,23 +925,23 @@ Tensor& divide_(Tensor& self, const Scalar& other) {
|
||||
return self.div_(other);
|
||||
}
|
||||
|
||||
Tensor& divide_out(const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode, Tensor& result) {
|
||||
Tensor& divide_out(const Tensor& self, const Tensor& other, std::optional<std::string_view> rounding_mode, Tensor& result) {
|
||||
return at::div_out(result, self, other, std::move(rounding_mode));
|
||||
}
|
||||
|
||||
Tensor divide(const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor divide(const Tensor& self, const Tensor& other, std::optional<std::string_view> rounding_mode) {
|
||||
return self.div(other, std::move(rounding_mode));
|
||||
}
|
||||
|
||||
Tensor& divide_(Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor& divide_(Tensor& self, const Tensor& other, std::optional<std::string_view> rounding_mode) {
|
||||
return self.div_(other, std::move(rounding_mode));
|
||||
}
|
||||
|
||||
Tensor divide(const Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor divide(const Tensor& self, const Scalar& other, std::optional<std::string_view> rounding_mode) {
|
||||
return self.div(other, std::move(rounding_mode));
|
||||
}
|
||||
|
||||
Tensor& divide_(Tensor& self, const Scalar& other, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor& divide_(Tensor& self, const Scalar& other, std::optional<std::string_view> rounding_mode) {
|
||||
return self.div_(other, std::move(rounding_mode));
|
||||
}
|
||||
|
||||
|
@ -146,7 +146,7 @@ Tensor& searchsorted_out_cpu(
|
||||
const Tensor& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt,
|
||||
Tensor& result) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
@ -193,7 +193,7 @@ Tensor& searchsorted_out_cpu(
|
||||
const Scalar& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt,
|
||||
Tensor& result) {
|
||||
const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
|
||||
@ -205,7 +205,7 @@ Tensor searchsorted_cpu(
|
||||
const Tensor& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt) {
|
||||
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
|
||||
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
|
||||
@ -219,7 +219,7 @@ Tensor searchsorted_cpu(
|
||||
const Scalar& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt) {
|
||||
const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
|
||||
return searchsorted_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt);
|
||||
|
@ -107,10 +107,10 @@ inline void searchsorted_pre_check(
|
||||
const Tensor& output,
|
||||
const bool out_int32,
|
||||
const bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const Tensor& sorter) {
|
||||
if (side_opt) {
|
||||
const c10::string_view side = *side_opt;
|
||||
const std::string_view side = *side_opt;
|
||||
TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
|
||||
"got ", side);
|
||||
|
||||
|
@ -885,7 +885,7 @@ at::Tensor complex_convolution_mode(
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_opt,
|
||||
c10::SymIntArrayRef stride,
|
||||
c10::string_view padding,
|
||||
std::string_view padding,
|
||||
c10::SymIntArrayRef dilation,
|
||||
const c10::SymInt& groups) {
|
||||
auto bias = bias_opt.value_or(Tensor());
|
||||
@ -1055,7 +1055,7 @@ static Tensor convolution_same(
|
||||
|
||||
Tensor _convolution_mode_symint(
|
||||
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
||||
SymIntArrayRef stride, c10::string_view padding, SymIntArrayRef dilation,
|
||||
SymIntArrayRef stride, std::string_view padding, SymIntArrayRef dilation,
|
||||
c10::SymInt groups) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
@ -1073,7 +1073,7 @@ Tensor _convolution_mode_symint(
|
||||
|
||||
at::Tensor conv1d_padding_symint(
|
||||
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
|
||||
c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
|
||||
c10::SymIntArrayRef stride, std::string_view padding, c10::SymIntArrayRef dilation,
|
||||
c10::SymInt groups) {
|
||||
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
|
||||
Tensor output;
|
||||
@ -1087,7 +1087,7 @@ at::Tensor conv1d_padding_symint(
|
||||
|
||||
at::Tensor conv2d_padding_symint(
|
||||
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
|
||||
c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
|
||||
c10::SymIntArrayRef stride, std::string_view padding, c10::SymIntArrayRef dilation,
|
||||
c10::SymInt groups) {
|
||||
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
|
||||
Tensor output;
|
||||
@ -1101,7 +1101,7 @@ at::Tensor conv2d_padding_symint(
|
||||
|
||||
at::Tensor conv3d_padding_symint(
|
||||
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
|
||||
c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
|
||||
c10::SymIntArrayRef stride, std::string_view padding, c10::SymIntArrayRef dilation,
|
||||
c10::SymInt groups) {
|
||||
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
|
||||
Tensor output;
|
||||
|
@ -11,7 +11,7 @@ enum class GeluType {
|
||||
END
|
||||
};
|
||||
|
||||
inline GeluType get_gelutype_enum(const c10::string_view approximate) {
|
||||
inline GeluType get_gelutype_enum(const std::string_view approximate) {
|
||||
if (approximate == "none") {
|
||||
return GeluType::None;
|
||||
} else if (approximate == "tanh") {
|
||||
|
@ -252,7 +252,7 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra
|
||||
// If a path is specified, we reduce in the order specified by the path, else we
|
||||
// default to going left => right. The path is a list of indices processed the same
|
||||
// way as opt-einsum: https://optimized-einsum.readthedocs.io/en/stable/path_finding.html#format-of-the-path
|
||||
Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArrayRef path) {
|
||||
Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArrayRef path) {
|
||||
TORCH_CHECK(!operands.empty(), "einsum(): must provide at least one operand");
|
||||
const auto num_ops = operands.size();
|
||||
|
||||
|
@ -452,7 +452,7 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
|
||||
const Tensor& input,
|
||||
const std::optional<Tensor>& atol_opt,
|
||||
const std::optional<Tensor>& rtol_opt,
|
||||
const c10::string_view function_name) {
|
||||
const std::string_view function_name) {
|
||||
auto options = input.options();
|
||||
if (input.device().type() == kMetal || input.device().type() == kMPS) {
|
||||
options = options.dtype(ScalarType::Float);
|
||||
@ -2950,7 +2950,7 @@ Tensor& linalg_matrix_norm_out(
|
||||
// fro / nuc
|
||||
Tensor linalg_matrix_norm(
|
||||
const Tensor& A,
|
||||
c10::string_view ord,
|
||||
std::string_view ord,
|
||||
IntArrayRef dim,
|
||||
bool keepdim,
|
||||
std::optional<ScalarType> opt_dtype) {
|
||||
@ -2979,7 +2979,7 @@ Tensor linalg_matrix_norm(
|
||||
|
||||
Tensor& linalg_matrix_norm_out(
|
||||
const Tensor& A,
|
||||
c10::string_view ord,
|
||||
std::string_view ord,
|
||||
IntArrayRef dim,
|
||||
bool keepdim,
|
||||
std::optional<ScalarType> opt_dtype,
|
||||
@ -3032,7 +3032,7 @@ Tensor& linalg_norm_out(const Tensor& X, const std::optional<Scalar>& opt_ord, O
|
||||
}
|
||||
|
||||
// Frobenius and nuclear norms
|
||||
Tensor linalg_norm(const Tensor& X, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
|
||||
Tensor linalg_norm(const Tensor& X, std::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
|
||||
if (opt_dim.has_value()) {
|
||||
TORCH_CHECK(opt_dim->size() == 1 || opt_dim ->size() == 2, "linalg.norm: If ",
|
||||
"dim is specified, it mut be of length 1 or 2. Got ", *opt_dim);
|
||||
@ -3045,7 +3045,7 @@ Tensor linalg_norm(const Tensor& X, c10::string_view ord, OptionalIntArrayRef op
|
||||
return at::linalg_matrix_norm(X, ord, dim, keepdim, opt_dtype);
|
||||
}
|
||||
|
||||
Tensor& linalg_norm_out(const Tensor& X, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
|
||||
Tensor& linalg_norm_out(const Tensor& X, std::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
|
||||
checkSameDevice("linalg.norm", X, result);
|
||||
auto out = at::linalg_norm(X, ord, opt_dim, keepdim, opt_dtype);
|
||||
TORCH_CHECK(out.scalar_type() == result.scalar_type(),
|
||||
@ -3144,7 +3144,7 @@ Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tens
|
||||
|
||||
|
||||
// This function helps to dispatch norm computations depending on 'ord' of variant type
|
||||
static Tensor _linalg_cond_helper(const Tensor& self, std::variant<Scalar, c10::string_view> ord_variant) {
|
||||
static Tensor _linalg_cond_helper(const Tensor& self, std::variant<Scalar, std::string_view> ord_variant) {
|
||||
Tensor inverse, info;
|
||||
std::tie(inverse, info) = at::linalg_inv_ex(self);
|
||||
info.unsqueeze_(-1).unsqueeze_(-1);
|
||||
@ -3167,14 +3167,14 @@ static Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtyp
|
||||
return at::zeros(result_shape, options);
|
||||
}
|
||||
|
||||
static void _linalg_cond_check_ord(std::variant<Scalar, c10::string_view> ord_variant) {
|
||||
static void _linalg_cond_check_ord(std::variant<Scalar, std::string_view> ord_variant) {
|
||||
if (ord_variant.index() == 0) {
|
||||
Scalar* ord = std::get_if<Scalar>(&ord_variant);
|
||||
double abs_ord = std::abs(ord->toDouble());
|
||||
TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY,
|
||||
"linalg.cond got an invalid norm type: ", ord->toDouble());
|
||||
} else if (ord_variant.index() == 1) {
|
||||
c10::string_view* ord = std::get_if<c10::string_view>(&ord_variant);
|
||||
std::string_view* ord = std::get_if<std::string_view>(&ord_variant);
|
||||
TORCH_CHECK(*ord == "fro" || *ord == "nuc",
|
||||
"linalg.cond got an invalid norm type: ", *ord);
|
||||
} else {
|
||||
@ -3190,7 +3190,7 @@ Tensor linalg_cond(const Tensor& self, const std::optional<Scalar>& opt_ord) {
|
||||
// The default case is using 2-norm
|
||||
Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
|
||||
|
||||
std::variant<Scalar, c10::string_view> ord_variant = ord;
|
||||
std::variant<Scalar, std::string_view> ord_variant = ord;
|
||||
_linalg_cond_check_ord(ord_variant);
|
||||
|
||||
// NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
|
||||
@ -3236,9 +3236,9 @@ Tensor& linalg_cond_out(const Tensor& self, const std::optional<Scalar>& opt_ord
|
||||
}
|
||||
|
||||
// Frobenius or nuclear norms
|
||||
Tensor linalg_cond(const Tensor& self, c10::string_view ord) {
|
||||
Tensor linalg_cond(const Tensor& self, std::string_view ord) {
|
||||
squareCheckInputs(self, ("linalg.cond(ord=" + std::string(ord) + ")").c_str());
|
||||
std::variant<Scalar, c10::string_view> ord_variant = ord;
|
||||
std::variant<Scalar, std::string_view> ord_variant = ord;
|
||||
_linalg_cond_check_ord(ord_variant);
|
||||
|
||||
// NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
|
||||
@ -3258,7 +3258,7 @@ Tensor linalg_cond(const Tensor& self, c10::string_view ord) {
|
||||
}
|
||||
|
||||
// TODO: implement _out variant avoiding copy and using already allocated storage directly
|
||||
Tensor& linalg_cond_out(const Tensor& self, c10::string_view ord, Tensor& result) {
|
||||
Tensor& linalg_cond_out(const Tensor& self, std::string_view ord, Tensor& result) {
|
||||
checkSameDevice("linalg.cond", result, self);
|
||||
ScalarType real_dtype = toRealValueType(self.scalar_type());
|
||||
checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
|
||||
@ -3526,7 +3526,7 @@ Tensor _weight_int8pack_mm_cpu(
|
||||
|
||||
Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
#ifndef STRIP_ERROR_MESSAGES
|
||||
static constexpr c10::string_view func_name = "int_mm_out_cpu";
|
||||
static constexpr std::string_view func_name = "int_mm_out_cpu";
|
||||
#endif
|
||||
TORCH_CHECK(self.dim() == 2, func_name, ": Expected self to be of dimension 2 but got ", self.dim());
|
||||
TORCH_CHECK(mat2.dim() == 2, func_name, ": Expected mat2 to be of dimension 2 but got ", mat2.dim());
|
||||
|
@ -369,7 +369,7 @@ inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
||||
}
|
||||
|
||||
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
|
||||
inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
|
||||
inline std::tuple<bool, bool> _parse_qr_mode(std::string_view mode) {
|
||||
bool compute_q;
|
||||
bool reduced;
|
||||
if (mode == "reduced") {
|
||||
@ -485,7 +485,7 @@ inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
||||
|
||||
// This function checks whether the uplo argument input is valid
|
||||
// Allowed strings are "u", "U", "l", "L"
|
||||
inline void checkUplo(const c10::string_view uplo) {
|
||||
inline void checkUplo(const std::string_view uplo) {
|
||||
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
|
||||
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
|
||||
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
|
||||
@ -524,7 +524,7 @@ inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType ou
|
||||
out_name, " with dtype ", out_type);
|
||||
}
|
||||
|
||||
inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
|
||||
inline void checkNotComplexTolerance(const Tensor& tol, const std::string_view f_name, const std::string_view tol_name) {
|
||||
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
|
||||
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor
|
||||
return std::make_tuple(output, lengths_t);
|
||||
}
|
||||
|
||||
Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value, const c10::string_view padding_side) {
|
||||
Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value, const std::string_view padding_side) {
|
||||
const int64_t sequences_size = sequences.size();
|
||||
TORCH_CHECK(sequences_size > 0, "received an empty list of sequences");
|
||||
TORCH_CHECK(padding_side == "left" || padding_side == "right",
|
||||
|
@ -189,7 +189,7 @@ Tensor _pad_circular_symint(const Tensor &self, c10::SymIntArrayRef padding) {
|
||||
return out;
|
||||
}
|
||||
|
||||
static c10::string_view padding_mode_string(padding_mode m) {
|
||||
static std::string_view padding_mode_string(padding_mode m) {
|
||||
switch (m) {
|
||||
case padding_mode::reflect:
|
||||
return "reflect";
|
||||
@ -244,7 +244,7 @@ Tensor _pad_enum_symint(const Tensor &self, c10::SymIntArrayRef pad, int64_t mod
|
||||
"Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now");
|
||||
}
|
||||
|
||||
Tensor pad_symint(const Tensor &self, c10::SymIntArrayRef pad, c10::string_view mode, std::optional<double> value) {
|
||||
Tensor pad_symint(const Tensor &self, c10::SymIntArrayRef pad, std::string_view mode, std::optional<double> value) {
|
||||
const auto mode_enum = [&] {
|
||||
if (mode == "reflect") {
|
||||
return at::padding_mode::reflect;
|
||||
|
@ -60,7 +60,7 @@ signature.
|
||||
- `int`. Think about this like a Python int. This is translated into a C++ argument of type `int64_t`.
|
||||
- `float`. Think about this like a Python `float`. It is translated into a C++ argument of type `double`.
|
||||
- `bool`
|
||||
- `str`. It is translated into a C++ argument of non-owning type `c10::string_view`
|
||||
- `str`. It is translated into a C++ argument of non-owning type `std::string_view`
|
||||
- `Scalar`. `Scalar` supports binding to any numerical types from Python, including integral types,
|
||||
floating point types, and zero dimensional tensors. `int` and `float` bind to the corresponding Python
|
||||
numerical types. However, you probably don't want to use `Scalar`;
|
||||
|
@ -6,7 +6,7 @@ namespace at::native {
|
||||
|
||||
enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
|
||||
|
||||
inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
inline ReductionType get_reduction_enum(const std::string_view& reduce) {
|
||||
if (reduce == "max" || reduce == "amax") {
|
||||
return ReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
@ -23,7 +23,7 @@ inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
}
|
||||
|
||||
// used for `scatter_reduce`, old options for BC.
|
||||
inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
|
||||
inline ReductionType get_operator_enum(const std::string_view reduce, bool use_new_options) {
|
||||
if (use_new_options) {
|
||||
return get_reduction_enum(reduce);
|
||||
} else {
|
||||
|
@ -385,7 +385,7 @@ Tensor _segment_reduce_cpu_offsets_backward_kernel(
|
||||
|
||||
Tensor segment_reduce_kernel(
|
||||
const Tensor& data,
|
||||
c10::string_view reduce,
|
||||
std::string_view reduce,
|
||||
const std::optional<Tensor>& lengths,
|
||||
const std::optional<Tensor>& indices,
|
||||
const std::optional<Tensor>& offsets,
|
||||
@ -485,7 +485,7 @@ Tensor _segment_reduce_backward_kernel(
|
||||
const Tensor& grad,
|
||||
const Tensor& output,
|
||||
const Tensor& data,
|
||||
c10::string_view reduce,
|
||||
std::string_view reduce,
|
||||
const std::optional<Tensor>& lengths,
|
||||
const std::optional<Tensor>& offsets,
|
||||
int64_t axis,
|
||||
|
@ -185,7 +185,7 @@ void quick_select_template(
|
||||
namespace {
|
||||
|
||||
QUANTILE_INTERPOLATION_MODE get_quantile_interpolation_mode(
|
||||
const c10::string_view interpolation) {
|
||||
const std::string_view interpolation) {
|
||||
if (interpolation == "linear") {
|
||||
return QUANTILE_INTERPOLATION_MODE::LINEAR;
|
||||
} else if (interpolation == "lower") {
|
||||
@ -655,7 +655,7 @@ Tensor& quantile_out(
|
||||
const Tensor& q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation,
|
||||
const std::string_view interpolation,
|
||||
Tensor& out) {
|
||||
quantile_out_impl(
|
||||
out,
|
||||
@ -673,7 +673,7 @@ Tensor& quantile_out(
|
||||
double q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation,
|
||||
const std::string_view interpolation,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(
|
||||
q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
|
||||
@ -691,7 +691,7 @@ Tensor quantile(
|
||||
const Tensor& q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation) {
|
||||
const std::string_view interpolation) {
|
||||
return quantile_impl(
|
||||
self,
|
||||
q,
|
||||
@ -706,7 +706,7 @@ Tensor quantile(
|
||||
double q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation) {
|
||||
const std::string_view interpolation) {
|
||||
TORCH_CHECK(
|
||||
q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
|
||||
return at::native::quantile(
|
||||
@ -718,7 +718,7 @@ Tensor& nanquantile_out(
|
||||
const Tensor& q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation,
|
||||
const std::string_view interpolation,
|
||||
Tensor& out) {
|
||||
quantile_out_impl(
|
||||
out,
|
||||
@ -736,7 +736,7 @@ Tensor& nanquantile_out(
|
||||
double q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation,
|
||||
const std::string_view interpolation,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(
|
||||
q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
|
||||
@ -754,7 +754,7 @@ Tensor nanquantile(
|
||||
const Tensor& q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation) {
|
||||
const std::string_view interpolation) {
|
||||
return quantile_impl(
|
||||
self,
|
||||
q,
|
||||
@ -769,7 +769,7 @@ Tensor nanquantile(
|
||||
double q,
|
||||
std::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
const c10::string_view interpolation) {
|
||||
const std::string_view interpolation) {
|
||||
TORCH_CHECK(
|
||||
q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
|
||||
return at::native::nanquantile(
|
||||
|
@ -114,7 +114,7 @@ Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
|
||||
// Convert NumPy compatible normalization mode string to enum values
|
||||
// NOTE: NumPy's normalization modes have direction-specific meanings. For example,
|
||||
// "forward" translates to `by_n` for a forward transform and `none` for backward.
|
||||
fft_norm_mode norm_from_string(std::optional<c10::string_view> norm, bool forward) {
|
||||
fft_norm_mode norm_from_string(std::optional<std::string_view> norm, bool forward) {
|
||||
if (!norm || *norm == "backward") {
|
||||
return forward ? fft_norm_mode::none : fft_norm_mode::by_n;
|
||||
}
|
||||
@ -158,7 +158,7 @@ Tensor resize_fft_input(Tensor x, IntArrayRef dims, SymIntArrayRef sizes) {
|
||||
}
|
||||
|
||||
Tensor fft_r2c_maybe_out(
|
||||
c10::string_view fname, const Tensor& out, const Tensor& input,
|
||||
std::string_view fname, const Tensor& out, const Tensor& input,
|
||||
IntArrayRef dim, int64_t norm, bool onesided) {
|
||||
if (out.defined()) {
|
||||
TORCH_CHECK(out.is_complex(), fname,
|
||||
@ -170,7 +170,7 @@ Tensor fft_r2c_maybe_out(
|
||||
}
|
||||
|
||||
Tensor fft_c2r_maybe_out(
|
||||
c10::string_view fname, const Tensor& out, const Tensor& input,
|
||||
std::string_view fname, const Tensor& out, const Tensor& input,
|
||||
IntArrayRef dim, int64_t norm, SymInt last_dim_size) {
|
||||
// Support out argument if defined, otherwise call functional
|
||||
// variant so autograd works properly.
|
||||
@ -184,7 +184,7 @@ Tensor fft_c2r_maybe_out(
|
||||
}
|
||||
|
||||
Tensor fft_c2c_maybe_out(
|
||||
c10::string_view fname, const Tensor& out, const Tensor& input,
|
||||
std::string_view fname, const Tensor& out, const Tensor& input,
|
||||
IntArrayRef dim, int64_t norm, bool forward) {
|
||||
if (out.defined()) {
|
||||
TORCH_CHECK(out.is_complex(), fname,
|
||||
@ -196,9 +196,9 @@ Tensor fft_c2c_maybe_out(
|
||||
}
|
||||
|
||||
// Complex to real FFT
|
||||
Tensor fft_c2r(c10::string_view function_name,
|
||||
Tensor fft_c2r(std::string_view function_name,
|
||||
Tensor out, Tensor input, std::optional<SymInt> n_opt,
|
||||
int64_t unwrapped_dim, std::optional<c10::string_view> norm_str,
|
||||
int64_t unwrapped_dim, std::optional<std::string_view> norm_str,
|
||||
bool forward) {
|
||||
TORCH_CHECK(!out.defined() || out.is_floating_point(), function_name,
|
||||
" expects a floating point output tensor, but got ", out.scalar_type());
|
||||
@ -220,9 +220,9 @@ Tensor fft_c2r(c10::string_view function_name,
|
||||
}
|
||||
|
||||
// Real to complex FFT
|
||||
Tensor fft_r2c(c10::string_view function_name,
|
||||
Tensor fft_r2c(std::string_view function_name,
|
||||
Tensor out, Tensor input, std::optional<SymInt> n_opt,
|
||||
int64_t unwrapped_dim, std::optional<c10::string_view> norm_str,
|
||||
int64_t unwrapped_dim, std::optional<std::string_view> norm_str,
|
||||
bool forward, bool onesided) {
|
||||
TORCH_CHECK(!input.is_complex(), function_name,
|
||||
" expects a real input tensor, but got ", input.scalar_type());
|
||||
@ -255,9 +255,9 @@ Tensor fft_r2c(c10::string_view function_name,
|
||||
}
|
||||
|
||||
// Complex to complex FFT
|
||||
Tensor fft_c2c(c10::string_view function_name,
|
||||
Tensor fft_c2c(std::string_view function_name,
|
||||
Tensor out, Tensor input, std::optional<SymInt> n_opt,
|
||||
int64_t unwrapped_dim, std::optional<c10::string_view> norm_str,
|
||||
int64_t unwrapped_dim, std::optional<std::string_view> norm_str,
|
||||
bool forward) {
|
||||
TORCH_CHECK(input.is_complex(), function_name,
|
||||
" expects a complex input tensor, but got ", input.scalar_type());
|
||||
@ -344,13 +344,13 @@ ShapeAndDims canonicalize_fft_shape_and_dim_args(
|
||||
|
||||
// Complex to complex n-dimensional fft
|
||||
Tensor fftn_c2c(
|
||||
c10::string_view function_name,
|
||||
std::string_view function_name,
|
||||
Tensor out, const Tensor& input, SymIntArrayRef shape,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm_str, bool forward) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm_str, bool forward) {
|
||||
TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got", input.scalar_type());
|
||||
Tensor x = resize_fft_input(input, dim, shape);
|
||||
const auto norm = static_cast<int64_t>(norm_from_string(norm_str, forward));
|
||||
constexpr c10::string_view fname = "fftn";
|
||||
constexpr std::string_view fname = "fftn";
|
||||
return fft_c2c_maybe_out(fname, out, x, dim, norm, forward);
|
||||
}
|
||||
|
||||
@ -358,14 +358,14 @@ Tensor fftn_c2c(
|
||||
|
||||
// torch.fft.fft, analogous to NumPy's numpy.fft.fft
|
||||
Tensor fft_fft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return self.is_complex() ?
|
||||
fft_c2c("fft", {}, self, n, dim, norm, /*forward=*/true) :
|
||||
fft_r2c("fft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/false);
|
||||
}
|
||||
|
||||
Tensor& fft_fft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
int64_t dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
if (self.is_complex()) {
|
||||
fft_c2c("fft", out, self, n, dim, norm, /*forward=*/true);
|
||||
} else {
|
||||
@ -375,14 +375,14 @@ Tensor& fft_fft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
}
|
||||
|
||||
Tensor fft_ifft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return self.is_complex() ?
|
||||
fft_c2c("ifft", {}, self, n, dim, norm, /*forward=*/false) :
|
||||
fft_r2c("ifft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/false);
|
||||
}
|
||||
|
||||
Tensor& fft_ifft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
int64_t dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
if (self.is_complex()) {
|
||||
fft_c2c("ifft", out, self, n, dim, norm, /*forward=*/false);
|
||||
} else {
|
||||
@ -392,52 +392,52 @@ Tensor& fft_ifft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
}
|
||||
|
||||
Tensor fft_rfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return fft_r2c("rfft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/true);
|
||||
}
|
||||
|
||||
Tensor& fft_rfft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
int64_t dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
fft_r2c("rfft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/true);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor fft_irfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return fft_c2r("irfft", {}, self, n, dim, norm, /*forward=*/false);
|
||||
}
|
||||
|
||||
Tensor& fft_irfft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
int64_t dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
fft_c2r("irfft", out, self, n, dim, norm, /*forward=*/false);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor fft_hfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return fft_c2r("hfft", {}, self, n, dim, norm, /*forward=*/true);
|
||||
}
|
||||
|
||||
Tensor& fft_hfft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
int64_t dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
fft_c2r("hfft", out, self, n, dim, norm, /*forward=*/true);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor fft_ihfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return fft_r2c("ihfft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
|
||||
}
|
||||
|
||||
Tensor& fft_ihfft_symint_out(const Tensor& self, std::optional<SymInt> n,
|
||||
int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
int64_t dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
fft_r2c("ihfft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor fft_fftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
|
||||
// TODO: For real input, perform rfftn then mirror with conjugate symmetry
|
||||
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
|
||||
@ -447,7 +447,7 @@ Tensor fft_fftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
Tensor& fft_fftn_symint_out(const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm, Tensor& out) {
|
||||
std::optional<std::string_view> norm, Tensor& out) {
|
||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
|
||||
// TODO: For real input, perform rfftn then mirror with conjugate symmetry
|
||||
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
|
||||
@ -457,7 +457,7 @@ Tensor& fft_fftn_symint_out(const Tensor& self,
|
||||
|
||||
Tensor fft_ifftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
|
||||
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
|
||||
return fftn_c2c("ifftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/false);
|
||||
@ -466,7 +466,7 @@ Tensor fft_ifftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
Tensor& fft_ifftn_symint_out(const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm, Tensor& out) {
|
||||
std::optional<std::string_view> norm, Tensor& out) {
|
||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
|
||||
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
|
||||
fftn_c2c("ifftn", out, input, desc.shape, desc.dim, norm, /*forward=*/false);
|
||||
@ -476,33 +476,33 @@ Tensor& fft_ifftn_symint_out(const Tensor& self,
|
||||
static Tensor fft_rfftn_impl(Tensor out, const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
const std::optional<c10::string_view>& norm_str) {
|
||||
const std::optional<std::string_view>& norm_str) {
|
||||
TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type());
|
||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
|
||||
TORCH_CHECK(!desc.shape.empty(), "rfftn must transform at least one axis");
|
||||
Tensor input = promote_tensor_fft(self, /*require_complex=*/false);
|
||||
Tensor x = resize_fft_input(input, desc.dim, desc.shape);
|
||||
const auto norm = static_cast<int64_t>(norm_from_string(norm_str, /*forward=*/true));
|
||||
constexpr c10::string_view fname = "rfftn";
|
||||
constexpr std::string_view fname = "rfftn";
|
||||
return fft_r2c_maybe_out(fname, out, x, desc.dim, norm, /*onesided=*/true);
|
||||
}
|
||||
|
||||
Tensor fft_rfftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm_str) {
|
||||
std::optional<std::string_view> norm_str) {
|
||||
return fft_rfftn_impl({}, self, s, dim, norm_str);
|
||||
}
|
||||
|
||||
Tensor& fft_rfftn_symint_out(const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm_str, Tensor& out) {
|
||||
std::optional<std::string_view> norm_str, Tensor& out) {
|
||||
fft_rfftn_impl(out, self, s, dim, norm_str);
|
||||
return out;
|
||||
}
|
||||
|
||||
static ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
|
||||
c10::string_view fname, const Tensor& self,
|
||||
std::string_view fname, const Tensor& self,
|
||||
const at::OptionalSymIntArrayRef& s,
|
||||
const at::OptionalIntArrayRef& dims,
|
||||
SymInt& last_dim_size) {
|
||||
@ -528,28 +528,28 @@ static ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
|
||||
static Tensor fft_irfftn_impl(Tensor out, const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
const std::optional<c10::string_view>& norm_str) {
|
||||
const std::optional<std::string_view>& norm_str) {
|
||||
SymInt last_dim_size = 0;
|
||||
auto desc = canonicalize_fft_c2r_shape_and_dim_args(
|
||||
"irfftn", self, s, dim, last_dim_size);
|
||||
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
|
||||
Tensor x = resize_fft_input(input, desc.dim, desc.shape);
|
||||
const auto norm = static_cast<int64_t>(norm_from_string(norm_str, /*forward=*/false));
|
||||
constexpr c10::string_view fname = "irfftn";
|
||||
constexpr std::string_view fname = "irfftn";
|
||||
return fft_c2r_maybe_out(fname, out, x, desc.dim, norm, last_dim_size);
|
||||
}
|
||||
|
||||
Tensor fft_irfftn_symint(const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm_str) {
|
||||
std::optional<std::string_view> norm_str) {
|
||||
return fft_irfftn_impl({}, self, s, dim, norm_str);
|
||||
}
|
||||
|
||||
Tensor& fft_irfftn_symint_out(const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm_str, Tensor& out) {
|
||||
std::optional<std::string_view> norm_str, Tensor& out) {
|
||||
fft_irfftn_impl(out, self, s, dim, norm_str);
|
||||
return out;
|
||||
}
|
||||
@ -558,9 +558,9 @@ static Tensor fft_hfftn_impl(
|
||||
const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm_str,
|
||||
std::optional<std::string_view> norm_str,
|
||||
const Tensor& out) {
|
||||
constexpr c10::string_view fname = "hfftn";
|
||||
constexpr std::string_view fname = "hfftn";
|
||||
SymInt last_dim_size = 0;
|
||||
auto desc = canonicalize_fft_c2r_shape_and_dim_args(
|
||||
fname, self, s, dim, last_dim_size);
|
||||
@ -586,14 +586,14 @@ Tensor fft_hfftn_symint(
|
||||
const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return fft_hfftn_impl(self, s, dim, norm, {});
|
||||
}
|
||||
|
||||
const Tensor& fft_hfftn_symint_out(
|
||||
const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim, std::optional<c10::string_view> norm,
|
||||
at::OptionalIntArrayRef dim, std::optional<std::string_view> norm,
|
||||
const Tensor& out) {
|
||||
fft_hfftn_impl(self, s, dim, norm, out);
|
||||
return out;
|
||||
@ -603,9 +603,9 @@ static Tensor fft_ihfftn_impl(
|
||||
const Tensor& self,
|
||||
const at::OptionalSymIntArrayRef& s,
|
||||
const at::OptionalIntArrayRef& dim,
|
||||
const std::optional<c10::string_view>& norm_str,
|
||||
const std::optional<std::string_view>& norm_str,
|
||||
const Tensor& out) {
|
||||
constexpr c10::string_view fname = "ihfftn";
|
||||
constexpr std::string_view fname = "ihfftn";
|
||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
|
||||
TORCH_CHECK(!desc.shape.empty(), "ihfftn must transform at least one axis");
|
||||
auto input = promote_tensor_fft(self, /*require_complex=*/false);
|
||||
@ -628,7 +628,7 @@ Tensor fft_ihfftn_symint(
|
||||
const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm) {
|
||||
std::optional<std::string_view> norm) {
|
||||
return fft_ihfftn_impl(self, s, dim, norm, {});
|
||||
}
|
||||
|
||||
@ -636,71 +636,71 @@ const Tensor& fft_ihfftn_symint_out(
|
||||
const Tensor& self,
|
||||
at::OptionalSymIntArrayRef s,
|
||||
at::OptionalIntArrayRef dim,
|
||||
std::optional<c10::string_view> norm,
|
||||
std::optional<std::string_view> norm,
|
||||
const Tensor& out) {
|
||||
fft_ihfftn_impl(self, s, dim, norm, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor fft_fft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm) {
|
||||
return native::fft_fftn_symint(self, s, dim, std::move(norm));
|
||||
}
|
||||
|
||||
Tensor& fft_fft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
return native::fft_fftn_symint_out(self, s, dim, std::move(norm), out);
|
||||
}
|
||||
|
||||
Tensor fft_ifft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm) {
|
||||
return native::fft_ifftn_symint(self, s, dim, std::move(norm));
|
||||
}
|
||||
|
||||
Tensor& fft_ifft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
return native::fft_ifftn_symint_out(self, s, dim, std::move(norm), out);
|
||||
}
|
||||
|
||||
Tensor fft_rfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm) {
|
||||
return native::fft_rfftn_symint(self, s, dim, std::move(norm));
|
||||
}
|
||||
|
||||
Tensor& fft_rfft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
return native::fft_rfftn_symint_out(self, s, dim, std::move(norm), out);
|
||||
}
|
||||
|
||||
Tensor fft_irfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm) {
|
||||
return native::fft_irfftn_symint(self, s, dim, std::move(norm));
|
||||
}
|
||||
|
||||
Tensor& fft_irfft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm, Tensor& out) {
|
||||
return native::fft_irfftn_symint_out(self, s, dim, std::move(norm), out);
|
||||
}
|
||||
|
||||
const Tensor& fft_hfft2_symint_out(
|
||||
const Tensor& self, at::OptionalSymIntArrayRef s, IntArrayRef dim,
|
||||
std::optional<c10::string_view> norm, const Tensor& out) {
|
||||
std::optional<std::string_view> norm, const Tensor& out) {
|
||||
return native::fft_hfftn_symint_out(self, s, dim, std::move(norm), out);
|
||||
}
|
||||
|
||||
Tensor fft_hfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm) {
|
||||
return native::fft_hfftn_symint(self, s, dim, std::move(norm));
|
||||
}
|
||||
|
||||
const Tensor& fft_ihfft2_symint_out(
|
||||
const Tensor& self, at::OptionalSymIntArrayRef s, IntArrayRef dim,
|
||||
std::optional<c10::string_view> norm, const Tensor& out) {
|
||||
std::optional<std::string_view> norm, const Tensor& out) {
|
||||
return native::fft_ihfftn_symint_out(self, s, dim, std::move(norm), out);
|
||||
}
|
||||
|
||||
Tensor fft_ihfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
|
||||
IntArrayRef dim, std::optional<c10::string_view> norm) {
|
||||
IntArrayRef dim, std::optional<std::string_view> norm) {
|
||||
return native::fft_ihfftn_symint(self, s, dim, std::move(norm));
|
||||
}
|
||||
|
||||
@ -825,7 +825,7 @@ static Stream& write_opt(Stream& SS, const std::optional<T>& value) {
|
||||
*/
|
||||
Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
|
||||
const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
|
||||
const bool center, c10::string_view mode, const bool normalized,
|
||||
const bool center, std::string_view mode, const bool normalized,
|
||||
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
|
||||
|
@ -191,7 +191,7 @@ void scatter_meta_impl(
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const std::optional<Tensor>& src = std::nullopt,
|
||||
const std::optional<c10::string_view> reduce = std::nullopt) {
|
||||
const std::optional<std::string_view> reduce = std::nullopt) {
|
||||
int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
at::native::scatter_gather_dtype_check("scatter", self, index, src);
|
||||
at::native::scatter_shape_check(self, wrapped_dim, index, src);
|
||||
@ -227,7 +227,7 @@ TORCH_META_FUNC2(scatter, reduce)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce) {
|
||||
const std::string_view reduce) {
|
||||
TORCH_WARN_ONCE(
|
||||
"The reduce argument of torch.scatter with Tensor src is deprecated and will be removed ",
|
||||
"in a future PyTorch release. Use torch.scatter_reduce instead for more reduction options."
|
||||
@ -240,7 +240,7 @@ TORCH_META_FUNC2(scatter, value_reduce)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Scalar& src,
|
||||
const c10::string_view reduce) {
|
||||
const std::string_view reduce) {
|
||||
scatter_meta_impl(*this, self, dim, index, std::nullopt, reduce);
|
||||
}
|
||||
|
||||
@ -254,7 +254,7 @@ TORCH_META_FUNC2(scatter_reduce, two)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_self) {
|
||||
(void) include_self;
|
||||
scatter_meta_impl</*use_new_options=*/true>(*this, self, dim, index, src, reduce);
|
||||
@ -326,7 +326,7 @@ void index_func_meta_impl(
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& source,
|
||||
c10::string_view func) {
|
||||
std::string_view func) {
|
||||
auto numel = index.numel();
|
||||
|
||||
TORCH_CHECK_INDEX(index.dim() <= 1, func, "_(): Index is supposed to be a vector, but got dim: ",
|
||||
@ -387,7 +387,7 @@ TORCH_PRECOMPUTE_META_FUNC(index_reduce)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& source,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_self) {
|
||||
(void)include_self;
|
||||
TORCH_CHECK(reduce == "prod" || reduce == "mean" || reduce == "amax" || reduce == "amin",
|
||||
@ -1212,7 +1212,7 @@ TORCH_IMPL_FUNC(index_reduce_cpu_out)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& source,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_input,
|
||||
const Tensor& result) {
|
||||
TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
|
||||
@ -1796,7 +1796,7 @@ void scatter_impl(
|
||||
const Tensor& out,
|
||||
ReduceStub& reduce_stub,
|
||||
FillStub& fill_stub,
|
||||
const std::optional<c10::string_view> reduce = std::nullopt,
|
||||
const std::optional<std::string_view> reduce = std::nullopt,
|
||||
bool reduce_includes_self = true) {
|
||||
|
||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
@ -1865,7 +1865,7 @@ TORCH_IMPL_FUNC(scatter_reduce_out)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
const Tensor& out) {
|
||||
scatter_impl(self, dim, index, src, out,
|
||||
scatter_reduce_stub,
|
||||
@ -1878,7 +1878,7 @@ TORCH_IMPL_FUNC(scatter_value_reduce_out)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Scalar& value,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
const Tensor& out) {
|
||||
scatter_impl(self, dim, index, value, out,
|
||||
scatter_scalar_reduce_stub,
|
||||
@ -1919,7 +1919,7 @@ TORCH_IMPL_FUNC(scatter_reduce_two)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_self,
|
||||
const Tensor& out) {
|
||||
|
||||
|
@ -421,28 +421,28 @@ void _assert_async_cpu(const Tensor& self) {
|
||||
TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
|
||||
}
|
||||
|
||||
void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
|
||||
void _assert_async_msg_cpu(const Tensor& self, std::string_view assert_msg) {
|
||||
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) {
|
||||
void _assert_scalar(const Scalar& scalar, std::string_view assert_msg) {
|
||||
TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) {
|
||||
Tensor _functional_assert_scalar(const Scalar& scalar, std::string_view assert_msg, const Tensor& dep_token) {
|
||||
_assert_scalar(scalar, assert_msg);
|
||||
return dep_token.clone();
|
||||
}
|
||||
|
||||
Tensor _functional_assert_async_msg_cpu(
|
||||
const Tensor& self,
|
||||
c10::string_view assert_msg,
|
||||
std::string_view assert_msg,
|
||||
const Tensor& dep_token) {
|
||||
_assert_async_msg_cpu(self, assert_msg);
|
||||
return dep_token.clone();
|
||||
}
|
||||
|
||||
void _print(c10::string_view s) {
|
||||
void _print(std::string_view s) {
|
||||
std::cout << s << "\n";
|
||||
}
|
||||
|
||||
|
@ -1714,7 +1714,7 @@ Tensor tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options)
|
||||
return at::detail::tensor_complex_backend(values, options);
|
||||
}
|
||||
|
||||
Tensor from_file(c10::string_view filename, std::optional<bool> shared, std::optional<int64_t> size,
|
||||
Tensor from_file(std::string_view filename, std::optional<bool> shared, std::optional<int64_t> size,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
|
@ -3614,7 +3614,7 @@ std::vector<Tensor> meshgrid(TensorList tensors) {
|
||||
}
|
||||
|
||||
std::vector<Tensor> meshgrid(TensorList tensors,
|
||||
c10::string_view indexing) {
|
||||
std::string_view indexing) {
|
||||
int64_t size = tensors.size();
|
||||
TORCH_CHECK(size > 0, "meshgrid expects a non-empty TensorList");
|
||||
|
||||
|
@ -64,8 +64,8 @@ Tensor _test_optional_floatlist(
|
||||
}
|
||||
|
||||
// Test default strings can handle escape sequences properly (although commas are broken)
|
||||
Tensor _test_string_default(const Tensor& dummy, c10::string_view a, c10::string_view b) {
|
||||
const c10::string_view expect = "\"'\\";
|
||||
Tensor _test_string_default(const Tensor& dummy, std::string_view a, std::string_view b) {
|
||||
const std::string_view expect = "\"'\\";
|
||||
TORCH_CHECK(a == expect, "Default A failed");
|
||||
TORCH_CHECK(b == expect, "Default B failed");
|
||||
return dummy;
|
||||
@ -82,7 +82,7 @@ Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) {
|
||||
}
|
||||
|
||||
// Overload b
|
||||
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, c10::string_view b) {
|
||||
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string_view b) {
|
||||
TORCH_CHECK(a == 2);
|
||||
TORCH_CHECK(b == "2");
|
||||
return c10::scalar_to_tensor(2);
|
||||
|
@ -94,13 +94,13 @@ std::tuple<Tensor, Tensor> log_sigmoid_forward_cuda(const Tensor& input) {
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_out_cuda) (
|
||||
const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/
|
||||
const Tensor& /*self*/, std::string_view approximate, const Tensor& /*result*/
|
||||
) {
|
||||
GeluCUDAKernelImpl(*this, get_gelutype_enum(approximate));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
|
||||
const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/
|
||||
const Tensor& /*grad*/, const Tensor& /*self*/, std::string_view approximate, const Tensor& /*grad_input*/
|
||||
) {
|
||||
GeluBackwardCUDAKernelImpl(*this, get_gelutype_enum(approximate));
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ Tensor& searchsorted_out_cuda(
|
||||
const Tensor& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt,
|
||||
Tensor& result) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
@ -180,7 +180,7 @@ Tensor& searchsorted_out_cuda(
|
||||
const Scalar& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt,
|
||||
Tensor& result) {
|
||||
const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
|
||||
@ -192,7 +192,7 @@ Tensor searchsorted_cuda(
|
||||
const Tensor& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter) {
|
||||
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
|
||||
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
|
||||
@ -206,7 +206,7 @@ Tensor searchsorted_cuda(
|
||||
const Scalar& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter) {
|
||||
const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
|
||||
return searchsorted_cuda(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter);
|
||||
|
@ -1317,7 +1317,7 @@ TORCH_IMPL_FUNC(index_reduce_cuda_out)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& source,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
bool include_self,
|
||||
const Tensor& result) {
|
||||
TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
|
||||
|
@ -98,7 +98,7 @@ void lazy_linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& i
|
||||
void lazy_svd_kernel(const Tensor& A,
|
||||
const bool full_matrices,
|
||||
const bool compute_uv,
|
||||
const std::optional<c10::string_view>& driver,
|
||||
const std::optional<std::string_view>& driver,
|
||||
const Tensor& U,
|
||||
const Tensor& S,
|
||||
const Tensor& Vh,
|
||||
|
@ -156,7 +156,7 @@ template<typename ElementInputA, typename ElementInputB>
|
||||
Tensor
|
||||
mixed_dtypes_linear_dispatch_bias_activation(
|
||||
const Tensor& input, const Tensor& weight, const Tensor& scale,
|
||||
const Tensor& bias, const c10::string_view& activation) {
|
||||
const Tensor& bias, const std::string_view& activation) {
|
||||
if (bias.numel() == 0) {
|
||||
if (activation == "none") {
|
||||
return mixed_dtypes_linear_cutlass<
|
||||
@ -196,7 +196,7 @@ Tensor
|
||||
_mixed_dtypes_linear(const Tensor& input, const Tensor& weight,
|
||||
const Tensor& scale,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<c10::string_view> activation_opt) {
|
||||
const std::optional<std::string_view> activation_opt) {
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
TORCH_CHECK(false, "_mixed_dtypes_linear: not compiled for this platform");
|
||||
return Tensor{};
|
||||
|
@ -581,7 +581,7 @@ __global__ void batch_norm_backward_elemt_kernel(
|
||||
|
||||
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
|
||||
const Tensor& t, c10::string_view var_name) {
|
||||
const Tensor& t, std::string_view var_name) {
|
||||
constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const_t<scalar_t>>::value;
|
||||
const auto actual_type = t.scalar_type();
|
||||
TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
|
||||
@ -591,7 +591,7 @@ static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed
|
||||
|
||||
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
|
||||
const Tensor& t, c10::string_view var_name) {
|
||||
const Tensor& t, std::string_view var_name) {
|
||||
if (!t.defined()) {
|
||||
const std::array<index_t, dim> zeros{{0}};
|
||||
return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
|
||||
|
@ -117,7 +117,7 @@ __global__ void _assert_async_cuda_kernel(const c10::complex<double>* input, Msg
|
||||
CUDA_KERNEL_ASSERT_MSG(input[0] != c10::complex<double>(0, 0), msg.msg);
|
||||
}
|
||||
|
||||
void _assert_async_msg_cuda(const Tensor& self_tensor, c10::string_view assert_msg) {
|
||||
void _assert_async_msg_cuda(const Tensor& self_tensor, std::string_view assert_msg) {
|
||||
const TensorBase &self = get_tensor_base(self_tensor);
|
||||
auto n = self.numel();
|
||||
TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
|
||||
|
@ -2210,7 +2210,7 @@ void svd_magma(const Tensor& A,
|
||||
void svd_kernel(const Tensor& A,
|
||||
const bool full_matrices,
|
||||
const bool compute_uv,
|
||||
const std::optional<c10::string_view>& driver,
|
||||
const std::optional<std::string_view>& driver,
|
||||
const Tensor& U,
|
||||
const Tensor& S,
|
||||
const Tensor& Vh,
|
||||
|
@ -641,7 +641,7 @@ std::string _format_non_converging_batches(const std::vector<int64_t>& batches)
|
||||
void svd_cusolver(const Tensor& A,
|
||||
const bool full_matrices,
|
||||
const bool compute_uv,
|
||||
const std::optional<c10::string_view>& driver,
|
||||
const std::optional<std::string_view>& driver,
|
||||
const Tensor& U,
|
||||
const Tensor& S,
|
||||
const Tensor& V,
|
||||
@ -655,7 +655,7 @@ void svd_cusolver(const Tensor& A,
|
||||
|
||||
// The default heuristic is to use gesvdj driver
|
||||
#ifdef USE_ROCM
|
||||
const auto driver_v = c10::string_view("gesvdj");
|
||||
const auto driver_v = std::string_view("gesvdj");
|
||||
#else
|
||||
const auto driver_v = driver.value_or("gesvdj");
|
||||
#endif
|
||||
|
@ -61,7 +61,7 @@ void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tenso
|
||||
|
||||
// entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched
|
||||
void svd_cusolver(const Tensor& A, const bool full_matrices, const bool compute_uv,
|
||||
const std::optional<c10::string_view>& driver, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& info);
|
||||
const std::optional<std::string_view>& driver, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& info);
|
||||
|
||||
// entrance of calculations of `cholesky` using cusolver potrf and potrfBatched
|
||||
void cholesky_helper_cusolver(const Tensor& input, bool upper, const Tensor& info);
|
||||
|
@ -220,10 +220,10 @@ static Tensor _mkldnn_convolution(
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool use_channels_last,
|
||||
c10::string_view attr = "none",
|
||||
std::string_view attr = "none",
|
||||
torch::List<std::optional<at::Scalar>> scalars =
|
||||
torch::List<std::optional<at::Scalar>>(),
|
||||
std::optional<c10::string_view> algorithm = std::nullopt) {
|
||||
std::optional<std::string_view> algorithm = std::nullopt) {
|
||||
ideep::attr_t op_attr = ideep::attr_t();
|
||||
if (attr != "none") {
|
||||
auto it = fusion_unary_attr_map().find(attr);
|
||||
@ -304,9 +304,9 @@ Tensor mkldnn_convolution_pointwise(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
|
||||
bool use_channels_last =
|
||||
weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
|
||||
@ -342,11 +342,11 @@ Tensor mkldnn_convolution_pointwise_binary(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm) {
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
TORCH_CHECK(
|
||||
input_t.ndimension() == 4 || input_t.ndimension() == 5,
|
||||
"mkldnn_convolution_pointwise_binary: currently only support 2d and 3d")
|
||||
@ -381,7 +381,7 @@ Tensor mkldnn_convolution_pointwise_binary(
|
||||
weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
|
||||
bool can_be_fused = groups == 1 && use_channels_last;
|
||||
|
||||
c10::string_view unary_attr_value = "none";
|
||||
std::string_view unary_attr_value = "none";
|
||||
ideep::algorithm unary_alg;
|
||||
if (unary_attr.has_value()) {
|
||||
auto it_unary = fusion_unary_alg_map().find(unary_attr.value());
|
||||
@ -504,11 +504,11 @@ Tensor& mkldnn_convolution_pointwise_binary_(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm) {
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
// other_t += convolution(...), other_t = unary(other_t)
|
||||
TORCH_CHECK(
|
||||
input_t.ndimension() == 4 || input_t.ndimension() == 5,
|
||||
@ -625,10 +625,10 @@ Tensor _mkldnn_convolution_transpose(
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool use_channels_last,
|
||||
c10::string_view attr = "none",
|
||||
std::string_view attr = "none",
|
||||
torch::List<std::optional<at::Scalar>> scalars =
|
||||
torch::List<std::optional<at::Scalar>>(),
|
||||
std::optional<c10::string_view> algorithm = std::nullopt) {
|
||||
std::optional<std::string_view> algorithm = std::nullopt) {
|
||||
ideep::attr_t op_attr = ideep::attr_t();
|
||||
if (attr != "none") {
|
||||
auto it = fusion_unary_attr_map().find(attr);
|
||||
@ -720,9 +720,9 @@ Tensor mkldnn_convolution_transpose_pointwise_meta(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
|
||||
std::vector<int64_t> weight_IOHW_sizes = _original_deconv_weight_size(weight_t, groups);
|
||||
int64_t dim = input_t.ndimension() - 2;
|
||||
@ -867,9 +867,9 @@ Tensor mkldnn_convolution_transpose_pointwise(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
|
||||
bool use_channels_last =
|
||||
weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
|
||||
|
@ -15,9 +15,9 @@ C10_API Tensor mkldnn_convolution_pointwise(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
C10_API Tensor mkldnn_convolution_pointwise_binary(
|
||||
const Tensor& input_t,
|
||||
@ -28,11 +28,11 @@ C10_API Tensor mkldnn_convolution_pointwise_binary(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm);
|
||||
std::optional<std::string_view> unary_algorithm);
|
||||
|
||||
C10_API Tensor& mkldnn_convolution_pointwise_binary_(
|
||||
Tensor& other_t,
|
||||
@ -43,11 +43,11 @@ C10_API Tensor& mkldnn_convolution_pointwise_binary_(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm);
|
||||
std::optional<std::string_view> unary_algorithm);
|
||||
|
||||
Tensor mkldnn_convolution_transpose_pointwise(
|
||||
const Tensor& input_t,
|
||||
@ -58,9 +58,9 @@ Tensor mkldnn_convolution_transpose_pointwise(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -14,11 +14,11 @@
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
Tensor mkldnn_gelu(const Tensor& input, c10::string_view approximate) {
|
||||
Tensor mkldnn_gelu(const Tensor& input, std::string_view approximate) {
|
||||
TORCH_CHECK(false, "mkldnn_gelu: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, c10::string_view approximate) {
|
||||
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, std::string_view approximate) {
|
||||
TORCH_CHECK(false, "mkldnn_gelu_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
@ -31,7 +31,7 @@ Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, c10:
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Tensor mkldnn_gelu(const Tensor& input, c10::string_view approximate) {
|
||||
Tensor mkldnn_gelu(const Tensor& input, std::string_view approximate) {
|
||||
if (input.scalar_type() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(mkldnn_bf16_device_check(),
|
||||
"mkldnn_gelu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
|
||||
@ -46,7 +46,7 @@ Tensor mkldnn_gelu(const Tensor& input, c10::string_view approximate) {
|
||||
input.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, c10::string_view approximate) {
|
||||
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, std::string_view approximate) {
|
||||
TORCH_CHECK(get_gelutype_enum(approximate) == GeluType::None,
|
||||
"mkldnn_gelu_backward: fast, approximate gelu is not supported");
|
||||
const ideep::tensor& x = itensor_from_tensor(input);
|
||||
|
@ -184,9 +184,9 @@ Tensor mkldnn_linear_pointwise(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
c10::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
auto input = input_t.contiguous();
|
||||
auto input_size = input.sizes();
|
||||
|
||||
@ -259,7 +259,7 @@ Tensor mkldnn_linear_pointwise_binary(
|
||||
const Tensor& other_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
c10::string_view attr) {
|
||||
std::string_view attr) {
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned =
|
||||
at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
@ -10,16 +10,16 @@ C10_API Tensor mkldnn_linear_pointwise(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
c10::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
C10_API Tensor mkldnn_linear_pointwise_binary(
|
||||
const Tensor& input_t,
|
||||
const Tensor& other_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
c10::string_view attr);
|
||||
std::string_view attr);
|
||||
|
||||
#if AT_MKL_ENABLED()
|
||||
|
||||
|
@ -80,13 +80,13 @@ void check_mkldnn_binary_fusion_inputs(
|
||||
|
||||
#define ATTR_FUNC(NAME) \
|
||||
[](torch::List<std::optional<at::Scalar>> scalars, \
|
||||
std::optional<c10::string_view> algorithm) { \
|
||||
std::optional<std::string_view> algorithm) { \
|
||||
return ideep::attr_t::fuse_##NAME(); \
|
||||
}
|
||||
|
||||
AttrFunction attr_func_leaky_relu =
|
||||
[](torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
TORCH_CHECK(
|
||||
scalars.size() == 1 &&
|
||||
scalars[0].get().toOptional<at::Scalar>().has_value(),
|
||||
@ -98,7 +98,7 @@ AttrFunction attr_func_leaky_relu =
|
||||
|
||||
AttrFunction attr_func_hardtanh =
|
||||
[](torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
TORCH_CHECK(
|
||||
scalars.size() == 2 &&
|
||||
scalars[0].get().toOptional<at::Scalar>().has_value() &&
|
||||
@ -113,7 +113,7 @@ AttrFunction attr_func_hardtanh =
|
||||
};
|
||||
|
||||
AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
TORCH_CHECK(
|
||||
algorithm.has_value(),
|
||||
"gelu is expected to have one str input: algorithm");
|
||||
@ -132,7 +132,7 @@ AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
|
||||
|
||||
AttrFunction attr_func_hardsigmoid =
|
||||
[](torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
ideep::attr_t attr;
|
||||
ideep::post_ops po;
|
||||
po.append_eltwise(
|
||||
@ -141,8 +141,8 @@ AttrFunction attr_func_hardsigmoid =
|
||||
return attr;
|
||||
};
|
||||
|
||||
const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map() {
|
||||
static const std::map<c10::string_view, AttrFunction> fusion_attr_map{
|
||||
const std::map<std::string_view, AttrFunction>& fusion_unary_attr_map() {
|
||||
static const std::map<std::string_view, AttrFunction> fusion_attr_map{
|
||||
{"relu", ATTR_FUNC(relu)},
|
||||
{"sigmoid", ATTR_FUNC(sigmoid)},
|
||||
{"tanh", ATTR_FUNC(tanh)},
|
||||
@ -156,15 +156,15 @@ const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map() {
|
||||
return fusion_attr_map;
|
||||
}
|
||||
|
||||
const std::map<c10::string_view, ideep::algorithm>& fusion_unary_alg_map() {
|
||||
static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
|
||||
const std::map<std::string_view, ideep::algorithm>& fusion_unary_alg_map() {
|
||||
static const std::map<std::string_view, ideep::algorithm> fusion_attr_map{
|
||||
{"relu", {ideep::algorithm::eltwise_relu}},
|
||||
};
|
||||
return fusion_attr_map;
|
||||
}
|
||||
|
||||
const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map() {
|
||||
static const std::map<c10::string_view, ideep::algorithm> fusion_attr_map{
|
||||
const std::map<std::string_view, ideep::algorithm>& fusion_binary_alg_map() {
|
||||
static const std::map<std::string_view, ideep::algorithm> fusion_attr_map{
|
||||
{"add", {ideep::algorithm::binary_add}},
|
||||
{"sub", {ideep::algorithm::binary_sub}},
|
||||
{"mul", {ideep::algorithm::binary_mul}},
|
||||
|
@ -74,13 +74,13 @@ inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) {
|
||||
|
||||
using AttrFunction = std::function<ideep::attr_t(
|
||||
torch::List<std::optional<at::Scalar>>,
|
||||
std::optional<c10::string_view>)>;
|
||||
std::optional<std::string_view>)>;
|
||||
|
||||
const std::map<c10::string_view, AttrFunction>& fusion_unary_attr_map();
|
||||
const std::map<std::string_view, AttrFunction>& fusion_unary_attr_map();
|
||||
|
||||
const std::map<c10::string_view, ideep::algorithm>& fusion_unary_alg_map();
|
||||
const std::map<std::string_view, ideep::algorithm>& fusion_unary_alg_map();
|
||||
|
||||
const std::map<c10::string_view, ideep::algorithm>& fusion_binary_alg_map();
|
||||
const std::map<std::string_view, ideep::algorithm>& fusion_binary_alg_map();
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
||||
}
|
||||
|
@ -368,9 +368,9 @@ class Attr {
|
||||
};
|
||||
|
||||
static inline void construct_attr_for_unary(
|
||||
const c10::string_view& unary_post_op,
|
||||
const std::string_view& unary_post_op,
|
||||
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
const c10::string_view& unary_post_op_algorithm,
|
||||
const std::string_view& unary_post_op_algorithm,
|
||||
at::native::onednn::Attr& attr) {
|
||||
if (unary_post_op == "relu") {
|
||||
attr = attr.append_post_eltwise(
|
||||
@ -406,13 +406,13 @@ static inline void construct_attr_for_unary(
|
||||
}
|
||||
|
||||
static inline void construct_attr_by_post_op(
|
||||
const c10::string_view& binary_post_op,
|
||||
const std::string_view& binary_post_op,
|
||||
double binary_alpha,
|
||||
double input1_scale,
|
||||
int64_t input1_zero_point,
|
||||
const c10::string_view& unary_post_op,
|
||||
const std::string_view& unary_post_op,
|
||||
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
const c10::string_view& unary_post_op_algorithm,
|
||||
const std::string_view& unary_post_op_algorithm,
|
||||
at::native::onednn::Attr& attr) {
|
||||
bool is_none_post_op =
|
||||
(binary_post_op == "none" && unary_post_op == "none"); // not post-ops
|
||||
|
@ -67,11 +67,11 @@ at::Tensor quantized_convolution(
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
c10::optional<c10::ScalarType> output_dtype,
|
||||
c10::optional<c10::string_view> binary_attr,
|
||||
c10::optional<std::string_view> binary_attr,
|
||||
c10::optional<at::Scalar> binary_alpha,
|
||||
c10::optional<c10::string_view> unary_attr,
|
||||
c10::optional<std::string_view> unary_attr,
|
||||
torch::List<c10::optional<at::Scalar>> unary_scalars,
|
||||
c10::optional<c10::string_view> unary_algorithm) {
|
||||
c10::optional<std::string_view> unary_algorithm) {
|
||||
Attr attr =
|
||||
Attr(/*q_scale=*/1.0 / inv_output_scale, /*zp=*/output_zero_point);
|
||||
|
||||
|
@ -127,10 +127,10 @@ at::Tensor quantized_convolution(
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
c10::optional<c10::ScalarType> output_dtype,
|
||||
c10::optional<c10::string_view> binary_attr,
|
||||
c10::optional<std::string_view> binary_attr,
|
||||
c10::optional<at::Scalar> binary_alpha,
|
||||
c10::optional<c10::string_view> unary_attr,
|
||||
c10::optional<std::string_view> unary_attr,
|
||||
torch::List<c10::optional<at::Scalar>> unary_scalars,
|
||||
c10::optional<c10::string_view> unary_algorithm);
|
||||
c10::optional<std::string_view> unary_algorithm);
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
@ -39,9 +39,9 @@ class QConvoneDNNXPU final {
|
||||
double inv_output_scale,
|
||||
int64_t output_zero_point,
|
||||
c10::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<c10::optional<at::Scalar>> scalars,
|
||||
c10::optional<c10::string_view> algorithm) {
|
||||
c10::optional<std::string_view> algorithm) {
|
||||
if (act.dim() == 3 || act.dim() == 5) {
|
||||
TORCH_CHECK(
|
||||
attr == "none",
|
||||
|
@ -662,7 +662,7 @@ static MPSGraphTensor* tanh(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return erfTensor;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate, const Tensor& output) {
|
||||
TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, std::string_view approximate, const Tensor& output) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
TORCH_CHECK(output.is_mps());
|
||||
@ -711,7 +711,7 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate,
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(gelu_backward_out_mps)
|
||||
(const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input) {
|
||||
(const Tensor& grad, const Tensor& self, std::string_view approximate, const Tensor& grad_input) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryGradCachedGraph;
|
||||
|
||||
|
@ -199,7 +199,7 @@ static void binaryOpScalar(const Tensor& self,
|
||||
|
||||
static void div_mode_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
std::optional<c10::string_view> rounding_mode,
|
||||
std::optional<std::string_view> rounding_mode,
|
||||
const Tensor& output,
|
||||
const string op_name) {
|
||||
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
|
||||
@ -405,7 +405,7 @@ TORCH_IMPL_FUNC(atan2_out_mps)(const Tensor& self, const Tensor& other, const Te
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(div_out_mode_mps)
|
||||
(const Tensor& self, const Tensor& other, std::optional<c10::string_view> rounding_mode, const Tensor& output) {
|
||||
(const Tensor& self, const Tensor& other, std::optional<std::string_view> rounding_mode, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out");
|
||||
}
|
||||
|
||||
|
@ -75,7 +75,7 @@ Tensor& searchsorted_out_mps(const Tensor& sorted_sequence,
|
||||
const Tensor& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt,
|
||||
Tensor& result) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
@ -120,7 +120,7 @@ Tensor& searchsorted_out_mps(const Tensor& sorted_sequence,
|
||||
const Scalar& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter_opt,
|
||||
Tensor& result) {
|
||||
const Tensor& scalar_tensor = mps::wrapped_scalar_tensor_mps(self, sorted_sequence.device());
|
||||
@ -131,7 +131,7 @@ Tensor searchsorted_mps(const Tensor& sorted_sequence,
|
||||
const Tensor& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter) {
|
||||
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
|
||||
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
|
||||
@ -144,7 +144,7 @@ Tensor searchsorted_mps(const Tensor& sorted_sequence,
|
||||
const Scalar& self,
|
||||
bool out_int32,
|
||||
bool right,
|
||||
const std::optional<c10::string_view> side_opt,
|
||||
const std::optional<std::string_view> side_opt,
|
||||
const std::optional<Tensor>& sorter) {
|
||||
const Tensor& scalar_tensor = mps::wrapped_scalar_tensor_mps(self, sorted_sequence.device());
|
||||
return searchsorted_mps(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter);
|
||||
|
@ -114,7 +114,7 @@ static void scatter_mps_general(const Tensor& self_arg,
|
||||
const Tensor& src,
|
||||
const Tensor& output,
|
||||
string func_name,
|
||||
const c10::string_view reduce) {
|
||||
const std::string_view reduce) {
|
||||
using namespace mps;
|
||||
|
||||
if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
|
||||
@ -317,7 +317,7 @@ TORCH_IMPL_FUNC(scatter_reduce_out_mps)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
const Tensor& output) {
|
||||
scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce);
|
||||
}
|
||||
@ -327,7 +327,7 @@ TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const Scalar& value,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
const Tensor& output) {
|
||||
Tensor src =
|
||||
at::empty(index.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, self.suggest_memory_format());
|
||||
|
@ -47,7 +47,7 @@ static void upsample_out_template(const Tensor& input,
|
||||
std::optional<double> scale_w_opt,
|
||||
const Tensor& output,
|
||||
bool align_corners,
|
||||
const c10::string_view resize_mode_str) {
|
||||
const std::string_view resize_mode_str) {
|
||||
if (input.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -174,7 +174,7 @@ Tensor _nested_select_backward_symint(
|
||||
return nt_grad;
|
||||
}
|
||||
|
||||
Tensor gelu_backwards_nested(const Tensor& grad, const Tensor& self, c10::string_view approximate){
|
||||
Tensor gelu_backwards_nested(const Tensor& grad, const Tensor& self, std::string_view approximate){
|
||||
auto partial_gelu_backward = [approximate](auto && PH1, auto && PH2) { return at::gelu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), approximate); };
|
||||
return map_nt_binary(grad, self, partial_gelu_backward);
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ namespace {
|
||||
inline void check_nested_tensor_matrix_constraints(
|
||||
const Tensor& nested_tensor,
|
||||
const Tensor& dense_matrix,
|
||||
c10::string_view caller) {
|
||||
std::string_view caller) {
|
||||
auto* nt_input = get_nested_tensor_impl(nested_tensor);
|
||||
TORCH_INTERNAL_ASSERT(nt_input != nullptr);
|
||||
TORCH_CHECK(
|
||||
@ -59,7 +59,7 @@ Tensor nested_linear(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const std::optional<Tensor>& bias_opt) {
|
||||
check_nested_tensor_matrix_constraints(input, weight, c10::string_view{"Linear"});
|
||||
check_nested_tensor_matrix_constraints(input, weight, std::string_view{"Linear"});
|
||||
auto* nt_input = get_nested_tensor_impl(input);
|
||||
const Tensor& input_buffer = nt_input->get_buffer();
|
||||
Tensor result_buffer =
|
||||
@ -73,7 +73,7 @@ Tensor nested_linear(
|
||||
}
|
||||
|
||||
Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) {
|
||||
check_nested_tensor_matrix_constraints(self, other, c10::string_view{"Matmul"});
|
||||
check_nested_tensor_matrix_constraints(self, other, std::string_view{"Matmul"});
|
||||
auto* nt_self = get_nested_tensor_impl_or_null(self);
|
||||
const Tensor& self_buffer = nt_self->get_buffer();
|
||||
Tensor result_buffer =
|
||||
|
@ -140,7 +140,7 @@ Tensor& NestedTensor_relu_(Tensor& self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
|
||||
Tensor& NestedTensor_gelu_(Tensor& self, std::string_view approximate) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
auto buffer = self_ptr->get_buffer();
|
||||
@ -148,7 +148,7 @@ Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_gelu(const Tensor& self, c10::string_view approximate) {
|
||||
Tensor NestedTensor_gelu(const Tensor& self, std::string_view approximate) {
|
||||
return map_nt(
|
||||
self,
|
||||
[approximate](const Tensor& buffer) {
|
||||
|
@ -310,14 +310,14 @@ struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
|
||||
namespace onednn_utils {
|
||||
|
||||
inline ideep::attr_t create_attr_by_post_op(
|
||||
const c10::string_view& binary_post_op,
|
||||
const std::string_view& binary_post_op,
|
||||
double binary_alpha,
|
||||
double input1_scale,
|
||||
int64_t input1_zero_point,
|
||||
const ideep::tensor::desc& input1_desc,
|
||||
const c10::string_view& unary_post_op,
|
||||
const std::string_view& unary_post_op,
|
||||
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
const c10::string_view& unary_post_op_algorithm) {
|
||||
const std::string_view& unary_post_op_algorithm) {
|
||||
using ideep::tensor;
|
||||
if (binary_post_op == "none") {
|
||||
if (unary_post_op == "relu") {
|
||||
|
@ -1404,11 +1404,11 @@ static at::Tensor _quantized_convolution_onednn(
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
std::optional<c10::string_view> binary_attr,
|
||||
std::optional<std::string_view> binary_attr,
|
||||
std::optional<at::Scalar> binary_alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm) {
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
/*********************************/
|
||||
/* Checks */
|
||||
/*********************************/
|
||||
@ -1754,9 +1754,9 @@ namespace at::native {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
|
||||
if (act.dim() == 3 || act.dim() == 5) {
|
||||
@ -1805,9 +1805,9 @@ namespace at::native {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm) {
|
||||
std::optional<std::string_view> algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
@ -1844,11 +1844,11 @@ namespace at::native {
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm) {
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
// Conv2D post op check
|
||||
TORCH_CHECK(
|
||||
@ -1897,11 +1897,11 @@ namespace at::native {
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm) {
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
|
@ -23,9 +23,9 @@ class QConvoneDNN final {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
C10_API static at::Tensor run_pointwise_tensor(
|
||||
at::Tensor act, // contains quantized values but not QTensor
|
||||
@ -42,9 +42,9 @@ class QConvoneDNN final {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view attr,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<c10::string_view> algorithm);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
C10_API static at::Tensor run_pointwise_binary(
|
||||
at::Tensor act, // contains quantized values but not QTensor
|
||||
@ -64,11 +64,11 @@ class QConvoneDNN final {
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm);
|
||||
std::optional<std::string_view> unary_algorithm);
|
||||
|
||||
C10_API static at::Tensor run_pointwise_binary_tensor(
|
||||
at::Tensor act, // contains quantized values but not QTensor
|
||||
@ -88,11 +88,11 @@ class QConvoneDNN final {
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
c10::string_view binary_attr,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<c10::string_view> unary_attr,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<c10::string_view> unary_algorithm);
|
||||
std::optional<std::string_view> unary_algorithm);
|
||||
|
||||
};
|
||||
|
||||
|
@ -12,13 +12,13 @@ namespace at::native {
|
||||
|
||||
DEFINE_DISPATCH(qgelu_stub);
|
||||
|
||||
Tensor gelu_quantized_cpu(const Tensor& qx, c10::string_view approximate) {
|
||||
Tensor gelu_quantized_cpu(const Tensor& qx, std::string_view approximate) {
|
||||
Tensor qy;
|
||||
qgelu_stub(qx.device().type(), qx, qy, get_gelutype_enum(approximate));
|
||||
return qy;
|
||||
}
|
||||
|
||||
Tensor& gelu_quantized_cpu_(Tensor& self, c10::string_view approximate) {
|
||||
Tensor& gelu_quantized_cpu_(Tensor& self, std::string_view approximate) {
|
||||
Tensor qy = gelu_quantized_cpu(self, approximate);
|
||||
// This can be optimized in a future PR if it becomes a bottleneck.
|
||||
self.copy_(qy);
|
||||
|
@ -924,11 +924,11 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
std::optional<at::Tensor> other, // extra input for binary post-op
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
const c10::string_view& binary_post_op, // e.g. "none", "sum", "add"
|
||||
const std::string_view& binary_post_op, // e.g. "none", "sum", "add"
|
||||
double binary_alpha,
|
||||
const c10::string_view& unary_post_op, // e.g. "none", "relu"
|
||||
const std::string_view& unary_post_op, // e.g. "none", "relu"
|
||||
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
c10::string_view& unary_post_op_algorithm) {
|
||||
std::string_view& unary_post_op_algorithm) {
|
||||
using ideep::tensor;
|
||||
const int64_t dim = input.dim();
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
|
||||
@ -1114,14 +1114,14 @@ namespace at::native {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view post_op_name,
|
||||
std::string_view post_op_name,
|
||||
torch::List<std::optional<at::Scalar>> post_op_args,
|
||||
c10::string_view post_op_algorithm) {
|
||||
std::string_view post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
static std::optional<at::Tensor> other = std::nullopt;
|
||||
static const c10::string_view binary_post_op = "none";
|
||||
static const std::string_view binary_post_op = "none";
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
|
||||
onednn_weight, weight_scales, weight_zero_points,
|
||||
@ -1148,11 +1148,11 @@ namespace at::native {
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
std::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op, // e.g. "none", "relu"
|
||||
std::string_view unary_post_op, // e.g. "none", "relu"
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
std::string_view unary_post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
@ -1268,12 +1268,12 @@ class QLinearOnednn final {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view post_op_name,
|
||||
std::string_view post_op_name,
|
||||
torch::List<std::optional<at::Scalar>> post_op_args,
|
||||
c10::string_view post_op_algorithm) {
|
||||
std::string_view post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
static std::optional<at::Tensor> other = std::nullopt;
|
||||
static const c10::string_view binary_post_op = "none";
|
||||
static const std::string_view binary_post_op = "none";
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale, act_zero_point,
|
||||
onednn_weight, weight_scales, weight_zero_points,
|
||||
@ -1300,11 +1300,11 @@ class QLinearOnednn final {
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
std::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op, // e.g. "none", "relu"
|
||||
std::string_view unary_post_op, // e.g. "none", "relu"
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
std::string_view unary_post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale, act_zero_point,
|
||||
|
@ -17,9 +17,9 @@ class QLinearOnednn final {
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
c10::string_view post_op_name,
|
||||
std::string_view post_op_name,
|
||||
torch::List<std::optional<at::Scalar>> post_op_args,
|
||||
c10::string_view post_op_algorithm);
|
||||
std::string_view post_op_algorithm);
|
||||
|
||||
C10_API static Tensor run_pointwise_binary_tensor(
|
||||
Tensor act, // int8 CPU tensor, not QTensor
|
||||
@ -35,11 +35,11 @@ C10_API static Tensor run_pointwise_binary_tensor(
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
std::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op, // e.g. "none", "relu"
|
||||
std::string_view unary_post_op, // e.g. "none", "relu"
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm);
|
||||
std::string_view unary_post_op_algorithm);
|
||||
};
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -7,7 +7,7 @@ namespace at::native {
|
||||
// this kernel is currently implemented with dequantize -> fp32 gelu -> quantize, which is not equivalent to int8 gelu
|
||||
// It might be possible to write a variant of the int8 gelu that's equivalent to dequantize -> fp32 cuda gelu kernel -> quantize,
|
||||
// which can be a topic for future work.
|
||||
Tensor gelu_quantized_cuda(const Tensor& qx, c10::string_view approximate) {
|
||||
Tensor gelu_quantized_cuda(const Tensor& qx, std::string_view approximate) {
|
||||
(void)approximate; // suppress unused variable lint warning
|
||||
if (qx.numel() == 0) {
|
||||
return Tensor{};
|
||||
|
@ -1359,7 +1359,7 @@ Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, boo
|
||||
std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_sparse_csr_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
const c10::string_view reduce) {
|
||||
const std::string_view reduce) {
|
||||
|
||||
auto layout = self.layout();
|
||||
TORCH_CHECK(layout == kSparseCsr,
|
||||
@ -1411,7 +1411,7 @@ std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_backward_sparse_csr_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& other,
|
||||
const c10::string_view reduce,
|
||||
const std::string_view reduce,
|
||||
const Tensor& arg_out,
|
||||
std::array<bool, 2> output_mask) {
|
||||
|
||||
|
@ -21,7 +21,7 @@ inline bool _is_sparse_and_zero(const Tensor& self) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void _check_is_cpu(const Tensor& self, c10::string_view name) {
|
||||
inline void _check_is_cpu(const Tensor& self, std::string_view name) {
|
||||
TORCH_CHECK(
|
||||
self.is_cpu(),
|
||||
"Expected all tensors to be on the same device. addmm expected '",
|
||||
@ -31,7 +31,7 @@ inline void _check_is_cpu(const Tensor& self, c10::string_view name) {
|
||||
" tensor");
|
||||
}
|
||||
|
||||
inline void _check_is_cuda(const Tensor& self, c10::string_view name) {
|
||||
inline void _check_is_cuda(const Tensor& self, std::string_view name) {
|
||||
TORCH_CHECK(
|
||||
self.is_cuda(),
|
||||
"Expected all tensors to be on the same device. addmm expected '",
|
||||
@ -41,7 +41,7 @@ inline void _check_is_cuda(const Tensor& self, c10::string_view name) {
|
||||
" tensor");
|
||||
}
|
||||
|
||||
inline void _check_dim(const Tensor& self, int64_t target_dim, c10::string_view name) {
|
||||
inline void _check_dim(const Tensor& self, int64_t target_dim, std::string_view name) {
|
||||
if (target_dim == 2) {
|
||||
TORCH_CHECK(
|
||||
self.dim() == target_dim,
|
||||
|
@ -220,7 +220,7 @@ static SparseTensor& coalesce_(SparseTensor& tensor) {
|
||||
// div(SparseTensor, Scalar)
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, std::optional<c10::string_view> rounding_mode, SparseTensor& r) {
|
||||
SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, std::optional<std::string_view> rounding_mode, SparseTensor& r) {
|
||||
TORCH_CHECK(value.dim() == 0, "Sparse division requires a scalar or ",
|
||||
"zero-dim dense tensor divisor (got shape ", value.sizes(), " for divisor)");
|
||||
TORCH_CHECK(!value.is_sparse(), "Sparse division requires a scalar or ",
|
||||
@ -270,7 +270,7 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value) {
|
||||
return div_out_sparse_zerodim(self, value, self);
|
||||
}
|
||||
|
||||
Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<std::string_view> rounding_mode) {
|
||||
auto commonDtype = at::result_type(self, value);
|
||||
if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) {
|
||||
commonDtype = typeMetaToScalarType(at::get_default_dtype());
|
||||
@ -279,7 +279,7 @@ Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<c10::st
|
||||
return div_out_sparse_zerodim(self, value, std::move(rounding_mode), result);
|
||||
}
|
||||
|
||||
Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
|
||||
Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional<std::string_view> rounding_mode) {
|
||||
return div_out_sparse_zerodim(self, value, std::move(rounding_mode), self);
|
||||
}
|
||||
|
||||
@ -1393,7 +1393,7 @@ SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
|
||||
return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch!
|
||||
}
|
||||
|
||||
Tensor _sparse_mm(const Tensor& mat1, const Tensor& mat2, const c10::string_view reduce) {
|
||||
Tensor _sparse_mm(const Tensor& mat1, const Tensor& mat2, const std::string_view reduce) {
|
||||
// result: out, arg_out
|
||||
auto result = at::_sparse_mm_reduce_impl(mat1, mat2, reduce);
|
||||
return std::get<0>(result);
|
||||
|
@ -507,7 +507,7 @@ template <
|
||||
bool EnableActivationSiLU>
|
||||
Tensor two_four_sgemm_dispatch_layouts_bias_activation(
|
||||
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
|
||||
const Tensor& meta, const c10::string_view& activation) {
|
||||
const Tensor& meta, const std::string_view& activation) {
|
||||
// Perform dispatching.
|
||||
if constexpr (EnableActivationNone) {
|
||||
if (activation == "none") {
|
||||
@ -601,7 +601,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
|
||||
Tensor _sparse_semi_structured_linear(
|
||||
const Tensor& input, const Tensor& weight,
|
||||
const Tensor& meta, const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<c10::string_view> activation_opt,
|
||||
const std::optional<std::string_view> activation_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt) {
|
||||
TORCH_WARN_ONCE("_sparse_semi_structured_linear is deprecated and will be "
|
||||
"removed in a future PyTorch release. Please use "
|
||||
|
@ -277,7 +277,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> sparse_semi_structured_tile_t
|
||||
// <packed, packed_meta_reordered, packed_trans, packed_trans_meta_reorderd, threads_masks>
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _sparse_semi_structured_tile(
|
||||
const Tensor& input,
|
||||
c10::string_view algorithm,
|
||||
std::string_view algorithm,
|
||||
bool use_cutlass)
|
||||
{
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
|
@ -100,7 +100,7 @@ struct IntArrayRefCaster<TargetType, 4> {
|
||||
|
||||
|
||||
template<int Rank = 4>
|
||||
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
|
||||
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, std::string_view tensor_name)
|
||||
{
|
||||
const auto strides = q.strides();
|
||||
int real_rank = strides.size();
|
||||
|
@ -112,7 +112,7 @@ inline bool try_broadcast_param_size(
|
||||
const c10::SymInt q_size,
|
||||
const c10::SymInt k_size,
|
||||
const c10::SymInt v_size,
|
||||
c10::string_view param_name,
|
||||
std::string_view param_name,
|
||||
bool debug) {
|
||||
auto max_size = std::max({q_size, k_size, v_size});
|
||||
if ((q_size != max_size && q_size != 1) ||
|
||||
@ -140,7 +140,7 @@ inline bool try_broadcast_param_size(
|
||||
|
||||
inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
|
||||
at::Tensor const& param,
|
||||
c10::string_view param_name,
|
||||
std::string_view param_name,
|
||||
bool debug) {
|
||||
const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
|
||||
const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
|
||||
|
@ -501,7 +501,7 @@ Tensor& activation_scalar_(
|
||||
return self_arg;
|
||||
}
|
||||
|
||||
Tensor gelu(const Tensor& self, c10::string_view approximate) {
|
||||
Tensor gelu(const Tensor& self, std::string_view approximate) {
|
||||
TORCH_CHECK(
|
||||
approximate == "tanh", "Vulkan: gelu only supported for tanh type");
|
||||
Scalar kBetaVec = M_SQRT2 * M_2_SQRTPI * 0.5;
|
||||
@ -521,7 +521,7 @@ Tensor gelu(const Tensor& self, c10::string_view approximate) {
|
||||
return ops::activation_scalar(self, scalar, VK_KERNEL(gelu_tanh));
|
||||
}
|
||||
|
||||
Tensor& gelu_(Tensor& self, c10::string_view approximate) {
|
||||
Tensor& gelu_(Tensor& self, std::string_view approximate) {
|
||||
TORCH_CHECK(
|
||||
approximate == "tanh", "Vulkan: gelu only supported for tanh type");
|
||||
Scalar kBetaVec = M_SQRT2 * M_2_SQRTPI * 0.5;
|
||||
|
@ -8,7 +8,7 @@ namespace at {
|
||||
namespace {
|
||||
|
||||
// Verifies the requested type is the same as the Tensor's type.
|
||||
void check_type(const TensorBase& tensor, ScalarType type, c10::string_view type_name) {
|
||||
void check_type(const TensorBase& tensor, ScalarType type, std::string_view type_name) {
|
||||
TORCH_CHECK(
|
||||
tensor.scalar_type() == type
|
||||
|| (isQIntType(tensor.scalar_type())
|
||||
|
Reference in New Issue
Block a user