mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/N] Move c10::variant to std::variant (#109723)
This PR moves most of c10::variant calls to std::variant. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109723 Approved by: https://github.com/ezyang
This commit is contained in:
@ -3,7 +3,6 @@
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
@ -11,6 +10,7 @@
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
@ -93,7 +93,7 @@ static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
|
||||
template <bool IS_INPUT>
|
||||
struct OffsetCalculatorVariant {
|
||||
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
|
||||
using OffsetCalculatorTypes = c10::variant<
|
||||
using OffsetCalculatorTypes = std::variant<
|
||||
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
@ -113,7 +113,7 @@ struct OffsetCalculatorVariant {
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -123,7 +123,7 @@ struct OffsetCalculatorVariant {
|
||||
struct ArrayVariant {
|
||||
// works for up to 8 input + 8 outputs
|
||||
#define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
|
||||
using ArrayTypes = c10::variant<
|
||||
using ArrayTypes = std::variant<
|
||||
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
@ -142,7 +142,7 @@ struct ArrayVariant {
|
||||
TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
|
||||
}
|
||||
|
||||
c10::visit([&](auto& a) {
|
||||
std::visit([&](auto& a) {
|
||||
for (auto i = 0; i < ntensors; ++i) {
|
||||
a[i] = (char*)iter.data_ptr(i);
|
||||
}
|
||||
@ -150,7 +150,7 @@ struct ArrayVariant {
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & a){ return static_cast<void*>(&a); }, array);
|
||||
return std::visit([](auto & a){ return static_cast<void*>(&a); }, array);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -159,7 +159,7 @@ private:
|
||||
|
||||
struct TrivialOffsetCalculatorVariant {
|
||||
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>
|
||||
using TrivialOffsetCalculatorTypes = c10::variant<
|
||||
using TrivialOffsetCalculatorTypes = std::variant<
|
||||
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
@ -178,7 +178,7 @@ struct TrivialOffsetCalculatorVariant {
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(&v); }, v);
|
||||
return std::visit([](auto & v){ return static_cast<void*>(&v); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -187,7 +187,7 @@ private:
|
||||
|
||||
struct LoadWithCastVariant {
|
||||
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
|
||||
using LoadWithCastPtr = c10::variant<
|
||||
using LoadWithCastPtr = std::variant<
|
||||
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
@ -207,7 +207,7 @@ struct LoadWithCastVariant {
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -216,7 +216,7 @@ private:
|
||||
|
||||
struct StoreWithCastVariant {
|
||||
#define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
|
||||
using StoreWithCastPtr = c10::variant<
|
||||
using StoreWithCastPtr = std::variant<
|
||||
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
|
||||
>;
|
||||
#undef DEFINE_CASE
|
||||
@ -236,7 +236,7 @@ struct StoreWithCastVariant {
|
||||
}
|
||||
|
||||
void* data_ptr() {
|
||||
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -9,9 +9,6 @@
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/variant.h>
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <ATen/functorch/Interpreter.h>
|
||||
#include <ATen/functorch/VmapInterpreter.h>
|
||||
|
@ -3,11 +3,11 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/variant.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER <= 1900
|
||||
@ -115,7 +115,7 @@ class C10_API Warning {
|
||||
class C10_API UserWarning {};
|
||||
class C10_API DeprecationWarning {};
|
||||
|
||||
using warning_variant_t = c10::variant<UserWarning, DeprecationWarning>;
|
||||
using warning_variant_t = std::variant<UserWarning, DeprecationWarning>;
|
||||
|
||||
Warning(
|
||||
warning_variant_t type,
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/in_place.h>
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
namespace c10 {
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <variant>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
@ -13,7 +14,7 @@
|
||||
}
|
||||
|
||||
TEST(EnumTest, AllEnums) {
|
||||
c10::variant<
|
||||
std::variant<
|
||||
torch::enumtype::kLinear,
|
||||
torch::enumtype::kConv1D,
|
||||
torch::enumtype::kConv2D,
|
||||
|
@ -951,11 +951,11 @@ TEST(ExternalCall, JitCustomFusionOp) {
|
||||
torch::jit::tensorexpr::BufHandle result_buf(
|
||||
"nnc_add_mul_res_buf", output_shape, output_dtype);
|
||||
const torch::jit::tensorexpr::BufHandle& a =
|
||||
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
|
||||
std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
|
||||
const torch::jit::tensorexpr::BufHandle& b =
|
||||
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
|
||||
std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
|
||||
const torch::jit::tensorexpr::BufHandle& c =
|
||||
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
|
||||
std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
|
||||
torch::jit::tensorexpr::StmtPtr s =
|
||||
torch::jit::tensorexpr::ExternalCall::make(
|
||||
result_buf, external_func_name, {a, b, c}, {});
|
||||
|
@ -1667,7 +1667,7 @@ Tensor lowerNanToNum(
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto input_buf = c10::get<BufHandle>(inputs[0]);
|
||||
auto input_buf = std::get<BufHandle>(inputs[0]);
|
||||
auto e = Compute(
|
||||
"custom_nan_to_num",
|
||||
outputShape,
|
||||
|
@ -286,7 +286,7 @@ PyObject* map_warning_to_python_type(const c10::Warning& warning) {
|
||||
return PyExc_DeprecationWarning;
|
||||
}
|
||||
};
|
||||
return c10::visit(Visitor(), warning.type());
|
||||
return std::visit(Visitor(), warning.type());
|
||||
}
|
||||
|
||||
/// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
|
||||
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
|
||||
#define TORCH_ENUM_DECLARE(name) \
|
||||
@ -42,7 +42,7 @@
|
||||
//
|
||||
// ```
|
||||
// struct TORCH_API SomeOptions {
|
||||
// typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
// typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
// reduction_t; SomeOptions(reduction_t reduction = torch::kMean) :
|
||||
// reduction_(reduction) {}
|
||||
//
|
||||
@ -188,16 +188,16 @@ struct _compute_enum_name {
|
||||
|
||||
template <typename V>
|
||||
std::string get_enum_name(V variant_enum) {
|
||||
return c10::visit(enumtype::_compute_enum_name{}, variant_enum);
|
||||
return std::visit(enumtype::_compute_enum_name{}, variant_enum);
|
||||
}
|
||||
|
||||
template <typename V>
|
||||
at::Reduction::Reduction reduction_get_enum(V variant_enum) {
|
||||
if (c10::get_if<enumtype::kNone>(&variant_enum)) {
|
||||
if (std::holds_alternative<enumtype::kNone>(variant_enum)) {
|
||||
return at::Reduction::None;
|
||||
} else if (c10::get_if<enumtype::kMean>(&variant_enum)) {
|
||||
} else if (std::holds_alternative<enumtype::kMean>(variant_enum)) {
|
||||
return at::Reduction::Mean;
|
||||
} else if (c10::get_if<enumtype::kSum>(&variant_enum)) {
|
||||
} else if (std::holds_alternative<enumtype::kSum>(variant_enum)) {
|
||||
return at::Reduction::Sum;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
|
@ -31,7 +31,7 @@ inline Tensor conv1d(
|
||||
const Conv1dFuncOptions::padding_t& padding,
|
||||
ExpandingArray<1> dilation,
|
||||
int64_t groups) {
|
||||
return c10::visit(
|
||||
return std::visit(
|
||||
[&](const auto& pad) {
|
||||
return torch::conv1d(
|
||||
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
|
||||
@ -77,7 +77,7 @@ inline Tensor conv2d(
|
||||
const Conv2dFuncOptions::padding_t& padding,
|
||||
ExpandingArray<2> dilation,
|
||||
int64_t groups) {
|
||||
return c10::visit(
|
||||
return std::visit(
|
||||
[&](const auto& pad) {
|
||||
return torch::conv2d(
|
||||
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
|
||||
@ -123,7 +123,7 @@ inline Tensor conv3d(
|
||||
const Conv3dFuncOptions::padding_t& padding,
|
||||
ExpandingArray<3> dilation,
|
||||
int64_t groups) {
|
||||
return c10::visit(
|
||||
return std::visit(
|
||||
[&](const auto& pad) {
|
||||
return torch::conv3d(
|
||||
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
|
||||
|
@ -135,11 +135,11 @@ inline Tensor embedding_bag(
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int mode_enum;
|
||||
if (c10::get_if<enumtype::kSum>(&mode)) {
|
||||
if (std::holds_alternative<enumtype::kSum>(mode)) {
|
||||
mode_enum = 0;
|
||||
} else if (c10::get_if<enumtype::kMean>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kMean>(mode)) {
|
||||
mode_enum = 1;
|
||||
} else if (c10::get_if<enumtype::kMax>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kMax>(mode)) {
|
||||
mode_enum = 2;
|
||||
TORCH_CHECK(
|
||||
!scale_grad_by_freq,
|
||||
@ -155,7 +155,7 @@ inline Tensor embedding_bag(
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!per_sample_weights_.defined() || c10::get_if<enumtype::kSum>(&mode),
|
||||
!per_sample_weights_.defined() || std::get_if<enumtype::kSum>(&mode),
|
||||
"embedding_bag: per_sample_weights was not null. ",
|
||||
"per_sample_weights is only supported for mode='kSum' (got mode='",
|
||||
torch::enumtype::get_enum_name(mode),
|
||||
|
@ -50,7 +50,7 @@ inline Tensor kl_div(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
torch::Reduction::Reduction reduction_enum;
|
||||
|
||||
if (c10::get_if<enumtype::kMean>(&reduction)) {
|
||||
if (std::holds_alternative<enumtype::kMean>(reduction)) {
|
||||
TORCH_WARN(
|
||||
"reduction: 'mean' divides the total loss by both the batch size and the support size."
|
||||
"'batchmean' divides only by the batch size, and aligns with the KL div math definition."
|
||||
@ -58,7 +58,7 @@ inline Tensor kl_div(
|
||||
}
|
||||
|
||||
// special case for batchmean
|
||||
if (c10::get_if<enumtype::kBatchMean>(&reduction)) {
|
||||
if (std::holds_alternative<enumtype::kBatchMean>(reduction)) {
|
||||
reduction_enum = torch::Reduction::Sum;
|
||||
} else {
|
||||
reduction_enum = enumtype::reduction_get_enum(reduction);
|
||||
@ -66,7 +66,8 @@ inline Tensor kl_div(
|
||||
|
||||
auto reduced = torch::kl_div(input, target, reduction_enum, log_target);
|
||||
|
||||
if (c10::get_if<enumtype::kBatchMean>(&reduction) && input.dim() != 0) {
|
||||
if (std::holds_alternative<enumtype::kBatchMean>(reduction) &&
|
||||
input.dim() != 0) {
|
||||
reduced = reduced / input.sizes()[0];
|
||||
}
|
||||
|
||||
@ -531,11 +532,11 @@ inline Tensor multilabel_soft_margin_loss(
|
||||
|
||||
Tensor ret;
|
||||
|
||||
if (c10::get_if<enumtype::kNone>(&reduction)) {
|
||||
if (std::holds_alternative<enumtype::kNone>(reduction)) {
|
||||
ret = loss;
|
||||
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
|
||||
} else if (std::holds_alternative<enumtype::kMean>(reduction)) {
|
||||
ret = loss.mean();
|
||||
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
|
||||
} else if (std::holds_alternative<enumtype::kSum>(reduction)) {
|
||||
ret = loss.sum();
|
||||
} else {
|
||||
ret = input;
|
||||
@ -661,11 +662,11 @@ inline Tensor triplet_margin_with_distance_loss(
|
||||
auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);
|
||||
|
||||
Tensor ret;
|
||||
if (c10::get_if<enumtype::kNone>(&reduction)) {
|
||||
if (std::holds_alternative<enumtype::kNone>(reduction)) {
|
||||
ret = loss;
|
||||
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
|
||||
} else if (std::holds_alternative<enumtype::kMean>(reduction)) {
|
||||
ret = loss.mean();
|
||||
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
|
||||
} else if (std::holds_alternative<enumtype::kSum>(reduction)) {
|
||||
ret = loss.sum();
|
||||
} else {
|
||||
ret = anchor;
|
||||
|
@ -15,13 +15,13 @@ inline Tensor pad(
|
||||
PadFuncOptions::mode_t mode,
|
||||
double value) {
|
||||
const auto mode_enum = [&] {
|
||||
if (c10::get_if<enumtype::kConstant>(&mode)) {
|
||||
if (std::holds_alternative<enumtype::kConstant>(mode)) {
|
||||
return at::padding_mode::constant;
|
||||
} else if (c10::get_if<enumtype::kReflect>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kReflect>(mode)) {
|
||||
return at::padding_mode::reflect;
|
||||
} else if (c10::get_if<enumtype::kReplicate>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kReplicate>(mode)) {
|
||||
return at::padding_mode::replicate;
|
||||
} else if (c10::get_if<enumtype::kCircular>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kCircular>(mode)) {
|
||||
return at::padding_mode::circular;
|
||||
}
|
||||
TORCH_CHECK(false, "Unrecognised padding mode");
|
||||
|
@ -86,8 +86,8 @@ inline Tensor interpolate(
|
||||
c10::optional<bool> align_corners,
|
||||
c10::optional<bool> recompute_scale_factor,
|
||||
bool antialias) {
|
||||
if (c10::get_if<enumtype::kNearest>(&mode) ||
|
||||
c10::get_if<enumtype::kArea>(&mode)) {
|
||||
if (std::holds_alternative<enumtype::kNearest>(mode) ||
|
||||
std::get_if<enumtype::kArea>(&mode)) {
|
||||
if (align_corners != c10::nullopt) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -130,8 +130,8 @@ inline Tensor interpolate(
|
||||
|
||||
if (antialias &&
|
||||
!(input.dim() == 4 &&
|
||||
(c10::get_if<enumtype::kBilinear>(&mode) ||
|
||||
c10::get_if<enumtype::kBicubic>(&mode)))) {
|
||||
(std::get_if<enumtype::kBilinear>(&mode) ||
|
||||
std::get_if<enumtype::kBicubic>(&mode)))) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Anti-alias option is only supported for bilinear and bicubic modes");
|
||||
@ -139,65 +139,65 @@ inline Tensor interpolate(
|
||||
|
||||
auto closed_over_args =
|
||||
std::make_tuple(input, size, scale_factor, recompute_scale_factor);
|
||||
if (input.dim() == 3 && c10::get_if<enumtype::kNearest>(&mode)) {
|
||||
if (input.dim() == 3 && std::get_if<enumtype::kNearest>(&mode)) {
|
||||
return torch::upsample_nearest1d(
|
||||
input,
|
||||
_interp_output_size(1, std::move(closed_over_args)),
|
||||
scale_factor_list.at(0));
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kNearest>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kNearest>(&mode)) {
|
||||
return torch::upsample_nearest2d(
|
||||
input,
|
||||
_interp_output_size(2, std::move(closed_over_args)),
|
||||
scale_factor_list.at(0),
|
||||
scale_factor_list.at(1));
|
||||
} else if (input.dim() == 5 && c10::get_if<enumtype::kNearest>(&mode)) {
|
||||
} else if (input.dim() == 5 && std::get_if<enumtype::kNearest>(&mode)) {
|
||||
return torch::upsample_nearest3d(
|
||||
input,
|
||||
_interp_output_size(3, std::move(closed_over_args)),
|
||||
scale_factor_list.at(0),
|
||||
scale_factor_list.at(1),
|
||||
scale_factor_list.at(2));
|
||||
} else if (input.dim() == 3 && c10::get_if<enumtype::kNearestExact>(&mode)) {
|
||||
} else if (input.dim() == 3 && std::get_if<enumtype::kNearestExact>(&mode)) {
|
||||
return torch::_upsample_nearest_exact1d(
|
||||
input,
|
||||
_interp_output_size(1, std::move(closed_over_args)),
|
||||
scale_factor_list.at(0));
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kNearestExact>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kNearestExact>(&mode)) {
|
||||
return torch::_upsample_nearest_exact2d(
|
||||
input,
|
||||
_interp_output_size(2, std::move(closed_over_args)),
|
||||
scale_factor_list.at(0),
|
||||
scale_factor_list.at(1));
|
||||
} else if (input.dim() == 5 && c10::get_if<enumtype::kNearestExact>(&mode)) {
|
||||
} else if (input.dim() == 5 && std::get_if<enumtype::kNearestExact>(&mode)) {
|
||||
return torch::_upsample_nearest_exact3d(
|
||||
input,
|
||||
_interp_output_size(3, std::move(closed_over_args)),
|
||||
scale_factor_list.at(0),
|
||||
scale_factor_list.at(1),
|
||||
scale_factor_list.at(2));
|
||||
} else if (input.dim() == 3 && c10::get_if<enumtype::kArea>(&mode)) {
|
||||
} else if (input.dim() == 3 && std::get_if<enumtype::kArea>(&mode)) {
|
||||
return detail::adaptive_avg_pool1d(
|
||||
input, _interp_output_size(1, std::move(closed_over_args)));
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kArea>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kArea>(&mode)) {
|
||||
return detail::adaptive_avg_pool2d(
|
||||
input, _interp_output_size(2, std::move(closed_over_args)));
|
||||
} else if (input.dim() == 5 && c10::get_if<enumtype::kArea>(&mode)) {
|
||||
} else if (input.dim() == 5 && std::get_if<enumtype::kArea>(&mode)) {
|
||||
return detail::adaptive_avg_pool3d(
|
||||
input, _interp_output_size(3, std::move(closed_over_args)));
|
||||
} else if (input.dim() == 3 && c10::get_if<enumtype::kLinear>(&mode)) {
|
||||
} else if (input.dim() == 3 && std::get_if<enumtype::kLinear>(&mode)) {
|
||||
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
|
||||
return torch::upsample_linear1d(
|
||||
input,
|
||||
_interp_output_size(1, std::move(closed_over_args)),
|
||||
*align_corners,
|
||||
scale_factor_list.at(0));
|
||||
} else if (input.dim() == 3 && c10::get_if<enumtype::kBilinear>(&mode)) {
|
||||
} else if (input.dim() == 3 && std::get_if<enumtype::kBilinear>(&mode)) {
|
||||
TORCH_CHECK(false, "Got 3D input, but bilinear mode needs 4D input");
|
||||
} else if (input.dim() == 3 && c10::get_if<enumtype::kTrilinear>(&mode)) {
|
||||
} else if (input.dim() == 3 && std::get_if<enumtype::kTrilinear>(&mode)) {
|
||||
TORCH_CHECK(false, "Got 3D input, but trilinear mode needs 5D input");
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kLinear>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kLinear>(&mode)) {
|
||||
TORCH_CHECK(false, "Got 4D input, but linear mode needs 3D input");
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kBilinear>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kBilinear>(&mode)) {
|
||||
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
|
||||
if (antialias) {
|
||||
return torch::_upsample_bilinear2d_aa(
|
||||
@ -213,13 +213,13 @@ inline Tensor interpolate(
|
||||
*align_corners,
|
||||
scale_factor_list.at(0),
|
||||
scale_factor_list.at(1));
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kTrilinear>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kTrilinear>(&mode)) {
|
||||
TORCH_CHECK(false, "Got 4D input, but trilinear mode needs 5D input");
|
||||
} else if (input.dim() == 5 && c10::get_if<enumtype::kLinear>(&mode)) {
|
||||
} else if (input.dim() == 5 && std::get_if<enumtype::kLinear>(&mode)) {
|
||||
TORCH_CHECK(false, "Got 5D input, but linear mode needs 3D input");
|
||||
} else if (input.dim() == 5 && c10::get_if<enumtype::kBilinear>(&mode)) {
|
||||
} else if (input.dim() == 5 && std::get_if<enumtype::kBilinear>(&mode)) {
|
||||
TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input");
|
||||
} else if (input.dim() == 5 && c10::get_if<enumtype::kTrilinear>(&mode)) {
|
||||
} else if (input.dim() == 5 && std::get_if<enumtype::kTrilinear>(&mode)) {
|
||||
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
|
||||
return torch::upsample_trilinear3d(
|
||||
input,
|
||||
@ -228,7 +228,7 @@ inline Tensor interpolate(
|
||||
scale_factor_list.at(0),
|
||||
scale_factor_list.at(1),
|
||||
scale_factor_list.at(2));
|
||||
} else if (input.dim() == 4 && c10::get_if<enumtype::kBicubic>(&mode)) {
|
||||
} else if (input.dim() == 4 && std::get_if<enumtype::kBicubic>(&mode)) {
|
||||
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
|
||||
if (antialias) {
|
||||
return torch::_upsample_bicubic2d_aa(
|
||||
|
@ -63,17 +63,17 @@ inline Tensor grid_sample(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t mode_enum, padding_mode_enum;
|
||||
|
||||
if (c10::get_if<enumtype::kBilinear>(&mode)) {
|
||||
if (std::holds_alternative<enumtype::kBilinear>(mode)) {
|
||||
mode_enum = 0;
|
||||
} else if (c10::get_if<enumtype::kNearest>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kNearest>(mode)) {
|
||||
mode_enum = 1;
|
||||
} else { /// mode == 'bicubic'
|
||||
mode_enum = 2;
|
||||
}
|
||||
|
||||
if (c10::get_if<enumtype::kZeros>(&padding_mode)) {
|
||||
if (std::holds_alternative<enumtype::kZeros>(padding_mode)) {
|
||||
padding_mode_enum = 0;
|
||||
} else if (c10::get_if<enumtype::kBorder>(&padding_mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kBorder>(padding_mode)) {
|
||||
padding_mode_enum = 1;
|
||||
} else { /// padding_mode == 'reflection'
|
||||
padding_mode_enum = 2;
|
||||
|
@ -8,7 +8,7 @@ namespace torch {
|
||||
namespace nn {
|
||||
namespace init {
|
||||
|
||||
using NonlinearityType = c10::variant<
|
||||
using NonlinearityType = std::variant<
|
||||
enumtype::kLinear,
|
||||
enumtype::kConv1D,
|
||||
enumtype::kConv2D,
|
||||
@ -21,7 +21,7 @@ using NonlinearityType = c10::variant<
|
||||
enumtype::kReLU,
|
||||
enumtype::kLeakyReLU>;
|
||||
|
||||
using FanModeType = c10::variant<enumtype::kFanIn, enumtype::kFanOut>;
|
||||
using FanModeType = std::variant<enumtype::kFanIn, enumtype::kFanOut>;
|
||||
|
||||
} // namespace init
|
||||
} // namespace nn
|
||||
|
@ -42,7 +42,7 @@ class ConvNdImpl : public torch::nn::Cloneable<Derived> {
|
||||
options.out_channels() % options.groups() == 0,
|
||||
"out_channels must be divisible by groups");
|
||||
|
||||
c10::visit(
|
||||
std::visit(
|
||||
c10::overloaded(
|
||||
[&](enumtype::kValid) {
|
||||
_reversed_padding_repeated_twice.resize(2 * D);
|
||||
@ -121,7 +121,7 @@ class ConvNdImpl : public torch::nn::Cloneable<Derived> {
|
||||
<< "(" << options.in_channels() << ", " << options.out_channels()
|
||||
<< ", kernel_size=" << options.kernel_size()
|
||||
<< ", stride=" << options.stride();
|
||||
c10::visit(
|
||||
std::visit(
|
||||
c10::overloaded(
|
||||
[&](enumtype::kValid) { stream << ", padding='valid'"; },
|
||||
[&](enumtype::kSame) { stream << ", padding='same'"; },
|
||||
@ -143,7 +143,7 @@ class ConvNdImpl : public torch::nn::Cloneable<Derived> {
|
||||
if (!options.bias()) {
|
||||
stream << ", bias=" << std::boolalpha << false;
|
||||
}
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
stream << ", padding_mode="
|
||||
<< enumtype::get_enum_name(options.padding_mode());
|
||||
}
|
||||
@ -276,7 +276,7 @@ class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
|
||||
explicit ConvTransposeNdImpl(detail::ConvNdOptions<D> options_)
|
||||
: ConvNdImpl<D, Derived>(options_) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
c10::holds_alternative<ExpandingArray<D>>(this->options.padding()),
|
||||
std::holds_alternative<ExpandingArray<D>>(this->options.padding()),
|
||||
"ConvTranspose padding cannot be a string");
|
||||
}
|
||||
|
||||
@ -303,7 +303,7 @@ class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
|
||||
if (!this->options.bias()) {
|
||||
stream << ", bias=" << std::boolalpha << false;
|
||||
}
|
||||
if (!c10::get_if<enumtype::kZeros>(&this->options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&this->options.padding_mode())) {
|
||||
stream << ", padding_mode="
|
||||
<< enumtype::get_enum_name(this->options.padding_mode());
|
||||
}
|
||||
@ -312,7 +312,7 @@ class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
|
||||
|
||||
protected:
|
||||
const ExpandingArray<D>& padding() const {
|
||||
return c10::get<ExpandingArray<D>>(this->options.padding());
|
||||
return std::get<ExpandingArray<D>>(this->options.padding());
|
||||
}
|
||||
|
||||
std::vector<int64_t> _output_padding(
|
||||
|
@ -11,7 +11,7 @@ namespace nn {
|
||||
|
||||
namespace detail {
|
||||
|
||||
typedef c10::variant<
|
||||
typedef std::variant<
|
||||
enumtype::kZeros,
|
||||
enumtype::kReflect,
|
||||
enumtype::kReplicate,
|
||||
@ -20,7 +20,7 @@ typedef c10::variant<
|
||||
|
||||
template <size_t D>
|
||||
using conv_padding_t =
|
||||
c10::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;
|
||||
std::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;
|
||||
|
||||
/// Options for a `D`-dimensional convolution or convolution transpose module.
|
||||
template <size_t D>
|
||||
|
@ -101,7 +101,7 @@ struct TORCH_API EmbeddingFuncOptions {
|
||||
|
||||
// ============================================================================
|
||||
|
||||
typedef c10::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
|
||||
typedef std::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
|
||||
EmbeddingBagMode;
|
||||
|
||||
/// Options for the `EmbeddingBag` module.
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/arg.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/types.h>
|
||||
|
@ -15,7 +15,7 @@ namespace nn {
|
||||
/// L1Loss model(L1LossOptions(torch::kNone));
|
||||
/// ```
|
||||
struct TORCH_API L1LossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
TORCH_OPTIONS_CTOR_VARIANT_ARG3(L1LossOptions, reduction, kNone, kMean, kSum)
|
||||
@ -48,7 +48,7 @@ using L1LossFuncOptions = L1LossOptions;
|
||||
/// model(KLDivLossOptions().reduction(torch::kNone).log_target(false));
|
||||
/// ```
|
||||
struct TORCH_API KLDivLossOptions {
|
||||
typedef c10::variant<
|
||||
typedef std::variant<
|
||||
enumtype::kNone,
|
||||
enumtype::kBatchMean,
|
||||
enumtype::kSum,
|
||||
@ -95,7 +95,7 @@ using KLDivFuncOptions = KLDivLossOptions;
|
||||
/// MSELoss model(MSELossOptions(torch::kNone));
|
||||
/// ```
|
||||
struct TORCH_API MSELossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
TORCH_OPTIONS_CTOR_VARIANT_ARG3(MSELossOptions, reduction, kNone, kMean, kSum)
|
||||
@ -128,7 +128,7 @@ using MSELossFuncOptions = MSELossOptions;
|
||||
/// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight));
|
||||
/// ```
|
||||
struct TORCH_API BCELossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// A manual rescaling weight given to the loss of each batch element.
|
||||
@ -163,7 +163,7 @@ using BinaryCrossEntropyFuncOptions = BCELossOptions;
|
||||
/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone));
|
||||
/// ```
|
||||
struct TORCH_API HingeEmbeddingLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// Specifies the threshold for which the distance of a negative sample must
|
||||
@ -197,7 +197,7 @@ using HingeEmbeddingLossFuncOptions = HingeEmbeddingLossOptions;
|
||||
/// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight));
|
||||
/// ```
|
||||
struct TORCH_API MultiMarginLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// Has a default value of :math:`1`. :math:`1` and :math:`2`
|
||||
@ -242,7 +242,7 @@ using MultiMarginLossFuncOptions = MultiMarginLossOptions;
|
||||
/// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5));
|
||||
/// ```
|
||||
struct TORCH_API CosineEmbeddingLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// Specifies the threshold for which the distance of a negative sample must
|
||||
@ -277,7 +277,7 @@ using CosineEmbeddingLossFuncOptions = CosineEmbeddingLossOptions;
|
||||
/// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone));
|
||||
/// ```
|
||||
struct TORCH_API MultiLabelMarginLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
|
||||
@ -318,7 +318,7 @@ using MultilabelMarginLossFuncOptions = MultiLabelMarginLossOptions;
|
||||
/// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone));
|
||||
/// ```
|
||||
struct TORCH_API SoftMarginLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
|
||||
@ -360,7 +360,7 @@ using SoftMarginLossFuncOptions = SoftMarginLossOptions;
|
||||
/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight));
|
||||
/// ```
|
||||
struct TORCH_API MultiLabelSoftMarginLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// A manual rescaling weight given to each
|
||||
@ -400,7 +400,7 @@ using MultilabelSoftMarginLossFuncOptions = MultiLabelSoftMarginLossOptions;
|
||||
/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false));
|
||||
/// ```
|
||||
struct TORCH_API TripletMarginLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// Specifies the threshold for which the distance of a negative sample must
|
||||
@ -442,7 +442,7 @@ using TripletMarginLossFuncOptions = TripletMarginLossOptions;
|
||||
/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false));
|
||||
/// ```
|
||||
struct TORCH_API TripletMarginWithDistanceLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
typedef std::function<Tensor(const Tensor&, const Tensor&)>
|
||||
distance_function_t;
|
||||
@ -493,7 +493,7 @@ using TripletMarginWithDistanceLossFuncOptions =
|
||||
/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum));
|
||||
/// ```
|
||||
struct TORCH_API CTCLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// blank label. Default `0`.
|
||||
@ -530,7 +530,7 @@ using CTCLossFuncOptions = CTCLossOptions;
|
||||
/// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5));
|
||||
/// ```
|
||||
struct TORCH_API SmoothL1LossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
|
||||
@ -573,7 +573,7 @@ using SmoothL1LossFuncOptions = SmoothL1LossOptions;
|
||||
/// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5));
|
||||
/// ```
|
||||
struct TORCH_API HuberLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
|
||||
@ -617,7 +617,7 @@ using HuberLossFuncOptions = HuberLossOptions;
|
||||
/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum));
|
||||
/// ```
|
||||
struct TORCH_API PoissonNLLLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// if true the loss is computed as `exp(input) - target * input`,
|
||||
@ -658,7 +658,7 @@ using PoissonNLLLossFuncOptions = PoissonNLLLossOptions;
|
||||
/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum));
|
||||
/// ```
|
||||
struct TORCH_API MarginRankingLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// Has a default value of `0`.
|
||||
@ -691,7 +691,7 @@ using MarginRankingLossFuncOptions = MarginRankingLossOptions;
|
||||
/// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean));
|
||||
/// ```
|
||||
struct TORCH_API NLLLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// A manual rescaling weight given to each
|
||||
@ -730,7 +730,7 @@ using NLLLossFuncOptions = NLLLossOptions;
|
||||
/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean));
|
||||
/// ```
|
||||
struct TORCH_API CrossEntropyLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
|
||||
/// A manual rescaling weight given to each class. If given, has to be a
|
||||
@ -770,7 +770,7 @@ using CrossEntropyFuncOptions = CrossEntropyLossOptions;
|
||||
/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight));
|
||||
/// ```
|
||||
struct TORCH_API BCEWithLogitsLossOptions {
|
||||
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
|
||||
reduction_t;
|
||||
/// A manual rescaling weight given to the loss of each batch element.
|
||||
/// If given, has to be a Tensor of size `nbatch`.
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/arg.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/enum.h>
|
||||
@ -195,7 +194,7 @@ namespace functional {
|
||||
/// 2}).mode(torch::kReplicate));
|
||||
/// ```
|
||||
struct TORCH_API PadFuncOptions {
|
||||
typedef c10::variant<
|
||||
typedef std::variant<
|
||||
enumtype::kConstant,
|
||||
enumtype::kReflect,
|
||||
enumtype::kReplicate,
|
||||
|
@ -12,7 +12,7 @@ namespace detail {
|
||||
|
||||
/// Common options for RNN, LSTM and GRU modules.
|
||||
struct TORCH_API RNNOptionsBase {
|
||||
typedef c10::variant<
|
||||
typedef std::variant<
|
||||
enumtype::kLSTM,
|
||||
enumtype::kGRU,
|
||||
enumtype::kRNN_TANH,
|
||||
@ -57,7 +57,7 @@ struct TORCH_API RNNOptionsBase {
|
||||
/// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
|
||||
/// ```
|
||||
struct TORCH_API RNNOptions {
|
||||
typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
|
||||
typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
|
||||
|
||||
RNNOptions(int64_t input_size, int64_t hidden_size);
|
||||
|
||||
@ -182,7 +182,7 @@ struct TORCH_API RNNCellOptionsBase {
|
||||
/// 10).bias(false).nonlinearity(torch::kReLU));
|
||||
/// ```
|
||||
struct TORCH_API RNNCellOptions {
|
||||
typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
|
||||
typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
|
||||
|
||||
RNNCellOptions(int64_t input_size, int64_t hidden_size);
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
using activation_t = c10::variant<
|
||||
using activation_t = std::variant<
|
||||
enumtype::kReLU,
|
||||
enumtype::kGELU,
|
||||
std::function<Tensor(const Tensor&)>>;
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/arg.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/enum.h>
|
||||
@ -28,7 +27,7 @@ struct TORCH_API UpsampleOptions {
|
||||
|
||||
/// the upsampling algorithm: one of "nearest", "linear", "bilinear",
|
||||
/// "bicubic" and "trilinear". Default: "nearest"
|
||||
typedef c10::variant<
|
||||
typedef std::variant<
|
||||
enumtype::kNearest,
|
||||
enumtype::kLinear,
|
||||
enumtype::kBilinear,
|
||||
@ -55,7 +54,7 @@ namespace functional {
|
||||
/// F::InterpolateFuncOptions().size(std::vector<int64_t>({4})).mode(torch::kNearest));
|
||||
/// ```
|
||||
struct TORCH_API InterpolateFuncOptions {
|
||||
typedef c10::variant<
|
||||
typedef std::variant<
|
||||
enumtype::kNearest,
|
||||
enumtype::kLinear,
|
||||
enumtype::kBilinear,
|
||||
|
@ -18,8 +18,8 @@ namespace functional {
|
||||
/// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true));
|
||||
/// ```
|
||||
struct TORCH_API GridSampleFuncOptions {
|
||||
typedef c10::variant<enumtype::kBilinear, enumtype::kNearest> mode_t;
|
||||
typedef c10::
|
||||
typedef std::variant<enumtype::kBilinear, enumtype::kNearest> mode_t;
|
||||
typedef std::
|
||||
variant<enumtype::kZeros, enumtype::kBorder, enumtype::kReflection>
|
||||
padding_mode_t;
|
||||
|
||||
|
@ -47,7 +47,7 @@ double calculate_kaiming_std(
|
||||
const auto gain = calculate_gain(nonlinearity, a);
|
||||
double std = 0.0;
|
||||
|
||||
if (c10::get_if<enumtype::kFanIn>(&mode)) {
|
||||
if (std::holds_alternative<enumtype::kFanIn>(mode)) {
|
||||
std = gain / std::sqrt(fan.in);
|
||||
} else {
|
||||
std = gain / std::sqrt(fan.out);
|
||||
@ -57,11 +57,11 @@ double calculate_kaiming_std(
|
||||
} // namespace
|
||||
|
||||
double calculate_gain(NonlinearityType nonlinearity, double param) {
|
||||
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
|
||||
if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
|
||||
return 5.0 / 3.0; // NOLINT
|
||||
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
|
||||
} else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
|
||||
return std::sqrt(2.0); // NOLINT
|
||||
} else if (c10::get_if<enumtype::kLeakyReLU>(&nonlinearity)) {
|
||||
} else if (std::holds_alternative<enumtype::kLeakyReLU>(nonlinearity)) {
|
||||
return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
|
||||
}
|
||||
|
||||
|
@ -20,11 +20,13 @@ namespace F = torch::nn::functional;
|
||||
static F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(
|
||||
torch::nn::detail::conv_padding_mode_t conv_padding_mode) {
|
||||
F::PadFuncOptions::mode_t pad_mode;
|
||||
if (c10::get_if<torch::enumtype::kReflect>(&conv_padding_mode)) {
|
||||
if (std::holds_alternative<torch::enumtype::kReflect>(conv_padding_mode)) {
|
||||
pad_mode = torch::kReflect;
|
||||
} else if (c10::get_if<torch::enumtype::kReplicate>(&conv_padding_mode)) {
|
||||
} else if (std::holds_alternative<torch::enumtype::kReplicate>(
|
||||
conv_padding_mode)) {
|
||||
pad_mode = torch::kReplicate;
|
||||
} else if (c10::get_if<torch::enumtype::kCircular>(&conv_padding_mode)) {
|
||||
} else if (std::holds_alternative<torch::enumtype::kCircular>(
|
||||
conv_padding_mode)) {
|
||||
pad_mode = torch::kCircular;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
@ -52,7 +54,7 @@ Conv1dImpl::Conv1dImpl(Conv1dOptions options_)
|
||||
.padding_mode(options_.padding_mode())) {}
|
||||
|
||||
Tensor Conv1dImpl::forward(const Tensor& input) {
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
return F::detail::conv1d(
|
||||
F::pad(
|
||||
input,
|
||||
@ -91,7 +93,7 @@ Conv2dImpl::Conv2dImpl(Conv2dOptions options_)
|
||||
.padding_mode(options_.padding_mode())) {}
|
||||
|
||||
Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) {
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
return F::detail::conv2d(
|
||||
F::pad(
|
||||
input,
|
||||
@ -134,7 +136,7 @@ Conv3dImpl::Conv3dImpl(Conv3dOptions options_)
|
||||
.padding_mode(options_.padding_mode())) {}
|
||||
|
||||
Tensor Conv3dImpl::forward(const Tensor& input) {
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
return F::detail::conv3d(
|
||||
F::pad(
|
||||
input,
|
||||
@ -247,7 +249,7 @@ ConvTranspose1dImpl::ConvTranspose1dImpl(ConvTranspose1dOptions options_)
|
||||
Tensor ConvTranspose1dImpl::forward(
|
||||
const Tensor& input,
|
||||
const c10::optional<at::IntArrayRef>& output_size) {
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
TORCH_CHECK(
|
||||
false, "Only `zeros` padding mode is supported for ConvTranspose1d");
|
||||
}
|
||||
@ -284,7 +286,7 @@ ConvTranspose2dImpl::ConvTranspose2dImpl(ConvTranspose2dOptions options_)
|
||||
Tensor ConvTranspose2dImpl::forward(
|
||||
const Tensor& input,
|
||||
const c10::optional<at::IntArrayRef>& output_size) {
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
TORCH_CHECK(
|
||||
false, "Only `zeros` padding mode is supported for ConvTranspose2d");
|
||||
}
|
||||
@ -321,7 +323,7 @@ ConvTranspose3dImpl::ConvTranspose3dImpl(ConvTranspose3dOptions options_)
|
||||
Tensor ConvTranspose3dImpl::forward(
|
||||
const Tensor& input,
|
||||
const c10::optional<at::IntArrayRef>& output_size) {
|
||||
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
|
||||
TORCH_CHECK(
|
||||
false, "Only `zeros` padding mode is supported for ConvTranspose3d");
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ void EmbeddingBagImpl::pretty_print(std::ostream& stream) const {
|
||||
if (options.sparse()) {
|
||||
stream << ", sparse=" << std::boolalpha << options.sparse();
|
||||
}
|
||||
if (!c10::get_if<enumtype::kMean>(&options.mode())) {
|
||||
if (!std::get_if<enumtype::kMean>(&options.mode())) {
|
||||
stream << ", mode=" << torch::enumtype::get_enum_name(options.mode());
|
||||
}
|
||||
if (options.include_last_offset()) {
|
||||
|
@ -30,13 +30,13 @@ enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
|
||||
|
||||
static CuDNNMode get_cudnn_mode_for_rnn(
|
||||
detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
|
||||
if (c10::get_if<enumtype::kRNN_RELU>(&mode)) {
|
||||
if (std::holds_alternative<enumtype::kRNN_RELU>(mode)) {
|
||||
return CuDNNMode::RNN_RELU;
|
||||
} else if (c10::get_if<enumtype::kRNN_TANH>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kRNN_TANH>(mode)) {
|
||||
return CuDNNMode::RNN_TANH;
|
||||
} else if (c10::get_if<enumtype::kLSTM>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kLSTM>(mode)) {
|
||||
return CuDNNMode::LSTM;
|
||||
} else if (c10::get_if<enumtype::kGRU>(&mode)) {
|
||||
} else if (std::holds_alternative<enumtype::kGRU>(mode)) {
|
||||
return CuDNNMode::GRU;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
|
||||
@ -94,19 +94,19 @@ void RNNImplBase<Derived>::reset() {
|
||||
|
||||
if (options_base.proj_size() > 0) {
|
||||
TORCH_CHECK(
|
||||
c10::get_if<enumtype::kLSTM>(&options_base.mode()),
|
||||
std::get_if<enumtype::kLSTM>(&options_base.mode()),
|
||||
"proj_size argument is only supported for LSTM, not RNN or GRU");
|
||||
}
|
||||
|
||||
int64_t gate_size = 0;
|
||||
if (c10::get_if<enumtype::kLSTM>(&options_base.mode())) {
|
||||
if (std::holds_alternative<enumtype::kLSTM>(options_base.mode())) {
|
||||
gate_size = 4 * options_base.hidden_size();
|
||||
} else if (c10::get_if<enumtype::kGRU>(&options_base.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kGRU>(options_base.mode())) {
|
||||
gate_size = 3 * options_base.hidden_size();
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
} else if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
|
||||
gate_size = options_base.hidden_size();
|
||||
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kRNN_RELU>(options_base.mode())) {
|
||||
gate_size = options_base.hidden_size();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
@ -405,9 +405,9 @@ template class RNNImplBase<RNNImpl>;
|
||||
|
||||
static detail::RNNOptionsBase::rnn_options_base_mode_t
|
||||
compute_rnn_options_base_mode(RNNOptions::nonlinearity_t nonlinearity) {
|
||||
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
|
||||
if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
|
||||
return torch::kRNN_TANH;
|
||||
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
|
||||
} else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
|
||||
return torch::kRNN_RELU;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
@ -453,7 +453,7 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
|
||||
|
||||
std::tuple<Tensor, Tensor> result;
|
||||
if (!batch_sizes.defined()) {
|
||||
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
|
||||
if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
|
||||
result = torch::rnn_tanh(
|
||||
input,
|
||||
hx,
|
||||
@ -464,7 +464,8 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
|
||||
this->is_training(),
|
||||
options_base.bidirectional(),
|
||||
options_base.batch_first());
|
||||
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kRNN_RELU>(
|
||||
options_base.mode())) {
|
||||
result = torch::rnn_relu(
|
||||
input,
|
||||
hx,
|
||||
@ -482,7 +483,7 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
|
||||
torch::enumtype::get_enum_name(options_base.mode()));
|
||||
}
|
||||
} else {
|
||||
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
|
||||
if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
|
||||
result = torch::rnn_tanh(
|
||||
input,
|
||||
batch_sizes,
|
||||
@ -493,7 +494,8 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
|
||||
options_base.dropout(),
|
||||
this->is_training(),
|
||||
options_base.bidirectional());
|
||||
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kRNN_RELU>(
|
||||
options_base.mode())) {
|
||||
result = torch::rnn_relu(
|
||||
input,
|
||||
batch_sizes,
|
||||
@ -920,10 +922,10 @@ Tensor RNNCellImpl::forward(const Tensor& input, Tensor hx) {
|
||||
r_hx = is_batched ? hx : hx.unsqueeze(0);
|
||||
}
|
||||
|
||||
if (c10::get_if<enumtype::kTanh>(&options.nonlinearity())) {
|
||||
if (std::holds_alternative<enumtype::kTanh>(options.nonlinearity())) {
|
||||
ret = torch::rnn_tanh_cell(
|
||||
r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
|
||||
} else if (c10::get_if<enumtype::kReLU>(&options.nonlinearity())) {
|
||||
} else if (std::holds_alternative<enumtype::kReLU>(options.nonlinearity())) {
|
||||
ret = torch::rnn_relu_cell(
|
||||
r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
|
||||
} else {
|
||||
|
@ -71,14 +71,14 @@ Tensor TransformerEncoderLayerImpl::forward(
|
||||
Tensor ret = norm1(src + dropout1(src2));
|
||||
|
||||
// feedforward
|
||||
if (c10::get_if<enumtype::kGELU>(&options.activation())) {
|
||||
if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
|
||||
src2 = linear2(dropout(F::gelu(linear1(ret))));
|
||||
} else if (c10::get_if<enumtype::kReLU>(&options.activation())) {
|
||||
} else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
|
||||
src2 = linear2(dropout(F::relu(linear1(ret))));
|
||||
} else if (c10::get_if<std::function<Tensor(const Tensor&)>>(
|
||||
&options.activation())) {
|
||||
} else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
|
||||
options.activation())) {
|
||||
auto callable_activation =
|
||||
*c10::get_if<std::function<Tensor(const Tensor&)>>(
|
||||
*std::get_if<std::function<Tensor(const Tensor&)>>(
|
||||
&options.activation());
|
||||
src2 = linear2(dropout(callable_activation(linear1(ret))));
|
||||
} else {
|
||||
@ -198,14 +198,14 @@ Tensor TransformerDecoderLayerImpl::forward(
|
||||
}
|
||||
|
||||
Tensor TransformerDecoderLayerImpl::activation(const Tensor& input) {
|
||||
if (c10::get_if<enumtype::kGELU>(&options.activation())) {
|
||||
if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
|
||||
return F::gelu(input);
|
||||
} else if (c10::get_if<enumtype::kReLU>(&options.activation())) {
|
||||
} else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
|
||||
return F::relu(input);
|
||||
} else if (c10::get_if<std::function<Tensor(const Tensor&)>>(
|
||||
&options.activation())) {
|
||||
} else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
|
||||
options.activation())) {
|
||||
auto callable_activation =
|
||||
*c10::get_if<std::function<Tensor(const Tensor&)>>(
|
||||
*std::get_if<std::function<Tensor(const Tensor&)>>(
|
||||
&options.activation());
|
||||
return callable_activation(input);
|
||||
} else {
|
||||
|
@ -25,15 +25,15 @@ void UpsampleImpl::pretty_print(std::ostream& stream) const {
|
||||
|
||||
Tensor UpsampleImpl::forward(const Tensor& input) {
|
||||
F::InterpolateFuncOptions::mode_t mode;
|
||||
if (c10::get_if<enumtype::kNearest>(&options.mode())) {
|
||||
if (std::holds_alternative<enumtype::kNearest>(options.mode())) {
|
||||
mode = torch::kNearest;
|
||||
} else if (c10::get_if<enumtype::kLinear>(&options.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kLinear>(options.mode())) {
|
||||
mode = torch::kLinear;
|
||||
} else if (c10::get_if<enumtype::kBilinear>(&options.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kBilinear>(options.mode())) {
|
||||
mode = torch::kBilinear;
|
||||
} else if (c10::get_if<enumtype::kBicubic>(&options.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kBicubic>(options.mode())) {
|
||||
mode = torch::kBicubic;
|
||||
} else if (c10::get_if<enumtype::kTrilinear>(&options.mode())) {
|
||||
} else if (std::holds_alternative<enumtype::kTrilinear>(options.mode())) {
|
||||
mode = torch::kTrilinear;
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
#include <c10/util/DimVector.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/util/variant.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
@ -173,13 +173,13 @@ bool symbolicShapeAnalysisTestModeEnabled() {
|
||||
return symbolic_shape_analysis_test_mode;
|
||||
}
|
||||
|
||||
using SSArgument = c10::variant<ShapeArguments, IValue>;
|
||||
using SSArgument = std::variant<ShapeArguments, IValue>;
|
||||
|
||||
static std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
|
||||
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
|
||||
if (const IValue* iv = std::get_if<IValue>(&sa)) {
|
||||
out << *iv;
|
||||
} else {
|
||||
out << c10::get<ShapeArguments>(sa);
|
||||
out << std::get<ShapeArguments>(sa);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -377,11 +377,11 @@ struct SymbolicShapeOpAnalyzer {
|
||||
SSArgument& argument = inputs_[op_in_index];
|
||||
Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
|
||||
|
||||
if (IValue* cur_val = c10::get_if<IValue>(&argument)) {
|
||||
if (IValue* cur_val = std::get_if<IValue>(&argument)) {
|
||||
GRAPH_DEBUG("Substituting constant input ", *cur_val);
|
||||
replaceWithIValue(graph_in_var, *cur_val);
|
||||
} else {
|
||||
auto cur_arg = c10::get<ShapeArguments>(argument);
|
||||
auto cur_arg = std::get<ShapeArguments>(argument);
|
||||
if (cur_arg.has_dim()) {
|
||||
graph_in_var->setType(ListType::ofInts());
|
||||
}
|
||||
@ -423,7 +423,7 @@ struct SymbolicShapeOpAnalyzer {
|
||||
"Missing Arg for Shape Graph");
|
||||
for (const auto index :
|
||||
c10::irange(shape_compute_graph_->inputs().size())) {
|
||||
auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]);
|
||||
auto shape_arguments = std::get_if<ShapeArguments>(&inputs_[index]);
|
||||
if (!shape_arguments || !shape_arguments->has_dim()) {
|
||||
continue;
|
||||
}
|
||||
@ -1146,10 +1146,10 @@ calculateSymbolicShapesOnOp(
|
||||
|
||||
std::vector<SSArgument> ssa_args;
|
||||
for (auto& arg : inputs) {
|
||||
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
|
||||
if (const IValue* ival = std::get_if<IValue>(&arg)) {
|
||||
ssa_args.emplace_back(*ival);
|
||||
} else {
|
||||
const c10::SymbolicShape* ss = c10::get_if<c10::SymbolicShape>(&arg);
|
||||
const c10::SymbolicShape* ss = std::get_if<c10::SymbolicShape>(&arg);
|
||||
ssa_args.emplace_back(ShapeArguments(*ss));
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -49,7 +49,7 @@ PropagateShapesAndBuildLargeShapeComputeGraph(
|
||||
TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
|
||||
TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
|
||||
|
||||
using SSAInput = c10::variant<IValue, c10::SymbolicShape>;
|
||||
using SSAInput = std::variant<IValue, c10::SymbolicShape>;
|
||||
TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
|
||||
calculateSymbolicShapesOnOp(
|
||||
const FunctionSchema* schema,
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
|
||||
using CanonicalArg = std::variant<CanonicalizedSymbolicShape, IValue>;
|
||||
using CanonicalArgVec = std::vector<CanonicalArg>;
|
||||
using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
|
||||
using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
|
||||
@ -20,14 +20,14 @@ CanonicalArgVec cannonicalizeVec(
|
||||
CanonicalArgVec canonical_args;
|
||||
canonical_args.reserve(arg_vec.size());
|
||||
for (auto& arg : arg_vec) {
|
||||
if (const IValue* iv = c10::get_if<IValue>(&arg)) {
|
||||
if (const IValue* iv = std::get_if<IValue>(&arg)) {
|
||||
if (deep_copy) {
|
||||
canonical_args.emplace_back(iv->deepcopy());
|
||||
} else {
|
||||
canonical_args.emplace_back(*iv);
|
||||
}
|
||||
} else {
|
||||
auto& ss = c10::get<at::SymbolicShape>(arg);
|
||||
auto& ss = std::get<at::SymbolicShape>(arg);
|
||||
canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
|
||||
}
|
||||
}
|
||||
@ -57,7 +57,7 @@ struct ArgumentsHasher {
|
||||
hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
|
||||
for (const CanonicalArg& arg : arg_vec) {
|
||||
size_t cur_arg = 0;
|
||||
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
|
||||
if (const IValue* ival = std::get_if<IValue>(&arg)) {
|
||||
// IValue doesn't hash List (as Python doesn't), so we will do a custom
|
||||
// list hash
|
||||
if (ival->isList()) {
|
||||
@ -70,7 +70,7 @@ struct ArgumentsHasher {
|
||||
cur_arg = IValue::hash(ival);
|
||||
}
|
||||
} else {
|
||||
cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
|
||||
cur_arg = std::get<CanonicalizedSymbolicShape>(arg).hash();
|
||||
}
|
||||
hash_val = at::hash_combine(hash_val, cur_arg);
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/graph_node_list.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
@ -912,7 +911,7 @@ class TORCH_API ProcessedNode {
|
||||
|
||||
// These should be noexcept, but some Android build is failing
|
||||
// saying the noexcept specification doesn't match the calculated
|
||||
// one. Maybe c10::variant is throwing it off?
|
||||
// one. Maybe std::variant is throwing it off?
|
||||
ProcessedNode(ProcessedNode&&) = default;
|
||||
|
||||
ProcessedNode(const ProcessedNode&) = delete;
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
|
||||
#include <ATen/ExpandUtils.h>
|
||||
@ -437,9 +436,9 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
|
||||
}
|
||||
if (vec.empty()) {
|
||||
return BufList(); // Return arbitrarily typed vector
|
||||
} else if (c10::get_if<BufHandle>(&vec[0])) {
|
||||
} else if (std::get_if<BufHandle>(&vec[0])) {
|
||||
return convertVecArgValue<BufHandle>(vec);
|
||||
} else if (c10::get_if<int64_t>(&vec[0])) {
|
||||
} else if (std::get_if<int64_t>(&vec[0])) {
|
||||
return convertVecArgValue<int64_t>(vec);
|
||||
}
|
||||
throw unsupported_dtype();
|
||||
@ -608,7 +607,7 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
||||
argInputs.emplace_back(toArg(inp));
|
||||
}
|
||||
// handle optional bias
|
||||
if (c10::get_if<ArgNone>(&argInputs[2])) {
|
||||
if (std::get_if<ArgNone>(&argInputs[2])) {
|
||||
Dtype dtype = outputType ? Dtype(*outputType) : kFloat;
|
||||
std::vector<ExprHandle> biasShape;
|
||||
biasShape.push_back(outputShape[1]);
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
@ -59,23 +58,23 @@ std::vector<ExprHandle> computeIndicesToBroadcast(
|
||||
const std::vector<ExprHandle>& inputSizes);
|
||||
|
||||
inline std::string getArgValueName(const ArgValue& a) {
|
||||
if (c10::get_if<tensorexpr::BufHandle>(&a)) {
|
||||
if (std::holds_alternative<tensorexpr::BufHandle>(a)) {
|
||||
return "BufHandle";
|
||||
} else if (c10::get_if<tensorexpr::VarHandle>(&a)) {
|
||||
} else if (std::holds_alternative<tensorexpr::VarHandle>(a)) {
|
||||
return "VarHandle";
|
||||
} else if (c10::get_if<double>(&a)) {
|
||||
} else if (std::holds_alternative<double>(a)) {
|
||||
return "double";
|
||||
} else if (c10::get_if<int64_t>(&a)) {
|
||||
} else if (std::holds_alternative<int64_t>(a)) {
|
||||
return "int64_t";
|
||||
} else if (c10::get_if<bool>(&a)) {
|
||||
} else if (std::holds_alternative<bool>(a)) {
|
||||
return "bool";
|
||||
} else if (c10::get_if<BufList>(&a)) {
|
||||
} else if (std::holds_alternative<BufList>(a)) {
|
||||
return "BufList";
|
||||
} else if (c10::get_if<DoubleList>(&a)) {
|
||||
} else if (std::holds_alternative<DoubleList>(a)) {
|
||||
return "DoubleList";
|
||||
} else if (c10::get_if<IntList>(&a)) {
|
||||
} else if (std::holds_alternative<IntList>(a)) {
|
||||
return "IntList";
|
||||
} else if (c10::get_if<ArgNone>(&a)) {
|
||||
} else if (std::holds_alternative<ArgNone>(a)) {
|
||||
return "None";
|
||||
} else {
|
||||
throw std::runtime_error("ArgValue type not handled in string conversion");
|
||||
@ -86,7 +85,7 @@ template <class T>
|
||||
std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
|
||||
std::vector<T> res;
|
||||
for (auto& x : v) {
|
||||
auto val = c10::get_if<T>(&x);
|
||||
auto val = std::get_if<T>(&x);
|
||||
if (val) {
|
||||
res.push_back(*val);
|
||||
} else {
|
||||
|
@ -517,11 +517,11 @@ int nnc_lowerings_lazy_registration() {
|
||||
at::Device device) {
|
||||
bool noMin = false;
|
||||
bool noMax = false;
|
||||
if (c10::get_if<ArgNone>(&inputs[1])) {
|
||||
if (std::get_if<ArgNone>(&inputs[1])) {
|
||||
noMin = true;
|
||||
}
|
||||
|
||||
if (c10::get_if<ArgNone>(&inputs[2])) {
|
||||
if (std::get_if<ArgNone>(&inputs[2])) {
|
||||
noMax = true;
|
||||
}
|
||||
|
||||
@ -583,7 +583,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
// check if the activation is quantized
|
||||
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& x = std::get<BufHandle>(inputs[0]);
|
||||
if (x.node()->qscale()) {
|
||||
return computeQuantizedSigmoidExternalCall(
|
||||
inputs, outputShape, outputStrides, outputType, device);
|
||||
@ -675,7 +675,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto A = std::get<BufHandle>(inputs[0]);
|
||||
if (A.node()->qscale()) {
|
||||
return computeQuantizedRelu(
|
||||
inputs, outputShape, outputStrides, outputType, device);
|
||||
@ -741,7 +741,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
const auto& kApproximate = c10::get<std::string>(inputs[1]);
|
||||
const auto& kApproximate = std::get<std::string>(inputs[1]);
|
||||
std::vector<ArgValue> operands = {inputs.front()};
|
||||
if (at::native::get_gelutype_enum(kApproximate) ==
|
||||
at::native::GeluType::Tanh) {
|
||||
@ -987,7 +987,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
const BufHandle& rhs = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& rhs = std::get<BufHandle>(inputs[1]);
|
||||
auto dtype = rhs.dtype();
|
||||
return computeOneOperand(
|
||||
"aten_type_as",
|
||||
@ -1708,7 +1708,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
// outputShape,
|
||||
// [&](const std::vector<VarHandle>& axes) {
|
||||
// int64_t dim =
|
||||
// at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]),
|
||||
// at::maybe_wrap_dim(std::get<int64_t>(inputs[1]),
|
||||
// axes.size());
|
||||
// ExprHandle start = constant(inputs[2]);
|
||||
// ExprHandle stride = constant(inputs[4]);
|
||||
@ -1730,7 +1730,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
outputShape,
|
||||
outputStrides,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
int64_t dim = c10::get<int64_t>(inputs[1]);
|
||||
int64_t dim = std::get<int64_t>(inputs[1]);
|
||||
if (dim < 0) {
|
||||
if (axes.empty()) {
|
||||
throw malformed_input("axes are zero handling unsqueeze");
|
||||
@ -1749,7 +1749,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
}
|
||||
}
|
||||
|
||||
return broadcast(c10::get<BufHandle>(inputs[0]), indices);
|
||||
return broadcast(std::get<BufHandle>(inputs[0]), indices);
|
||||
});
|
||||
});
|
||||
RegisterNNCLoweringsFunction aten_t(
|
||||
@ -1776,7 +1776,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto A = std::get<BufHandle>(inputs[0]);
|
||||
// Trivial case of 0-dim tensors: just a copy of the input
|
||||
if (A.ndim() == 0) {
|
||||
auto tensor = Compute(
|
||||
@ -1793,7 +1793,7 @@ int nnc_lowerings_lazy_registration() {
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
auto permute_dims = c10::get<IntList>(inputs[1]);
|
||||
auto permute_dims = std::get<IntList>(inputs[1]);
|
||||
auto tensor = Compute(
|
||||
"aten_permute",
|
||||
outputShape,
|
||||
|
@ -2,7 +2,6 @@
|
||||
// IR.
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
||||
@ -13,11 +12,11 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
using ArgNone = c10::monostate;
|
||||
using ArgNone = std::monostate;
|
||||
using BufList = std::vector<tensorexpr::BufHandle>;
|
||||
using DoubleList = std::vector<double>;
|
||||
using IntList = std::vector<int64_t>;
|
||||
using ArgValue = c10::variant<
|
||||
using ArgValue = std::variant<
|
||||
tensorexpr::BufHandle,
|
||||
tensorexpr::VarHandle,
|
||||
double,
|
||||
|
@ -246,18 +246,18 @@ Tensor conv2d_depthwise(
|
||||
}
|
||||
|
||||
static std::vector<int64_t> _pair_int(ArgValue v) {
|
||||
if (auto t = c10::get_if<IntList>(&v)) {
|
||||
if (auto t = std::get_if<IntList>(&v)) {
|
||||
return {(*t)[0], (*t)[1]};
|
||||
}
|
||||
auto i = c10::get<int64_t>(v);
|
||||
auto i = std::get<int64_t>(v);
|
||||
return {i, i};
|
||||
}
|
||||
|
||||
static std::vector<int64_t> _single_int_list(ArgValue v) {
|
||||
if (auto t = c10::get_if<IntList>(&v)) {
|
||||
if (auto t = std::get_if<IntList>(&v)) {
|
||||
return {(*t)[0]};
|
||||
}
|
||||
auto i = c10::get<int64_t>(v);
|
||||
auto i = std::get<int64_t>(v);
|
||||
return {i};
|
||||
}
|
||||
|
||||
@ -361,15 +361,15 @@ Tensor computeConv2d(
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("conv", outputShape, dtype);
|
||||
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& w = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& b = c10::get<BufHandle>(inputs[2]);
|
||||
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& w = std::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& b = std::get<BufHandle>(inputs[2]);
|
||||
|
||||
auto strides = _pair_int(inputs[3]);
|
||||
auto padding = _pair_int(inputs[4]);
|
||||
auto dilation = _pair_int(inputs[5]);
|
||||
|
||||
int groups = c10::get<int64_t>(inputs[6]);
|
||||
int groups = std::get<int64_t>(inputs[6]);
|
||||
|
||||
auto inpInfo = getTensorInfo(inp);
|
||||
auto wInfo = getTensorInfo(w);
|
||||
@ -409,15 +409,15 @@ Tensor computeConv1d(
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("conv", outputShape, dtype);
|
||||
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& w = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& b = c10::get<BufHandle>(inputs[2]);
|
||||
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& w = std::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& b = std::get<BufHandle>(inputs[2]);
|
||||
|
||||
auto strides = _single_int_list(inputs[3]);
|
||||
auto padding = _single_int_list(inputs[4]);
|
||||
auto dilation = _single_int_list(inputs[5]);
|
||||
|
||||
int groups = c10::get<int64_t>(inputs[6]);
|
||||
int groups = std::get<int64_t>(inputs[6]);
|
||||
|
||||
auto inpInfo = getTensorInfo(inp);
|
||||
auto wInfo = getTensorInfo(w);
|
||||
@ -443,8 +443,8 @@ Tensor computePrepackedConv2dClampRun(
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("prepacked_conv2d_clamp_run", outputShape, dtype);
|
||||
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf, "nnc_prepacked_conv2d_clamp_run", {inp, prepacked}, {});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
@ -462,8 +462,8 @@ Tensor computePrepackedLinearClampRun(
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("prepacked_linear_clamp_run", outputShape, dtype);
|
||||
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf, "nnc_prepacked_linear_clamp_run", {inp, prepacked}, {});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
@ -482,8 +482,8 @@ Tensor computeMkldnnPrepackedConvRun(
|
||||
|
||||
BufHandle ResultBuf(
|
||||
"mkldnn_prepacked_conv_run", outputShape, outputStrides, dtype);
|
||||
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf, "nnc_mkldnn_prepacked_conv_run", {inp, prepacked}, {});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
|
@ -16,8 +16,8 @@ Tensor computeMatmul(
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
BufHandle ResultBuf("matmul", outputShape, dtype);
|
||||
const BufHandle a = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle b = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle a = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle b = std::get<BufHandle>(inputs[1]);
|
||||
|
||||
auto size_a = a.dims();
|
||||
auto size_b = b.dims();
|
||||
@ -68,11 +68,11 @@ Tensor computeAddMM(
|
||||
ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_addmm",
|
||||
{c10::get<BufHandle>(inputs[0]),
|
||||
c10::get<BufHandle>(inputs[1]),
|
||||
c10::get<BufHandle>(inputs[2])},
|
||||
{c10::get<int64_t>(inputs[3]),
|
||||
c10::get<int64_t>(
|
||||
{std::get<BufHandle>(inputs[0]),
|
||||
std::get<BufHandle>(inputs[1]),
|
||||
std::get<BufHandle>(inputs[2])},
|
||||
{std::get<int64_t>(inputs[3]),
|
||||
std::get<int64_t>(
|
||||
inputs[4])})); // TODO: handle other dtypes of alpha and beta
|
||||
}
|
||||
|
||||
|
@ -249,7 +249,7 @@ std::vector<ExprHandle> broadcastShapes(
|
||||
}
|
||||
|
||||
std::vector<ExprHandle> valueShape(const ArgValue& v) {
|
||||
if (auto b = c10::get_if<tensorexpr::BufHandle>(&v)) {
|
||||
if (auto b = std::get_if<tensorexpr::BufHandle>(&v)) {
|
||||
return b->dims();
|
||||
}
|
||||
return {};
|
||||
@ -258,14 +258,14 @@ std::vector<ExprHandle> valueShape(const ArgValue& v) {
|
||||
ExprHandle tensorOrConstant(
|
||||
const ArgValue& v,
|
||||
const std::vector<ExprHandle>& axes) {
|
||||
if (auto b = c10::get_if<BufHandle>(&v)) {
|
||||
if (auto b = std::get_if<BufHandle>(&v)) {
|
||||
return broadcast(*b, axes);
|
||||
}
|
||||
return constant(v);
|
||||
}
|
||||
|
||||
ExprHandle scalarOrConstant(const ArgValue& v) {
|
||||
if (auto vh = c10::get_if<VarHandle>(&v)) {
|
||||
if (auto vh = std::get_if<VarHandle>(&v)) {
|
||||
return *vh;
|
||||
}
|
||||
return constant(v);
|
||||
@ -276,15 +276,15 @@ ExprHandle broadcast(BufHandle b, const std::vector<ExprHandle>& axes) {
|
||||
}
|
||||
|
||||
ExprHandle constant(const ArgValue& v) {
|
||||
if (auto s = c10::get_if<tensorexpr::VarHandle>(&v)) {
|
||||
if (auto s = std::get_if<tensorexpr::VarHandle>(&v)) {
|
||||
return *s;
|
||||
} else if (auto d = c10::get_if<double>(&v)) {
|
||||
} else if (auto d = std::get_if<double>(&v)) {
|
||||
return DoubleImm::make(*d);
|
||||
} else if (auto i = c10::get_if<int64_t>(&v)) {
|
||||
} else if (auto i = std::get_if<int64_t>(&v)) {
|
||||
return LongImm::make(*i);
|
||||
} else if (auto b = c10::get_if<bool>(&v)) {
|
||||
} else if (auto b = std::get_if<bool>(&v)) {
|
||||
return BoolImm::make(*b);
|
||||
} else if (c10::get_if<ArgNone>(&v)) {
|
||||
} else if (std::get_if<ArgNone>(&v)) {
|
||||
// This is just a placeholder so we don't throw. None-handling
|
||||
// is operator-specific and should be handled properly in
|
||||
// the operator-specific lowering code.
|
||||
@ -327,10 +327,10 @@ Tensor computeChunk(
|
||||
"prim_constantchunk",
|
||||
outputShape,
|
||||
[inputs](const std::vector<VarHandle>& axes) {
|
||||
const auto& b = c10::get<BufHandle>(inputs[0]);
|
||||
int64_t chunkIdx = c10::get<int64_t>(inputs[1]);
|
||||
int64_t dim = c10::get<int64_t>(inputs[2]);
|
||||
int64_t chunks = c10::get<int64_t>(inputs[3]);
|
||||
const auto& b = std::get<BufHandle>(inputs[0]);
|
||||
int64_t chunkIdx = std::get<int64_t>(inputs[1]);
|
||||
int64_t dim = std::get<int64_t>(inputs[2]);
|
||||
int64_t chunks = std::get<int64_t>(inputs[3]);
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
||||
auto norm_dim = normalizeAndCheckIndex(dim, indices.size());
|
||||
@ -357,7 +357,7 @@ Tensor computeTranspose(
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto A = std::get<BufHandle>(inputs[0]);
|
||||
// Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
|
||||
if (A.ndim() <= 1) {
|
||||
return Compute(
|
||||
@ -369,8 +369,8 @@ Tensor computeTranspose(
|
||||
});
|
||||
}
|
||||
// Usual case where transpose actually swaps dimensions
|
||||
auto start_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
|
||||
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
|
||||
auto start_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[1]), A.ndim());
|
||||
auto to_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[2]), A.ndim());
|
||||
return Compute(
|
||||
"aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
|
||||
std::swap(axes[start_dim], axes[to_dim]);
|
||||
@ -384,7 +384,7 @@ Tensor computeExpand(
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto A = std::get<BufHandle>(inputs[0]);
|
||||
return Compute(
|
||||
"aten_expand", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
@ -398,7 +398,7 @@ Tensor computeReshape(
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto A = std::get<BufHandle>(inputs[0]);
|
||||
if (A.ndim() == 0) {
|
||||
return Compute(
|
||||
"aten_view", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
@ -513,7 +513,7 @@ static Tensor computeCatWoConditionals(
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const std::vector<ExprHandle>& outputStrides) {
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
auto input_list = c10::get<BufList>(inputs[0]);
|
||||
auto input_list = std::get<BufList>(inputs[0]);
|
||||
auto arg_dim = inputs[1];
|
||||
auto cat_info = processCatList(input_list);
|
||||
ScalarType high_type = cat_info.first;
|
||||
@ -547,7 +547,7 @@ static Tensor computeCatWoConditionals(
|
||||
output_buf, alloc<tensorexpr::Block>(std::vector<StmtPtr>({})));
|
||||
}
|
||||
|
||||
int64_t concat_dim = c10::get<int64_t>(arg_dim);
|
||||
int64_t concat_dim = std::get<int64_t>(arg_dim);
|
||||
auto norm_concat_dim = normalizeAndCheckIndex(concat_dim, outputShape.size());
|
||||
|
||||
auto loop_order_fn = [&](const BufPtr& buf_) {
|
||||
@ -628,7 +628,7 @@ Tensor computeCat(
|
||||
return computeCatWoConditionals(inputs, outputShape, outputStrides);
|
||||
}
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
auto inputList = c10::get<BufList>(inputs[0]);
|
||||
auto inputList = std::get<BufList>(inputs[0]);
|
||||
auto argDim = inputs[1];
|
||||
auto catInfo = processCatList(inputList);
|
||||
ScalarType highType = catInfo.first;
|
||||
@ -642,7 +642,7 @@ Tensor computeCat(
|
||||
return ExprHandle(0);
|
||||
}
|
||||
|
||||
int64_t dim_ = c10::get<int64_t>(argDim);
|
||||
int64_t dim_ = std::get<int64_t>(argDim);
|
||||
auto dim = normalizeAndCheckIndex(dim_, axes.size());
|
||||
// Promote input types.
|
||||
// Note that we need to consider all inputs, including empty - they
|
||||
@ -693,8 +693,8 @@ Tensor computeEmbedding(
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("emb", outputShape, dtype);
|
||||
const BufHandle& w = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& indices = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& w = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& indices = std::get<BufHandle>(inputs[1]);
|
||||
|
||||
StmtPtr s =
|
||||
ExternalCall::make(ResultBuf, "nnc_aten_embedding", {w, indices}, {});
|
||||
|
@ -14,11 +14,11 @@ Tensor computeBatchNorm(
|
||||
bool hasWeight = true;
|
||||
bool hasBias = true;
|
||||
|
||||
if (c10::get_if<ArgNone>(&inputs[1])) {
|
||||
if (std::holds_alternative<ArgNone>(inputs[1])) {
|
||||
hasWeight = false;
|
||||
}
|
||||
|
||||
if (c10::get_if<ArgNone>(&inputs[2])) {
|
||||
if (std::holds_alternative<ArgNone>(inputs[2])) {
|
||||
hasBias = false;
|
||||
}
|
||||
|
||||
|
@ -11,10 +11,10 @@ namespace jit {
|
||||
namespace tensorexpr {
|
||||
namespace {
|
||||
std::vector<int64_t> _pair_int(ArgValue v) {
|
||||
if (auto t = c10::get_if<IntList>(&v)) {
|
||||
if (auto t = std::get_if<IntList>(&v)) {
|
||||
return {(*t)[0], (*t)[1]};
|
||||
}
|
||||
auto i = c10::get<int64_t>(v);
|
||||
auto i = std::get<int64_t>(v);
|
||||
return {i, i};
|
||||
}
|
||||
} // namespace
|
||||
@ -161,7 +161,7 @@ Tensor computeQuantizePerTensor(
|
||||
return Dtype(ScalarType::QUInt8);
|
||||
}
|
||||
throw malformed_input("Expected quantized dtype");
|
||||
}(c10::get<int64_t>(inputs[3]));
|
||||
}(std::get<int64_t>(inputs[3]));
|
||||
|
||||
ExprHandle e =
|
||||
quant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero);
|
||||
@ -183,14 +183,14 @@ Tensor computeQuantizedAdd(
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
const BufHandle& QA = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& QB = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& QA = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& QB = std::get<BufHandle>(inputs[1]);
|
||||
auto qa_scale = ExprHandle(QA.node()->qscale());
|
||||
auto qa_zero = ExprHandle(QA.node()->qzero());
|
||||
auto qb_scale = ExprHandle(QB.node()->qscale());
|
||||
auto qb_zero = ExprHandle(QB.node()->qzero());
|
||||
ExprHandle out_qscale = DoubleImm::make(c10::get<double>(inputs[2]));
|
||||
ExprHandle out_qzero = LongImm::make(c10::get<int64_t>(inputs[3]));
|
||||
ExprHandle out_qscale = DoubleImm::make(std::get<double>(inputs[2]));
|
||||
ExprHandle out_qzero = LongImm::make(std::get<int64_t>(inputs[3]));
|
||||
Dtype dequant_dtype = kFloat;
|
||||
Dtype out_dtype = outputType ? Dtype(*outputType) : QA.dtype();
|
||||
std::vector<VarPtr> vars;
|
||||
@ -227,10 +227,10 @@ Tensor computeQuantizePerTensorExternalCall(
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
|
||||
const auto qscale = c10::get<double>(inputs[1]);
|
||||
const auto qzero = c10::get<int64_t>(inputs[2]);
|
||||
const auto qdtype = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& x = std::get<BufHandle>(inputs[0]);
|
||||
const auto qscale = std::get<double>(inputs[1]);
|
||||
const auto qzero = std::get<int64_t>(inputs[2]);
|
||||
const auto qdtype = std::get<int64_t>(inputs[3]);
|
||||
|
||||
const auto dtype = [](auto qdtype) {
|
||||
if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
|
||||
@ -264,7 +264,7 @@ Tensor computeDequantizeExternalCall(
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
const int64_t qdtype = (int64_t)immQDType(qx);
|
||||
|
||||
BufHandle ResultBuf("dequantize", outputShape, dtype);
|
||||
@ -290,12 +290,12 @@ Tensor computeQuantizedConv2dPrepack(
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("quantized_conv2d_prepack", outputShape, dtype);
|
||||
const BufHandle& qw = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& b = c10::get<BufHandle>(inputs[1]);
|
||||
const BufHandle& qw = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& b = std::get<BufHandle>(inputs[1]);
|
||||
auto strides = _pair_int(inputs[2]);
|
||||
auto padding = _pair_int(inputs[3]);
|
||||
auto dilation = _pair_int(inputs[4]);
|
||||
int groups = c10::get<int64_t>(inputs[5]);
|
||||
int groups = std::get<int64_t>(inputs[5]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
qw.node()->qscale(),
|
||||
buildErrorMessage(
|
||||
@ -335,10 +335,10 @@ Tensor computeQuantizedConv1d(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandleChannelsLast(
|
||||
@ -367,10 +367,10 @@ Tensor computeQuantizedConv2d(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandleChannelsLast(
|
||||
@ -399,10 +399,10 @@ Tensor computeQuantizedConv2dRelu(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandleChannelsLast(
|
||||
@ -431,10 +431,10 @@ Tensor computeQuantizedLinear(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandleContiguous(
|
||||
@ -463,10 +463,10 @@ Tensor computeQuantizedLinearRelu(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandleContiguous(
|
||||
@ -495,10 +495,10 @@ Tensor computeQuantizedAddExternalCall(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qb = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qb = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qa);
|
||||
const bool isQAChannelsLast = isChannelsLast(qa);
|
||||
@ -539,10 +539,10 @@ Tensor computeQuantizedMul(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qb = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qb = std::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qa);
|
||||
auto ResultBuf = makeQBufHandleContiguous(
|
||||
@ -570,8 +570,8 @@ Tensor computeQuantizedMulScalar(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
|
||||
const auto scalar = c10::get<double>(inputs[1]);
|
||||
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
|
||||
const auto scalar = std::get<double>(inputs[1]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qa);
|
||||
double scale1 = immQScale(qa);
|
||||
@ -597,7 +597,7 @@ Tensor computeQuantizedRelu(
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
|
||||
const auto out_qdtype = immQDType(qa);
|
||||
const bool isQAChannelsLast = isChannelsLast(qa);
|
||||
auto ResultBuf = isQAChannelsLast ? makeQBufHandleChannelsLast(
|
||||
@ -629,12 +629,12 @@ Tensor computeQuantizedCat(
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
auto inputList = c10::get<BufList>(inputs[0]);
|
||||
auto argDim = c10::get<int64_t>(inputs[1]);
|
||||
auto inputList = std::get<BufList>(inputs[0]);
|
||||
auto argDim = std::get<int64_t>(inputs[1]);
|
||||
auto n = inputList.size();
|
||||
// TODO: handle optional out_qscale, out_qzero
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
const auto out_qscale = std::get<double>(inputs[2]);
|
||||
const auto out_qzero = std::get<int64_t>(inputs[3]);
|
||||
|
||||
std::vector<BufHandle> args;
|
||||
std::vector<ExprHandle> extra_args;
|
||||
@ -669,7 +669,7 @@ Tensor computeDequantize(
|
||||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
auto qx = c10::get<BufHandle>(inputs[0]);
|
||||
auto qx = std::get<BufHandle>(inputs[0]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
qx.node()->qscale(),
|
||||
buildErrorMessage("Missing quantized scale for dequantize"));
|
||||
@ -697,7 +697,7 @@ Tensor computeUpsampleNearest2d(
|
||||
const std::vector<ExprHandle>& outputStrides,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
auto A = std::get<BufHandle>(inputs[0]);
|
||||
const auto& output_height = outputShape[2];
|
||||
const auto& output_width = outputShape[3];
|
||||
auto input_height = ExprHandle(A.dim(2));
|
||||
@ -750,18 +750,18 @@ Tensor computeUpsampleNearest2dExternalCall(
|
||||
}
|
||||
int64_t output_size_h = -1;
|
||||
int64_t output_size_w = -1;
|
||||
if (auto output_sizes = c10::get_if<IntList>(&inputs[1])) {
|
||||
if (auto output_sizes = std::get_if<IntList>(&inputs[1])) {
|
||||
output_size_h = (*output_sizes)[0];
|
||||
output_size_w = (*output_sizes)[1];
|
||||
}
|
||||
|
||||
double scale_factor_h = -1.f;
|
||||
double scale_factor_w = -1.f;
|
||||
if (auto scale_factors = c10::get_if<DoubleList>(&inputs[2])) {
|
||||
if (auto scale_factors = std::get_if<DoubleList>(&inputs[2])) {
|
||||
scale_factor_h = (*scale_factors)[0];
|
||||
scale_factor_w = (*scale_factors)[1];
|
||||
}
|
||||
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& x = std::get<BufHandle>(inputs[0]);
|
||||
double qx_qscale = -1.f;
|
||||
int64_t qx_qzero = -1l;
|
||||
int64_t qx_qdtype = -1l;
|
||||
@ -804,7 +804,7 @@ Tensor computeQuantizedSigmoidExternalCall(
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
|
||||
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
const double out_qscale = 1.0f / 256.0f;
|
||||
|
@ -32,7 +32,7 @@ Tensor computeSum(
|
||||
|
||||
size_t rank = sizes.size();
|
||||
if (inputs.size() > 2) {
|
||||
if (auto emptyAxes = c10::get_if<BufList>(&inputs[1])) {
|
||||
if (auto emptyAxes = std::get_if<BufList>(&inputs[1])) {
|
||||
// If dim-array is an empty list, it will appear as BufList instead of
|
||||
// IntList, and hence we need a special handling for it.
|
||||
// In that case, we need to sum over all axes.
|
||||
@ -40,7 +40,7 @@ Tensor computeSum(
|
||||
axes.resize(rank);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (rank > 0) {
|
||||
auto nodeAxes = c10::get<IntList>(inputs[1]);
|
||||
auto nodeAxes = std::get<IntList>(inputs[1]);
|
||||
// Canonicalize axes: wrap around, sort and make unique.
|
||||
for (auto axis : nodeAxes) {
|
||||
axes.push_back(at::maybe_wrap_dim(axis, rank));
|
||||
@ -48,7 +48,7 @@ Tensor computeSum(
|
||||
std::sort(axes.begin(), axes.end());
|
||||
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
|
||||
}
|
||||
keepdim = c10::get<bool>(inputs[2]);
|
||||
keepdim = std::get<bool>(inputs[2]);
|
||||
} else {
|
||||
axes.resize(rank);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@ -116,13 +116,13 @@ Tensor computeMean(
|
||||
}
|
||||
bool keepdim = false;
|
||||
BufHandle ResultBuf("mean", outputShape, dtype);
|
||||
BufHandle InputBuf = c10::get<BufHandle>(inputs[0]);
|
||||
BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
|
||||
std::vector<ExprHandle> extra_args;
|
||||
if (inputs.size() > 2) {
|
||||
keepdim = c10::get<bool>(inputs[2]);
|
||||
keepdim = std::get<bool>(inputs[2]);
|
||||
}
|
||||
|
||||
if (auto mean_dims = c10::get_if<IntList>(&inputs[1])) {
|
||||
if (auto mean_dims = std::get_if<IntList>(&inputs[1])) {
|
||||
extra_args = c10::fmap<ExprHandle>(*mean_dims);
|
||||
} else {
|
||||
// When dims argument is not specified, reduce over all dimensions
|
||||
@ -147,10 +147,10 @@ Tensor computeMax(
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
BufHandle ResultBuf("max", outputShape, dtype);
|
||||
BufHandle InputBuf = c10::get<BufHandle>(inputs[0]);
|
||||
BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
|
||||
std::vector<ExprHandle> max_dims_expr;
|
||||
auto max_dim = c10::get<int64_t>(inputs[1]);
|
||||
auto keep_dim = c10::get<bool>(inputs[2]);
|
||||
auto max_dim = std::get<int64_t>(inputs[1]);
|
||||
auto keep_dim = std::get<bool>(inputs[2]);
|
||||
return Tensor(
|
||||
ResultBuf.node(),
|
||||
ExternalCall::make(
|
||||
@ -172,13 +172,13 @@ Tensor computeAdaptiveAvgPool2d(
|
||||
}
|
||||
BufHandle ResultBuf("adaptive_avgpool2d", outputShape, dtype);
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
auto out_size_param = c10::get<IntList>(inputs[1]);
|
||||
auto out_size_param = std::get<IntList>(inputs[1]);
|
||||
return Tensor(
|
||||
ResultBuf.node(),
|
||||
ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_adaptive_avg_pool2d",
|
||||
{c10::get<BufHandle>(inputs[0])},
|
||||
{std::get<BufHandle>(inputs[0])},
|
||||
c10::fmap<ExprHandle>(out_size_param)));
|
||||
}
|
||||
|
||||
|
@ -44,10 +44,10 @@ Tensor computeSoftmax(
|
||||
|
||||
// We do not handle None for dims (input 1) because that is supposed to
|
||||
// be deprecated.
|
||||
TORCH_INTERNAL_ASSERT(c10::get_if<int64_t>(&inputs[1]));
|
||||
TORCH_INTERNAL_ASSERT(std::get_if<int64_t>(&inputs[1]));
|
||||
int64_t rank = valueShape(inputs[0]).size();
|
||||
size_t softmax_dim =
|
||||
normalizeAndCheckIndex(c10::get<int64_t>(inputs[1]), rank);
|
||||
normalizeAndCheckIndex(std::get<int64_t>(inputs[1]), rank);
|
||||
std::vector<ExprHandle> non_softmax_dims;
|
||||
for (size_t i = 0; i < outputShape.size(); ++i) {
|
||||
if (i != softmax_dim) {
|
||||
@ -93,10 +93,10 @@ Tensor computeSoftmax(
|
||||
return new_indices;
|
||||
};
|
||||
|
||||
auto inp_buf = c10::get<BufHandle>(inputs[0]);
|
||||
auto inp_buf = std::get<BufHandle>(inputs[0]);
|
||||
|
||||
auto dtype = inp_buf.dtype();
|
||||
if (auto d = c10::get_if<int64_t>(&inputs[2])) {
|
||||
if (auto d = std::get_if<int64_t>(&inputs[2])) {
|
||||
dtype = ToDtype(static_cast<ScalarType>(*d));
|
||||
}
|
||||
|
||||
|
@ -731,25 +731,25 @@ void initTensorExprBindings(PyObject* module) {
|
||||
}))
|
||||
.def(
|
||||
"as_buf",
|
||||
[](const ArgValue& self) { return c10::get<BufHandle>(self); })
|
||||
[](const ArgValue& self) { return std::get<BufHandle>(self); })
|
||||
.def(
|
||||
"as_var",
|
||||
[](const ArgValue& self) { return c10::get<VarHandle>(self); })
|
||||
[](const ArgValue& self) { return std::get<VarHandle>(self); })
|
||||
.def(
|
||||
"as_float",
|
||||
[](const ArgValue& self) { return c10::get<double>(self); })
|
||||
[](const ArgValue& self) { return std::get<double>(self); })
|
||||
.def(
|
||||
"as_int",
|
||||
[](const ArgValue& self) { return c10::get<int64_t>(self); })
|
||||
.def("as_bool", [](const ArgValue& self) { return c10::get<bool>(self); })
|
||||
[](const ArgValue& self) { return std::get<int64_t>(self); })
|
||||
.def("as_bool", [](const ArgValue& self) { return std::get<bool>(self); })
|
||||
.def(
|
||||
"as_none",
|
||||
[](const ArgValue& self) { return c10::get<ArgNone>(self); })
|
||||
[](const ArgValue& self) { return std::get<ArgNone>(self); })
|
||||
.def(
|
||||
"as_buflist",
|
||||
[](const ArgValue& self) { return c10::get<BufList>(self); })
|
||||
[](const ArgValue& self) { return std::get<BufList>(self); })
|
||||
.def("as_intlist", [](const ArgValue& self) {
|
||||
return c10::get<IntList>(self);
|
||||
return std::get<IntList>(self);
|
||||
});
|
||||
|
||||
py::class_<c10::ScalarType>(te, "ScalarType");
|
||||
|
Reference in New Issue
Block a user