mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
C++ changes toward libtorch and libcaffe2 unification (#19554)
Summary: * adds TORCH_API and AT_CUDA_API in places * refactor code generation Python logic to separate caffe2/torch outputs * fix hip and asan * remove profiler_cuda from hip * fix gcc warnings for enums * Fix PythonOp::Kind Pull Request resolved: https://github.com/pytorch/pytorch/pull/19554 Differential Revision: D15082727 Pulled By: kostmo fbshipit-source-id: 83a8a99717f025ab44b29608848928d76b3147a4
This commit is contained in:
committed by
Facebook Github Bot
parent
9d180e602f
commit
8f0603b128
@ -6,7 +6,7 @@
|
||||
|
||||
#ifdef _WIN32
|
||||
#if !defined(AT_CORE_STATIC_WINDOWS)
|
||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
|
||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(caffe2_hip_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
|
||||
# define AT_CUDA_API __declspec(dllexport)
|
||||
# else
|
||||
# define AT_CUDA_API __declspec(dllimport)
|
||||
@ -15,7 +15,7 @@
|
||||
# define AT_CUDA_API
|
||||
#endif
|
||||
#elif defined(__GNUC__)
|
||||
#if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS)
|
||||
#if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(caffe2_hip_EXPORTS)
|
||||
#define AT_CUDA_API __attribute__((__visibility__("default")))
|
||||
#else
|
||||
#define AT_CUDA_API
|
||||
|
@ -22,7 +22,7 @@
|
||||
#endif
|
||||
|
||||
// ROCM hcc doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#define compat_pow c10::cuda::compat::pow
|
||||
#else
|
||||
|
@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/bin/bash -xe
|
||||
##############################################################################
|
||||
# Example command to build the iOS target.
|
||||
##############################################################################
|
||||
@ -7,8 +7,6 @@
|
||||
# using ios-cmake. This is very similar to the android-cmake - see
|
||||
# build_android.sh for more details.
|
||||
|
||||
set -e
|
||||
|
||||
CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)"
|
||||
|
||||
# Build protobuf from third_party so we have a host protoc binary.
|
||||
|
@ -2,6 +2,8 @@
|
||||
#include <gtest/gtest.h>
|
||||
#endif
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
// To add a new test file:
|
||||
// 1. Add a test_foo.h file in this directory
|
||||
// 2. include test_base.h
|
||||
|
@ -194,14 +194,41 @@ def gen_autograd(aten_path, out, autograd_dir):
|
||||
gen_variable_type(out, aten_decls, template_path)
|
||||
|
||||
# Generate Functions.h/cpp
|
||||
from .gen_autograd_functions import gen_autograd_functions
|
||||
gen_autograd_functions(
|
||||
from .gen_autograd_functions import gen_autograd_functions_lib
|
||||
gen_autograd_functions_lib(
|
||||
out, autograd_functions, template_path)
|
||||
|
||||
# Load deprecated signatures
|
||||
deprecated = load_deprecated_signatures(
|
||||
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
|
||||
|
||||
# Generate variable_factories.h
|
||||
from .gen_variable_factories import gen_variable_factories
|
||||
gen_variable_factories(out, aten_decls, template_path)
|
||||
|
||||
|
||||
def gen_autograd_python(aten_path, out, autograd_dir):
|
||||
|
||||
# TODO Deduplicate these four variable assignments
|
||||
|
||||
aten_decls = load_aten_declarations(aten_path)
|
||||
|
||||
# Parse and load derivatives.yaml
|
||||
from .load_derivatives import load_derivatives
|
||||
autograd_functions = load_derivatives(
|
||||
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
|
||||
|
||||
template_path = os.path.join(autograd_dir, 'templates')
|
||||
|
||||
# Load deprecated signatures
|
||||
deprecated = load_deprecated_signatures(
|
||||
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
|
||||
|
||||
# Generate Functions.h/cpp
|
||||
from .gen_autograd_functions import gen_autograd_functions_python
|
||||
gen_autograd_functions_python(
|
||||
out, autograd_functions, template_path)
|
||||
|
||||
# Generate Python bindings
|
||||
from . import gen_python_functions
|
||||
gen_python_functions.gen_py_variable_methods(
|
||||
@ -211,10 +238,6 @@ def gen_autograd(aten_path, out, autograd_dir):
|
||||
gen_python_functions.gen_py_nn_functions(
|
||||
out, aten_decls, template_path)
|
||||
|
||||
# Generate variable_factories.h
|
||||
from .gen_variable_factories import gen_variable_factories
|
||||
gen_variable_factories(out, aten_decls, template_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -4,13 +4,14 @@
|
||||
# Functions.h/cpp: subclasses of autograd::Function
|
||||
# python_functions.h/cpp: Python bindings for the above classes
|
||||
#
|
||||
import os
|
||||
import re
|
||||
from .utils import nested_dict, CodeTemplate, write
|
||||
from .gen_autograd import VIEW_FUNCTIONS
|
||||
from .utils import IDENT_REGEX
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate("""\
|
||||
struct ${op} : public ${superclass} {
|
||||
struct TORCH_API ${op} : public ${superclass} {
|
||||
using ${superclass}::${superclass};
|
||||
variable_list apply(variable_list&& grads) override;
|
||||
std::string name() const override { return "${op}"; }
|
||||
@ -81,18 +82,21 @@ if (should_compute_output({ ${idx_ranges} })) {
|
||||
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
||||
|
||||
|
||||
def gen_autograd_functions(out, autograd_functions, template_path):
|
||||
def gen_autograd_functions_lib(out, autograd_functions, template_path):
|
||||
gen_autograd_functions(out, autograd_functions, template_path, "Functions")
|
||||
|
||||
|
||||
def gen_autograd_functions_python(out, autograd_functions, template_path):
|
||||
gen_autograd_functions(out, autograd_functions, template_path, "python_functions")
|
||||
|
||||
|
||||
def gen_autograd_functions(out, autograd_functions, template_path, file_basename):
|
||||
"""Functions.h and Functions.cpp body
|
||||
|
||||
These contain the auto-generated subclasses of torch::autograd::Function
|
||||
for each every differentiable torch function.
|
||||
"""
|
||||
|
||||
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
|
||||
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
|
||||
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
|
||||
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
|
||||
|
||||
function_definitions = []
|
||||
function_declarations = []
|
||||
py_function_initializers = []
|
||||
@ -110,10 +114,10 @@ def gen_autograd_functions(out, autograd_functions, template_path):
|
||||
'py_function_initializers': py_function_initializers,
|
||||
}
|
||||
|
||||
write(out, 'Functions.h', FUNCTIONS_H, top_env)
|
||||
write(out, 'Functions.cpp', FUNCTIONS_CPP, top_env)
|
||||
write(out, 'python_functions.h', PY_FUNCTIONS_H, top_env)
|
||||
write(out, 'python_functions.cpp', PY_FUNCTIONS_CPP, top_env)
|
||||
for suffix in [".h", ".cpp"]:
|
||||
f = file_basename + suffix
|
||||
templated_output = CodeTemplate.from_file(os.path.join(template_path, f))
|
||||
write(out, f, templated_output, top_env)
|
||||
|
||||
|
||||
def process_function(func):
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/autograd/saved_variable.h"
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch { namespace autograd { namespace generated {
|
||||
|
||||
|
@ -38,6 +38,7 @@ libtorch_sources = [
|
||||
"torch/csrc/autograd/anomaly_mode.cpp",
|
||||
"torch/csrc/autograd/engine.cpp",
|
||||
"torch/csrc/autograd/function.cpp",
|
||||
"torch/csrc/autograd/function_hook.cpp",
|
||||
"torch/csrc/autograd/functions/accumulate_grad.cpp",
|
||||
"torch/csrc/autograd/functions/basic_ops.cpp",
|
||||
"torch/csrc/autograd/functions/tensor.cpp",
|
||||
@ -107,6 +108,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/script/schema_matching.cpp",
|
||||
"torch/csrc/jit/script/class_type.cpp",
|
||||
"torch/csrc/jit/script/parser.cpp",
|
||||
"torch/csrc/jit/script/jit_exception.cpp",
|
||||
"torch/csrc/jit/testing/file_check.cpp",
|
||||
"torch/csrc/jit/import_source.cpp",
|
||||
"torch/csrc/jit/hooks_for_testing.cpp",
|
||||
|
@ -19,61 +19,35 @@ def all_generator_source():
|
||||
return sorted(r)
|
||||
|
||||
|
||||
inputs = [
|
||||
'torch/lib/THNN.h',
|
||||
'torch/lib/THCUNN.h',
|
||||
'torch/share/ATen/Declarations.yaml',
|
||||
'tools/autograd/derivatives.yaml',
|
||||
'tools/autograd/deprecated.yaml',
|
||||
]
|
||||
|
||||
outputs = [
|
||||
'torch/csrc/autograd/generated/Functions.cpp',
|
||||
'torch/csrc/autograd/generated/Functions.h',
|
||||
'torch/csrc/autograd/generated/python_functions.cpp',
|
||||
'torch/csrc/autograd/generated/python_functions.h',
|
||||
'torch/csrc/autograd/generated/python_nn_functions.cpp',
|
||||
'torch/csrc/autograd/generated/python_nn_functions.h',
|
||||
'torch/csrc/autograd/generated/python_nn_functions_dispatch.h',
|
||||
'torch/csrc/autograd/generated/python_variable_methods.cpp',
|
||||
'torch/csrc/autograd/generated/python_variable_methods_dispatch.h',
|
||||
'torch/csrc/autograd/generated/variable_factories.h',
|
||||
'torch/csrc/autograd/generated/VariableType_0.cpp',
|
||||
'torch/csrc/autograd/generated/VariableType_1.cpp',
|
||||
'torch/csrc/autograd/generated/VariableType_2.cpp',
|
||||
'torch/csrc/autograd/generated/VariableType_3.cpp',
|
||||
'torch/csrc/autograd/generated/VariableType_4.cpp',
|
||||
'torch/csrc/autograd/generated/VariableType.h',
|
||||
'torch/csrc/jit/generated/register_aten_ops_0.cpp',
|
||||
'torch/csrc/jit/generated/register_aten_ops_1.cpp',
|
||||
'torch/csrc/jit/generated/register_aten_ops_2.cpp',
|
||||
]
|
||||
|
||||
|
||||
def generate_code(ninja_global=None,
|
||||
declarations_path=None,
|
||||
nn_path=None,
|
||||
install_dir=None):
|
||||
install_dir=None,
|
||||
subset=None):
|
||||
# cwrap depends on pyyaml, so we can't import it earlier
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
from tools.autograd.gen_autograd import gen_autograd
|
||||
from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
|
||||
from tools.jit.gen_jit_dispatch import gen_jit_dispatch
|
||||
|
||||
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
|
||||
|
||||
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
|
||||
# used by the legacy NN bindings.
|
||||
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
|
||||
|
||||
# Build ATen based Variable classes
|
||||
autograd_gen_dir = install_dir or 'torch/csrc/autograd/generated'
|
||||
jit_gen_dir = install_dir or 'torch/csrc/jit/generated'
|
||||
for d in (autograd_gen_dir, jit_gen_dir):
|
||||
if not os.path.exists(d):
|
||||
os.makedirs(d)
|
||||
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
||||
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates')
|
||||
|
||||
if subset == "pybindings" or not subset:
|
||||
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
|
||||
# used by the legacy NN bindings.
|
||||
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
|
||||
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
|
||||
|
||||
gen_autograd_python(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
||||
|
||||
if subset == "libtorch" or not subset:
|
||||
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
||||
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates')
|
||||
|
||||
|
||||
def main():
|
||||
@ -82,11 +56,18 @@ def main():
|
||||
parser.add_argument('--nn-path')
|
||||
parser.add_argument('--ninja-global')
|
||||
parser.add_argument('--install_dir')
|
||||
parser.add_argument(
|
||||
'--subset',
|
||||
help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
|
||||
)
|
||||
options = parser.parse_args()
|
||||
generate_code(options.ninja_global,
|
||||
options.declarations_path,
|
||||
options.nn_path,
|
||||
options.install_dir)
|
||||
generate_code(
|
||||
options.ninja_global,
|
||||
options.declarations_path,
|
||||
options.nn_path,
|
||||
options.install_dir,
|
||||
options.subset,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -104,6 +104,7 @@ set(TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/engine.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/function.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/function_hook.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/accumulate_grad.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/basic_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp
|
||||
@ -190,6 +191,7 @@ set(TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
|
||||
@ -230,6 +232,14 @@ if (USE_CUDA)
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
if (USE_ROCM)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
if (NOT NO_API)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
|
||||
@ -543,7 +553,6 @@ if (BUILD_PYTHON)
|
||||
${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_list.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_new.cpp
|
||||
@ -615,7 +624,6 @@ if (BUILD_PYTHON)
|
||||
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
|
||||
@ -662,7 +670,6 @@ if (BUILD_PYTHON)
|
||||
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
|
||||
|
@ -29,17 +29,17 @@ class Sampler {
|
||||
/// Resets the `Sampler`'s internal state.
|
||||
/// Typically called before a new epoch.
|
||||
/// Optionally, accepts a new size when reseting the sampler.
|
||||
TORCH_API virtual void reset(optional<size_t> new_size) = 0;
|
||||
virtual void reset(optional<size_t> new_size) = 0;
|
||||
|
||||
/// Returns the next index if possible, or an empty optional if the
|
||||
/// sampler is exhausted for this epoch.
|
||||
TORCH_API virtual optional<BatchRequest> next(size_t batch_size) = 0;
|
||||
virtual optional<BatchRequest> next(size_t batch_size) = 0;
|
||||
|
||||
/// Serializes the `Sampler` to the `archive`.
|
||||
TORCH_API virtual void save(serialize::OutputArchive& archive) const = 0;
|
||||
virtual void save(serialize::OutputArchive& archive) const = 0;
|
||||
|
||||
/// Deserializes the `Sampler` from the `archive`.
|
||||
TORCH_API virtual void load(serialize::InputArchive& archive) = 0;
|
||||
virtual void load(serialize::InputArchive& archive) = 0;
|
||||
};
|
||||
|
||||
} // namespace samplers
|
||||
|
@ -25,7 +25,7 @@ namespace samplers {
|
||||
template <typename BatchRequest = std::vector<size_t>>
|
||||
class DistributedSampler : public Sampler<BatchRequest> {
|
||||
public:
|
||||
TORCH_API DistributedSampler(
|
||||
DistributedSampler(
|
||||
size_t size,
|
||||
size_t num_replicas = 1,
|
||||
size_t rank = 0,
|
||||
@ -64,28 +64,28 @@ class DistributedSampler : public Sampler<BatchRequest> {
|
||||
|
||||
/// Select samples randomly. The sampling order is shuffled at each `reset()`
|
||||
/// call.
|
||||
class DistributedRandomSampler : public DistributedSampler<> {
|
||||
class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
|
||||
public:
|
||||
TORCH_API DistributedRandomSampler(
|
||||
DistributedRandomSampler(
|
||||
size_t size,
|
||||
size_t num_replicas = 1,
|
||||
size_t rank = 0,
|
||||
bool allow_duplicates = true);
|
||||
|
||||
/// Resets the `DistributedRandomSampler` to a new set of indices.
|
||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
||||
void reset(optional<size_t> new_size = nullopt) override;
|
||||
|
||||
/// Returns the next batch of indices.
|
||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
|
||||
/// Serializes the `DistributedRandomSampler` to the `archive`.
|
||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
|
||||
/// Deserializes the `DistributedRandomSampler` from the `archive`.
|
||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
/// Returns the current index of the `DistributedRandomSampler`.
|
||||
TORCH_API size_t index() const noexcept;
|
||||
size_t index() const noexcept;
|
||||
|
||||
private:
|
||||
void populate_indices();
|
||||
@ -97,28 +97,28 @@ class DistributedRandomSampler : public DistributedSampler<> {
|
||||
};
|
||||
|
||||
/// Select samples sequentially.
|
||||
class DistributedSequentialSampler : public DistributedSampler<> {
|
||||
class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
|
||||
public:
|
||||
TORCH_API DistributedSequentialSampler(
|
||||
DistributedSequentialSampler(
|
||||
size_t size,
|
||||
size_t num_replicas = 1,
|
||||
size_t rank = 0,
|
||||
bool allow_duplicates = true);
|
||||
|
||||
/// Resets the `DistributedSequentialSampler` to a new set of indices.
|
||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
||||
void reset(optional<size_t> new_size = nullopt) override;
|
||||
|
||||
/// Returns the next batch of indices.
|
||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
|
||||
/// Serializes the `DistributedSequentialSampler` to the `archive`.
|
||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
|
||||
/// Deserializes the `DistributedSequentialSampler` from the `archive`.
|
||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
/// Returns the current index of the `DistributedSequentialSampler`.
|
||||
TORCH_API size_t index() const noexcept;
|
||||
size_t index() const noexcept;
|
||||
|
||||
private:
|
||||
void populate_indices();
|
||||
|
@ -19,31 +19,33 @@ namespace data {
|
||||
namespace samplers {
|
||||
|
||||
/// A `Sampler` that returns random indices.
|
||||
class RandomSampler : public Sampler<> {
|
||||
class TORCH_API RandomSampler : public Sampler<> {
|
||||
public:
|
||||
/// Constructs a `RandomSampler` with a size and dtype for the stored indices.
|
||||
///
|
||||
/// The constructor will eagerly allocate all required indices, which is the
|
||||
/// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
|
||||
/// indices. You can change it to influence memory usage.
|
||||
TORCH_API explicit RandomSampler(
|
||||
explicit RandomSampler(
|
||||
int64_t size,
|
||||
Dtype index_dtype = torch::kInt64);
|
||||
|
||||
~RandomSampler() override;
|
||||
|
||||
/// Resets the `RandomSampler` to a new set of indices.
|
||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
||||
void reset(optional<size_t> new_size = nullopt) override;
|
||||
|
||||
/// Returns the next batch of indices.
|
||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
|
||||
/// Serializes the `RandomSampler` to the `archive`.
|
||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
|
||||
/// Deserializes the `RandomSampler` from the `archive`.
|
||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
/// Returns the current index of the `RandomSampler`.
|
||||
TORCH_API size_t index() const noexcept;
|
||||
size_t index() const noexcept;
|
||||
|
||||
private:
|
||||
Tensor indices_;
|
||||
|
@ -19,26 +19,26 @@ namespace data {
|
||||
namespace samplers {
|
||||
|
||||
/// A `Sampler` that returns indices sequentially.
|
||||
class SequentialSampler : public Sampler<> {
|
||||
class TORCH_API SequentialSampler : public Sampler<> {
|
||||
public:
|
||||
/// Creates a `SequentialSampler` that will return indices in the range
|
||||
/// `0...size - 1`.
|
||||
TORCH_API explicit SequentialSampler(size_t size);
|
||||
explicit SequentialSampler(size_t size);
|
||||
|
||||
/// Resets the `SequentialSampler` to zero.
|
||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
||||
void reset(optional<size_t> new_size = nullopt) override;
|
||||
|
||||
/// Returns the next batch of indices.
|
||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||
|
||||
/// Serializes the `SequentialSampler` to the `archive`.
|
||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
|
||||
/// Deserializes the `SequentialSampler` from the `archive`.
|
||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
/// Returns the current index of the `SequentialSampler`.
|
||||
TORCH_API size_t index() const noexcept;
|
||||
size_t index() const noexcept;
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
|
@ -32,26 +32,26 @@ struct TORCH_API BatchSize : public CustomBatchRequest {
|
||||
/// The major feature of the `StreamSampler` is that it does not return
|
||||
/// particular indices, but instead only the number of elements to fetch from
|
||||
/// the dataset. The dataset has to decide how to produce those elements.
|
||||
class StreamSampler : public Sampler<BatchSize> {
|
||||
class TORCH_API StreamSampler : public Sampler<BatchSize> {
|
||||
public:
|
||||
/// Constructs the `StreamSampler` with the number of individual examples that
|
||||
/// should be fetched until the sampler is exhausted.
|
||||
TORCH_API explicit StreamSampler(size_t epoch_size);
|
||||
explicit StreamSampler(size_t epoch_size);
|
||||
|
||||
/// Resets the internal state of the sampler.
|
||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
||||
void reset(optional<size_t> new_size = nullopt) override;
|
||||
|
||||
/// Returns a `BatchSize` object with the number of elements to fetch in the
|
||||
/// next batch. This number is the minimum of the supplied `batch_size` and
|
||||
/// the difference between the `epoch_size` and the current index. If the
|
||||
/// `epoch_size` has been reached, returns an empty optional.
|
||||
TORCH_API optional<BatchSize> next(size_t batch_size) override;
|
||||
optional<BatchSize> next(size_t batch_size) override;
|
||||
|
||||
/// Serializes the `StreamSampler` to the `archive`.
|
||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
|
||||
/// Deserializes the `StreamSampler` from the `archive`.
|
||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
private:
|
||||
size_t examples_retrieved_so_far_ = 0;
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
|
@ -4,6 +4,8 @@
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
|
@ -163,4 +163,4 @@ size_t DistributedSequentialSampler::index() const noexcept {
|
||||
|
||||
} // namespace samplers
|
||||
} // namespace data
|
||||
} // namespace torch
|
||||
} // namespace torch
|
||||
|
@ -12,6 +12,8 @@ namespace samplers {
|
||||
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
|
||||
: indices_(torch::randperm(size, index_dtype)) {}
|
||||
|
||||
RandomSampler::~RandomSampler() = default;
|
||||
|
||||
void RandomSampler::reset(optional<size_t> new_size) {
|
||||
// This allocates a new chunk of memory every time (just FYI). It should be
|
||||
// amortized over the entire epoch hopefully.
|
||||
|
@ -4,4 +4,6 @@ namespace torch { namespace autograd {
|
||||
|
||||
bool AnomalyMode::_enabled = false;
|
||||
|
||||
AnomalyMetadata::~AnomalyMetadata() = default;
|
||||
|
||||
}}
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct AnomalyMode {
|
||||
struct TORCH_API AnomalyMode {
|
||||
static bool is_enabled() {
|
||||
return _enabled;
|
||||
}
|
||||
@ -13,12 +13,12 @@ struct AnomalyMode {
|
||||
}
|
||||
|
||||
private:
|
||||
TORCH_API static bool _enabled;
|
||||
static bool _enabled;
|
||||
};
|
||||
|
||||
|
||||
struct AnomalyMetadata {
|
||||
virtual ~AnomalyMetadata() = default;
|
||||
struct TORCH_API AnomalyMetadata {
|
||||
virtual ~AnomalyMetadata();
|
||||
virtual void store_stack() = 0;
|
||||
virtual void print_stack() = 0;
|
||||
};
|
||||
|
8
torch/csrc/autograd/function_hook.cpp
Normal file
8
torch/csrc/autograd/function_hook.cpp
Normal file
@ -0,0 +1,8 @@
|
||||
#include <torch/csrc/autograd/function_hook.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
FunctionPreHook::~FunctionPreHook() = default;
|
||||
FunctionPostHook::~FunctionPostHook() = default;
|
||||
|
||||
}} // namespace torch::autograd
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
// A hook that's called on gradients
|
||||
|
||||
@ -9,13 +10,13 @@ namespace torch { namespace autograd {
|
||||
struct Variable;
|
||||
using variable_list = std::vector<Variable>;
|
||||
|
||||
struct FunctionPreHook {
|
||||
virtual ~FunctionPreHook() = default;
|
||||
struct TORCH_API FunctionPreHook {
|
||||
virtual ~FunctionPreHook();
|
||||
virtual variable_list operator()(const variable_list& grads) = 0;
|
||||
};
|
||||
|
||||
struct FunctionPostHook {
|
||||
virtual ~FunctionPostHook() = default;
|
||||
struct TORCH_API FunctionPostHook {
|
||||
virtual ~FunctionPostHook();
|
||||
virtual variable_list operator()(
|
||||
const variable_list& outputs /* grad_inputs */,
|
||||
const variable_list& inputs /* grad_outputs */) = 0;
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct AccumulateGrad : public Function {
|
||||
struct TORCH_API AccumulateGrad : public Function {
|
||||
explicit AccumulateGrad(Variable variable_);
|
||||
|
||||
variable_list apply(variable_list&& grads) override;
|
||||
|
@ -28,6 +28,8 @@ Scatter::Scatter(
|
||||
streams_(streams),
|
||||
unsqueeze_scalars_(unsqueeze_scalars) {}
|
||||
|
||||
Scatter::~Scatter() {}
|
||||
|
||||
variable_list Scatter::apply(variable_list&& inputs) {
|
||||
AT_ASSERT(inputs.size() == 1);
|
||||
auto& input = inputs.front();
|
||||
@ -65,6 +67,8 @@ variable_list Scatter::apply(variable_list&& inputs) {
|
||||
Gather::Gather(const at::Device& destination_device, int64_t dim)
|
||||
: destination_device_(destination_device), dim_(dim) {}
|
||||
|
||||
Gather::~Gather() {}
|
||||
|
||||
variable_list Gather::apply(variable_list&& inputs) {
|
||||
bool all_are_zero_dim = true;
|
||||
for (const auto& input : inputs) {
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/ATenCUDAGeneral.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
@ -13,6 +14,7 @@
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
//TODO: change it to TORCH_API when we merge the libs
|
||||
struct TORCH_API Scatter : public Function {
|
||||
explicit Scatter(
|
||||
std::vector<at::Device> devices,
|
||||
@ -21,6 +23,7 @@ struct TORCH_API Scatter : public Function {
|
||||
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
|
||||
c10::nullopt,
|
||||
bool unsqueeze_scalars = false);
|
||||
~Scatter() override;
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
|
||||
@ -33,6 +36,7 @@ struct TORCH_API Scatter : public Function {
|
||||
|
||||
struct TORCH_API Gather : public Function {
|
||||
explicit Gather(const at::Device& destination_device, int64_t dim = 0);
|
||||
~Gather() override;
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
#include <ATen/TensorGeometry.h>
|
||||
#include <ATen/core/DeprecatedTypeProperties.h>
|
||||
@ -12,7 +13,7 @@
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct CopyBackwards : public Function {
|
||||
struct TORCH_API CopyBackwards : public Function {
|
||||
variable_list apply(variable_list&& grads) override;
|
||||
|
||||
at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety.
|
||||
@ -26,7 +27,7 @@ struct CopyBackwards : public Function {
|
||||
// grad_fn is updated to become a `CopySlice` wrapping the backward of the
|
||||
// in-place operation.
|
||||
// See NOTE [ Autograd View Variables ].
|
||||
struct CopySlices : public Function {
|
||||
struct TORCH_API CopySlices : public Function {
|
||||
CopySlices(
|
||||
const Variable& base_var,
|
||||
at::TensorGeometry view_,
|
||||
|
@ -12,7 +12,7 @@ constexpr CUDAStubs* default_stubs_addr = &default_stubs;
|
||||
// static initialization calls which may invoke registerCUDAMethods
|
||||
static CUDAStubs* cuda_stubs = default_stubs_addr;
|
||||
|
||||
TORCH_API void registerCUDAMethods(CUDAStubs* stubs) {
|
||||
void registerCUDAMethods(CUDAStubs* stubs) {
|
||||
cuda_stubs = stubs;
|
||||
}
|
||||
|
||||
|
@ -91,11 +91,27 @@ inline int64_t getTime() {
|
||||
#endif
|
||||
}
|
||||
|
||||
enum class EventKind : uint16_t {
|
||||
// Old GCC versions generate warnings incorrectly
|
||||
// see https://stackoverflow.com/questions/2463113/g-c0x-enum-class-compiler-warnings
|
||||
#ifndef _MSC_VER
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wattributes"
|
||||
#endif
|
||||
enum class TORCH_API ProfilerState {
|
||||
Disabled,
|
||||
CPU, // CPU-only profiling
|
||||
CUDA, // CPU + CUDA events
|
||||
NVTX, // only emit NVTX markers
|
||||
};
|
||||
|
||||
enum class TORCH_API EventKind : uint16_t {
|
||||
Mark,
|
||||
PushRange,
|
||||
PopRange
|
||||
};
|
||||
#ifndef _MSC_VER
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
struct TORCH_API Event final {
|
||||
Event(EventKind kind, StringView name, uint16_t thread_id, bool record_cuda)
|
||||
@ -183,13 +199,6 @@ struct RangeEventList {
|
||||
std::forward_list<block_type> blocks;
|
||||
};
|
||||
|
||||
enum class ProfilerState {
|
||||
Disabled,
|
||||
CPU, // CPU-only profiling
|
||||
CUDA, // CPU + CUDA events
|
||||
NVTX, // only emit NVTX markers
|
||||
};
|
||||
|
||||
TORCH_API RangeEventList& getEventList();
|
||||
TORCH_API void mark(std::string name, bool include_cuda = true);
|
||||
TORCH_API void pushRange(std::string name);
|
||||
|
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/cuda/ATenCUDAGeneral.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
@ -11,11 +13,11 @@ namespace torch { namespace cuda {
|
||||
|
||||
using tensor_list2d = std::vector<std::vector<at::Tensor>>;
|
||||
|
||||
std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices);
|
||||
tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices,
|
||||
TORCH_API std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices);
|
||||
TORCH_API tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices,
|
||||
size_t buffer_size);
|
||||
|
||||
std::vector<at::Tensor> scatter(
|
||||
TORCH_API std::vector<at::Tensor> scatter(
|
||||
const at::Tensor& tensor,
|
||||
at::IntArrayRef devices,
|
||||
const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
|
||||
@ -23,7 +25,7 @@ std::vector<at::Tensor> scatter(
|
||||
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
|
||||
c10::nullopt);
|
||||
|
||||
at::Tensor gather(
|
||||
TORCH_API at::Tensor gather(
|
||||
at::TensorList tensors,
|
||||
int64_t dim,
|
||||
c10::optional<int32_t> destination_index);
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/variable_tensor_list.h>
|
||||
#include <torch/csrc/utils/hash.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
@ -133,7 +134,7 @@ struct ArgumentSpec {
|
||||
// ArgumentSpecCreator takes an initial graph and comes up with a set
|
||||
// of simple instructions to compute the ArgumentSpec given a set of
|
||||
// input tensors.
|
||||
struct ArgumentSpecCreator {
|
||||
struct TORCH_API ArgumentSpecCreator {
|
||||
// instructs acts on a stack of a list of input IValues
|
||||
// at the beginning the stack contains a single list of the inputs to the
|
||||
// function the ENTER_ instructs descend into subobjects and push new lists
|
||||
|
@ -84,7 +84,7 @@ struct Graph;
|
||||
|
||||
// We special case Graph attributes like this because we want to ensure that
|
||||
// Graph::copy() is called when we clone() these attributes.
|
||||
struct GraphAttr : public AttributeValue {
|
||||
struct TORCH_API GraphAttr : public AttributeValue {
|
||||
using ConstructorType = std::shared_ptr<Graph>;
|
||||
using ValueType = std::shared_ptr<Graph>;
|
||||
GraphAttr(Symbol name, ConstructorType value_)
|
||||
@ -92,7 +92,7 @@ struct GraphAttr : public AttributeValue {
|
||||
ValueType& value() {
|
||||
return value_;
|
||||
}
|
||||
TORCH_API Ptr clone() const override;
|
||||
Ptr clone() const override;
|
||||
AttributeKind kind() const override {
|
||||
return AttributeKind::g;
|
||||
}
|
||||
@ -101,7 +101,7 @@ struct GraphAttr : public AttributeValue {
|
||||
std::shared_ptr<Graph> value_;
|
||||
};
|
||||
|
||||
struct GraphsAttr : public AttributeValue {
|
||||
struct TORCH_API GraphsAttr : public AttributeValue {
|
||||
using ConstructorType = std::vector<std::shared_ptr<Graph>>;
|
||||
using ValueType = std::vector<std::shared_ptr<Graph>>;
|
||||
GraphsAttr(Symbol name, ConstructorType value_)
|
||||
@ -112,7 +112,7 @@ struct GraphsAttr : public AttributeValue {
|
||||
AttributeKind kind() const override {
|
||||
return AttributeKind::gs;
|
||||
}
|
||||
TORCH_API std::unique_ptr<AttributeValue> clone() const override;
|
||||
std::unique_ptr<AttributeValue> clone() const override;
|
||||
|
||||
private:
|
||||
ValueType value_;
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -11,11 +12,11 @@ namespace cpu {
|
||||
struct DynamicLibrary {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
||||
|
||||
DynamicLibrary(const char* name);
|
||||
TORCH_API DynamicLibrary(const char* name);
|
||||
|
||||
void* sym(const char* name);
|
||||
TORCH_API void* sym(const char* name);
|
||||
|
||||
~DynamicLibrary();
|
||||
TORCH_API ~DynamicLibrary();
|
||||
|
||||
private:
|
||||
void* handle = nullptr;
|
||||
|
@ -6,19 +6,19 @@ namespace jit {
|
||||
|
||||
static std::function<void(std::shared_ptr<script::Module> module)>
|
||||
emit_module_callback;
|
||||
TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module) {
|
||||
void didFinishEmitModule(std::shared_ptr<script::Module> module) {
|
||||
if (emit_module_callback) {
|
||||
emit_module_callback(std::move(module));
|
||||
}
|
||||
}
|
||||
static std::function<void(std::shared_ptr<script::Function> fn)>
|
||||
emit_function_callback;
|
||||
TORCH_API void didFinishEmitFunction(std::shared_ptr<script::Function> fn) {
|
||||
void didFinishEmitFunction(std::shared_ptr<script::Function> fn) {
|
||||
if (emit_function_callback) {
|
||||
emit_function_callback(fn);
|
||||
}
|
||||
}
|
||||
TORCH_API void setEmitHooks(
|
||||
void setEmitHooks(
|
||||
std::function<void(std::shared_ptr<script::Module> module)> for_mod,
|
||||
std::function<void(std::shared_ptr<script::Function> for_fn)> for_fn) {
|
||||
emit_module_callback = std::move(for_mod);
|
||||
|
@ -43,6 +43,7 @@
|
||||
#include <torch/csrc/jit/script/python_tree_views.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
@ -84,7 +85,7 @@ void runJITCPPTests(bool runCuda) {
|
||||
AT_ERROR("JIT tests not yet supported on Windows");
|
||||
}
|
||||
#else
|
||||
void runJITCPPTests(bool runCuda);
|
||||
CAFFE2_API void runJITCPPTests(bool runCuda);
|
||||
#endif
|
||||
|
||||
void initJITBindings(PyObject* module) {
|
||||
|
@ -47,11 +47,11 @@ struct TORCH_API Code {
|
||||
};
|
||||
|
||||
struct InterpreterState {
|
||||
InterpreterState(const Code& code);
|
||||
void run(Stack& stack);
|
||||
TORCH_API InterpreterState(const Code& code);
|
||||
TORCH_API void run(Stack& stack);
|
||||
c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||
c10::intrusive_ptr<Future> getFuture();
|
||||
~InterpreterState();
|
||||
TORCH_API ~InterpreterState();
|
||||
|
||||
private:
|
||||
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
|
||||
|
@ -33,10 +33,6 @@ static constexpr topo_position_t kMidPoint = 0;
|
||||
// - 2^(64-n) is the maximum number of appends to the end without reindex
|
||||
static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
|
||||
|
||||
// Sigh, see
|
||||
// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
|
||||
constexpr Symbol PythonOp::Kind;
|
||||
|
||||
static void printValueRef(std::ostream& out, const Value* n) {
|
||||
out << "%" << n->uniqueName();
|
||||
}
|
||||
|
@ -231,7 +231,7 @@ struct Value {
|
||||
TORCH_API Value* copyMetadata(Value* from);
|
||||
};
|
||||
|
||||
struct Node {
|
||||
struct TORCH_API Node {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Node);
|
||||
friend struct Graph;
|
||||
friend struct Block;
|
||||
@ -259,7 +259,7 @@ struct Node {
|
||||
topo_position_t topo_position_ = 0;
|
||||
|
||||
protected:
|
||||
TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
|
||||
Node(Graph* graph_, NodeKind kind_); // defined after graph
|
||||
public:
|
||||
// each node but Return/Param
|
||||
// is associated with exactly one place in the node list...
|
||||
@ -358,7 +358,7 @@ struct Node {
|
||||
return false;
|
||||
}
|
||||
|
||||
TORCH_API void replaceAllUsesWith(Node* n);
|
||||
void replaceAllUsesWith(Node* n);
|
||||
|
||||
// lots of things like chunk have a single input or single output, so we have
|
||||
// a helper to make accessing it easier
|
||||
@ -399,10 +399,10 @@ struct Node {
|
||||
bool is_constant(Symbol name) const {
|
||||
return static_cast<bool>(get(name));
|
||||
}
|
||||
TORCH_API bool mustBeNone() const;
|
||||
bool mustBeNone() const;
|
||||
|
||||
TORCH_API bool isNondeterministic() const;
|
||||
TORCH_API bool hasSideEffects() const;
|
||||
bool isNondeterministic() const;
|
||||
bool hasSideEffects() const;
|
||||
|
||||
// Graphs
|
||||
|
||||
@ -424,11 +424,11 @@ struct Node {
|
||||
// Given: %3 = f(%1, %2)
|
||||
// Execute: %3.addInput(%4)
|
||||
// Result: %3 = f(%1, %2, %4)
|
||||
TORCH_API Value* addInput(Value* value);
|
||||
Value* addInput(Value* value);
|
||||
|
||||
// Add 'value' as an input to 'this' at the specified position in the
|
||||
// arguments. Returns the added value for ease of chaining.
|
||||
TORCH_API Value* insertInput(size_t i, Value* value);
|
||||
Value* insertInput(size_t i, Value* value);
|
||||
|
||||
// Replace the input of 'this' at position 'i' with
|
||||
// 'newValue', returning the old node.
|
||||
@ -436,7 +436,7 @@ struct Node {
|
||||
// Given: %3 = f(%1, %2)
|
||||
// Execute: %3.replaceInput(1, %4)
|
||||
// Result: %3 = f(%1, %4)
|
||||
TORCH_API Value* replaceInput(size_t i, Value* newValue);
|
||||
Value* replaceInput(size_t i, Value* newValue);
|
||||
|
||||
// Replace all occurrences of 'from' in the inputs of this
|
||||
// node with 'to'. Corresponds to llvm's replaceUsesOfWith.
|
||||
@ -444,16 +444,16 @@ struct Node {
|
||||
// Given: %3 = f(%1, %2, %1)
|
||||
// Execute: %3.replaceInputWith(%1, %4)
|
||||
// Result: %3 = f(%4, %2, %4)
|
||||
TORCH_API void replaceInputWith(Value* from, Value* to);
|
||||
void replaceInputWith(Value* from, Value* to);
|
||||
|
||||
TORCH_API Value* addOutput();
|
||||
Value* addOutput();
|
||||
|
||||
TORCH_API Value* insertOutput(size_t i);
|
||||
Value* insertOutput(size_t i);
|
||||
|
||||
TORCH_API void eraseOutput(size_t i);
|
||||
void eraseOutput(size_t i);
|
||||
|
||||
TORCH_API Block* addBlock();
|
||||
TORCH_API void eraseBlock(size_t i);
|
||||
Block* addBlock();
|
||||
void eraseBlock(size_t i);
|
||||
|
||||
// Each Node can have a list of subblocks. These are used to define structured
|
||||
// nested control flow operators such as If and Loop.
|
||||
@ -482,10 +482,10 @@ struct Node {
|
||||
}
|
||||
|
||||
// Is 'this' before 'n' in the topological order?
|
||||
TORCH_API bool isBefore(const Node* n) const;
|
||||
bool isBefore(const Node* n) const;
|
||||
|
||||
// Is 'this' after 'n' in the topological order?
|
||||
TORCH_API bool isAfter(const Node* n) const;
|
||||
bool isAfter(const Node* n) const;
|
||||
|
||||
// Insert unattached 'this' node before 'n' in the topological order.
|
||||
// Returns this (for chaining).
|
||||
@ -497,7 +497,7 @@ struct Node {
|
||||
// Result: %3 = f(%1, %2)
|
||||
// %5 = h(%1)
|
||||
// %4 = g(%3)
|
||||
TORCH_API Node* insertBefore(Node* n);
|
||||
Node* insertBefore(Node* n);
|
||||
|
||||
// Insert unattached 'this' node after 'n' in the topological order.
|
||||
// Returns this (for chaining).
|
||||
@ -509,7 +509,7 @@ struct Node {
|
||||
// Result: %3 = f(%1, %2)
|
||||
// %4 = g(%3)
|
||||
// %5 = h(%1)
|
||||
TORCH_API Node* insertAfter(Node* n);
|
||||
Node* insertAfter(Node* n);
|
||||
|
||||
// Move 'this' (already in the graph) after 'n' in the topological order.
|
||||
//
|
||||
@ -522,7 +522,7 @@ struct Node {
|
||||
// Result: %3 = g(%1)
|
||||
// %2 = f(%1)
|
||||
//
|
||||
TORCH_API void moveAfter(Node* n);
|
||||
void moveAfter(Node* n);
|
||||
|
||||
// Move a node 'n' (already in the graph) before 'this' in the topological
|
||||
// order.
|
||||
@ -535,7 +535,7 @@ struct Node {
|
||||
// Execute: %3.moveBefore(%2)
|
||||
// Result: %3 = g(%1)
|
||||
// %2 = f(%1)
|
||||
TORCH_API void moveBefore(Node* n);
|
||||
void moveBefore(Node* n);
|
||||
|
||||
// Remove the input at 'i' from this node.
|
||||
//
|
||||
@ -545,14 +545,14 @@ struct Node {
|
||||
// Given: %3 = f(%1, %2)
|
||||
// Execute: %3.removeInput(1)
|
||||
// Result: %3 = f(%1)
|
||||
TORCH_API void removeInput(size_t i);
|
||||
void removeInput(size_t i);
|
||||
|
||||
// Remove all inputs from a node.
|
||||
//
|
||||
// Given: %3 = f(%1, %2)
|
||||
// Execute: %3.removeAllInputs()
|
||||
// Result: %3 = f()
|
||||
TORCH_API void removeAllInputs();
|
||||
void removeAllInputs();
|
||||
|
||||
// iterators of the node list starting at this node
|
||||
// useful for resuming a search starting at this node
|
||||
@ -577,7 +577,7 @@ struct Node {
|
||||
// %3 = g(%1)
|
||||
// Execute: %2.destroy()
|
||||
// Result: %3 = g(%1)
|
||||
TORCH_API void destroy();
|
||||
void destroy();
|
||||
|
||||
// Dynamically cast this node to the subclass indicated by the
|
||||
// template variable, returning nullptr if the cast is invalid..
|
||||
@ -604,11 +604,11 @@ struct Node {
|
||||
}
|
||||
|
||||
// XXX: this function is meant to be used with string literals only!
|
||||
TORCH_API bool matches(
|
||||
bool matches(
|
||||
const char* signature_literal,
|
||||
at::ArrayRef<Symbol> const_inputs = {}) const;
|
||||
|
||||
TORCH_API const FunctionSchema& schema() const {
|
||||
const FunctionSchema& schema() const {
|
||||
if (!schema_) {
|
||||
findSchema();
|
||||
}
|
||||
@ -782,15 +782,15 @@ struct Node {
|
||||
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
|
||||
|
||||
std::pair<Value*, const Argument&> findInput(Symbol name);
|
||||
TORCH_API void findSchema() const;
|
||||
void findSchema() const;
|
||||
// Lookup iterator in use list of _input i_ that corresponds to its use of
|
||||
// _this_
|
||||
TORCH_API use_list::iterator findUseForInput(size_t i);
|
||||
use_list::iterator findUseForInput(size_t i);
|
||||
|
||||
// remove the use of input i, this sets input i to nullptr, but
|
||||
// is only used internally to Node before setting it to a new value
|
||||
// or erasing the entry from the list.
|
||||
TORCH_API Value* dropInput(size_t i);
|
||||
Value* dropInput(size_t i);
|
||||
|
||||
bool inBlockList() const {
|
||||
if (next() == nullptr) {
|
||||
@ -799,8 +799,8 @@ struct Node {
|
||||
return next() != nullptr;
|
||||
}
|
||||
|
||||
TORCH_API void removeFromList();
|
||||
TORCH_API void lint() const;
|
||||
void removeFromList();
|
||||
void lint() const;
|
||||
|
||||
void assignTopoPosition();
|
||||
|
||||
@ -818,7 +818,7 @@ struct Node {
|
||||
// 'this' will be allocated with s->allocNewInstance(g) so it should have
|
||||
// the same concrete type as 's'
|
||||
//
|
||||
TORCH_API virtual void cloneFrom(Node* s);
|
||||
virtual void cloneFrom(Node* s);
|
||||
};
|
||||
|
||||
struct Block {
|
||||
@ -1262,9 +1262,7 @@ struct ProfileOp : public Node {
|
||||
// which is not included in libtorch.so. We still include some bits and pieces
|
||||
// of PythonOp here to enable writing simple passes generically. In general,
|
||||
// python-aware bits need to be moved to the descendant classes.
|
||||
struct PythonOp : public Node {
|
||||
static constexpr Symbol Kind = ::c10::prim::PythonOp;
|
||||
|
||||
struct TORCH_API PythonOp : public Node {
|
||||
using Node::Node;
|
||||
|
||||
// should this Python function be skipped over when exported (i.e. for
|
||||
|
@ -2,6 +2,8 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <caffe2/proto/caffe2_pb.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <unordered_map>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -15,7 +16,7 @@ namespace jit {
|
||||
* \p Prefix can be used for appending some string to every operator name (e.g.
|
||||
* we can add "caffe2::").
|
||||
*/
|
||||
void convertNetDefToIR(
|
||||
TORCH_API void convertNetDefToIR(
|
||||
const caffe2::NetDef& net,
|
||||
Graph* graph,
|
||||
std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
|
||||
@ -34,7 +35,7 @@ void convertNetDefToIR(
|
||||
* TODO: We might need to do a better job at preserving names of the variables,
|
||||
* especially external_inputs/external_outputs.
|
||||
*/
|
||||
void convertIRToNetDef(
|
||||
TORCH_API void convertIRToNetDef(
|
||||
caffe2::NetDef* net,
|
||||
const Graph& graph,
|
||||
const std::string& prefix = "");
|
||||
|
@ -19,9 +19,9 @@ namespace jit {
|
||||
// A pass modifies a Graph in place.
|
||||
using Pass = std::function<void(std::shared_ptr<Graph>&)>;
|
||||
|
||||
std::vector<Pass>& getCustomPasses();
|
||||
TORCH_API std::vector<Pass>& getCustomPasses();
|
||||
|
||||
struct RegisterPass {
|
||||
struct TORCH_API RegisterPass {
|
||||
RegisterPass(Pass p);
|
||||
};
|
||||
|
||||
|
@ -1265,7 +1265,7 @@ c10::optional<const Node*> AliasDb::getLastWildcard() const {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
// WARNING: by adding a case to this list, you are asserting that you have
|
||||
// added a case for the unschematized node in AliasDb::analyze
|
||||
const static std::unordered_set<Symbol> handled = {
|
||||
|
@ -42,7 +42,7 @@ class AliasDb {
|
||||
|
||||
// Does `n` write to an alias of one of the values in `vs`?
|
||||
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
|
||||
bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
|
||||
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
|
||||
const;
|
||||
|
||||
// Does `a` and `b` potentially share a memory location or do either
|
||||
@ -56,7 +56,7 @@ class AliasDb {
|
||||
const at::ArrayRef<Value*>& b) const;
|
||||
|
||||
// Do `a` and `b` potentially share a memory location?
|
||||
bool mayAlias(const Value* a, const Value* b) const;
|
||||
TORCH_API bool mayAlias(const Value* a, const Value* b) const;
|
||||
// Do any values in group `a` potentially share a memory location with any
|
||||
// value in group `b`? i.e. may they overlap?
|
||||
//
|
||||
@ -104,7 +104,7 @@ class AliasDb {
|
||||
}
|
||||
|
||||
// Do any nodes write to an alias set inputed/outputed by `n`?
|
||||
bool hasWriters(const Node* n) const;
|
||||
TORCH_API bool hasWriters(const Node* n) const;
|
||||
|
||||
// Move 'n' (already in the graph) after 'movePoint' in the topological order.
|
||||
//
|
||||
@ -115,8 +115,8 @@ class AliasDb {
|
||||
//
|
||||
// Returns `false` if it's impossible to move `n` after `MovePoint` without
|
||||
// violating dependencies, otherwise executes the move and returns `true`
|
||||
bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
|
||||
bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
|
||||
TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
|
||||
TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
|
||||
|
||||
bool couldMoveAfterTopologically(Node* n, Node* movePoint);
|
||||
bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
|
||||
@ -146,7 +146,7 @@ class AliasDb {
|
||||
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
|
||||
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
|
||||
// Do any nodes write to `v`s memory location?
|
||||
bool hasWriters(const Value* v) const;
|
||||
TORCH_API bool hasWriters(const Value* v) const;
|
||||
// Register the fact that `n` writes to `v`.
|
||||
void registerWrite(const Value* v, Node* n);
|
||||
// Get all the values that `n` reads from.
|
||||
@ -163,7 +163,7 @@ class AliasDb {
|
||||
* Wildcard methods
|
||||
*/
|
||||
// is `v` a wildcard?
|
||||
bool isWildcard(const Value* v) const;
|
||||
TORCH_API bool isWildcard(const Value* v) const;
|
||||
// Register `v` as a wildcard value.
|
||||
void setWildcard(const Value* v);
|
||||
// Get all nodes that write to a wildcard value.
|
||||
|
@ -855,7 +855,7 @@ struct PythonPrintPass {
|
||||
// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
|
||||
void printRHS(std::ostream& stmt, Node* node) {
|
||||
switch (node->kind()) {
|
||||
case PythonOp::Kind: {
|
||||
case prim::PythonOp: {
|
||||
auto value = static_cast<const PythonOp*>(node);
|
||||
if (enforce_importable_) {
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
@ -1111,7 +1111,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Function& func,
|
||||
bool is_method,
|
||||
@ -1124,7 +1124,7 @@ TORCH_API void PythonPrint(
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::CompilationUnit& cu,
|
||||
bool is_method,
|
||||
@ -1137,7 +1137,7 @@ TORCH_API void PythonPrint(
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
void PythonPrint(
|
||||
std::ostream& out,
|
||||
const ClassTypePtr& classType,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
@ -1148,7 +1148,7 @@ TORCH_API void PythonPrint(
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
|
||||
bool printerHasSpecialCaseFor(Symbol sym) {
|
||||
// WARNING: by adding a value to this set, you are asserting
|
||||
// that you have also added special handling of this symbol to
|
||||
// the printer above. Not adding handling will cause import and export
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
@ -30,17 +32,17 @@ struct Value;
|
||||
class MemoryDAG {
|
||||
public:
|
||||
// Make `from` point at `to`.
|
||||
void makePointerTo(Element* from, Element* to);
|
||||
TORCH_API void makePointerTo(Element* from, Element* to);
|
||||
|
||||
void addToContainedElements(Element* contained, Element* container);
|
||||
|
||||
// Make a fresh element (i.e. an element that doesn't point to anything) and
|
||||
// return it.
|
||||
Element* makeFreshValue(const Value* v);
|
||||
TORCH_API Element* makeFreshValue(const Value* v);
|
||||
|
||||
// Do `a` and `b` potentially share a memory location?
|
||||
bool mayAlias(const Element* a, const Element* b) const;
|
||||
bool mayAlias(Element* a, Element* b) const;
|
||||
TORCH_API bool mayAlias(Element* a, Element* b) const;
|
||||
|
||||
// Does a hold reference to any memory that is stored in elem, or vice versa?
|
||||
bool mayContainAlias(const Element* a, const Element* b) const;
|
||||
@ -124,7 +126,7 @@ struct Element {
|
||||
std::unordered_set<Element*> contained_elements;
|
||||
|
||||
// Return the unique memory locations that `Element` might represent.
|
||||
std::unordered_set<const Element*> getMemoryLocations() const;
|
||||
TORCH_API std::unordered_set<const Element*> getMemoryLocations() const;
|
||||
// We do path compression to make repeated memory location queries faster.
|
||||
// An empty cache means it is invalidated (it can never be empty otherwise,
|
||||
// since every element must point to at least one memory location).
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -17,16 +18,16 @@ namespace SubgraphUtils {
|
||||
// `n` is destroyed.
|
||||
//
|
||||
// Returns the new subgraph node.
|
||||
Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
|
||||
TORCH_API Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
|
||||
|
||||
// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
|
||||
// subgraphs are merged.
|
||||
// `toMerge` is destroyed.
|
||||
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
|
||||
TORCH_API void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
|
||||
|
||||
// Move nodes from a subgraph node to the outer graph.
|
||||
// `subgraphNode` is destroyed.
|
||||
void unmergeSubgraph(Node* subgraphNode);
|
||||
TORCH_API void unmergeSubgraph(Node* subgraphNode);
|
||||
|
||||
// Convenience function
|
||||
std::shared_ptr<Graph> getSubgraph(Node* n);
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
@ -21,7 +22,7 @@ struct ProfilingRecord {
|
||||
ProfilingRecord(const ProfilingRecord&) = delete;
|
||||
ProfilingRecord(ProfilingRecord&&) noexcept = delete;
|
||||
static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival);
|
||||
static std::unique_ptr<ProfilingRecord> instrumentGraph(
|
||||
TORCH_API static std::unique_ptr<ProfilingRecord> instrumentGraph(
|
||||
const std::shared_ptr<Graph>& graph);
|
||||
|
||||
std::shared_ptr<Graph> profiled_graph_;
|
||||
|
@ -19,6 +19,8 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
Symbol ConcretePythonOp::Kind = prim::PythonOp;
|
||||
|
||||
using c10::Type;
|
||||
|
||||
std::string getPythonName(const PyObject* obj_) {
|
||||
|
@ -11,6 +11,8 @@ void initPythonIRBindings(PyObject* module);
|
||||
// execute a Python function, used for Ops we can't optimize but that we want to
|
||||
// optimize around
|
||||
struct ConcretePythonOp : public PythonOp {
|
||||
static Symbol Kind;
|
||||
|
||||
ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {}
|
||||
ConcretePythonOp* init(
|
||||
THPObjectPtr&& pyobj,
|
||||
|
@ -100,7 +100,7 @@ struct BuiltinFunctionRegistry {
|
||||
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
|
||||
};
|
||||
|
||||
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
||||
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
||||
static BuiltinFunctionRegistry registry;
|
||||
return registry.getAllBuiltinFunctionsFor(name);
|
||||
}
|
||||
|
9
torch/csrc/jit/script/jit_exception.cpp
Normal file
9
torch/csrc/jit/script/jit_exception.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <torch/csrc/jit/script/jit_exception.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
JITException::JITException(const std::string& msg) : std::runtime_error(msg) {}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -2,12 +2,13 @@
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct JITException : public std::runtime_error {
|
||||
JITException() = default;
|
||||
explicit JITException(const std::string& msg) : std::runtime_error(msg) {}
|
||||
struct TORCH_API JITException : public std::runtime_error {
|
||||
explicit JITException(const std::string& msg);
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
#include <torch/csrc/jit/script/strtod.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/core/Macros.h>
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "torch/csrc/jit/script/logging.h"
|
||||
#include <torch/csrc/jit/script/logging.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
@ -17,7 +17,7 @@ void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
|
||||
raw_counter.count++;
|
||||
}
|
||||
|
||||
TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
||||
int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
||||
std::unique_lock<std::mutex> lk(m);
|
||||
if (!raw_counters.count(name)) {
|
||||
return 0;
|
||||
|
@ -37,12 +37,12 @@ class NoopLogger : public LoggerBase {
|
||||
//
|
||||
// NOTE: this is not written in a scalable way and should probably only be used
|
||||
// in the single-threaded case or for testing.
|
||||
class LockingLogger : public LoggerBase {
|
||||
class TORCH_API LockingLogger : public LoggerBase {
|
||||
public:
|
||||
TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override;
|
||||
TORCH_API virtual int64_t getCounterValue(const std::string& name) const;
|
||||
void addStatValue(const std::string& stat_name, int64_t val) override;
|
||||
virtual int64_t getCounterValue(const std::string& name) const;
|
||||
enum class AggregationType { SUM, AVG };
|
||||
TORCH_API void setAggregationType(
|
||||
void setAggregationType(
|
||||
const std::string& stat_name,
|
||||
AggregationType type);
|
||||
~LockingLogger() {}
|
||||
|
@ -22,7 +22,7 @@ namespace script {
|
||||
|
||||
enum NoneStatus { ALWAYS, MAYBE, NEVER };
|
||||
|
||||
struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
|
||||
struct TORCH_API SugaredValue : public std::enable_shared_from_this<SugaredValue> {
|
||||
// what is this node? for error reporting (e.g. Module, python function)
|
||||
virtual std::string kind() const = 0;
|
||||
|
||||
|
@ -39,7 +39,7 @@ void badArgType(const T& v) {
|
||||
thread_local std::shared_ptr<TracingState> tracing_state;
|
||||
} // namespace detail
|
||||
|
||||
TORCH_API std::function<void()> pauseTracing() {
|
||||
std::function<void()> pauseTracing() {
|
||||
// NOLINTNEXTLINE
|
||||
std::shared_ptr<tracer::TracingState> state = getTracingState();
|
||||
tracer::setTracingState(nullptr);
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <torch/csrc/generic/utils.cpp>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/generic/utils.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <utility>
|
||||
|
||||
@ -59,16 +59,16 @@ struct TensorGroup {
|
||||
// enough tensors for all data types until the size_limit, and then split
|
||||
// the accumulated tensors into different groups by data types, therefore:
|
||||
// it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
|
||||
std::vector<TensorGroup> take_tensors(
|
||||
TORCH_API std::vector<TensorGroup> take_tensors(
|
||||
at::TensorList tensors,
|
||||
size_t size_limit,
|
||||
bool fine_grained = false);
|
||||
|
||||
void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order);
|
||||
TORCH_API void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order);
|
||||
|
||||
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors);
|
||||
TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors);
|
||||
|
||||
std::vector<at::Tensor> unflatten_sparse_tensors(
|
||||
TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors(
|
||||
const at::Tensor& flat_indices,
|
||||
const at::Tensor& flat_values,
|
||||
at::TensorList tensors);
|
||||
|
Reference in New Issue
Block a user