[2/N] Move c10::variant to std::variant (#109723)

This PR moves most of c10::variant calls to std::variant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109723
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2023-09-24 02:47:43 +00:00
committed by PyTorch MergeBot
parent c13177f2cb
commit dee100945e
49 changed files with 337 additions and 341 deletions

View File

@ -3,7 +3,6 @@
#if AT_USE_JITERATOR()
#include <c10/util/variant.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/jit_utils.h>
@ -11,6 +10,7 @@
#include <ATen/native/cuda/JitLoops.cuh>
#include <string>
#include <variant>
#include <vector>
namespace at::native {
@ -93,7 +93,7 @@ static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
template <bool IS_INPUT>
struct OffsetCalculatorVariant {
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
using OffsetCalculatorTypes = c10::variant<
using OffsetCalculatorTypes = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
@ -113,7 +113,7 @@ struct OffsetCalculatorVariant {
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
@ -123,7 +123,7 @@ struct OffsetCalculatorVariant {
struct ArrayVariant {
// works for up to 8 input + 8 outputs
#define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
using ArrayTypes = c10::variant<
using ArrayTypes = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
@ -142,7 +142,7 @@ struct ArrayVariant {
TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
}
c10::visit([&](auto& a) {
std::visit([&](auto& a) {
for (auto i = 0; i < ntensors; ++i) {
a[i] = (char*)iter.data_ptr(i);
}
@ -150,7 +150,7 @@ struct ArrayVariant {
}
void* data_ptr() {
return c10::visit([](auto & a){ return static_cast<void*>(&a); }, array);
return std::visit([](auto & a){ return static_cast<void*>(&a); }, array);
}
private:
@ -159,7 +159,7 @@ private:
struct TrivialOffsetCalculatorVariant {
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>
using TrivialOffsetCalculatorTypes = c10::variant<
using TrivialOffsetCalculatorTypes = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
@ -178,7 +178,7 @@ struct TrivialOffsetCalculatorVariant {
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(&v); }, v);
return std::visit([](auto & v){ return static_cast<void*>(&v); }, v);
}
private:
@ -187,7 +187,7 @@ private:
struct LoadWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
using LoadWithCastPtr = c10::variant<
using LoadWithCastPtr = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
@ -207,7 +207,7 @@ struct LoadWithCastVariant {
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
@ -216,7 +216,7 @@ private:
struct StoreWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
using StoreWithCastPtr = c10::variant<
using StoreWithCastPtr = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
@ -236,7 +236,7 @@ struct StoreWithCastVariant {
}
void* data_ptr() {
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:

View File

@ -9,9 +9,6 @@
#include <c10/core/DispatchKey.h>
#include <ATen/core/function_schema.h>
#include <c10/util/Optional.h>
#include <c10/util/variant.h>
#include <unordered_map>
#include <mutex>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <ATen/functorch/Interpreter.h>
#include <ATen/functorch/VmapInterpreter.h>

View File

@ -3,11 +3,11 @@
#include <c10/macros/Macros.h>
#include <c10/util/StringUtil.h>
#include <c10/util/variant.h>
#include <cstddef>
#include <exception>
#include <string>
#include <variant>
#include <vector>
#if defined(_MSC_VER) && _MSC_VER <= 1900
@ -115,7 +115,7 @@ class C10_API Warning {
class C10_API UserWarning {};
class C10_API DeprecationWarning {};
using warning_variant_t = c10::variant<UserWarning, DeprecationWarning>;
using warning_variant_t = std::variant<UserWarning, DeprecationWarning>;
Warning(
warning_variant_t type,

View File

@ -4,6 +4,7 @@
#include <c10/util/Exception.h>
#include <c10/util/in_place.h>
#include <memory>
#include <type_traits>
namespace c10 {

View File

@ -1,6 +1,7 @@
#include <gtest/gtest.h>
#include <torch/torch.h>
#include <variant>
#include <test/cpp/api/support.h>
@ -13,7 +14,7 @@
}
TEST(EnumTest, AllEnums) {
c10::variant<
std::variant<
torch::enumtype::kLinear,
torch::enumtype::kConv1D,
torch::enumtype::kConv2D,

View File

@ -951,11 +951,11 @@ TEST(ExternalCall, JitCustomFusionOp) {
torch::jit::tensorexpr::BufHandle result_buf(
"nnc_add_mul_res_buf", output_shape, output_dtype);
const torch::jit::tensorexpr::BufHandle& a =
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
const torch::jit::tensorexpr::BufHandle& b =
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
const torch::jit::tensorexpr::BufHandle& c =
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
torch::jit::tensorexpr::StmtPtr s =
torch::jit::tensorexpr::ExternalCall::make(
result_buf, external_func_name, {a, b, c}, {});

View File

@ -1667,7 +1667,7 @@ Tensor lowerNanToNum(
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto input_buf = c10::get<BufHandle>(inputs[0]);
auto input_buf = std::get<BufHandle>(inputs[0]);
auto e = Compute(
"custom_nan_to_num",
outputShape,

View File

@ -286,7 +286,7 @@ PyObject* map_warning_to_python_type(const c10::Warning& warning) {
return PyExc_DeprecationWarning;
}
};
return c10::visit(Visitor(), warning.type());
return std::visit(Visitor(), warning.type());
}
/// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification

View File

@ -1,10 +1,10 @@
#pragma once
#include <string>
#include <variant>
#include <ATen/core/Reduction.h>
#include <c10/util/Exception.h>
#include <c10/util/variant.h>
#include <torch/csrc/Export.h>
#define TORCH_ENUM_DECLARE(name) \
@ -42,7 +42,7 @@
//
// ```
// struct TORCH_API SomeOptions {
// typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
// typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
// reduction_t; SomeOptions(reduction_t reduction = torch::kMean) :
// reduction_(reduction) {}
//
@ -188,16 +188,16 @@ struct _compute_enum_name {
template <typename V>
std::string get_enum_name(V variant_enum) {
return c10::visit(enumtype::_compute_enum_name{}, variant_enum);
return std::visit(enumtype::_compute_enum_name{}, variant_enum);
}
template <typename V>
at::Reduction::Reduction reduction_get_enum(V variant_enum) {
if (c10::get_if<enumtype::kNone>(&variant_enum)) {
if (std::holds_alternative<enumtype::kNone>(variant_enum)) {
return at::Reduction::None;
} else if (c10::get_if<enumtype::kMean>(&variant_enum)) {
} else if (std::holds_alternative<enumtype::kMean>(variant_enum)) {
return at::Reduction::Mean;
} else if (c10::get_if<enumtype::kSum>(&variant_enum)) {
} else if (std::holds_alternative<enumtype::kSum>(variant_enum)) {
return at::Reduction::Sum;
} else {
TORCH_CHECK(

View File

@ -31,7 +31,7 @@ inline Tensor conv1d(
const Conv1dFuncOptions::padding_t& padding,
ExpandingArray<1> dilation,
int64_t groups) {
return c10::visit(
return std::visit(
[&](const auto& pad) {
return torch::conv1d(
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
@ -77,7 +77,7 @@ inline Tensor conv2d(
const Conv2dFuncOptions::padding_t& padding,
ExpandingArray<2> dilation,
int64_t groups) {
return c10::visit(
return std::visit(
[&](const auto& pad) {
return torch::conv2d(
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
@ -123,7 +123,7 @@ inline Tensor conv3d(
const Conv3dFuncOptions::padding_t& padding,
ExpandingArray<3> dilation,
int64_t groups) {
return c10::visit(
return std::visit(
[&](const auto& pad) {
return torch::conv3d(
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);

View File

@ -135,11 +135,11 @@ inline Tensor embedding_bag(
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int mode_enum;
if (c10::get_if<enumtype::kSum>(&mode)) {
if (std::holds_alternative<enumtype::kSum>(mode)) {
mode_enum = 0;
} else if (c10::get_if<enumtype::kMean>(&mode)) {
} else if (std::holds_alternative<enumtype::kMean>(mode)) {
mode_enum = 1;
} else if (c10::get_if<enumtype::kMax>(&mode)) {
} else if (std::holds_alternative<enumtype::kMax>(mode)) {
mode_enum = 2;
TORCH_CHECK(
!scale_grad_by_freq,
@ -155,7 +155,7 @@ inline Tensor embedding_bag(
}
TORCH_CHECK(
!per_sample_weights_.defined() || c10::get_if<enumtype::kSum>(&mode),
!per_sample_weights_.defined() || std::get_if<enumtype::kSum>(&mode),
"embedding_bag: per_sample_weights was not null. ",
"per_sample_weights is only supported for mode='kSum' (got mode='",
torch::enumtype::get_enum_name(mode),

View File

@ -50,7 +50,7 @@ inline Tensor kl_div(
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
torch::Reduction::Reduction reduction_enum;
if (c10::get_if<enumtype::kMean>(&reduction)) {
if (std::holds_alternative<enumtype::kMean>(reduction)) {
TORCH_WARN(
"reduction: 'mean' divides the total loss by both the batch size and the support size."
"'batchmean' divides only by the batch size, and aligns with the KL div math definition."
@ -58,7 +58,7 @@ inline Tensor kl_div(
}
// special case for batchmean
if (c10::get_if<enumtype::kBatchMean>(&reduction)) {
if (std::holds_alternative<enumtype::kBatchMean>(reduction)) {
reduction_enum = torch::Reduction::Sum;
} else {
reduction_enum = enumtype::reduction_get_enum(reduction);
@ -66,7 +66,8 @@ inline Tensor kl_div(
auto reduced = torch::kl_div(input, target, reduction_enum, log_target);
if (c10::get_if<enumtype::kBatchMean>(&reduction) && input.dim() != 0) {
if (std::holds_alternative<enumtype::kBatchMean>(reduction) &&
input.dim() != 0) {
reduced = reduced / input.sizes()[0];
}
@ -531,11 +532,11 @@ inline Tensor multilabel_soft_margin_loss(
Tensor ret;
if (c10::get_if<enumtype::kNone>(&reduction)) {
if (std::holds_alternative<enumtype::kNone>(reduction)) {
ret = loss;
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
} else if (std::holds_alternative<enumtype::kMean>(reduction)) {
ret = loss.mean();
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
} else if (std::holds_alternative<enumtype::kSum>(reduction)) {
ret = loss.sum();
} else {
ret = input;
@ -661,11 +662,11 @@ inline Tensor triplet_margin_with_distance_loss(
auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);
Tensor ret;
if (c10::get_if<enumtype::kNone>(&reduction)) {
if (std::holds_alternative<enumtype::kNone>(reduction)) {
ret = loss;
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
} else if (std::holds_alternative<enumtype::kMean>(reduction)) {
ret = loss.mean();
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
} else if (std::holds_alternative<enumtype::kSum>(reduction)) {
ret = loss.sum();
} else {
ret = anchor;

View File

@ -15,13 +15,13 @@ inline Tensor pad(
PadFuncOptions::mode_t mode,
double value) {
const auto mode_enum = [&] {
if (c10::get_if<enumtype::kConstant>(&mode)) {
if (std::holds_alternative<enumtype::kConstant>(mode)) {
return at::padding_mode::constant;
} else if (c10::get_if<enumtype::kReflect>(&mode)) {
} else if (std::holds_alternative<enumtype::kReflect>(mode)) {
return at::padding_mode::reflect;
} else if (c10::get_if<enumtype::kReplicate>(&mode)) {
} else if (std::holds_alternative<enumtype::kReplicate>(mode)) {
return at::padding_mode::replicate;
} else if (c10::get_if<enumtype::kCircular>(&mode)) {
} else if (std::holds_alternative<enumtype::kCircular>(mode)) {
return at::padding_mode::circular;
}
TORCH_CHECK(false, "Unrecognised padding mode");

View File

@ -86,8 +86,8 @@ inline Tensor interpolate(
c10::optional<bool> align_corners,
c10::optional<bool> recompute_scale_factor,
bool antialias) {
if (c10::get_if<enumtype::kNearest>(&mode) ||
c10::get_if<enumtype::kArea>(&mode)) {
if (std::holds_alternative<enumtype::kNearest>(mode) ||
std::get_if<enumtype::kArea>(&mode)) {
if (align_corners != c10::nullopt) {
TORCH_CHECK(
false,
@ -130,8 +130,8 @@ inline Tensor interpolate(
if (antialias &&
!(input.dim() == 4 &&
(c10::get_if<enumtype::kBilinear>(&mode) ||
c10::get_if<enumtype::kBicubic>(&mode)))) {
(std::get_if<enumtype::kBilinear>(&mode) ||
std::get_if<enumtype::kBicubic>(&mode)))) {
TORCH_CHECK(
false,
"Anti-alias option is only supported for bilinear and bicubic modes");
@ -139,65 +139,65 @@ inline Tensor interpolate(
auto closed_over_args =
std::make_tuple(input, size, scale_factor, recompute_scale_factor);
if (input.dim() == 3 && c10::get_if<enumtype::kNearest>(&mode)) {
if (input.dim() == 3 && std::get_if<enumtype::kNearest>(&mode)) {
return torch::upsample_nearest1d(
input,
_interp_output_size(1, std::move(closed_over_args)),
scale_factor_list.at(0));
} else if (input.dim() == 4 && c10::get_if<enumtype::kNearest>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kNearest>(&mode)) {
return torch::upsample_nearest2d(
input,
_interp_output_size(2, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1));
} else if (input.dim() == 5 && c10::get_if<enumtype::kNearest>(&mode)) {
} else if (input.dim() == 5 && std::get_if<enumtype::kNearest>(&mode)) {
return torch::upsample_nearest3d(
input,
_interp_output_size(3, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1),
scale_factor_list.at(2));
} else if (input.dim() == 3 && c10::get_if<enumtype::kNearestExact>(&mode)) {
} else if (input.dim() == 3 && std::get_if<enumtype::kNearestExact>(&mode)) {
return torch::_upsample_nearest_exact1d(
input,
_interp_output_size(1, std::move(closed_over_args)),
scale_factor_list.at(0));
} else if (input.dim() == 4 && c10::get_if<enumtype::kNearestExact>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kNearestExact>(&mode)) {
return torch::_upsample_nearest_exact2d(
input,
_interp_output_size(2, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1));
} else if (input.dim() == 5 && c10::get_if<enumtype::kNearestExact>(&mode)) {
} else if (input.dim() == 5 && std::get_if<enumtype::kNearestExact>(&mode)) {
return torch::_upsample_nearest_exact3d(
input,
_interp_output_size(3, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1),
scale_factor_list.at(2));
} else if (input.dim() == 3 && c10::get_if<enumtype::kArea>(&mode)) {
} else if (input.dim() == 3 && std::get_if<enumtype::kArea>(&mode)) {
return detail::adaptive_avg_pool1d(
input, _interp_output_size(1, std::move(closed_over_args)));
} else if (input.dim() == 4 && c10::get_if<enumtype::kArea>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kArea>(&mode)) {
return detail::adaptive_avg_pool2d(
input, _interp_output_size(2, std::move(closed_over_args)));
} else if (input.dim() == 5 && c10::get_if<enumtype::kArea>(&mode)) {
} else if (input.dim() == 5 && std::get_if<enumtype::kArea>(&mode)) {
return detail::adaptive_avg_pool3d(
input, _interp_output_size(3, std::move(closed_over_args)));
} else if (input.dim() == 3 && c10::get_if<enumtype::kLinear>(&mode)) {
} else if (input.dim() == 3 && std::get_if<enumtype::kLinear>(&mode)) {
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
return torch::upsample_linear1d(
input,
_interp_output_size(1, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0));
} else if (input.dim() == 3 && c10::get_if<enumtype::kBilinear>(&mode)) {
} else if (input.dim() == 3 && std::get_if<enumtype::kBilinear>(&mode)) {
TORCH_CHECK(false, "Got 3D input, but bilinear mode needs 4D input");
} else if (input.dim() == 3 && c10::get_if<enumtype::kTrilinear>(&mode)) {
} else if (input.dim() == 3 && std::get_if<enumtype::kTrilinear>(&mode)) {
TORCH_CHECK(false, "Got 3D input, but trilinear mode needs 5D input");
} else if (input.dim() == 4 && c10::get_if<enumtype::kLinear>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kLinear>(&mode)) {
TORCH_CHECK(false, "Got 4D input, but linear mode needs 3D input");
} else if (input.dim() == 4 && c10::get_if<enumtype::kBilinear>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kBilinear>(&mode)) {
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
if (antialias) {
return torch::_upsample_bilinear2d_aa(
@ -213,13 +213,13 @@ inline Tensor interpolate(
*align_corners,
scale_factor_list.at(0),
scale_factor_list.at(1));
} else if (input.dim() == 4 && c10::get_if<enumtype::kTrilinear>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kTrilinear>(&mode)) {
TORCH_CHECK(false, "Got 4D input, but trilinear mode needs 5D input");
} else if (input.dim() == 5 && c10::get_if<enumtype::kLinear>(&mode)) {
} else if (input.dim() == 5 && std::get_if<enumtype::kLinear>(&mode)) {
TORCH_CHECK(false, "Got 5D input, but linear mode needs 3D input");
} else if (input.dim() == 5 && c10::get_if<enumtype::kBilinear>(&mode)) {
} else if (input.dim() == 5 && std::get_if<enumtype::kBilinear>(&mode)) {
TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input");
} else if (input.dim() == 5 && c10::get_if<enumtype::kTrilinear>(&mode)) {
} else if (input.dim() == 5 && std::get_if<enumtype::kTrilinear>(&mode)) {
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
return torch::upsample_trilinear3d(
input,
@ -228,7 +228,7 @@ inline Tensor interpolate(
scale_factor_list.at(0),
scale_factor_list.at(1),
scale_factor_list.at(2));
} else if (input.dim() == 4 && c10::get_if<enumtype::kBicubic>(&mode)) {
} else if (input.dim() == 4 && std::get_if<enumtype::kBicubic>(&mode)) {
TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt);
if (antialias) {
return torch::_upsample_bicubic2d_aa(

View File

@ -63,17 +63,17 @@ inline Tensor grid_sample(
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t mode_enum, padding_mode_enum;
if (c10::get_if<enumtype::kBilinear>(&mode)) {
if (std::holds_alternative<enumtype::kBilinear>(mode)) {
mode_enum = 0;
} else if (c10::get_if<enumtype::kNearest>(&mode)) {
} else if (std::holds_alternative<enumtype::kNearest>(mode)) {
mode_enum = 1;
} else { /// mode == 'bicubic'
mode_enum = 2;
}
if (c10::get_if<enumtype::kZeros>(&padding_mode)) {
if (std::holds_alternative<enumtype::kZeros>(padding_mode)) {
padding_mode_enum = 0;
} else if (c10::get_if<enumtype::kBorder>(&padding_mode)) {
} else if (std::holds_alternative<enumtype::kBorder>(padding_mode)) {
padding_mode_enum = 1;
} else { /// padding_mode == 'reflection'
padding_mode_enum = 2;

View File

@ -8,7 +8,7 @@ namespace torch {
namespace nn {
namespace init {
using NonlinearityType = c10::variant<
using NonlinearityType = std::variant<
enumtype::kLinear,
enumtype::kConv1D,
enumtype::kConv2D,
@ -21,7 +21,7 @@ using NonlinearityType = c10::variant<
enumtype::kReLU,
enumtype::kLeakyReLU>;
using FanModeType = c10::variant<enumtype::kFanIn, enumtype::kFanOut>;
using FanModeType = std::variant<enumtype::kFanIn, enumtype::kFanOut>;
} // namespace init
} // namespace nn

View File

@ -42,7 +42,7 @@ class ConvNdImpl : public torch::nn::Cloneable<Derived> {
options.out_channels() % options.groups() == 0,
"out_channels must be divisible by groups");
c10::visit(
std::visit(
c10::overloaded(
[&](enumtype::kValid) {
_reversed_padding_repeated_twice.resize(2 * D);
@ -121,7 +121,7 @@ class ConvNdImpl : public torch::nn::Cloneable<Derived> {
<< "(" << options.in_channels() << ", " << options.out_channels()
<< ", kernel_size=" << options.kernel_size()
<< ", stride=" << options.stride();
c10::visit(
std::visit(
c10::overloaded(
[&](enumtype::kValid) { stream << ", padding='valid'"; },
[&](enumtype::kSame) { stream << ", padding='same'"; },
@ -143,7 +143,7 @@ class ConvNdImpl : public torch::nn::Cloneable<Derived> {
if (!options.bias()) {
stream << ", bias=" << std::boolalpha << false;
}
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
stream << ", padding_mode="
<< enumtype::get_enum_name(options.padding_mode());
}
@ -276,7 +276,7 @@ class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
explicit ConvTransposeNdImpl(detail::ConvNdOptions<D> options_)
: ConvNdImpl<D, Derived>(options_) {
TORCH_INTERNAL_ASSERT(
c10::holds_alternative<ExpandingArray<D>>(this->options.padding()),
std::holds_alternative<ExpandingArray<D>>(this->options.padding()),
"ConvTranspose padding cannot be a string");
}
@ -303,7 +303,7 @@ class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
if (!this->options.bias()) {
stream << ", bias=" << std::boolalpha << false;
}
if (!c10::get_if<enumtype::kZeros>(&this->options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&this->options.padding_mode())) {
stream << ", padding_mode="
<< enumtype::get_enum_name(this->options.padding_mode());
}
@ -312,7 +312,7 @@ class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
protected:
const ExpandingArray<D>& padding() const {
return c10::get<ExpandingArray<D>>(this->options.padding());
return std::get<ExpandingArray<D>>(this->options.padding());
}
std::vector<int64_t> _output_padding(

View File

@ -11,7 +11,7 @@ namespace nn {
namespace detail {
typedef c10::variant<
typedef std::variant<
enumtype::kZeros,
enumtype::kReflect,
enumtype::kReplicate,
@ -20,7 +20,7 @@ typedef c10::variant<
template <size_t D>
using conv_padding_t =
c10::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;
std::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;
/// Options for a `D`-dimensional convolution or convolution transpose module.
template <size_t D>

View File

@ -101,7 +101,7 @@ struct TORCH_API EmbeddingFuncOptions {
// ============================================================================
typedef c10::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
typedef std::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
EmbeddingBagMode;
/// Options for the `EmbeddingBag` module.

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/variant.h>
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/types.h>

View File

@ -15,7 +15,7 @@ namespace nn {
/// L1Loss model(L1LossOptions(torch::kNone));
/// ```
struct TORCH_API L1LossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG3(L1LossOptions, reduction, kNone, kMean, kSum)
@ -48,7 +48,7 @@ using L1LossFuncOptions = L1LossOptions;
/// model(KLDivLossOptions().reduction(torch::kNone).log_target(false));
/// ```
struct TORCH_API KLDivLossOptions {
typedef c10::variant<
typedef std::variant<
enumtype::kNone,
enumtype::kBatchMean,
enumtype::kSum,
@ -95,7 +95,7 @@ using KLDivFuncOptions = KLDivLossOptions;
/// MSELoss model(MSELossOptions(torch::kNone));
/// ```
struct TORCH_API MSELossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG3(MSELossOptions, reduction, kNone, kMean, kSum)
@ -128,7 +128,7 @@ using MSELossFuncOptions = MSELossOptions;
/// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight));
/// ```
struct TORCH_API BCELossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// A manual rescaling weight given to the loss of each batch element.
@ -163,7 +163,7 @@ using BinaryCrossEntropyFuncOptions = BCELossOptions;
/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone));
/// ```
struct TORCH_API HingeEmbeddingLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// Specifies the threshold for which the distance of a negative sample must
@ -197,7 +197,7 @@ using HingeEmbeddingLossFuncOptions = HingeEmbeddingLossOptions;
/// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight));
/// ```
struct TORCH_API MultiMarginLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// Has a default value of :math:`1`. :math:`1` and :math:`2`
@ -242,7 +242,7 @@ using MultiMarginLossFuncOptions = MultiMarginLossOptions;
/// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5));
/// ```
struct TORCH_API CosineEmbeddingLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// Specifies the threshold for which the distance of a negative sample must
@ -277,7 +277,7 @@ using CosineEmbeddingLossFuncOptions = CosineEmbeddingLossOptions;
/// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone));
/// ```
struct TORCH_API MultiLabelMarginLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
@ -318,7 +318,7 @@ using MultilabelMarginLossFuncOptions = MultiLabelMarginLossOptions;
/// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone));
/// ```
struct TORCH_API SoftMarginLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
@ -360,7 +360,7 @@ using SoftMarginLossFuncOptions = SoftMarginLossOptions;
/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight));
/// ```
struct TORCH_API MultiLabelSoftMarginLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// A manual rescaling weight given to each
@ -400,7 +400,7 @@ using MultilabelSoftMarginLossFuncOptions = MultiLabelSoftMarginLossOptions;
/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false));
/// ```
struct TORCH_API TripletMarginLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// Specifies the threshold for which the distance of a negative sample must
@ -442,7 +442,7 @@ using TripletMarginLossFuncOptions = TripletMarginLossOptions;
/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false));
/// ```
struct TORCH_API TripletMarginWithDistanceLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
typedef std::function<Tensor(const Tensor&, const Tensor&)>
distance_function_t;
@ -493,7 +493,7 @@ using TripletMarginWithDistanceLossFuncOptions =
/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum));
/// ```
struct TORCH_API CTCLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// blank label. Default `0`.
@ -530,7 +530,7 @@ using CTCLossFuncOptions = CTCLossOptions;
/// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5));
/// ```
struct TORCH_API SmoothL1LossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
@ -573,7 +573,7 @@ using SmoothL1LossFuncOptions = SmoothL1LossOptions;
/// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5));
/// ```
struct TORCH_API HuberLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG3(
@ -617,7 +617,7 @@ using HuberLossFuncOptions = HuberLossOptions;
/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum));
/// ```
struct TORCH_API PoissonNLLLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// if true the loss is computed as `exp(input) - target * input`,
@ -658,7 +658,7 @@ using PoissonNLLLossFuncOptions = PoissonNLLLossOptions;
/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum));
/// ```
struct TORCH_API MarginRankingLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// Has a default value of `0`.
@ -691,7 +691,7 @@ using MarginRankingLossFuncOptions = MarginRankingLossOptions;
/// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean));
/// ```
struct TORCH_API NLLLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// A manual rescaling weight given to each
@ -730,7 +730,7 @@ using NLLLossFuncOptions = NLLLossOptions;
/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean));
/// ```
struct TORCH_API CrossEntropyLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// A manual rescaling weight given to each class. If given, has to be a
@ -770,7 +770,7 @@ using CrossEntropyFuncOptions = CrossEntropyLossOptions;
/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight));
/// ```
struct TORCH_API BCEWithLogitsLossOptions {
typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
reduction_t;
/// A manual rescaling weight given to the loss of each batch element.
/// If given, has to be a Tensor of size `nbatch`.

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/variant.h>
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
@ -195,7 +194,7 @@ namespace functional {
/// 2}).mode(torch::kReplicate));
/// ```
struct TORCH_API PadFuncOptions {
typedef c10::variant<
typedef std::variant<
enumtype::kConstant,
enumtype::kReflect,
enumtype::kReplicate,

View File

@ -12,7 +12,7 @@ namespace detail {
/// Common options for RNN, LSTM and GRU modules.
struct TORCH_API RNNOptionsBase {
typedef c10::variant<
typedef std::variant<
enumtype::kLSTM,
enumtype::kGRU,
enumtype::kRNN_TANH,
@ -57,7 +57,7 @@ struct TORCH_API RNNOptionsBase {
/// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
/// ```
struct TORCH_API RNNOptions {
typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
RNNOptions(int64_t input_size, int64_t hidden_size);
@ -182,7 +182,7 @@ struct TORCH_API RNNCellOptionsBase {
/// 10).bias(false).nonlinearity(torch::kReLU));
/// ```
struct TORCH_API RNNCellOptions {
typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
RNNCellOptions(int64_t input_size, int64_t hidden_size);

View File

@ -8,7 +8,7 @@
namespace torch {
namespace nn {
using activation_t = c10::variant<
using activation_t = std::variant<
enumtype::kReLU,
enumtype::kGELU,
std::function<Tensor(const Tensor&)>>;

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/variant.h>
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
@ -28,7 +27,7 @@ struct TORCH_API UpsampleOptions {
/// the upsampling algorithm: one of "nearest", "linear", "bilinear",
/// "bicubic" and "trilinear". Default: "nearest"
typedef c10::variant<
typedef std::variant<
enumtype::kNearest,
enumtype::kLinear,
enumtype::kBilinear,
@ -55,7 +54,7 @@ namespace functional {
/// F::InterpolateFuncOptions().size(std::vector<int64_t>({4})).mode(torch::kNearest));
/// ```
struct TORCH_API InterpolateFuncOptions {
typedef c10::variant<
typedef std::variant<
enumtype::kNearest,
enumtype::kLinear,
enumtype::kBilinear,

View File

@ -18,8 +18,8 @@ namespace functional {
/// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true));
/// ```
struct TORCH_API GridSampleFuncOptions {
typedef c10::variant<enumtype::kBilinear, enumtype::kNearest> mode_t;
typedef c10::
typedef std::variant<enumtype::kBilinear, enumtype::kNearest> mode_t;
typedef std::
variant<enumtype::kZeros, enumtype::kBorder, enumtype::kReflection>
padding_mode_t;

View File

@ -47,7 +47,7 @@ double calculate_kaiming_std(
const auto gain = calculate_gain(nonlinearity, a);
double std = 0.0;
if (c10::get_if<enumtype::kFanIn>(&mode)) {
if (std::holds_alternative<enumtype::kFanIn>(mode)) {
std = gain / std::sqrt(fan.in);
} else {
std = gain / std::sqrt(fan.out);
@ -57,11 +57,11 @@ double calculate_kaiming_std(
} // namespace
double calculate_gain(NonlinearityType nonlinearity, double param) {
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
return 5.0 / 3.0; // NOLINT
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
} else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
return std::sqrt(2.0); // NOLINT
} else if (c10::get_if<enumtype::kLeakyReLU>(&nonlinearity)) {
} else if (std::holds_alternative<enumtype::kLeakyReLU>(nonlinearity)) {
return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
}

View File

@ -20,11 +20,13 @@ namespace F = torch::nn::functional;
static F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(
torch::nn::detail::conv_padding_mode_t conv_padding_mode) {
F::PadFuncOptions::mode_t pad_mode;
if (c10::get_if<torch::enumtype::kReflect>(&conv_padding_mode)) {
if (std::holds_alternative<torch::enumtype::kReflect>(conv_padding_mode)) {
pad_mode = torch::kReflect;
} else if (c10::get_if<torch::enumtype::kReplicate>(&conv_padding_mode)) {
} else if (std::holds_alternative<torch::enumtype::kReplicate>(
conv_padding_mode)) {
pad_mode = torch::kReplicate;
} else if (c10::get_if<torch::enumtype::kCircular>(&conv_padding_mode)) {
} else if (std::holds_alternative<torch::enumtype::kCircular>(
conv_padding_mode)) {
pad_mode = torch::kCircular;
} else {
TORCH_CHECK(
@ -52,7 +54,7 @@ Conv1dImpl::Conv1dImpl(Conv1dOptions options_)
.padding_mode(options_.padding_mode())) {}
Tensor Conv1dImpl::forward(const Tensor& input) {
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
return F::detail::conv1d(
F::pad(
input,
@ -91,7 +93,7 @@ Conv2dImpl::Conv2dImpl(Conv2dOptions options_)
.padding_mode(options_.padding_mode())) {}
Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) {
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
return F::detail::conv2d(
F::pad(
input,
@ -134,7 +136,7 @@ Conv3dImpl::Conv3dImpl(Conv3dOptions options_)
.padding_mode(options_.padding_mode())) {}
Tensor Conv3dImpl::forward(const Tensor& input) {
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
return F::detail::conv3d(
F::pad(
input,
@ -247,7 +249,7 @@ ConvTranspose1dImpl::ConvTranspose1dImpl(ConvTranspose1dOptions options_)
Tensor ConvTranspose1dImpl::forward(
const Tensor& input,
const c10::optional<at::IntArrayRef>& output_size) {
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
TORCH_CHECK(
false, "Only `zeros` padding mode is supported for ConvTranspose1d");
}
@ -284,7 +286,7 @@ ConvTranspose2dImpl::ConvTranspose2dImpl(ConvTranspose2dOptions options_)
Tensor ConvTranspose2dImpl::forward(
const Tensor& input,
const c10::optional<at::IntArrayRef>& output_size) {
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
TORCH_CHECK(
false, "Only `zeros` padding mode is supported for ConvTranspose2d");
}
@ -321,7 +323,7 @@ ConvTranspose3dImpl::ConvTranspose3dImpl(ConvTranspose3dOptions options_)
Tensor ConvTranspose3dImpl::forward(
const Tensor& input,
const c10::optional<at::IntArrayRef>& output_size) {
if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
TORCH_CHECK(
false, "Only `zeros` padding mode is supported for ConvTranspose3d");
}

View File

@ -167,7 +167,7 @@ void EmbeddingBagImpl::pretty_print(std::ostream& stream) const {
if (options.sparse()) {
stream << ", sparse=" << std::boolalpha << options.sparse();
}
if (!c10::get_if<enumtype::kMean>(&options.mode())) {
if (!std::get_if<enumtype::kMean>(&options.mode())) {
stream << ", mode=" << torch::enumtype::get_enum_name(options.mode());
}
if (options.include_last_offset()) {

View File

@ -30,13 +30,13 @@ enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
static CuDNNMode get_cudnn_mode_for_rnn(
detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
if (c10::get_if<enumtype::kRNN_RELU>(&mode)) {
if (std::holds_alternative<enumtype::kRNN_RELU>(mode)) {
return CuDNNMode::RNN_RELU;
} else if (c10::get_if<enumtype::kRNN_TANH>(&mode)) {
} else if (std::holds_alternative<enumtype::kRNN_TANH>(mode)) {
return CuDNNMode::RNN_TANH;
} else if (c10::get_if<enumtype::kLSTM>(&mode)) {
} else if (std::holds_alternative<enumtype::kLSTM>(mode)) {
return CuDNNMode::LSTM;
} else if (c10::get_if<enumtype::kGRU>(&mode)) {
} else if (std::holds_alternative<enumtype::kGRU>(mode)) {
return CuDNNMode::GRU;
} else {
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
@ -94,19 +94,19 @@ void RNNImplBase<Derived>::reset() {
if (options_base.proj_size() > 0) {
TORCH_CHECK(
c10::get_if<enumtype::kLSTM>(&options_base.mode()),
std::get_if<enumtype::kLSTM>(&options_base.mode()),
"proj_size argument is only supported for LSTM, not RNN or GRU");
}
int64_t gate_size = 0;
if (c10::get_if<enumtype::kLSTM>(&options_base.mode())) {
if (std::holds_alternative<enumtype::kLSTM>(options_base.mode())) {
gate_size = 4 * options_base.hidden_size();
} else if (c10::get_if<enumtype::kGRU>(&options_base.mode())) {
} else if (std::holds_alternative<enumtype::kGRU>(options_base.mode())) {
gate_size = 3 * options_base.hidden_size();
// NOLINTNEXTLINE(bugprone-branch-clone)
} else if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
} else if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
gate_size = options_base.hidden_size();
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
} else if (std::holds_alternative<enumtype::kRNN_RELU>(options_base.mode())) {
gate_size = options_base.hidden_size();
} else {
TORCH_CHECK(
@ -405,9 +405,9 @@ template class RNNImplBase<RNNImpl>;
static detail::RNNOptionsBase::rnn_options_base_mode_t
compute_rnn_options_base_mode(RNNOptions::nonlinearity_t nonlinearity) {
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
return torch::kRNN_TANH;
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
} else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
return torch::kRNN_RELU;
} else {
TORCH_CHECK(
@ -453,7 +453,7 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
std::tuple<Tensor, Tensor> result;
if (!batch_sizes.defined()) {
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
result = torch::rnn_tanh(
input,
hx,
@ -464,7 +464,8 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
this->is_training(),
options_base.bidirectional(),
options_base.batch_first());
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
} else if (std::holds_alternative<enumtype::kRNN_RELU>(
options_base.mode())) {
result = torch::rnn_relu(
input,
hx,
@ -482,7 +483,7 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
torch::enumtype::get_enum_name(options_base.mode()));
}
} else {
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
result = torch::rnn_tanh(
input,
batch_sizes,
@ -493,7 +494,8 @@ std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
options_base.dropout(),
this->is_training(),
options_base.bidirectional());
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
} else if (std::holds_alternative<enumtype::kRNN_RELU>(
options_base.mode())) {
result = torch::rnn_relu(
input,
batch_sizes,
@ -920,10 +922,10 @@ Tensor RNNCellImpl::forward(const Tensor& input, Tensor hx) {
r_hx = is_batched ? hx : hx.unsqueeze(0);
}
if (c10::get_if<enumtype::kTanh>(&options.nonlinearity())) {
if (std::holds_alternative<enumtype::kTanh>(options.nonlinearity())) {
ret = torch::rnn_tanh_cell(
r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
} else if (c10::get_if<enumtype::kReLU>(&options.nonlinearity())) {
} else if (std::holds_alternative<enumtype::kReLU>(options.nonlinearity())) {
ret = torch::rnn_relu_cell(
r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
} else {

View File

@ -71,14 +71,14 @@ Tensor TransformerEncoderLayerImpl::forward(
Tensor ret = norm1(src + dropout1(src2));
// feedforward
if (c10::get_if<enumtype::kGELU>(&options.activation())) {
if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
src2 = linear2(dropout(F::gelu(linear1(ret))));
} else if (c10::get_if<enumtype::kReLU>(&options.activation())) {
} else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
src2 = linear2(dropout(F::relu(linear1(ret))));
} else if (c10::get_if<std::function<Tensor(const Tensor&)>>(
&options.activation())) {
} else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
options.activation())) {
auto callable_activation =
*c10::get_if<std::function<Tensor(const Tensor&)>>(
*std::get_if<std::function<Tensor(const Tensor&)>>(
&options.activation());
src2 = linear2(dropout(callable_activation(linear1(ret))));
} else {
@ -198,14 +198,14 @@ Tensor TransformerDecoderLayerImpl::forward(
}
Tensor TransformerDecoderLayerImpl::activation(const Tensor& input) {
if (c10::get_if<enumtype::kGELU>(&options.activation())) {
if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
return F::gelu(input);
} else if (c10::get_if<enumtype::kReLU>(&options.activation())) {
} else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
return F::relu(input);
} else if (c10::get_if<std::function<Tensor(const Tensor&)>>(
&options.activation())) {
} else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
options.activation())) {
auto callable_activation =
*c10::get_if<std::function<Tensor(const Tensor&)>>(
*std::get_if<std::function<Tensor(const Tensor&)>>(
&options.activation());
return callable_activation(input);
} else {

View File

@ -25,15 +25,15 @@ void UpsampleImpl::pretty_print(std::ostream& stream) const {
Tensor UpsampleImpl::forward(const Tensor& input) {
F::InterpolateFuncOptions::mode_t mode;
if (c10::get_if<enumtype::kNearest>(&options.mode())) {
if (std::holds_alternative<enumtype::kNearest>(options.mode())) {
mode = torch::kNearest;
} else if (c10::get_if<enumtype::kLinear>(&options.mode())) {
} else if (std::holds_alternative<enumtype::kLinear>(options.mode())) {
mode = torch::kLinear;
} else if (c10::get_if<enumtype::kBilinear>(&options.mode())) {
} else if (std::holds_alternative<enumtype::kBilinear>(options.mode())) {
mode = torch::kBilinear;
} else if (c10::get_if<enumtype::kBicubic>(&options.mode())) {
} else if (std::holds_alternative<enumtype::kBicubic>(options.mode())) {
mode = torch::kBicubic;
} else if (c10::get_if<enumtype::kTrilinear>(&options.mode())) {
} else if (std::holds_alternative<enumtype::kTrilinear>(options.mode())) {
mode = torch::kTrilinear;
}

View File

@ -12,7 +12,6 @@
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/variant.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>

View File

@ -173,13 +173,13 @@ bool symbolicShapeAnalysisTestModeEnabled() {
return symbolic_shape_analysis_test_mode;
}
using SSArgument = c10::variant<ShapeArguments, IValue>;
using SSArgument = std::variant<ShapeArguments, IValue>;
static std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
if (const IValue* iv = std::get_if<IValue>(&sa)) {
out << *iv;
} else {
out << c10::get<ShapeArguments>(sa);
out << std::get<ShapeArguments>(sa);
}
return out;
}
@ -377,11 +377,11 @@ struct SymbolicShapeOpAnalyzer {
SSArgument& argument = inputs_[op_in_index];
Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
if (IValue* cur_val = c10::get_if<IValue>(&argument)) {
if (IValue* cur_val = std::get_if<IValue>(&argument)) {
GRAPH_DEBUG("Substituting constant input ", *cur_val);
replaceWithIValue(graph_in_var, *cur_val);
} else {
auto cur_arg = c10::get<ShapeArguments>(argument);
auto cur_arg = std::get<ShapeArguments>(argument);
if (cur_arg.has_dim()) {
graph_in_var->setType(ListType::ofInts());
}
@ -423,7 +423,7 @@ struct SymbolicShapeOpAnalyzer {
"Missing Arg for Shape Graph");
for (const auto index :
c10::irange(shape_compute_graph_->inputs().size())) {
auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]);
auto shape_arguments = std::get_if<ShapeArguments>(&inputs_[index]);
if (!shape_arguments || !shape_arguments->has_dim()) {
continue;
}
@ -1146,10 +1146,10 @@ calculateSymbolicShapesOnOp(
std::vector<SSArgument> ssa_args;
for (auto& arg : inputs) {
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
if (const IValue* ival = std::get_if<IValue>(&arg)) {
ssa_args.emplace_back(*ival);
} else {
const c10::SymbolicShape* ss = c10::get_if<c10::SymbolicShape>(&arg);
const c10::SymbolicShape* ss = std::get_if<c10::SymbolicShape>(&arg);
ssa_args.emplace_back(ShapeArguments(*ss));
}
}

View File

@ -1,10 +1,10 @@
#pragma once
#include <c10/util/variant.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
#include <unordered_map>
#include <utility>
#include <variant>
namespace torch {
namespace jit {
@ -49,7 +49,7 @@ PropagateShapesAndBuildLargeShapeComputeGraph(
TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
using SSAInput = c10::variant<IValue, c10::SymbolicShape>;
using SSAInput = std::variant<IValue, c10::SymbolicShape>;
TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,

View File

@ -8,7 +8,7 @@
namespace torch {
namespace jit {
namespace {
using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
using CanonicalArg = std::variant<CanonicalizedSymbolicShape, IValue>;
using CanonicalArgVec = std::vector<CanonicalArg>;
using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
@ -20,14 +20,14 @@ CanonicalArgVec cannonicalizeVec(
CanonicalArgVec canonical_args;
canonical_args.reserve(arg_vec.size());
for (auto& arg : arg_vec) {
if (const IValue* iv = c10::get_if<IValue>(&arg)) {
if (const IValue* iv = std::get_if<IValue>(&arg)) {
if (deep_copy) {
canonical_args.emplace_back(iv->deepcopy());
} else {
canonical_args.emplace_back(*iv);
}
} else {
auto& ss = c10::get<at::SymbolicShape>(arg);
auto& ss = std::get<at::SymbolicShape>(arg);
canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
}
}
@ -57,7 +57,7 @@ struct ArgumentsHasher {
hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
for (const CanonicalArg& arg : arg_vec) {
size_t cur_arg = 0;
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
if (const IValue* ival = std::get_if<IValue>(&arg)) {
// IValue doesn't hash List (as Python doesn't), so we will do a custom
// list hash
if (ival->isList()) {
@ -70,7 +70,7 @@ struct ArgumentsHasher {
cur_arg = IValue::hash(ival);
}
} else {
cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
cur_arg = std::get<CanonicalizedSymbolicShape>(arg).hash();
}
hash_val = at::hash_combine(hash_val, cur_arg);
}

View File

@ -5,7 +5,6 @@
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/FbcodeMaps.h>
#include <c10/util/variant.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/graph_node_list.h>
#include <torch/csrc/jit/ir/ir.h>
@ -912,7 +911,7 @@ class TORCH_API ProcessedNode {
// These should be noexcept, but some Android build is failing
// saying the noexcept specification doesn't match the calculated
// one. Maybe c10::variant is throwing it off?
// one. Maybe std::variant is throwing it off?
ProcessedNode(ProcessedNode&&) = default;
ProcessedNode(const ProcessedNode&) = delete;

View File

@ -1,4 +1,3 @@
#include <c10/util/variant.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <ATen/ExpandUtils.h>
@ -437,9 +436,9 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
}
if (vec.empty()) {
return BufList(); // Return arbitrarily typed vector
} else if (c10::get_if<BufHandle>(&vec[0])) {
} else if (std::get_if<BufHandle>(&vec[0])) {
return convertVecArgValue<BufHandle>(vec);
} else if (c10::get_if<int64_t>(&vec[0])) {
} else if (std::get_if<int64_t>(&vec[0])) {
return convertVecArgValue<int64_t>(vec);
}
throw unsupported_dtype();
@ -608,7 +607,7 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
argInputs.emplace_back(toArg(inp));
}
// handle optional bias
if (c10::get_if<ArgNone>(&argInputs[2])) {
if (std::get_if<ArgNone>(&argInputs[2])) {
Dtype dtype = outputType ? Dtype(*outputType) : kFloat;
std::vector<ExprHandle> biasShape;
biasShape.push_back(outputShape[1]);

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/variant.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
@ -59,23 +58,23 @@ std::vector<ExprHandle> computeIndicesToBroadcast(
const std::vector<ExprHandle>& inputSizes);
inline std::string getArgValueName(const ArgValue& a) {
if (c10::get_if<tensorexpr::BufHandle>(&a)) {
if (std::holds_alternative<tensorexpr::BufHandle>(a)) {
return "BufHandle";
} else if (c10::get_if<tensorexpr::VarHandle>(&a)) {
} else if (std::holds_alternative<tensorexpr::VarHandle>(a)) {
return "VarHandle";
} else if (c10::get_if<double>(&a)) {
} else if (std::holds_alternative<double>(a)) {
return "double";
} else if (c10::get_if<int64_t>(&a)) {
} else if (std::holds_alternative<int64_t>(a)) {
return "int64_t";
} else if (c10::get_if<bool>(&a)) {
} else if (std::holds_alternative<bool>(a)) {
return "bool";
} else if (c10::get_if<BufList>(&a)) {
} else if (std::holds_alternative<BufList>(a)) {
return "BufList";
} else if (c10::get_if<DoubleList>(&a)) {
} else if (std::holds_alternative<DoubleList>(a)) {
return "DoubleList";
} else if (c10::get_if<IntList>(&a)) {
} else if (std::holds_alternative<IntList>(a)) {
return "IntList";
} else if (c10::get_if<ArgNone>(&a)) {
} else if (std::holds_alternative<ArgNone>(a)) {
return "None";
} else {
throw std::runtime_error("ArgValue type not handled in string conversion");
@ -86,7 +85,7 @@ template <class T>
std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
std::vector<T> res;
for (auto& x : v) {
auto val = c10::get_if<T>(&x);
auto val = std::get_if<T>(&x);
if (val) {
res.push_back(*val);
} else {

View File

@ -517,11 +517,11 @@ int nnc_lowerings_lazy_registration() {
at::Device device) {
bool noMin = false;
bool noMax = false;
if (c10::get_if<ArgNone>(&inputs[1])) {
if (std::get_if<ArgNone>(&inputs[1])) {
noMin = true;
}
if (c10::get_if<ArgNone>(&inputs[2])) {
if (std::get_if<ArgNone>(&inputs[2])) {
noMax = true;
}
@ -583,7 +583,7 @@ int nnc_lowerings_lazy_registration() {
const c10::optional<ScalarType>& outputType,
at::Device device) {
// check if the activation is quantized
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
const BufHandle& x = std::get<BufHandle>(inputs[0]);
if (x.node()->qscale()) {
return computeQuantizedSigmoidExternalCall(
inputs, outputShape, outputStrides, outputType, device);
@ -675,7 +675,7 @@ int nnc_lowerings_lazy_registration() {
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
auto A = std::get<BufHandle>(inputs[0]);
if (A.node()->qscale()) {
return computeQuantizedRelu(
inputs, outputShape, outputStrides, outputType, device);
@ -741,7 +741,7 @@ int nnc_lowerings_lazy_registration() {
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
const auto& kApproximate = c10::get<std::string>(inputs[1]);
const auto& kApproximate = std::get<std::string>(inputs[1]);
std::vector<ArgValue> operands = {inputs.front()};
if (at::native::get_gelutype_enum(kApproximate) ==
at::native::GeluType::Tanh) {
@ -987,7 +987,7 @@ int nnc_lowerings_lazy_registration() {
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
const BufHandle& rhs = c10::get<BufHandle>(inputs[1]);
const BufHandle& rhs = std::get<BufHandle>(inputs[1]);
auto dtype = rhs.dtype();
return computeOneOperand(
"aten_type_as",
@ -1708,7 +1708,7 @@ int nnc_lowerings_lazy_registration() {
// outputShape,
// [&](const std::vector<VarHandle>& axes) {
// int64_t dim =
// at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]),
// at::maybe_wrap_dim(std::get<int64_t>(inputs[1]),
// axes.size());
// ExprHandle start = constant(inputs[2]);
// ExprHandle stride = constant(inputs[4]);
@ -1730,7 +1730,7 @@ int nnc_lowerings_lazy_registration() {
outputShape,
outputStrides,
[&](const std::vector<VarHandle>& axes) {
int64_t dim = c10::get<int64_t>(inputs[1]);
int64_t dim = std::get<int64_t>(inputs[1]);
if (dim < 0) {
if (axes.empty()) {
throw malformed_input("axes are zero handling unsqueeze");
@ -1749,7 +1749,7 @@ int nnc_lowerings_lazy_registration() {
}
}
return broadcast(c10::get<BufHandle>(inputs[0]), indices);
return broadcast(std::get<BufHandle>(inputs[0]), indices);
});
});
RegisterNNCLoweringsFunction aten_t(
@ -1776,7 +1776,7 @@ int nnc_lowerings_lazy_registration() {
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
auto A = std::get<BufHandle>(inputs[0]);
// Trivial case of 0-dim tensors: just a copy of the input
if (A.ndim() == 0) {
auto tensor = Compute(
@ -1793,7 +1793,7 @@ int nnc_lowerings_lazy_registration() {
}
return tensor;
}
auto permute_dims = c10::get<IntList>(inputs[1]);
auto permute_dims = std::get<IntList>(inputs[1]);
auto tensor = Compute(
"aten_permute",
outputShape,

View File

@ -2,7 +2,6 @@
// IR.
#pragma once
#include <c10/util/variant.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
@ -13,11 +12,11 @@ namespace torch {
namespace jit {
namespace tensorexpr {
using ArgNone = c10::monostate;
using ArgNone = std::monostate;
using BufList = std::vector<tensorexpr::BufHandle>;
using DoubleList = std::vector<double>;
using IntList = std::vector<int64_t>;
using ArgValue = c10::variant<
using ArgValue = std::variant<
tensorexpr::BufHandle,
tensorexpr::VarHandle,
double,

View File

@ -246,18 +246,18 @@ Tensor conv2d_depthwise(
}
static std::vector<int64_t> _pair_int(ArgValue v) {
if (auto t = c10::get_if<IntList>(&v)) {
if (auto t = std::get_if<IntList>(&v)) {
return {(*t)[0], (*t)[1]};
}
auto i = c10::get<int64_t>(v);
auto i = std::get<int64_t>(v);
return {i, i};
}
static std::vector<int64_t> _single_int_list(ArgValue v) {
if (auto t = c10::get_if<IntList>(&v)) {
if (auto t = std::get_if<IntList>(&v)) {
return {(*t)[0]};
}
auto i = c10::get<int64_t>(v);
auto i = std::get<int64_t>(v);
return {i};
}
@ -361,15 +361,15 @@ Tensor computeConv2d(
}
BufHandle ResultBuf("conv", outputShape, dtype);
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
const BufHandle& w = c10::get<BufHandle>(inputs[1]);
const BufHandle& b = c10::get<BufHandle>(inputs[2]);
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
const BufHandle& w = std::get<BufHandle>(inputs[1]);
const BufHandle& b = std::get<BufHandle>(inputs[2]);
auto strides = _pair_int(inputs[3]);
auto padding = _pair_int(inputs[4]);
auto dilation = _pair_int(inputs[5]);
int groups = c10::get<int64_t>(inputs[6]);
int groups = std::get<int64_t>(inputs[6]);
auto inpInfo = getTensorInfo(inp);
auto wInfo = getTensorInfo(w);
@ -409,15 +409,15 @@ Tensor computeConv1d(
}
BufHandle ResultBuf("conv", outputShape, dtype);
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
const BufHandle& w = c10::get<BufHandle>(inputs[1]);
const BufHandle& b = c10::get<BufHandle>(inputs[2]);
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
const BufHandle& w = std::get<BufHandle>(inputs[1]);
const BufHandle& b = std::get<BufHandle>(inputs[2]);
auto strides = _single_int_list(inputs[3]);
auto padding = _single_int_list(inputs[4]);
auto dilation = _single_int_list(inputs[5]);
int groups = c10::get<int64_t>(inputs[6]);
int groups = std::get<int64_t>(inputs[6]);
auto inpInfo = getTensorInfo(inp);
auto wInfo = getTensorInfo(w);
@ -443,8 +443,8 @@ Tensor computePrepackedConv2dClampRun(
}
BufHandle ResultBuf("prepacked_conv2d_clamp_run", outputShape, dtype);
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
StmtPtr s = ExternalCall::make(
ResultBuf, "nnc_prepacked_conv2d_clamp_run", {inp, prepacked}, {});
return Tensor(ResultBuf.node(), s);
@ -462,8 +462,8 @@ Tensor computePrepackedLinearClampRun(
}
BufHandle ResultBuf("prepacked_linear_clamp_run", outputShape, dtype);
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
StmtPtr s = ExternalCall::make(
ResultBuf, "nnc_prepacked_linear_clamp_run", {inp, prepacked}, {});
return Tensor(ResultBuf.node(), s);
@ -482,8 +482,8 @@ Tensor computeMkldnnPrepackedConvRun(
BufHandle ResultBuf(
"mkldnn_prepacked_conv_run", outputShape, outputStrides, dtype);
const BufHandle& inp = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const BufHandle& inp = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
StmtPtr s = ExternalCall::make(
ResultBuf, "nnc_mkldnn_prepacked_conv_run", {inp, prepacked}, {});
return Tensor(ResultBuf.node(), s);

View File

@ -16,8 +16,8 @@ Tensor computeMatmul(
dtype = Dtype(*outputType);
}
BufHandle ResultBuf("matmul", outputShape, dtype);
const BufHandle a = c10::get<BufHandle>(inputs[0]);
const BufHandle b = c10::get<BufHandle>(inputs[1]);
const BufHandle a = std::get<BufHandle>(inputs[0]);
const BufHandle b = std::get<BufHandle>(inputs[1]);
auto size_a = a.dims();
auto size_b = b.dims();
@ -68,11 +68,11 @@ Tensor computeAddMM(
ExternalCall::make(
ResultBuf,
"nnc_aten_addmm",
{c10::get<BufHandle>(inputs[0]),
c10::get<BufHandle>(inputs[1]),
c10::get<BufHandle>(inputs[2])},
{c10::get<int64_t>(inputs[3]),
c10::get<int64_t>(
{std::get<BufHandle>(inputs[0]),
std::get<BufHandle>(inputs[1]),
std::get<BufHandle>(inputs[2])},
{std::get<int64_t>(inputs[3]),
std::get<int64_t>(
inputs[4])})); // TODO: handle other dtypes of alpha and beta
}

View File

@ -249,7 +249,7 @@ std::vector<ExprHandle> broadcastShapes(
}
std::vector<ExprHandle> valueShape(const ArgValue& v) {
if (auto b = c10::get_if<tensorexpr::BufHandle>(&v)) {
if (auto b = std::get_if<tensorexpr::BufHandle>(&v)) {
return b->dims();
}
return {};
@ -258,14 +258,14 @@ std::vector<ExprHandle> valueShape(const ArgValue& v) {
ExprHandle tensorOrConstant(
const ArgValue& v,
const std::vector<ExprHandle>& axes) {
if (auto b = c10::get_if<BufHandle>(&v)) {
if (auto b = std::get_if<BufHandle>(&v)) {
return broadcast(*b, axes);
}
return constant(v);
}
ExprHandle scalarOrConstant(const ArgValue& v) {
if (auto vh = c10::get_if<VarHandle>(&v)) {
if (auto vh = std::get_if<VarHandle>(&v)) {
return *vh;
}
return constant(v);
@ -276,15 +276,15 @@ ExprHandle broadcast(BufHandle b, const std::vector<ExprHandle>& axes) {
}
ExprHandle constant(const ArgValue& v) {
if (auto s = c10::get_if<tensorexpr::VarHandle>(&v)) {
if (auto s = std::get_if<tensorexpr::VarHandle>(&v)) {
return *s;
} else if (auto d = c10::get_if<double>(&v)) {
} else if (auto d = std::get_if<double>(&v)) {
return DoubleImm::make(*d);
} else if (auto i = c10::get_if<int64_t>(&v)) {
} else if (auto i = std::get_if<int64_t>(&v)) {
return LongImm::make(*i);
} else if (auto b = c10::get_if<bool>(&v)) {
} else if (auto b = std::get_if<bool>(&v)) {
return BoolImm::make(*b);
} else if (c10::get_if<ArgNone>(&v)) {
} else if (std::get_if<ArgNone>(&v)) {
// This is just a placeholder so we don't throw. None-handling
// is operator-specific and should be handled properly in
// the operator-specific lowering code.
@ -327,10 +327,10 @@ Tensor computeChunk(
"prim_constantchunk",
outputShape,
[inputs](const std::vector<VarHandle>& axes) {
const auto& b = c10::get<BufHandle>(inputs[0]);
int64_t chunkIdx = c10::get<int64_t>(inputs[1]);
int64_t dim = c10::get<int64_t>(inputs[2]);
int64_t chunks = c10::get<int64_t>(inputs[3]);
const auto& b = std::get<BufHandle>(inputs[0]);
int64_t chunkIdx = std::get<int64_t>(inputs[1]);
int64_t dim = std::get<int64_t>(inputs[2]);
int64_t chunks = std::get<int64_t>(inputs[3]);
std::vector<ExprHandle> indices(axes.begin(), axes.end());
auto norm_dim = normalizeAndCheckIndex(dim, indices.size());
@ -357,7 +357,7 @@ Tensor computeTranspose(
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
auto A = std::get<BufHandle>(inputs[0]);
// Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
if (A.ndim() <= 1) {
return Compute(
@ -369,8 +369,8 @@ Tensor computeTranspose(
});
}
// Usual case where transpose actually swaps dimensions
auto start_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
auto start_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[1]), A.ndim());
auto to_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[2]), A.ndim());
return Compute(
"aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
std::swap(axes[start_dim], axes[to_dim]);
@ -384,7 +384,7 @@ Tensor computeExpand(
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
auto A = std::get<BufHandle>(inputs[0]);
return Compute(
"aten_expand", outputShape, [&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
@ -398,7 +398,7 @@ Tensor computeReshape(
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
auto A = std::get<BufHandle>(inputs[0]);
if (A.ndim() == 0) {
return Compute(
"aten_view", outputShape, [&](const std::vector<VarHandle>& axes) {
@ -513,7 +513,7 @@ static Tensor computeCatWoConditionals(
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto input_list = c10::get<BufList>(inputs[0]);
auto input_list = std::get<BufList>(inputs[0]);
auto arg_dim = inputs[1];
auto cat_info = processCatList(input_list);
ScalarType high_type = cat_info.first;
@ -547,7 +547,7 @@ static Tensor computeCatWoConditionals(
output_buf, alloc<tensorexpr::Block>(std::vector<StmtPtr>({})));
}
int64_t concat_dim = c10::get<int64_t>(arg_dim);
int64_t concat_dim = std::get<int64_t>(arg_dim);
auto norm_concat_dim = normalizeAndCheckIndex(concat_dim, outputShape.size());
auto loop_order_fn = [&](const BufPtr& buf_) {
@ -628,7 +628,7 @@ Tensor computeCat(
return computeCatWoConditionals(inputs, outputShape, outputStrides);
}
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto inputList = c10::get<BufList>(inputs[0]);
auto inputList = std::get<BufList>(inputs[0]);
auto argDim = inputs[1];
auto catInfo = processCatList(inputList);
ScalarType highType = catInfo.first;
@ -642,7 +642,7 @@ Tensor computeCat(
return ExprHandle(0);
}
int64_t dim_ = c10::get<int64_t>(argDim);
int64_t dim_ = std::get<int64_t>(argDim);
auto dim = normalizeAndCheckIndex(dim_, axes.size());
// Promote input types.
// Note that we need to consider all inputs, including empty - they
@ -693,8 +693,8 @@ Tensor computeEmbedding(
}
BufHandle ResultBuf("emb", outputShape, dtype);
const BufHandle& w = c10::get<BufHandle>(inputs[0]);
const BufHandle& indices = c10::get<BufHandle>(inputs[1]);
const BufHandle& w = std::get<BufHandle>(inputs[0]);
const BufHandle& indices = std::get<BufHandle>(inputs[1]);
StmtPtr s =
ExternalCall::make(ResultBuf, "nnc_aten_embedding", {w, indices}, {});

View File

@ -14,11 +14,11 @@ Tensor computeBatchNorm(
bool hasWeight = true;
bool hasBias = true;
if (c10::get_if<ArgNone>(&inputs[1])) {
if (std::holds_alternative<ArgNone>(inputs[1])) {
hasWeight = false;
}
if (c10::get_if<ArgNone>(&inputs[2])) {
if (std::holds_alternative<ArgNone>(inputs[2])) {
hasBias = false;
}

View File

@ -11,10 +11,10 @@ namespace jit {
namespace tensorexpr {
namespace {
std::vector<int64_t> _pair_int(ArgValue v) {
if (auto t = c10::get_if<IntList>(&v)) {
if (auto t = std::get_if<IntList>(&v)) {
return {(*t)[0], (*t)[1]};
}
auto i = c10::get<int64_t>(v);
auto i = std::get<int64_t>(v);
return {i, i};
}
} // namespace
@ -161,7 +161,7 @@ Tensor computeQuantizePerTensor(
return Dtype(ScalarType::QUInt8);
}
throw malformed_input("Expected quantized dtype");
}(c10::get<int64_t>(inputs[3]));
}(std::get<int64_t>(inputs[3]));
ExprHandle e =
quant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero);
@ -183,14 +183,14 @@ Tensor computeQuantizedAdd(
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
const BufHandle& QA = c10::get<BufHandle>(inputs[0]);
const BufHandle& QB = c10::get<BufHandle>(inputs[1]);
const BufHandle& QA = std::get<BufHandle>(inputs[0]);
const BufHandle& QB = std::get<BufHandle>(inputs[1]);
auto qa_scale = ExprHandle(QA.node()->qscale());
auto qa_zero = ExprHandle(QA.node()->qzero());
auto qb_scale = ExprHandle(QB.node()->qscale());
auto qb_zero = ExprHandle(QB.node()->qzero());
ExprHandle out_qscale = DoubleImm::make(c10::get<double>(inputs[2]));
ExprHandle out_qzero = LongImm::make(c10::get<int64_t>(inputs[3]));
ExprHandle out_qscale = DoubleImm::make(std::get<double>(inputs[2]));
ExprHandle out_qzero = LongImm::make(std::get<int64_t>(inputs[3]));
Dtype dequant_dtype = kFloat;
Dtype out_dtype = outputType ? Dtype(*outputType) : QA.dtype();
std::vector<VarPtr> vars;
@ -227,10 +227,10 @@ Tensor computeQuantizePerTensorExternalCall(
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
at::Device) {
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
const auto qscale = c10::get<double>(inputs[1]);
const auto qzero = c10::get<int64_t>(inputs[2]);
const auto qdtype = c10::get<int64_t>(inputs[3]);
const BufHandle& x = std::get<BufHandle>(inputs[0]);
const auto qscale = std::get<double>(inputs[1]);
const auto qzero = std::get<int64_t>(inputs[2]);
const auto qdtype = std::get<int64_t>(inputs[3]);
const auto dtype = [](auto qdtype) {
if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
@ -264,7 +264,7 @@ Tensor computeDequantizeExternalCall(
dtype = Dtype(*outputType);
}
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const int64_t qdtype = (int64_t)immQDType(qx);
BufHandle ResultBuf("dequantize", outputShape, dtype);
@ -290,12 +290,12 @@ Tensor computeQuantizedConv2dPrepack(
}
BufHandle ResultBuf("quantized_conv2d_prepack", outputShape, dtype);
const BufHandle& qw = c10::get<BufHandle>(inputs[0]);
const BufHandle& b = c10::get<BufHandle>(inputs[1]);
const BufHandle& qw = std::get<BufHandle>(inputs[0]);
const BufHandle& b = std::get<BufHandle>(inputs[1]);
auto strides = _pair_int(inputs[2]);
auto padding = _pair_int(inputs[3]);
auto dilation = _pair_int(inputs[4]);
int groups = c10::get<int64_t>(inputs[5]);
int groups = std::get<int64_t>(inputs[5]);
TORCH_INTERNAL_ASSERT(
qw.node()->qscale(),
buildErrorMessage(
@ -335,10 +335,10 @@ Tensor computeQuantizedConv1d(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qx);
auto ResultBuf = makeQBufHandleChannelsLast(
@ -367,10 +367,10 @@ Tensor computeQuantizedConv2d(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qx);
auto ResultBuf = makeQBufHandleChannelsLast(
@ -399,10 +399,10 @@ Tensor computeQuantizedConv2dRelu(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qx);
auto ResultBuf = makeQBufHandleChannelsLast(
@ -431,10 +431,10 @@ Tensor computeQuantizedLinear(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qx);
auto ResultBuf = makeQBufHandleContiguous(
@ -463,10 +463,10 @@ Tensor computeQuantizedLinearRelu(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qx);
auto ResultBuf = makeQBufHandleContiguous(
@ -495,10 +495,10 @@ Tensor computeQuantizedAddExternalCall(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
const BufHandle& qb = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
const BufHandle& qb = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qa);
const bool isQAChannelsLast = isChannelsLast(qa);
@ -539,10 +539,10 @@ Tensor computeQuantizedMul(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
const BufHandle& qb = c10::get<BufHandle>(inputs[1]);
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
const BufHandle& qb = std::get<BufHandle>(inputs[1]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qa);
auto ResultBuf = makeQBufHandleContiguous(
@ -570,8 +570,8 @@ Tensor computeQuantizedMulScalar(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
const auto scalar = c10::get<double>(inputs[1]);
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
const auto scalar = std::get<double>(inputs[1]);
// Change to dtype based on outputType when dtype propagation implemented
const auto out_qdtype = immQDType(qa);
double scale1 = immQScale(qa);
@ -597,7 +597,7 @@ Tensor computeQuantizedRelu(
const c10::optional<ScalarType>& outputType,
// NOLINTNEXTLINE
at::Device device) {
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
const BufHandle& qa = std::get<BufHandle>(inputs[0]);
const auto out_qdtype = immQDType(qa);
const bool isQAChannelsLast = isChannelsLast(qa);
auto ResultBuf = isQAChannelsLast ? makeQBufHandleChannelsLast(
@ -629,12 +629,12 @@ Tensor computeQuantizedCat(
// NOLINTNEXTLINE
at::Device device) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto inputList = c10::get<BufList>(inputs[0]);
auto argDim = c10::get<int64_t>(inputs[1]);
auto inputList = std::get<BufList>(inputs[0]);
auto argDim = std::get<int64_t>(inputs[1]);
auto n = inputList.size();
// TODO: handle optional out_qscale, out_qzero
const auto out_qscale = c10::get<double>(inputs[2]);
const auto out_qzero = c10::get<int64_t>(inputs[3]);
const auto out_qscale = std::get<double>(inputs[2]);
const auto out_qzero = std::get<int64_t>(inputs[3]);
std::vector<BufHandle> args;
std::vector<ExprHandle> extra_args;
@ -669,7 +669,7 @@ Tensor computeDequantize(
if (outputType) {
dtype = Dtype(*outputType);
}
auto qx = c10::get<BufHandle>(inputs[0]);
auto qx = std::get<BufHandle>(inputs[0]);
TORCH_INTERNAL_ASSERT(
qx.node()->qscale(),
buildErrorMessage("Missing quantized scale for dequantize"));
@ -697,7 +697,7 @@ Tensor computeUpsampleNearest2d(
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device) {
auto A = c10::get<BufHandle>(inputs[0]);
auto A = std::get<BufHandle>(inputs[0]);
const auto& output_height = outputShape[2];
const auto& output_width = outputShape[3];
auto input_height = ExprHandle(A.dim(2));
@ -750,18 +750,18 @@ Tensor computeUpsampleNearest2dExternalCall(
}
int64_t output_size_h = -1;
int64_t output_size_w = -1;
if (auto output_sizes = c10::get_if<IntList>(&inputs[1])) {
if (auto output_sizes = std::get_if<IntList>(&inputs[1])) {
output_size_h = (*output_sizes)[0];
output_size_w = (*output_sizes)[1];
}
double scale_factor_h = -1.f;
double scale_factor_w = -1.f;
if (auto scale_factors = c10::get_if<DoubleList>(&inputs[2])) {
if (auto scale_factors = std::get_if<DoubleList>(&inputs[2])) {
scale_factor_h = (*scale_factors)[0];
scale_factor_w = (*scale_factors)[1];
}
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
const BufHandle& x = std::get<BufHandle>(inputs[0]);
double qx_qscale = -1.f;
int64_t qx_qzero = -1l;
int64_t qx_qdtype = -1l;
@ -804,7 +804,7 @@ Tensor computeQuantizedSigmoidExternalCall(
// NOLINTNEXTLINE
const c10::optional<ScalarType>& outputType,
at::Device) {
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
const BufHandle& qx = std::get<BufHandle>(inputs[0]);
const auto out_qdtype = immQDType(qx);
const double out_qscale = 1.0f / 256.0f;

View File

@ -32,7 +32,7 @@ Tensor computeSum(
size_t rank = sizes.size();
if (inputs.size() > 2) {
if (auto emptyAxes = c10::get_if<BufList>(&inputs[1])) {
if (auto emptyAxes = std::get_if<BufList>(&inputs[1])) {
// If dim-array is an empty list, it will appear as BufList instead of
// IntList, and hence we need a special handling for it.
// In that case, we need to sum over all axes.
@ -40,7 +40,7 @@ Tensor computeSum(
axes.resize(rank);
std::iota(axes.begin(), axes.end(), 0);
} else if (rank > 0) {
auto nodeAxes = c10::get<IntList>(inputs[1]);
auto nodeAxes = std::get<IntList>(inputs[1]);
// Canonicalize axes: wrap around, sort and make unique.
for (auto axis : nodeAxes) {
axes.push_back(at::maybe_wrap_dim(axis, rank));
@ -48,7 +48,7 @@ Tensor computeSum(
std::sort(axes.begin(), axes.end());
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
}
keepdim = c10::get<bool>(inputs[2]);
keepdim = std::get<bool>(inputs[2]);
} else {
axes.resize(rank);
std::iota(axes.begin(), axes.end(), 0);
@ -116,13 +116,13 @@ Tensor computeMean(
}
bool keepdim = false;
BufHandle ResultBuf("mean", outputShape, dtype);
BufHandle InputBuf = c10::get<BufHandle>(inputs[0]);
BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
std::vector<ExprHandle> extra_args;
if (inputs.size() > 2) {
keepdim = c10::get<bool>(inputs[2]);
keepdim = std::get<bool>(inputs[2]);
}
if (auto mean_dims = c10::get_if<IntList>(&inputs[1])) {
if (auto mean_dims = std::get_if<IntList>(&inputs[1])) {
extra_args = c10::fmap<ExprHandle>(*mean_dims);
} else {
// When dims argument is not specified, reduce over all dimensions
@ -147,10 +147,10 @@ Tensor computeMax(
dtype = Dtype(*outputType);
}
BufHandle ResultBuf("max", outputShape, dtype);
BufHandle InputBuf = c10::get<BufHandle>(inputs[0]);
BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
std::vector<ExprHandle> max_dims_expr;
auto max_dim = c10::get<int64_t>(inputs[1]);
auto keep_dim = c10::get<bool>(inputs[2]);
auto max_dim = std::get<int64_t>(inputs[1]);
auto keep_dim = std::get<bool>(inputs[2]);
return Tensor(
ResultBuf.node(),
ExternalCall::make(
@ -172,13 +172,13 @@ Tensor computeAdaptiveAvgPool2d(
}
BufHandle ResultBuf("adaptive_avgpool2d", outputShape, dtype);
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto out_size_param = c10::get<IntList>(inputs[1]);
auto out_size_param = std::get<IntList>(inputs[1]);
return Tensor(
ResultBuf.node(),
ExternalCall::make(
ResultBuf,
"nnc_aten_adaptive_avg_pool2d",
{c10::get<BufHandle>(inputs[0])},
{std::get<BufHandle>(inputs[0])},
c10::fmap<ExprHandle>(out_size_param)));
}

View File

@ -44,10 +44,10 @@ Tensor computeSoftmax(
// We do not handle None for dims (input 1) because that is supposed to
// be deprecated.
TORCH_INTERNAL_ASSERT(c10::get_if<int64_t>(&inputs[1]));
TORCH_INTERNAL_ASSERT(std::get_if<int64_t>(&inputs[1]));
int64_t rank = valueShape(inputs[0]).size();
size_t softmax_dim =
normalizeAndCheckIndex(c10::get<int64_t>(inputs[1]), rank);
normalizeAndCheckIndex(std::get<int64_t>(inputs[1]), rank);
std::vector<ExprHandle> non_softmax_dims;
for (size_t i = 0; i < outputShape.size(); ++i) {
if (i != softmax_dim) {
@ -93,10 +93,10 @@ Tensor computeSoftmax(
return new_indices;
};
auto inp_buf = c10::get<BufHandle>(inputs[0]);
auto inp_buf = std::get<BufHandle>(inputs[0]);
auto dtype = inp_buf.dtype();
if (auto d = c10::get_if<int64_t>(&inputs[2])) {
if (auto d = std::get_if<int64_t>(&inputs[2])) {
dtype = ToDtype(static_cast<ScalarType>(*d));
}

View File

@ -731,25 +731,25 @@ void initTensorExprBindings(PyObject* module) {
}))
.def(
"as_buf",
[](const ArgValue& self) { return c10::get<BufHandle>(self); })
[](const ArgValue& self) { return std::get<BufHandle>(self); })
.def(
"as_var",
[](const ArgValue& self) { return c10::get<VarHandle>(self); })
[](const ArgValue& self) { return std::get<VarHandle>(self); })
.def(
"as_float",
[](const ArgValue& self) { return c10::get<double>(self); })
[](const ArgValue& self) { return std::get<double>(self); })
.def(
"as_int",
[](const ArgValue& self) { return c10::get<int64_t>(self); })
.def("as_bool", [](const ArgValue& self) { return c10::get<bool>(self); })
[](const ArgValue& self) { return std::get<int64_t>(self); })
.def("as_bool", [](const ArgValue& self) { return std::get<bool>(self); })
.def(
"as_none",
[](const ArgValue& self) { return c10::get<ArgNone>(self); })
[](const ArgValue& self) { return std::get<ArgNone>(self); })
.def(
"as_buflist",
[](const ArgValue& self) { return c10::get<BufList>(self); })
[](const ArgValue& self) { return std::get<BufList>(self); })
.def("as_intlist", [](const ArgValue& self) {
return c10::get<IntList>(self);
return std::get<IntList>(self);
});
py::class_<c10::ScalarType>(te, "ScalarType");