[2/N] Change #include <c10/util/Optional.h> to #include <optional> (#130236)

Follows  #128301. The changes were made by grep and sed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130236
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-07-09 03:17:24 +00:00
committed by PyTorch MergeBot
parent d1e0653fad
commit 29861779ce
105 changed files with 361 additions and 356 deletions

View File

@ -222,7 +222,7 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
static const size_t size = sizeof(CPUGeneratorImplState);
static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr();
// accumulate generator data to be copied into byte tensor

View File

@ -59,7 +59,7 @@ class TORCH_API Context {
}
}
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
std::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
c10::DeviceType device_type = opt_device_type.has_value()
? opt_device_type.value()
: at::getAccelerator(true).value();
@ -407,7 +407,7 @@ class TORCH_API Context {
bool release_original_weights = false;
#endif
bool display_vmap_fallback_warnings_ = false;
std::optional<at::QEngine> quantized_engine = c10::nullopt;
std::optional<at::QEngine> quantized_engine = std::nullopt;
bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false;

View File

@ -17,14 +17,14 @@ namespace at {
/// Return the Device of a Tensor, if the Tensor is defined.
inline std::optional<Device> device_of(const Tensor& t) {
if (t.defined()) {
return c10::make_optional(t.device());
return std::make_optional(t.device());
} else {
return c10::nullopt;
return std::nullopt;
}
}
inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
return t.has_value() ? device_of(t.value()) : c10::nullopt;
return t.has_value() ? device_of(t.value()) : std::nullopt;
}
/// Return the Device of a TensorList, if the list is non-empty and
@ -34,7 +34,7 @@ inline std::optional<Device> device_of(ITensorListRef t) {
if (!t.empty()) {
return device_of(t.front());
} else {
return c10::nullopt;
return std::nullopt;
}
}

View File

@ -76,7 +76,7 @@ TORCH_API TensorBase empty_cpu(
IntArrayRef size,
ScalarType dtype,
bool pin_memory = false,
std::optional<c10::MemoryFormat> memory_format_opt = c10::nullopt);
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
TORCH_API TensorBase empty_cpu(
IntArrayRef size,
@ -110,7 +110,7 @@ TORCH_API TensorBase empty_strided_cpu(
TORCH_API TensorBase empty_meta(
IntArrayRef size,
ScalarType dtype,
std::optional<c10::MemoryFormat> memory_format_opt = c10::nullopt);
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
TORCH_API TensorBase empty_meta(
IntArrayRef size,

View File

@ -321,8 +321,8 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const
auto max_seqlen = at::_nested_get_max_seqlen(base);
auto nt = at::_nested_view_from_jagged(
mutated_view, offsets, dummy, lengths, ragged_idx,
(min_seqlen.defined() ? c10::optional<Tensor>(min_seqlen) : c10::nullopt),
(max_seqlen.defined() ? c10::optional<Tensor>(max_seqlen) : c10::nullopt));
(min_seqlen.defined() ? c10::optional<Tensor>(min_seqlen) : std::nullopt),
(max_seqlen.defined() ? c10::optional<Tensor>(max_seqlen) : std::nullopt));
if (inverse_return_mode != InverseReturnMode::NeverView) {
return nt;

View File

@ -531,9 +531,9 @@ Tensor to_functional_tensor(const Tensor& tensor) {
}
std::optional<Tensor> to_functional_tensor(const std::optional<Tensor>& tensor) {
if (tensor.has_value()) {
return c10::make_optional<Tensor>(to_functional_tensor(*tensor));
return std::make_optional<Tensor>(to_functional_tensor(*tensor));
}
return c10::nullopt;
return std::nullopt;
}
c10::List<::std::optional<Tensor>> to_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
@ -569,9 +569,9 @@ Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
}
std::optional<Tensor> from_functional_tensor(const std::optional<Tensor>& t, bool assert_functional) {
if (t.has_value()) {
return c10::make_optional<Tensor>(from_functional_tensor(*t, assert_functional));
return std::make_optional<Tensor>(from_functional_tensor(*t, assert_functional));
}
return c10::nullopt;
return std::nullopt;
}
std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
std::vector<Tensor> outputs;

View File

@ -217,7 +217,7 @@ static at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
// we will end up hitting PreDispatch stack first. So, we should
// directly redispatch to the functionalize key manually.
static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::clone", "").typed<at::Tensor(const at::Tensor &, std::optional<at::MemoryFormat>)>();
return op.redispatch(c10::DispatchKeySet({c10::DispatchKey::Functionalize}), self, c10::nullopt);
return op.redispatch(c10::DispatchKeySet({c10::DispatchKey::Functionalize}), self, std::nullopt);
}
at::AutoDispatchSkipFunctionalize guard;

View File

@ -173,7 +173,7 @@ NestedTensorImpl::NestedTensorImpl(
nested_sizes_(std::move(nested_sizes)),
nested_strides_(std::move(nested_strides)),
storage_offsets_(std::move(storage_offsets)),
opt_sizes_(c10::nullopt) {
opt_sizes_(std::nullopt) {
C10_LOG_API_USAGE_ONCE("torch.NestedTensor");
TORCH_WARN_ONCE(
"The PyTorch API of nested tensors is in prototype stage and will change "
@ -230,7 +230,7 @@ NestedTensorImpl::NestedTensorImpl(
nested_sizes_(std::move(nested_sizes)),
nested_strides_(std::move(nested_strides)),
storage_offsets_(std::move(storage_offsets)),
opt_sizes_(c10::nullopt) {
opt_sizes_(std::nullopt) {
validate_nested_tensor_metadata(nested_sizes_, nested_strides_, storage_offsets_);
refresh_dim();
set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
@ -239,11 +239,11 @@ NestedTensorImpl::NestedTensorImpl(
std::optional<int64_t> NestedTensorImpl::opt_size(int64_t d) const {
if (C10_UNLIKELY(!opt_sizes_.has_value())) {
// Cache the metadata to avoid recomputing it each time.
opt_sizes_ = c10::make_optional(construct_opt_sizes(nested_sizes_));
opt_sizes_ = std::make_optional(construct_opt_sizes(nested_sizes_));
}
d = at::maybe_wrap_dim(d, dim(), false);
if ((*opt_sizes_)[d] == -1) {
return c10::nullopt;
return std::nullopt;
}
return (*opt_sizes_)[d];
}

View File

@ -27,7 +27,7 @@ Tensor scalar_tensor_static(const Scalar& s, std::optional<ScalarType> dtype_opt
at::tracer::impl::NoTracerDispatchMode tracer_guard;
at::AutoDispatchBelowAutograd mode;
Tensor result = at::detail::empty_cpu(
{}, dtype_opt, c10::nullopt, device_opt, c10::nullopt, c10::nullopt);
{}, dtype_opt, std::nullopt, device_opt, std::nullopt, std::nullopt);
scalar_fill(result, s);
return result;
}

View File

@ -29,7 +29,7 @@ constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
constexpr c10::nullopt_t None = c10::nullopt;
constexpr std::nullopt_t None = std::nullopt;
struct TORCH_API EllipsisIndexType final {
EllipsisIndexType() = default;
@ -39,9 +39,9 @@ TORCH_API extern const EllipsisIndexType Ellipsis;
struct TORCH_API Slice final {
public:
Slice(
std::optional<c10::SymInt> start_index = c10::nullopt,
std::optional<c10::SymInt> stop_index = c10::nullopt,
std::optional<c10::SymInt> step_index = c10::nullopt) {
std::optional<c10::SymInt> start_index = std::nullopt,
std::optional<c10::SymInt> stop_index = std::nullopt,
std::optional<c10::SymInt> step_index = std::nullopt) {
if (!step_index.has_value()) {
step_ = c10::SymInt(1);
} else {
@ -110,7 +110,7 @@ TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
struct TORCH_API TensorIndex final {
// Case 1: `at::indexing::None`
TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
// Case 2: "..." / `at::indexing::Ellipsis`
TensorIndex(at::indexing::EllipsisIndexType)
@ -530,7 +530,7 @@ inline Tensor applySlicing(
auto& obj = indices[i];
// See NOTE [nested tensor size for indexing]
std::optional<SymIntArrayRef> result_sizes = result.is_nested()
? std::optional<SymIntArrayRef>(c10::nullopt)
? std::optional<SymIntArrayRef>(std::nullopt)
: std::optional<SymIntArrayRef>(result.sym_sizes());
result = handleDimInMultiDimIndexing(
/*prev_dim_result=*/result,
@ -606,7 +606,7 @@ inline Tensor get_item(
// as null may need to be changed after we reach a better solution for nested
// tensor size
std::optional<SymIntArrayRef> self_sizes = self.is_nested()
? std::optional<SymIntArrayRef>(c10::nullopt)
? std::optional<SymIntArrayRef>(std::nullopt)
: std::optional<SymIntArrayRef>(self.sym_sizes());
// handle simple types: integers, slices, none, ellipsis, bool

View File

@ -171,7 +171,7 @@ TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef sha
// This will bypass all shape checking in the TensorIterator. Kernels which call this method
// are expected to check shapes before calling `add_owned_input` or `add_owned_output`.
TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
static_shape_ = c10::make_optional(DimVector(shape));
static_shape_ = std::make_optional(DimVector(shape));
return *this;
}

View File

@ -147,7 +147,7 @@ struct TORCH_API OperandInfo {
/// promotion target_dtype value can become different from tensor's dtype
/// also, during type promotion target_dtype and device can be set for an
/// undefined tensor so that tensor can be properly constructed later.
std::optional<Device> device = c10::nullopt;
std::optional<Device> device = std::nullopt;
ScalarType target_dtype = ScalarType::Undefined;
// Caches dtype of the tensor, because scalar_type is an expensive operation
// If dtype of the tensor is changed (e.g. as a result of type promotion or in
@ -971,9 +971,9 @@ class TORCH_API TensorIteratorConfig final {
int num_outputs_ = 0;
int num_inputs_ = 0;
std::optional<DimVector> static_shape_ = c10::nullopt;
std::optional<ScalarType> static_dtype_ = c10::nullopt;
std::optional<Device> static_device_ = c10::nullopt;
std::optional<DimVector> static_shape_ = std::nullopt;
std::optional<ScalarType> static_dtype_ = std::nullopt;
std::optional<Device> static_device_ = std::nullopt;
bool check_mem_overlap_ = true;
bool allow_cpu_scalars_ = false;
bool is_reduction_ = false;

View File

@ -380,7 +380,7 @@ inline std::optional<ResultVec> computeStride_impl(
view_d--;
}
if (view_numel != tensor_numel) {
return c10::nullopt;
return std::nullopt;
}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
@ -390,7 +390,7 @@ inline std::optional<ResultVec> computeStride_impl(
}
}
if (view_d != -1) {
return c10::nullopt;
return std::nullopt;
}
return newstride;
}

View File

@ -304,7 +304,7 @@ inline std::optional<Tensor> cached_cast(
if (arg.has_value()) {
return cached_cast(to_type, *arg, device_type);
} else {
return c10::nullopt;
return std::nullopt;
}
}

View File

@ -1,7 +1,7 @@
#include <c10/core/Allocator.h>
#include <c10/util/Optional.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/llvmMathExtras.h>
#include <optional>
#include <deque>
#include <mutex>
@ -258,7 +258,6 @@ struct CachingHostAllocatorImpl {
}
virtual void process_events() {
while (true) {
// Avoid calling cudaEventDestroy while holding a mutex, so move
// intermediate events out of the lock into this object.
@ -350,7 +349,7 @@ struct CachingHostAllocatorImpl {
template <typename T>
struct CachingHostAllocatorInterface : public at::Allocator {
CachingHostAllocatorInterface() :impl_(std::make_unique<T>()) {}
CachingHostAllocatorInterface() : impl_(std::make_unique<T>()) {}
at::DataPtr allocate(size_t size) override {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for allocate");

View File

@ -7,7 +7,7 @@ check_tensor_options_and_extract_memory_format(
const TensorOptions& options,
std::optional<MemoryFormat> memory_format) {
TORCH_CHECK(
options.requires_grad_opt() == c10::nullopt ||
options.requires_grad_opt() == std::nullopt ||
options.requires_grad_opt().value() == false,
"Operators taking TensorOptions cannot take a TensorOptions with "
"options.requires_grad set as true. This isn't implemented yet.");

View File

@ -6,7 +6,7 @@
#include <c10/util/TypeList.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/order_preserving_flat_hash_map.h>
#include <c10/util/Optional.h>
#include <optional>
#include <ATen/core/TensorBody.h>
#include <ATen/core/jit_type_base.h>

View File

@ -57,7 +57,7 @@ Dimname Dimname::wildcard() {
return result;
}
optional<Dimname> Dimname::unify(Dimname other) const {
std::optional<Dimname> Dimname::unify(Dimname other) const {
if (other.type() == NameType::WILDCARD) {
return *this;
}
@ -67,7 +67,7 @@ optional<Dimname> Dimname::unify(Dimname other) const {
if (name_ == other.symbol()) {
return *this;
}
return c10::nullopt;
return std::nullopt;
}
bool Dimname::matches(Dimname other) const {

View File

@ -2,7 +2,7 @@
#include <ATen/core/symbol.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <optional>
#include <ostream>
namespace at {

View File

@ -5,12 +5,12 @@
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include <c10/util/MathConstants.h>
#include <c10/util/Optional.h>
#include <c10/macros/Macros.h>
#include <type_traits>
#include <limits>
#include <cmath>
#include <limits>
#include <optional>
#include <type_traits>
/**
* Distributions kernel adapted from THRandom.cpp

View File

@ -6,7 +6,7 @@ namespace at {
static std::mutex _generator_mutex_lock;
std::optional<GeneratorFuncType>& GetGeneratorPrivate() {
static std::optional<GeneratorFuncType> generator_privateuse1 = c10::nullopt;
static std::optional<GeneratorFuncType> generator_privateuse1 = std::nullopt;
return generator_privateuse1;
}

View File

@ -8,7 +8,7 @@
#include <c10/util/TypeList.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <optional>
#include <vector>
namespace at {

View File

@ -1127,7 +1127,7 @@ TEST(ListTest, canAccessStringByReference) {
}
TEST(ListTest, canAccessOptionalStringByReference) {
List<std::optional<std::string>> list({"one", "two", c10::nullopt});
List<std::optional<std::string>> list({"one", "two", std::nullopt});
const auto& listRef = list;
static_assert(
std::is_same_v<decltype(listRef[1]), std::optional<std::reference_wrapper<const std::string>>>,

View File

@ -4,9 +4,9 @@
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <optional>
#include <string>
namespace c10 {

View File

@ -33,7 +33,7 @@ struct StashTLSOnEntryGuard {
public:
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
StashTLSOnEntryGuard(): saved_(tls_on_entry.value()) {
tls_on_entry = c10::nullopt;
tls_on_entry = std::nullopt;
}
~StashTLSOnEntryGuard() {
@ -124,7 +124,7 @@ void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet disp
namespace at::impl {
RestorePythonTLSSnapshot::RestorePythonTLSSnapshot() : saved_(safe_get_tls_on_entry()), guard_(safe_get_tls_on_entry()) {
tls_on_entry = c10::nullopt;
tls_on_entry = std::nullopt;
}
RestorePythonTLSSnapshot::~RestorePythonTLSSnapshot() {
@ -143,7 +143,7 @@ MaybeSetTLSOnEntryGuard::MaybeSetTLSOnEntryGuard() {
MaybeSetTLSOnEntryGuard::~MaybeSetTLSOnEntryGuard() {
if (value_set_) {
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
tls_on_entry = c10::nullopt;
tls_on_entry = std::nullopt;
}
}

View File

@ -16,7 +16,7 @@
#include <c10/util/ExclusivelyOwned.h>
#include <c10/util/ExclusivelyOwnedTensorTraits.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/Optional.h>
#include <optional>
#include <c10/util/intrusive_ptr.h>
#include <ATen/core/NamedTensor.h>
@ -147,7 +147,7 @@ class TORCH_API TensorBase {
const TensorBase& fill_(const c10::Scalar& scalar) const;
const TensorBase& zero_() const;
TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=c10::nullopt) const;
TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
bool is_complex() const {
return at::isComplexType(this->scalar_type());
@ -712,7 +712,7 @@ class TORCH_API TensorBase {
/// // f requires grad, has no operation creating it
/// @endcode
/// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, std::optional<TensorList> inputs=c10::nullopt) const;
/// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
///
/// Computes the gradient of current tensor with respect to graph leaves.
///

View File

@ -1,16 +1,17 @@
#pragma once
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/ArrayRef.h>
#include <torch/library.h>
#include <optional>
namespace at::impl {
TORCH_API bool tensor_has_dispatch(const at::Tensor& t);
TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li);
TORCH_API bool tensorlist_has_dispatch(const c10::List<std::optional<at::Tensor>>& li);
TORCH_API bool tensorlist_has_dispatch(
const c10::List<std::optional<at::Tensor>>& li);
using c10::impl::dispatch_mode_enabled;
}
} // namespace at::impl

View File

@ -72,12 +72,12 @@ inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIn
template <>
inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(std::optional<c10::SymInt> x) {
return x.has_value() ? c10::make_optional(x->guard_int(__FILE__, __LINE__)) : c10::nullopt;
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) : std::nullopt;
}
template <>
inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(at::OptionalSymIntArrayRef x) {
return x.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : c10::nullopt;
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : std::nullopt;
}
template<class Return, class... Args>

View File

@ -180,7 +180,7 @@ void boxed_func_for_outofplace_multi_op(const OperatorHandle& /*opHandle*/, Stac
// functional
void expectBoxedCallingWithReturnWorks(const KernelFunction& func) {
called_with_args = c10::nullopt;
called_with_args = std::nullopt;
vector<IValue> stack {3, 4};
OperatorHandle dummy = makeDummyOperatorHandle();
@ -194,7 +194,7 @@ void expectBoxedCallingWithReturnWorks(const KernelFunction& func) {
}
void expectBoxedCallingWithoutReturnWorks(const KernelFunction& func) {
called_with_args = c10::nullopt;
called_with_args = std::nullopt;
vector<IValue> stack {3, 4};
OperatorHandle dummy = makeDummyOperatorHandle();
@ -206,7 +206,7 @@ void expectBoxedCallingWithoutReturnWorks(const KernelFunction& func) {
}
void expectBoxedCallingWithMultiReturnWorks(const KernelFunction& func) {
called_with_args = c10::nullopt;
called_with_args = std::nullopt;
vector<IValue> stack {3, 4};
OperatorHandle dummy = makeDummyOperatorHandle();
@ -284,7 +284,7 @@ void expectOutOfPlaceMultiBoxedCallingWorks(const KernelFunction& func) {
// make an unboxed call to a kernel that returns a single value.
//
void expectUnboxedCallingWithReturnWorks(const KernelFunction& func) {
called_with_args = c10::nullopt;
called_with_args = std::nullopt;
OperatorHandle dummy = makeDummyOperatorHandle();
int64_t result = func.call<int64_t, int64_t, int64_t>(dummy, CPU_TEST_SET, 3, 4);
@ -297,7 +297,7 @@ void expectUnboxedCallingWithReturnWorks(const KernelFunction& func) {
// make an unboxed call to a kernel that returns nothing.
//
void expectUnboxedCallingWithoutReturnWorks(const KernelFunction& func) {
called_with_args = c10::nullopt;
called_with_args = std::nullopt;
OperatorHandle dummy = makeDummyOperatorHandle();
func.call<void, int64_t, int64_t>(dummy, CPU_TEST_SET, 3, 4);
@ -310,7 +310,7 @@ void expectUnboxedCallingWithoutReturnWorks(const KernelFunction& func) {
// When calling unboxed, multiple values are returned as a tuple.
//
void expectUnboxedCallingWithMultiReturnWorks(const KernelFunction& func) {
called_with_args = c10::nullopt;
called_with_args = std::nullopt;
OperatorHandle dummy = makeDummyOperatorHandle();
auto result = func.call<std::tuple<int64_t, int64_t>, int64_t, int64_t>(dummy, CPU_TEST_SET, 3, 4);

View File

@ -793,9 +793,9 @@ TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenFallbackKernelWitho
EXPECT_EQ(4, outputs[0].toInt());
}
std::optional<Tensor> called_arg2 = c10::nullopt;
std::optional<int64_t> called_arg3 = c10::nullopt;
std::optional<std::string> called_arg4 = c10::nullopt;
std::optional<Tensor> called_arg2 = std::nullopt;
std::optional<int64_t> called_arg3 = std::nullopt;
std::optional<std::string> called_arg4 = std::nullopt;
void kernelWithOptInputWithoutOutput(Tensor arg1, const std::optional<Tensor>& arg2, std::optional<int64_t> arg3, std::optional<std::string> arg4) {
called = true;

View File

@ -550,9 +550,9 @@ TEST(OperatorRegistrationTestFunctionBasedKernel, givenFallbackKernelWithoutTens
EXPECT_EQ(4, outputs[0].toInt());
}
std::optional<Tensor> called_arg2 = c10::nullopt;
std::optional<int64_t> called_arg3 = c10::nullopt;
std::optional<std::string> called_arg4 = c10::nullopt;
std::optional<Tensor> called_arg2 = std::nullopt;
std::optional<int64_t> called_arg3 = std::nullopt;
std::optional<std::string> called_arg4 = std::nullopt;
void kernelWithOptInputWithoutOutput(Tensor arg1, const std::optional<Tensor>& arg2, std::optional<int64_t> arg3, std::optional<std::string> arg4) {
called = true;

View File

@ -732,9 +732,9 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenFallbackKernelWithout
TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
bool called = false;
std::optional<Tensor> called_arg2 = c10::nullopt;
std::optional<int64_t> called_arg3 = c10::nullopt;
std::optional<std::string> called_arg4 = c10::nullopt;
std::optional<Tensor> called_arg2 = std::nullopt;
std::optional<int64_t> called_arg3 = std::nullopt;
std::optional<std::string> called_arg4 = std::nullopt;
auto registrar = RegisterOperators().op(
"_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()",
@ -771,9 +771,9 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInp
TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) {
bool called = false;
std::optional<Tensor> called_arg2 = c10::nullopt;
std::optional<int64_t> called_arg3 = c10::nullopt;
std::optional<std::string> called_arg4 = c10::nullopt;
std::optional<Tensor> called_arg2 = std::nullopt;
std::optional<int64_t> called_arg3 = std::nullopt;
std::optional<std::string> called_arg4 = std::nullopt;
auto registrar = RegisterOperators().op(
"_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?",

View File

@ -466,9 +466,9 @@ TEST(OperatorRegistrationTestLambdaBasedKernel, givenFallbackKernelWithoutTensor
EXPECT_EQ(4, outputs[0].toInt());
}
std::optional<Tensor> called_arg2 = c10::nullopt;
std::optional<int64_t> called_arg3 = c10::nullopt;
std::optional<std::string> called_arg4 = c10::nullopt;
std::optional<Tensor> called_arg2 = std::nullopt;
std::optional<int64_t> called_arg3 = std::nullopt;
std::optional<std::string> called_arg4 = std::nullopt;
TEST(OperatorRegistrationTestLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators().op(

View File

@ -668,9 +668,9 @@ TEST(OperatorRegistrationTestFunctorBasedKernel, givenFallbackKernelWithoutTenso
EXPECT_EQ(4, outputs[0].toInt());
}
std::optional<Tensor> called_arg2 = c10::nullopt;
std::optional<int64_t> called_arg3 = c10::nullopt;
std::optional<std::string> called_arg4 = c10::nullopt;
std::optional<Tensor> called_arg2 = std::nullopt;
std::optional<int64_t> called_arg3 = std::nullopt;
std::optional<std::string> called_arg4 = std::nullopt;
struct KernelWithOptInputWithoutOutput final : OperatorKernel {
void operator()(Tensor arg1, const std::optional<Tensor>& arg2, std::optional<int64_t> arg3, std::optional<std::string> arg4) {

View File

@ -631,7 +631,7 @@ std::optional<IValue> ClassType::findConstant(const std::string& name) const {
}
if (pos >= constantNames_.size()) {
return c10::nullopt;
return std::nullopt;
}
return constantValues_[pos];
}
@ -659,7 +659,7 @@ std::optional<ClassType::Property> ClassType::getProperty(const std::string& nam
}
}
return c10::nullopt;
return std::nullopt;
}
void ClassType::addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter) {
@ -676,7 +676,7 @@ std::optional<size_t> ClassType::findConstantSlot(const std::string& name) const
}
slot++;
}
return c10::nullopt;
return std::nullopt;
}
const std::string& ClassType::getConstantName(size_t slot) const {

View File

@ -4,7 +4,7 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type_base.h>
#include <c10/util/Optional.h>
#include <optional>
namespace torch::jit {
@ -160,7 +160,7 @@ struct TORCH_API ClassType : public NamedType {
}
slot++;
}
return c10::nullopt;
return std::nullopt;
}
size_t getAttributeSlot(const std::string& name) const {
if (auto r = findAttributeSlot(name)) {

View File

@ -80,7 +80,7 @@ std::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_na
return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::optional<OperatorHandle> {
auto found = operatorLookupTable.find(overload_name);
if (found == operatorLookupTable.end()) {
return c10::nullopt;
return std::nullopt;
}
return found->second;
});
@ -93,7 +93,7 @@ void Dispatcher::waitForDef(const FunctionSchema& schema) {
using namespace std::chrono_literals;
std::unique_lock<std::mutex> lock(guard_->mutex);
bool r = cond_var_.wait_for(lock, 2s, [&]{
return findOp(schema.operator_name()) != c10::nullopt;
return findOp(schema.operator_name()) != std::nullopt;
});
TORCH_INTERNAL_ASSERT(r,
"Expected main interpreter to define ", schema.operator_name(),
@ -127,7 +127,7 @@ std::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overloa
if (it->hasSchema()) {
return it;
} else {
return c10::nullopt;
return std::nullopt;
}
} else {
return it;
@ -164,7 +164,7 @@ const std::vector<OperatorName> Dispatcher::getAllOpNames() {
// are done
OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) {
const auto found = findOp(op_name);
if (found != c10::nullopt) {
if (found != std::nullopt) {
return *found;
}
@ -279,7 +279,7 @@ std::optional<std::pair<const char*, const char*>> Dispatcher::getPyStub(Operato
std::lock_guard<std::mutex> lock(guard_->mutex);
auto found = pythonModulesSingleton().find(op_name);
if (found == pythonModulesSingleton().end()) {
return c10::nullopt;
return std::nullopt;
}
return found->second;
}

View File

@ -97,7 +97,7 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug,
void OperatorEntry::deregisterSchema() {
TORCH_INTERNAL_ASSERT(schema_.has_value());
schema_ = c10::nullopt;
schema_ = std::nullopt;
dispatchKeyExtractor_.deregisterSchema();
}

View File

@ -3,7 +3,6 @@
#include <ATen/core/function_schema.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/Optional.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/PyHandleCache.h>
#include <c10/core/SafePyObject.h>
@ -16,8 +15,9 @@
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <ATen/core/enum_tag.h>
#include <list>
#include <optional>
#include <array>
#include <list>
#ifdef C10_MOBILE
#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY

View File

@ -5,7 +5,7 @@
#include <type_traits>
#include <ATen/core/jit_type_base.h>
#include <c10/util/Optional.h>
#include <optional>
namespace c10 {

View File

@ -69,7 +69,7 @@ bool FunctionSchema::canAliasTypeSetsAlias(const std::optional<AliasTypeSet> &lh
std::optional<AliasTypeSet> FunctionSchema::getAliasTypeSetContainedTypes(const std::optional<AliasTypeSet> &aliasTypeSet) const {
if (!aliasTypeSet) {
return c10::nullopt;
return std::nullopt;
}
std::unordered_set<TypePtr> containedTypes;
std::stack<TypePtr> typeStack;
@ -114,7 +114,7 @@ std::optional<AliasTypeSet> FunctionSchema::mapTypeToAliasTypeSet(const TypePtr&
}
}
if (mutable_types.empty()) {
return c10::nullopt;
return std::nullopt;
}
return mutable_types;
}
@ -135,12 +135,12 @@ std::optional<AliasTypeSet> FunctionSchema::mapTypeToAliasTypeSet(const TypePtr&
}
}
if (mutable_types.empty()) {
return c10::nullopt;
return std::nullopt;
}
return {AliasTypeSet{TupleType::create(std::move(mutable_types))}};
}
default:
return c10::nullopt;
return std::nullopt;
}
}

View File

@ -29,20 +29,20 @@ struct Argument {
Argument(
std::string name = "",
const TypePtr& type = nullptr,
std::optional<int32_t> N = c10::nullopt,
std::optional<IValue> default_value = c10::nullopt,
std::optional<int32_t> N = std::nullopt,
std::optional<IValue> default_value = std::nullopt,
bool kwarg_only = false,
std::optional<AliasInfo> alias_info = c10::nullopt)
std::optional<AliasInfo> alias_info = std::nullopt)
: Argument(std::move(name), type, type, N, std::move(default_value), kwarg_only, std::move(alias_info)) {}
Argument(
std::string name,
TypePtr fake_type,
TypePtr real_type,
std::optional<int32_t> N = c10::nullopt,
std::optional<IValue> default_value = c10::nullopt,
std::optional<int32_t> N = std::nullopt,
std::optional<IValue> default_value = std::nullopt,
bool kwarg_only = false,
std::optional<AliasInfo> alias_info = c10::nullopt)
std::optional<AliasInfo> alias_info = std::nullopt)
: name_(std::move(name)),
type_(fake_type ? std::move(fake_type) : TensorType::get()),
real_type_(real_type ? std::move(real_type) : type_),
@ -150,7 +150,7 @@ struct Argument {
N_,
default_value_,
kwarg_only_,
alias_info_ ? std::optional<AliasInfo>(*alias_info_) : c10::nullopt);
alias_info_ ? std::optional<AliasInfo>(*alias_info_) : std::nullopt);
}
// this function checks whether this Argument is backward compatible with
@ -397,7 +397,7 @@ struct TORCH_API FunctionSchema {
bool is_mutable(c10::string_view name) const {
std::optional<int> index = argumentIndexWithName(name);
TORCH_INTERNAL_ASSERT(
index != c10::nullopt, "Schema has no argument named ", name);
index != std::nullopt, "Schema has no argument named ", name);
return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
}
@ -436,7 +436,7 @@ struct TORCH_API FunctionSchema {
if(name == arguments()[i].name())
return i;
}
return c10::nullopt;
return std::nullopt;
}
FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
return FunctionSchema(
@ -470,8 +470,8 @@ struct TORCH_API FunctionSchema {
std::string formatTypeMismatchMsg(
const Argument& expected,
const std::string& actual_type,
std::optional<size_t> position = c10::nullopt,
std::optional<std::string> value = c10::nullopt) const;
std::optional<size_t> position = std::nullopt,
std::optional<std::string> value = std::nullopt) const;
FunctionSchema cloneWithRemappedTypes(
const std::function<TypePtr(TypePtr)> type_map) const;

View File

@ -820,7 +820,7 @@ struct TORCH_API IValue final {
IValue(std::optional<T> v);
template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
IValue(c10::OptionalArrayRef<T> v);
IValue(c10::nullopt_t);
IValue(std::nullopt_t);
// ClassType
IValue(c10::intrusive_ptr<ivalue::Object> v);
@ -1145,10 +1145,10 @@ struct TORCH_API IValue final {
// TODO: There are several places that recurse over IValue. This is fragile.
// This visitor should be used to recurse over ivalues.
void visit(const std::function<bool(const IValue&)>& visitor) const;
IValue deepcopy(std::optional<at::Device> device = c10::nullopt) const;
IValue deepcopy(std::optional<at::Device> device = std::nullopt) const;
IValue deepcopy(
HashIdentityIValueMap& memo,
std::optional<at::Device> device = c10::nullopt) const;
std::optional<at::Device> device = std::nullopt) const;
private:
static c10::intrusive_ptr_target* null_to_undefined_tensor(
@ -1523,24 +1523,24 @@ struct TORCH_API WeakTypePtr {
struct WeakOrStrongCompilationUnit {
explicit WeakOrStrongCompilationUnit(
std::shared_ptr<torch::jit::CompilationUnit> shared_cu)
: strong_ptr_(std::move(shared_cu)), weak_ptr_(c10::nullopt) {}
: strong_ptr_(std::move(shared_cu)), weak_ptr_(std::nullopt) {}
explicit WeakOrStrongCompilationUnit(
std::weak_ptr<torch::jit::CompilationUnit> weak_cu)
: strong_ptr_(c10::nullopt), weak_ptr_(std::move(weak_cu)) {}
: strong_ptr_(std::nullopt), weak_ptr_(std::move(weak_cu)) {}
std::shared_ptr<torch::jit::CompilationUnit> getStrongRefOrThrow() const {
TORCH_INTERNAL_ASSERT(strong_ptr_ != c10::nullopt);
TORCH_INTERNAL_ASSERT(strong_ptr_ != std::nullopt);
return *strong_ptr_;
}
std::weak_ptr<torch::jit::CompilationUnit> getWeakRefOrThrow() const {
TORCH_INTERNAL_ASSERT(weak_ptr_ != c10::nullopt);
TORCH_INTERNAL_ASSERT(weak_ptr_ != std::nullopt);
return *weak_ptr_;
}
bool holdingStrongRef() const {
return strong_ptr_ != c10::nullopt;
return strong_ptr_ != std::nullopt;
}
bool holdingEmptyStrongRef() const {

View File

@ -2,6 +2,7 @@
#include <condition_variable>
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>
@ -909,7 +910,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
using WeakStorage = c10::weak_intrusive_ptr<c10::StorageImpl>;
void markCompleted(
IValue value,
std::optional<std::vector<WeakStorage>> storages = c10::nullopt) {
std::optional<std::vector<WeakStorage>> storages = std::nullopt) {
// Start by performing all steps that can throw, before setting any field.
// Do this before even acquiring the mutex, because extractStorages might
// acquire the GIL, which could lead to a lock inversion with our mutex.
@ -1586,11 +1587,11 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
c10::intrusive_ptr<Object> copy() const;
c10::intrusive_ptr<Object> deepcopy(
std::optional<at::Device> device = c10::nullopt) const;
std::optional<at::Device> device = std::nullopt) const;
c10::intrusive_ptr<Object> deepcopy(
IValue::HashIdentityIValueMap& memo,
std::optional<at::Device> device = c10::nullopt) const;
std::optional<at::Device> device = std::nullopt) const;
bool is_weak_compilation_ref() const {
return !type_.holds_strong_ref();
@ -1613,7 +1614,7 @@ struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
public:
virtual PyObject* getPyObject() = 0;
virtual c10::InferredType tryToInferType() = 0;
virtual IValue toIValue(const TypePtr& type, std::optional<int32_t> N = c10::nullopt) = 0;
virtual IValue toIValue(const TypePtr& type, std::optional<int32_t> N = std::nullopt) = 0;
virtual std::string toStr() = 0;
virtual std::vector<at::Tensor> extractTensors() = 0;
@ -1911,7 +1912,7 @@ std::unordered_map<K, V> generic_to(
template <typename T>
std::optional<T> generic_to(IValue ivalue, _fake_type<std::optional<T>>) {
if (ivalue.isNone()) {
return c10::nullopt;
return std::nullopt;
}
return std::move(ivalue).to<T>();
}
@ -2280,7 +2281,7 @@ inline IValue::IValue(std::optional<T> v) : IValue() {
}
}
inline IValue::IValue(c10::nullopt_t) : IValue() {}
inline IValue::IValue(std::nullopt_t) : IValue() {}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
: tag(Tag::Object) {
@ -2363,7 +2364,7 @@ inline const std::string& IValue::toStringRef() const {
inline std::optional<std::reference_wrapper<const std::string>> IValue::
toOptionalStringRef() const {
if (isNone()) {
return c10::nullopt;
return std::nullopt;
}
AT_ASSERT(isString(), "Expected optional<string> but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(

View File

@ -8,7 +8,7 @@
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h>
#include <c10/util/Optional.h>
#include <optional>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
#include <c10/core/Device.h>
@ -187,7 +187,7 @@ struct OptionalType;
using OptionalTypePtr = std::shared_ptr<OptionalType>;
// This type represents an optional type. There is one `Optional` for
// each element type. `Optional[T]` can accept both `T` and
// `None`(`c10::nullopt` in C++)
// `None`(`std::nullopt` in C++)
// Subtype hierarchy for Optional:
// - Optional[T] <: Optional[R] iff T <: R
// - T <: Optional[R] if T <: R
@ -372,10 +372,10 @@ inline ShapeSymbol merge_primitive(
// dims, partially known and fully known shapes are all supported.
struct TORCH_API SymbolicShape {
// Unranked shape constructor.
SymbolicShape() : dims_(c10::nullopt) {}
SymbolicShape() : dims_(std::nullopt) {}
// Known rank but unknown dimentions.
SymbolicShape(std::optional<size_t> rank) : dims_(c10::nullopt) {
SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
if(!rank) {
return;
}
@ -432,7 +432,7 @@ struct TORCH_API SymbolicShape {
// Returns rank or nullopt in case of unranked shape.
std::optional<size_t> rank() const {
if(!dims_) {
return c10::nullopt;
return std::nullopt;
}
return dims_->size();
}
@ -443,7 +443,7 @@ struct TORCH_API SymbolicShape {
std::optional<std::vector<bool>> symbolicDims() const {
if (!dims_) {
return c10::nullopt;
return std::nullopt;
}
auto symbolic_dims = std::vector<bool>();
for (const ShapeSymbol& s : *dims_) {
@ -505,7 +505,7 @@ struct VaryingShape {
VaryingShape(c10::ArrayRef<T> vec)
: VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
VaryingShape(std::optional<size_t> size = c10::nullopt) : dims_(c10::nullopt) {
VaryingShape(std::optional<size_t> size = std::nullopt) : dims_(std::nullopt) {
if (size) {
dims_ = ListOfOptionalElements(*size);
}
@ -528,7 +528,7 @@ struct VaryingShape {
std::optional<size_t> size() const {
if (!dims_) {
return c10::nullopt;
return std::nullopt;
}
const auto& dims = dims_.value();
return dims.size();
@ -542,13 +542,13 @@ struct VaryingShape {
std::optional<std::vector<T>> concrete_sizes() const {
if (!dims_) {
return c10::nullopt;
return std::nullopt;
}
std::vector<T> sizes;
sizes.reserve(dims_.value().size());
for (auto d : *dims_) {
if (!d) {
return c10::nullopt;
return std::nullopt;
}
sizes.push_back(d.value());
}
@ -780,7 +780,7 @@ struct TORCH_API TensorType : public SharedType {
TensorTypePtr withPossiblyUndefined() {
auto r = clone();
r->undefined_ = c10::nullopt;
r->undefined_ = std::nullopt;
return r;
}
@ -854,9 +854,9 @@ struct TORCH_API TensorType : public SharedType {
// with `withUndefined`
// This will also mean that `undefined` tensors will fail
// `subtypeOf(TensorType::get())` check
// undefined_ may become `c10::nullopt` if the tensor was observed to be both
// undefined_ may become `std::nullopt` if the tensor was observed to be both
// defined and undefined. However, no tensor type starts out with
// `undefined_` set to `c10::nullopt`
// `undefined_` set to `std::nullopt`
std::optional<bool> undefined_;
// Represents whether or not this type was inferred.
bool is_inferred_ = false;
@ -1161,7 +1161,7 @@ struct TORCH_API TupleType : public NamedType {
std::vector<TypePtr> types) {
return TupleTypePtr(new TupleType(
std::move(types),
c10::nullopt,
std::nullopt,
nullptr)); // NOLINT(modernize-make-shared)
}
static TupleTypePtr create() {
@ -1739,7 +1739,7 @@ inline TypePtr TensorType::fromNumberType(const Type& typ) {
} else if (typ.isSubtypeOf(*BoolType::get())) {
return TensorType::createContiguous(at::kBool, at::kCPU, {});
} else if (typ.kind() == NumberType::Kind) {
return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
return TensorType::create(std::nullopt, at::kCPU, {}, std::nullopt);
}
TORCH_CHECK(false, "Unknown number type: ", typ.str());
}
@ -1755,7 +1755,7 @@ inline std::optional<c10::ScalarType> tryScalarTypeFromJitType(const Type& type)
} else if (type == *BoolType::get()) {
return at::ScalarType::Bool;
}
return c10::nullopt;
return std::nullopt;
}
inline at::ScalarType scalarTypeFromJitType(const Type& type) {
@ -2040,7 +2040,7 @@ struct getMaybeFakeTypePtr_<c10::Dict<K, V>, fake> final {
};
template <class T, bool fake>
struct getMaybeFakeTypePtr_<at::optional<T>, fake> final {
struct getMaybeFakeTypePtr_<std::optional<T>, fake> final {
static const auto& call() {
static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
// The "per optional<T>" static singleton needs to live in a .cpp file,
@ -2131,7 +2131,7 @@ struct MatchTypeReturn {
private:
MatchTypeReturn()
: reason_(c10::nullopt) {}
: reason_(std::nullopt) {}
std::optional<std::string> reason_; // is there is no match, this contains the reason
};

View File

@ -14,7 +14,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <optional>
namespace c10 {
@ -73,7 +73,7 @@ struct Type;
struct SharedType;
// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// std::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter = std::function<std::optional<std::string>(const Type&)>;

View File

@ -61,7 +61,7 @@ void Library::reset() {
Library::Library(Kind kind, std::string ns, std::optional<c10::DispatchKey> k, const char* file, uint32_t line)
: kind_(kind)
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
, ns_(ns == "_" ? std::nullopt : std::make_optional(std::move(ns)))
, dispatch_key_(k.value_or(CatchAll) == CatchAll ? std::optional<c10::DispatchKey>() : k)
, file_(file)
, line_(line)

View File

@ -88,7 +88,7 @@ std::optional<std::string> findSchemaDifferences(
}
// no differences found
return c10::nullopt;
return std::nullopt;
}
} // namespace c10

View File

@ -71,7 +71,7 @@ c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(
opName,
" because there is no kernel specified.");
std::optional<FunctionSchema> inferred_schema = c10::nullopt;
std::optional<FunctionSchema> inferred_schema = std::nullopt;
for (const auto& kernel : options.kernels) {
if (nullptr != kernel.inferred_function_schema.get()) {
if (!inferred_schema.has_value()) {

View File

@ -76,7 +76,7 @@ public:
// internal-only for registering stack based catch-all kernels
template<KernelFunction::BoxedKernelFunction* kernel_func>
Options&& catchAllKernel() && {
return std::move(*this).kernel(c10::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr);
return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr);
}
// internal only for registering caffe2 ops
@ -215,7 +215,7 @@ public:
static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
return std::move(*this).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
impl::CppSignature::make<KernelFunctor>(),
detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
@ -272,7 +272,7 @@ public:
static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
return std::move(*this).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
impl::CppSignature::make<FuncType>(),
// TODO Do schema inference without relying on WrapFunctionIntoFunctor
@ -302,7 +302,7 @@ public:
TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
return std::move(*this).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
impl::CppSignature::make<FuncType>(),
// TODO Do schema inference without relying on WrapFunctionIntoFunctor
@ -384,7 +384,7 @@ public:
static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
return std::move(*this).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(lambda)),
impl::CppSignature::make<Lambda>(),
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
@ -410,18 +410,18 @@ public:
}
Options()
: schemaOrName_(c10::nullopt)
: schemaOrName_(std::nullopt)
, kernels()
, aliasAnalysisKind_(c10::nullopt)
, aliasAnalysisKind_(std::nullopt)
{}
// KernelRegistrationConfig accumulates all information from the config
// parameters passed to a RegisterOperators::op() call into one object.
struct KernelRegistrationConfig final {
KernelRegistrationConfig()
: dispatch_key(c10::nullopt)
: dispatch_key(std::nullopt)
, func()
, cpp_signature(c10::nullopt)
, cpp_signature(std::nullopt)
, inferred_function_schema(nullptr)
{}
@ -522,7 +522,7 @@ public:
op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && {
constexpr bool AllowLegacyTypes = true;
return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedRuntimeFunction<AllowLegacyTypes>(func),
impl::CppSignature::make<FuncType>(),
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
@ -553,7 +553,7 @@ public:
constexpr bool AllowLegacyTypes = true;
return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
impl::CppSignature::make<Lambda>(),
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
@ -570,7 +570,7 @@ public:
constexpr bool AllowLegacyTypes = true;
return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
c10::nullopt,
std::nullopt,
KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
impl::CppSignature::make<Lambda>(),
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor

View File

@ -909,28 +909,28 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
// optional types (with has_value() == false)
testArgTypes<std::optional<double>>::test(
std::optional<double>(c10::nullopt), [] (const std::optional<double>& v) {EXPECT_FALSE(v.has_value());},
std::optional<double>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<double>(std::nullopt), [] (const std::optional<double>& v) {EXPECT_FALSE(v.has_value());},
std::optional<double>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(float? a) -> float?");
testArgTypes<std::optional<int64_t>>::test(
std::optional<int64_t>(c10::nullopt), [] (const std::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
std::optional<int64_t>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<int64_t>(std::nullopt), [] (const std::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
std::optional<int64_t>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(int? a) -> int?");
testArgTypes<std::optional<bool>>::test(
std::optional<bool>(c10::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
std::optional<bool>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<bool>(std::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
std::optional<bool>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(bool? a) -> bool?");
testArgTypes<std::optional<bool>>::test(
std::optional<bool>(c10::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
std::optional<bool>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<bool>(std::nullopt), [] (const std::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
std::optional<bool>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(bool? a) -> bool?");
testArgTypes<std::optional<std::string>>::test(
std::optional<std::string>(c10::nullopt), [] (const std::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
std::optional<std::string>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<std::string>(std::nullopt), [] (const std::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
std::optional<std::string>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(str? a) -> str?");
testArgTypes<std::optional<Tensor>>::test(
std::optional<Tensor>(c10::nullopt), [] (const std::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
std::optional<Tensor>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<Tensor>(std::nullopt), [] (const std::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
std::optional<Tensor>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(Tensor? a) -> Tensor?");
@ -1136,8 +1136,8 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
// Test optional of list (with nullopt)
testArgTypes<std::optional<c10::List<int64_t>>>::test(
std::optional<c10::List<int64_t>>(c10::nullopt), [] (const std::optional<c10::List<int64_t>>& v) {EXPECT_FALSE(v.has_value());},
std::optional<c10::List<int64_t>>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
std::optional<c10::List<int64_t>>(std::nullopt), [] (const std::optional<c10::List<int64_t>>& v) {EXPECT_FALSE(v.has_value());},
std::optional<c10::List<int64_t>>(std::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
"(int[]? a) -> int[]?");
// Test optional of list (with empty list)
@ -1160,8 +1160,8 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
// Test list of optional (with values)
testArgTypes<c10::List<::std::optional<int64_t>>>::test(
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, c10::nullopt, 2})), [] (const c10::List<::std::optional<int64_t>>& v) {expectListEquals<std::optional<int64_t>>({3, c10::nullopt, 2}, v);},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, c10::nullopt, 2})), [] (const IValue& v) {expectListEquals<std::optional<int64_t>>({3, c10::nullopt, 2}, v.to<c10::List<::std::optional<int64_t>>>());},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, std::nullopt, 2})), [] (const c10::List<::std::optional<int64_t>>& v) {expectListEquals<std::optional<int64_t>>({3, std::nullopt, 2}, v);},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, std::nullopt, 2})), [] (const IValue& v) {expectListEquals<std::optional<int64_t>>({3, std::nullopt, 2}, v.to<c10::List<::std::optional<int64_t>>>());},
"(int?[] a) -> int?[]");
// dict types
@ -2141,7 +2141,7 @@ TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_mixedCall
TEST(OperatorRegistrationTest, getRegistrationsForDispatchKey) {
// should return every registered op
auto all_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::nullopt);
auto all_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(std::nullopt);
// should return every registered op with a cpu kernel
auto cpu_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::DispatchKey::CPU);
ASSERT_TRUE(all_ops.size() > 0);

View File

@ -2,11 +2,11 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/string_view.h>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include <ostream>
namespace c10 {
@ -26,9 +26,9 @@ struct OperatorName final {
std::optional<c10::string_view> getNamespace() const {
auto pos = name.find("::");
if (pos == std::string::npos) {
return c10::nullopt;
return std::nullopt;
} else {
return c10::make_optional(c10::string_view(name.data(), pos));
return std::make_optional(c10::string_view(name.data(), pos));
}
}
@ -39,7 +39,8 @@ struct OperatorName final {
const auto old_name_size = name.size();
name.resize(ns_len + 2 + old_name_size);
// Shift current value of name to the end of the new space.
name.replace(name.size() - old_name_size, old_name_size, name, 0, old_name_size);
name.replace(
name.size() - old_name_size, old_name_size, name, 0, old_name_size);
name.replace(0, ns_len, ns, ns_len);
name[ns_len] = ':';
name[ns_len + 1] = ':';
@ -56,7 +57,9 @@ struct OperatorName final {
struct OperatorNameView final {
c10::string_view name;
c10::string_view overload_name;
constexpr OperatorNameView(c10::string_view name, c10::string_view overload_name)
constexpr OperatorNameView(
c10::string_view name,
c10::string_view overload_name)
: name(name), overload_name(overload_name) {}
// Parses strings like "foo.overload" and also "foo"
constexpr static OperatorNameView parse(c10::string_view full_name) {
@ -83,10 +86,11 @@ TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&);
} // namespace c10
namespace std {
template <>
struct hash<::c10::OperatorName> {
template <>
struct hash<::c10::OperatorName> {
size_t operator()(const ::c10::OperatorName& x) const {
return std::hash<std::string>()(x.name) ^ (~ std::hash<std::string>()(x.overload_name));
return std::hash<std::string>()(x.name) ^
(~std::hash<std::string>()(x.overload_name));
}
};
}
};
} // namespace std

View File

@ -350,7 +350,7 @@ VaryingShape<int64_t> TensorType::sizes() const {
// we turn symbolic shapes into unknowns
return ss.is_static()
? std::optional<int64_t>(ss.static_size())
: c10::nullopt;
: std::nullopt;
}));
}
@ -456,7 +456,7 @@ TensorTypePtr TensorType::createContiguous(
device,
VaryingShape<int64_t>(sizes),
VaryingShape<int64_t>(strides),
c10::nullopt);
std::nullopt);
}
const SymbolicShape& TensorType::symbolic_sizes() const {

View File

@ -403,14 +403,14 @@ static std::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t
auto tuple1 = t1->castRaw<TupleType>();
auto tuple2 = t2->castRaw<TupleType>();
if (tuple1->elements().size() != tuple2->elements().size()) {
return c10::nullopt;
return std::nullopt;
}
std::vector<TypePtr> elements;
for (size_t i = 0; i < tuple1->elements().size(); i++) {
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_union)) {
elements.push_back(*std::move(elem));
} else {
return c10::nullopt;
return std::nullopt;
}
}
return static_cast<TypePtr>(TupleType::create(std::move(elements)));
@ -443,7 +443,7 @@ static std::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t
return type_hint;
}
return c10::nullopt;
return std::nullopt;
}
std::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_union, const TypePtr& type_hint) {
@ -463,7 +463,7 @@ std::optional<TypePtr> unifyTypeList(
const TypePtr& type_hint) {
if (elements.empty()) {
why_not << "Cannot get unified type from empty list";
return c10::nullopt;
return std::nullopt;
}
TypePtr ret_type = elements.at(0);
@ -474,7 +474,7 @@ std::optional<TypePtr> unifyTypeList(
<< elements.at(i)->repr_str()
<< " did not match the types before it ("
<< ret_type->repr_str() << ")";
return c10::nullopt;
return std::nullopt;
}
ret_type = *maybe_unified;
}

View File

@ -50,7 +50,7 @@ std::optional<TypePtr> subtractTypeSetFrom(std::vector<TypePtr>& to_subtract, Ar
});
if (types.empty()) {
return c10::nullopt;
return std::nullopt;
} else if (types.size() == 1) {
return types[0];
} else {
@ -98,7 +98,7 @@ void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
// `Optional` could prevent us from coalescing other types
if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get()))
|| (!t1->isSubtypeOf(*NoneType::get()) && t2->isSubtypeOf(*NoneType::get()))) {
return c10::nullopt;
return std::nullopt;
} else {
return unifyTypes(t1, t2, /*default_to_union=*/false);
}
@ -278,7 +278,7 @@ std::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtr
std::optional<TypePtr> UnionType::toOptional() const {
if (!canHoldType(*NoneType::get())) {
return c10::nullopt;
return std::nullopt;
}
std::vector<TypePtr> copied_types = this->containedTypes().vec();
@ -286,7 +286,7 @@ std::optional<TypePtr> UnionType::toOptional() const {
auto maybe_opt = UnionType::create(std::move(copied_types));
if (maybe_opt->kind() == UnionType::Kind) {
return c10::nullopt;
return std::nullopt;
} else {
return maybe_opt;
}

View File

@ -326,7 +326,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed();
auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>

View File

@ -100,7 +100,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
// default generator is always registered
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
std::nullopt, cuda::detail::getDefaultCUDAGenerator());
gen->register_graph(this);
for (auto& [generator_state, wholegraph_increments] :

View File

@ -3,7 +3,7 @@
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/Generator.h>
#include <c10/util/Optional.h>
#include <optional>
// TODO: No need to have this whole header, we can just put it all in
// the cpp file

View File

@ -149,7 +149,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Tensor_batch_rule(
std::optional<at::Layout> layout,
std::optional<at::Device> device,
std::optional<bool> pin_memory){
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory);
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::nullopt, dtype, layout, device, pin_memory);
}
static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Scalar_batch_rule(
@ -162,7 +162,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Scalar_batch_rule(
std::optional<bool> pin_memory){
auto end_t = at::native::wrapped_scalar_tensor(end, start.device());
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, c10::nullopt, steps, c10::nullopt, dtype, layout, device, pin_memory);
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::nullopt, dtype, layout, device, pin_memory);
}
static std::tuple<Tensor,optional<int64_t>> linspace_Scalar_Tensor_batch_rule(
@ -175,7 +175,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_Scalar_Tensor_batch_rule(
std::optional<bool> pin_memory){
auto start_t = at::native::wrapped_scalar_tensor(start, end.device());
return linspace_logspace_batch_rule_helper(start_t, c10::nullopt, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory);
return linspace_logspace_batch_rule_helper(start_t, std::nullopt, end, end_bdim, steps, std::nullopt, dtype, layout, device, pin_memory);
}
static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Tensor_batch_rule(
@ -187,7 +187,7 @@ static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Tensor_batch_rule(
std::optional<at::Layout> layout,
std::optional<at::Device> device,
std::optional<bool> pin_memory){
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory);
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::make_optional(base), dtype, layout, device, pin_memory);
}
static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
@ -201,7 +201,7 @@ static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
std::optional<bool> pin_memory){
auto end_t = at::native::wrapped_scalar_tensor(end, start.device());
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, c10::nullopt, steps, c10::make_optional(base), dtype, layout, device, pin_memory);
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::make_optional(base), dtype, layout, device, pin_memory);
}
static std::tuple<Tensor,optional<int64_t>> logspace_Scalar_Tensor_batch_rule(
@ -215,7 +215,7 @@ static std::tuple<Tensor,optional<int64_t>> logspace_Scalar_Tensor_batch_rule(
std::optional<bool> pin_memory){
auto start_t = at::native::wrapped_scalar_tensor(start, end.device());
return linspace_logspace_batch_rule_helper(start_t, c10::nullopt, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory);
return linspace_logspace_batch_rule_helper(start_t, std::nullopt, end, end_bdim, steps, std::make_optional(base), dtype, layout, device, pin_memory);
}
static bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) {

View File

@ -8,7 +8,7 @@
#include <ATen/functorch/Macros.h>
#include <c10/core/DispatchKey.h>
#include <ATen/core/function_schema.h>
#include <c10/util/Optional.h>
#include <optional>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <ATen/functorch/Interpreter.h>
#include <ATen/functorch/VmapInterpreter.h>

View File

@ -3,7 +3,7 @@
#include <ATen/functorch/Macros.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/Optional.h>
#include <optional>
#include <bitset>
#include <utility>
#include <variant>
@ -149,7 +149,7 @@ struct Interpreter {
}
void clearSavedLocalDispatchKeySet() {
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
savedLocalDispatchKeySet_ = c10::nullopt;
savedLocalDispatchKeySet_ = std::nullopt;
}
c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());

View File

@ -536,7 +536,7 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
// we'll just slice the tensor to get a Tensor of shape [0] to pass to at::cat.
std::vector<Tensor> tensors_to_cat;
tensors_to_cat.reserve(tensors.size());
std::optional<int64_t> bdim_size = c10::nullopt;
std::optional<int64_t> bdim_size = std::nullopt;
// find the bdim size. Might not exist if all BatchedTensors should be skipped
// by cat's special case.
@ -573,7 +573,7 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
}
auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
std::optional<int64_t> new_bdim = bdim_size.has_value() ? c10::make_optional((int64_t)0) : nullopt;
std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : nullopt;
auto result = at::cat(tensors_to_cat, new_dim);
return makeBatched(result, new_bdim, get_current_level());
}

View File

@ -5,7 +5,7 @@
#include <ATen/core/Generator.h>
#include <ATen/core/PhiloxRNGEngine.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/util/Optional.h>
#include <optional>
namespace at {
namespace mps::detail {

View File

@ -68,7 +68,7 @@ c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
static const size_t total_size = states_size + seed_size;
auto state_tensor = at::detail::empty_cpu(
{(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
{(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed();
memcpy(rng_state, this->data_.state.data(), states_size);

View File

@ -5,7 +5,7 @@
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/Generator.h>
#include <ATen/mps/MPSEvent.h>
#include <c10/util/Optional.h>
#include <optional>
namespace at::mps {

View File

@ -1470,7 +1470,7 @@ void batch_norm_elemt_channels_last_cuda_template(
const at::Tensor& shift, // bias of BN
const at::Tensor& mean,
const at::Tensor& inv_std,
const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
const std::optional<at::Tensor>& z = c10::nullopt, // bias after BN
const bool fuse_relu = false) {
const auto stride = input.sizes()[1];
const auto reduction_size = input.numel() / stride;

View File

@ -272,7 +272,7 @@ Tensor rms_norm(
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
auto bias_opt = at::optional<Tensor>();
auto bias_opt = std::optional<Tensor>();
const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt);
(void) _check_layer_norm_inputs(input, normalized_shape, weight, bias);

View File

@ -110,7 +110,7 @@ inline Tensor from_blob(
IntArrayRef strides,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {},
const std::optional<Device> target_device = c10::nullopt) {
const std::optional<Device> target_device = std::nullopt) {
return for_blob(data, sizes)
.strides(strides)
.deleter(deleter)
@ -126,7 +126,7 @@ inline Tensor from_blob(
int64_t storage_offset,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {},
const std::optional<Device> target_device = c10::nullopt) {
const std::optional<Device> target_device = std::nullopt) {
return for_blob(data, sizes)
.strides(strides)
.storage_offset(storage_offset)
@ -141,7 +141,7 @@ inline Tensor from_blob(
IntArrayRef sizes,
std::function<void(void*)> deleter,
const TensorOptions& options = {},
const std::optional<Device> target_device = c10::nullopt) {
const std::optional<Device> target_device = std::nullopt) {
return for_blob(data, sizes)
.deleter(std::move(deleter))
.options(options)

View File

@ -48,7 +48,7 @@ std::optional<RecordFunctionCallback> extractCallback(
CallbackHandle handle) {
auto it = findCallback(entries, handle);
if (it == entries.end()) {
return c10::nullopt;
return std::nullopt;
}
auto out = it->callback_;
entries.erase(it);
@ -313,7 +313,7 @@ StepCallbacks CacheEntry::getActiveCallbacks() {
std::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
getActiveCallbacksImpl();
if (C10_LIKELY(active_callbacks_.empty())) {
return c10::nullopt;
return std::nullopt;
}
return active_callbacks_;
}
@ -589,7 +589,7 @@ std::optional<OperatorName> RecordFunction::operator_name() const {
return std::visit(
c10::overloaded(
[&](const std::string&) -> std::optional<OperatorName> {
return c10::nullopt;
return std::nullopt;
},
[](const schema_ref_t schema) -> std::optional<OperatorName> {
return schema.get().operator_name();
@ -601,7 +601,7 @@ std::optional<c10::FunctionSchema> RecordFunction::operator_schema() const {
return std::visit(
c10::overloaded(
[&](const std::string&) -> std::optional<c10::FunctionSchema> {
return c10::nullopt;
return std::nullopt;
},
[](const schema_ref_t schema) -> std::optional<c10::FunctionSchema> {
return schema.get();

View File

@ -2,7 +2,7 @@
#include <ATen/Dimname.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <optional>
using at::NameType;
using at::Symbol;
@ -49,7 +49,7 @@ TEST(DimnameTest, createNormalName) {
static void check_unify_and_match(
const std::string& dimname,
const std::string& other,
at::optional<const std::string> expected) {
std::optional<const std::string> expected) {
auto dimname1 = Dimname::fromSymbol(Symbol::dimname(dimname));
auto dimname2 = Dimname::fromSymbol(Symbol::dimname(other));
auto result = dimname1.unify(dimname2);
@ -69,5 +69,5 @@ TEST(DimnameTest, unifyAndMatch) {
check_unify_and_match("a", "*", "a");
check_unify_and_match("*", "a", "a");
check_unify_and_match("*", "*", "*");
check_unify_and_match("a", "b", c10::nullopt);
check_unify_and_match("a", "b", std::nullopt);
}

View File

@ -80,9 +80,9 @@ TEST(NamedTensorTest, internalSetNamesInplace) {
ASSERT_TRUE(dimnames_equal(retrieved_names, names));
// Drop names
at::internal_set_names_inplace(tensor, at::nullopt);
at::internal_set_names_inplace(tensor, std::nullopt);
ASSERT_TRUE(tensor.get_named_tensor_meta() == nullptr);
ASSERT_TRUE(tensor.opt_names() == at::nullopt);
ASSERT_TRUE(tensor.opt_names() == std::nullopt);
}
TEST(NamedTensorTest, empty) {
@ -93,10 +93,10 @@ TEST(NamedTensorTest, empty) {
std::vector<Dimname> names = { N, C, H, W };
auto tensor = at::empty({});
ASSERT_EQ(tensor.opt_names(), at::nullopt);
ASSERT_EQ(tensor.opt_names(), std::nullopt);
tensor = at::empty({1, 2, 3});
ASSERT_EQ(tensor.opt_names(), at::nullopt);
ASSERT_EQ(tensor.opt_names(), std::nullopt);
tensor = at::empty({1, 2, 3, 4}, names);
ASSERT_TRUE(dimnames_equal(tensor.opt_names().value(), names));

View File

@ -6,7 +6,7 @@
#include <ATen/native/DistributionTemplates.h>
#include <ATen/native/cpu/DistributionTemplates.h>
#include <torch/library.h>
#include <c10/util/Optional.h>
#include <optional>
#include <torch/all.h>
#include <stdexcept>
@ -194,7 +194,7 @@ TEST_F(RNGTest, Random) {
TEST_F(RNGTest, Random64bits) {
auto gen = at::make_generator<TestCPUGenerator>(std::numeric_limits<uint64_t>::max());
auto actual = torch::empty({1}, torch::kInt64);
actual.random_(std::numeric_limits<int64_t>::min(), c10::nullopt, gen);
actual.random_(std::numeric_limits<int64_t>::min(), std::nullopt, gen);
ASSERT_EQ(static_cast<uint64_t>(actual[0].item<int64_t>()), std::numeric_limits<uint64_t>::max());
}

View File

@ -173,7 +173,7 @@ TEST(RandomPermutationTest, TestIslandShuffle) {
bool shuffled2 = false;
for (int i = 0; i < 100; i++) {
cudaDeviceSynchronize();
std::optional<at::Generator> gen = c10::nullopt;
std::optional<at::Generator> gen = std::nullopt;
randperm_handle_duplicate_keys(keys, values, 8, 5, gen);
cudaDeviceSynchronize();
std::vector<int> slice1 = {values[0], values[1], values[2]};

View File

@ -2,7 +2,7 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Optional.h>
#include <optional>
#include <assert.h>

View File

@ -89,7 +89,7 @@ TEST(TestStream, GetAndSetTest) {
ASSERT_EQ_CUDA(curStream, defaultStream);
}
void thread_fun(at::optional<at::cuda::CUDAStream>& cur_thread_stream) {
void thread_fun(std::optional<at::cuda::CUDAStream>& cur_thread_stream) {
auto new_stream = at::cuda::getStreamFromPool();
at::cuda::setCurrentCUDAStream(new_stream);
cur_thread_stream = {at::cuda::getCurrentCUDAStream()};
@ -99,7 +99,7 @@ void thread_fun(at::optional<at::cuda::CUDAStream>& cur_thread_stream) {
// Ensures streams are thread local
TEST(TestStream, MultithreadGetAndSetTest) {
if (!at::cuda::is_available()) return;
at::optional<at::cuda::CUDAStream> s0, s1;
std::optional<at::cuda::CUDAStream> s0, s1;
std::thread t0{thread_fun, std::ref(s0)};
std::thread t1{thread_fun, std::ref(s1)};

View File

@ -44,7 +44,7 @@ Tensor empty_strided_override(
std::optional<c10::Device> device,
std::optional<bool> pin_memory) {
return empty_override(fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, c10::nullopt);
return empty_override(fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, std::nullopt);
}
TORCH_LIBRARY_IMPL(aten, MAIA, m) {

View File

@ -3,7 +3,7 @@
#include <ATen/Tensor.h>
#include <ATen/native/TensorIterator.h>
#include <torch/library.h>
#include <c10/util/Optional.h>
#include <optional>
#include <torch/all.h>
#include <stdexcept>
@ -75,7 +75,7 @@ void test_random_from_to(const at::Device& device) {
};
tos = {
1L,
static_cast<::std::optional<int64_t>>(c10::nullopt)
static_cast<::std::optional<int64_t>>(::std::nullopt)
};
} else if constexpr (::std::is_signed_v<T>) {
constexpr int64_t min_from = _min_from<T>();
@ -90,7 +90,7 @@ void test_random_from_to(const at::Device& device) {
::std::optional<int64_t>(0L),
::std::optional<int64_t>(42L),
::std::optional<int64_t>(max_to),
static_cast<::std::optional<int64_t>>(c10::nullopt)
static_cast<::std::optional<int64_t>>(::std::nullopt)
};
} else {
froms = {
@ -100,7 +100,7 @@ void test_random_from_to(const at::Device& device) {
tos = {
::std::optional<int64_t>(42L),
::std::optional<int64_t>(max_to),
static_cast<::std::optional<int64_t>>(c10::nullopt)
static_cast<::std::optional<int64_t>>(::std::nullopt)
};
}

View File

@ -25,7 +25,7 @@ inline bool CheckStrideIndices(const Tensor& t, at::MemoryFormat format) {
// testing computeStrideProps with `IValue ival(t)` somehow doesn't work on CI
// with onnx; The function works fine within, but stride properties is somehow
// altered in ival->type()->cast<TensorType>();
auto tt = TensorType::create(c10::nullopt, c10::nullopt, t.sizes(), t.strides(), c10::nullopt);
auto tt = TensorType::create(std::nullopt, std::nullopt, t.sizes(), t.strides(), std::nullopt);
TORCH_INTERNAL_ASSERT(tt->stride_properties().isComplete(), "complete stride properties is needed for the test");
auto index_iter = stride_indices.begin();
@ -53,11 +53,11 @@ TEST(StridePropertiesTest, ZeroStrideIndicesEagerConsistencyTest) {
auto permuted_tensor = at::rand({6, 3, 1, 5, 2}).permute({0, 3, 2, 1, 4}); // permute dim-1 & dim-3
auto tensor = permuted_tensor.expand({6, 5, 4, 3, 2}); // expand dim-2
auto temp = TensorType::create(c10::nullopt, c10::nullopt, tensor.sizes(), tensor.strides(), c10::nullopt);
auto temp = TensorType::create(std::nullopt, std::nullopt, tensor.sizes(), tensor.strides(), std::nullopt);
// TensorIterator would preserve stride order, this is the eager reference
auto eager_tensor = tensor.relu();
auto ref_type = TensorType::create(c10::nullopt, c10::nullopt, eager_tensor.sizes(), eager_tensor.strides(), c10::nullopt);
auto ref_type = TensorType::create(std::nullopt, std::nullopt, eager_tensor.sizes(), eager_tensor.strides(), std::nullopt);
TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
@ -80,7 +80,7 @@ TEST(StridePropertiesTest, SlicedStrideIndicesTest) {
// Sliced tensor shouldn't have changed stride order
Tensor t = at::rand({16, 4}).slice(1, 0, 4, 4);
auto temp = TensorType::create(c10::nullopt, c10::nullopt, t.sizes(), t.strides(), c10::nullopt);
auto temp = TensorType::create(std::nullopt, std::nullopt, t.sizes(), t.strides(), std::nullopt);
TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
std::vector<size_t> stride_indices(2);

View File

@ -13,7 +13,7 @@ TEST(TypeCustomPrinter, Basic) {
if (auto tensorType = t.cast<TensorType>()) {
return "CustomTensor";
}
return c10::nullopt;
return std::nullopt;
};
// Tensor types should be rewritten
@ -33,7 +33,7 @@ TEST(TypeCustomPrinter, ContainedTypes) {
if (auto tensorType = t.cast<TensorType>()) {
return "CustomTensor";
}
return c10::nullopt;
return std::nullopt;
};
torch::Tensor iv = torch::rand({2, 3});
const auto type = TensorType::create(iv);
@ -60,7 +60,7 @@ TEST(TypeCustomPrinter, NamedTuples) {
return "Rewritten";
}
}
return c10::nullopt;
return std::nullopt;
};
torch::Tensor iv = torch::rand({2, 3});
const auto type = TensorType::create(iv);

View File

@ -1493,7 +1493,7 @@ void test_conv2d_context(
const auto prepack_vulkan = callOpByName(
"vulkan_prepack::create_conv2d_context",
"",
weight, bias, stride, padding, dilation, groups, c10::nullopt, c10::nullopt);
weight, bias, stride, padding, dilation, groups, std::nullopt, std::nullopt);
const auto vulkan_output = callOpByName(
"vulkan_prepack::run_conv2d_context",
@ -1534,7 +1534,7 @@ void test_backwards_compatible_conv2d_context(
const auto prepack_vulkan = callOpByName(
"vulkan_prepack::conv2d_clamp_prepack",
"",
weight, bias, stride, padding, dilation, groups, c10::nullopt, c10::nullopt);
weight, bias, stride, padding, dilation, groups, std::nullopt, std::nullopt);
const auto vulkan_output = callOpByName(
"vulkan_prepack::conv2d_clamp_run",
@ -1576,7 +1576,7 @@ void test_transposed_conv2d_context(
const auto prepack_vulkan = callOpByName(
"vulkan_prepack::create_tconv2d_context",
"",
weight, bias, stride, padding, output_padding, dilation, groups, c10::nullopt, c10::nullopt);
weight, bias, stride, padding, output_padding, dilation, groups, std::nullopt, std::nullopt);
const auto vulkan_output = callOpByName(
"vulkan_prepack::run_tconv2d_context",
@ -2136,7 +2136,7 @@ TEST_F(VulkanAPITest, conv2d_clamp_after_div) {
const auto prepack_cpu = callOpByName(
"prepacked::conv2d_clamp_prepack",
"",
weight, bias, stride, padding, dilation, groups, 0.0f, c10::nullopt)[0];
weight, bias, stride, padding, dilation, groups, 0.0f, std::nullopt)[0];
const auto out_cpu = callOpByName(
"prepacked::conv2d_clamp_run",
@ -2147,7 +2147,7 @@ TEST_F(VulkanAPITest, conv2d_clamp_after_div) {
const auto prepack_vk = callOpByName(
"vulkan_prepack::create_conv2d_context",
"",
weight, bias, stride, padding, dilation, groups, 0.0f, c10::nullopt)[0];
weight, bias, stride, padding, dilation, groups, 0.0f, std::nullopt)[0];
const auto out_vk = callOpByName(
"vulkan_prepack::run_conv2d_context",

View File

@ -1852,8 +1852,8 @@ static void test_quantized_conv_transpose2d(
output_padding,
dilation,
groups,
c10::nullopt,
c10::nullopt);
std::nullopt,
std::nullopt);
const auto input_vk_q = at::quantize_per_tensor(
input.vulkan(), input_scale, input_zero_point, c10::ScalarType::QUInt8);
@ -2661,8 +2661,8 @@ void test_quantized_conv2d(
padding,
dilation,
groups,
c10::nullopt,
c10::nullopt);
std::nullopt,
std::nullopt);
const auto vulkan_output = callOpByName(
"vulkan_prepack::run_qconv2d_context",
"",

View File

@ -233,7 +233,7 @@ TEST(TestXNNPackOps, TestConvolution2dMultiThreaded) {
{weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat));
auto context = at::native::xnnpack::XNNPackConv2dOpContext::create_context(
std::move(weights_cpu), std::move(bias_cpu), {1, 1}, {2, 2}, {1, 1}, groups, c10::nullopt, c10::nullopt);
std::move(weights_cpu), std::move(bias_cpu), {1, 1}, {2, 2}, {1, 1}, groups, std::nullopt, std::nullopt);
std::atomic<int64_t> count{0};
int64_t num_workers = 5;
std::mutex lock;

View File

@ -104,10 +104,10 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
auto state_tensor = at::detail::empty_cpu(
{static_cast<int64_t>(total_size)},
ScalarType::Byte,
c10::nullopt,
c10::nullopt,
c10::nullopt,
c10::nullopt);
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed();
auto offset = this->philox_offset_per_thread();

View File

@ -87,7 +87,7 @@ static void THPEvent_dealloc(THPEvent* self) {
static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
HANDLE_TH_ERRORS
at::optional<at::Device> device = self->event.device();
std::optional<at::Device> device = self->event.device();
if (!device) {
Py_RETURN_NONE;
}

View File

@ -1220,7 +1220,7 @@ int THPVariable_set_names(PyObject* self, PyObject* names, void* unused) {
}
const auto& var = THPVariable_Unpack(self);
if (names == Py_None) {
at::internal_set_names_inplace(var, at::nullopt);
at::internal_set_names_inplace(var, std::nullopt);
} else {
TORCH_CHECK(
THPUtils_checkDimnameList(names),

View File

@ -114,7 +114,7 @@ static PyObject* THCPEvent_get_cuda_event(THCPEvent* self, void* unused) {
static PyObject* THCPEvent_get_device(THCPEvent* self, void* unused) {
HANDLE_TH_ERRORS
at::optional<at::Device> device = self->cuda_event.device();
std::optional<at::Device> device = self->cuda_event.device();
if (!device) {
Py_RETURN_NONE;
}

View File

@ -291,7 +291,7 @@ ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
static inline void check_tensor(
const at::Tensor& input,
const at::optional<at::Tensor>& output,
const std::optional<at::Tensor>& output,
int input_multiplier,
int output_multiplier,
int64_t ref_numel,
@ -396,8 +396,8 @@ void check_inputs(
check_tensor(
input,
i == static_cast<std::remove_cv_t<decltype(i)>>(root)
? at::optional<at::Tensor>{output}
: at::nullopt,
? std::optional<at::Tensor>{output}
: std::nullopt,
input_multiplier,
output_multiplier,
numel,

View File

@ -653,7 +653,7 @@ struct NCCLTraceBuffer {
std::chrono::milliseconds timeout_ms,
bool isP2P) {
if (!enabled_) {
return c10::nullopt;
return std::nullopt;
}
auto traceback =
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
@ -765,7 +765,7 @@ struct NCCLTraceBuffer {
bool can_compute_duration = false;
Event* startEvent = nullptr;
Event* endEvent = nullptr;
std::optional<float> duration = c10::nullopt;
std::optional<float> duration = std::nullopt;
std::unique_lock<std::mutex> guard(mutex_);

View File

@ -7,6 +7,7 @@
#include <torch/csrc/distributed/c10d/socket.h>
#include <cstring>
#include <optional>
#include <system_error>
#include <utility>
#include <vector>
@ -37,7 +38,6 @@ C10_DIAGNOSTIC_POP()
#include <torch/csrc/distributed/c10d/logging.h>
#include <c10/util/CallOnce.h>
#include <c10/util/Optional.h>
namespace c10d::detail {
namespace {
@ -142,7 +142,7 @@ class SocketImpl {
explicit SocketImpl(
Handle hnd,
c10::optional<::addrinfo> remote = c10::nullopt) noexcept
std::optional<::addrinfo> remote = std::nullopt) noexcept
: hnd_{hnd}, remote_(remote) {}
SocketImpl(const SocketImpl& other) = delete;
@ -181,7 +181,7 @@ class SocketImpl {
return hnd_;
}
const c10::optional<::addrinfo>& remote() const noexcept {
const std::optional<::addrinfo>& remote() const noexcept {
return remote_;
}
@ -191,7 +191,7 @@ class SocketImpl {
bool setSocketFlag(int level, int optname, bool value) noexcept;
Handle hnd_;
const c10::optional<::addrinfo> remote_;
const std::optional<::addrinfo> remote_;
};
} // namespace c10d::detail

View File

@ -60,30 +60,30 @@ int64_t store(std::shared_ptr<Graph> graph) {
}
// XXX: Does not grab mutex
static at::optional<KernelSpec*> nolock_retrieve(
static std::optional<KernelSpec*> nolock_retrieve(
KernelCacheImpl& cache,
const int64_t key) {
auto it = cache.specMap_.find(key);
if (it == cache.specMap_.end())
return at::nullopt;
return std::nullopt;
return &(it->second);
}
at::optional<KernelSpec*> retrieve(const int64_t key) {
std::optional<KernelSpec*> retrieve(const int64_t key) {
auto& cache = getKernelCache();
std::lock_guard<std::mutex> guard{cache.mutex_};
return nolock_retrieve(cache, key);
}
// precondition: graph has been normalized via normalizeGraphForCache
at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
std::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
auto& cache = getKernelCache();
std::string repr = graph->toString(false);
std::lock_guard<std::mutex> guard{cache.mutex_};
auto it = cache.graphToKey_.find(repr);
if (it == cache.graphToKey_.end())
return at::nullopt;
return std::nullopt;
return nolock_retrieve(cache, it->second);
}

View File

@ -22,10 +22,10 @@ TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(
TORCH_API int64_t store(std::shared_ptr<Graph> graph);
// Given a graph, find a KernelSpec based on it
TORCH_API at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph);
TORCH_API std::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph);
// Returns the graph corresponding to the given key (if it exists)
TORCH_API at::optional<KernelSpec*> retrieve(const int64_t key);
TORCH_API std::optional<KernelSpec*> retrieve(const int64_t key);
// Returns the size of the fusion key -> KernelSpec cache.
// Only used for testing.

View File

@ -502,42 +502,42 @@ struct Environment {
{"len",
makeMagic(
"__len__",
std::make_shared<BuiltinFunction>(aten::len, at::nullopt))},
std::make_shared<BuiltinFunction>(aten::len, std::nullopt))},
{"hex",
makeMagic(
"__hex__",
std::make_shared<BuiltinFunction>(aten::hex, at::nullopt))},
std::make_shared<BuiltinFunction>(aten::hex, std::nullopt))},
{"oct",
makeMagic(
"__oct__",
std::make_shared<BuiltinFunction>(aten::oct, at::nullopt))},
std::make_shared<BuiltinFunction>(aten::oct, std::nullopt))},
{"round",
makeMagic(
"__round__",
std::make_shared<BuiltinFunction>(aten::round, at::nullopt))},
{"hash", std::make_shared<BuiltinFunction>(aten::hash, at::nullopt)},
{"id", std::make_shared<BuiltinFunction>(prim::id, at::nullopt)},
{"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
{"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
{"abs", std::make_shared<BuiltinFunction>(prim::abs, at::nullopt)},
{"all", std::make_shared<BuiltinFunction>(aten::all, at::nullopt)},
{"any", std::make_shared<BuiltinFunction>(aten::any, at::nullopt)},
std::make_shared<BuiltinFunction>(aten::round, std::nullopt))},
{"hash", std::make_shared<BuiltinFunction>(aten::hash, std::nullopt)},
{"id", std::make_shared<BuiltinFunction>(prim::id, std::nullopt)},
{"min", std::make_shared<BuiltinFunction>(prim::min, std::nullopt)},
{"max", std::make_shared<BuiltinFunction>(prim::max, std::nullopt)},
{"abs", std::make_shared<BuiltinFunction>(prim::abs, std::nullopt)},
{"all", std::make_shared<BuiltinFunction>(aten::all, std::nullopt)},
{"any", std::make_shared<BuiltinFunction>(aten::any, std::nullopt)},
{"divmod",
std::make_shared<BuiltinFunction>(aten::divmod, at::nullopt)},
{"sum", std::make_shared<BuiltinFunction>(aten::sum, at::nullopt)},
std::make_shared<BuiltinFunction>(aten::divmod, std::nullopt)},
{"sum", std::make_shared<BuiltinFunction>(aten::sum, std::nullopt)},
{"list", SpecialFormValue::create(prim::list)},
{"dict", SpecialFormValue::create(prim::dict)},
{"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
{"chr", std::make_shared<BuiltinFunction>(aten::chr, at::nullopt)},
{"bin", std::make_shared<BuiltinFunction>(aten::bin, at::nullopt)},
{"pow", std::make_shared<BuiltinFunction>(aten::pow, at::nullopt)},
{"ord", std::make_shared<BuiltinFunction>(aten::ord, std::nullopt)},
{"chr", std::make_shared<BuiltinFunction>(aten::chr, std::nullopt)},
{"bin", std::make_shared<BuiltinFunction>(aten::bin, std::nullopt)},
{"pow", std::make_shared<BuiltinFunction>(aten::pow, std::nullopt)},
{"range", SpecialFormValue::create(prim::range)},
{"zip", SpecialFormValue::create(prim::zip)},
{"enumerate", SpecialFormValue::create(prim::enumerate)},
{"rangelist",
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
std::make_shared<BuiltinFunction>(prim::rangelist, std::nullopt)},
{"sorted",
std::make_shared<BuiltinFunction>(aten::sorted, at::nullopt)},
std::make_shared<BuiltinFunction>(aten::sorted, std::nullopt)},
// Only AssertionError is bound so that we can use it from emitAssert,
// all other exceptions should be resolved at the Python level
{"AssertionError",
@ -2945,7 +2945,7 @@ struct to_ir {
args.push_back(rhs);
makeMagic(
"__setitem__",
std::make_shared<BuiltinFunction>(aten::_set_item, at::nullopt))
std::make_shared<BuiltinFunction>(aten::_set_item, std::nullopt))
->call(stmtRange, method, args, {}, 0);
}
}
@ -4110,7 +4110,7 @@ struct to_ir {
auto val =
asSimple(makeMagic(
magicMethod,
std::make_shared<BuiltinFunction>(opSymbol, at::nullopt))
std::make_shared<BuiltinFunction>(opSymbol, std::nullopt))
->call(tree->range(), method, named_values, {}, 0));
// if we emitted the unary op and not some other overloaded function,
@ -4362,7 +4362,7 @@ struct to_ir {
return asSimple(
makeMagic(
overload, std::make_shared<BuiltinFunction>(kind, at::nullopt))
overload, std::make_shared<BuiltinFunction>(kind, std::nullopt))
->call(tree->range(), method, named_values, {}, 0));
}
@ -4790,7 +4790,7 @@ struct to_ir {
}
if (sliceable->type()->cast<TupleType>()) {
std::vector<at::optional<NamedValue>> tuple_args;
std::vector<std::optional<NamedValue>> tuple_args;
// since we are only dealing with tuple slicing, we try to keep
// tuple args separate for now
tuple_args.reserve(3);
@ -5170,7 +5170,7 @@ struct to_ir {
Value* emitTupleSlice(
const SourceRange& loc,
const NamedValue& tuple_val,
const std::vector<at::optional<NamedValue>>& tuple_args) {
const std::vector<std::optional<NamedValue>>& tuple_args) {
auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
int64_t tuple_len = tuple_type->elements().size();
auto beg_val = tuple_args[0];
@ -5224,7 +5224,7 @@ struct to_ir {
auto s_tuple_val =
sv->asTupleValue(val_range, method)->asValue(val_range, method);
const SliceExpr& slice = SliceExpr(subscript_exprs[0]);
std::vector<at::optional<NamedValue>> tuple_args;
std::vector<std::optional<NamedValue>> tuple_args;
tuple_args.reserve(3);
if (slice.start().present()) {
auto begin = NamedValue(

View File

@ -226,7 +226,7 @@ std::optional<std::string> ScriptTypeParser::parseBaseTypeName(
}
} break;
}
return at::nullopt;
return std::nullopt;
}
TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {

View File

@ -305,7 +305,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
};
struct TORCH_API BuiltinModule : public SugaredValue {
BuiltinModule(std::string name, std::optional<int64_t> version = at::nullopt)
BuiltinModule(std::string name, std::optional<int64_t> version = std::nullopt)
: name(std::move(name)), version(version) {}
std::string kind() const override {
@ -514,8 +514,8 @@ struct TORCH_API CastValue : public BuiltinFunction {
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
if (args.size() == 1 && kwargs.empty()) {
auto len_op = std::make_shared<BuiltinFunction>(aten::len, at::nullopt);
auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, at::nullopt);
auto len_op = std::make_shared<BuiltinFunction>(aten::len, std::nullopt);
auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, std::nullopt);
auto zero = m.graph()->insertConstant(0);
auto v = args[0].value(*m.graph());

View File

@ -18,7 +18,7 @@ Function::Function(c10::QualifiedName name) : name_(std::move(name)) {}
Function::Function(
c10::QualifiedName name,
Code code,
at::optional<c10::FunctionSchema> schema)
std::optional<c10::FunctionSchema> schema)
: name_(std::move(name)),
code_(std::move(code)),
schema_(std::move(schema)) {}

View File

@ -21,7 +21,7 @@ class TORCH_API Function : public torch::jit::Function {
Function(
c10::QualifiedName name,
Code code,
at::optional<c10::FunctionSchema> schema);
std::optional<c10::FunctionSchema> schema);
void run(Stack& stack) override;
at::IValue operator()(Stack& stack);
void ensure_defined() override {}
@ -72,7 +72,7 @@ class TORCH_API Function : public torch::jit::Function {
private:
c10::QualifiedName name_;
Code code_;
at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
std::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
};
std::optional<std::function<void(Stack&)>> makeOperatorFunction(

View File

@ -415,7 +415,7 @@ struct GraphFuser {
return group;
}
at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
std::optional<Node*> tryFuse(Node* consumer, Value* producer) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
@ -431,13 +431,13 @@ struct GraphFuser {
aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
if (!shouldFuse) {
return at::nullopt;
return std::nullopt;
}
if ((consumer->inputs().size() + consumer->outputs().size() +
producer->node()->inputs().size() +
producer->node()->outputs().size()) > subgraph_arg_limit_) {
return at::nullopt;
return std::nullopt;
}
auto group = consumer;

View File

@ -15,18 +15,18 @@ using namespace ::c10::onnx;
namespace {
at::optional<Node*> FindFusibleListUnpack(Node* n) {
std::optional<Node*> FindFusibleListUnpack(Node* n) {
// 1. number of outputs is restricted to 1.
// 2. output is only used by prim::ListUnpack.
if (n->outputs().size() != 1) {
return at::nullopt;
return std::nullopt;
}
if (n->output()->uses().size() != 1) {
return at::nullopt;
return std::nullopt;
}
auto listUnpackNode = n->output()->uses()[0].user;
if (listUnpackNode->kind() != prim::ListUnpack) {
return at::nullopt;
return std::nullopt;
}
return listUnpackNode;
}

View File

@ -548,7 +548,7 @@ class ShapePropagator : public PropertyPropBase {
list_type = input_base_type->cast<ListType>();
}
at::optional<at::ScalarType> default_type =
std::optional<at::ScalarType> default_type =
tryScalarTypeFromJitType(*input_base_type);
if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
auto inp = toIValue(node->inputs().at(*grad_index));
@ -1195,7 +1195,7 @@ class ShapePropagator : public PropertyPropBase {
static const register_formula_for reduce_ops_with_opt_dtype{
{"aten::mean(Tensor self, *, int? dtype) -> Tensor"},
[](Node* node) -> type_vec_t {
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (auto type = node->input(0)->type()->cast<TensorType>()) {
auto ret = type->withDim(0);
if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
@ -1223,7 +1223,7 @@ class ShapePropagator : public PropertyPropBase {
[](Node* node) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
type = type->withDim(0);
at::optional<IValue> maybe_dtype_option =
std::optional<IValue> maybe_dtype_option =
node->get(attr::dtype);
if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
return {
@ -1348,7 +1348,7 @@ class ShapePropagator : public PropertyPropBase {
},
[](Node* node) -> type_vec_t {
auto maybe_keepdim = node->get<bool>(attr::keepdim);
at::optional<IValue> opt_dtype = node->get(attr::dtype);
std::optional<IValue> opt_dtype = node->get(attr::dtype);
return reduce_op_handler(
node,
/*num_reduce_dim=*/*maybe_keepdim ? 0 : 1,
@ -1370,7 +1370,7 @@ class ShapePropagator : public PropertyPropBase {
"aten::cumsum(Tensor self, int dim, *, int? dtype) -> Tensor",
"aten::log_softmax(Tensor self, int dim, int? dtype) -> Tensor"},
[](Node* node) -> type_vec_t {
at::optional<IValue> opt_dtype = node->get(attr::dtype);
std::optional<IValue> opt_dtype = node->get(attr::dtype);
return reduce_op_handler(
node,
/*num_reduce_dim=*/0,
@ -1389,7 +1389,7 @@ class ShapePropagator : public PropertyPropBase {
static const register_formula_for register_softmax{
{"aten::softmax(Tensor self, int dim, int? dtype) -> Tensor"},
[](Node* node) -> type_vec_t {
at::optional<IValue> opt_dtype = node->get(attr::dtype);
std::optional<IValue> opt_dtype = node->get(attr::dtype);
return reduce_op_handler(
node,
/*num_reduced_dim=*/0,
@ -1399,18 +1399,18 @@ class ShapePropagator : public PropertyPropBase {
static const auto factory_with_ndim =
[](Node* node, int dim, at::ScalarType default_dtype) -> type_vec_t {
at::optional<IValue> maybe_layout_option = node->get(attr::layout);
std::optional<IValue> maybe_layout_option = node->get(attr::layout);
if (!maybe_layout_option)
return {};
at::optional<IValue> maybe_device_option = node->get(attr::device);
std::optional<IValue> maybe_device_option = node->get(attr::device);
if (!maybe_device_option)
return {};
auto device =
(maybe_device_option->isNone() ? at::kCPU
: maybe_device_option->toDevice());
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (!maybe_dtype_option)
return {};
auto dtype =
@ -1427,11 +1427,11 @@ class ShapePropagator : public PropertyPropBase {
auto in_type = tt->scalarType();
auto in_dev = tt->device();
at::optional<IValue> maybe_layout_option = node->get(attr::layout);
std::optional<IValue> maybe_layout_option = node->get(attr::layout);
if (!maybe_layout_option)
return {};
at::optional<IValue> maybe_device_option = node->get(attr::device);
std::optional<IValue> maybe_device_option = node->get(attr::device);
if (!maybe_device_option)
return {};
@ -1439,7 +1439,7 @@ class ShapePropagator : public PropertyPropBase {
in_dev = maybe_device_option->toDevice();
}
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (!maybe_dtype_option)
return {};

View File

@ -280,7 +280,7 @@ void checkMutableFunctionDefault(
FunctionSchema getSchemaWithNameAndDefaults(
const SourceRange& range,
const FunctionSchema& schema,
const at::optional<std::string>& new_name,
const std::optional<std::string>& new_name,
const FunctionDefaults& default_args) {
std::vector<Argument> new_args;
for (auto& arg : schema.arguments()) {
@ -1796,7 +1796,7 @@ void initJitScriptBindings(PyObject* module) {
method.setSchema(getSchemaWithNameAndDefaults(
defs_it->range(),
method.getSchema(),
at::nullopt,
std::nullopt,
default_it->second));
++defs_it;
}
@ -2277,7 +2277,7 @@ void initJitScriptBindings(PyObject* module) {
method.setSchema(getSchemaWithNameAndDefaults(
defs_it->range(),
method.getSchema(),
at::nullopt,
std::nullopt,
*defaults_it));
++defs_it;
++defaults_it;

Some files were not shown because too many files have changed in this diff Show More