mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Redirect c10::optional to std::optional (#101995)
We have C++17 now! I am intentionally dropping the `c10::optional<c10::ArrayRef>` size optimization. It was intended to improve dispatch, but thanks to D34602980 / #70864 we don't use `optional<ArrayRef>` in function arguments anymore anyway. Differential Revision: [D46079028](https://our.internmc.facebook.com/intern/diff/D46079028/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/101995 Approved by: https://github.com/malfet, https://github.com/Skylion007, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
013675ff59
commit
165f4f6ccf
@ -4,8 +4,6 @@
|
|||||||
// Forward declarations of core ATen types used in dispatch functions
|
// Forward declarations of core ATen types used in dispatch functions
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
class optional;
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
class List;
|
class List;
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
|||||||
@ -26,6 +26,7 @@ at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
|
|||||||
"Please register a generator to the PrivateUse1 dispatch key, \
|
"Please register a generator to the PrivateUse1 dispatch key, \
|
||||||
using the REGISTER_GENERATOR_PRIVATEUSE1 macro.");
|
using the REGISTER_GENERATOR_PRIVATEUSE1 macro.");
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return GetGeneratorPrivate().value()(device_index);
|
return GetGeneratorPrivate().value()(device_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -240,6 +240,7 @@ TEST(IOptTensorListRefTest, Boxed_Iterate) {
|
|||||||
for (const auto t : list) {
|
for (const auto t : list) {
|
||||||
EXPECT_EQ(boxed[i].has_value(), t.has_value());
|
EXPECT_EQ(boxed[i].has_value(), t.has_value());
|
||||||
if (t.has_value()) {
|
if (t.has_value()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
EXPECT_TRUE((*boxed[i]).is_same(*t));
|
EXPECT_TRUE((*boxed[i]).is_same(*t));
|
||||||
}
|
}
|
||||||
i++;
|
i++;
|
||||||
|
|||||||
@ -1136,8 +1136,10 @@ TEST(ListTest, canAccessOptionalStringByReference) {
|
|||||||
c10::optional<std::string> str2 = list[2];
|
c10::optional<std::string> str2 = list[2];
|
||||||
decltype(auto) strRef1 = listRef[1];
|
decltype(auto) strRef1 = listRef[1];
|
||||||
decltype(auto) strRef2 = listRef[2];
|
decltype(auto) strRef2 = listRef[2];
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
EXPECT_EQ("two", str1.value());
|
EXPECT_EQ("two", str1.value());
|
||||||
EXPECT_FALSE(str2.has_value());
|
EXPECT_FALSE(str2.has_value());
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
EXPECT_EQ("two", strRef1.value().get());
|
EXPECT_EQ("two", strRef1.value().get());
|
||||||
EXPECT_FALSE(strRef2.has_value());
|
EXPECT_FALSE(strRef2.has_value());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,6 +31,7 @@ constexpr c10::DispatchKeySet after_Python_keyset = c10::DispatchKeySet(c10::Dis
|
|||||||
// This guard assumes that tls_on_entry has a value.
|
// This guard assumes that tls_on_entry has a value.
|
||||||
struct StashTLSOnEntryGuard {
|
struct StashTLSOnEntryGuard {
|
||||||
public:
|
public:
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
StashTLSOnEntryGuard(): saved_(tls_on_entry.value()) {
|
StashTLSOnEntryGuard(): saved_(tls_on_entry.value()) {
|
||||||
tls_on_entry = c10::nullopt;
|
tls_on_entry = c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,10 +3,11 @@
|
|||||||
namespace c10 {
|
namespace c10 {
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
|
||||||
void common_device_check_failure(optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||||
TORCH_CHECK(false,
|
TORCH_CHECK(false,
|
||||||
"Expected all tensors to be on the same device, but "
|
"Expected all tensors to be on the same device, but "
|
||||||
"found at least two devices, ", common_device.value(), " and ", tensor.device(), "! "
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
"found at least two devices, ", common_device, " and ", tensor.device(), "! "
|
||||||
"(when checking argument for argument ", argName, " in method ", methodName, ")");
|
"(when checking argument for argument ", argName, " in method ", methodName, ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -87,9 +87,11 @@ std::string ClassType::getForwardPreHookErrorMessage(int pre_hook_idx) const {
|
|||||||
pre_hook_name + "(self, input: Tuple[" + input_types + "])";
|
pre_hook_name + "(self, input: Tuple[" + input_types + "])";
|
||||||
std::string return_string =
|
std::string return_string =
|
||||||
"This error occurred while scripting the forward pre-hook '" +
|
"This error occurred while scripting the forward pre-hook '" +
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
pre_hook_name + "' on module '" + name()->name() +
|
pre_hook_name + "' on module '" + name()->name() +
|
||||||
"'. If you did not want to script this pre-hook remove it from the "
|
"'. If you did not want to script this pre-hook remove it from the "
|
||||||
"original NN module before scripting. Pre-hooks for module '" +
|
"original NN module before scripting. Pre-hooks for module '" +
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
name()->name() + "' are expected to have the following signature: "
|
name()->name() + "' are expected to have the following signature: "
|
||||||
+ pre_hook_schema + " with a return type of either 'None'" +
|
+ pre_hook_schema + " with a return type of either 'None'" +
|
||||||
single_output + " or 'Tuple[" + input_types + "]'.";
|
single_output + " or 'Tuple[" + input_types + "]'.";
|
||||||
@ -112,6 +114,7 @@ std::string ClassType::getForwardHookErrorMessage(int hook_idx) const {
|
|||||||
input_types + "], output: " + output_types + ")";
|
input_types + "], output: " + output_types + ")";
|
||||||
std::string return_string =
|
std::string return_string =
|
||||||
"This error occurred while scripting the forward hook '"
|
"This error occurred while scripting the forward hook '"
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
+ hook_name + "' on module " + name()->name() +
|
+ hook_name + "' on module " + name()->name() +
|
||||||
". If you did not want to script this hook remove it from" +
|
". If you did not want to script this hook remove it from" +
|
||||||
" the original NN module before scripting. This hook was" +
|
" the original NN module before scripting. This hook was" +
|
||||||
@ -191,6 +194,7 @@ void ClassType::checkForwardPreHookSchema(
|
|||||||
const FunctionSchema& pre_hook_schema) const {
|
const FunctionSchema& pre_hook_schema) const {
|
||||||
const torch::jit::Function* pre_hook = forward_pre_hooks_[pre_hook_idx];
|
const torch::jit::Function* pre_hook = forward_pre_hooks_[pre_hook_idx];
|
||||||
std::string hook_id =
|
std::string hook_id =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
"Pre-hook '" + pre_hook->name() + "' on module '" + name()->name() + "' ";
|
"Pre-hook '" + pre_hook->name() + "' on module '" + name()->name() + "' ";
|
||||||
std::string pre_hook_err_msg = getForwardPreHookErrorMessage(pre_hook_idx) + "\n";
|
std::string pre_hook_err_msg = getForwardPreHookErrorMessage(pre_hook_idx) + "\n";
|
||||||
|
|
||||||
@ -287,6 +291,7 @@ void ClassType::checkForwardHookSchema(
|
|||||||
const FunctionSchema& hook_schema) const {
|
const FunctionSchema& hook_schema) const {
|
||||||
const torch::jit::Function* hook = forward_hooks_[hook_idx];
|
const torch::jit::Function* hook = forward_hooks_[hook_idx];
|
||||||
std::string hook_id =
|
std::string hook_id =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
"Hook '" + hook->name() + "' on module '" + name()->name() + "' ";
|
"Hook '" + hook->name() + "' on module '" + name()->name() + "' ";
|
||||||
std::string hook_err_msg = getForwardHookErrorMessage(hook_idx) + "\n";
|
std::string hook_err_msg = getForwardHookErrorMessage(hook_idx) + "\n";
|
||||||
// Hooks are expecting three inputs: self, a Tuple containing the non-self
|
// Hooks are expecting three inputs: self, a Tuple containing the non-self
|
||||||
|
|||||||
@ -68,6 +68,7 @@ static std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
|
|||||||
|
|
||||||
void registerCustomClass(at::ClassTypePtr class_type) {
|
void registerCustomClass(at::ClassTypePtr class_type) {
|
||||||
TORCH_INTERNAL_ASSERT(class_type->name());
|
TORCH_INTERNAL_ASSERT(class_type->name());
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto name = class_type->name()->qualifiedName();
|
auto name = class_type->name()->qualifiedName();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!customClasses().count(name),
|
!customClasses().count(name),
|
||||||
@ -96,6 +97,7 @@ const std::unordered_set<std::string> getAllCustomClassesNames() {
|
|||||||
|
|
||||||
bool isCustomClass(const c10::IValue& v) {
|
bool isCustomClass(const c10::IValue& v) {
|
||||||
return v.isObject() && v.toObject()->type()->name() &&
|
return v.isObject() && v.toObject()->type()->name() &&
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
getCustomClass(v.toObject()->type()->name()->qualifiedName());
|
getCustomClass(v.toObject()->type()->name()->qualifiedName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -261,6 +261,7 @@ TypePtr DynamicType::fallback() const {
|
|||||||
std::vector<c10::string_view> fields;
|
std::vector<c10::string_view> fields;
|
||||||
fields.reserve(arguments_.elems.size());
|
fields.reserve(arguments_.elems.size());
|
||||||
for (const auto& elem : arguments_.elems) {
|
for (const auto& elem : arguments_.elems) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
fields.emplace_back(*elem.label);
|
fields.emplace_back(*elem.label);
|
||||||
}
|
}
|
||||||
return TupleType::createNamed(*name_, fields, fallbacks);
|
return TupleType::createNamed(*name_, fields, fallbacks);
|
||||||
@ -290,6 +291,7 @@ TypePtr DynamicType::fallback() const {
|
|||||||
case Tag::Storage:
|
case Tag::Storage:
|
||||||
return StorageType::get();
|
return StorageType::get();
|
||||||
case Tag::Var:
|
case Tag::Var:
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return VarType::create(*name_);
|
return VarType::create(*name_);
|
||||||
case Tag::AnyClass:
|
case Tag::AnyClass:
|
||||||
return AnyClassType::get();
|
return AnyClassType::get();
|
||||||
|
|||||||
@ -985,6 +985,7 @@ void IValue::reportToTensorTypeError() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string ivalue::Object::name() const {
|
std::string ivalue::Object::name() const {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return type()->name()->qualifiedName();
|
return type()->name()->qualifiedName();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -56,7 +56,7 @@ CppFunction::~CppFunction() = default;
|
|||||||
Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
||||||
: kind_(kind)
|
: kind_(kind)
|
||||||
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
|
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
|
||||||
, dispatch_key_(k.value_or(CatchAll) == CatchAll ? c10::nullopt : k)
|
, dispatch_key_(k.value_or(CatchAll) == CatchAll ? c10::optional<c10::DispatchKey>() : k)
|
||||||
, file_(file)
|
, file_(file)
|
||||||
, line_(line)
|
, line_(line)
|
||||||
{
|
{
|
||||||
@ -66,6 +66,7 @@ Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, c
|
|||||||
// don't register a library
|
// don't register a library
|
||||||
registrars_.emplace_back(
|
registrars_.emplace_back(
|
||||||
c10::Dispatcher::singleton().registerLibrary(
|
c10::Dispatcher::singleton().registerLibrary(
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
*ns_, debugString(file_, line_)
|
*ns_, debugString(file_, line_)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@ -195,15 +196,18 @@ at::OperatorName Library::_parseNameForLib(const char* name_str) const {
|
|||||||
// This is a copy paste of Library::_impl
|
// This is a copy paste of Library::_impl
|
||||||
if (ns_opt.has_value()) {
|
if (ns_opt.has_value()) {
|
||||||
// See Note [Redundancy in registration code is OK]
|
// See Note [Redundancy in registration code is OK]
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
TORCH_CHECK(*ns_opt == *ns_,
|
TORCH_CHECK(*ns_opt == *ns_,
|
||||||
IMPL_PRELUDE,
|
IMPL_PRELUDE,
|
||||||
"Explicitly provided namespace (", *ns_opt, ") in operator name "
|
"Explicitly provided namespace (", *ns_opt, ") in operator name "
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
"does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
|
"does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
|
||||||
"Move this definition to the ", toString(kind_), " block corresponding to this namespace "
|
"Move this definition to the ", toString(kind_), " block corresponding to this namespace "
|
||||||
"(and consider deleting the namespace from your schema string.) ",
|
"(and consider deleting the namespace from your schema string.) ",
|
||||||
ERROR_CONTEXT
|
ERROR_CONTEXT
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
||||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,7 +43,7 @@
|
|||||||
namespace c10 {
|
namespace c10 {
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
|
||||||
TORCH_API void common_device_check_failure(optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
|
TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
|
||||||
|
|
||||||
inline void check_and_update_common_device(optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
inline void check_and_update_common_device(optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||||
// TODO: Remove this once the following issue is addressed:
|
// TODO: Remove this once the following issue is addressed:
|
||||||
@ -58,7 +58,7 @@ inline void check_and_update_common_device(optional<Device>& common_device, cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (C10_UNLIKELY(common_device != tensor.device())) {
|
if (C10_UNLIKELY(common_device != tensor.device())) {
|
||||||
common_device_check_failure(common_device, tensor, methodName, argName);
|
common_device_check_failure(*common_device, tensor, methodName, argName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
|||||||
out << ", ";
|
out << ", ";
|
||||||
}
|
}
|
||||||
if (vs[i].has_value()) {
|
if (vs[i].has_value()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
out << vs[i].value();
|
out << vs[i].value();
|
||||||
} else {
|
} else {
|
||||||
out << "*";
|
out << "*";
|
||||||
@ -281,18 +282,23 @@ TensorTypePtr TensorType::create(
|
|||||||
c10::optional<bool> undefined, bool tensor_contiguity) {
|
c10::optional<bool> undefined, bool tensor_contiguity) {
|
||||||
if(strides.concrete_sizes() && strides.concrete_sizes().has_value()){
|
if(strides.concrete_sizes() && strides.concrete_sizes().has_value()){
|
||||||
// handles case where strides are set
|
// handles case where strides are set
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
TORCH_INTERNAL_ASSERT(sizes.concrete_sizes()->size() == strides.concrete_sizes()->size());
|
TORCH_INTERNAL_ASSERT(sizes.concrete_sizes()->size() == strides.concrete_sizes()->size());
|
||||||
auto sprops = strides.concrete_sizes().has_value()
|
auto sprops = strides.concrete_sizes().has_value()
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
? computeStrideProps(*sizes.concrete_sizes(), *strides.concrete_sizes(), tensor_contiguity)
|
? computeStrideProps(*sizes.concrete_sizes(), *strides.concrete_sizes(), tensor_contiguity)
|
||||||
: VaryingShape<Stride>();
|
: VaryingShape<Stride>();
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto symbol_sizes = SymbolicShape(*sizes.concrete_sizes());
|
auto symbol_sizes = SymbolicShape(*sizes.concrete_sizes());
|
||||||
return TensorType::create(
|
return TensorType::create(
|
||||||
scalar_type, device, symbol_sizes, sprops, requires_grad, undefined);
|
scalar_type, device, symbol_sizes, sprops, requires_grad, undefined);
|
||||||
} else {
|
} else {
|
||||||
// strides are all null, but still have number of strides equal to number of ranks
|
// strides are all null, but still have number of strides equal to number of ranks
|
||||||
TORCH_INTERNAL_ASSERT(sizes.sizes() && sizes.size());
|
TORCH_INTERNAL_ASSERT(sizes.sizes() && sizes.size());
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto symbol_sizes = SymbolicShape(*sizes.sizes());
|
auto symbol_sizes = SymbolicShape(*sizes.sizes());
|
||||||
return TensorType::create(
|
return TensorType::create(
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
scalar_type, device, symbol_sizes, VaryingShape<Stride>(*sizes.size()), requires_grad, undefined);
|
scalar_type, device, symbol_sizes, VaryingShape<Stride>(*sizes.size()), requires_grad, undefined);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -338,6 +344,7 @@ VaryingShape<int64_t> TensorType::sizes() const {
|
|||||||
return VaryingShape<int64_t>();
|
return VaryingShape<int64_t>();
|
||||||
}
|
}
|
||||||
return VaryingShape<int64_t>(
|
return VaryingShape<int64_t>(
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
fmap(*sizes_.sizes(), [](ShapeSymbol ss) {
|
fmap(*sizes_.sizes(), [](ShapeSymbol ss) {
|
||||||
// we turn symbolic shapes into unknowns
|
// we turn symbolic shapes into unknowns
|
||||||
return ss.is_static()
|
return ss.is_static()
|
||||||
|
|||||||
@ -185,6 +185,7 @@ OptionalType::OptionalType(const TypePtr& contained)
|
|||||||
} else {
|
} else {
|
||||||
std::vector<TypePtr> to_subtract{NoneType::get()};
|
std::vector<TypePtr> to_subtract{NoneType::get()};
|
||||||
auto without_none = subtractTypeSetFrom(to_subtract, types_);
|
auto without_none = subtractTypeSetFrom(to_subtract, types_);
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
contained_ = UnionType::create({*without_none});
|
contained_ = UnionType::create({*without_none});
|
||||||
}
|
}
|
||||||
has_free_variables_ = contained_->hasFreeVariables();
|
has_free_variables_ = contained_->hasFreeVariables();
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
#include <c10/util/string_view.h>
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <ATen/native/DispatchStub.h>
|
#include <ATen/native/DispatchStub.h>
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/native/DispatchStub.h>
|
#include <ATen/native/DispatchStub.h>
|
||||||
|
#include <c10/util/ArrayRef.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|||||||
@ -609,7 +609,7 @@ at::Tensor _qconv_prepack_onednn(
|
|||||||
|
|
||||||
auto packed_weight = at::native::new_with_itensor_mkldnn(
|
auto packed_weight = at::native::new_with_itensor_mkldnn(
|
||||||
std::move(exp_wgt),
|
std::move(exp_wgt),
|
||||||
optTypeMetaToScalarType(weight_copy.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(weight_copy.options().dtype_opt()),
|
||||||
weight_copy.options().device_opt());
|
weight_copy.options().device_opt());
|
||||||
|
|
||||||
return packed_weight;
|
return packed_weight;
|
||||||
|
|||||||
@ -298,7 +298,7 @@ inline at::Tensor pack_weight_to_onednn_tensor(
|
|||||||
expected_weight.feed_from(wei);
|
expected_weight.feed_from(wei);
|
||||||
auto packed_weight = at::native::new_with_itensor_mkldnn(
|
auto packed_weight = at::native::new_with_itensor_mkldnn(
|
||||||
std::move(expected_weight),
|
std::move(expected_weight),
|
||||||
optTypeMetaToScalarType(weight.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(weight.options().dtype_opt()),
|
||||||
weight.options().device_opt());
|
weight.options().device_opt());
|
||||||
return packed_weight;
|
return packed_weight;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,11 +20,13 @@ bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
|||||||
}
|
}
|
||||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
|
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||||
} else if (rhs->singleton_int()) {
|
} else if (rhs->singleton_int()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (lhs->constant_int() && *lhs->constant_int() < 2) {
|
if (lhs->constant_int() && *lhs->constant_int() < 2) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -301,6 +301,7 @@ void TensorImpl::throw_cannot_call_with_symbolic(const char* meth) const {
|
|||||||
|
|
||||||
void TensorImpl::throw_storage_access_error() const {
|
void TensorImpl::throw_storage_access_error() const {
|
||||||
if (extra_meta_ && extra_meta_->custom_storage_error_msg_) {
|
if (extra_meta_ && extra_meta_->custom_storage_error_msg_) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
TORCH_CHECK(false, *extra_meta_->custom_storage_error_msg_);
|
TORCH_CHECK(false, *extra_meta_->custom_storage_error_msg_);
|
||||||
}
|
}
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
@ -309,6 +310,7 @@ void TensorImpl::throw_storage_access_error() const {
|
|||||||
|
|
||||||
void TensorImpl::throw_data_ptr_access_error() const {
|
void TensorImpl::throw_data_ptr_access_error() const {
|
||||||
if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) {
|
if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_);
|
TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_);
|
||||||
}
|
}
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|||||||
@ -32,6 +32,7 @@ std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
|
|||||||
// default
|
// default
|
||||||
stream << ", memory_format=";
|
stream << ", memory_format=";
|
||||||
if (options.has_memory_format()) {
|
if (options.has_memory_format()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
stream << *options.memory_format_opt();
|
stream << *options.memory_format_opt();
|
||||||
} else {
|
} else {
|
||||||
stream << "(nullopt)";
|
stream << "(nullopt)";
|
||||||
|
|||||||
@ -45,7 +45,8 @@ const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
|
|||||||
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
|
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
|
||||||
i >= 0;
|
i >= 0;
|
||||||
--i) {
|
--i) {
|
||||||
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
|
if (torchDispatchModeState.infra_modes_[i].has_value()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
out = std::move(torchDispatchModeState.infra_modes_[i].value());
|
out = std::move(torchDispatchModeState.infra_modes_[i].value());
|
||||||
torchDispatchModeState.infra_modes_[i] = c10::nullopt;
|
torchDispatchModeState.infra_modes_[i] = c10::nullopt;
|
||||||
break;
|
break;
|
||||||
@ -65,7 +66,8 @@ TorchDispatchModeTLS::pop_highest_infra_mode() {
|
|||||||
for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
|
for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
|
||||||
i >= 0;
|
i >= 0;
|
||||||
--i) {
|
--i) {
|
||||||
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
|
if (torchDispatchModeState.infra_modes_[i].has_value()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto out_mode = torchDispatchModeState.infra_modes_[i].value();
|
auto out_mode = torchDispatchModeState.infra_modes_[i].value();
|
||||||
torchDispatchModeState.infra_modes_[i] = c10::nullopt;
|
torchDispatchModeState.infra_modes_[i] = c10::nullopt;
|
||||||
if (!any_modes_set()) {
|
if (!any_modes_set()) {
|
||||||
@ -94,8 +96,9 @@ const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
|
|||||||
auto curr_idx = idx;
|
auto curr_idx = idx;
|
||||||
for (const auto i :
|
for (const auto i :
|
||||||
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
|
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
|
||||||
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
|
if (torchDispatchModeState.infra_modes_[i].has_value()) {
|
||||||
if (curr_idx == 0) {
|
if (curr_idx == 0) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return torchDispatchModeState.infra_modes_[i].value();
|
return torchDispatchModeState.infra_modes_[i].value();
|
||||||
}
|
}
|
||||||
curr_idx -= 1;
|
curr_idx -= 1;
|
||||||
|
|||||||
@ -7,6 +7,8 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include <c10/util/ArrayRef.h>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using testing::Eq;
|
using testing::Eq;
|
||||||
@ -56,13 +58,6 @@ using OptionalTypes = ::testing::Types<
|
|||||||
// Non-trivial destructor.
|
// Non-trivial destructor.
|
||||||
std::string>;
|
std::string>;
|
||||||
|
|
||||||
// This assert is also in Optional.cpp; including here too to make it
|
|
||||||
// more likely that we'll remember to port this optimization over when
|
|
||||||
// we move to std::optional.
|
|
||||||
static_assert(
|
|
||||||
sizeof(c10::optional<c10::IntArrayRef>) == sizeof(c10::IntArrayRef),
|
|
||||||
"c10::optional<IntArrayRef> should be size-optimized");
|
|
||||||
|
|
||||||
TYPED_TEST_SUITE(OptionalTest, OptionalTypes);
|
TYPED_TEST_SUITE(OptionalTest, OptionalTypes);
|
||||||
|
|
||||||
TYPED_TEST(OptionalTest, Empty) {
|
TYPED_TEST(OptionalTest, Empty) {
|
||||||
@ -71,7 +66,7 @@ TYPED_TEST(OptionalTest, Empty) {
|
|||||||
EXPECT_FALSE((bool)empty);
|
EXPECT_FALSE((bool)empty);
|
||||||
EXPECT_FALSE(empty.has_value());
|
EXPECT_FALSE(empty.has_value());
|
||||||
|
|
||||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access,hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||||
EXPECT_THROW(empty.value(), c10::bad_optional_access);
|
EXPECT_THROW(empty.value(), c10::bad_optional_access);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,7 +89,9 @@ TYPED_TEST(OptionalTest, Initialized) {
|
|||||||
EXPECT_TRUE((bool)opt);
|
EXPECT_TRUE((bool)opt);
|
||||||
EXPECT_TRUE(opt.has_value());
|
EXPECT_TRUE(opt.has_value());
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
EXPECT_EQ(opt.value(), val);
|
EXPECT_EQ(opt.value(), val);
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
EXPECT_EQ(*opt, val);
|
EXPECT_EQ(*opt, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,17 +1 @@
|
|||||||
#include <c10/util/ArrayRef.h>
|
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
static_assert(
|
|
||||||
C10_IS_TRIVIALLY_COPYABLE(c10::optional<int>),
|
|
||||||
"c10::optional<int> should be trivially copyable");
|
|
||||||
static_assert(
|
|
||||||
C10_IS_TRIVIALLY_COPYABLE(c10::optional<bool>),
|
|
||||||
"c10::optional<bool> should be trivially copyable");
|
|
||||||
static_assert(
|
|
||||||
C10_IS_TRIVIALLY_COPYABLE(c10::optional<c10::IntArrayRef>),
|
|
||||||
"c10::optional<IntArrayRef> should be trivially copyable");
|
|
||||||
static_assert(
|
|
||||||
sizeof(c10::optional<c10::IntArrayRef>) == sizeof(c10::IntArrayRef),
|
|
||||||
"c10::optional<IntArrayRef> should be size-optimized");
|
|
||||||
|
|||||||
1237
c10/util/Optional.h
1237
c10/util/Optional.h
File diff suppressed because it is too large
Load Diff
@ -4,20 +4,9 @@
|
|||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
struct in_place_t {
|
using std::in_place;
|
||||||
explicit in_place_t() = default;
|
using std::in_place_index_t;
|
||||||
};
|
using std::in_place_t;
|
||||||
|
using std::in_place_type_t;
|
||||||
template <std::size_t I>
|
|
||||||
struct in_place_index_t {
|
|
||||||
explicit in_place_index_t() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct in_place_type_t {
|
|
||||||
explicit in_place_type_t() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr in_place_t in_place{};
|
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|||||||
@ -483,7 +483,8 @@ TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
|
|||||||
|
|
||||||
auto res =
|
auto res =
|
||||||
calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
|
calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
|
||||||
c10::SymbolicShape expected_res = c10::SymbolicShape({sym_dim});
|
c10::SymbolicShape expected_res =
|
||||||
|
c10::SymbolicShape(std::vector<c10::optional<int64_t>>{sym_dim});
|
||||||
assertShapeEqual(res->at(0), expected_res);
|
assertShapeEqual(res->at(0), expected_res);
|
||||||
// res0 and res1 should share the same symbolic symbol
|
// res0 and res1 should share the same symbolic symbol
|
||||||
EXPECT_EQ(res->at(0), res->at(1));
|
EXPECT_EQ(res->at(0), res->at(1));
|
||||||
|
|||||||
@ -77,6 +77,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
|
|||||||
throw python_error();
|
throw python_error();
|
||||||
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
|
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*m));
|
PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*m));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1528,9 +1528,11 @@ std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
|
|||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
self_requires_grad ? maybe_multiply(grad, beta.conj()) : Tensor{},
|
self_requires_grad ? maybe_multiply(grad, beta.conj()) : Tensor{},
|
||||||
mat1_requires_grad
|
mat1_requires_grad
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
? maybe_multiply(grad_projected.mm(mat2->mH()), alpha.conj())
|
? maybe_multiply(grad_projected.mm(mat2->mH()), alpha.conj())
|
||||||
: Tensor{},
|
: Tensor{},
|
||||||
mat2_requires_grad
|
mat2_requires_grad
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
? maybe_multiply(mat1->mH().mm(grad_projected), alpha.conj())
|
? maybe_multiply(mat1->mH().mm(grad_projected), alpha.conj())
|
||||||
: Tensor{});
|
: Tensor{});
|
||||||
}
|
}
|
||||||
@ -2263,9 +2265,12 @@ Tensor binary_cross_entropy_target_backward(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isDefined(weight)) {
|
if (isDefined(weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (!isTensorSubclassLike(weight.value())) {
|
if (!isTensorSubclassLike(weight.value())) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
grad_target.mul_(weight.value());
|
grad_target.mul_(weight.value());
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
grad_target = grad_target * weight.value();
|
grad_target = grad_target * weight.value();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2287,7 +2292,11 @@ Tensor binary_cross_entropy_double_backward_target(
|
|||||||
auto res = -grad * grad_output;
|
auto res = -grad * grad_output;
|
||||||
|
|
||||||
if (isDefined(weight)) {
|
if (isDefined(weight)) {
|
||||||
res = isTensorSubclassLike(weight.value()) ? res.mul(weight.value())
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
res = isTensorSubclassLike(weight.value())
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
? res.mul(weight.value())
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
: res.mul_(weight.value());
|
: res.mul_(weight.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2329,6 +2338,7 @@ Tensor binary_cross_entropy_with_logits_backward(
|
|||||||
Tensor grad_input;
|
Tensor grad_input;
|
||||||
if (isDefined(pos_weight)) {
|
if (isDefined(pos_weight)) {
|
||||||
// pos_weight might need to be broadcasted, thus mul(target) is not inplace.
|
// pos_weight might need to be broadcasted, thus mul(target) is not inplace.
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto t = pos_weight->mul(target);
|
auto t = pos_weight->mul(target);
|
||||||
grad_input = at::areAnyTensorSubclassLike({input, target}) ||
|
grad_input = at::areAnyTensorSubclassLike({input, target}) ||
|
||||||
at::GradMode::is_enabled()
|
at::GradMode::is_enabled()
|
||||||
@ -2348,9 +2358,12 @@ Tensor binary_cross_entropy_with_logits_backward(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isDefined(weight)) {
|
if (isDefined(weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
|
if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
grad_input = grad_input.mul(*weight);
|
grad_input = grad_input.mul(*weight);
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
grad_input.mul_(*weight);
|
grad_input.mul_(*weight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2375,12 +2388,15 @@ Tensor binary_cross_entropy_with_logits_target_backward(
|
|||||||
|
|
||||||
Tensor grad_target;
|
Tensor grad_target;
|
||||||
if (isDefined(pos_weight)) {
|
if (isDefined(pos_weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
|
if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
|
||||||
grad_target = at::log_sigmoid(-self)
|
grad_target = at::log_sigmoid(-self)
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
.sub(at::log_sigmoid(self).mul(*pos_weight))
|
.sub(at::log_sigmoid(self).mul(*pos_weight))
|
||||||
.mul(grad_output);
|
.mul(grad_output);
|
||||||
} else {
|
} else {
|
||||||
grad_target = at::log_sigmoid(-self)
|
grad_target = at::log_sigmoid(-self)
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
.sub_(at::log_sigmoid(self).mul_(*pos_weight))
|
.sub_(at::log_sigmoid(self).mul_(*pos_weight))
|
||||||
.mul_(grad_output);
|
.mul_(grad_output);
|
||||||
}
|
}
|
||||||
@ -2389,9 +2405,12 @@ Tensor binary_cross_entropy_with_logits_target_backward(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isDefined(weight)) {
|
if (isDefined(weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (at::isTensorSubclassLike(*weight)) {
|
if (at::isTensorSubclassLike(*weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
grad_target = grad_target.mul(*weight);
|
grad_target = grad_target.mul(*weight);
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
grad_target.mul_(*weight);
|
grad_target.mul_(*weight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2467,9 +2486,12 @@ Tensor binary_cross_entropy_double_backward(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isDefined(weight)) {
|
if (isDefined(weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (!isTensorSubclassLike(*weight)) {
|
if (!isTensorSubclassLike(*weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
gI *= *weight;
|
gI *= *weight;
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
gI = gI.mul(*weight);
|
gI = gI.mul(*weight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2496,9 +2518,12 @@ Tensor binary_cross_entropy_double_backward_grad_output(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isDefined(weight)) {
|
if (isDefined(weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (!isTensorSubclassLike(*weight)) {
|
if (!isTensorSubclassLike(*weight)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
ggO *= *weight;
|
ggO *= *weight;
|
||||||
} else {
|
} else {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
ggO = ggO.mul(*weight);
|
ggO = ggO.mul(*weight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3165,6 +3190,7 @@ Tensor as_strided_backward(
|
|||||||
|
|
||||||
// Step (2): use output geometry to scatter gradients into storage
|
// Step (2): use output geometry to scatter gradients into storage
|
||||||
if (out_maybe_overlap) {
|
if (out_maybe_overlap) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto out_indices = flatten_full_indices->as_strided_symint(
|
auto out_indices = flatten_full_indices->as_strided_symint(
|
||||||
out_sizes_, out_strides_, out_effective_offset);
|
out_sizes_, out_strides_, out_effective_offset);
|
||||||
storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
|
storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
|
||||||
@ -3179,6 +3205,7 @@ Tensor as_strided_backward(
|
|||||||
if (inp_maybe_overlap) {
|
if (inp_maybe_overlap) {
|
||||||
auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
auto inp_indices =
|
auto inp_indices =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
flatten_full_indices
|
flatten_full_indices
|
||||||
->as_strided_symint(inp_sizes_, inp_strides_, inp_effective_offset)
|
->as_strided_symint(inp_sizes_, inp_strides_, inp_effective_offset)
|
||||||
.reshape(-1);
|
.reshape(-1);
|
||||||
@ -4600,6 +4627,7 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
|
|||||||
Tensor gamma_expanded;
|
Tensor gamma_expanded;
|
||||||
Tensor ggG_expanded, ggB_expanded;
|
Tensor ggG_expanded, ggB_expanded;
|
||||||
if (affine) {
|
if (affine) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
gamma_expanded = expand_as_dim1(*gamma, input);
|
gamma_expanded = expand_as_dim1(*gamma, input);
|
||||||
if (ggG.defined()) {
|
if (ggG.defined()) {
|
||||||
ggG_expanded = expand_as_dim1(ggG, input);
|
ggG_expanded = expand_as_dim1(ggG, input);
|
||||||
@ -4749,6 +4777,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
|||||||
Tensor gamma_expanded;
|
Tensor gamma_expanded;
|
||||||
Tensor ggG_expanded, ggB_expanded;
|
Tensor ggG_expanded, ggB_expanded;
|
||||||
if (affine) {
|
if (affine) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
gamma_expanded = gamma->reshape({1, N});
|
gamma_expanded = gamma->reshape({1, N});
|
||||||
if (ggG.defined()) {
|
if (ggG.defined()) {
|
||||||
ggG_expanded = ggG.reshape({1, N});
|
ggG_expanded = ggG.reshape({1, N});
|
||||||
@ -4832,6 +4861,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
|||||||
gG = first_bwd_fn_grad_input(
|
gG = first_bwd_fn_grad_input(
|
||||||
ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
|
ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
|
||||||
gG = (gO * gG).sum(0);
|
gG = (gO * gG).sum(0);
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
gG = gG.reshape_as(*gamma);
|
gG = gG.reshape_as(*gamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4895,6 +4925,7 @@ infinitely_differentiable_native_group_norm_backward(
|
|||||||
if (grad_input_mask[0]) {
|
if (grad_input_mask[0]) {
|
||||||
Tensor gamma_tensor;
|
Tensor gamma_tensor;
|
||||||
if (isDefined(gamma)) {
|
if (isDefined(gamma)) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
gamma_tensor = gamma->reshape_symint({1, G, D, 1});
|
gamma_tensor = gamma->reshape_symint({1, G, D, 1});
|
||||||
}
|
}
|
||||||
const Tensor var =
|
const Tensor var =
|
||||||
@ -4961,12 +4992,15 @@ std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
|
|||||||
if (grad_out.defined()) {
|
if (grad_out.defined()) {
|
||||||
if (grad_mask[0])
|
if (grad_mask[0])
|
||||||
grad_i1 =
|
grad_i1 =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
at::_trilinear(grad_out, *i2, *i3, sumdim, expand2, expand3, expand1);
|
at::_trilinear(grad_out, *i2, *i3, sumdim, expand2, expand3, expand1);
|
||||||
if (grad_mask[1])
|
if (grad_mask[1])
|
||||||
grad_i2 =
|
grad_i2 =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
at::_trilinear(*i1, grad_out, *i3, expand1, sumdim, expand3, expand2);
|
at::_trilinear(*i1, grad_out, *i3, expand1, sumdim, expand3, expand2);
|
||||||
if (grad_mask[2])
|
if (grad_mask[2])
|
||||||
grad_i3 =
|
grad_i3 =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
at::_trilinear(*i1, *i2, grad_out, expand1, expand2, sumdim, expand3);
|
at::_trilinear(*i1, *i2, grad_out, expand1, expand2, sumdim, expand3);
|
||||||
}
|
}
|
||||||
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
|
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
|
||||||
@ -6089,11 +6123,14 @@ static Tensor _affine_jvp(
|
|||||||
TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
|
TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
|
||||||
if (weight_p.defined()) {
|
if (weight_p.defined()) {
|
||||||
if (areAnyTensorSubclassLike(
|
if (areAnyTensorSubclassLike(
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
{input_p.value(), input_t, weight_p, weight_t}) ||
|
{input_p.value(), input_t, weight_p, weight_t}) ||
|
||||||
input_t._is_zerotensor() || weight_t._is_zerotensor()) {
|
input_t._is_zerotensor() || weight_t._is_zerotensor()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
input_t = input_t * weight_p + input_p.value() * weight_t;
|
input_t = input_t * weight_p + input_p.value() * weight_t;
|
||||||
} else {
|
} else {
|
||||||
input_t *= weight_p;
|
input_t *= weight_p;
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto temp = input_p.value();
|
auto temp = input_p.value();
|
||||||
temp *= weight_t;
|
temp *= weight_t;
|
||||||
input_t += temp;
|
input_t += temp;
|
||||||
|
|||||||
@ -380,6 +380,7 @@ static optional_variable_list _process_backward_mode_ad(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
Variable var = raw_outputs[i].value();
|
Variable var = raw_outputs[i].value();
|
||||||
|
|
||||||
auto out_tensor_impl = var.unsafeGetTensorImpl();
|
auto out_tensor_impl = var.unsafeGetTensorImpl();
|
||||||
|
|||||||
@ -708,6 +708,7 @@ void GraphTask::exec_post_processing() {
|
|||||||
// If leaf_stream.device_index() happens to be for a new device,
|
// If leaf_stream.device_index() happens to be for a new device,
|
||||||
// operator* on the c10::nullopt should throw an error.
|
// operator* on the c10::nullopt should throw an error.
|
||||||
const auto caller_current_stream =
|
const auto caller_current_stream =
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
*caller_current_streams_[leaf_stream.device_index()];
|
*caller_current_streams_[leaf_stream.device_index()];
|
||||||
|
|
||||||
if (caller_current_stream != leaf_stream) {
|
if (caller_current_stream != leaf_stream) {
|
||||||
|
|||||||
@ -1078,6 +1078,7 @@ static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) {
|
|||||||
if (maybe_mode == c10::nullopt) {
|
if (maybe_mode == c10::nullopt) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto* r = maybe_mode.value()->ptr(getPyInterpreter());
|
auto* r = maybe_mode.value()->ptr(getPyInterpreter());
|
||||||
Py_INCREF(r);
|
Py_INCREF(r);
|
||||||
return r;
|
return r;
|
||||||
@ -1093,6 +1094,7 @@ static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) {
|
|||||||
if (maybe_mode == c10::nullopt) {
|
if (maybe_mode == c10::nullopt) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
auto* r = maybe_mode.value()->ptr(getPyInterpreter());
|
auto* r = maybe_mode.value()->ptr(getPyInterpreter());
|
||||||
Py_INCREF(r);
|
Py_INCREF(r);
|
||||||
return r;
|
return r;
|
||||||
|
|||||||
@ -160,6 +160,7 @@ void InputBuffer::add(
|
|||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(device_of(var));
|
TORCH_INTERNAL_ASSERT(device_of(var));
|
||||||
c10::optional<c10::Stream> opt_accumulate_stream = c10::nullopt;
|
c10::optional<c10::Stream> opt_accumulate_stream = c10::nullopt;
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (device_of(var)->is_cuda()) {
|
if (device_of(var)->is_cuda()) {
|
||||||
const auto on_producer =
|
const auto on_producer =
|
||||||
opt_producer_stream && device_of(var) == opt_producer_stream->device();
|
opt_producer_stream && device_of(var) == opt_producer_stream->device();
|
||||||
@ -190,6 +191,7 @@ void InputBuffer::add(
|
|||||||
opt_sync_stream = opt_producer_stream;
|
opt_sync_stream = opt_producer_stream;
|
||||||
} else {
|
} else {
|
||||||
// (5)
|
// (5)
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
|
opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
|
||||||
}
|
}
|
||||||
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
|
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
|
||||||
|
|||||||
@ -465,6 +465,7 @@ ExtraFields<EventType::PyCall>::args_t ValueCache::load<
|
|||||||
OptimizerInfo info{
|
OptimizerInfo info{
|
||||||
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
|
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
|
||||||
return {
|
return {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
|
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
|
||||||
/*module_info_=*/c10::nullopt,
|
/*module_info_=*/c10::nullopt,
|
||||||
/*optimizer_info_=*/std::move(info)};
|
/*optimizer_info_=*/std::move(info)};
|
||||||
|
|||||||
@ -543,8 +543,10 @@ static void _wrap_outputs(
|
|||||||
PyTuple_SetItem(outputs, i, obj);
|
PyTuple_SetItem(outputs, i, obj);
|
||||||
} else {
|
} else {
|
||||||
if (is_executable) {
|
if (is_executable) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
self->output_info.emplace_back(*wrapped_outputs[i]);
|
self->output_info.emplace_back(*wrapped_outputs[i]);
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
|
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -751,7 +751,7 @@ void Reducer::all_reduce_local_used_map() {
|
|||||||
// local_used_map_
|
// local_used_map_
|
||||||
auto local_used_map_tmp = at::native::empty_like(
|
auto local_used_map_tmp = at::native::empty_like(
|
||||||
local_used_map_,
|
local_used_map_,
|
||||||
optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
|
||||||
local_used_map_.options().layout_opt(),
|
local_used_map_.options().layout_opt(),
|
||||||
local_used_map_.options().device_opt(),
|
local_used_map_.options().device_opt(),
|
||||||
true /* pinned_memory */);
|
true /* pinned_memory */);
|
||||||
@ -770,7 +770,7 @@ void Reducer::all_reduce_local_used_map() {
|
|||||||
// the pin memory step.
|
// the pin memory step.
|
||||||
auto local_used_map_tmp = at::native::empty_like(
|
auto local_used_map_tmp = at::native::empty_like(
|
||||||
local_used_map_,
|
local_used_map_,
|
||||||
optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
|
||||||
local_used_map_.options().layout_opt(),
|
local_used_map_.options().layout_opt(),
|
||||||
local_used_map_.options().device_opt());
|
local_used_map_.options().device_opt());
|
||||||
local_used_map_tmp.copy_(local_used_map_);
|
local_used_map_tmp.copy_(local_used_map_);
|
||||||
|
|||||||
@ -217,6 +217,7 @@ int64_t dlevel(const Tensor& tensor) {
|
|||||||
if (!wrapped->is_alive()) {
|
if (!wrapped->is_alive()) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return wrapped->level().value();
|
return wrapped->level().value();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,6 +341,7 @@ static int64_t maybe_get_level(const Tensor& tensor) {
|
|||||||
auto* wrapped = maybeGetTensorWrapper(tensor);
|
auto* wrapped = maybeGetTensorWrapper(tensor);
|
||||||
if (wrapped) {
|
if (wrapped) {
|
||||||
if (wrapped->level()) {
|
if (wrapped->level()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return *wrapped->level();
|
return *wrapped->level();
|
||||||
}
|
}
|
||||||
// TODO: this is a weird special case...
|
// TODO: this is a weird special case...
|
||||||
|
|||||||
@ -264,7 +264,7 @@ Operation createUnaryOp(
|
|||||||
TORCH_INTERNAL_ASSERT(it_empty.get_desc() == a_it.get_desc());
|
TORCH_INTERNAL_ASSERT(it_empty.get_desc() == a_it.get_desc());
|
||||||
out = at::native::new_with_itensor_mkldnn(
|
out = at::native::new_with_itensor_mkldnn(
|
||||||
std::move(it_empty),
|
std::move(it_empty),
|
||||||
optTypeMetaToScalarType(a.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(a.options().dtype_opt()),
|
||||||
a.options().device_opt());
|
a.options().device_opt());
|
||||||
|
|
||||||
out_raw_data = at::native::itensor_from_mkldnn(out).get_data_handle();
|
out_raw_data = at::native::itensor_from_mkldnn(out).get_data_handle();
|
||||||
@ -379,13 +379,13 @@ Operation BroadOp(const Node* node) {
|
|||||||
auto a_options = exp_a.options();
|
auto a_options = exp_a.options();
|
||||||
auto a_out = at::native::new_with_itensor_mkldnn(
|
auto a_out = at::native::new_with_itensor_mkldnn(
|
||||||
std::move(a_it),
|
std::move(a_it),
|
||||||
optTypeMetaToScalarType(a_options.dtype_opt()),
|
c10::optTypeMetaToScalarType(a_options.dtype_opt()),
|
||||||
a_options.device_opt());
|
a_options.device_opt());
|
||||||
push(stack, a_out);
|
push(stack, a_out);
|
||||||
auto b_options = exp_b.options();
|
auto b_options = exp_b.options();
|
||||||
auto b_out = at::native::new_with_itensor_mkldnn(
|
auto b_out = at::native::new_with_itensor_mkldnn(
|
||||||
std::move(b_it),
|
std::move(b_it),
|
||||||
optTypeMetaToScalarType(b_options.dtype_opt()),
|
c10::optTypeMetaToScalarType(b_options.dtype_opt()),
|
||||||
b_options.device_opt());
|
b_options.device_opt());
|
||||||
push(stack, b_out);
|
push(stack, b_out);
|
||||||
};
|
};
|
||||||
@ -544,7 +544,7 @@ jit::RegisterOperators reg_fut_ops({
|
|||||||
stack,
|
stack,
|
||||||
at::native::empty_mkldnn(
|
at::native::empty_mkldnn(
|
||||||
o,
|
o,
|
||||||
optTypeMetaToScalarType(input.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||||
input.options().layout_opt(),
|
input.options().layout_opt(),
|
||||||
input.options().device_opt(),
|
input.options().device_opt(),
|
||||||
input.options().pinned_memory_opt()));
|
input.options().pinned_memory_opt()));
|
||||||
@ -576,7 +576,7 @@ jit::RegisterOperators reg_fut_ops({
|
|||||||
Tensor self = pop(stack).toTensor();
|
Tensor self = pop(stack).toTensor();
|
||||||
auto out = at::native::empty_mkldnn(
|
auto out = at::native::empty_mkldnn(
|
||||||
self.sizes(),
|
self.sizes(),
|
||||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
c10::optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||||
self.options().layout_opt(),
|
self.options().layout_opt(),
|
||||||
self.options().device_opt(),
|
self.options().device_opt(),
|
||||||
self.options().pinned_memory_opt());
|
self.options().pinned_memory_opt());
|
||||||
|
|||||||
@ -280,7 +280,7 @@ class TORCH_API TensorExprKernel {
|
|||||||
c10::optional<bool> pinned_memory;
|
c10::optional<bool> pinned_memory;
|
||||||
|
|
||||||
UnpackedTensorOptions(const c10::TensorOptions& opts)
|
UnpackedTensorOptions(const c10::TensorOptions& opts)
|
||||||
: dtype(optTypeMetaToScalarType(opts.dtype_opt())),
|
: dtype(c10::optTypeMetaToScalarType(opts.dtype_opt())),
|
||||||
layout(opts.layout_opt()),
|
layout(opts.layout_opt()),
|
||||||
device(opts.device_opt()),
|
device(opts.device_opt()),
|
||||||
pinned_memory(opts.pinned_memory_opt()) {}
|
pinned_memory(opts.pinned_memory_opt()) {}
|
||||||
|
|||||||
@ -129,6 +129,7 @@ void calculateUniqueTensorIDs(
|
|||||||
ska::flat_hash_set<AllocationID> tensor_set;
|
ska::flat_hash_set<AllocationID> tensor_set;
|
||||||
for (const auto& t : tensors) {
|
for (const auto& t : tensors) {
|
||||||
if (t.impl_ != NoTensorImpl) {
|
if (t.impl_ != NoTensorImpl) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
tensor_set.insert(*t.allocation_id_ref_.get());
|
tensor_set.insert(*t.allocation_id_ref_.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -156,6 +157,7 @@ void calculateUniqueTensorIDs(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
const auto allocation_id = *t.allocation_id_ref_.get();
|
const auto allocation_id = *t.allocation_id_ref_.get();
|
||||||
const auto it = impl_map.insert({t.impl_, allocation_id}).first;
|
const auto it = impl_map.insert({t.impl_, allocation_id}).first;
|
||||||
|
|
||||||
@ -187,6 +189,7 @@ void calculateUniqueTensorIDs(
|
|||||||
// Write back to Tensor IDs.
|
// Write back to Tensor IDs.
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
for (const auto& t : tensors) {
|
for (const auto& t : tensors) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
const auto id = id_map.at(*t.allocation_id_ref_.get());
|
const auto id = id_map.at(*t.allocation_id_ref_.get());
|
||||||
t.id_ref_.get().emplace(TensorID(id));
|
t.id_ref_.get().emplace(TensorID(id));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -125,6 +125,7 @@ std::vector<FileLineFunc> prepareCallstack(
|
|||||||
auto line =
|
auto line =
|
||||||
src->starting_line_no() + src->lineno_for_offset(range.start());
|
src->starting_line_no() + src->lineno_for_offset(range.start());
|
||||||
entries.emplace_back(
|
entries.emplace_back(
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
FileLineFunc{*(src->filename()), line, entry.filename});
|
FileLineFunc{*(src->filename()), line, entry.filename});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,9 +15,11 @@ void check_out_type_matches(
|
|||||||
if (scalarType_is_none && !layout && device_is_none) { // common case
|
if (scalarType_is_none && !layout && device_is_none) { // common case
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (!scalarType_is_none && result.scalar_type() != scalarType.value()) {
|
if (!scalarType_is_none && result.scalar_type() != scalarType.value()) {
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
"dtype ",
|
"dtype ",
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
*scalarType,
|
*scalarType,
|
||||||
" does not match dtype of out parameter (",
|
" does not match dtype of out parameter (",
|
||||||
result.scalar_type(),
|
result.scalar_type(),
|
||||||
@ -31,9 +33,11 @@ void check_out_type_matches(
|
|||||||
result.layout(),
|
result.layout(),
|
||||||
")");
|
")");
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
if (!device_is_none && result.device().type() != device.value().type()) {
|
if (!device_is_none && result.device().type() != device.value().type()) {
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
"device type ",
|
"device type ",
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
device->type(),
|
device->type(),
|
||||||
" does not match device type of out parameter (",
|
" does not match device type of out parameter (",
|
||||||
result.device().type(),
|
result.device().type(),
|
||||||
|
|||||||
@ -46,6 +46,7 @@ py::handle type_caster<c10::SymInt>::cast(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto m = si.maybe_as_int();
|
auto m = si.maybe_as_int();
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return py::cast(*m).release();
|
return py::cast(*m).release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -324,11 +324,6 @@ struct type_caster<c10::complex<T>> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Pybind11 bindings for our optional.
|
|
||||||
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
|
|
||||||
template <typename T>
|
|
||||||
struct type_caster<c10::optional<T>> : optional_caster<c10::optional<T>> {};
|
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace pybind11
|
} // namespace pybind11
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ struct RAIIContextManager {
|
|||||||
|
|
||||||
void enter() {
|
void enter() {
|
||||||
auto emplace = [&](Args... args) {
|
auto emplace = [&](Args... args) {
|
||||||
return guard_.emplace(std::forward<Args>(args)...);
|
guard_.emplace(std::forward<Args>(args)...);
|
||||||
};
|
};
|
||||||
std::apply(std::move(emplace), args_);
|
std::apply(std::move(emplace), args_);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,8 +20,9 @@ void SchemaInfo::addArgumentValues(
|
|||||||
"Schema does not have enough arguments for value list");
|
"Schema does not have enough arguments for value list");
|
||||||
|
|
||||||
for (size_t i = 0; i < value_list.size(); i++) {
|
for (size_t i = 0; i < value_list.size(); i++) {
|
||||||
if (value_list[i] != c10::nullopt) {
|
if (value_list[i].has_value()) {
|
||||||
value_map_[schema_.arguments()[i].name()] = *(value_list[i]);
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
value_map_[schema_.arguments()[i].name()] = *value_list[i];
|
||||||
alias_maps_current_ = false;
|
alias_maps_current_ = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,7 +108,7 @@ bool SchemaInfo::has_argument(c10::string_view name) {
|
|||||||
bool SchemaInfo::is_mutable(c10::string_view name) {
|
bool SchemaInfo::is_mutable(c10::string_view name) {
|
||||||
c10::optional<int> index = schema_.argumentIndexWithName(name);
|
c10::optional<int> index = schema_.argumentIndexWithName(name);
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
index != c10::nullopt, "Schema has no argument named ", name);
|
index.has_value(), "Schema has no argument named ", name);
|
||||||
|
|
||||||
return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
|
return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -194,6 +194,7 @@ ScalarType infer_scalar_type(PyObject* obj) {
|
|||||||
return *scalarType;
|
return *scalarType;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
return *scalarType;
|
return *scalarType;
|
||||||
}
|
}
|
||||||
AT_ERROR("Could not infer dtype of ", Py_TYPE(obj)->tp_name);
|
AT_ERROR("Could not infer dtype of ", Py_TYPE(obj)->tp_name);
|
||||||
|
|||||||
Reference in New Issue
Block a user