mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove unneeded std::make_optional (#141567)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141567 Approved by: https://github.com/albanD
This commit is contained in:
@ -17,7 +17,7 @@ namespace at {
|
||||
/// Return the Device of a Tensor, if the Tensor is defined.
|
||||
inline std::optional<Device> device_of(const Tensor& t) {
|
||||
if (t.defined()) {
|
||||
return std::make_optional(t.device());
|
||||
return t.device();
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -531,7 +531,7 @@ Tensor to_functional_tensor(const Tensor& tensor) {
|
||||
}
|
||||
std::optional<Tensor> to_functional_tensor(const std::optional<Tensor>& tensor) {
|
||||
if (tensor.has_value()) {
|
||||
return std::make_optional<Tensor>(to_functional_tensor(*tensor));
|
||||
return to_functional_tensor(*tensor);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -569,7 +569,7 @@ Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
|
||||
}
|
||||
std::optional<Tensor> from_functional_tensor(const std::optional<Tensor>& t, bool assert_functional) {
|
||||
if (t.has_value()) {
|
||||
return std::make_optional<Tensor>(from_functional_tensor(*t, assert_functional));
|
||||
return from_functional_tensor(*t, assert_functional);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ NestedTensorImpl::NestedTensorImpl(
|
||||
std::optional<int64_t> NestedTensorImpl::opt_size(int64_t d) const {
|
||||
if (C10_UNLIKELY(!opt_sizes_.has_value())) {
|
||||
// Cache the metadata to avoid recomputing it each time.
|
||||
opt_sizes_ = std::make_optional(construct_opt_sizes(nested_sizes_));
|
||||
opt_sizes_ = construct_opt_sizes(nested_sizes_);
|
||||
}
|
||||
d = at::maybe_wrap_dim(d, dim(), false);
|
||||
if ((*opt_sizes_)[d] == -1) {
|
||||
|
@ -171,7 +171,7 @@ TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef sha
|
||||
// This will bypass all shape checking in the TensorIterator. Kernels which call this method
|
||||
// are expected to check shapes before calling `add_owned_input` or `add_owned_output`.
|
||||
TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
|
||||
static_shape_ = std::make_optional(DimVector(shape));
|
||||
static_shape_ = DimVector(shape);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ struct OperatorName final {
|
||||
if (pos == std::string::npos) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return std::make_optional(std::string_view(name.data(), pos));
|
||||
return std::string_view(name.data(), pos);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -187,7 +187,7 @@ static std::tuple<Tensor, std::optional<int64_t>> logspace_Tensor_Tensor_batch_r
|
||||
std::optional<at::Layout> layout,
|
||||
std::optional<at::Device> device,
|
||||
std::optional<bool> pin_memory){
|
||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, base, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, std::optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
|
||||
@ -201,7 +201,7 @@ static std::tuple<Tensor, std::optional<int64_t>> logspace_Tensor_Scalar_batch_r
|
||||
std::optional<bool> pin_memory){
|
||||
|
||||
auto end_t = at::native::wrapped_scalar_tensor(end, start.device());
|
||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, base, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, std::optional<int64_t>> logspace_Scalar_Tensor_batch_rule(
|
||||
@ -215,7 +215,7 @@ static std::tuple<Tensor, std::optional<int64_t>> logspace_Scalar_Tensor_batch_r
|
||||
std::optional<bool> pin_memory){
|
||||
|
||||
auto start_t = at::native::wrapped_scalar_tensor(start, end.device());
|
||||
return linspace_logspace_batch_rule_helper(start_t, std::nullopt, end, end_bdim, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
||||
return linspace_logspace_batch_rule_helper(start_t, std::nullopt, end, end_bdim, steps, base, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
static bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) {
|
||||
|
@ -337,7 +337,7 @@ struct OptionalHIPStreamGuardMasqueradingAsCUDA {
|
||||
std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
|
||||
auto r = guard_.original_stream();
|
||||
if (r.has_value()) {
|
||||
return std::make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -346,7 +346,7 @@ struct OptionalHIPStreamGuardMasqueradingAsCUDA {
|
||||
std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
|
||||
auto r = guard_.current_stream();
|
||||
if (r.has_value()) {
|
||||
return std::make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -1981,7 +1981,7 @@ std::tuple<Tensor, Tensor> var_mean(
|
||||
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
|
||||
return at::var_mean(
|
||||
self, /*dim=*/at::OptionalIntArrayRef(dim),
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0),
|
||||
keepdim);
|
||||
}
|
||||
|
||||
@ -1989,20 +1989,20 @@ std::tuple<Tensor, Tensor> std_mean(
|
||||
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
|
||||
return at::std_mean(
|
||||
self, /*dim=*/at::OptionalIntArrayRef(dim),
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0),
|
||||
keepdim);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
|
||||
return at::std_mean(
|
||||
self, /*dim=*/std::nullopt,
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
|
||||
return at::var_mean(
|
||||
self, /*dim=*/std::nullopt,
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0));
|
||||
}
|
||||
std::tuple<Tensor&, Tensor&> var_mean_out(
|
||||
Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
|
||||
@ -2037,36 +2037,36 @@ std::tuple<Tensor, Tensor> std_mean(
|
||||
Tensor var(const Tensor& self, bool unbiased) {
|
||||
return at::var(
|
||||
self, /*dim=*/std::nullopt,
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0));
|
||||
}
|
||||
|
||||
Tensor var(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
|
||||
return at::var(
|
||||
self, /*dim=*/at::OptionalIntArrayRef(dim),
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0),
|
||||
keepdim);
|
||||
}
|
||||
|
||||
Tensor& var_out(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
|
||||
return at::var_out(
|
||||
result, self, /*dim=*/at::OptionalIntArrayRef(dim),
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0),
|
||||
keepdim);
|
||||
}
|
||||
|
||||
Tensor std(const Tensor& self, bool unbiased) {
|
||||
return at::std(
|
||||
self, /*dim=*/std::nullopt, /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
|
||||
self, /*dim=*/std::nullopt, /*correction=*/Scalar(unbiased ? 1 : 0));
|
||||
}
|
||||
|
||||
Tensor std(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
|
||||
return at::std(self, dim,
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0), keepdim);
|
||||
}
|
||||
|
||||
Tensor& std_out(const Tensor& self, at::OptionalIntArrayRef opt_dim, bool unbiased, bool keepdim, Tensor& result) {
|
||||
return at::std_out(result, self, opt_dim,
|
||||
/*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
|
||||
/*correction=*/Scalar(unbiased ? 1 : 0), keepdim);
|
||||
}
|
||||
|
||||
Tensor std(const Tensor& self, at::OptionalIntArrayRef dim,
|
||||
|
@ -167,7 +167,7 @@ inline void setStrided(
|
||||
|
||||
/* storage offset */
|
||||
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
||||
self_->set_sizes_and_strides(size, stride, std::make_optional(storage_offset));
|
||||
self_->set_sizes_and_strides(size, stride, storage_offset);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -2013,8 +2013,8 @@ Tensor to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<in
|
||||
|
||||
Tensor to_meta(const Tensor& tensor) {
|
||||
auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
|
||||
/*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
|
||||
/*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
|
||||
/*dtype=*/tensor.scalar_type(), /*layout=*/tensor.layout(), \
|
||||
/*device=*/c10::Device(c10::kMeta), /*pin_memory=*/std::nullopt);
|
||||
// needs to handle wrapped numbers, so dtype promotion works properly.
|
||||
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
out.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
|
@ -1179,7 +1179,7 @@ inline void setStridedUnchecked(
|
||||
ArrayRef<T> stride,
|
||||
T&& storage_offset) {
|
||||
auto* self_ = self.unsafeGetTensorImpl();
|
||||
self_->set_sizes_and_strides(size, stride, std::make_optional(std::forward<T>(storage_offset)));
|
||||
self_->set_sizes_and_strides(size, stride, std::forward<T>(storage_offset));
|
||||
}
|
||||
|
||||
Tensor as_strided_tensorimpl_meta_symint(const Tensor& self, SymIntArrayRef sym_size, SymIntArrayRef sym_stride, std::optional<c10::SymInt> sym_storage_offset_) {
|
||||
|
@ -82,7 +82,7 @@ ContextConv create(
|
||||
|
||||
return ContextConv{
|
||||
std::move(packed_weight),
|
||||
bias.has_value() ? std::make_optional(*bias) : std::nullopt,
|
||||
bias,
|
||||
{padding_expanded.begin(), padding_expanded.end()},
|
||||
{stride_expanded.begin(), stride_expanded.end()},
|
||||
{dilation_expanded.begin(), dilation_expanded.end()},
|
||||
@ -276,6 +276,6 @@ Tensor conv_run(
|
||||
return op_context->run(input);
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
} // namespace at::native::mkldnn::internal::convolution
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
||||
|
@ -1159,8 +1159,7 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, st
|
||||
return inputTensorPNorm;
|
||||
};
|
||||
|
||||
std::optional<IntArrayRef> inputBroadcastSize =
|
||||
std::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size()));
|
||||
IntArrayRef inputBroadcastSize = makeArrayRef(tensor1_view.data(), tensor1_view.size());
|
||||
impl_func_norm_mps(x1,
|
||||
x2,
|
||||
OptionalScalarRef(p),
|
||||
|
@ -293,7 +293,7 @@ std::tuple<Tensor, Tensor, Tensor> unique_dim_consecutive_mps(const Tensor& self
|
||||
int64_t dim,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, true, std::make_optional((int64_t)dim));
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, true, dim);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
|
||||
|
@ -56,13 +56,13 @@ inline bool has_internal_overlap_helper(const at::Tensor t) {
|
||||
inline Tensor to_meta(const Tensor& t) {
|
||||
if (!t.defined()) return t;
|
||||
return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(),
|
||||
/*dtype=*/std::make_optional(t.scalar_type()), /*layout=*/std::make_optional(t.layout()),
|
||||
/*device=*/std::make_optional(c10::Device(kMeta)), /*pin_memory=*/std::nullopt);
|
||||
/*dtype=*/t.scalar_type(), /*layout=*/t.layout(),
|
||||
/*device=*/c10::Device(kMeta), /*pin_memory=*/std::nullopt);
|
||||
}
|
||||
|
||||
inline std::optional<Tensor> to_meta(const std::optional<Tensor>& t) {
|
||||
if (t.has_value()) {
|
||||
return std::make_optional<Tensor>(to_meta(*t));
|
||||
return to_meta(*t);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ class C10_API SymBool {
|
||||
|
||||
std::optional<bool> maybe_as_bool() const {
|
||||
if (!is_heap_allocated()) {
|
||||
return std::make_optional(data_);
|
||||
return data_;
|
||||
}
|
||||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ class C10_API SymInt {
|
||||
|
||||
std::optional<int64_t> maybe_as_int() const {
|
||||
if (!is_heap_allocated()) {
|
||||
return std::make_optional(data_);
|
||||
return data_;
|
||||
}
|
||||
auto* node = toSymNodeImplUnowned();
|
||||
if (auto c = node->constant_int()) {
|
||||
|
@ -69,9 +69,8 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
|
||||
for (const auto& s : strides) {
|
||||
stride_nodes.emplace_back(s.wrap_node(base));
|
||||
}
|
||||
return std::make_optional(
|
||||
std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
|
||||
std::move(base), std::move(size_nodes), std::move(stride_nodes)));
|
||||
return std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
|
||||
std::move(base), std::move(size_nodes), std::move(stride_nodes));
|
||||
}
|
||||
|
||||
// Special treatment because of numel
|
||||
|
@ -438,7 +438,7 @@ struct C10_API TensorOptions {
|
||||
std::optional<MemoryFormat> optional_memory_format) const noexcept {
|
||||
TensorOptions merged = *this;
|
||||
if (optional_memory_format.has_value()) {
|
||||
merged.set_memory_format(*optional_memory_format);
|
||||
merged.set_memory_format(optional_memory_format);
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ struct C10_API PyObjectSlot {
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return std::make_optional(_unchecked_untagged_pyobj());
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
|
@ -246,7 +246,7 @@ struct OptionalCUDAStreamGuard {
|
||||
std::optional<CUDAStream> original_stream() const {
|
||||
auto r = guard_.original_stream();
|
||||
if (r.has_value()) {
|
||||
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
|
||||
return CUDAStream(CUDAStream::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -258,7 +258,7 @@ struct OptionalCUDAStreamGuard {
|
||||
std::optional<CUDAStream> current_stream() const {
|
||||
auto r = guard_.current_stream();
|
||||
if (r.has_value()) {
|
||||
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
|
||||
return CUDAStream(CUDAStream::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -36,7 +36,6 @@ TEST(OptionalDeviceGuard, ResetDeviceDifferentDeviceType) {
|
||||
g.reset_device(Device(DeviceType::HIP, 2), &hip_impl);
|
||||
ASSERT_EQ(FakeGuardImpl<DeviceType::CUDA>::getDeviceIndex(), 0);
|
||||
ASSERT_EQ(FakeGuardImpl<DeviceType::HIP>::getDeviceIndex(), 2);
|
||||
ASSERT_EQ(g.current_device(), std::make_optional(Device(DeviceType::HIP, 2)));
|
||||
ASSERT_EQ(
|
||||
g.original_device(), std::make_optional(Device(DeviceType::HIP, 0)));
|
||||
ASSERT_EQ(g.current_device(), Device(DeviceType::HIP, 2));
|
||||
ASSERT_EQ(g.original_device(), Device(DeviceType::HIP, 0));
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ TEST(InlineDeviceGuard, Constructor) {
|
||||
/*
|
||||
{
|
||||
// Optional constructor
|
||||
TestGuard g(std::make_optional(dev(i)));
|
||||
TestGuard g(dev(i));
|
||||
test_body(g);
|
||||
}
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), init_i);
|
||||
@ -136,7 +136,7 @@ TEST(InlineOptionalDeviceGuard, Constructor) {
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), init_i);
|
||||
{
|
||||
// Optional constructor
|
||||
MaybeTestGuard g(std::make_optional(dev(i)));
|
||||
MaybeTestGuard g(dev(i));
|
||||
test_body(g);
|
||||
}
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), init_i);
|
||||
@ -170,12 +170,12 @@ TEST(InlineOptionalDeviceGuard, SetDevice) {
|
||||
MaybeTestGuard g;
|
||||
DeviceIndex i = 1;
|
||||
g.set_device(dev(i));
|
||||
ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i)));
|
||||
ASSERT_EQ(g.current_device(), std::make_optional(dev(i)));
|
||||
ASSERT_EQ(g.original_device(), dev(init_i));
|
||||
ASSERT_EQ(g.current_device(), dev(i));
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i);
|
||||
g.set_device(dev(i));
|
||||
ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i)));
|
||||
ASSERT_EQ(g.current_device(), std::make_optional(dev(i)));
|
||||
ASSERT_EQ(g.original_device(), dev(init_i));
|
||||
ASSERT_EQ(g.current_device(), dev(i));
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i);
|
||||
}
|
||||
|
||||
@ -185,11 +185,11 @@ TEST(InlineOptionalDeviceGuard, SetIndex) {
|
||||
DeviceIndex i = 1;
|
||||
MaybeTestGuard g;
|
||||
g.set_index(i);
|
||||
ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i)));
|
||||
ASSERT_EQ(g.current_device(), std::make_optional(dev(i)));
|
||||
ASSERT_EQ(g.original_device(), dev(init_i));
|
||||
ASSERT_EQ(g.current_device(), dev(i));
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i);
|
||||
g.set_index(i);
|
||||
ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i)));
|
||||
ASSERT_EQ(g.current_device(), std::make_optional(dev(i)));
|
||||
ASSERT_EQ(g.original_device(), dev(init_i));
|
||||
ASSERT_EQ(g.current_device(), dev(i));
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i);
|
||||
}
|
||||
|
@ -1,7 +1,11 @@
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/impl/FakeGuardImpl.h>
|
||||
#include <c10/core/impl/InlineStreamGuard.h>
|
||||
#include <vector>
|
||||
|
||||
using namespace c10;
|
||||
using namespace c10::impl;
|
||||
@ -10,7 +14,7 @@ constexpr auto TestDeviceType = DeviceType::CUDA;
|
||||
using TestGuardImpl = FakeGuardImpl<TestDeviceType>;
|
||||
|
||||
static Device dev(DeviceIndex index) {
|
||||
return Device(TestDeviceType, index);
|
||||
return Device{TestDeviceType, index};
|
||||
}
|
||||
|
||||
static Stream stream(DeviceIndex index, StreamId sid) {
|
||||
@ -109,19 +113,19 @@ TEST(InlineOptionalStreamGuard, Constructor) {
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 2);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0)));
|
||||
ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 2)));
|
||||
ASSERT_EQ(g.original_stream(), stream(0, 0));
|
||||
ASSERT_EQ(g.current_stream(), stream(1, 2));
|
||||
}
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
{
|
||||
OptionalTestGuard g(std::make_optional(stream(1, 2)));
|
||||
OptionalTestGuard g(stream(1, 2));
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 2);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0)));
|
||||
ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 2)));
|
||||
ASSERT_EQ(g.original_stream(), stream(0, 0));
|
||||
ASSERT_EQ(g.current_stream(), stream(1, 2));
|
||||
}
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0);
|
||||
@ -146,8 +150,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamSameDevice) {
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 3);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0)));
|
||||
ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 3)));
|
||||
ASSERT_EQ(g.original_stream(), stream(0, 0));
|
||||
ASSERT_EQ(g.current_stream(), stream(1, 3));
|
||||
}
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0);
|
||||
@ -164,8 +168,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamDifferentDevice) {
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(2), 3);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0);
|
||||
ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0)));
|
||||
ASSERT_EQ(g.current_stream(), std::make_optional(stream(2, 3)));
|
||||
ASSERT_EQ(g.original_stream(), stream(0, 0));
|
||||
ASSERT_EQ(g.current_stream(), stream(2, 3));
|
||||
}
|
||||
ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0);
|
||||
ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(2), 0);
|
||||
|
@ -154,7 +154,7 @@ static bool THPStorage_isPreservable(THPStorage* self) {
|
||||
|
||||
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/true) !=
|
||||
std::make_optional((PyObject*)self)) {
|
||||
(PyObject*)self) {
|
||||
return false;
|
||||
}
|
||||
if (storage.use_count() <= 1) {
|
||||
|
@ -344,7 +344,7 @@ bool isResurrectable(THPVariable* self) {
|
||||
// Check if this is hermetic. If it is, no resurrection.
|
||||
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false) !=
|
||||
std::make_optional((PyObject*)self)) {
|
||||
(PyObject*)self) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -452,7 +452,7 @@ static int THPVariable_subclass_clear(THPVariable* self) {
|
||||
if (!self->cdata.unsafeIsBorrowed() &&
|
||||
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false) ==
|
||||
std::make_optional((PyObject*)self)) {
|
||||
(PyObject*)self) {
|
||||
// TODO: empirically, on OS X this assert appears to be untrue
|
||||
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
||||
// distributed/rpc/test_process_group_agent.py
|
||||
|
@ -57,7 +57,7 @@ class TensorpipeCpuConverter : public TensorpipeDeviceTypeConverter {
|
||||
|
||||
message.tensors.push_back(std::move(tensor));
|
||||
|
||||
return std::make_optional(std::move(storageData));
|
||||
return storageData;
|
||||
} else {
|
||||
tensorpipe::CpuBuffer buffer;
|
||||
buffer.ptr = static_cast<char*>(storage.mutable_data());
|
||||
|
@ -2104,20 +2104,19 @@ std::vector<Value*> inlineCallTo(
|
||||
if (to_replace->kind() == prim::CallMethod) {
|
||||
auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>();
|
||||
if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
|
||||
module_instance_info = std::make_optional(ModuleInstanceInfo(
|
||||
class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
|
||||
module_instance_info = ModuleInstanceInfo(
|
||||
class_type_ptr, to_replace->input(0)->node()->s(attr::name));
|
||||
} else if (
|
||||
!to_replace->owningGraph()->inputs().empty() &&
|
||||
to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
|
||||
// This CallMethod must correspond to method of the same object
|
||||
// to which this graph belongs.
|
||||
module_instance_info =
|
||||
std::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
|
||||
module_instance_info = ModuleInstanceInfo(class_type_ptr, "SELF");
|
||||
} else {
|
||||
// Not sure if it is possible to come here ever.
|
||||
// TODO: Remove this else. Or add assert
|
||||
module_instance_info = std::make_optional(
|
||||
ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
|
||||
module_instance_info =
|
||||
ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ void Module::unsafeCopyMethod(
|
||||
std::optional<Method> Module::find_method(const std::string& basename) const {
|
||||
for (const auto& fn : cu_->methods()) {
|
||||
if (fn->name() == basename) {
|
||||
return std::make_optional<Method>(Method(this, fn.get()));
|
||||
return Method(this, fn.get());
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
@ -138,7 +138,7 @@ void slot_named_params_recurse(
|
||||
auto slots = obj->slots();
|
||||
size_t nslots = slots.size();
|
||||
for (const auto i : c10::irange(nslots)) {
|
||||
auto slot = slots[i];
|
||||
const auto& slot = slots[i];
|
||||
std::string name = parent_name.empty() ? parent_name : parent_name + ".";
|
||||
name += obj->type()->getAttributeName(i);
|
||||
// TODO: Fix this filter. Requires_grad is not the appropriate
|
||||
|
@ -35,8 +35,7 @@ std::optional<size_t> ConstantValueMap::GetRank(const std::string& tensorName) {
|
||||
}
|
||||
|
||||
void ConstantValueMap::SetAllGraphInputsStatic(bool all_static) {
|
||||
ConstantValueMap::getInstance().allGraphInputsStatic =
|
||||
std::make_optional(all_static);
|
||||
ConstantValueMap::getInstance().allGraphInputsStatic = all_static;
|
||||
}
|
||||
|
||||
std::optional<bool> ConstantValueMap::GetAllGraphInputsStatic() {
|
||||
|
@ -1695,14 +1695,12 @@ void initJITBindings(PyObject* module) {
|
||||
c10::DispatchKey dk_,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
std::optional<c10::DispatchKey> dk =
|
||||
std::make_optional(dk_);
|
||||
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true, dk);
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true, dk_);
|
||||
});
|
||||
return std::make_optional(
|
||||
py::make_tuple(func, func_dk, py::cast(op->getTags().vec())));
|
||||
return py::make_tuple(
|
||||
func, func_dk, py::cast(op->getTags().vec()));
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
|
@ -1044,16 +1044,13 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
std::optional<float> clamp = std::nullopt;
|
||||
std::optional<double> clamp = std::nullopt;
|
||||
if (n->inputs()[1]->node()->kind() == prim::Constant) {
|
||||
auto clamp_d = toIValue(n->inputs()[1])->toOptional<double>();
|
||||
clamp = clamp_d
|
||||
? std::make_optional<float>(static_cast<float>(clamp_d.value()))
|
||||
: std::nullopt;
|
||||
clamp = clamp_d;
|
||||
}
|
||||
auto te = clamp ? createLogit() : nullptr;
|
||||
float clamp_value = clamp ? *clamp : 0.0f;
|
||||
return [te, clamp_value](ProcessedNode* p_node) {
|
||||
return [te, clamp](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = create_empty_from(in0_t);
|
||||
@ -1068,7 +1065,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
|
||||
}
|
||||
at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
|
||||
int64_t nn = in0_t.numel();
|
||||
float c = clamp_value;
|
||||
float c = clamp.value() ? static_cast<float>(clamp.value()) : 0;
|
||||
te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c});
|
||||
};
|
||||
})
|
||||
|
@ -77,7 +77,7 @@ std::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor) {
|
||||
std::optional<BackendDevice> GetBackendDevice(
|
||||
const std::optional<c10::Device>& device) {
|
||||
if (device) {
|
||||
return std::make_optional(atenDeviceToBackendDevice(*device));
|
||||
return atenDeviceToBackendDevice(*device);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -1274,16 +1274,16 @@ std::vector<Shape> compute_shape_select_scatter(
|
||||
auto self_meta = at::native::empty_strided_meta_symint(
|
||||
self.sym_sizes(),
|
||||
self.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(self.scalar_type()),
|
||||
/*layout=*/::std::make_optional(self.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/self.scalar_type(),
|
||||
/*layout=*/self.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta_symint(
|
||||
src.sym_sizes(),
|
||||
src.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(src.scalar_type()),
|
||||
/*layout=*/::std::make_optional(src.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/src.scalar_type(),
|
||||
/*layout=*/src.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
|
||||
self_meta, src_meta, dim, index);
|
||||
@ -1299,16 +1299,16 @@ std::vector<Shape> compute_shape_diagonal_scatter(
|
||||
auto self_meta = at::native::empty_strided_meta_symint(
|
||||
self.sym_sizes(),
|
||||
self.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(self.scalar_type()),
|
||||
/*layout=*/::std::make_optional(self.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/self.scalar_type(),
|
||||
/*layout=*/self.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta_symint(
|
||||
src.sym_sizes(),
|
||||
src.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(src.scalar_type()),
|
||||
/*layout=*/::std::make_optional(src.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/src.scalar_type(),
|
||||
/*layout=*/src.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
|
||||
self_meta, src_meta, offset, dim1, dim2);
|
||||
@ -1325,16 +1325,16 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
|
||||
auto self_meta = at::native::empty_strided_meta_symint(
|
||||
self.sym_sizes(),
|
||||
self.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(self.scalar_type()),
|
||||
/*layout=*/::std::make_optional(self.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/self.scalar_type(),
|
||||
/*layout=*/self.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta_symint(
|
||||
src.sym_sizes(),
|
||||
src.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(src.scalar_type()),
|
||||
/*layout=*/::std::make_optional(src.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/src.scalar_type(),
|
||||
/*layout=*/src.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto out_meta =
|
||||
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
|
||||
@ -1356,16 +1356,16 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
|
||||
auto self_meta = at::native::empty_strided_meta_symint(
|
||||
self.sym_sizes(),
|
||||
self.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(self.scalar_type()),
|
||||
/*layout=*/::std::make_optional(self.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/self.scalar_type(),
|
||||
/*layout=*/self.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto src_meta = at::native::empty_strided_meta_symint(
|
||||
src.sym_sizes(),
|
||||
src.sym_strides(),
|
||||
/*dtype=*/::std::make_optional(src.scalar_type()),
|
||||
/*layout=*/::std::make_optional(src.layout()),
|
||||
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
|
||||
/*dtype=*/src.scalar_type(),
|
||||
/*layout=*/src.layout(),
|
||||
/*device=*/c10::Device(c10::kMeta),
|
||||
/*pin_memory=*/::std::nullopt);
|
||||
auto out_meta =
|
||||
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
|
||||
|
@ -143,8 +143,8 @@ at::Tensor to_meta(const at::Tensor& tensor) {
|
||||
// undefined tensors can't be converted to the meta device, since they don't have sizes/strides
|
||||
if (!tensor.defined()) return tensor;
|
||||
auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
|
||||
/*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
|
||||
/*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
|
||||
/*dtype=*/tensor.scalar_type(), /*layout=*/tensor.layout(), \
|
||||
/*device=*/c10::Device(c10::kMeta), /*pin_memory=*/std::nullopt);
|
||||
// needs to handle wrapped numbers, so dtype promotion works properly.
|
||||
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
out.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
|
Reference in New Issue
Block a user