C++ changes toward libtorch and libcaffe2 unification (#19554)

Summary:
* adds TORCH_API and AT_CUDA_API in places
* refactor code generation Python logic to separate
  caffe2/torch outputs
* fix hip and asan
* remove profiler_cuda from hip
* fix gcc warnings for enums
* Fix PythonOp::Kind
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19554

Differential Revision: D15082727

Pulled By: kostmo

fbshipit-source-id: 83a8a99717f025ab44b29608848928d76b3147a4
This commit is contained in:
Karl Ostmo
2019-04-26 01:26:49 -07:00
committed by Facebook Github Bot
parent 9d180e602f
commit 8f0603b128
59 changed files with 309 additions and 231 deletions

View File

@ -6,7 +6,7 @@
#ifdef _WIN32 #ifdef _WIN32
#if !defined(AT_CORE_STATIC_WINDOWS) #if !defined(AT_CORE_STATIC_WINDOWS)
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB) # if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(caffe2_hip_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
# define AT_CUDA_API __declspec(dllexport) # define AT_CUDA_API __declspec(dllexport)
# else # else
# define AT_CUDA_API __declspec(dllimport) # define AT_CUDA_API __declspec(dllimport)
@ -15,7 +15,7 @@
# define AT_CUDA_API # define AT_CUDA_API
#endif #endif
#elif defined(__GNUC__) #elif defined(__GNUC__)
#if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) #if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(caffe2_hip_EXPORTS)
#define AT_CUDA_API __attribute__((__visibility__("default"))) #define AT_CUDA_API __attribute__((__visibility__("default")))
#else #else
#define AT_CUDA_API #define AT_CUDA_API

View File

@ -22,7 +22,7 @@
#endif #endif
// ROCM hcc doesn't work well with using std:: in kernel functions // ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__) #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
#include <c10/cuda/CUDAMathCompat.h> #include <c10/cuda/CUDAMathCompat.h>
#define compat_pow c10::cuda::compat::pow #define compat_pow c10::cuda::compat::pow
#else #else

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/bash -xe
############################################################################## ##############################################################################
# Example command to build the iOS target. # Example command to build the iOS target.
############################################################################## ##############################################################################
@ -7,8 +7,6 @@
# using ios-cmake. This is very similar to the android-cmake - see # using ios-cmake. This is very similar to the android-cmake - see
# build_android.sh for more details. # build_android.sh for more details.
set -e
CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)"
# Build protobuf from third_party so we have a host protoc binary. # Build protobuf from third_party so we have a host protoc binary.

View File

@ -2,6 +2,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#endif #endif
#include <c10/macros/Export.h>
// To add a new test file: // To add a new test file:
// 1. Add a test_foo.h file in this directory // 1. Add a test_foo.h file in this directory
// 2. include test_base.h // 2. include test_base.h

View File

@ -194,14 +194,41 @@ def gen_autograd(aten_path, out, autograd_dir):
gen_variable_type(out, aten_decls, template_path) gen_variable_type(out, aten_decls, template_path)
# Generate Functions.h/cpp # Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions from .gen_autograd_functions import gen_autograd_functions_lib
gen_autograd_functions( gen_autograd_functions_lib(
out, autograd_functions, template_path) out, autograd_functions, template_path)
# Load deprecated signatures # Load deprecated signatures
deprecated = load_deprecated_signatures( deprecated = load_deprecated_signatures(
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml')) aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
# Generate variable_factories.h
from .gen_variable_factories import gen_variable_factories
gen_variable_factories(out, aten_decls, template_path)
def gen_autograd_python(aten_path, out, autograd_dir):
# TODO Deduplicate these four variable assignments
aten_decls = load_aten_declarations(aten_path)
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
template_path = os.path.join(autograd_dir, 'templates')
# Load deprecated signatures
deprecated = load_deprecated_signatures(
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, autograd_functions, template_path)
# Generate Python bindings # Generate Python bindings
from . import gen_python_functions from . import gen_python_functions
gen_python_functions.gen_py_variable_methods( gen_python_functions.gen_py_variable_methods(
@ -211,10 +238,6 @@ def gen_autograd(aten_path, out, autograd_dir):
gen_python_functions.gen_py_nn_functions( gen_python_functions.gen_py_nn_functions(
out, aten_decls, template_path) out, aten_decls, template_path)
# Generate variable_factories.h
from .gen_variable_factories import gen_variable_factories
gen_variable_factories(out, aten_decls, template_path)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -4,13 +4,14 @@
# Functions.h/cpp: subclasses of autograd::Function # Functions.h/cpp: subclasses of autograd::Function
# python_functions.h/cpp: Python bindings for the above classes # python_functions.h/cpp: Python bindings for the above classes
# #
import os
import re import re
from .utils import nested_dict, CodeTemplate, write from .utils import nested_dict, CodeTemplate, write
from .gen_autograd import VIEW_FUNCTIONS from .gen_autograd import VIEW_FUNCTIONS
from .utils import IDENT_REGEX from .utils import IDENT_REGEX
FUNCTION_DECLARATION = CodeTemplate("""\ FUNCTION_DECLARATION = CodeTemplate("""\
struct ${op} : public ${superclass} { struct TORCH_API ${op} : public ${superclass} {
using ${superclass}::${superclass}; using ${superclass}::${superclass};
variable_list apply(variable_list&& grads) override; variable_list apply(variable_list&& grads) override;
std::string name() const override { return "${op}"; } std::string name() const override { return "${op}"; }
@ -81,18 +82,21 @@ if (should_compute_output({ ${idx_ranges} })) {
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def gen_autograd_functions(out, autograd_functions, template_path): def gen_autograd_functions_lib(out, autograd_functions, template_path):
gen_autograd_functions(out, autograd_functions, template_path, "Functions")
def gen_autograd_functions_python(out, autograd_functions, template_path):
gen_autograd_functions(out, autograd_functions, template_path, "python_functions")
def gen_autograd_functions(out, autograd_functions, template_path, file_basename):
"""Functions.h and Functions.cpp body """Functions.h and Functions.cpp body
These contain the auto-generated subclasses of torch::autograd::Function These contain the auto-generated subclasses of torch::autograd::Function
for each every differentiable torch function. for each every differentiable torch function.
""" """
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
function_definitions = [] function_definitions = []
function_declarations = [] function_declarations = []
py_function_initializers = [] py_function_initializers = []
@ -110,10 +114,10 @@ def gen_autograd_functions(out, autograd_functions, template_path):
'py_function_initializers': py_function_initializers, 'py_function_initializers': py_function_initializers,
} }
write(out, 'Functions.h', FUNCTIONS_H, top_env) for suffix in [".h", ".cpp"]:
write(out, 'Functions.cpp', FUNCTIONS_CPP, top_env) f = file_basename + suffix
write(out, 'python_functions.h', PY_FUNCTIONS_H, top_env) templated_output = CodeTemplate.from_file(os.path.join(template_path, f))
write(out, 'python_functions.cpp', PY_FUNCTIONS_CPP, top_env) write(out, f, templated_output, top_env)
def process_function(func): def process_function(func):

View File

@ -10,6 +10,7 @@
#include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/saved_variable.h" #include "torch/csrc/autograd/saved_variable.h"
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace autograd { namespace generated { namespace torch { namespace autograd { namespace generated {

View File

@ -38,6 +38,7 @@ libtorch_sources = [
"torch/csrc/autograd/anomaly_mode.cpp", "torch/csrc/autograd/anomaly_mode.cpp",
"torch/csrc/autograd/engine.cpp", "torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp", "torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/function_hook.cpp",
"torch/csrc/autograd/functions/accumulate_grad.cpp", "torch/csrc/autograd/functions/accumulate_grad.cpp",
"torch/csrc/autograd/functions/basic_ops.cpp", "torch/csrc/autograd/functions/basic_ops.cpp",
"torch/csrc/autograd/functions/tensor.cpp", "torch/csrc/autograd/functions/tensor.cpp",
@ -107,6 +108,7 @@ libtorch_sources = [
"torch/csrc/jit/script/schema_matching.cpp", "torch/csrc/jit/script/schema_matching.cpp",
"torch/csrc/jit/script/class_type.cpp", "torch/csrc/jit/script/class_type.cpp",
"torch/csrc/jit/script/parser.cpp", "torch/csrc/jit/script/parser.cpp",
"torch/csrc/jit/script/jit_exception.cpp",
"torch/csrc/jit/testing/file_check.cpp", "torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/jit/import_source.cpp", "torch/csrc/jit/import_source.cpp",
"torch/csrc/jit/hooks_for_testing.cpp", "torch/csrc/jit/hooks_for_testing.cpp",

View File

@ -19,59 +19,33 @@ def all_generator_source():
return sorted(r) return sorted(r)
inputs = [
'torch/lib/THNN.h',
'torch/lib/THCUNN.h',
'torch/share/ATen/Declarations.yaml',
'tools/autograd/derivatives.yaml',
'tools/autograd/deprecated.yaml',
]
outputs = [
'torch/csrc/autograd/generated/Functions.cpp',
'torch/csrc/autograd/generated/Functions.h',
'torch/csrc/autograd/generated/python_functions.cpp',
'torch/csrc/autograd/generated/python_functions.h',
'torch/csrc/autograd/generated/python_nn_functions.cpp',
'torch/csrc/autograd/generated/python_nn_functions.h',
'torch/csrc/autograd/generated/python_nn_functions_dispatch.h',
'torch/csrc/autograd/generated/python_variable_methods.cpp',
'torch/csrc/autograd/generated/python_variable_methods_dispatch.h',
'torch/csrc/autograd/generated/variable_factories.h',
'torch/csrc/autograd/generated/VariableType_0.cpp',
'torch/csrc/autograd/generated/VariableType_1.cpp',
'torch/csrc/autograd/generated/VariableType_2.cpp',
'torch/csrc/autograd/generated/VariableType_3.cpp',
'torch/csrc/autograd/generated/VariableType_4.cpp',
'torch/csrc/autograd/generated/VariableType.h',
'torch/csrc/jit/generated/register_aten_ops_0.cpp',
'torch/csrc/jit/generated/register_aten_ops_1.cpp',
'torch/csrc/jit/generated/register_aten_ops_2.cpp',
]
def generate_code(ninja_global=None, def generate_code(ninja_global=None,
declarations_path=None, declarations_path=None,
nn_path=None, nn_path=None,
install_dir=None): install_dir=None,
subset=None):
# cwrap depends on pyyaml, so we can't import it earlier # cwrap depends on pyyaml, so we can't import it earlier
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, root) sys.path.insert(0, root)
from tools.autograd.gen_autograd import gen_autograd from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
from tools.jit.gen_jit_dispatch import gen_jit_dispatch from tools.jit.gen_jit_dispatch import gen_jit_dispatch
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
# used by the legacy NN bindings.
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
# Build ATen based Variable classes # Build ATen based Variable classes
autograd_gen_dir = install_dir or 'torch/csrc/autograd/generated' autograd_gen_dir = install_dir or 'torch/csrc/autograd/generated'
jit_gen_dir = install_dir or 'torch/csrc/jit/generated' jit_gen_dir = install_dir or 'torch/csrc/jit/generated'
for d in (autograd_gen_dir, jit_gen_dir): for d in (autograd_gen_dir, jit_gen_dir):
if not os.path.exists(d): if not os.path.exists(d):
os.makedirs(d) os.makedirs(d)
if subset == "pybindings" or not subset:
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
# used by the legacy NN bindings.
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
gen_autograd_python(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
if subset == "libtorch" or not subset:
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd') gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates') gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates')
@ -82,11 +56,18 @@ def main():
parser.add_argument('--nn-path') parser.add_argument('--nn-path')
parser.add_argument('--ninja-global') parser.add_argument('--ninja-global')
parser.add_argument('--install_dir') parser.add_argument('--install_dir')
parser.add_argument(
'--subset',
help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
)
options = parser.parse_args() options = parser.parse_args()
generate_code(options.ninja_global, generate_code(
options.ninja_global,
options.declarations_path, options.declarations_path,
options.nn_path, options.nn_path,
options.install_dir) options.install_dir,
options.subset,
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -104,6 +104,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp ${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
${TORCH_SRC_DIR}/csrc/autograd/engine.cpp ${TORCH_SRC_DIR}/csrc/autograd/engine.cpp
${TORCH_SRC_DIR}/csrc/autograd/function.cpp ${TORCH_SRC_DIR}/csrc/autograd/function.cpp
${TORCH_SRC_DIR}/csrc/autograd/function_hook.cpp
${TORCH_SRC_DIR}/csrc/autograd/functions/accumulate_grad.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/accumulate_grad.cpp
${TORCH_SRC_DIR}/csrc/autograd/functions/basic_ops.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/basic_ops.cpp
${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp
@ -190,6 +191,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp ${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
@ -230,6 +232,14 @@ if (USE_CUDA)
) )
endif() endif()
if (USE_ROCM)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
)
endif()
if (NOT NO_API) if (NOT NO_API)
list(APPEND TORCH_SRCS list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp ${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
@ -543,7 +553,6 @@ if (BUILD_PYTHON)
${TORCH_SRC_DIR}/csrc/utils/structseq.cpp ${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_list.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_list.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_new.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_new.cpp
@ -615,7 +624,6 @@ if (BUILD_PYTHON)
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp ${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp ${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp ${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp ${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp ${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
@ -662,7 +670,6 @@ if (BUILD_PYTHON)
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp ${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp ${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp ${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp ${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp ${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp

View File

@ -29,17 +29,17 @@ class Sampler {
/// Resets the `Sampler`'s internal state. /// Resets the `Sampler`'s internal state.
/// Typically called before a new epoch. /// Typically called before a new epoch.
/// Optionally, accepts a new size when reseting the sampler. /// Optionally, accepts a new size when reseting the sampler.
TORCH_API virtual void reset(optional<size_t> new_size) = 0; virtual void reset(optional<size_t> new_size) = 0;
/// Returns the next index if possible, or an empty optional if the /// Returns the next index if possible, or an empty optional if the
/// sampler is exhausted for this epoch. /// sampler is exhausted for this epoch.
TORCH_API virtual optional<BatchRequest> next(size_t batch_size) = 0; virtual optional<BatchRequest> next(size_t batch_size) = 0;
/// Serializes the `Sampler` to the `archive`. /// Serializes the `Sampler` to the `archive`.
TORCH_API virtual void save(serialize::OutputArchive& archive) const = 0; virtual void save(serialize::OutputArchive& archive) const = 0;
/// Deserializes the `Sampler` from the `archive`. /// Deserializes the `Sampler` from the `archive`.
TORCH_API virtual void load(serialize::InputArchive& archive) = 0; virtual void load(serialize::InputArchive& archive) = 0;
}; };
} // namespace samplers } // namespace samplers

View File

@ -25,7 +25,7 @@ namespace samplers {
template <typename BatchRequest = std::vector<size_t>> template <typename BatchRequest = std::vector<size_t>>
class DistributedSampler : public Sampler<BatchRequest> { class DistributedSampler : public Sampler<BatchRequest> {
public: public:
TORCH_API DistributedSampler( DistributedSampler(
size_t size, size_t size,
size_t num_replicas = 1, size_t num_replicas = 1,
size_t rank = 0, size_t rank = 0,
@ -64,28 +64,28 @@ class DistributedSampler : public Sampler<BatchRequest> {
/// Select samples randomly. The sampling order is shuffled at each `reset()` /// Select samples randomly. The sampling order is shuffled at each `reset()`
/// call. /// call.
class DistributedRandomSampler : public DistributedSampler<> { class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
public: public:
TORCH_API DistributedRandomSampler( DistributedRandomSampler(
size_t size, size_t size,
size_t num_replicas = 1, size_t num_replicas = 1,
size_t rank = 0, size_t rank = 0,
bool allow_duplicates = true); bool allow_duplicates = true);
/// Resets the `DistributedRandomSampler` to a new set of indices. /// Resets the `DistributedRandomSampler` to a new set of indices.
TORCH_API void reset(optional<size_t> new_size = nullopt) override; void reset(optional<size_t> new_size = nullopt) override;
/// Returns the next batch of indices. /// Returns the next batch of indices.
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override; optional<std::vector<size_t>> next(size_t batch_size) override;
/// Serializes the `DistributedRandomSampler` to the `archive`. /// Serializes the `DistributedRandomSampler` to the `archive`.
TORCH_API void save(serialize::OutputArchive& archive) const override; void save(serialize::OutputArchive& archive) const override;
/// Deserializes the `DistributedRandomSampler` from the `archive`. /// Deserializes the `DistributedRandomSampler` from the `archive`.
TORCH_API void load(serialize::InputArchive& archive) override; void load(serialize::InputArchive& archive) override;
/// Returns the current index of the `DistributedRandomSampler`. /// Returns the current index of the `DistributedRandomSampler`.
TORCH_API size_t index() const noexcept; size_t index() const noexcept;
private: private:
void populate_indices(); void populate_indices();
@ -97,28 +97,28 @@ class DistributedRandomSampler : public DistributedSampler<> {
}; };
/// Select samples sequentially. /// Select samples sequentially.
class DistributedSequentialSampler : public DistributedSampler<> { class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
public: public:
TORCH_API DistributedSequentialSampler( DistributedSequentialSampler(
size_t size, size_t size,
size_t num_replicas = 1, size_t num_replicas = 1,
size_t rank = 0, size_t rank = 0,
bool allow_duplicates = true); bool allow_duplicates = true);
/// Resets the `DistributedSequentialSampler` to a new set of indices. /// Resets the `DistributedSequentialSampler` to a new set of indices.
TORCH_API void reset(optional<size_t> new_size = nullopt) override; void reset(optional<size_t> new_size = nullopt) override;
/// Returns the next batch of indices. /// Returns the next batch of indices.
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override; optional<std::vector<size_t>> next(size_t batch_size) override;
/// Serializes the `DistributedSequentialSampler` to the `archive`. /// Serializes the `DistributedSequentialSampler` to the `archive`.
TORCH_API void save(serialize::OutputArchive& archive) const override; void save(serialize::OutputArchive& archive) const override;
/// Deserializes the `DistributedSequentialSampler` from the `archive`. /// Deserializes the `DistributedSequentialSampler` from the `archive`.
TORCH_API void load(serialize::InputArchive& archive) override; void load(serialize::InputArchive& archive) override;
/// Returns the current index of the `DistributedSequentialSampler`. /// Returns the current index of the `DistributedSequentialSampler`.
TORCH_API size_t index() const noexcept; size_t index() const noexcept;
private: private:
void populate_indices(); void populate_indices();

View File

@ -19,31 +19,33 @@ namespace data {
namespace samplers { namespace samplers {
/// A `Sampler` that returns random indices. /// A `Sampler` that returns random indices.
class RandomSampler : public Sampler<> { class TORCH_API RandomSampler : public Sampler<> {
public: public:
/// Constructs a `RandomSampler` with a size and dtype for the stored indices. /// Constructs a `RandomSampler` with a size and dtype for the stored indices.
/// ///
/// The constructor will eagerly allocate all required indices, which is the /// The constructor will eagerly allocate all required indices, which is the
/// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
/// indices. You can change it to influence memory usage. /// indices. You can change it to influence memory usage.
TORCH_API explicit RandomSampler( explicit RandomSampler(
int64_t size, int64_t size,
Dtype index_dtype = torch::kInt64); Dtype index_dtype = torch::kInt64);
~RandomSampler() override;
/// Resets the `RandomSampler` to a new set of indices. /// Resets the `RandomSampler` to a new set of indices.
TORCH_API void reset(optional<size_t> new_size = nullopt) override; void reset(optional<size_t> new_size = nullopt) override;
/// Returns the next batch of indices. /// Returns the next batch of indices.
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override; optional<std::vector<size_t>> next(size_t batch_size) override;
/// Serializes the `RandomSampler` to the `archive`. /// Serializes the `RandomSampler` to the `archive`.
TORCH_API void save(serialize::OutputArchive& archive) const override; void save(serialize::OutputArchive& archive) const override;
/// Deserializes the `RandomSampler` from the `archive`. /// Deserializes the `RandomSampler` from the `archive`.
TORCH_API void load(serialize::InputArchive& archive) override; void load(serialize::InputArchive& archive) override;
/// Returns the current index of the `RandomSampler`. /// Returns the current index of the `RandomSampler`.
TORCH_API size_t index() const noexcept; size_t index() const noexcept;
private: private:
Tensor indices_; Tensor indices_;

View File

@ -19,26 +19,26 @@ namespace data {
namespace samplers { namespace samplers {
/// A `Sampler` that returns indices sequentially. /// A `Sampler` that returns indices sequentially.
class SequentialSampler : public Sampler<> { class TORCH_API SequentialSampler : public Sampler<> {
public: public:
/// Creates a `SequentialSampler` that will return indices in the range /// Creates a `SequentialSampler` that will return indices in the range
/// `0...size - 1`. /// `0...size - 1`.
TORCH_API explicit SequentialSampler(size_t size); explicit SequentialSampler(size_t size);
/// Resets the `SequentialSampler` to zero. /// Resets the `SequentialSampler` to zero.
TORCH_API void reset(optional<size_t> new_size = nullopt) override; void reset(optional<size_t> new_size = nullopt) override;
/// Returns the next batch of indices. /// Returns the next batch of indices.
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override; optional<std::vector<size_t>> next(size_t batch_size) override;
/// Serializes the `SequentialSampler` to the `archive`. /// Serializes the `SequentialSampler` to the `archive`.
TORCH_API void save(serialize::OutputArchive& archive) const override; void save(serialize::OutputArchive& archive) const override;
/// Deserializes the `SequentialSampler` from the `archive`. /// Deserializes the `SequentialSampler` from the `archive`.
TORCH_API void load(serialize::InputArchive& archive) override; void load(serialize::InputArchive& archive) override;
/// Returns the current index of the `SequentialSampler`. /// Returns the current index of the `SequentialSampler`.
TORCH_API size_t index() const noexcept; size_t index() const noexcept;
private: private:
size_t size_; size_t size_;

View File

@ -32,26 +32,26 @@ struct TORCH_API BatchSize : public CustomBatchRequest {
/// The major feature of the `StreamSampler` is that it does not return /// The major feature of the `StreamSampler` is that it does not return
/// particular indices, but instead only the number of elements to fetch from /// particular indices, but instead only the number of elements to fetch from
/// the dataset. The dataset has to decide how to produce those elements. /// the dataset. The dataset has to decide how to produce those elements.
class StreamSampler : public Sampler<BatchSize> { class TORCH_API StreamSampler : public Sampler<BatchSize> {
public: public:
/// Constructs the `StreamSampler` with the number of individual examples that /// Constructs the `StreamSampler` with the number of individual examples that
/// should be fetched until the sampler is exhausted. /// should be fetched until the sampler is exhausted.
TORCH_API explicit StreamSampler(size_t epoch_size); explicit StreamSampler(size_t epoch_size);
/// Resets the internal state of the sampler. /// Resets the internal state of the sampler.
TORCH_API void reset(optional<size_t> new_size = nullopt) override; void reset(optional<size_t> new_size = nullopt) override;
/// Returns a `BatchSize` object with the number of elements to fetch in the /// Returns a `BatchSize` object with the number of elements to fetch in the
/// next batch. This number is the minimum of the supplied `batch_size` and /// next batch. This number is the minimum of the supplied `batch_size` and
/// the difference between the `epoch_size` and the current index. If the /// the difference between the `epoch_size` and the current index. If the
/// `epoch_size` has been reached, returns an empty optional. /// `epoch_size` has been reached, returns an empty optional.
TORCH_API optional<BatchSize> next(size_t batch_size) override; optional<BatchSize> next(size_t batch_size) override;
/// Serializes the `StreamSampler` to the `archive`. /// Serializes the `StreamSampler` to the `archive`.
TORCH_API void save(serialize::OutputArchive& archive) const override; void save(serialize::OutputArchive& archive) const override;
/// Deserializes the `StreamSampler` from the `archive`. /// Deserializes the `StreamSampler` from the `archive`.
TORCH_API void load(serialize::InputArchive& archive) override; void load(serialize::InputArchive& archive) override;
private: private:
size_t examples_retrieved_so_far_ = 0; size_t examples_retrieved_so_far_ = 0;

View File

@ -5,6 +5,8 @@
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/types.h> #include <torch/types.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -4,6 +4,8 @@
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/types.h> #include <torch/types.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -12,6 +12,8 @@ namespace samplers {
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype) RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
: indices_(torch::randperm(size, index_dtype)) {} : indices_(torch::randperm(size, index_dtype)) {}
RandomSampler::~RandomSampler() = default;
void RandomSampler::reset(optional<size_t> new_size) { void RandomSampler::reset(optional<size_t> new_size) {
// This allocates a new chunk of memory every time (just FYI). It should be // This allocates a new chunk of memory every time (just FYI). It should be
// amortized over the entire epoch hopefully. // amortized over the entire epoch hopefully.

View File

@ -4,4 +4,6 @@ namespace torch { namespace autograd {
bool AnomalyMode::_enabled = false; bool AnomalyMode::_enabled = false;
AnomalyMetadata::~AnomalyMetadata() = default;
}} }}

View File

@ -4,7 +4,7 @@
namespace torch { namespace autograd { namespace torch { namespace autograd {
struct AnomalyMode { struct TORCH_API AnomalyMode {
static bool is_enabled() { static bool is_enabled() {
return _enabled; return _enabled;
} }
@ -13,12 +13,12 @@ struct AnomalyMode {
} }
private: private:
TORCH_API static bool _enabled; static bool _enabled;
}; };
struct AnomalyMetadata { struct TORCH_API AnomalyMetadata {
virtual ~AnomalyMetadata() = default; virtual ~AnomalyMetadata();
virtual void store_stack() = 0; virtual void store_stack() = 0;
virtual void print_stack() = 0; virtual void print_stack() = 0;
}; };

View File

@ -0,0 +1,8 @@
#include <torch/csrc/autograd/function_hook.h>
namespace torch { namespace autograd {
FunctionPreHook::~FunctionPreHook() = default;
FunctionPostHook::~FunctionPostHook() = default;
}} // namespace torch::autograd

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <torch/csrc/WindowsTorchApiMacro.h>
// A hook that's called on gradients // A hook that's called on gradients
@ -9,13 +10,13 @@ namespace torch { namespace autograd {
struct Variable; struct Variable;
using variable_list = std::vector<Variable>; using variable_list = std::vector<Variable>;
struct FunctionPreHook { struct TORCH_API FunctionPreHook {
virtual ~FunctionPreHook() = default; virtual ~FunctionPreHook();
virtual variable_list operator()(const variable_list& grads) = 0; virtual variable_list operator()(const variable_list& grads) = 0;
}; };
struct FunctionPostHook { struct TORCH_API FunctionPostHook {
virtual ~FunctionPostHook() = default; virtual ~FunctionPostHook();
virtual variable_list operator()( virtual variable_list operator()(
const variable_list& outputs /* grad_inputs */, const variable_list& outputs /* grad_inputs */,
const variable_list& inputs /* grad_outputs */) = 0; const variable_list& inputs /* grad_outputs */) = 0;

View File

@ -2,10 +2,11 @@
#include <torch/csrc/autograd/function.h> #include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h> #include <torch/csrc/autograd/variable.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace autograd { namespace torch { namespace autograd {
struct AccumulateGrad : public Function { struct TORCH_API AccumulateGrad : public Function {
explicit AccumulateGrad(Variable variable_); explicit AccumulateGrad(Variable variable_);
variable_list apply(variable_list&& grads) override; variable_list apply(variable_list&& grads) override;

View File

@ -28,6 +28,8 @@ Scatter::Scatter(
streams_(streams), streams_(streams),
unsqueeze_scalars_(unsqueeze_scalars) {} unsqueeze_scalars_(unsqueeze_scalars) {}
Scatter::~Scatter() {}
variable_list Scatter::apply(variable_list&& inputs) { variable_list Scatter::apply(variable_list&& inputs) {
AT_ASSERT(inputs.size() == 1); AT_ASSERT(inputs.size() == 1);
auto& input = inputs.front(); auto& input = inputs.front();
@ -65,6 +67,8 @@ variable_list Scatter::apply(variable_list&& inputs) {
Gather::Gather(const at::Device& destination_device, int64_t dim) Gather::Gather(const at::Device& destination_device, int64_t dim)
: destination_device_(destination_device), dim_(dim) {} : destination_device_(destination_device), dim_(dim) {}
Gather::~Gather() {}
variable_list Gather::apply(variable_list&& inputs) { variable_list Gather::apply(variable_list&& inputs) {
bool all_are_zero_dim = true; bool all_are_zero_dim = true;
for (const auto& input : inputs) { for (const auto& input : inputs) {

View File

@ -6,6 +6,7 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
@ -13,6 +14,7 @@
namespace torch { namespace torch {
namespace autograd { namespace autograd {
//TODO: change it to TORCH_API when we merge the libs
struct TORCH_API Scatter : public Function { struct TORCH_API Scatter : public Function {
explicit Scatter( explicit Scatter(
std::vector<at::Device> devices, std::vector<at::Device> devices,
@ -21,6 +23,7 @@ struct TORCH_API Scatter : public Function {
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams = const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
c10::nullopt, c10::nullopt,
bool unsqueeze_scalars = false); bool unsqueeze_scalars = false);
~Scatter() override;
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
@ -33,6 +36,7 @@ struct TORCH_API Scatter : public Function {
struct TORCH_API Gather : public Function { struct TORCH_API Gather : public Function {
explicit Gather(const at::Device& destination_device, int64_t dim = 0); explicit Gather(const at::Device& destination_device, int64_t dim = 0);
~Gather() override;
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;

View File

@ -2,6 +2,7 @@
#include <torch/csrc/autograd/function.h> #include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h> #include <torch/csrc/autograd/variable.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/TensorGeometry.h> #include <ATen/TensorGeometry.h>
#include <ATen/core/DeprecatedTypeProperties.h> #include <ATen/core/DeprecatedTypeProperties.h>
@ -12,7 +13,7 @@
namespace torch { namespace autograd { namespace torch { namespace autograd {
struct CopyBackwards : public Function { struct TORCH_API CopyBackwards : public Function {
variable_list apply(variable_list&& grads) override; variable_list apply(variable_list&& grads) override;
at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety. at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety.
@ -26,7 +27,7 @@ struct CopyBackwards : public Function {
// grad_fn is updated to become a `CopySlice` wrapping the backward of the // grad_fn is updated to become a `CopySlice` wrapping the backward of the
// in-place operation. // in-place operation.
// See NOTE [ Autograd View Variables ]. // See NOTE [ Autograd View Variables ].
struct CopySlices : public Function { struct TORCH_API CopySlices : public Function {
CopySlices( CopySlices(
const Variable& base_var, const Variable& base_var,
at::TensorGeometry view_, at::TensorGeometry view_,

View File

@ -12,7 +12,7 @@ constexpr CUDAStubs* default_stubs_addr = &default_stubs;
// static initialization calls which may invoke registerCUDAMethods // static initialization calls which may invoke registerCUDAMethods
static CUDAStubs* cuda_stubs = default_stubs_addr; static CUDAStubs* cuda_stubs = default_stubs_addr;
TORCH_API void registerCUDAMethods(CUDAStubs* stubs) { void registerCUDAMethods(CUDAStubs* stubs) {
cuda_stubs = stubs; cuda_stubs = stubs;
} }

View File

@ -91,11 +91,27 @@ inline int64_t getTime() {
#endif #endif
} }
enum class EventKind : uint16_t { // Old GCC versions generate warnings incorrectly
// see https://stackoverflow.com/questions/2463113/g-c0x-enum-class-compiler-warnings
#ifndef _MSC_VER
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wattributes"
#endif
enum class TORCH_API ProfilerState {
Disabled,
CPU, // CPU-only profiling
CUDA, // CPU + CUDA events
NVTX, // only emit NVTX markers
};
enum class TORCH_API EventKind : uint16_t {
Mark, Mark,
PushRange, PushRange,
PopRange PopRange
}; };
#ifndef _MSC_VER
# pragma GCC diagnostic pop
#endif
struct TORCH_API Event final { struct TORCH_API Event final {
Event(EventKind kind, StringView name, uint16_t thread_id, bool record_cuda) Event(EventKind kind, StringView name, uint16_t thread_id, bool record_cuda)
@ -183,13 +199,6 @@ struct RangeEventList {
std::forward_list<block_type> blocks; std::forward_list<block_type> blocks;
}; };
enum class ProfilerState {
Disabled,
CPU, // CPU-only profiling
CUDA, // CPU + CUDA events
NVTX, // only emit NVTX markers
};
TORCH_API RangeEventList& getEventList(); TORCH_API RangeEventList& getEventList();
TORCH_API void mark(std::string name, bool include_cuda = true); TORCH_API void mark(std::string name, bool include_cuda = true);
TORCH_API void pushRange(std::string name); TORCH_API void pushRange(std::string name);

View File

@ -1,6 +1,8 @@
#pragma once #pragma once
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
@ -11,11 +13,11 @@ namespace torch { namespace cuda {
using tensor_list2d = std::vector<std::vector<at::Tensor>>; using tensor_list2d = std::vector<std::vector<at::Tensor>>;
std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices); TORCH_API std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices);
tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices, TORCH_API tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices,
size_t buffer_size); size_t buffer_size);
std::vector<at::Tensor> scatter( TORCH_API std::vector<at::Tensor> scatter(
const at::Tensor& tensor, const at::Tensor& tensor,
at::IntArrayRef devices, at::IntArrayRef devices,
const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt, const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
@ -23,7 +25,7 @@ std::vector<at::Tensor> scatter(
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams = const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
c10::nullopt); c10::nullopt);
at::Tensor gather( TORCH_API at::Tensor gather(
at::TensorList tensors, at::TensorList tensors,
int64_t dim, int64_t dim,
c10::optional<int32_t> destination_index); c10::optional<int32_t> destination_index);

View File

@ -6,6 +6,7 @@
#include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/variable_tensor_list.h> #include <torch/csrc/jit/variable_tensor_list.h>
#include <torch/csrc/utils/hash.h> #include <torch/csrc/utils/hash.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
@ -133,7 +134,7 @@ struct ArgumentSpec {
// ArgumentSpecCreator takes an initial graph and comes up with a set // ArgumentSpecCreator takes an initial graph and comes up with a set
// of simple instructions to compute the ArgumentSpec given a set of // of simple instructions to compute the ArgumentSpec given a set of
// input tensors. // input tensors.
struct ArgumentSpecCreator { struct TORCH_API ArgumentSpecCreator {
// instructs acts on a stack of a list of input IValues // instructs acts on a stack of a list of input IValues
// at the beginning the stack contains a single list of the inputs to the // at the beginning the stack contains a single list of the inputs to the
// function the ENTER_ instructs descend into subobjects and push new lists // function the ENTER_ instructs descend into subobjects and push new lists

View File

@ -84,7 +84,7 @@ struct Graph;
// We special case Graph attributes like this because we want to ensure that // We special case Graph attributes like this because we want to ensure that
// Graph::copy() is called when we clone() these attributes. // Graph::copy() is called when we clone() these attributes.
struct GraphAttr : public AttributeValue { struct TORCH_API GraphAttr : public AttributeValue {
using ConstructorType = std::shared_ptr<Graph>; using ConstructorType = std::shared_ptr<Graph>;
using ValueType = std::shared_ptr<Graph>; using ValueType = std::shared_ptr<Graph>;
GraphAttr(Symbol name, ConstructorType value_) GraphAttr(Symbol name, ConstructorType value_)
@ -92,7 +92,7 @@ struct GraphAttr : public AttributeValue {
ValueType& value() { ValueType& value() {
return value_; return value_;
} }
TORCH_API Ptr clone() const override; Ptr clone() const override;
AttributeKind kind() const override { AttributeKind kind() const override {
return AttributeKind::g; return AttributeKind::g;
} }
@ -101,7 +101,7 @@ struct GraphAttr : public AttributeValue {
std::shared_ptr<Graph> value_; std::shared_ptr<Graph> value_;
}; };
struct GraphsAttr : public AttributeValue { struct TORCH_API GraphsAttr : public AttributeValue {
using ConstructorType = std::vector<std::shared_ptr<Graph>>; using ConstructorType = std::vector<std::shared_ptr<Graph>>;
using ValueType = std::vector<std::shared_ptr<Graph>>; using ValueType = std::vector<std::shared_ptr<Graph>>;
GraphsAttr(Symbol name, ConstructorType value_) GraphsAttr(Symbol name, ConstructorType value_)
@ -112,7 +112,7 @@ struct GraphsAttr : public AttributeValue {
AttributeKind kind() const override { AttributeKind kind() const override {
return AttributeKind::gs; return AttributeKind::gs;
} }
TORCH_API std::unique_ptr<AttributeValue> clone() const override; std::unique_ptr<AttributeValue> clone() const override;
private: private:
ValueType value_; ValueType value_;

View File

@ -2,6 +2,7 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <torch/csrc/utils/disallow_copy.h> #include <torch/csrc/utils/disallow_copy.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
@ -11,11 +12,11 @@ namespace cpu {
struct DynamicLibrary { struct DynamicLibrary {
TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
DynamicLibrary(const char* name); TORCH_API DynamicLibrary(const char* name);
void* sym(const char* name); TORCH_API void* sym(const char* name);
~DynamicLibrary(); TORCH_API ~DynamicLibrary();
private: private:
void* handle = nullptr; void* handle = nullptr;

View File

@ -6,19 +6,19 @@ namespace jit {
static std::function<void(std::shared_ptr<script::Module> module)> static std::function<void(std::shared_ptr<script::Module> module)>
emit_module_callback; emit_module_callback;
TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module) { void didFinishEmitModule(std::shared_ptr<script::Module> module) {
if (emit_module_callback) { if (emit_module_callback) {
emit_module_callback(std::move(module)); emit_module_callback(std::move(module));
} }
} }
static std::function<void(std::shared_ptr<script::Function> fn)> static std::function<void(std::shared_ptr<script::Function> fn)>
emit_function_callback; emit_function_callback;
TORCH_API void didFinishEmitFunction(std::shared_ptr<script::Function> fn) { void didFinishEmitFunction(std::shared_ptr<script::Function> fn) {
if (emit_function_callback) { if (emit_function_callback) {
emit_function_callback(fn); emit_function_callback(fn);
} }
} }
TORCH_API void setEmitHooks( void setEmitHooks(
std::function<void(std::shared_ptr<script::Module> module)> for_mod, std::function<void(std::shared_ptr<script::Module> module)> for_mod,
std::function<void(std::shared_ptr<script::Function> for_fn)> for_fn) { std::function<void(std::shared_ptr<script::Function> for_fn)> for_fn) {
emit_module_callback = std::move(for_mod); emit_module_callback = std::move(for_mod);

View File

@ -43,6 +43,7 @@
#include <torch/csrc/jit/script/python_tree_views.h> #include <torch/csrc/jit/script/python_tree_views.h>
#include <torch/csrc/jit/tracer.h> #include <torch/csrc/jit/tracer.h>
#include <c10/macros/Export.h>
#include <caffe2/serialize/inline_container.h> #include <caffe2/serialize/inline_container.h>
#include <ATen/core/function_schema.h> #include <ATen/core/function_schema.h>
@ -84,7 +85,7 @@ void runJITCPPTests(bool runCuda) {
AT_ERROR("JIT tests not yet supported on Windows"); AT_ERROR("JIT tests not yet supported on Windows");
} }
#else #else
void runJITCPPTests(bool runCuda); CAFFE2_API void runJITCPPTests(bool runCuda);
#endif #endif
void initJITBindings(PyObject* module) { void initJITBindings(PyObject* module) {

View File

@ -47,11 +47,11 @@ struct TORCH_API Code {
}; };
struct InterpreterState { struct InterpreterState {
InterpreterState(const Code& code); TORCH_API InterpreterState(const Code& code);
void run(Stack& stack); TORCH_API void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack); c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> getFuture(); c10::intrusive_ptr<Future> getFuture();
~InterpreterState(); TORCH_API ~InterpreterState();
private: private:
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl); InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);

View File

@ -33,10 +33,6 @@ static constexpr topo_position_t kMidPoint = 0;
// - 2^(64-n) is the maximum number of appends to the end without reindex // - 2^(64-n) is the maximum number of appends to the end without reindex
static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */; static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
// Sigh, see
// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
constexpr Symbol PythonOp::Kind;
static void printValueRef(std::ostream& out, const Value* n) { static void printValueRef(std::ostream& out, const Value* n) {
out << "%" << n->uniqueName(); out << "%" << n->uniqueName();
} }

View File

@ -231,7 +231,7 @@ struct Value {
TORCH_API Value* copyMetadata(Value* from); TORCH_API Value* copyMetadata(Value* from);
}; };
struct Node { struct TORCH_API Node {
TH_DISALLOW_COPY_AND_ASSIGN(Node); TH_DISALLOW_COPY_AND_ASSIGN(Node);
friend struct Graph; friend struct Graph;
friend struct Block; friend struct Block;
@ -259,7 +259,7 @@ struct Node {
topo_position_t topo_position_ = 0; topo_position_t topo_position_ = 0;
protected: protected:
TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph Node(Graph* graph_, NodeKind kind_); // defined after graph
public: public:
// each node but Return/Param // each node but Return/Param
// is associated with exactly one place in the node list... // is associated with exactly one place in the node list...
@ -358,7 +358,7 @@ struct Node {
return false; return false;
} }
TORCH_API void replaceAllUsesWith(Node* n); void replaceAllUsesWith(Node* n);
// lots of things like chunk have a single input or single output, so we have // lots of things like chunk have a single input or single output, so we have
// a helper to make accessing it easier // a helper to make accessing it easier
@ -399,10 +399,10 @@ struct Node {
bool is_constant(Symbol name) const { bool is_constant(Symbol name) const {
return static_cast<bool>(get(name)); return static_cast<bool>(get(name));
} }
TORCH_API bool mustBeNone() const; bool mustBeNone() const;
TORCH_API bool isNondeterministic() const; bool isNondeterministic() const;
TORCH_API bool hasSideEffects() const; bool hasSideEffects() const;
// Graphs // Graphs
@ -424,11 +424,11 @@ struct Node {
// Given: %3 = f(%1, %2) // Given: %3 = f(%1, %2)
// Execute: %3.addInput(%4) // Execute: %3.addInput(%4)
// Result: %3 = f(%1, %2, %4) // Result: %3 = f(%1, %2, %4)
TORCH_API Value* addInput(Value* value); Value* addInput(Value* value);
// Add 'value' as an input to 'this' at the specified position in the // Add 'value' as an input to 'this' at the specified position in the
// arguments. Returns the added value for ease of chaining. // arguments. Returns the added value for ease of chaining.
TORCH_API Value* insertInput(size_t i, Value* value); Value* insertInput(size_t i, Value* value);
// Replace the input of 'this' at position 'i' with // Replace the input of 'this' at position 'i' with
// 'newValue', returning the old node. // 'newValue', returning the old node.
@ -436,7 +436,7 @@ struct Node {
// Given: %3 = f(%1, %2) // Given: %3 = f(%1, %2)
// Execute: %3.replaceInput(1, %4) // Execute: %3.replaceInput(1, %4)
// Result: %3 = f(%1, %4) // Result: %3 = f(%1, %4)
TORCH_API Value* replaceInput(size_t i, Value* newValue); Value* replaceInput(size_t i, Value* newValue);
// Replace all occurrences of 'from' in the inputs of this // Replace all occurrences of 'from' in the inputs of this
// node with 'to'. Corresponds to llvm's replaceUsesOfWith. // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
@ -444,16 +444,16 @@ struct Node {
// Given: %3 = f(%1, %2, %1) // Given: %3 = f(%1, %2, %1)
// Execute: %3.replaceInputWith(%1, %4) // Execute: %3.replaceInputWith(%1, %4)
// Result: %3 = f(%4, %2, %4) // Result: %3 = f(%4, %2, %4)
TORCH_API void replaceInputWith(Value* from, Value* to); void replaceInputWith(Value* from, Value* to);
TORCH_API Value* addOutput(); Value* addOutput();
TORCH_API Value* insertOutput(size_t i); Value* insertOutput(size_t i);
TORCH_API void eraseOutput(size_t i); void eraseOutput(size_t i);
TORCH_API Block* addBlock(); Block* addBlock();
TORCH_API void eraseBlock(size_t i); void eraseBlock(size_t i);
// Each Node can have a list of subblocks. These are used to define structured // Each Node can have a list of subblocks. These are used to define structured
// nested control flow operators such as If and Loop. // nested control flow operators such as If and Loop.
@ -482,10 +482,10 @@ struct Node {
} }
// Is 'this' before 'n' in the topological order? // Is 'this' before 'n' in the topological order?
TORCH_API bool isBefore(const Node* n) const; bool isBefore(const Node* n) const;
// Is 'this' after 'n' in the topological order? // Is 'this' after 'n' in the topological order?
TORCH_API bool isAfter(const Node* n) const; bool isAfter(const Node* n) const;
// Insert unattached 'this' node before 'n' in the topological order. // Insert unattached 'this' node before 'n' in the topological order.
// Returns this (for chaining). // Returns this (for chaining).
@ -497,7 +497,7 @@ struct Node {
// Result: %3 = f(%1, %2) // Result: %3 = f(%1, %2)
// %5 = h(%1) // %5 = h(%1)
// %4 = g(%3) // %4 = g(%3)
TORCH_API Node* insertBefore(Node* n); Node* insertBefore(Node* n);
// Insert unattached 'this' node after 'n' in the topological order. // Insert unattached 'this' node after 'n' in the topological order.
// Returns this (for chaining). // Returns this (for chaining).
@ -509,7 +509,7 @@ struct Node {
// Result: %3 = f(%1, %2) // Result: %3 = f(%1, %2)
// %4 = g(%3) // %4 = g(%3)
// %5 = h(%1) // %5 = h(%1)
TORCH_API Node* insertAfter(Node* n); Node* insertAfter(Node* n);
// Move 'this' (already in the graph) after 'n' in the topological order. // Move 'this' (already in the graph) after 'n' in the topological order.
// //
@ -522,7 +522,7 @@ struct Node {
// Result: %3 = g(%1) // Result: %3 = g(%1)
// %2 = f(%1) // %2 = f(%1)
// //
TORCH_API void moveAfter(Node* n); void moveAfter(Node* n);
// Move a node 'n' (already in the graph) before 'this' in the topological // Move a node 'n' (already in the graph) before 'this' in the topological
// order. // order.
@ -535,7 +535,7 @@ struct Node {
// Execute: %3.moveBefore(%2) // Execute: %3.moveBefore(%2)
// Result: %3 = g(%1) // Result: %3 = g(%1)
// %2 = f(%1) // %2 = f(%1)
TORCH_API void moveBefore(Node* n); void moveBefore(Node* n);
// Remove the input at 'i' from this node. // Remove the input at 'i' from this node.
// //
@ -545,14 +545,14 @@ struct Node {
// Given: %3 = f(%1, %2) // Given: %3 = f(%1, %2)
// Execute: %3.removeInput(1) // Execute: %3.removeInput(1)
// Result: %3 = f(%1) // Result: %3 = f(%1)
TORCH_API void removeInput(size_t i); void removeInput(size_t i);
// Remove all inputs from a node. // Remove all inputs from a node.
// //
// Given: %3 = f(%1, %2) // Given: %3 = f(%1, %2)
// Execute: %3.removeAllInputs() // Execute: %3.removeAllInputs()
// Result: %3 = f() // Result: %3 = f()
TORCH_API void removeAllInputs(); void removeAllInputs();
// iterators of the node list starting at this node // iterators of the node list starting at this node
// useful for resuming a search starting at this node // useful for resuming a search starting at this node
@ -577,7 +577,7 @@ struct Node {
// %3 = g(%1) // %3 = g(%1)
// Execute: %2.destroy() // Execute: %2.destroy()
// Result: %3 = g(%1) // Result: %3 = g(%1)
TORCH_API void destroy(); void destroy();
// Dynamically cast this node to the subclass indicated by the // Dynamically cast this node to the subclass indicated by the
// template variable, returning nullptr if the cast is invalid.. // template variable, returning nullptr if the cast is invalid..
@ -604,11 +604,11 @@ struct Node {
} }
// XXX: this function is meant to be used with string literals only! // XXX: this function is meant to be used with string literals only!
TORCH_API bool matches( bool matches(
const char* signature_literal, const char* signature_literal,
at::ArrayRef<Symbol> const_inputs = {}) const; at::ArrayRef<Symbol> const_inputs = {}) const;
TORCH_API const FunctionSchema& schema() const { const FunctionSchema& schema() const {
if (!schema_) { if (!schema_) {
findSchema(); findSchema();
} }
@ -782,15 +782,15 @@ struct Node {
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
std::pair<Value*, const Argument&> findInput(Symbol name); std::pair<Value*, const Argument&> findInput(Symbol name);
TORCH_API void findSchema() const; void findSchema() const;
// Lookup iterator in use list of _input i_ that corresponds to its use of // Lookup iterator in use list of _input i_ that corresponds to its use of
// _this_ // _this_
TORCH_API use_list::iterator findUseForInput(size_t i); use_list::iterator findUseForInput(size_t i);
// remove the use of input i, this sets input i to nullptr, but // remove the use of input i, this sets input i to nullptr, but
// is only used internally to Node before setting it to a new value // is only used internally to Node before setting it to a new value
// or erasing the entry from the list. // or erasing the entry from the list.
TORCH_API Value* dropInput(size_t i); Value* dropInput(size_t i);
bool inBlockList() const { bool inBlockList() const {
if (next() == nullptr) { if (next() == nullptr) {
@ -799,8 +799,8 @@ struct Node {
return next() != nullptr; return next() != nullptr;
} }
TORCH_API void removeFromList(); void removeFromList();
TORCH_API void lint() const; void lint() const;
void assignTopoPosition(); void assignTopoPosition();
@ -818,7 +818,7 @@ struct Node {
// 'this' will be allocated with s->allocNewInstance(g) so it should have // 'this' will be allocated with s->allocNewInstance(g) so it should have
// the same concrete type as 's' // the same concrete type as 's'
// //
TORCH_API virtual void cloneFrom(Node* s); virtual void cloneFrom(Node* s);
}; };
struct Block { struct Block {
@ -1262,9 +1262,7 @@ struct ProfileOp : public Node {
// which is not included in libtorch.so. We still include some bits and pieces // which is not included in libtorch.so. We still include some bits and pieces
// of PythonOp here to enable writing simple passes generically. In general, // of PythonOp here to enable writing simple passes generically. In general,
// python-aware bits need to be moved to the descendant classes. // python-aware bits need to be moved to the descendant classes.
struct PythonOp : public Node { struct TORCH_API PythonOp : public Node {
static constexpr Symbol Kind = ::c10::prim::PythonOp;
using Node::Node; using Node::Node;
// should this Python function be skipped over when exported (i.e. for // should this Python function be skipped over when exported (i.e. for

View File

@ -2,6 +2,8 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace torch {
namespace jit { namespace jit {

View File

@ -2,6 +2,7 @@
#include <caffe2/proto/caffe2_pb.h> #include <caffe2/proto/caffe2_pb.h>
#include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/ir.h>
#include <unordered_map> #include <unordered_map>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
@ -15,7 +16,7 @@ namespace jit {
* \p Prefix can be used for appending some string to every operator name (e.g. * \p Prefix can be used for appending some string to every operator name (e.g.
* we can add "caffe2::"). * we can add "caffe2::").
*/ */
void convertNetDefToIR( TORCH_API void convertNetDefToIR(
const caffe2::NetDef& net, const caffe2::NetDef& net,
Graph* graph, Graph* graph,
std::unordered_map<std::string, Value*>* valueMapPtr = nullptr, std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
@ -34,7 +35,7 @@ void convertNetDefToIR(
* TODO: We might need to do a better job at preserving names of the variables, * TODO: We might need to do a better job at preserving names of the variables,
* especially external_inputs/external_outputs. * especially external_inputs/external_outputs.
*/ */
void convertIRToNetDef( TORCH_API void convertIRToNetDef(
caffe2::NetDef* net, caffe2::NetDef* net,
const Graph& graph, const Graph& graph,
const std::string& prefix = ""); const std::string& prefix = "");

View File

@ -19,9 +19,9 @@ namespace jit {
// A pass modifies a Graph in place. // A pass modifies a Graph in place.
using Pass = std::function<void(std::shared_ptr<Graph>&)>; using Pass = std::function<void(std::shared_ptr<Graph>&)>;
std::vector<Pass>& getCustomPasses(); TORCH_API std::vector<Pass>& getCustomPasses();
struct RegisterPass { struct TORCH_API RegisterPass {
RegisterPass(Pass p); RegisterPass(Pass p);
}; };

View File

@ -1265,7 +1265,7 @@ c10::optional<const Node*> AliasDb::getLastWildcard() const {
} }
} }
TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
// WARNING: by adding a case to this list, you are asserting that you have // WARNING: by adding a case to this list, you are asserting that you have
// added a case for the unschematized node in AliasDb::analyze // added a case for the unschematized node in AliasDb::analyze
const static std::unordered_set<Symbol> handled = { const static std::unordered_set<Symbol> handled = {

View File

@ -42,7 +42,7 @@ class AliasDb {
// Does `n` write to an alias of one of the values in `vs`? // Does `n` write to an alias of one of the values in `vs`?
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false) TORCH_API bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
const; const;
// Does `a` and `b` potentially share a memory location or do either // Does `a` and `b` potentially share a memory location or do either
@ -56,7 +56,7 @@ class AliasDb {
const at::ArrayRef<Value*>& b) const; const at::ArrayRef<Value*>& b) const;
// Do `a` and `b` potentially share a memory location? // Do `a` and `b` potentially share a memory location?
bool mayAlias(const Value* a, const Value* b) const; TORCH_API bool mayAlias(const Value* a, const Value* b) const;
// Do any values in group `a` potentially share a memory location with any // Do any values in group `a` potentially share a memory location with any
// value in group `b`? i.e. may they overlap? // value in group `b`? i.e. may they overlap?
// //
@ -104,7 +104,7 @@ class AliasDb {
} }
// Do any nodes write to an alias set inputed/outputed by `n`? // Do any nodes write to an alias set inputed/outputed by `n`?
bool hasWriters(const Node* n) const; TORCH_API bool hasWriters(const Node* n) const;
// Move 'n' (already in the graph) after 'movePoint' in the topological order. // Move 'n' (already in the graph) after 'movePoint' in the topological order.
// //
@ -115,8 +115,8 @@ class AliasDb {
// //
// Returns `false` if it's impossible to move `n` after `MovePoint` without // Returns `false` if it's impossible to move `n` after `MovePoint` without
// violating dependencies, otherwise executes the move and returns `true` // violating dependencies, otherwise executes the move and returns `true`
bool moveAfterTopologicallyValid(Node* n, Node* movePoint); TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
bool moveBeforeTopologicallyValid(Node* n, Node* movePoint); TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
bool couldMoveAfterTopologically(Node* n, Node* movePoint); bool couldMoveAfterTopologically(Node* n, Node* movePoint);
bool couldMoveBeforeTopologically(Node* n, Node* movePoint); bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
@ -146,7 +146,7 @@ class AliasDb {
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const; void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const; void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
// Do any nodes write to `v`s memory location? // Do any nodes write to `v`s memory location?
bool hasWriters(const Value* v) const; TORCH_API bool hasWriters(const Value* v) const;
// Register the fact that `n` writes to `v`. // Register the fact that `n` writes to `v`.
void registerWrite(const Value* v, Node* n); void registerWrite(const Value* v, Node* n);
// Get all the values that `n` reads from. // Get all the values that `n` reads from.
@ -163,7 +163,7 @@ class AliasDb {
* Wildcard methods * Wildcard methods
*/ */
// is `v` a wildcard? // is `v` a wildcard?
bool isWildcard(const Value* v) const; TORCH_API bool isWildcard(const Value* v) const;
// Register `v` as a wildcard value. // Register `v` as a wildcard value.
void setWildcard(const Value* v); void setWildcard(const Value* v);
// Get all nodes that write to a wildcard value. // Get all nodes that write to a wildcard value.

View File

@ -855,7 +855,7 @@ struct PythonPrintPass {
// Prints the RHS value of a Node, e.g. `aten.add(x, y)` // Prints the RHS value of a Node, e.g. `aten.add(x, y)`
void printRHS(std::ostream& stmt, Node* node) { void printRHS(std::ostream& stmt, Node* node) {
switch (node->kind()) { switch (node->kind()) {
case PythonOp::Kind: { case prim::PythonOp: {
auto value = static_cast<const PythonOp*>(node); auto value = static_cast<const PythonOp*>(node);
if (enforce_importable_) { if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation()) throw script::ErrorReport(node->getSourceLocation())
@ -1111,7 +1111,7 @@ public:
} }
}; };
TORCH_API void PythonPrint( void PythonPrint(
std::ostream& out, std::ostream& out,
const script::Function& func, const script::Function& func,
bool is_method, bool is_method,
@ -1124,7 +1124,7 @@ TORCH_API void PythonPrint(
pp.print(out); pp.print(out);
} }
TORCH_API void PythonPrint( void PythonPrint(
std::ostream& out, std::ostream& out,
const script::CompilationUnit& cu, const script::CompilationUnit& cu,
bool is_method, bool is_method,
@ -1137,7 +1137,7 @@ TORCH_API void PythonPrint(
pp.print(out); pp.print(out);
} }
TORCH_API void PythonPrint( void PythonPrint(
std::ostream& out, std::ostream& out,
const ClassTypePtr& classType, const ClassTypePtr& classType,
std::vector<at::Tensor>& tensor_table, std::vector<at::Tensor>& tensor_table,
@ -1148,7 +1148,7 @@ TORCH_API void PythonPrint(
pp.print(out); pp.print(out);
} }
TORCH_API bool printerHasSpecialCaseFor(Symbol sym) { bool printerHasSpecialCaseFor(Symbol sym) {
// WARNING: by adding a value to this set, you are asserting // WARNING: by adding a value to this set, you are asserting
// that you have also added special handling of this symbol to // that you have also added special handling of this symbol to
// the printer above. Not adding handling will cause import and export // the printer above. Not adding handling will cause import and export

View File

@ -6,6 +6,8 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
@ -30,17 +32,17 @@ struct Value;
class MemoryDAG { class MemoryDAG {
public: public:
// Make `from` point at `to`. // Make `from` point at `to`.
void makePointerTo(Element* from, Element* to); TORCH_API void makePointerTo(Element* from, Element* to);
void addToContainedElements(Element* contained, Element* container); void addToContainedElements(Element* contained, Element* container);
// Make a fresh element (i.e. an element that doesn't point to anything) and // Make a fresh element (i.e. an element that doesn't point to anything) and
// return it. // return it.
Element* makeFreshValue(const Value* v); TORCH_API Element* makeFreshValue(const Value* v);
// Do `a` and `b` potentially share a memory location? // Do `a` and `b` potentially share a memory location?
bool mayAlias(const Element* a, const Element* b) const; bool mayAlias(const Element* a, const Element* b) const;
bool mayAlias(Element* a, Element* b) const; TORCH_API bool mayAlias(Element* a, Element* b) const;
// Does a hold reference to any memory that is stored in elem, or vice versa? // Does a hold reference to any memory that is stored in elem, or vice versa?
bool mayContainAlias(const Element* a, const Element* b) const; bool mayContainAlias(const Element* a, const Element* b) const;
@ -124,7 +126,7 @@ struct Element {
std::unordered_set<Element*> contained_elements; std::unordered_set<Element*> contained_elements;
// Return the unique memory locations that `Element` might represent. // Return the unique memory locations that `Element` might represent.
std::unordered_set<const Element*> getMemoryLocations() const; TORCH_API std::unordered_set<const Element*> getMemoryLocations() const;
// We do path compression to make repeated memory location queries faster. // We do path compression to make repeated memory location queries faster.
// An empty cache means it is invalidated (it can never be empty otherwise, // An empty cache means it is invalidated (it can never be empty otherwise,
// since every element must point to at least one memory location). // since every element must point to at least one memory location).

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/ir.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
@ -17,16 +18,16 @@ namespace SubgraphUtils {
// `n` is destroyed. // `n` is destroyed.
// //
// Returns the new subgraph node. // Returns the new subgraph node.
Node* createSingletonSubgraph(Node* n, Symbol subgraphKind); TORCH_API Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the // Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
// subgraphs are merged. // subgraphs are merged.
// `toMerge` is destroyed. // `toMerge` is destroyed.
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode); TORCH_API void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
// Move nodes from a subgraph node to the outer graph. // Move nodes from a subgraph node to the outer graph.
// `subgraphNode` is destroyed. // `subgraphNode` is destroyed.
void unmergeSubgraph(Node* subgraphNode); TORCH_API void unmergeSubgraph(Node* subgraphNode);
// Convenience function // Convenience function
std::shared_ptr<Graph> getSubgraph(Node* n); std::shared_ptr<Graph> getSubgraph(Node* n);

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
@ -21,7 +22,7 @@ struct ProfilingRecord {
ProfilingRecord(const ProfilingRecord&) = delete; ProfilingRecord(const ProfilingRecord&) = delete;
ProfilingRecord(ProfilingRecord&&) noexcept = delete; ProfilingRecord(ProfilingRecord&&) noexcept = delete;
static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival); static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival);
static std::unique_ptr<ProfilingRecord> instrumentGraph( TORCH_API static std::unique_ptr<ProfilingRecord> instrumentGraph(
const std::shared_ptr<Graph>& graph); const std::shared_ptr<Graph>& graph);
std::shared_ptr<Graph> profiled_graph_; std::shared_ptr<Graph> profiled_graph_;

View File

@ -19,6 +19,8 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
Symbol ConcretePythonOp::Kind = prim::PythonOp;
using c10::Type; using c10::Type;
std::string getPythonName(const PyObject* obj_) { std::string getPythonName(const PyObject* obj_) {

View File

@ -11,6 +11,8 @@ void initPythonIRBindings(PyObject* module);
// execute a Python function, used for Ops we can't optimize but that we want to // execute a Python function, used for Ops we can't optimize but that we want to
// optimize around // optimize around
struct ConcretePythonOp : public PythonOp { struct ConcretePythonOp : public PythonOp {
static Symbol Kind;
ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {} ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {}
ConcretePythonOp* init( ConcretePythonOp* init(
THPObjectPtr&& pyobj, THPObjectPtr&& pyobj,

View File

@ -100,7 +100,7 @@ struct BuiltinFunctionRegistry {
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name; std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
}; };
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) { const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
static BuiltinFunctionRegistry registry; static BuiltinFunctionRegistry registry;
return registry.getAllBuiltinFunctionsFor(name); return registry.getAllBuiltinFunctionsFor(name);
} }

View File

@ -0,0 +1,9 @@
#include <torch/csrc/jit/script/jit_exception.h>
namespace torch {
namespace jit {
JITException::JITException(const std::string& msg) : std::runtime_error(msg) {}
} // namespace jit
} // namespace torch

View File

@ -2,12 +2,13 @@
#include <stdexcept> #include <stdexcept>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
struct JITException : public std::runtime_error { struct TORCH_API JITException : public std::runtime_error {
JITException() = default; explicit JITException(const std::string& msg);
explicit JITException(const std::string& msg) : std::runtime_error(msg) {}
}; };
} // namespace jit } // namespace jit

View File

@ -3,6 +3,7 @@
#include <c10/util/C++17.h> #include <c10/util/C++17.h>
#include <torch/csrc/jit/source_range.h> #include <torch/csrc/jit/source_range.h>
#include <torch/csrc/jit/script/strtod.h> #include <torch/csrc/jit/script/strtod.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/Macros.h> #include <ATen/core/Macros.h>
#include <algorithm> #include <algorithm>
#include <clocale> #include <clocale>

View File

@ -1,4 +1,4 @@
#include "torch/csrc/jit/script/logging.h" #include <torch/csrc/jit/script/logging.h>
#include <atomic> #include <atomic>
#include <mutex> #include <mutex>
@ -17,7 +17,7 @@ void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
raw_counter.count++; raw_counter.count++;
} }
TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const { int64_t LockingLogger::getCounterValue(const std::string& name) const {
std::unique_lock<std::mutex> lk(m); std::unique_lock<std::mutex> lk(m);
if (!raw_counters.count(name)) { if (!raw_counters.count(name)) {
return 0; return 0;

View File

@ -37,12 +37,12 @@ class NoopLogger : public LoggerBase {
// //
// NOTE: this is not written in a scalable way and should probably only be used // NOTE: this is not written in a scalable way and should probably only be used
// in the single-threaded case or for testing. // in the single-threaded case or for testing.
class LockingLogger : public LoggerBase { class TORCH_API LockingLogger : public LoggerBase {
public: public:
TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override; void addStatValue(const std::string& stat_name, int64_t val) override;
TORCH_API virtual int64_t getCounterValue(const std::string& name) const; virtual int64_t getCounterValue(const std::string& name) const;
enum class AggregationType { SUM, AVG }; enum class AggregationType { SUM, AVG };
TORCH_API void setAggregationType( void setAggregationType(
const std::string& stat_name, const std::string& stat_name,
AggregationType type); AggregationType type);
~LockingLogger() {} ~LockingLogger() {}

View File

@ -22,7 +22,7 @@ namespace script {
enum NoneStatus { ALWAYS, MAYBE, NEVER }; enum NoneStatus { ALWAYS, MAYBE, NEVER };
struct SugaredValue : public std::enable_shared_from_this<SugaredValue> { struct TORCH_API SugaredValue : public std::enable_shared_from_this<SugaredValue> {
// what is this node? for error reporting (e.g. Module, python function) // what is this node? for error reporting (e.g. Module, python function)
virtual std::string kind() const = 0; virtual std::string kind() const = 0;

View File

@ -39,7 +39,7 @@ void badArgType(const T& v) {
thread_local std::shared_ptr<TracingState> tracing_state; thread_local std::shared_ptr<TracingState> tracing_state;
} // namespace detail } // namespace detail
TORCH_API std::function<void()> pauseTracing() { std::function<void()> pauseTracing() {
// NOLINTNEXTLINE // NOLINTNEXTLINE
std::shared_ptr<tracer::TracingState> state = getTracingState(); std::shared_ptr<tracer::TracingState> state = getTracingState();
tracer::setTracingState(nullptr); tracer::setTracingState(nullptr);

View File

@ -17,6 +17,7 @@
#include <torch/csrc/generic/utils.cpp> #include <torch/csrc/generic/utils.cpp>
#include <TH/THGenerateHalfType.h> #include <TH/THGenerateHalfType.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/generic/utils.cpp> #include <torch/csrc/generic/utils.cpp>
#include <TH/THGenerateBoolType.h> #include <TH/THGenerateBoolType.h>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/core/functional.h> #include <ATen/core/functional.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <utility> #include <utility>
@ -59,16 +59,16 @@ struct TensorGroup {
// enough tensors for all data types until the size_limit, and then split // enough tensors for all data types until the size_limit, and then split
// the accumulated tensors into different groups by data types, therefore: // the accumulated tensors into different groups by data types, therefore:
// it will output: {{tensor_a}, {tensor_b}, {tensor_c}} // it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
std::vector<TensorGroup> take_tensors( TORCH_API std::vector<TensorGroup> take_tensors(
at::TensorList tensors, at::TensorList tensors,
size_t size_limit, size_t size_limit,
bool fine_grained = false); bool fine_grained = false);
void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order); TORCH_API void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order);
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors); TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors);
std::vector<at::Tensor> unflatten_sparse_tensors( TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors(
const at::Tensor& flat_indices, const at::Tensor& flat_indices,
const at::Tensor& flat_values, const at::Tensor& flat_values,
at::TensorList tensors); at::TensorList tensors);