[12/N] Apply clang-tidy and fix warnings in headers of torch/csrc (#116486)

This PR follows #116751.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116486
Approved by: https://github.com/albanD
This commit is contained in:
cyy
2024-01-10 08:48:14 +00:00
committed by PyTorch MergeBot
parent 90df7c008a
commit 20f769544c
50 changed files with 120 additions and 244 deletions

View File

@ -145,7 +145,7 @@ extern PyObject *THPException_FatalError, *THPException_LinAlgError,
// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
struct python_error : public std::exception {
python_error() : type(nullptr), value(nullptr), traceback(nullptr) {}
python_error() {}
python_error(const python_error& other)
: type(other.type),
@ -244,9 +244,9 @@ struct python_error : public std::exception {
PyErr_Restore(type, value, traceback);
}
PyObject* type;
PyObject* value;
PyObject* traceback;
PyObject* type{nullptr};
PyObject* value{nullptr};
PyObject* traceback{nullptr};
// Message to return to the user when 'what()' is invoked.
std::string message;

View File

@ -12,10 +12,7 @@
#include <ATen/ATen.h>
#include <torch/csrc/autograd/generated/Functions.h>
namespace torch {
namespace autograd {
namespace generated {
namespace details {
namespace torch::autograd::generated::details {
extern const char* kCudnnDoubleBackwardMsg;
@ -1101,7 +1098,4 @@ mkldnn_rnn_layer_differentiable_backward(
Tensor values_backward(const Tensor& grad, const Tensor& self);
} // namespace details
} // namespace generated
} // namespace autograd
} // namespace torch
} // namespace torch::autograd::generated::details

View File

@ -3,10 +3,8 @@
#include <c10/core/InferenceMode.h>
#include <torch/csrc/Export.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
using InferenceMode = c10::InferenceMode;
}
} // namespace torch

View File

@ -17,14 +17,9 @@
#include <torch/csrc/autograd/jit_decomp_interface.h>
#include <torch/csrc/utils/variadic.h>
#include <array>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
@ -117,8 +112,8 @@ inline void rebase_history(Variable& var, std::shared_ptr<Node> grad_fn) {
}
inline void rebase_history(
std::vector<Variable>&& vars,
std::shared_ptr<Node> grad_fn) {
const std::vector<Variable>& vars,
const std::shared_ptr<Node>& grad_fn) {
if (grad_fn) {
for (auto& var : vars) {
if (var.defined()) {
@ -137,6 +132,7 @@ inline void increment_version(const at::Tensor& t) {
struct Flatten : IterArgs<Flatten> {
Flatten(variable_list& out) : out(out) {}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
variable_list& out;
void operator()(const at::Tensor& x) {
out.emplace_back(x);

View File

@ -4,8 +4,7 @@
#include <memory>
#include <string>
namespace torch {
namespace autograd {
namespace torch::autograd {
// forward declaration of Node from function.h
struct Node;
@ -69,5 +68,4 @@ struct TORCH_API AnomalyMetadata {
std::shared_ptr<Node> parent_;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -2,8 +2,7 @@
#include <torch/csrc/autograd/variable.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
/// Computes the sum of gradients of given tensors with respect to graph leaves.
///
@ -102,5 +101,4 @@ TORCH_API uint64_t enter_dual_level();
TORCH_API void exit_dual_level(uint64_t level);
} // namespace forward_ad
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -2,8 +2,7 @@
#include <torch/library.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
// Default DispatchKey::Autograd fallback for built-in operators.
// Can be registered for custom operators.
@ -30,5 +29,4 @@ enum class AutogradFallbackMode {
TORCH_API void setAutogradFallbackMode(AutogradFallbackMode mode);
TORCH_API AutogradFallbackMode getAutogradFallbackMode();
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -3,8 +3,7 @@
#include <functional>
#include <memory>
namespace torch {
namespace autograd {
namespace torch::autograd {
using hooks_list =
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
@ -27,5 +26,4 @@ struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
size_t value_idx_;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -8,8 +8,7 @@
#include <torch/csrc/autograd/variable.h>
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
using optional_variable_list = std::vector<c10::optional<Variable>>;
using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>;
@ -97,7 +96,7 @@ struct TORCH_API Function {
// the parameter X.
template <typename X = T, typename... Args>
static auto apply(Args&&... args)
-> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>>;
-> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>>;
};
/// Context to save information during `forward` that can be accessed in
@ -228,8 +227,8 @@ inline void extract_vars(
}
template <typename T>
typename std::enable_if<std::is_same<T, variable_list>::value, T>::type
to_output_type(std::vector<c10::optional<Variable>>& output_list) {
std::enable_if_t<std::is_same_v<T, variable_list>, T> to_output_type(
std::vector<c10::optional<Variable>>& output_list) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
variable_list result;
std::transform(
@ -241,8 +240,8 @@ to_output_type(std::vector<c10::optional<Variable>>& output_list) {
}
template <typename T>
typename std::enable_if<std::is_same<T, Variable>::value, T>::type
to_output_type(std::vector<c10::optional<Variable>>& output_list) {
std::enable_if_t<std::is_same_v<T, Variable>, T> to_output_type(
std::vector<c10::optional<Variable>>& output_list) {
return *output_list[0];
}
@ -264,7 +263,7 @@ inline std::vector<c10::optional<Variable>> to_optional(variable_list& output) {
template <class T>
template <typename X, typename... Args>
auto Function<T>::apply(Args&&... args)
-> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> {
-> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>> {
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
if (functorch_tls) {
// Function support for functorch is handled in Python.
@ -434,5 +433,4 @@ void CppNode<T>::set_ctx_grad_fn(const std::shared_ptr<Node>& node) {
ctx_.grad_fn_ = node;
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -6,8 +6,7 @@
#include <c10/util/hash.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct Node;
@ -38,8 +37,7 @@ struct Edge {
/// The identifier of a particular input to the function.
uint32_t input_nr;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd
// The idiomatic way of enabling use of a custom type as the key of hash
// containers in C++11. This method removes the requirement of having to pass

View File

@ -27,14 +27,11 @@
#include <utility>
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct ReadyQueue;
}
} // namespace torch
namespace torch {
namespace autograd {
namespace torch::autograd {
// Maximum reentrant backward depth before switching to a new thread
// This limit is based on the TSAN's deadlock detector, where it will
@ -291,5 +288,4 @@ struct TORCH_API Engine {
using EngineStub = Engine& (*)();
TORCH_API void set_default_engine_stub(EngineStub stub);
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -3,8 +3,7 @@
#include <ATen/core/Tensor.h>
#include <unordered_set>
namespace torch {
namespace autograd {
namespace torch::autograd {
// [ Using ForwardGrad ]
// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner
@ -208,5 +207,4 @@ struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
mutable std::mutex mutex_;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -24,8 +24,7 @@
#include <utility>
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct Edge;
struct FunctionPostHook;
@ -757,5 +756,4 @@ struct TypeAndSize {
at::TensorOptions options;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -12,8 +12,7 @@ class SwapSavedVariables;
// A hook that's called on gradients
namespace torch {
namespace autograd {
namespace torch::autograd {
using Variable = at::Tensor;
using variable_list = std::vector<Variable>;
@ -62,5 +61,4 @@ struct TORCH_API PostAccumulateGradHook {
}
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -3,11 +3,9 @@
#include <ATen/core/grad_mode.h>
#include <torch/csrc/Export.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
using GradMode = at::GradMode;
using AutoGradMode = at::AutoGradMode;
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -6,8 +6,7 @@
#include <torch/csrc/autograd/utils/warnings.h>
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
using edge_list = std::vector<Edge>;
struct ReadyQueue;
@ -239,5 +238,4 @@ TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
TORCH_API int get_current_graph_task_id();
void add_node_to_current_graph_task_exec_info(Node* fn);
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -13,8 +13,7 @@
#include <c10/util/Optional.h>
#include <torch/csrc/autograd/variable.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct InputBuffer {
explicit InputBuffer(size_t size) : buffer(size) {}
@ -44,5 +43,4 @@ struct InputBuffer {
std::vector<Variable> buffer;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -22,8 +22,7 @@
#include <cstdint>
#include <utility>
namespace torch {
namespace autograd {
namespace torch::autograd {
using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
using MetadataShape = std::variant<SymIntSmallVec, at::Tensor>;
@ -109,5 +108,4 @@ struct TORCH_API InputMetadata {
bool is_nested_ = false;
bool was_default_constructed_ = true;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -27,9 +27,7 @@
// For extra context, see VariableHooksInterface.h, where a similar technique
// is used
namespace torch {
namespace autograd {
namespace impl {
namespace torch::autograd::impl {
struct TORCH_API JitDecompInterface {
virtual ~JitDecompInterface() = default;
@ -49,6 +47,4 @@ struct TORCH_API JitDecompRegisterer {
}
};
} // namespace impl
} // namespace autograd
} // namespace torch
} // namespace torch::autograd::impl

View File

@ -9,16 +9,15 @@
#include <torch/csrc/profiler/util.h>
namespace torch {
namespace profiler {
namespace impl {
namespace profiler::impl {
struct Result;
namespace kineto {
struct ActivityTraceWrapper;
} // namespace kineto
} // namespace impl
} // namespace profiler
namespace autograd {
namespace profiler {
} // namespace profiler::impl
namespace autograd::profiler {
using experimental_event_t = std::shared_ptr<torch::profiler::impl::Result>;
using extra_meta_t = std::unordered_map<std::string, std::string>;
@ -177,16 +176,13 @@ TORCH_API void prepareProfiler(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities);
} // namespace profiler
} // namespace autograd
} // namespace autograd::profiler
namespace profiler {
namespace impl {
namespace profiler::impl {
// Experimental.
TORCH_API void _reportVulkanEventToProfiler(vulkan_id_t id);
} // namespace impl
} // namespace profiler
} // namespace profiler::impl
} // namespace torch

View File

@ -15,8 +15,7 @@
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct Node;
@ -413,5 +412,4 @@ struct TORCH_API TLSLegacyProfilerGuard {
};
} // namespace profiler
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -1,13 +1,7 @@
#pragma once
namespace torch {
namespace autograd {
namespace profiler {
namespace python_tracer {
namespace torch::autograd::profiler::python_tracer {
void init();
}
} // namespace profiler
} // namespace autograd
} // namespace torch

View File

@ -4,13 +4,11 @@
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused);
void THPAutograd_initFunctions();
namespace torch {
namespace autograd {
namespace torch::autograd {
PyMethodDef* python_functions();
}
} // namespace torch
#include <torch/csrc/autograd/python_engine.h>
#include <torch/csrc/autograd/python_function.h>

View File

@ -8,8 +8,7 @@
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/utils/object_ptr.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct THPCppFunction {
PyObject_HEAD std::shared_ptr<Node> cdata;
@ -103,5 +102,4 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
bool THPCppFunction_Check(PyObject* obj);
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -7,9 +7,7 @@
bool THPEngine_initModule(PyObject* module);
namespace torch {
namespace autograd {
namespace python {
namespace torch::autograd::python {
struct PythonEngine : public Engine {
static Engine& get_python_engine();
@ -43,6 +41,4 @@ struct PythonEngine : public Engine {
PythonEngine();
};
} // namespace python
} // namespace autograd
} // namespace torch
} // namespace torch::autograd::python

View File

@ -2,8 +2,6 @@
#include <torch/csrc/python_headers.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
void initEnumTag(PyObject* module);
}
} // namespace torch

View File

@ -1,9 +1,7 @@
#pragma once
namespace torch {
namespace autograd {
namespace torch::autograd {
void initFFTFunctions(PyObject* module);
}
} // namespace torch

View File

@ -16,13 +16,11 @@
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
struct Graph;
}
} // namespace torch
namespace torch {
namespace autograd {
namespace torch::autograd {
// A Function which is implemented by a Python object (i.e., a THPFunction).
// Calls to 'apply' are forwarded to the Python method implementation.
@ -71,8 +69,7 @@ inline bool ensure_tuple(THPObjectPtr& obj) {
return true;
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THPFunction {

View File

@ -8,8 +8,7 @@ namespace torch::dynamo::autograd {
class SwapSavedVariables;
} // namespace torch::dynamo::autograd
namespace torch {
namespace autograd {
namespace torch::autograd {
struct PyFunctionTensorPreHook : public FunctionPreHook {
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
@ -53,5 +52,4 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
PyObject* dict;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -5,10 +5,8 @@
#include <torch/csrc/python_headers.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
void init_legacy_variable(PyObject* module);
}
} // namespace torch

View File

@ -1,9 +1,7 @@
#pragma once
namespace torch {
namespace autograd {
namespace torch::autograd {
void initLinalgFunctions(PyObject* module);
}
} // namespace torch

View File

@ -1,11 +1,9 @@
#pragma once
namespace torch {
namespace autograd {
namespace torch::autograd {
PyMethodDef* get_nested_functions_manual();
void initNestedFunctions(PyObject* module);
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -1,9 +1,7 @@
#pragma once
namespace torch {
namespace autograd {
namespace torch::autograd {
void initNNFunctions(PyObject* module);
}
} // namespace torch

View File

@ -10,8 +10,7 @@
namespace py = pybind11;
namespace torch {
namespace autograd {
namespace torch::autograd {
struct PySavedVariableHooks : public SavedVariableHooks {
PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook);
@ -31,5 +30,4 @@ struct PyDefaultSavedVariableHooks {
static std::unique_ptr<SavedVariableHooks> get_hooks();
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -1,9 +1,7 @@
#pragma once
namespace torch {
namespace autograd {
namespace torch::autograd {
void initSparseFunctions(PyObject* module);
}
} // namespace torch

View File

@ -1,9 +1,7 @@
#pragma once
namespace torch {
namespace autograd {
namespace torch::autograd {
void initSpecialFunctions(PyObject* module);
}
} // namespace torch

View File

@ -2,8 +2,7 @@
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
extern PyObject* THPVariableFunctionsModule;
@ -25,5 +24,4 @@ inline PyObject* TypeError_to_NotImplemented_(
void initTorchFunctions();
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -6,8 +6,7 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_symnode.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct UnpackedSlice {
c10::SymInt start;
@ -100,5 +99,4 @@ Variable valueToTensor(
PyObject* value,
const at::Device& device);
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -3,9 +3,7 @@
#include <c10/util/Optional.h>
#include <torch/custom_class.h>
namespace torch {
namespace autograd {
namespace profiler {
namespace torch::autograd::profiler {
struct PythonRecordFunction : public torch::CustomClassHolder {
at::RecordFunction record;
@ -26,6 +24,4 @@ TORCH_API c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new
const c10::intrusive_ptr<PythonRecordFunction>& record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut);
} // namespace profiler
} // namespace autograd
} // namespace torch
} // namespace torch::autograd::profiler

View File

@ -9,8 +9,7 @@
#include <cstdint>
#include <memory>
namespace torch {
namespace autograd {
namespace torch::autograd {
using Variable = at::Tensor;
struct Node;
@ -119,5 +118,4 @@ class TORCH_API SavedVariable {
std::unique_ptr<SavedVariableHooks>&& hooks,
const Variable& data);
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -2,8 +2,7 @@
#include <ATen/core/Tensor.h>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct TORCH_API SavedVariableHooks {
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
@ -11,5 +10,4 @@ struct TORCH_API SavedVariableHooks {
virtual ~SavedVariableHooks() = default;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -4,8 +4,7 @@
#include <torch/csrc/onnx/onnx.h>
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
struct SymbolicContext {
jit::Block* block;
@ -15,5 +14,4 @@ struct symbolic_unconvertible : public std::runtime_error {
using std::runtime_error::runtime_error;
};
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -21,8 +21,7 @@
#include <utility>
#include <vector>
namespace torch {
namespace autograd {
namespace torch::autograd {
/// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable =
/// at::Tensor`). This means you can perform all the usual mathematical and
@ -33,8 +32,7 @@ namespace autograd {
/// is to eliminate the `Variable` class in the near future.
using Variable = at::Tensor;
} // namespace autograd
} // namespace torch
} // namespace torch::autograd
// The following are all internal APIs and should not be shown in libtorch docs.
// Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
@ -42,8 +40,7 @@ using Variable = at::Tensor;
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch {
namespace autograd {
namespace torch::autograd {
/// Check if this type is supported by the autograd engine.
/// If you change this, update the doc at the top of the
@ -861,7 +858,6 @@ namespace utils {
TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
} // namespace utils
} // namespace autograd
} // namespace torch
} // namespace torch::autograd
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

View File

@ -1,12 +1,7 @@
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
#include <Python.h>
namespace torch {
namespace functorch {
namespace impl {
namespace torch::functorch::impl {
void initFuncTorchBindings(PyObject* module);
}
} // namespace functorch
} // namespace torch

View File

@ -11,8 +11,7 @@
#include <sstream>
#include <unordered_map>
namespace torch {
namespace jit {
namespace torch::jit {
class SourceRangeUnpickler;
struct SourceRange;
@ -444,8 +443,7 @@ using SourceRangeRecords = std::vector<TaggedRange>;
using SourceRangeTagMap =
std::unordered_map<SourceRange, int64_t, SourceRangeHasher>;
} // namespace jit
} // namespace torch
} // namespace torch::jit
namespace std {
template <>

View File

@ -17,8 +17,7 @@
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
struct Node;
struct Value;
struct Graph;
@ -382,13 +381,13 @@ TORCH_API void ensureUniqueIfOutOfPlaced(
template <
typename T,
typename = torch::enable_if_t<(
!std::is_convertible<torch::decay_t<T>, at::TensorList>::value &&
!std::is_convertible<torch::decay_t<T>, c10::List<at::Tensor>>::value &&
!std::is_convertible<torch::decay_t<T>, at::Tensor>::value &&
!std::is_convertible<
torch::decay_t<T>,
c10::intrusive_ptr<c10::ivalue::Object>>::value)>>
typename = torch::enable_if_t<
(!std::is_convertible_v<torch::decay_t<T>, at::TensorList> &&
!std::is_convertible_v<torch::decay_t<T>, c10::List<at::Tensor>> &&
!std::is_convertible_v<torch::decay_t<T>, at::Tensor> &&
!std::is_convertible_v<
torch::decay_t<T>,
c10::intrusive_ptr<c10::ivalue::Object>>)>>
void addOutput(Node* node, T&&) {
AT_ERROR(
"Found an unsupported argument type ",
@ -410,5 +409,4 @@ TORCH_API autograd::Variable getSizeOf(
TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var);
} // namespace tracer
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -28,8 +28,7 @@ PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr<T>, true);
PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr<T>);
PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr<T>, true);
namespace pybind11 {
namespace detail {
namespace pybind11::detail {
// torch.Tensor <-> at::Tensor conversions (without unwrapping)
template <>
@ -324,11 +323,9 @@ struct type_caster<c10::complex<T>> {
}
};
} // namespace detail
} // namespace pybind11
} // namespace pybind11::detail
namespace torch {
namespace impl {
namespace torch::impl {
// Use this function if you have a C++ object that is used from both C++
// and Python contexts, and you need its GIL to be released when you
@ -384,5 +381,4 @@ inline void destroy_without_gil(T* ptr) {
}
}
} // namespace impl
} // namespace torch
} // namespace torch::impl

View File

@ -4,11 +4,9 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/utils/python_stub.h>
namespace torch {
namespace utils {
namespace torch::utils {
void initializeMemoryFormats();
TORCH_PYTHON_API PyObject* getTHPMemoryFormat(c10::MemoryFormat);
} // namespace utils
} // namespace torch
} // namespace torch::utils

View File

@ -3,8 +3,7 @@
#include <ATen/core/Tensor.h>
#include <torch/csrc/python_headers.h>
namespace torch {
namespace utils {
namespace torch::utils {
PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force = false);
at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable = true);
@ -23,5 +22,4 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj);
void validate_numpy_for_dlpack_deleter_bug();
bool is_numpy_dlpack_deleter_bugged();
} // namespace utils
} // namespace torch
} // namespace torch::utils

View File

@ -73,13 +73,13 @@ struct MakeIndices<0, Is...> {
//===----------------------------------------------------------------------===//
template <bool value, typename T = void>
using enable_if_t = typename std::enable_if<value, T>::type;
using enable_if_t = std::enable_if_t<value, T>;
template <bool value, typename T = void>
using disable_if_t = enable_if_t<!value, T>;
template <typename T>
using decay_t = typename std::decay<T>::type;
using decay_t = std::decay_t<T>;
namespace detail {
template <bool...>
@ -112,7 +112,7 @@ using enable_if_all_of_t = enable_if_t<all_of<values...>::value>;
template <typename T, typename... Ts>
using disable_if_contains_t =
enable_if_all_of_t<(!std::is_same<T, decay_t<Ts>>::value)...>;
enable_if_all_of_t<(!std::is_same_v<T, decay_t<Ts>>)...>;
template <typename Function, typename... Ts>
void apply(Function function, Ts&&... ts) {