mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] [3/N] Use nested namespaces (#110314)
Mostly in torch/csrc/jit/runtime and in `ATen/cuda/` Pull Request resolved: https://github.com/pytorch/pytorch/pull/110314 Approved by: https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
8745d2d4f2
commit
ad8aef0f98
@ -3,8 +3,7 @@
|
|||||||
|
|
||||||
#include <c10/util/flat_hash_map.h>
|
#include <c10/util/flat_hash_map.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::caching {
|
||||||
namespace caching {
|
|
||||||
|
|
||||||
|
|
||||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||||
@ -45,5 +44,4 @@ size_t adjusted_use_count(const at::Tensor& t) {
|
|||||||
return t.use_count() - (is_cached_tensor(t) ? 1 : 0);
|
return t.use_count() - (is_cached_tensor(t) ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} // namespace at::caching
|
||||||
}
|
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::caching {
|
||||||
namespace caching {
|
|
||||||
|
|
||||||
// Some systems (just cudagraphs currently) will persist a static tensor output
|
// Some systems (just cudagraphs currently) will persist a static tensor output
|
||||||
// whose TensorImpl does not change across iterations. For these tensors caching
|
// whose TensorImpl does not change across iterations. For these tensors caching
|
||||||
@ -22,5 +21,4 @@ TORCH_API void set_cached_tensors_enabled(bool enable);
|
|||||||
// count of tensors with hooks.
|
// count of tensors with hooks.
|
||||||
TORCH_API size_t adjusted_use_count(const at::Tensor& t);
|
TORCH_API size_t adjusted_use_count(const at::Tensor& t);
|
||||||
|
|
||||||
} // namespace caching
|
} // namespace at::caching
|
||||||
} // namespace at
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
namespace at {
|
namespace at::detail {
|
||||||
namespace detail {
|
|
||||||
namespace {
|
namespace {
|
||||||
c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
|
c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
|
||||||
if (pin_memory) {
|
if (pin_memory) {
|
||||||
@ -441,4 +440,4 @@ TensorBase empty_strided_symint_meta(
|
|||||||
options.pinned_memory_opt());
|
options.pinned_memory_opt());
|
||||||
}
|
}
|
||||||
|
|
||||||
}} // namespace at::detail
|
} // namespace at::detail
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/core/TensorBase.h>
|
#include <ATen/core/TensorBase.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::detail {
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
|
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
|
||||||
for (const auto& x : size) {
|
for (const auto& x : size) {
|
||||||
@ -158,5 +157,4 @@ TORCH_API TensorBase empty_strided_symint_meta(
|
|||||||
SymIntArrayRef stride,
|
SymIntArrayRef stride,
|
||||||
const TensorOptions& options);
|
const TensorOptions& options);
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace at::detail
|
||||||
} // namespace at
|
|
||||||
|
@ -9,7 +9,7 @@ namespace internal {
|
|||||||
TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) {
|
TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) {
|
||||||
return OptionalTensorRef(self)->expand(size);
|
return OptionalTensorRef(self)->expand(size);
|
||||||
}
|
}
|
||||||
}
|
} // namespace internal
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// NOTE: are_expandable did a similar check, please keep them sync if change is needed
|
// NOTE: are_expandable did a similar check, please keep them sync if change is needed
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/FuncTorchTLS.h>
|
#include <ATen/FuncTorchTLS.h>
|
||||||
|
|
||||||
namespace at { namespace functorch {
|
namespace at::functorch {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -28,4 +28,4 @@ std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}}
|
} // namespace at::functorch
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace at {
|
namespace at::functorch {
|
||||||
namespace functorch {
|
|
||||||
|
|
||||||
// NOTE [functorch TLS in pytorch/pytorch]
|
// NOTE [functorch TLS in pytorch/pytorch]
|
||||||
//
|
//
|
||||||
@ -44,5 +43,4 @@ TORCH_API void setFuncTorchTLS(
|
|||||||
// get a mutable reference to the functorch tls
|
// get a mutable reference to the functorch tls
|
||||||
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
|
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
|
||||||
|
|
||||||
} // namespace functorch
|
} // namespace at::functorch
|
||||||
} // namespace at
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <ATen/WrapDimUtilsMulti.h>
|
#include <ATen/WrapDimUtilsMulti.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
namespace at {
|
namespace at::functionalization {
|
||||||
namespace functionalization {
|
|
||||||
|
|
||||||
// This logic is similar to autograd code for view backwards calls.
|
// This logic is similar to autograd code for view backwards calls.
|
||||||
// We can't easily share it though, because (eventually) these functions
|
// We can't easily share it though, because (eventually) these functions
|
||||||
@ -348,5 +347,4 @@ Tensor FunctionalInverses::alias_copy_inverse(const Tensor& base, const Tensor&
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // functionalization
|
} // namespace at::functionalization
|
||||||
} // at
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace at {
|
namespace at::functionalization {
|
||||||
namespace functionalization {
|
|
||||||
|
|
||||||
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
|
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
|
||||||
if (out_idx == this->out_index) return *this;
|
if (out_idx == this->out_index) return *this;
|
||||||
@ -122,5 +121,4 @@ bool FunctionalStorageImpl::apply_updates() {
|
|||||||
return any_updates;
|
return any_updates;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace functionalization
|
} // namespace at::functionalization
|
||||||
} // namespace at
|
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::functionalization {
|
||||||
namespace functionalization {
|
|
||||||
|
|
||||||
// See Note [Functionalization Pass In Core]
|
// See Note [Functionalization Pass In Core]
|
||||||
|
|
||||||
@ -117,5 +116,4 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
|||||||
bool frozen_ = false;
|
bool frozen_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace functionalization
|
} // namespace at::functionalization
|
||||||
} // namespace at
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include <ATen/LegacyVmapMode.h>
|
#include <ATen/LegacyVmapMode.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
thread_local int64_t VmapMode_current_vmap_level = 0;
|
thread_local int64_t VmapMode_current_vmap_level = 0;
|
||||||
|
|
||||||
@ -24,5 +23,4 @@ int64_t VmapMode::decrement_nesting() {
|
|||||||
}
|
}
|
||||||
return VmapMode_current_vmap_level;
|
return VmapMode_current_vmap_level;
|
||||||
}
|
}
|
||||||
} // namespace impl
|
} // namespace at::impl
|
||||||
} // namespace at
|
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
// VmapMode contains a thread local count of how many nested vmaps
|
// VmapMode contains a thread local count of how many nested vmaps
|
||||||
// we are currently inside. That number is known as the `vmap level`.
|
// we are currently inside. That number is known as the `vmap level`.
|
||||||
@ -24,5 +23,4 @@ struct TORCH_API VmapMode {
|
|||||||
static int64_t decrement_nesting();
|
static int64_t decrement_nesting();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace at::impl
|
||||||
} // namespace at
|
|
||||||
|
@ -10,8 +10,7 @@
|
|||||||
#include <c10/util/Metaprogramming.h>
|
#include <c10/util/Metaprogramming.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::native {
|
||||||
namespace native {
|
|
||||||
struct NestedTensorImpl;
|
struct NestedTensorImpl;
|
||||||
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
|
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
|
||||||
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
|
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
|
||||||
@ -276,5 +275,4 @@ inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
|
|||||||
return get_nested_tensor_impl(tensor)->get_nested_sizes();
|
return get_nested_tensor_impl(tensor)->get_nested_sizes();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace at::native
|
||||||
} // namespace at
|
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
|
|
||||||
#define INTRA_OP_PARALLEL
|
#define INTRA_OP_PARALLEL
|
||||||
|
|
||||||
namespace at {
|
namespace at::internal {
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
TORCH_API void invoke_parallel(
|
TORCH_API void invoke_parallel(
|
||||||
const int64_t begin,
|
const int64_t begin,
|
||||||
@ -17,6 +16,4 @@ TORCH_API void invoke_parallel(
|
|||||||
const int64_t grain_size,
|
const int64_t grain_size,
|
||||||
const std::function<void(int64_t, int64_t)>& f);
|
const std::function<void(int64_t, int64_t)>& f);
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace at::internal
|
||||||
|
|
||||||
} // namespace at
|
|
||||||
|
@ -15,8 +15,7 @@
|
|||||||
|
|
||||||
#define INTRA_OP_PARALLEL
|
#define INTRA_OP_PARALLEL
|
||||||
|
|
||||||
namespace at {
|
namespace at::internal {
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
template <typename F>
|
template <typename F>
|
||||||
inline void invoke_parallel(
|
inline void invoke_parallel(
|
||||||
@ -50,5 +49,4 @@ inline void invoke_parallel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace at::internal
|
||||||
} // namespace at
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
#include <ATen/PythonTorchFunctionTLS.h>
|
#include <ATen/PythonTorchFunctionTLS.h>
|
||||||
#include <c10/core/TensorImpl.h>
|
#include <c10/core/TensorImpl.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
static thread_local PythonTorchFunctionTLS pythonTorchFunctionState;
|
static thread_local PythonTorchFunctionTLS pythonTorchFunctionState;
|
||||||
|
|
||||||
@ -47,5 +46,4 @@ bool torch_function_mode_enabled() {
|
|||||||
PythonTorchFunctionTLS::stack_len() > 0;
|
PythonTorchFunctionTLS::stack_len() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace at::impl
|
||||||
} // namespace at
|
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
#include <c10/core/SafePyObject.h>
|
#include <c10/core/SafePyObject.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
|
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
|
||||||
|
|
||||||
@ -32,5 +31,4 @@ struct TORCH_API PythonTorchFunctionTLS {
|
|||||||
|
|
||||||
TORCH_API bool torch_function_mode_enabled();
|
TORCH_API bool torch_function_mode_enabled();
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace at::impl
|
||||||
} // namespace at
|
|
||||||
|
@ -9,8 +9,7 @@
|
|||||||
#include <ATen/ops/scalar_tensor.h>
|
#include <ATen/ops/scalar_tensor.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at {
|
namespace at::detail {
|
||||||
namespace detail {
|
|
||||||
// When filling a number to 1-element CPU tensor, we want to skip
|
// When filling a number to 1-element CPU tensor, we want to skip
|
||||||
// everything but manipulate data ptr directly.
|
// everything but manipulate data ptr directly.
|
||||||
// Ideally this fast pass should be implemented in TensorIterator,
|
// Ideally this fast pass should be implemented in TensorIterator,
|
||||||
@ -21,8 +20,7 @@ TORCH_API Tensor scalar_tensor_static(
|
|||||||
const Scalar& s,
|
const Scalar& s,
|
||||||
c10::optional<ScalarType> dtype_opt,
|
c10::optional<ScalarType> dtype_opt,
|
||||||
c10::optional<Device> device_opt);
|
c10::optional<Device> device_opt);
|
||||||
} // namespace detail
|
} // namespace at::detail
|
||||||
} // namespace at
|
|
||||||
|
|
||||||
// This is in the c10 namespace because we use ADL to find the functions in it.
|
// This is in the c10 namespace because we use ADL to find the functions in it.
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
@ -60,8 +58,7 @@ inline at::Tensor scalar_to_tensor(
|
|||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
namespace at {
|
namespace at::native {
|
||||||
namespace native {
|
|
||||||
|
|
||||||
inline Tensor wrapped_scalar_tensor(
|
inline Tensor wrapped_scalar_tensor(
|
||||||
const Scalar& scalar,
|
const Scalar& scalar,
|
||||||
@ -71,5 +68,4 @@ inline Tensor wrapped_scalar_tensor(
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace at::native
|
||||||
} // namespace at
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include <ATen/SequenceNumber.h>
|
#include <ATen/SequenceNumber.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::sequence_number {
|
||||||
namespace sequence_number {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
thread_local uint64_t sequence_nr_ = 0;
|
thread_local uint64_t sequence_nr_ = 0;
|
||||||
@ -15,5 +14,4 @@ uint64_t get_and_increment() {
|
|||||||
return sequence_nr_++;
|
return sequence_nr_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sequence_number
|
} // namespace at::sequence_number
|
||||||
} // namespace at
|
|
||||||
|
@ -3,14 +3,11 @@
|
|||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
namespace at {
|
|
||||||
|
|
||||||
// A simple thread local enumeration, used to link forward and backward pass
|
// A simple thread local enumeration, used to link forward and backward pass
|
||||||
// ops and is used by autograd and observers framework
|
// ops and is used by autograd and observers framework
|
||||||
namespace sequence_number {
|
namespace at::sequence_number {
|
||||||
|
|
||||||
TORCH_API uint64_t peek();
|
TORCH_API uint64_t peek();
|
||||||
TORCH_API uint64_t get_and_increment();
|
TORCH_API uint64_t get_and_increment();
|
||||||
|
|
||||||
} // namespace sequence_number
|
} // namespace at::sequence_number
|
||||||
} // namespace at
|
|
||||||
|
@ -22,8 +22,7 @@
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace at {
|
namespace at::indexing {
|
||||||
namespace indexing {
|
|
||||||
|
|
||||||
const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
|
const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
|
||||||
const int64_t INDEX_MAX = -(INDEX_MIN + 1);
|
const int64_t INDEX_MAX = -(INDEX_MIN + 1);
|
||||||
@ -728,5 +727,4 @@ static inline void set_item(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace indexing
|
} // namespace at::indexing
|
||||||
} // namespace at
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#include <ATen/WrapDimUtils.h>
|
#include <ATen/WrapDimUtils.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
namespace at { namespace namedinference {
|
namespace at::namedinference {
|
||||||
|
|
||||||
|
|
||||||
Dimname TensorName::toDimname() const {
|
Dimname TensorName::toDimname() const {
|
||||||
@ -126,4 +126,4 @@ std::vector<Dimname> TensorNames::toDimnameVec() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}} // namespace at::namedinference
|
} // namespace at::namedinference
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
#include <ATen/WrapDimUtils.h>
|
#include <ATen/WrapDimUtils.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::namedinference {
|
||||||
namespace namedinference {
|
|
||||||
|
|
||||||
// TensorName and TensorNames are wrappers around Dimname and DimnameList
|
// TensorName and TensorNames are wrappers around Dimname and DimnameList
|
||||||
// that contain helper functions to make writing name inference rules easier.
|
// that contain helper functions to make writing name inference rules easier.
|
||||||
@ -71,5 +70,4 @@ struct TORCH_API TensorNames {
|
|||||||
TensorNameVec names_;
|
TensorNameVec names_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace namedinference
|
} // namespace at::namedinference
|
||||||
} // namespace at
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace at {
|
namespace at::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
static thread_local ThreadLocalPythonObjects py_objects;
|
static thread_local ThreadLocalPythonObjects py_objects;
|
||||||
|
|
||||||
@ -32,5 +31,4 @@ const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
} // namespace at::impl
|
||||||
}
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace at {
|
namespace at::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
struct TORCH_API ThreadLocalPythonObjects {
|
struct TORCH_API ThreadLocalPythonObjects {
|
||||||
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
|
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
|
||||||
@ -19,5 +18,4 @@ struct TORCH_API ThreadLocalPythonObjects {
|
|||||||
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
|
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace at::impl
|
||||||
} // namespace at
|
|
||||||
|
@ -108,12 +108,10 @@
|
|||||||
// guard is essentially no-op when the master `setTracingState()` switch is
|
// guard is essentially no-op when the master `setTracingState()` switch is
|
||||||
// off.
|
// off.
|
||||||
|
|
||||||
namespace at {
|
|
||||||
// TODO: move this from `at::` to `jit::torch::` after
|
// TODO: move this from `at::` to `jit::torch::` after
|
||||||
// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
|
// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
|
||||||
|
|
||||||
namespace tracer {
|
namespace at::tracer::impl {
|
||||||
namespace impl {
|
|
||||||
|
|
||||||
static inline bool is_dispatch_enabled() {
|
static inline bool is_dispatch_enabled() {
|
||||||
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
|
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
|
||||||
@ -131,6 +129,4 @@ struct NoTracerDispatchMode {
|
|||||||
c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
|
c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace at::tracer::impl
|
||||||
} // namespace tracer
|
|
||||||
} // namespace at
|
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
#include <ATen/CachedTensorUtils.h>
|
#include <ATen/CachedTensorUtils.h>
|
||||||
#include <c10/util/flat_hash_map.h>
|
#include <c10/util/flat_hash_map.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::autocast {
|
||||||
namespace autocast {
|
|
||||||
|
|
||||||
bool is_enabled() {
|
bool is_enabled() {
|
||||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
|
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
|
||||||
@ -518,5 +517,4 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace autocast
|
} // namespace at::autocast
|
||||||
} // namespace at
|
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::autocast {
|
||||||
namespace autocast {
|
|
||||||
|
|
||||||
TORCH_API bool is_enabled();
|
TORCH_API bool is_enabled();
|
||||||
TORCH_API void set_enabled(bool enabled);
|
TORCH_API void set_enabled(bool enabled);
|
||||||
@ -537,8 +536,7 @@ wouldn't try to get clever about it Therefore, for the moment, this is all
|
|||||||
copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
|
copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
|
||||||
********************************************************************************************************************/
|
********************************************************************************************************************/
|
||||||
|
|
||||||
} // namespace autocast
|
} // namespace at::autocast
|
||||||
} // namespace at
|
|
||||||
|
|
||||||
#define ADD_NS(RAW_OP) at::RAW_OP
|
#define ADD_NS(RAW_OP) at::RAW_OP
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
// TODO: No need to have this whole header, we can just put it all in
|
// TODO: No need to have this whole header, we can just put it all in
|
||||||
// the cpp file
|
// the cpp file
|
||||||
|
|
||||||
namespace at { namespace cuda { namespace detail {
|
namespace at::cuda::detail {
|
||||||
|
|
||||||
// Set the callback to initialize Magma, which is set by
|
// Set the callback to initialize Magma, which is set by
|
||||||
// torch_cuda_cu. This indirection is required so magma_init is called
|
// torch_cuda_cu. This indirection is required so magma_init is called
|
||||||
@ -51,4 +51,4 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||||||
void deviceSynchronize(DeviceIndex device_index) const override;
|
void deviceSynchronize(DeviceIndex device_index) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
}}} // at::cuda::detail
|
} // at::cuda::detail
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
namespace at { namespace cuda { namespace {
|
namespace at::cuda { namespace {
|
||||||
|
|
||||||
template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
|
template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
|
||||||
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
|
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
|
||||||
@ -148,4 +148,4 @@ struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThread
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}}} // namespace at::cuda::detail::<anonymous>
|
}} // namespace at::cuda::detail::<anonymous>
|
||||||
|
@ -4,9 +4,7 @@
|
|||||||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||||
#include <ATen/native/CanUse32BitIndexMath.h>
|
#include <ATen/native/CanUse32BitIndexMath.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::cuda::detail {
|
||||||
namespace cuda {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
|
TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
|
||||||
using at::native::canUse32BitIndexMath;
|
using at::native::canUse32BitIndexMath;
|
||||||
@ -27,6 +25,4 @@ getTensorInfo(const at::TensorBase &t) {
|
|||||||
t.data_ptr<scalar>(), dims, sz, st);
|
t.data_ptr<scalar>(), dims, sz, st);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // detail
|
} // namespace at::cuda::detail
|
||||||
} // cuda
|
|
||||||
} // at
|
|
||||||
|
@ -5,9 +5,7 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at {
|
namespace at::cuda::detail {
|
||||||
namespace cuda {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
// A utility class to implement integer division by multiplication, given a fixed
|
// A utility class to implement integer division by multiplication, given a fixed
|
||||||
// divisor.
|
// divisor.
|
||||||
@ -123,4 +121,4 @@ struct IntDivider<unsigned int> {
|
|||||||
unsigned int shift; // Shift amounts.
|
unsigned int shift; // Shift amounts.
|
||||||
};
|
};
|
||||||
|
|
||||||
}}} // namespace at::cuda::detail
|
} // namespace at::cuda::detail
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
namespace at { namespace cuda { namespace detail {
|
namespace at::cuda::detail {
|
||||||
|
|
||||||
// CUDA: grid stride looping
|
// CUDA: grid stride looping
|
||||||
//
|
//
|
||||||
@ -34,4 +34,4 @@ inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_
|
|||||||
return static_cast<int>(block_num);
|
return static_cast<int>(block_num);
|
||||||
}
|
}
|
||||||
|
|
||||||
}}} // namespace at::cuda::detail
|
} // namespace at::cuda::detail
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/detail/CUDAHooksInterface.h>
|
#include <ATen/detail/CUDAHooksInterface.h>
|
||||||
namespace at { namespace cuda {
|
namespace at::cuda {
|
||||||
// Forward-declares at::cuda::NVRTC
|
// Forward-declares at::cuda::NVRTC
|
||||||
struct NVRTC;
|
struct NVRTC;
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
extern NVRTC lazyNVRTC;
|
extern NVRTC lazyNVRTC;
|
||||||
}
|
} // namespace detail
|
||||||
|
|
||||||
}} // at::cuda::detail
|
} // namespace at::cuda
|
||||||
|
@ -2,9 +2,7 @@
|
|||||||
|
|
||||||
#include <ATen/CollapseDims.h>
|
#include <ATen/CollapseDims.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at::cuda::detail {
|
||||||
namespace cuda {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
#define MAX_TENSORINFO_DIMS 25
|
#define MAX_TENSORINFO_DIMS 25
|
||||||
|
|
||||||
@ -115,6 +113,4 @@ struct IndexToOffset<T, IndexType, -1> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // detail
|
} // namespace at::cuda::detail
|
||||||
} // cuda
|
|
||||||
} // at
|
|
||||||
|
@ -2,9 +2,7 @@
|
|||||||
// Eager mode clients should not include this file directly, instead,
|
// Eager mode clients should not include this file directly, instead,
|
||||||
// they should #include <ATen/cuda/CUDAGraphsUtils.cuh>, which has a #pragma once.
|
// they should #include <ATen/cuda/CUDAGraphsUtils.cuh>, which has a #pragma once.
|
||||||
|
|
||||||
namespace at {
|
namespace at::cuda::philox {
|
||||||
namespace cuda {
|
|
||||||
namespace philox {
|
|
||||||
|
|
||||||
// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
|
// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
|
||||||
// that instance was created with graph capture underway or not.
|
// that instance was created with graph capture underway or not.
|
||||||
@ -27,6 +25,4 @@ unpack(at::PhiloxCudaState arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace philox
|
} // namespace at::cuda::philox
|
||||||
} // namespace cuda
|
|
||||||
} // namespace at
|
|
||||||
|
@ -1240,21 +1240,15 @@ void THCPGraph_init(PyObject* module);
|
|||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
PyMethodDef* THCPModule_methods();
|
PyMethodDef* THCPModule_methods();
|
||||||
namespace torch {
|
namespace torch::cuda {
|
||||||
namespace cuda {
|
|
||||||
|
|
||||||
void initModule(PyObject* module);
|
void initModule(PyObject* module);
|
||||||
|
} // namespace torch::cuda
|
||||||
}
|
|
||||||
} // namespace torch
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef USE_ITT
|
#ifdef USE_ITT
|
||||||
namespace torch {
|
namespace torch::profiler {
|
||||||
namespace profiler {
|
|
||||||
void initIttBindings(PyObject* module);
|
void initIttBindings(PyObject* module);
|
||||||
} // namespace profiler
|
} // namespace torch::profiler
|
||||||
} // namespace torch
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static std::vector<PyMethodDef> methods;
|
static std::vector<PyMethodDef> methods;
|
||||||
|
@ -22,8 +22,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct Def;
|
struct Def;
|
||||||
struct Property;
|
struct Property;
|
||||||
@ -349,5 +348,4 @@ namespace script {
|
|||||||
// of the public API; new code should not use this type alias.
|
// of the public API; new code should not use this type alias.
|
||||||
using CompilationUnit = ::torch::jit::CompilationUnit;
|
using CompilationUnit = ::torch::jit::CompilationUnit;
|
||||||
} // namespace script
|
} // namespace script
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct TORCH_API GraphFunction : public Function {
|
struct TORCH_API GraphFunction : public Function {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||||
@ -192,5 +191,4 @@ TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
|
|||||||
TORCH_API GraphFunction& toGraphFunction(Function&);
|
TORCH_API GraphFunction& toGraphFunction(Function&);
|
||||||
TORCH_API const GraphFunction& toGraphFunction(const Function&);
|
TORCH_API const GraphFunction& toGraphFunction(const Function&);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <torch/csrc/api/include/torch/imethod.h>
|
#include <torch/csrc/api/include/torch/imethod.h>
|
||||||
#include <torch/csrc/jit/api/function_impl.h>
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||||
|
|
||||||
@ -79,5 +78,4 @@ namespace script {
|
|||||||
using Method = ::torch::jit::Method;
|
using Method = ::torch::jit::Method;
|
||||||
} // namespace script
|
} // namespace script
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -32,8 +32,7 @@
|
|||||||
// modules and their methods into flattened graphs which don't have any
|
// modules and their methods into flattened graphs which don't have any
|
||||||
// function calls.
|
// function calls.
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using ::c10::Argument;
|
using ::c10::Argument;
|
||||||
using ::c10::FunctionSchema;
|
using ::c10::FunctionSchema;
|
||||||
@ -676,5 +675,4 @@ using Module = ::torch::jit::Module;
|
|||||||
using ExtraFilesMap = ::torch::jit::ExtraFilesMap;
|
using ExtraFilesMap = ::torch::jit::ExtraFilesMap;
|
||||||
} // namespace script
|
} // namespace script
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -7,8 +7,7 @@
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct Resolver;
|
struct Resolver;
|
||||||
using ResolverPtr = std::shared_ptr<Resolver>;
|
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||||
@ -196,5 +195,4 @@ namespace script {
|
|||||||
// of the public API; new code should not use this type alias.
|
// of the public API; new code should not use this type alias.
|
||||||
using Object = ::torch::jit::Object;
|
using Object = ::torch::jit::Object;
|
||||||
} // namespace script
|
} // namespace script
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
static UpgradersMap upgradersMap;
|
static UpgradersMap upgradersMap;
|
||||||
|
|
||||||
@ -84,5 +83,4 @@ void test_only_remove_upgraders(
|
|||||||
upgradersMap.test_only_remove_content(content);
|
upgradersMap.test_only_remove_content(content);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
class UpgradersMap {
|
class UpgradersMap {
|
||||||
public:
|
public:
|
||||||
@ -44,5 +43,4 @@ TORCH_API void test_only_populate_upgraders(
|
|||||||
TORCH_API void test_only_remove_upgraders(
|
TORCH_API void test_only_remove_upgraders(
|
||||||
const std::unordered_map<std::string, std::string>& content);
|
const std::unordered_map<std::string, std::string>& content);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -11,8 +11,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
|
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
|
||||||
{"logspace_0_8", R"SCRIPT(
|
{"logspace_0_8", R"SCRIPT(
|
||||||
@ -150,5 +149,4 @@ std::unordered_map<std::string, std::string> get_upgraders_entry_map() {
|
|||||||
return kUpgradersEntryMap;
|
return kUpgradersEntryMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
TORCH_API void populate_upgraders_graph_map();
|
TORCH_API void populate_upgraders_graph_map();
|
||||||
|
|
||||||
@ -18,5 +17,4 @@ std::shared_ptr<Graph> create_upgrader_graph(
|
|||||||
const std::string& upgrader_name,
|
const std::string& upgrader_name,
|
||||||
const std::string& upgrader_body);
|
const std::string& upgrader_body);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
c10::optional<UpgraderEntry> findUpgrader(
|
c10::optional<UpgraderEntry> findUpgrader(
|
||||||
const std::vector<UpgraderEntry>& upgraders_for_schema,
|
const std::vector<UpgraderEntry>& upgraders_for_schema,
|
||||||
@ -95,5 +94,4 @@ std::vector<UpgraderRange> getUpgradersRangeForOp(const std::string& name) {
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct UpgraderRange {
|
struct UpgraderRange {
|
||||||
int min_version;
|
int min_version;
|
||||||
@ -48,5 +47,4 @@ TORCH_API uint64_t getMaxOperatorVersion();
|
|||||||
TORCH_API std::vector<UpgraderRange> getUpgradersRangeForOp(
|
TORCH_API std::vector<UpgraderRange> getUpgradersRangeForOp(
|
||||||
const std::string& name);
|
const std::string& name);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
// this flag is used to make sure the elements in the version map
|
// this flag is used to make sure the elements in the version map
|
||||||
// are sorted according to when the upgraders are introduced.
|
// are sorted according to when the upgraders are introduced.
|
||||||
@ -130,5 +129,4 @@ bool get_version_calculator_flag() {
|
|||||||
return calculatePackageVersionBasedOnUpgraders;
|
return calculatePackageVersionBasedOnUpgraders;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct UpgraderEntry {
|
struct UpgraderEntry {
|
||||||
int bumped_at_version;
|
int bumped_at_version;
|
||||||
@ -31,5 +30,4 @@ TORCH_API void test_only_remove_entry(const std::string& op_name);
|
|||||||
|
|
||||||
TORCH_API void test_only_reset_flag();
|
TORCH_API void test_only_reset_flag();
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void initJITBindings(PyObject* module);
|
void initJITBindings(PyObject* module);
|
||||||
|
|
||||||
}
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
inline c10::optional<Module> as_module(py::handle obj) {
|
inline c10::optional<Module> as_module(py::handle obj) {
|
||||||
static py::handle ScriptModule =
|
static py::handle ScriptModule =
|
||||||
@ -33,5 +32,4 @@ inline c10::optional<Object> as_object(py::handle obj) {
|
|||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -18,8 +18,7 @@
|
|||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
// This is a variant of shared_ptr that "sees through" a wrapper.
|
// This is a variant of shared_ptr that "sees through" a wrapper.
|
||||||
// We use it to convert Value, Node, Block and node to "wrapped" Python
|
// We use it to convert Value, Node, Block and node to "wrapped" Python
|
||||||
@ -64,13 +63,11 @@ class unwrapping_shared_ptr {
|
|||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
|
||||||
PYBIND11_DECLARE_HOLDER_TYPE(T, torch::jit::unwrapping_shared_ptr<T>, true);
|
PYBIND11_DECLARE_HOLDER_TYPE(T, torch::jit::unwrapping_shared_ptr<T>, true);
|
||||||
|
|
||||||
namespace pybind11 {
|
namespace pybind11::detail {
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
#define CREATE_UNWRAPPING_CASTER(Class) \
|
#define CREATE_UNWRAPPING_CASTER(Class) \
|
||||||
template <> \
|
template <> \
|
||||||
@ -110,12 +107,6 @@ CREATE_UNWRAPPING_CASTER(torch::jit::Block);
|
|||||||
|
|
||||||
#undef CREATE_UNWRAPPING_CASTER
|
#undef CREATE_UNWRAPPING_CASTER
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace pybind11
|
|
||||||
|
|
||||||
namespace pybind11 {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct type_caster<torch::jit::IValue> {
|
struct type_caster<torch::jit::IValue> {
|
||||||
public:
|
public:
|
||||||
@ -207,11 +198,9 @@ struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace pybind11::detail
|
||||||
} // namespace pybind11
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
static inline py::tuple tuple_tail(const py::tuple& tup) {
|
static inline py::tuple tuple_tail(const py::tuple& tup) {
|
||||||
py::tuple r(tup.size() - 1);
|
py::tuple r(tup.size() - 1);
|
||||||
@ -221,5 +210,4 @@ static inline py::tuple tuple_tail(const py::tuple& tup) {
|
|||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -53,8 +53,7 @@
|
|||||||
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
|
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using ResolutionCallback = std::function<py::object(std::string)>;
|
using ResolutionCallback = std::function<py::object(std::string)>;
|
||||||
|
|
||||||
@ -1111,5 +1110,4 @@ TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
|
|||||||
bool is_overload,
|
bool is_overload,
|
||||||
c10::optional<c10::DispatchKey> dk = c10::nullopt);
|
c10::optional<c10::DispatchKey> dk = c10::nullopt);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -10,9 +10,7 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::python {
|
||||||
namespace jit {
|
|
||||||
namespace python {
|
|
||||||
|
|
||||||
struct IODescriptor {
|
struct IODescriptor {
|
||||||
struct VariableMetadata {
|
struct VariableMetadata {
|
||||||
@ -118,6 +116,4 @@ PyObject* unflatten(
|
|||||||
at::ArrayRef<autograd::Variable> vars,
|
at::ArrayRef<autograd::Variable> vars,
|
||||||
const IODescriptor& structure);
|
const IODescriptor& structure);
|
||||||
|
|
||||||
} // namespace python
|
} // namespace torch::jit::python
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
#include <torch/custom_class.h>
|
#include <torch/custom_class.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void initPythonCustomClassBindings(PyObject* module);
|
void initPythonCustomClassBindings(PyObject* module);
|
||||||
|
|
||||||
@ -18,5 +17,4 @@ struct ScriptClass {
|
|||||||
c10::StrongTypePtr class_type_;
|
c10::StrongTypePtr class_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void initScriptDictBindings(PyObject* module);
|
void initScriptDictBindings(PyObject* module);
|
||||||
|
|
||||||
@ -124,5 +123,4 @@ class ScriptDict final {
|
|||||||
c10::impl::GenericDict dict_;
|
c10::impl::GenericDict dict_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/utils/object_ptr.h>
|
#include <torch/csrc/utils/object_ptr.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void initPythonIRBindings(PyObject* module);
|
void initPythonIRBindings(PyObject* module);
|
||||||
|
|
||||||
@ -48,5 +47,4 @@ struct ConcretePythonOp : public PythonOp {
|
|||||||
void lint_python() const override;
|
void lint_python() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -7,8 +7,7 @@
|
|||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10::ivalue {
|
||||||
namespace ivalue {
|
|
||||||
|
|
||||||
// concrete ivalue Holder that hold a py::object
|
// concrete ivalue Holder that hold a py::object
|
||||||
struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
|
struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
|
||||||
@ -95,5 +94,4 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
|
|||||||
py::object py_obj_;
|
py::object py_obj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ivalue
|
} // namespace c10::ivalue
|
||||||
} // namespace c10
|
|
||||||
|
@ -10,8 +10,7 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void initScriptListBindings(PyObject* module);
|
void initScriptListBindings(PyObject* module);
|
||||||
|
|
||||||
@ -226,5 +225,4 @@ class ScriptList final {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -10,8 +10,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
std::string typeString(py::handle h);
|
std::string typeString(py::handle h);
|
||||||
|
|
||||||
@ -374,5 +373,4 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
|
|||||||
size_t n_binders) override;
|
size_t n_binders) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct Module;
|
struct Module;
|
||||||
|
|
||||||
@ -43,5 +42,4 @@ std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
|
|||||||
Module* self = nullptr,
|
Module* self = nullptr,
|
||||||
const std::vector<std::string>& argument_names = {});
|
const std::vector<std::string>& argument_names = {});
|
||||||
} // namespace tracer
|
} // namespace tracer
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
#include <torch/csrc/python_headers.h>
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void initTreeViewBindings(PyObject* module);
|
void initTreeViewBindings(PyObject* module);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
#include <torch/csrc/jit/python/pybind.h>
|
#include <torch/csrc/jit/python/pybind.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
void initJitScriptBindings(PyObject* module);
|
void initJitScriptBindings(PyObject* module);
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
TORCH_API void setGraphExecutorOptimize(bool o);
|
TORCH_API void setGraphExecutorOptimize(bool o);
|
||||||
TORCH_API bool getGraphExecutorOptimize();
|
TORCH_API bool getGraphExecutorOptimize();
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
TORCH_API void setUTF8DecodingIgnore(bool o);
|
TORCH_API void setUTF8DecodingIgnore(bool o);
|
||||||
TORCH_API bool getUTF8DecodingIgnore();
|
TORCH_API bool getUTF8DecodingIgnore();
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void ArgumentSpecCreator::scan(
|
void ArgumentSpecCreator::scan(
|
||||||
const TypePtr& typ,
|
const TypePtr& typ,
|
||||||
@ -287,5 +286,4 @@ void ArgumentSpecCreator::specializeTypes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -15,8 +15,7 @@ C10_CLANG_DIAGNOSTIC_PUSH()
|
|||||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
// GraphExecutor creates specializations of Graphs for different
|
// GraphExecutor creates specializations of Graphs for different
|
||||||
// dimensionalitities and types of inputs.
|
// dimensionalitities and types of inputs.
|
||||||
@ -467,8 +466,7 @@ inline c10::optional<int8_t> convertOptional(
|
|||||||
: c10::optional<int8_t>{};
|
: c10::optional<int8_t>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
|
|
||||||
|
@ -16,8 +16,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using value_map = std::unordered_map<Value*, Value*>;
|
using value_map = std::unordered_map<Value*, Value*>;
|
||||||
using value_set = std::unordered_set<Value*>;
|
using value_set = std::unordered_set<Value*>;
|
||||||
@ -868,5 +867,4 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
|
|||||||
UpdateDifferentiableGraphRequiresGrad(grad_desc.f, false);
|
UpdateDifferentiableGraphRequiresGrad(grad_desc.f, false);
|
||||||
return grad_desc;
|
return grad_desc;
|
||||||
}
|
}
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using value_list = std::vector<Value*>;
|
using value_list = std::vector<Value*>;
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@ -94,5 +93,4 @@ TORCH_API bool isDifferentiable(const Node* n);
|
|||||||
TORCH_API bool isDifferentiable(Graph& g);
|
TORCH_API bool isDifferentiable(Graph& g);
|
||||||
TORCH_API bool isZero(Value* v);
|
TORCH_API bool isZero(Value* v);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
// Calculates the number of args that need to be passed in.
|
// Calculates the number of args that need to be passed in.
|
||||||
// Less args may be needed if defaults are provided.
|
// Less args may be needed if defaults are provided.
|
||||||
@ -67,5 +66,4 @@ inline std::pair<int64_t, int64_t> CalculateNecessaryArgs(
|
|||||||
return std::make_pair(0, num_out);
|
return std::make_pair(0, num_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
/// Registration class for new operators. Effectively calls
|
/// Registration class for new operators. Effectively calls
|
||||||
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
|
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
|
||||||
@ -28,5 +27,4 @@ struct TORCH_API RegisterOperators {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -17,8 +17,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
namespace {
|
namespace {
|
||||||
std::mutex lock;
|
std::mutex lock;
|
||||||
|
|
||||||
@ -213,5 +212,4 @@ Function* GetDecompositionExecutor(const char* schema_literal) {
|
|||||||
return GetDecompositionExecutor(schema);
|
return GetDecompositionExecutor(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
TORCH_API c10::optional<std::shared_ptr<Graph>> GetDecomposition(
|
TORCH_API c10::optional<std::shared_ptr<Graph>> GetDecomposition(
|
||||||
const FunctionSchema& schema);
|
const FunctionSchema& schema);
|
||||||
@ -31,5 +30,4 @@ TORCH_API void run_jit_decomposition(
|
|||||||
|
|
||||||
TORCH_API bool has_jit_decomposition(const FunctionSchema& schema);
|
TORCH_API bool has_jit_decomposition(const FunctionSchema& schema);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -10,8 +10,7 @@
|
|||||||
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
|
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
const std::string decomp_funcs =
|
const std::string decomp_funcs =
|
||||||
R"(def var_decomposition(input: Tensor,
|
R"(def var_decomposition(input: Tensor,
|
||||||
@ -104,5 +103,4 @@ const OperatorMap<std::string>& GetDecompositionMapping() {
|
|||||||
return decomposition_mapping;
|
return decomposition_mapping;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -3,12 +3,10 @@
|
|||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
TORCH_API const std::string& GetSerializedDecompositions();
|
TORCH_API const std::string& GetSerializedDecompositions();
|
||||||
|
|
||||||
TORCH_API const OperatorMap<std::string>& GetDecompositionMapping();
|
TORCH_API const OperatorMap<std::string>& GetDecompositionMapping();
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct ExceptionMessage {
|
struct ExceptionMessage {
|
||||||
ExceptionMessage(const std::exception& e) : e_(e) {}
|
ExceptionMessage(const std::exception& e) : e_(e) {}
|
||||||
@ -27,5 +26,4 @@ inline std::ostream& operator<<(
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -53,8 +53,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
EnableProfilingGuard::EnableProfilingGuard() {
|
EnableProfilingGuard::EnableProfilingGuard() {
|
||||||
auto& executor_mode = getExecutorMode();
|
auto& executor_mode = getExecutorMode();
|
||||||
@ -1062,5 +1061,4 @@ Node* replaceBlockWithFallbackGraph(Block* b, ArrayRef<Value*> inputs) {
|
|||||||
return fallback;
|
return fallback;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -11,8 +11,7 @@
|
|||||||
|
|
||||||
C10_DECLARE_bool(torch_jit_enable_new_executor);
|
C10_DECLARE_bool(torch_jit_enable_new_executor);
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
struct GraphExecutorState;
|
struct GraphExecutorState;
|
||||||
struct Code;
|
struct Code;
|
||||||
|
|
||||||
@ -140,5 +139,4 @@ GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op);
|
|||||||
// with less plumbing.
|
// with less plumbing.
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -27,8 +27,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
void packGradient(const Gradient& gradient, Node* dnode);
|
void packGradient(const Gradient& gradient, Node* dnode);
|
||||||
bool needsGradient(const std::shared_ptr<const Graph>& graph);
|
bool needsGradient(const std::shared_ptr<const Graph>& graph);
|
||||||
@ -111,5 +110,4 @@ struct GraphExecutorImplBase {
|
|||||||
std::mutex compile_mutex;
|
std::mutex compile_mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
// This class facilitates depth-first iteration over all nodes in a graph.
|
// This class facilitates depth-first iteration over all nodes in a graph.
|
||||||
class DepthFirstGraphNodeIterator {
|
class DepthFirstGraphNodeIterator {
|
||||||
@ -145,5 +144,4 @@ class DepthFirstGraphNodeIterator {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
static std::ostream& operator<<(std::ostream& out, OpCode op) {
|
static std::ostream& operator<<(std::ostream& out, OpCode op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
#define OP_STRING(x, _) \
|
#define OP_STRING(x, _) \
|
||||||
@ -93,5 +92,4 @@ bool isOpSupportedInMobile(OpCode op) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
#include <typeinfo>
|
#include <typeinfo>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
// instruction look like:
|
// instruction look like:
|
||||||
// op_code X, N
|
// op_code X, N
|
||||||
// meaning of X, N depend on the op:
|
// meaning of X, N depend on the op:
|
||||||
@ -98,5 +97,4 @@ char const* toString(OpCode op);
|
|||||||
OpCode parseOpCode(const char* str);
|
OpCode parseOpCode(const char* str);
|
||||||
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -50,8 +50,7 @@ C10_DEFINE_bool(
|
|||||||
false,
|
false,
|
||||||
"enable rethrowing caught exception");
|
"enable rethrowing caught exception");
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using CodeImpl = interpreter::CodeImpl;
|
using CodeImpl = interpreter::CodeImpl;
|
||||||
|
|
||||||
@ -1202,5 +1201,4 @@ void InterpreterContinuation::operator()() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -26,8 +26,7 @@ struct IValue;
|
|||||||
struct OperatorName;
|
struct OperatorName;
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
// The interpreter run Graphs with Tensor inputs and Tensor outputs
|
// The interpreter run Graphs with Tensor inputs and Tensor outputs
|
||||||
// a separate component in the autograd handles unwrapping and wrapping
|
// a separate component in the autograd handles unwrapping and wrapping
|
||||||
@ -163,7 +162,6 @@ TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext(
|
|||||||
TORCH_API std::vector<StackEntry> currentCallstack();
|
TORCH_API std::vector<StackEntry> currentCallstack();
|
||||||
TORCH_API std::vector<std::string> currentModuleHierarchy();
|
TORCH_API std::vector<std::string> currentModuleHierarchy();
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
|
||||||
C10_CLANG_DIAGNOSTIC_POP()
|
C10_CLANG_DIAGNOSTIC_POP()
|
||||||
|
@ -4,9 +4,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::interpreter {
|
||||||
namespace jit {
|
|
||||||
namespace interpreter {
|
|
||||||
/*
|
/*
|
||||||
This is an optimization that reduces the number of store/load/move nodes needed
|
This is an optimization that reduces the number of store/load/move nodes needed
|
||||||
by recognizing that parts of the graph are simple trees like a*x + b*y. When
|
by recognizing that parts of the graph are simple trees like a*x + b*y. When
|
||||||
@ -105,6 +103,4 @@ struct CanEmitInline {
|
|||||||
std::unordered_map<Node*, bool> can_emit_inline_;
|
std::unordered_map<Node*, bool> can_emit_inline_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace interpreter
|
} // namespace torch::jit::interpreter
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -14,8 +14,7 @@
|
|||||||
#include <torch/csrc/jit/runtime/instruction.h>
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
#include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
|
#include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
||||||
|
|
||||||
@ -997,5 +996,4 @@ struct MobileCodeImpl : CodeImpl {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace interpreter
|
} // namespace interpreter
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
#include <torch/csrc/jit/runtime/interpreter/frame.h>
|
#include <torch/csrc/jit/runtime/interpreter/frame.h>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::interpreter {
|
||||||
namespace jit {
|
|
||||||
namespace interpreter {
|
|
||||||
|
|
||||||
/* static */ size_t Frame::genId() {
|
/* static */ size_t Frame::genId() {
|
||||||
static std::atomic<size_t> numFrames{0};
|
static std::atomic<size_t> numFrames{0};
|
||||||
return numFrames.fetch_add(1, std::memory_order_relaxed);
|
return numFrames.fetch_add(1, std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace interpreter
|
} // namespace torch::jit::interpreter
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -6,9 +6,7 @@
|
|||||||
#include <torch/csrc/jit/runtime/interpreter/code_impl.h>
|
#include <torch/csrc/jit/runtime/interpreter/code_impl.h>
|
||||||
#include <torch/csrc/jit/runtime/profiling_record.h>
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::interpreter {
|
||||||
namespace jit {
|
|
||||||
namespace interpreter {
|
|
||||||
|
|
||||||
// A Frame captures function's state
|
// A Frame captures function's state
|
||||||
// (e.g. `pc` and `base_pointer`)
|
// (e.g. `pc` and `base_pointer`)
|
||||||
@ -39,6 +37,4 @@ struct Frame {
|
|||||||
static size_t genId();
|
static size_t genId();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace interpreter
|
} // namespace torch::jit::interpreter
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -3,9 +3,7 @@
|
|||||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||||
#include <torch/csrc/jit/runtime/interpreter/can_emit_inline.h>
|
#include <torch/csrc/jit/runtime/interpreter/can_emit_inline.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::interpreter {
|
||||||
namespace jit {
|
|
||||||
namespace interpreter {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -213,6 +211,4 @@ PreprocessGraph::PreprocessGraph(Graph& g) : graph(g.copy()) {
|
|||||||
insertLastUses(*graph);
|
insertLastUses(*graph);
|
||||||
can_emit_inline = std::move(CanEmitInline(*graph.get()).can_emit_inline_);
|
can_emit_inline = std::move(CanEmitInline(*graph.get()).can_emit_inline_);
|
||||||
}
|
}
|
||||||
} // namespace interpreter
|
} // namespace torch::jit::interpreter
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,9 +5,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::interpreter {
|
||||||
namespace jit {
|
|
||||||
namespace interpreter {
|
|
||||||
|
|
||||||
// pre-processing that happens once per graph
|
// pre-processing that happens once per graph
|
||||||
struct PreprocessGraph {
|
struct PreprocessGraph {
|
||||||
@ -18,6 +16,4 @@ struct PreprocessGraph {
|
|||||||
std::unordered_map<Node*, bool> can_emit_inline;
|
std::unordered_map<Node*, bool> can_emit_inline;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace interpreter
|
} // namespace torch::jit::interpreter
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
static thread_local std::string caughtOriginalMsg = "";
|
static thread_local std::string caughtOriginalMsg = "";
|
||||||
static thread_local std::string caughtPythonClassName = "";
|
static thread_local std::string caughtPythonClassName = "";
|
||||||
@ -28,5 +27,4 @@ void JITException::setCaughtPythonClassName(
|
|||||||
caughtPythonClassName = pythonClassName;
|
caughtPythonClassName = pythonClassName;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct TORCH_API JITException : public std::runtime_error {
|
struct TORCH_API JITException : public std::runtime_error {
|
||||||
explicit JITException(
|
explicit JITException(
|
||||||
@ -36,5 +35,4 @@ struct TORCH_API JITException : public std::runtime_error {
|
|||||||
c10::optional<std::string> original_msg_;
|
c10::optional<std::string> original_msg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
#include <torch/csrc/jit/runtime/profiling_record.h>
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
|
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -315,5 +313,4 @@ std::shared_ptr<Graph> TraceGraph(std::shared_ptr<Graph> graph, Stack& stack) {
|
|||||||
GRAPH_DUMP("Traced graph:", td.traced_graph_);
|
GRAPH_DUMP("Traced graph:", td.traced_graph_);
|
||||||
return td.traced_graph_;
|
return td.traced_graph_;
|
||||||
}
|
}
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
TORCH_API std::shared_ptr<Graph> TraceGraph(
|
TORCH_API std::shared_ptr<Graph> TraceGraph(
|
||||||
std::shared_ptr<Graph> graph,
|
std::shared_ptr<Graph> graph,
|
||||||
Stack& stack);
|
Stack& stack);
|
||||||
}
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -5,9 +5,7 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::logging {
|
||||||
namespace jit {
|
|
||||||
namespace logging {
|
|
||||||
|
|
||||||
// TODO: multi-scale histogram for this thing
|
// TODO: multi-scale histogram for this thing
|
||||||
|
|
||||||
@ -68,6 +66,4 @@ void recordDurationSince(const std::string& name, const JITTimePoint& tp) {
|
|||||||
logging::getLogger()->addStatValue(name, seconds);
|
logging::getLogger()->addStatValue(name, seconds);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace logging
|
} // namespace torch::jit::logging
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -7,9 +7,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit::logging {
|
||||||
namespace jit {
|
|
||||||
namespace logging {
|
|
||||||
|
|
||||||
class LoggerBase {
|
class LoggerBase {
|
||||||
public:
|
public:
|
||||||
@ -85,6 +83,4 @@ inline std::vector<const char*> allRuntimeCounters() {
|
|||||||
|
|
||||||
} // namespace runtime_counters
|
} // namespace runtime_counters
|
||||||
|
|
||||||
} // namespace logging
|
} // namespace torch::jit::logging
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
@ -10,8 +10,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using OperatorMap =
|
using OperatorMap =
|
||||||
@ -447,5 +446,4 @@ std::string canonicalSchemaString(const FunctionSchema& schema) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -23,8 +23,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
struct Node;
|
struct Node;
|
||||||
using ::c10::Argument;
|
using ::c10::Argument;
|
||||||
@ -323,5 +322,4 @@ c10::optional<Operator> OperatorGenerator(
|
|||||||
alias_analysis));
|
alias_analysis));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
#include <ATen/core/dispatch/OperatorOptions.h>
|
#include <ATen/core/dispatch/OperatorOptions.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch::jit {
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
using AliasAnalysisKind = c10::AliasAnalysisKind;
|
using AliasAnalysisKind = c10::AliasAnalysisKind;
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace torch::jit
|
||||||
} // namespace torch
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user