mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Fixes #9092. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9491 Pull Request resolved: https://github.com/pytorch/pytorch/pull/9693 Differential Revision: D8946850 Pulled By: ezyang fbshipit-source-id: bd816f459ab70f6b4a0983305a1ce341bb633707
This commit is contained in:
committed by
Facebook Github Bot
parent
9ee5133651
commit
53083b8353
1
.gitignore
vendored
1
.gitignore
vendored
@ -49,6 +49,7 @@ torch/csrc/nn/THNN.cpp
|
||||
torch/csrc/nn/THNN.cwrap
|
||||
torch/lib/*.a*
|
||||
torch/lib/*.dll*
|
||||
torch/lib/*.exe*
|
||||
torch/lib/*.dylib*
|
||||
torch/lib/*.h
|
||||
torch/lib/*.lib
|
||||
|
@ -152,10 +152,6 @@ endif()
|
||||
# ---[ CMake scripts + modules
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
|
||||
|
||||
if (MSVC AND ${BUILD_SHARED_LIBS})
|
||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
endif()
|
||||
|
||||
# ---[ CMake build directories
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <ATen/optional.h>
|
||||
#include <ATen/Backtrace.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
@ -4,9 +4,11 @@
|
||||
#include <string>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <ATen/ATenGeneral.h>
|
||||
|
||||
namespace at {
|
||||
/// Utility to demangle a C++ symbol name.
|
||||
std::string demangle(const char* name);
|
||||
AT_API std::string demangle(const char* name);
|
||||
|
||||
/// Returns the printable name of the type.
|
||||
template <typename T>
|
||||
@ -19,7 +21,7 @@ inline const char* demangle_type() {
|
||||
#endif // __GXX_RTTI
|
||||
}
|
||||
|
||||
std::string get_backtrace(
|
||||
AT_API std::string get_backtrace(
|
||||
size_t frames_to_skip = 0,
|
||||
size_t maximum_number_of_frames = 64,
|
||||
bool skip_python_frames = true);
|
||||
|
@ -111,8 +111,8 @@ struct Device {
|
||||
};
|
||||
} // namespace at
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, at::Device::Type type);
|
||||
std::ostream& operator<<(std::ostream& stream, const at::Device& device);
|
||||
AT_API std::ostream& operator<<(std::ostream& stream, at::Device::Type type);
|
||||
AT_API std::ostream& operator<<(std::ostream& stream, const at::Device& device);
|
||||
|
||||
namespace std {
|
||||
template<> struct hash<at::Device>
|
||||
|
@ -10,7 +10,7 @@
|
||||
namespace at {
|
||||
|
||||
AT_API std::vector<int64_t> infer_size(IntList a, IntList b);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t> > inferExpandGeometry(
|
||||
AT_API std::tuple<std::vector<int64_t>, std::vector<int64_t> > inferExpandGeometry(
|
||||
IntList tensor_sizes, IntList tensor_strides, IntList sizes);
|
||||
|
||||
// avoid copy-construction of Tensor by using a reference_wrapper.
|
||||
|
@ -35,8 +35,8 @@ namespace at {
|
||||
|
||||
namespace detail {
|
||||
|
||||
float halfbits2float(unsigned short bits);
|
||||
unsigned short float2halfbits(float value);
|
||||
AT_API float halfbits2float(unsigned short bits);
|
||||
AT_API unsigned short float2halfbits(float value);
|
||||
|
||||
}
|
||||
|
||||
|
@ -33,6 +33,8 @@
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include <ATen/ATenGeneral.h>
|
||||
|
||||
#if __GNUG__ && __GNUC__ < 5
|
||||
#define AT_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
|
||||
#else
|
||||
@ -57,7 +59,7 @@ static inline uint64_t NextPowerOf2(uint64_t A) {
|
||||
}
|
||||
|
||||
/// This is all the non-templated stuff common to all SmallVectors.
|
||||
class SmallVectorBase {
|
||||
class AT_API SmallVectorBase {
|
||||
protected:
|
||||
void *BeginX, *EndX, *CapacityX;
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include "ATen/Error.h"
|
||||
|
||||
namespace at {
|
||||
struct SparseTensorImpl : public TensorImpl {
|
||||
struct AT_API SparseTensorImpl : public TensorImpl {
|
||||
// Stored in COO format, indices + values.
|
||||
|
||||
// Ideal INVARIANTS:
|
||||
|
@ -19,7 +19,7 @@ namespace at {
|
||||
/// `torch::TensorOptions` subclass of this `TensorOptions`, which changes
|
||||
/// `type()` to return a variable type instead of a tensor type, such that
|
||||
/// variables are created inside factory methods, instead of tensors.
|
||||
struct TensorOptions {
|
||||
struct AT_API TensorOptions {
|
||||
TensorOptions() : TensorOptions(/*use_thread_local_default_options=*/true) {}
|
||||
|
||||
/// Constructs the `TensorOptions` with defaults taken from the thread local
|
||||
|
@ -5,6 +5,8 @@
|
||||
|
||||
#include "cuda_runtime_api.h"
|
||||
|
||||
#include <ATen/ATenGeneral.h>
|
||||
|
||||
/*
|
||||
* A CUDA stream interface with no CUDA build dependency.
|
||||
*
|
||||
@ -23,31 +25,31 @@ namespace detail {
|
||||
|
||||
// Pointer-based API (for internal use)
|
||||
// Note: ATen/Context is preferred to work with streams safely
|
||||
CUDAStreamInternals* CUDAStream_getDefaultStreamOnDevice(int64_t device);
|
||||
CUDAStreamInternals* CUDAStream_getDefaultStream();
|
||||
AT_API CUDAStreamInternals* CUDAStream_getDefaultStreamOnDevice(int64_t device);
|
||||
AT_API CUDAStreamInternals* CUDAStream_getDefaultStream();
|
||||
|
||||
CUDAStreamInternals* CUDAStream_createAndRetainWithOptions(int32_t flags, int32_t priority);
|
||||
AT_API CUDAStreamInternals* CUDAStream_createAndRetainWithOptions(int32_t flags, int32_t priority);
|
||||
|
||||
CUDAStreamInternals* CUDAStream_getAndRetainCurrentStreamOnDevice(int64_t device);
|
||||
CUDAStreamInternals* CUDAStream_getAndRetainCurrentStream();
|
||||
AT_API CUDAStreamInternals* CUDAStream_getAndRetainCurrentStreamOnDevice(int64_t device);
|
||||
AT_API CUDAStreamInternals* CUDAStream_getAndRetainCurrentStream();
|
||||
|
||||
// Note: these Unsafe gets should NEVER be used and are only here for legacy
|
||||
// purposes. Once those uses are gone they should be removed.
|
||||
CUDAStreamInternals* CUDAStream_getCurrentStreamOnDeviceUnsafe(int64_t device);
|
||||
CUDAStreamInternals* CUDAStream_getCurrentStreamUnsafe();
|
||||
AT_API CUDAStreamInternals* CUDAStream_getCurrentStreamOnDeviceUnsafe(int64_t device);
|
||||
AT_API CUDAStreamInternals* CUDAStream_getCurrentStreamUnsafe();
|
||||
|
||||
void CUDAStream_setStreamOnDevice(int64_t device, CUDAStreamInternals* internals);
|
||||
void CUDAStream_uncheckedSetStreamOnDevice(
|
||||
AT_API void CUDAStream_setStreamOnDevice(int64_t device, CUDAStreamInternals* internals);
|
||||
AT_API void CUDAStream_uncheckedSetStreamOnDevice(
|
||||
int64_t device,
|
||||
CUDAStreamInternals* internals);
|
||||
void CUDAStream_setStream(CUDAStreamInternals* internals);
|
||||
AT_API void CUDAStream_setStream(CUDAStreamInternals* internals);
|
||||
|
||||
cudaStream_t CUDAStream_stream(CUDAStreamInternals*);
|
||||
int64_t CUDAStream_device(CUDAStreamInternals*);
|
||||
AT_API cudaStream_t CUDAStream_stream(CUDAStreamInternals*);
|
||||
AT_API int64_t CUDAStream_device(CUDAStreamInternals*);
|
||||
|
||||
bool CUDAStream_retain(CUDAStreamInternals*);
|
||||
void CUDAStream_free(CUDAStreamInternals*&);
|
||||
void CUDAStream_uncheckedFree(CUDAStreamInternals*&);
|
||||
AT_API bool CUDAStream_retain(CUDAStreamInternals*);
|
||||
AT_API void CUDAStream_free(CUDAStreamInternals*&);
|
||||
AT_API void CUDAStream_uncheckedFree(CUDAStreamInternals*&);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -71,10 +73,10 @@ struct CUDAStream {
|
||||
~CUDAStream() { detail::CUDAStream_uncheckedFree(internals_); }
|
||||
|
||||
// Copy constructor
|
||||
CUDAStream(const CUDAStream& other);
|
||||
AT_API CUDAStream(const CUDAStream& other);
|
||||
|
||||
// Move constructor
|
||||
CUDAStream(CUDAStream&& other);
|
||||
AT_API CUDAStream(CUDAStream&& other);
|
||||
|
||||
// Assignment operator
|
||||
CUDAStream& operator=(CUDAStream other) noexcept {
|
||||
|
@ -143,7 +143,7 @@ static inline ${return_type} ${api_name}(${formals}) {
|
||||
""")
|
||||
# add a native declaration for a native function
|
||||
NATIVE_DECLARATION = CodeTemplate("""\
|
||||
${return_type} ${native_type_method_dispatch}(${formals_with_defaults});
|
||||
AT_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});
|
||||
""")
|
||||
|
||||
# special method definition for factory functions in Functions.h
|
||||
|
@ -35,11 +35,14 @@
|
||||
#ifdef _WIN32
|
||||
# if defined(ATen_cpu_EXPORTS) || defined(caffe2_EXPORTS)
|
||||
# define TH_API TH_EXTERNC __declspec(dllexport)
|
||||
# define TH_CPP_API __declspec(dllexport)
|
||||
# else
|
||||
# define TH_API TH_EXTERNC __declspec(dllimport)
|
||||
# define TH_CPP_API __declspec(dllimport)
|
||||
# endif
|
||||
#else
|
||||
# define TH_API TH_EXTERNC
|
||||
# define TH_CPP_API
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
@ -37,7 +37,7 @@ struct THFinalizer {
|
||||
virtual ~THFinalizer() {};
|
||||
};
|
||||
|
||||
struct THStorage
|
||||
struct TH_CPP_API THStorage
|
||||
{
|
||||
THStorage() = delete;
|
||||
THStorage(at::ScalarType, ptrdiff_t, at::DataPtr, at::Allocator*, char);
|
||||
|
@ -33,14 +33,14 @@
|
||||
// If it is not, you must report that the storage is dead.
|
||||
//
|
||||
|
||||
ptrdiff_t THStorage_size(const THStorage *self);
|
||||
TH_API ptrdiff_t THStorage_size(const THStorage *self);
|
||||
|
||||
void THStorage_setFlag(THStorage *storage, const char flag);
|
||||
void THStorage_clearFlag(THStorage *storage, const char flag);
|
||||
void THStorage_retain(THStorage *storage);
|
||||
void THStorage_resize(THStorage *storage, ptrdiff_t size);
|
||||
void THStorage_swap(THStorage *storage1, THStorage *storage2);
|
||||
TH_API void THStorage_setFlag(THStorage *storage, const char flag);
|
||||
TH_API void THStorage_clearFlag(THStorage *storage, const char flag);
|
||||
TH_API void THStorage_retain(THStorage *storage);
|
||||
TH_API void THStorage_resize(THStorage *storage, ptrdiff_t size);
|
||||
TH_API void THStorage_swap(THStorage *storage1, THStorage *storage2);
|
||||
|
||||
void THStorage_weakRetain(THStorage *weak_storage);
|
||||
void THStorage_weakFree(THStorage *weak_storage);
|
||||
THStorage* THStorage_weakLock(THStorage *weak_storage);
|
||||
TH_API void THStorage_weakRetain(THStorage *weak_storage);
|
||||
TH_API void THStorage_weakFree(THStorage *weak_storage);
|
||||
TH_API THStorage* THStorage_weakLock(THStorage *weak_storage);
|
||||
|
@ -149,5 +149,5 @@ inline void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage)
|
||||
}
|
||||
|
||||
TH_API void THTensor_free(THTensor *self);
|
||||
at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
|
||||
at::IntList newshape);
|
||||
TH_CPP_API at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
|
||||
at::IntList newshape);
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "caffe2/utils/proto_wrap.h"
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
#include <google/protobuf/stubs/common.h>
|
||||
#include <google/protobuf/generated_message_util.h>
|
||||
@ -8,7 +9,7 @@ namespace caffe {
|
||||
// Caffe wrapper functions for protobuf's GetEmptyStringAlreadyInited() function
|
||||
// used to avoid duplicated global variable in the case when protobuf
|
||||
// is built with hidden visibility.
|
||||
const ::std::string& GetEmptyStringAlreadyInited() {
|
||||
CAFFE2_API const ::std::string& GetEmptyStringAlreadyInited() {
|
||||
return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
|
||||
}
|
||||
|
||||
@ -19,7 +20,7 @@ namespace ONNX_NAMESPACE {
|
||||
// ONNX wrapper functions for protobuf's GetEmptyStringAlreadyInited() function
|
||||
// used to avoid duplicated global variable in the case when protobuf
|
||||
// is built with hidden visibility.
|
||||
const ::std::string& GetEmptyStringAlreadyInited() {
|
||||
CAFFE2_API const ::std::string& GetEmptyStringAlreadyInited() {
|
||||
return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
|
||||
}
|
||||
|
||||
@ -30,7 +31,7 @@ namespace caffe2 {
|
||||
// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() function
|
||||
// used to avoid duplicated global variable in the case when protobuf
|
||||
// is built with hidden visibility.
|
||||
const ::std::string& GetEmptyStringAlreadyInited() {
|
||||
CAFFE2_API const ::std::string& GetEmptyStringAlreadyInited() {
|
||||
return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/TensorGeometry.h>
|
||||
|
||||
#include "torch/csrc/THP_export.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/saved_variable.h"
|
||||
|
@ -3,6 +3,7 @@
|
||||
// Engine implements backpropagation from output variables and their gradients
|
||||
// to "root" variables (variables created by the user with requires_grad=True).
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/input_buffer.h"
|
||||
#include "torch/csrc/autograd/anomaly_mode.h"
|
||||
@ -24,7 +25,7 @@ struct GraphTask;
|
||||
namespace torch { namespace autograd {
|
||||
// A single instance of this struct should be created through the whole process lifetime.
|
||||
// The worker thread creation logic and Engine's destructor rely on this.
|
||||
struct Engine {
|
||||
struct TORCH_API Engine {
|
||||
/// Returns a reference to a static `Engine` instance.
|
||||
static Engine& get_default_engine();
|
||||
|
||||
@ -67,6 +68,6 @@ protected:
|
||||
|
||||
// allow python_engine to override the default engine when it loads
|
||||
typedef Engine& (*EngineStub)(void);
|
||||
void set_default_engine_stub(EngineStub stub);
|
||||
TORCH_API void set_default_engine_stub(EngineStub stub);
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/grad_mode.h"
|
||||
#include "torch/csrc/autograd/anomaly_mode.h"
|
||||
@ -84,7 +85,7 @@ void deleteFunction(Function* function);
|
||||
/// are created in one thread and `C` is created in a new thread, there are *no
|
||||
/// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
struct Function : std::enable_shared_from_this<Function> {
|
||||
struct TORCH_API Function : std::enable_shared_from_this<Function> {
|
||||
public:
|
||||
/// Construct a new `Function` with `num_inputs` inputs and the given
|
||||
/// `next_edges`. sequence_nr is a (currently THE) hint to prioritization
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/symbolic.h"
|
||||
@ -10,7 +11,7 @@
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct Error : public Function {
|
||||
struct TORCH_API Error : public Function {
|
||||
Error(std::string msg, edge_list&& next_edges)
|
||||
: Function(std::move(next_edges))
|
||||
, msg(std::move(msg)) {}
|
||||
@ -24,7 +25,7 @@ struct Error : public Function {
|
||||
};
|
||||
|
||||
// Identity in forward, Error in backward. Used to implement @once_differentiable
|
||||
struct DelayedError : public Function {
|
||||
struct TORCH_API DelayedError : public Function {
|
||||
DelayedError(std::string msg, int num_inputs)
|
||||
: msg(std::move(msg)) {
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
@ -36,7 +37,7 @@ struct DelayedError : public Function {
|
||||
std::string msg;
|
||||
};
|
||||
|
||||
struct GraphRoot : public Function {
|
||||
struct TORCH_API GraphRoot : public Function {
|
||||
GraphRoot(edge_list functions, variable_list inputs)
|
||||
: Function(std::move(functions)),
|
||||
outputs(std::move(inputs)) {}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
@ -18,12 +19,12 @@ using function_constructor = std::function<std::shared_ptr<Function>(edge_list&&
|
||||
* Wraps the tensor outputs in variables and creates the grad_fn and sets the
|
||||
* grad_fn if necessary.
|
||||
*/
|
||||
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
|
||||
function_constructor ctr);
|
||||
TORCH_API variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
|
||||
function_constructor ctr);
|
||||
|
||||
/// Checks that inputs contains exactly `args` items and that the first `required_args`
|
||||
/// items are not nullptr. If not specified, `required_args` defaults to `args`.
|
||||
void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1);
|
||||
TORCH_API void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1);
|
||||
|
||||
struct ComputeRequiresGrad : IterArgs<ComputeRequiresGrad> {
|
||||
bool out = false;
|
||||
|
@ -1,15 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct GradMode {
|
||||
struct TORCH_API GradMode {
|
||||
static bool is_enabled();
|
||||
static void set_enabled(bool enabled);
|
||||
};
|
||||
|
||||
// A RAII, thread local (!) guard that enables or disables grad mode upon
|
||||
// construction, and sets it back to the original value upon destruction.
|
||||
struct AutoGradMode {
|
||||
struct TORCH_API AutoGradMode {
|
||||
AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) {
|
||||
GradMode::set_enabled(enabled);
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <forward_list>
|
||||
#include <tuple>
|
||||
#include "ATen/ATen.h"
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/cuda/cuda_check.h"
|
||||
#ifdef USE_CUDA
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
@ -163,12 +164,12 @@ enum class ProfilerState {
|
||||
NVTX, // only emit NVTX markers
|
||||
};
|
||||
|
||||
RangeEventList& getEventList();
|
||||
void mark(std::string name, bool include_cuda = true);
|
||||
void pushRange(std::string name);
|
||||
void popRange();
|
||||
TORCH_API RangeEventList& getEventList();
|
||||
TORCH_API void mark(std::string name, bool include_cuda = true);
|
||||
TORCH_API void pushRange(std::string name);
|
||||
TORCH_API void popRange();
|
||||
|
||||
struct RecordFunction {
|
||||
struct TORCH_API RecordFunction {
|
||||
explicit RecordFunction(Function* fn);
|
||||
|
||||
explicit RecordFunction(std::string name);
|
||||
@ -184,8 +185,8 @@ struct RecordFunction {
|
||||
using thread_event_lists = std::vector<std::vector<Event>>;
|
||||
// NOTE: changing profiler modes is **NOT THREAD SAFE**. You should ensure that
|
||||
// there no autograd functions are being executed when these function are used.
|
||||
void enableProfiler(ProfilerState state);
|
||||
thread_event_lists disableProfiler();
|
||||
TORCH_API void enableProfiler(ProfilerState state);
|
||||
TORCH_API thread_event_lists disableProfiler();
|
||||
|
||||
} // namespace profiler
|
||||
}} // namespace torch::autograd
|
||||
|
@ -18,7 +18,7 @@ TORCH_API extern const char* ERR_BACKWARD_TWICE;
|
||||
|
||||
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
|
||||
/// enough information to reconstruct a variable from a certain point in time.
|
||||
class SavedVariable {
|
||||
class TORCH_API SavedVariable {
|
||||
public:
|
||||
SavedVariable() = default;
|
||||
SavedVariable(const Variable& variable, bool is_output);
|
||||
|
@ -77,7 +77,7 @@ struct Function;
|
||||
/// free function instead. To create a view variable, use `make_variable_view`.
|
||||
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
struct Variable : public at::Tensor {
|
||||
struct TORCH_API Variable : public at::Tensor {
|
||||
/// Default constructor.
|
||||
Variable() = default;
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
@ -84,11 +85,11 @@ struct Gradient {
|
||||
};
|
||||
// XXX: When calling this function, graph should have complete type information.
|
||||
// Use the shape analysis pass to fill in the gaps if it doesn't.
|
||||
Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad);
|
||||
TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad);
|
||||
|
||||
// can we take a derivative of this node symbolically?
|
||||
bool isDifferentiable(Node * n);
|
||||
bool isDifferentiable(Graph & g);
|
||||
bool isZero(Value * v);
|
||||
TORCH_API bool isDifferentiable(Node * n);
|
||||
TORCH_API bool isDifferentiable(Graph & g);
|
||||
TORCH_API bool isZero(Value * v);
|
||||
|
||||
}}
|
||||
|
@ -15,7 +15,7 @@ namespace torch { namespace jit {
|
||||
// file contents being the raw tensor data.
|
||||
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
|
||||
|
||||
std::tuple<std::string, RawDataExportMap> ExportGraph(
|
||||
TORCH_API std::tuple<std::string, RawDataExportMap> ExportGraph(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::vector<at::Tensor>& initializers,
|
||||
int64_t onnx_opset_version,
|
||||
@ -24,7 +24,7 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
|
||||
= ::torch::onnx::OperatorExportTypes::ONNX);
|
||||
|
||||
// For testing purposes
|
||||
std::string PrettyPrintExportedGraph(
|
||||
TORCH_API std::string PrettyPrintExportedGraph(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::vector<at::Tensor> & initializers,
|
||||
int64_t onnx_opset_version,
|
||||
|
@ -33,7 +33,7 @@ struct GraphExecutorState {
|
||||
};
|
||||
|
||||
struct GraphExecutorImpl;
|
||||
struct GraphExecutor {
|
||||
struct TORCH_API GraphExecutor {
|
||||
GraphExecutor() {}
|
||||
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
|
||||
// note: if not specified, symbolically_differentiable is computed from the graph.
|
||||
@ -51,17 +51,17 @@ private:
|
||||
|
||||
// These passes need to run before it is valid to pass to the interpreter
|
||||
// regardless of whether sizes have been specialized or not.
|
||||
void runRequiredPasses(const std::shared_ptr<Graph>& g);
|
||||
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
|
||||
|
||||
// specialize 'graph' to the types, sizes, and other properties described in spec
|
||||
// this prepares the graph for execution, including running runRequiredPasses,
|
||||
// but the execution only remains valid for tensors whose properties match spec
|
||||
// otherwise running the graph will have undefined results.
|
||||
void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec);
|
||||
TORCH_API void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec);
|
||||
|
||||
// apply standard optimizations. if graphMustSupportVariables=false then
|
||||
// then the passes are allowed to modify the graph in ways that make it no longer
|
||||
// work with tensors that have requires_grad=True
|
||||
void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables);
|
||||
TORCH_API void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
std::shared_ptr<Graph> ImportIRGraph(const std::string& serialized_graph, std::vector<at::Tensor> & initializers);
|
||||
TORCH_API std::shared_ptr<Graph> ImportIRGraph(const std::string& serialized_graph, std::vector<at::Tensor> & initializers);
|
||||
|
||||
}}
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/jit/generated/aten_interned_strings.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
@ -132,7 +133,7 @@ static const std::string domain_prefix = "org.pytorch.";
|
||||
// A Symbol is like an interned string, but with a little extra
|
||||
// structure; it is namespaced via SymbolNamespace and the resulting
|
||||
// intern pointers support efficient namespace testing.
|
||||
struct Symbol {
|
||||
struct TORCH_API Symbol {
|
||||
explicit constexpr Symbol() : value(0) {};
|
||||
explicit constexpr Symbol(unique_t uniq)
|
||||
: value(uniq) {}
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <vector>
|
||||
#include "ATen/optional.h"
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
namespace at {
|
||||
struct Tensor;
|
||||
}
|
||||
@ -22,7 +24,7 @@ struct TensorType;
|
||||
struct IValue;
|
||||
using Stack = std::vector<IValue>;
|
||||
|
||||
struct Code {
|
||||
struct TORCH_API Code {
|
||||
Code()
|
||||
: pImpl(nullptr) {}
|
||||
Code(std::shared_ptr<Graph>& graph);
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "torch/csrc/utils/python_stub.h"
|
||||
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "ATen/ArrayRef.h"
|
||||
@ -51,9 +52,9 @@ struct Node;
|
||||
// Tensor or an opaque Handle object, as determined by type().
|
||||
struct Value;
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const Graph & g);
|
||||
std::ostream& operator<<(std::ostream & out, const Type & t);
|
||||
std::ostream& operator<<(std::ostream & out, const Node & t);
|
||||
TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g);
|
||||
TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t);
|
||||
TORCH_API std::ostream& operator<<(std::ostream & out, const Node & t);
|
||||
|
||||
// A list of nodes, with inputs and outputs
|
||||
struct Block;
|
||||
@ -195,7 +196,7 @@ public:
|
||||
bool hasUniqueName() const {
|
||||
return unique_name_ != "";
|
||||
}
|
||||
Value* setUniqueName(const std::string & name);
|
||||
TORCH_API Value* setUniqueName(const std::string & name);
|
||||
std::string uniqueName() const {
|
||||
if (hasUniqueName())
|
||||
return unique_name_;
|
||||
@ -813,7 +814,7 @@ struct Block {
|
||||
// to the inputs, nodes, and outputs of this block
|
||||
// value_map is used whenever a node in src references a free variable
|
||||
// in src to look up its corresponding value
|
||||
void cloneFrom(Block * src, std::function<Value*(Value*)> value_map);
|
||||
TORCH_API void cloneFrom(Block * src, std::function<Value*(Value*)> value_map);
|
||||
private:
|
||||
// should only be called in the constructor
|
||||
Node* initOutput(Node* p) {
|
||||
@ -1069,9 +1070,9 @@ public:
|
||||
}
|
||||
|
||||
// Checks well-formedness and invariants of graph
|
||||
void lint() const;
|
||||
TORCH_API void lint() const;
|
||||
// for use in debugger
|
||||
void dump() const;
|
||||
TORCH_API void dump() const;
|
||||
|
||||
~Graph() {
|
||||
for (const Node * n : all_nodes)
|
||||
@ -1089,7 +1090,7 @@ public:
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream & out, const Graph & g);
|
||||
std::shared_ptr<Graph> copy();
|
||||
TORCH_API std::shared_ptr<Graph> copy();
|
||||
|
||||
private:
|
||||
|
||||
@ -1338,8 +1339,8 @@ struct PythonOp : public Node {
|
||||
|
||||
};
|
||||
// patched in when python bindings are loaded
|
||||
PythonOp* allocPythonOp(Graph* g);
|
||||
void setAllocPythonOp(PythonOp* (*v)(Graph* g));
|
||||
TORCH_API PythonOp* allocPythonOp(Graph* g);
|
||||
TORCH_API void setAllocPythonOp(PythonOp* (*v)(Graph* g));
|
||||
|
||||
inline Node* Graph::createPythonOp(
|
||||
THPObjectPtr&& pyobj,
|
||||
@ -1365,6 +1366,6 @@ inline const_graph_node_list_iterator Node::reverseIterator() const {
|
||||
return iterator().reverse();
|
||||
}
|
||||
|
||||
void LintGraph(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
@ -13,7 +13,7 @@ FunctionSchema parseSchema(const std::string& decl);
|
||||
|
||||
using OperationCreator = std::function<Operation(Node*)>;
|
||||
|
||||
struct Operator {
|
||||
struct TORCH_API Operator {
|
||||
Operator(FunctionSchema schema, OperationCreator op, OperationCreator op_const_attributes = nullptr)
|
||||
: schema(std::move(schema))
|
||||
, op(std::move(op))
|
||||
@ -65,7 +65,7 @@ void registerOperator(Operator&& op);
|
||||
// XXX: this function is meant to be used with string literals only!
|
||||
Operator& sig(const char *signature_literal);
|
||||
|
||||
struct RegisterOperators {
|
||||
struct TORCH_API RegisterOperators {
|
||||
RegisterOperators(std::vector<Operator> operators) {
|
||||
for(Operator& o : operators) {
|
||||
registerOperator(std::move(o));
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void BatchMM(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void BatchMM(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -8,6 +8,6 @@ struct Graph;
|
||||
// insert GraphExecutor nodes that group together
|
||||
// subgraphs that are differentiable by the jit's autodiff passes
|
||||
// threshold - minimum number of nodes that will appear in a block
|
||||
void CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
|
||||
TORCH_API void CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
|
||||
|
||||
}}
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
|
||||
void EliminateDeadCode(Block *block, bool recurse=true);
|
||||
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
|
||||
|
||||
}}
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
||||
#include "torch/csrc/jit/passes/decompose_addmm.h"
|
||||
#include "torch/csrc/jit/symbolic_variable.h"
|
||||
#include "torch/csrc/jit/tensor_conversions.h"
|
||||
|
||||
|
@ -9,6 +9,6 @@ namespace torch { namespace jit {
|
||||
//
|
||||
// In the future, if we need more passes like this, we should convert this
|
||||
// into a generic canonicalization pass.
|
||||
void DecomposeAddmm(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void DecomposeAddmm(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -15,6 +15,6 @@ namespace torch { namespace jit {
|
||||
// - prim::TensorToNum, and prim::NumToTensor nodes are erased.
|
||||
//
|
||||
// The pass assumes that DCE will be called sometime after.
|
||||
void EraseNumberTypes(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void EraseNumberTypes(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -6,6 +6,6 @@ namespace torch { namespace jit {
|
||||
|
||||
// NB: Be sure to run DCE before fusion, because dead instructions
|
||||
// can prevent fusion opportunities from being exploited.
|
||||
void FuseGraph(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void FuseGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void CheckInplace(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void CheckInplace(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void UnrollLoops(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void UnrollLoops(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
@ -10,6 +10,6 @@ namespace torch { namespace jit {
|
||||
// outputs = <original_computation>
|
||||
// else:
|
||||
// outputs = undefineds
|
||||
void LowerGradOf(Graph& graph);
|
||||
TORCH_API void LowerGradOf(Graph& graph);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void LowerTuples(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void LowerTuples(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& state, ::torch::onnx::OperatorExportTypes operator_export_type);
|
||||
void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, std::unordered_map<Value*, Value*> env);
|
||||
TORCH_API std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& state, ::torch::onnx::OperatorExportTypes operator_export_type);
|
||||
TORCH_API void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, std::unordered_map<Value*, Value*> env);
|
||||
|
||||
}}
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void PeepholeOptimize(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void PeepholeOptimize(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include "torch/csrc/jit/passes/remove_expands.h"
|
||||
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
@ -4,6 +4,6 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
void RemoveExpands(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void RemoveExpands(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
||||
|
@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
struct Graph;
|
||||
struct ArgumentSpec;
|
||||
void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec);
|
||||
TORCH_API void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec);
|
||||
|
||||
}}
|
||||
|
@ -11,6 +11,6 @@ namespace torch { namespace jit {
|
||||
// operations generated by the symbolic autodiff code and cleans up
|
||||
// AutogradAdds when possible. Outputs of other nodes are conservatively
|
||||
// marked Unknown and not optimized.
|
||||
void specializeUndef(Graph & g, const std::vector<bool>& defined);
|
||||
TORCH_API void specializeUndef(Graph & g, const std::vector<bool>& defined);
|
||||
|
||||
}}
|
||||
|
@ -11,9 +11,9 @@ private:
|
||||
std::unordered_map<Value*, std::vector<Value*>> batch_map;
|
||||
public:
|
||||
static std::unordered_map<std::string, std::shared_ptr<Graph>> batch_operator_table;
|
||||
void toBatch(Block* block, Block* res_block);
|
||||
TORCH_API void toBatch(Block* block, Block* res_block);
|
||||
};
|
||||
|
||||
std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph);
|
||||
void initRegisterBatchOpsBindings(PyObject* module);
|
||||
TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
|
||||
}}
|
||||
|
@ -83,7 +83,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
|
||||
|
||||
// most things in the environment are just simple value types
|
||||
// and not special python syntax sugar types
|
||||
struct SimpleValue : public SugaredValue {
|
||||
struct TORCH_API SimpleValue : public SugaredValue {
|
||||
SimpleValue(Value * value)
|
||||
: value(value) {}
|
||||
virtual std::string kind() const override {
|
||||
@ -101,7 +101,7 @@ private:
|
||||
Value* value;
|
||||
};
|
||||
|
||||
struct BuiltinFunction : public SugaredValue {
|
||||
struct TORCH_API BuiltinFunction : public SugaredValue {
|
||||
BuiltinFunction(const std::string& name, at::optional<NamedValue> value)
|
||||
: name(name), value(std::move(value)) {}
|
||||
std::string name;
|
||||
@ -121,7 +121,7 @@ struct BuiltinFunction : public SugaredValue {
|
||||
};
|
||||
|
||||
using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string& name)>;
|
||||
void defineMethodsInModule(
|
||||
TORCH_API void defineMethodsInModule(
|
||||
Module & m,
|
||||
const std::vector<Def>& definitions,
|
||||
const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/
|
||||
@ -129,20 +129,20 @@ void defineMethodsInModule(
|
||||
);
|
||||
|
||||
// same as above but parse the definitions from source
|
||||
void defineMethodsInModule(Module & m, const std::string& source, const Resolver& resolver, std::shared_ptr<SugaredValue> self);
|
||||
std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver);
|
||||
TORCH_API void defineMethodsInModule(Module & m, const std::string& source, const Resolver& resolver, std::shared_ptr<SugaredValue> self);
|
||||
TORCH_API std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver);
|
||||
|
||||
// pack outputs of a function following python rules. If there is a single value return
|
||||
// a SimpleValue, otherwise pack all the values into a Tuple.
|
||||
std::shared_ptr<SugaredValue> packOutputs(Graph& g, at::ArrayRef<Value*> values);
|
||||
std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
|
||||
void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what);
|
||||
void ensureTensors(const SourceRange& range, at::ArrayRef<Value*> values);
|
||||
TORCH_API std::shared_ptr<SugaredValue> packOutputs(Graph& g, at::ArrayRef<Value*> values);
|
||||
TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
|
||||
TORCH_API void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what);
|
||||
TORCH_API void ensureTensors(const SourceRange& range, at::ArrayRef<Value*> values);
|
||||
|
||||
// try to match a list if inputs and keyword 'attributes' to this schema,
|
||||
// if it works return the flat list of positional inputs to the call
|
||||
// if it returns nullopt, then failure_messages contains a good error report
|
||||
at::optional<std::vector<Value*>> tryMatchSchema(
|
||||
TORCH_API at::optional<std::vector<Value*>> tryMatchSchema(
|
||||
const FunctionSchema& schema,
|
||||
const SourceRange& loc,
|
||||
Graph& graph,
|
||||
|
@ -953,7 +953,7 @@ void testProto() {
|
||||
proto.set_producer_name("foo");
|
||||
}
|
||||
|
||||
std::string runJITCPPTests() {
|
||||
TORCH_API std::string runJITCPPTests() {
|
||||
std::stringstream out;
|
||||
testIValue();
|
||||
testControlFlow();
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
#include "torch/csrc/utils/functional.h"
|
||||
#include "torch/csrc/utils/variadic.h"
|
||||
#include "torch/csrc/autograd/function_hook.h"
|
||||
@ -19,7 +20,7 @@ namespace torch { namespace jit { namespace tracer {
|
||||
using torch::autograd::Variable;
|
||||
using variable_list = std::vector<Variable>;
|
||||
|
||||
struct TracingState : public std::enable_shared_from_this<TracingState> {
|
||||
struct TORCH_API TracingState : public std::enable_shared_from_this<TracingState> {
|
||||
TracingState();
|
||||
~TracingState();
|
||||
|
||||
@ -59,10 +60,10 @@ struct ArgumentStash {
|
||||
return stash.intlists.empty();
|
||||
}
|
||||
|
||||
static void stashIntListElem(const std::string& arg_name,
|
||||
size_t size,
|
||||
size_t idx,
|
||||
const Variable& var);
|
||||
TORCH_API static void stashIntListElem(const std::string& arg_name,
|
||||
size_t size,
|
||||
size_t idx,
|
||||
const Variable& var);
|
||||
|
||||
static bool hasIntList(const std::string& arg_name) {
|
||||
return stash.intlists.count(arg_name) > 0;
|
||||
@ -80,8 +81,8 @@ private:
|
||||
};
|
||||
|
||||
// Retrieve or set the current tracing state. Returns a nullptr if tracing is disabled.
|
||||
const std::shared_ptr<TracingState>& getTracingState();
|
||||
void setTracingState(std::shared_ptr<TracingState> state);
|
||||
TORCH_API const std::shared_ptr<TracingState>& getTracingState();
|
||||
TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
|
||||
|
||||
inline bool isTracing() {
|
||||
return static_cast<bool>(getTracingState());
|
||||
@ -191,11 +192,11 @@ struct PreTraceInfo {
|
||||
Node *n;
|
||||
};
|
||||
|
||||
PreTraceInfo preRecordTrace(Symbol op, at::ArrayRef<Variable> inputs);
|
||||
void postRecordTrace(const PreTraceInfo& info, at::ArrayRef<Variable> outputs);
|
||||
TORCH_API PreTraceInfo preRecordTrace(Symbol op, at::ArrayRef<Variable> inputs);
|
||||
TORCH_API void postRecordTrace(const PreTraceInfo& info, at::ArrayRef<Variable> outputs);
|
||||
|
||||
void recordSourceLocation(Node* n);
|
||||
void setRecordSourceLocation(void (*v)(Node*));
|
||||
TORCH_API void recordSourceLocation(Node* n);
|
||||
TORCH_API void setRecordSourceLocation(void (*v)(Node*));
|
||||
|
||||
// We must record the nodes of inputs before we actually carry out
|
||||
// the operation, because an inplace operation may destroy the information
|
||||
@ -221,6 +222,6 @@ PreTraceInfo makePreTraceInfo(at::ArrayRef<Variable> inputs, F ctor) {
|
||||
return info;
|
||||
}
|
||||
|
||||
autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
|
||||
TORCH_API autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
|
||||
|
||||
}}} // namespace torch::jit::tracer
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "torch/csrc/jit/interned_strings.h"
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
@ -29,7 +30,7 @@ struct Type;
|
||||
using TypePtr = std::shared_ptr<Type>;
|
||||
|
||||
|
||||
struct Type : std::enable_shared_from_this<Type> {
|
||||
struct TORCH_API Type : std::enable_shared_from_this<Type> {
|
||||
|
||||
private:
|
||||
TypeKind kind_;
|
||||
@ -89,7 +90,7 @@ inline bool operator!=(const Type & lhs, const Type & rhs) {
|
||||
}
|
||||
|
||||
// This node represents a single Tensor value, with an unknown shape.
|
||||
struct DynamicType : public Type {
|
||||
struct TORCH_API DynamicType : public Type {
|
||||
DynamicType()
|
||||
: Type(TypeKind::DynamicType) {}
|
||||
bool operator==(const Type& rhs) const override {
|
||||
@ -106,7 +107,7 @@ struct DynamicType : public Type {
|
||||
struct TensorType;
|
||||
using TensorTypePtr = std::shared_ptr<TensorType>;
|
||||
// This node represents a single Tensor value with a specific size
|
||||
struct TensorType : public Type {
|
||||
struct TORCH_API TensorType : public Type {
|
||||
friend struct Type;
|
||||
TensorType(const at::Tensor& tensor)
|
||||
: Type(TypeKind::TensorType)
|
||||
@ -185,7 +186,7 @@ private:
|
||||
std::vector<int64_t> strides_;
|
||||
};
|
||||
|
||||
struct ListType : public Type {
|
||||
struct TORCH_API ListType : public Type {
|
||||
friend struct Type;
|
||||
static const TypeKind Kind = TypeKind::ListType;
|
||||
ListType(TypePtr elem)
|
||||
@ -211,7 +212,7 @@ private:
|
||||
TypePtr elem;
|
||||
};
|
||||
|
||||
struct TupleType : public Type {
|
||||
struct TORCH_API TupleType : public Type {
|
||||
friend struct Type;
|
||||
TupleType(std::vector<TypePtr> elements_)
|
||||
: Type(TypeKind::TupleType)
|
||||
@ -268,7 +269,7 @@ private:
|
||||
};
|
||||
|
||||
// This node represents a Python number value
|
||||
struct NumberType : public Type {
|
||||
struct TORCH_API NumberType : public Type {
|
||||
NumberType()
|
||||
: Type(TypeKind::NumberType) {}
|
||||
bool operator==(const Type& rhs) const override {
|
||||
@ -283,7 +284,7 @@ struct NumberType : public Type {
|
||||
};
|
||||
|
||||
// This node represents a Python float number value
|
||||
struct FloatType : public Type {
|
||||
struct TORCH_API FloatType : public Type {
|
||||
FloatType()
|
||||
: Type(TypeKind::FloatType) {}
|
||||
bool operator==(const Type& rhs) const override {
|
||||
@ -301,7 +302,7 @@ struct FloatType : public Type {
|
||||
};
|
||||
|
||||
// This node represents a Python int number value
|
||||
struct IntType : public Type {
|
||||
struct TORCH_API IntType : public Type {
|
||||
IntType()
|
||||
: Type(TypeKind::IntType) {}
|
||||
bool operator==(const Type& rhs) const override {
|
||||
@ -319,6 +320,6 @@ struct IntType : public Type {
|
||||
};
|
||||
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const Type & t);
|
||||
TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t);
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "torch/csrc/onnx/onnx.npb.h"
|
||||
#include "torch/csrc/assertions.h"
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
#include <pb_encode.h>
|
||||
#include <ATen/ATen.h>
|
||||
@ -417,7 +418,7 @@ public:
|
||||
opset_import.emplace_back(ptr);
|
||||
return ptr;
|
||||
}
|
||||
void dump(std::ostream& stream, size_t indent = 0);
|
||||
TORCH_API void dump(std::ostream& stream, size_t indent = 0);
|
||||
std::string prettyPrint() {
|
||||
std::stringstream ss;
|
||||
dump(ss, 0);
|
||||
|
Reference in New Issue
Block a user