mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
C++ changes toward libtorch and libcaffe2 unification (#19554)
Summary: * adds TORCH_API and AT_CUDA_API in places * refactor code generation Python logic to separate caffe2/torch outputs * fix hip and asan * remove profiler_cuda from hip * fix gcc warnings for enums * Fix PythonOp::Kind Pull Request resolved: https://github.com/pytorch/pytorch/pull/19554 Differential Revision: D15082727 Pulled By: kostmo fbshipit-source-id: 83a8a99717f025ab44b29608848928d76b3147a4
This commit is contained in:
committed by
Facebook Github Bot
parent
9d180e602f
commit
8f0603b128
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#if !defined(AT_CORE_STATIC_WINDOWS)
|
#if !defined(AT_CORE_STATIC_WINDOWS)
|
||||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
|
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(caffe2_hip_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
|
||||||
# define AT_CUDA_API __declspec(dllexport)
|
# define AT_CUDA_API __declspec(dllexport)
|
||||||
# else
|
# else
|
||||||
# define AT_CUDA_API __declspec(dllimport)
|
# define AT_CUDA_API __declspec(dllimport)
|
||||||
@ -15,7 +15,7 @@
|
|||||||
# define AT_CUDA_API
|
# define AT_CUDA_API
|
||||||
#endif
|
#endif
|
||||||
#elif defined(__GNUC__)
|
#elif defined(__GNUC__)
|
||||||
#if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS)
|
#if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(caffe2_hip_EXPORTS)
|
||||||
#define AT_CUDA_API __attribute__((__visibility__("default")))
|
#define AT_CUDA_API __attribute__((__visibility__("default")))
|
||||||
#else
|
#else
|
||||||
#define AT_CUDA_API
|
#define AT_CUDA_API
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// ROCM hcc doesn't work well with using std:: in kernel functions
|
// ROCM hcc doesn't work well with using std:: in kernel functions
|
||||||
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
|
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||||
#include <c10/cuda/CUDAMathCompat.h>
|
#include <c10/cuda/CUDAMathCompat.h>
|
||||||
#define compat_pow c10::cuda::compat::pow
|
#define compat_pow c10::cuda::compat::pow
|
||||||
#else
|
#else
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash -xe
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# Example command to build the iOS target.
|
# Example command to build the iOS target.
|
||||||
##############################################################################
|
##############################################################################
|
||||||
@ -7,8 +7,6 @@
|
|||||||
# using ios-cmake. This is very similar to the android-cmake - see
|
# using ios-cmake. This is very similar to the android-cmake - see
|
||||||
# build_android.sh for more details.
|
# build_android.sh for more details.
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)"
|
CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)"
|
||||||
|
|
||||||
# Build protobuf from third_party so we have a host protoc binary.
|
# Build protobuf from third_party so we have a host protoc binary.
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
|
||||||
// To add a new test file:
|
// To add a new test file:
|
||||||
// 1. Add a test_foo.h file in this directory
|
// 1. Add a test_foo.h file in this directory
|
||||||
// 2. include test_base.h
|
// 2. include test_base.h
|
||||||
|
|||||||
@ -194,14 +194,41 @@ def gen_autograd(aten_path, out, autograd_dir):
|
|||||||
gen_variable_type(out, aten_decls, template_path)
|
gen_variable_type(out, aten_decls, template_path)
|
||||||
|
|
||||||
# Generate Functions.h/cpp
|
# Generate Functions.h/cpp
|
||||||
from .gen_autograd_functions import gen_autograd_functions
|
from .gen_autograd_functions import gen_autograd_functions_lib
|
||||||
gen_autograd_functions(
|
gen_autograd_functions_lib(
|
||||||
out, autograd_functions, template_path)
|
out, autograd_functions, template_path)
|
||||||
|
|
||||||
# Load deprecated signatures
|
# Load deprecated signatures
|
||||||
deprecated = load_deprecated_signatures(
|
deprecated = load_deprecated_signatures(
|
||||||
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
|
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
|
||||||
|
|
||||||
|
# Generate variable_factories.h
|
||||||
|
from .gen_variable_factories import gen_variable_factories
|
||||||
|
gen_variable_factories(out, aten_decls, template_path)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_autograd_python(aten_path, out, autograd_dir):
|
||||||
|
|
||||||
|
# TODO Deduplicate these four variable assignments
|
||||||
|
|
||||||
|
aten_decls = load_aten_declarations(aten_path)
|
||||||
|
|
||||||
|
# Parse and load derivatives.yaml
|
||||||
|
from .load_derivatives import load_derivatives
|
||||||
|
autograd_functions = load_derivatives(
|
||||||
|
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
|
||||||
|
|
||||||
|
template_path = os.path.join(autograd_dir, 'templates')
|
||||||
|
|
||||||
|
# Load deprecated signatures
|
||||||
|
deprecated = load_deprecated_signatures(
|
||||||
|
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
|
||||||
|
|
||||||
|
# Generate Functions.h/cpp
|
||||||
|
from .gen_autograd_functions import gen_autograd_functions_python
|
||||||
|
gen_autograd_functions_python(
|
||||||
|
out, autograd_functions, template_path)
|
||||||
|
|
||||||
# Generate Python bindings
|
# Generate Python bindings
|
||||||
from . import gen_python_functions
|
from . import gen_python_functions
|
||||||
gen_python_functions.gen_py_variable_methods(
|
gen_python_functions.gen_py_variable_methods(
|
||||||
@ -211,10 +238,6 @@ def gen_autograd(aten_path, out, autograd_dir):
|
|||||||
gen_python_functions.gen_py_nn_functions(
|
gen_python_functions.gen_py_nn_functions(
|
||||||
out, aten_decls, template_path)
|
out, aten_decls, template_path)
|
||||||
|
|
||||||
# Generate variable_factories.h
|
|
||||||
from .gen_variable_factories import gen_variable_factories
|
|
||||||
gen_variable_factories(out, aten_decls, template_path)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@ -4,13 +4,14 @@
|
|||||||
# Functions.h/cpp: subclasses of autograd::Function
|
# Functions.h/cpp: subclasses of autograd::Function
|
||||||
# python_functions.h/cpp: Python bindings for the above classes
|
# python_functions.h/cpp: Python bindings for the above classes
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from .utils import nested_dict, CodeTemplate, write
|
from .utils import nested_dict, CodeTemplate, write
|
||||||
from .gen_autograd import VIEW_FUNCTIONS
|
from .gen_autograd import VIEW_FUNCTIONS
|
||||||
from .utils import IDENT_REGEX
|
from .utils import IDENT_REGEX
|
||||||
|
|
||||||
FUNCTION_DECLARATION = CodeTemplate("""\
|
FUNCTION_DECLARATION = CodeTemplate("""\
|
||||||
struct ${op} : public ${superclass} {
|
struct TORCH_API ${op} : public ${superclass} {
|
||||||
using ${superclass}::${superclass};
|
using ${superclass}::${superclass};
|
||||||
variable_list apply(variable_list&& grads) override;
|
variable_list apply(variable_list&& grads) override;
|
||||||
std::string name() const override { return "${op}"; }
|
std::string name() const override { return "${op}"; }
|
||||||
@ -81,18 +82,21 @@ if (should_compute_output({ ${idx_ranges} })) {
|
|||||||
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
def gen_autograd_functions(out, autograd_functions, template_path):
|
def gen_autograd_functions_lib(out, autograd_functions, template_path):
|
||||||
|
gen_autograd_functions(out, autograd_functions, template_path, "Functions")
|
||||||
|
|
||||||
|
|
||||||
|
def gen_autograd_functions_python(out, autograd_functions, template_path):
|
||||||
|
gen_autograd_functions(out, autograd_functions, template_path, "python_functions")
|
||||||
|
|
||||||
|
|
||||||
|
def gen_autograd_functions(out, autograd_functions, template_path, file_basename):
|
||||||
"""Functions.h and Functions.cpp body
|
"""Functions.h and Functions.cpp body
|
||||||
|
|
||||||
These contain the auto-generated subclasses of torch::autograd::Function
|
These contain the auto-generated subclasses of torch::autograd::Function
|
||||||
for each every differentiable torch function.
|
for each every differentiable torch function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
|
|
||||||
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
|
|
||||||
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
|
|
||||||
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
|
|
||||||
|
|
||||||
function_definitions = []
|
function_definitions = []
|
||||||
function_declarations = []
|
function_declarations = []
|
||||||
py_function_initializers = []
|
py_function_initializers = []
|
||||||
@ -110,10 +114,10 @@ def gen_autograd_functions(out, autograd_functions, template_path):
|
|||||||
'py_function_initializers': py_function_initializers,
|
'py_function_initializers': py_function_initializers,
|
||||||
}
|
}
|
||||||
|
|
||||||
write(out, 'Functions.h', FUNCTIONS_H, top_env)
|
for suffix in [".h", ".cpp"]:
|
||||||
write(out, 'Functions.cpp', FUNCTIONS_CPP, top_env)
|
f = file_basename + suffix
|
||||||
write(out, 'python_functions.h', PY_FUNCTIONS_H, top_env)
|
templated_output = CodeTemplate.from_file(os.path.join(template_path, f))
|
||||||
write(out, 'python_functions.cpp', PY_FUNCTIONS_CPP, top_env)
|
write(out, f, templated_output, top_env)
|
||||||
|
|
||||||
|
|
||||||
def process_function(func):
|
def process_function(func):
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "torch/csrc/autograd/function.h"
|
#include "torch/csrc/autograd/function.h"
|
||||||
#include "torch/csrc/autograd/variable.h"
|
#include "torch/csrc/autograd/variable.h"
|
||||||
#include "torch/csrc/autograd/saved_variable.h"
|
#include "torch/csrc/autograd/saved_variable.h"
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch { namespace autograd { namespace generated {
|
namespace torch { namespace autograd { namespace generated {
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,7 @@ libtorch_sources = [
|
|||||||
"torch/csrc/autograd/anomaly_mode.cpp",
|
"torch/csrc/autograd/anomaly_mode.cpp",
|
||||||
"torch/csrc/autograd/engine.cpp",
|
"torch/csrc/autograd/engine.cpp",
|
||||||
"torch/csrc/autograd/function.cpp",
|
"torch/csrc/autograd/function.cpp",
|
||||||
|
"torch/csrc/autograd/function_hook.cpp",
|
||||||
"torch/csrc/autograd/functions/accumulate_grad.cpp",
|
"torch/csrc/autograd/functions/accumulate_grad.cpp",
|
||||||
"torch/csrc/autograd/functions/basic_ops.cpp",
|
"torch/csrc/autograd/functions/basic_ops.cpp",
|
||||||
"torch/csrc/autograd/functions/tensor.cpp",
|
"torch/csrc/autograd/functions/tensor.cpp",
|
||||||
@ -107,6 +108,7 @@ libtorch_sources = [
|
|||||||
"torch/csrc/jit/script/schema_matching.cpp",
|
"torch/csrc/jit/script/schema_matching.cpp",
|
||||||
"torch/csrc/jit/script/class_type.cpp",
|
"torch/csrc/jit/script/class_type.cpp",
|
||||||
"torch/csrc/jit/script/parser.cpp",
|
"torch/csrc/jit/script/parser.cpp",
|
||||||
|
"torch/csrc/jit/script/jit_exception.cpp",
|
||||||
"torch/csrc/jit/testing/file_check.cpp",
|
"torch/csrc/jit/testing/file_check.cpp",
|
||||||
"torch/csrc/jit/import_source.cpp",
|
"torch/csrc/jit/import_source.cpp",
|
||||||
"torch/csrc/jit/hooks_for_testing.cpp",
|
"torch/csrc/jit/hooks_for_testing.cpp",
|
||||||
|
|||||||
@ -19,59 +19,33 @@ def all_generator_source():
|
|||||||
return sorted(r)
|
return sorted(r)
|
||||||
|
|
||||||
|
|
||||||
inputs = [
|
|
||||||
'torch/lib/THNN.h',
|
|
||||||
'torch/lib/THCUNN.h',
|
|
||||||
'torch/share/ATen/Declarations.yaml',
|
|
||||||
'tools/autograd/derivatives.yaml',
|
|
||||||
'tools/autograd/deprecated.yaml',
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = [
|
|
||||||
'torch/csrc/autograd/generated/Functions.cpp',
|
|
||||||
'torch/csrc/autograd/generated/Functions.h',
|
|
||||||
'torch/csrc/autograd/generated/python_functions.cpp',
|
|
||||||
'torch/csrc/autograd/generated/python_functions.h',
|
|
||||||
'torch/csrc/autograd/generated/python_nn_functions.cpp',
|
|
||||||
'torch/csrc/autograd/generated/python_nn_functions.h',
|
|
||||||
'torch/csrc/autograd/generated/python_nn_functions_dispatch.h',
|
|
||||||
'torch/csrc/autograd/generated/python_variable_methods.cpp',
|
|
||||||
'torch/csrc/autograd/generated/python_variable_methods_dispatch.h',
|
|
||||||
'torch/csrc/autograd/generated/variable_factories.h',
|
|
||||||
'torch/csrc/autograd/generated/VariableType_0.cpp',
|
|
||||||
'torch/csrc/autograd/generated/VariableType_1.cpp',
|
|
||||||
'torch/csrc/autograd/generated/VariableType_2.cpp',
|
|
||||||
'torch/csrc/autograd/generated/VariableType_3.cpp',
|
|
||||||
'torch/csrc/autograd/generated/VariableType_4.cpp',
|
|
||||||
'torch/csrc/autograd/generated/VariableType.h',
|
|
||||||
'torch/csrc/jit/generated/register_aten_ops_0.cpp',
|
|
||||||
'torch/csrc/jit/generated/register_aten_ops_1.cpp',
|
|
||||||
'torch/csrc/jit/generated/register_aten_ops_2.cpp',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_code(ninja_global=None,
|
def generate_code(ninja_global=None,
|
||||||
declarations_path=None,
|
declarations_path=None,
|
||||||
nn_path=None,
|
nn_path=None,
|
||||||
install_dir=None):
|
install_dir=None,
|
||||||
|
subset=None):
|
||||||
# cwrap depends on pyyaml, so we can't import it earlier
|
# cwrap depends on pyyaml, so we can't import it earlier
|
||||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.insert(0, root)
|
sys.path.insert(0, root)
|
||||||
from tools.autograd.gen_autograd import gen_autograd
|
from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
|
||||||
from tools.jit.gen_jit_dispatch import gen_jit_dispatch
|
from tools.jit.gen_jit_dispatch import gen_jit_dispatch
|
||||||
|
|
||||||
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
|
|
||||||
|
|
||||||
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
|
|
||||||
# used by the legacy NN bindings.
|
|
||||||
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
|
|
||||||
|
|
||||||
# Build ATen based Variable classes
|
# Build ATen based Variable classes
|
||||||
autograd_gen_dir = install_dir or 'torch/csrc/autograd/generated'
|
autograd_gen_dir = install_dir or 'torch/csrc/autograd/generated'
|
||||||
jit_gen_dir = install_dir or 'torch/csrc/jit/generated'
|
jit_gen_dir = install_dir or 'torch/csrc/jit/generated'
|
||||||
for d in (autograd_gen_dir, jit_gen_dir):
|
for d in (autograd_gen_dir, jit_gen_dir):
|
||||||
if not os.path.exists(d):
|
if not os.path.exists(d):
|
||||||
os.makedirs(d)
|
os.makedirs(d)
|
||||||
|
|
||||||
|
if subset == "pybindings" or not subset:
|
||||||
|
# Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily
|
||||||
|
# used by the legacy NN bindings.
|
||||||
|
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
|
||||||
|
generate_nn_wrappers(nn_path, install_dir, 'tools/cwrap/plugins/templates')
|
||||||
|
|
||||||
|
gen_autograd_python(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
||||||
|
|
||||||
|
if subset == "libtorch" or not subset:
|
||||||
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd')
|
||||||
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates')
|
gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates')
|
||||||
|
|
||||||
@ -82,11 +56,18 @@ def main():
|
|||||||
parser.add_argument('--nn-path')
|
parser.add_argument('--nn-path')
|
||||||
parser.add_argument('--ninja-global')
|
parser.add_argument('--ninja-global')
|
||||||
parser.add_argument('--install_dir')
|
parser.add_argument('--install_dir')
|
||||||
|
parser.add_argument(
|
||||||
|
'--subset',
|
||||||
|
help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
|
||||||
|
)
|
||||||
options = parser.parse_args()
|
options = parser.parse_args()
|
||||||
generate_code(options.ninja_global,
|
generate_code(
|
||||||
|
options.ninja_global,
|
||||||
options.declarations_path,
|
options.declarations_path,
|
||||||
options.nn_path,
|
options.nn_path,
|
||||||
options.install_dir)
|
options.install_dir,
|
||||||
|
options.subset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -104,6 +104,7 @@ set(TORCH_SRCS
|
|||||||
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
|
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/autograd/engine.cpp
|
${TORCH_SRC_DIR}/csrc/autograd/engine.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/autograd/function.cpp
|
${TORCH_SRC_DIR}/csrc/autograd/function.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/autograd/function_hook.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/autograd/functions/accumulate_grad.cpp
|
${TORCH_SRC_DIR}/csrc/autograd/functions/accumulate_grad.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/autograd/functions/basic_ops.cpp
|
${TORCH_SRC_DIR}/csrc/autograd/functions/basic_ops.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp
|
${TORCH_SRC_DIR}/csrc/autograd/functions/tensor.cpp
|
||||||
@ -190,6 +191,7 @@ set(TORCH_SRCS
|
|||||||
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
|
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
|
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
|
||||||
@ -230,6 +232,14 @@ if (USE_CUDA)
|
|||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
if (USE_ROCM)
|
||||||
|
list(APPEND TORCH_SRCS
|
||||||
|
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
if (NOT NO_API)
|
if (NOT NO_API)
|
||||||
list(APPEND TORCH_SRCS
|
list(APPEND TORCH_SRCS
|
||||||
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
|
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
|
||||||
@ -543,7 +553,6 @@ if (BUILD_PYTHON)
|
|||||||
${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
|
${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
|
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
|
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
|
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp
|
${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_list.cpp
|
${TORCH_SRC_DIR}/csrc/utils/tensor_list.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/tensor_new.cpp
|
${TORCH_SRC_DIR}/csrc/utils/tensor_new.cpp
|
||||||
@ -615,7 +624,6 @@ if (BUILD_PYTHON)
|
|||||||
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
|
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
|
||||||
@ -662,7 +670,6 @@ if (BUILD_PYTHON)
|
|||||||
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
|
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
|
${TORCH_SRC_DIR}/csrc/nn/THCUNN.cpp
|
||||||
|
|||||||
@ -29,17 +29,17 @@ class Sampler {
|
|||||||
/// Resets the `Sampler`'s internal state.
|
/// Resets the `Sampler`'s internal state.
|
||||||
/// Typically called before a new epoch.
|
/// Typically called before a new epoch.
|
||||||
/// Optionally, accepts a new size when reseting the sampler.
|
/// Optionally, accepts a new size when reseting the sampler.
|
||||||
TORCH_API virtual void reset(optional<size_t> new_size) = 0;
|
virtual void reset(optional<size_t> new_size) = 0;
|
||||||
|
|
||||||
/// Returns the next index if possible, or an empty optional if the
|
/// Returns the next index if possible, or an empty optional if the
|
||||||
/// sampler is exhausted for this epoch.
|
/// sampler is exhausted for this epoch.
|
||||||
TORCH_API virtual optional<BatchRequest> next(size_t batch_size) = 0;
|
virtual optional<BatchRequest> next(size_t batch_size) = 0;
|
||||||
|
|
||||||
/// Serializes the `Sampler` to the `archive`.
|
/// Serializes the `Sampler` to the `archive`.
|
||||||
TORCH_API virtual void save(serialize::OutputArchive& archive) const = 0;
|
virtual void save(serialize::OutputArchive& archive) const = 0;
|
||||||
|
|
||||||
/// Deserializes the `Sampler` from the `archive`.
|
/// Deserializes the `Sampler` from the `archive`.
|
||||||
TORCH_API virtual void load(serialize::InputArchive& archive) = 0;
|
virtual void load(serialize::InputArchive& archive) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace samplers
|
} // namespace samplers
|
||||||
|
|||||||
@ -25,7 +25,7 @@ namespace samplers {
|
|||||||
template <typename BatchRequest = std::vector<size_t>>
|
template <typename BatchRequest = std::vector<size_t>>
|
||||||
class DistributedSampler : public Sampler<BatchRequest> {
|
class DistributedSampler : public Sampler<BatchRequest> {
|
||||||
public:
|
public:
|
||||||
TORCH_API DistributedSampler(
|
DistributedSampler(
|
||||||
size_t size,
|
size_t size,
|
||||||
size_t num_replicas = 1,
|
size_t num_replicas = 1,
|
||||||
size_t rank = 0,
|
size_t rank = 0,
|
||||||
@ -64,28 +64,28 @@ class DistributedSampler : public Sampler<BatchRequest> {
|
|||||||
|
|
||||||
/// Select samples randomly. The sampling order is shuffled at each `reset()`
|
/// Select samples randomly. The sampling order is shuffled at each `reset()`
|
||||||
/// call.
|
/// call.
|
||||||
class DistributedRandomSampler : public DistributedSampler<> {
|
class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
|
||||||
public:
|
public:
|
||||||
TORCH_API DistributedRandomSampler(
|
DistributedRandomSampler(
|
||||||
size_t size,
|
size_t size,
|
||||||
size_t num_replicas = 1,
|
size_t num_replicas = 1,
|
||||||
size_t rank = 0,
|
size_t rank = 0,
|
||||||
bool allow_duplicates = true);
|
bool allow_duplicates = true);
|
||||||
|
|
||||||
/// Resets the `DistributedRandomSampler` to a new set of indices.
|
/// Resets the `DistributedRandomSampler` to a new set of indices.
|
||||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
void reset(optional<size_t> new_size = nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
|
||||||
/// Serializes the `DistributedRandomSampler` to the `archive`.
|
/// Serializes the `DistributedRandomSampler` to the `archive`.
|
||||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
void save(serialize::OutputArchive& archive) const override;
|
||||||
|
|
||||||
/// Deserializes the `DistributedRandomSampler` from the `archive`.
|
/// Deserializes the `DistributedRandomSampler` from the `archive`.
|
||||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
void load(serialize::InputArchive& archive) override;
|
||||||
|
|
||||||
/// Returns the current index of the `DistributedRandomSampler`.
|
/// Returns the current index of the `DistributedRandomSampler`.
|
||||||
TORCH_API size_t index() const noexcept;
|
size_t index() const noexcept;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void populate_indices();
|
void populate_indices();
|
||||||
@ -97,28 +97,28 @@ class DistributedRandomSampler : public DistributedSampler<> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// Select samples sequentially.
|
/// Select samples sequentially.
|
||||||
class DistributedSequentialSampler : public DistributedSampler<> {
|
class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
|
||||||
public:
|
public:
|
||||||
TORCH_API DistributedSequentialSampler(
|
DistributedSequentialSampler(
|
||||||
size_t size,
|
size_t size,
|
||||||
size_t num_replicas = 1,
|
size_t num_replicas = 1,
|
||||||
size_t rank = 0,
|
size_t rank = 0,
|
||||||
bool allow_duplicates = true);
|
bool allow_duplicates = true);
|
||||||
|
|
||||||
/// Resets the `DistributedSequentialSampler` to a new set of indices.
|
/// Resets the `DistributedSequentialSampler` to a new set of indices.
|
||||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
void reset(optional<size_t> new_size = nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
|
||||||
/// Serializes the `DistributedSequentialSampler` to the `archive`.
|
/// Serializes the `DistributedSequentialSampler` to the `archive`.
|
||||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
void save(serialize::OutputArchive& archive) const override;
|
||||||
|
|
||||||
/// Deserializes the `DistributedSequentialSampler` from the `archive`.
|
/// Deserializes the `DistributedSequentialSampler` from the `archive`.
|
||||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
void load(serialize::InputArchive& archive) override;
|
||||||
|
|
||||||
/// Returns the current index of the `DistributedSequentialSampler`.
|
/// Returns the current index of the `DistributedSequentialSampler`.
|
||||||
TORCH_API size_t index() const noexcept;
|
size_t index() const noexcept;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void populate_indices();
|
void populate_indices();
|
||||||
|
|||||||
@ -19,31 +19,33 @@ namespace data {
|
|||||||
namespace samplers {
|
namespace samplers {
|
||||||
|
|
||||||
/// A `Sampler` that returns random indices.
|
/// A `Sampler` that returns random indices.
|
||||||
class RandomSampler : public Sampler<> {
|
class TORCH_API RandomSampler : public Sampler<> {
|
||||||
public:
|
public:
|
||||||
/// Constructs a `RandomSampler` with a size and dtype for the stored indices.
|
/// Constructs a `RandomSampler` with a size and dtype for the stored indices.
|
||||||
///
|
///
|
||||||
/// The constructor will eagerly allocate all required indices, which is the
|
/// The constructor will eagerly allocate all required indices, which is the
|
||||||
/// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
|
/// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
|
||||||
/// indices. You can change it to influence memory usage.
|
/// indices. You can change it to influence memory usage.
|
||||||
TORCH_API explicit RandomSampler(
|
explicit RandomSampler(
|
||||||
int64_t size,
|
int64_t size,
|
||||||
Dtype index_dtype = torch::kInt64);
|
Dtype index_dtype = torch::kInt64);
|
||||||
|
|
||||||
|
~RandomSampler() override;
|
||||||
|
|
||||||
/// Resets the `RandomSampler` to a new set of indices.
|
/// Resets the `RandomSampler` to a new set of indices.
|
||||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
void reset(optional<size_t> new_size = nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
|
||||||
/// Serializes the `RandomSampler` to the `archive`.
|
/// Serializes the `RandomSampler` to the `archive`.
|
||||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
void save(serialize::OutputArchive& archive) const override;
|
||||||
|
|
||||||
/// Deserializes the `RandomSampler` from the `archive`.
|
/// Deserializes the `RandomSampler` from the `archive`.
|
||||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
void load(serialize::InputArchive& archive) override;
|
||||||
|
|
||||||
/// Returns the current index of the `RandomSampler`.
|
/// Returns the current index of the `RandomSampler`.
|
||||||
TORCH_API size_t index() const noexcept;
|
size_t index() const noexcept;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Tensor indices_;
|
Tensor indices_;
|
||||||
|
|||||||
@ -19,26 +19,26 @@ namespace data {
|
|||||||
namespace samplers {
|
namespace samplers {
|
||||||
|
|
||||||
/// A `Sampler` that returns indices sequentially.
|
/// A `Sampler` that returns indices sequentially.
|
||||||
class SequentialSampler : public Sampler<> {
|
class TORCH_API SequentialSampler : public Sampler<> {
|
||||||
public:
|
public:
|
||||||
/// Creates a `SequentialSampler` that will return indices in the range
|
/// Creates a `SequentialSampler` that will return indices in the range
|
||||||
/// `0...size - 1`.
|
/// `0...size - 1`.
|
||||||
TORCH_API explicit SequentialSampler(size_t size);
|
explicit SequentialSampler(size_t size);
|
||||||
|
|
||||||
/// Resets the `SequentialSampler` to zero.
|
/// Resets the `SequentialSampler` to zero.
|
||||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
void reset(optional<size_t> new_size = nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
|
optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
|
||||||
/// Serializes the `SequentialSampler` to the `archive`.
|
/// Serializes the `SequentialSampler` to the `archive`.
|
||||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
void save(serialize::OutputArchive& archive) const override;
|
||||||
|
|
||||||
/// Deserializes the `SequentialSampler` from the `archive`.
|
/// Deserializes the `SequentialSampler` from the `archive`.
|
||||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
void load(serialize::InputArchive& archive) override;
|
||||||
|
|
||||||
/// Returns the current index of the `SequentialSampler`.
|
/// Returns the current index of the `SequentialSampler`.
|
||||||
TORCH_API size_t index() const noexcept;
|
size_t index() const noexcept;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
|
|||||||
@ -32,26 +32,26 @@ struct TORCH_API BatchSize : public CustomBatchRequest {
|
|||||||
/// The major feature of the `StreamSampler` is that it does not return
|
/// The major feature of the `StreamSampler` is that it does not return
|
||||||
/// particular indices, but instead only the number of elements to fetch from
|
/// particular indices, but instead only the number of elements to fetch from
|
||||||
/// the dataset. The dataset has to decide how to produce those elements.
|
/// the dataset. The dataset has to decide how to produce those elements.
|
||||||
class StreamSampler : public Sampler<BatchSize> {
|
class TORCH_API StreamSampler : public Sampler<BatchSize> {
|
||||||
public:
|
public:
|
||||||
/// Constructs the `StreamSampler` with the number of individual examples that
|
/// Constructs the `StreamSampler` with the number of individual examples that
|
||||||
/// should be fetched until the sampler is exhausted.
|
/// should be fetched until the sampler is exhausted.
|
||||||
TORCH_API explicit StreamSampler(size_t epoch_size);
|
explicit StreamSampler(size_t epoch_size);
|
||||||
|
|
||||||
/// Resets the internal state of the sampler.
|
/// Resets the internal state of the sampler.
|
||||||
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
|
void reset(optional<size_t> new_size = nullopt) override;
|
||||||
|
|
||||||
/// Returns a `BatchSize` object with the number of elements to fetch in the
|
/// Returns a `BatchSize` object with the number of elements to fetch in the
|
||||||
/// next batch. This number is the minimum of the supplied `batch_size` and
|
/// next batch. This number is the minimum of the supplied `batch_size` and
|
||||||
/// the difference between the `epoch_size` and the current index. If the
|
/// the difference between the `epoch_size` and the current index. If the
|
||||||
/// `epoch_size` has been reached, returns an empty optional.
|
/// `epoch_size` has been reached, returns an empty optional.
|
||||||
TORCH_API optional<BatchSize> next(size_t batch_size) override;
|
optional<BatchSize> next(size_t batch_size) override;
|
||||||
|
|
||||||
/// Serializes the `StreamSampler` to the `archive`.
|
/// Serializes the `StreamSampler` to the `archive`.
|
||||||
TORCH_API void save(serialize::OutputArchive& archive) const override;
|
void save(serialize::OutputArchive& archive) const override;
|
||||||
|
|
||||||
/// Deserializes the `StreamSampler` from the `archive`.
|
/// Deserializes the `StreamSampler` from the `archive`.
|
||||||
TORCH_API void load(serialize::InputArchive& archive) override;
|
void load(serialize::InputArchive& archive) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t examples_retrieved_so_far_ = 0;
|
size_t examples_retrieved_so_far_ = 0;
|
||||||
|
|||||||
@ -5,6 +5,8 @@
|
|||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/types.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,8 @@
|
|||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/types.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,8 @@ namespace samplers {
|
|||||||
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
|
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
|
||||||
: indices_(torch::randperm(size, index_dtype)) {}
|
: indices_(torch::randperm(size, index_dtype)) {}
|
||||||
|
|
||||||
|
RandomSampler::~RandomSampler() = default;
|
||||||
|
|
||||||
void RandomSampler::reset(optional<size_t> new_size) {
|
void RandomSampler::reset(optional<size_t> new_size) {
|
||||||
// This allocates a new chunk of memory every time (just FYI). It should be
|
// This allocates a new chunk of memory every time (just FYI). It should be
|
||||||
// amortized over the entire epoch hopefully.
|
// amortized over the entire epoch hopefully.
|
||||||
|
|||||||
@ -4,4 +4,6 @@ namespace torch { namespace autograd {
|
|||||||
|
|
||||||
bool AnomalyMode::_enabled = false;
|
bool AnomalyMode::_enabled = false;
|
||||||
|
|
||||||
|
AnomalyMetadata::~AnomalyMetadata() = default;
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
namespace torch { namespace autograd {
|
namespace torch { namespace autograd {
|
||||||
|
|
||||||
struct AnomalyMode {
|
struct TORCH_API AnomalyMode {
|
||||||
static bool is_enabled() {
|
static bool is_enabled() {
|
||||||
return _enabled;
|
return _enabled;
|
||||||
}
|
}
|
||||||
@ -13,12 +13,12 @@ struct AnomalyMode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TORCH_API static bool _enabled;
|
static bool _enabled;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct AnomalyMetadata {
|
struct TORCH_API AnomalyMetadata {
|
||||||
virtual ~AnomalyMetadata() = default;
|
virtual ~AnomalyMetadata();
|
||||||
virtual void store_stack() = 0;
|
virtual void store_stack() = 0;
|
||||||
virtual void print_stack() = 0;
|
virtual void print_stack() = 0;
|
||||||
};
|
};
|
||||||
|
|||||||
8
torch/csrc/autograd/function_hook.cpp
Normal file
8
torch/csrc/autograd/function_hook.cpp
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#include <torch/csrc/autograd/function_hook.h>
|
||||||
|
|
||||||
|
namespace torch { namespace autograd {
|
||||||
|
|
||||||
|
FunctionPreHook::~FunctionPreHook() = default;
|
||||||
|
FunctionPostHook::~FunctionPostHook() = default;
|
||||||
|
|
||||||
|
}} // namespace torch::autograd
|
||||||
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
// A hook that's called on gradients
|
// A hook that's called on gradients
|
||||||
|
|
||||||
@ -9,13 +10,13 @@ namespace torch { namespace autograd {
|
|||||||
struct Variable;
|
struct Variable;
|
||||||
using variable_list = std::vector<Variable>;
|
using variable_list = std::vector<Variable>;
|
||||||
|
|
||||||
struct FunctionPreHook {
|
struct TORCH_API FunctionPreHook {
|
||||||
virtual ~FunctionPreHook() = default;
|
virtual ~FunctionPreHook();
|
||||||
virtual variable_list operator()(const variable_list& grads) = 0;
|
virtual variable_list operator()(const variable_list& grads) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FunctionPostHook {
|
struct TORCH_API FunctionPostHook {
|
||||||
virtual ~FunctionPostHook() = default;
|
virtual ~FunctionPostHook();
|
||||||
virtual variable_list operator()(
|
virtual variable_list operator()(
|
||||||
const variable_list& outputs /* grad_inputs */,
|
const variable_list& outputs /* grad_inputs */,
|
||||||
const variable_list& inputs /* grad_outputs */) = 0;
|
const variable_list& inputs /* grad_outputs */) = 0;
|
||||||
|
|||||||
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/function.h>
|
||||||
#include <torch/csrc/autograd/variable.h>
|
#include <torch/csrc/autograd/variable.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch { namespace autograd {
|
namespace torch { namespace autograd {
|
||||||
|
|
||||||
struct AccumulateGrad : public Function {
|
struct TORCH_API AccumulateGrad : public Function {
|
||||||
explicit AccumulateGrad(Variable variable_);
|
explicit AccumulateGrad(Variable variable_);
|
||||||
|
|
||||||
variable_list apply(variable_list&& grads) override;
|
variable_list apply(variable_list&& grads) override;
|
||||||
|
|||||||
@ -28,6 +28,8 @@ Scatter::Scatter(
|
|||||||
streams_(streams),
|
streams_(streams),
|
||||||
unsqueeze_scalars_(unsqueeze_scalars) {}
|
unsqueeze_scalars_(unsqueeze_scalars) {}
|
||||||
|
|
||||||
|
Scatter::~Scatter() {}
|
||||||
|
|
||||||
variable_list Scatter::apply(variable_list&& inputs) {
|
variable_list Scatter::apply(variable_list&& inputs) {
|
||||||
AT_ASSERT(inputs.size() == 1);
|
AT_ASSERT(inputs.size() == 1);
|
||||||
auto& input = inputs.front();
|
auto& input = inputs.front();
|
||||||
@ -65,6 +67,8 @@ variable_list Scatter::apply(variable_list&& inputs) {
|
|||||||
Gather::Gather(const at::Device& destination_device, int64_t dim)
|
Gather::Gather(const at::Device& destination_device, int64_t dim)
|
||||||
: destination_device_(destination_device), dim_(dim) {}
|
: destination_device_(destination_device), dim_(dim) {}
|
||||||
|
|
||||||
|
Gather::~Gather() {}
|
||||||
|
|
||||||
variable_list Gather::apply(variable_list&& inputs) {
|
variable_list Gather::apply(variable_list&& inputs) {
|
||||||
bool all_are_zero_dim = true;
|
bool all_are_zero_dim = true;
|
||||||
for (const auto& input : inputs) {
|
for (const auto& input : inputs) {
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <ATen/cuda/ATenCUDAGeneral.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -13,6 +14,7 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace autograd {
|
namespace autograd {
|
||||||
|
|
||||||
|
//TODO: change it to TORCH_API when we merge the libs
|
||||||
struct TORCH_API Scatter : public Function {
|
struct TORCH_API Scatter : public Function {
|
||||||
explicit Scatter(
|
explicit Scatter(
|
||||||
std::vector<at::Device> devices,
|
std::vector<at::Device> devices,
|
||||||
@ -21,6 +23,7 @@ struct TORCH_API Scatter : public Function {
|
|||||||
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
|
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
|
||||||
c10::nullopt,
|
c10::nullopt,
|
||||||
bool unsqueeze_scalars = false);
|
bool unsqueeze_scalars = false);
|
||||||
|
~Scatter() override;
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
|
||||||
@ -33,6 +36,7 @@ struct TORCH_API Scatter : public Function {
|
|||||||
|
|
||||||
struct TORCH_API Gather : public Function {
|
struct TORCH_API Gather : public Function {
|
||||||
explicit Gather(const at::Device& destination_device, int64_t dim = 0);
|
explicit Gather(const at::Device& destination_device, int64_t dim = 0);
|
||||||
|
~Gather() override;
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/function.h>
|
||||||
#include <torch/csrc/autograd/variable.h>
|
#include <torch/csrc/autograd/variable.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
#include <ATen/TensorGeometry.h>
|
#include <ATen/TensorGeometry.h>
|
||||||
#include <ATen/core/DeprecatedTypeProperties.h>
|
#include <ATen/core/DeprecatedTypeProperties.h>
|
||||||
@ -12,7 +13,7 @@
|
|||||||
|
|
||||||
namespace torch { namespace autograd {
|
namespace torch { namespace autograd {
|
||||||
|
|
||||||
struct CopyBackwards : public Function {
|
struct TORCH_API CopyBackwards : public Function {
|
||||||
variable_list apply(variable_list&& grads) override;
|
variable_list apply(variable_list&& grads) override;
|
||||||
|
|
||||||
at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety.
|
at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety.
|
||||||
@ -26,7 +27,7 @@ struct CopyBackwards : public Function {
|
|||||||
// grad_fn is updated to become a `CopySlice` wrapping the backward of the
|
// grad_fn is updated to become a `CopySlice` wrapping the backward of the
|
||||||
// in-place operation.
|
// in-place operation.
|
||||||
// See NOTE [ Autograd View Variables ].
|
// See NOTE [ Autograd View Variables ].
|
||||||
struct CopySlices : public Function {
|
struct TORCH_API CopySlices : public Function {
|
||||||
CopySlices(
|
CopySlices(
|
||||||
const Variable& base_var,
|
const Variable& base_var,
|
||||||
at::TensorGeometry view_,
|
at::TensorGeometry view_,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ constexpr CUDAStubs* default_stubs_addr = &default_stubs;
|
|||||||
// static initialization calls which may invoke registerCUDAMethods
|
// static initialization calls which may invoke registerCUDAMethods
|
||||||
static CUDAStubs* cuda_stubs = default_stubs_addr;
|
static CUDAStubs* cuda_stubs = default_stubs_addr;
|
||||||
|
|
||||||
TORCH_API void registerCUDAMethods(CUDAStubs* stubs) {
|
void registerCUDAMethods(CUDAStubs* stubs) {
|
||||||
cuda_stubs = stubs;
|
cuda_stubs = stubs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -91,11 +91,27 @@ inline int64_t getTime() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class EventKind : uint16_t {
|
// Old GCC versions generate warnings incorrectly
|
||||||
|
// see https://stackoverflow.com/questions/2463113/g-c0x-enum-class-compiler-warnings
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
# pragma GCC diagnostic push
|
||||||
|
# pragma GCC diagnostic ignored "-Wattributes"
|
||||||
|
#endif
|
||||||
|
enum class TORCH_API ProfilerState {
|
||||||
|
Disabled,
|
||||||
|
CPU, // CPU-only profiling
|
||||||
|
CUDA, // CPU + CUDA events
|
||||||
|
NVTX, // only emit NVTX markers
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class TORCH_API EventKind : uint16_t {
|
||||||
Mark,
|
Mark,
|
||||||
PushRange,
|
PushRange,
|
||||||
PopRange
|
PopRange
|
||||||
};
|
};
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
# pragma GCC diagnostic pop
|
||||||
|
#endif
|
||||||
|
|
||||||
struct TORCH_API Event final {
|
struct TORCH_API Event final {
|
||||||
Event(EventKind kind, StringView name, uint16_t thread_id, bool record_cuda)
|
Event(EventKind kind, StringView name, uint16_t thread_id, bool record_cuda)
|
||||||
@ -183,13 +199,6 @@ struct RangeEventList {
|
|||||||
std::forward_list<block_type> blocks;
|
std::forward_list<block_type> blocks;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class ProfilerState {
|
|
||||||
Disabled,
|
|
||||||
CPU, // CPU-only profiling
|
|
||||||
CUDA, // CPU + CUDA events
|
|
||||||
NVTX, // only emit NVTX markers
|
|
||||||
};
|
|
||||||
|
|
||||||
TORCH_API RangeEventList& getEventList();
|
TORCH_API RangeEventList& getEventList();
|
||||||
TORCH_API void mark(std::string name, bool include_cuda = true);
|
TORCH_API void mark(std::string name, bool include_cuda = true);
|
||||||
TORCH_API void pushRange(std::string name);
|
TORCH_API void pushRange(std::string name);
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
#include <ATen/cuda/ATenCUDAGeneral.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
@ -11,11 +13,11 @@ namespace torch { namespace cuda {
|
|||||||
|
|
||||||
using tensor_list2d = std::vector<std::vector<at::Tensor>>;
|
using tensor_list2d = std::vector<std::vector<at::Tensor>>;
|
||||||
|
|
||||||
std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices);
|
TORCH_API std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices);
|
||||||
tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices,
|
TORCH_API tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices,
|
||||||
size_t buffer_size);
|
size_t buffer_size);
|
||||||
|
|
||||||
std::vector<at::Tensor> scatter(
|
TORCH_API std::vector<at::Tensor> scatter(
|
||||||
const at::Tensor& tensor,
|
const at::Tensor& tensor,
|
||||||
at::IntArrayRef devices,
|
at::IntArrayRef devices,
|
||||||
const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
|
const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
|
||||||
@ -23,7 +25,7 @@ std::vector<at::Tensor> scatter(
|
|||||||
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
|
const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
|
||||||
c10::nullopt);
|
c10::nullopt);
|
||||||
|
|
||||||
at::Tensor gather(
|
TORCH_API at::Tensor gather(
|
||||||
at::TensorList tensors,
|
at::TensorList tensors,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
c10::optional<int32_t> destination_index);
|
c10::optional<int32_t> destination_index);
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#include <torch/csrc/jit/ir.h>
|
#include <torch/csrc/jit/ir.h>
|
||||||
#include <torch/csrc/jit/variable_tensor_list.h>
|
#include <torch/csrc/jit/variable_tensor_list.h>
|
||||||
#include <torch/csrc/utils/hash.h>
|
#include <torch/csrc/utils/hash.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -133,7 +134,7 @@ struct ArgumentSpec {
|
|||||||
// ArgumentSpecCreator takes an initial graph and comes up with a set
|
// ArgumentSpecCreator takes an initial graph and comes up with a set
|
||||||
// of simple instructions to compute the ArgumentSpec given a set of
|
// of simple instructions to compute the ArgumentSpec given a set of
|
||||||
// input tensors.
|
// input tensors.
|
||||||
struct ArgumentSpecCreator {
|
struct TORCH_API ArgumentSpecCreator {
|
||||||
// instructs acts on a stack of a list of input IValues
|
// instructs acts on a stack of a list of input IValues
|
||||||
// at the beginning the stack contains a single list of the inputs to the
|
// at the beginning the stack contains a single list of the inputs to the
|
||||||
// function the ENTER_ instructs descend into subobjects and push new lists
|
// function the ENTER_ instructs descend into subobjects and push new lists
|
||||||
|
|||||||
@ -84,7 +84,7 @@ struct Graph;
|
|||||||
|
|
||||||
// We special case Graph attributes like this because we want to ensure that
|
// We special case Graph attributes like this because we want to ensure that
|
||||||
// Graph::copy() is called when we clone() these attributes.
|
// Graph::copy() is called when we clone() these attributes.
|
||||||
struct GraphAttr : public AttributeValue {
|
struct TORCH_API GraphAttr : public AttributeValue {
|
||||||
using ConstructorType = std::shared_ptr<Graph>;
|
using ConstructorType = std::shared_ptr<Graph>;
|
||||||
using ValueType = std::shared_ptr<Graph>;
|
using ValueType = std::shared_ptr<Graph>;
|
||||||
GraphAttr(Symbol name, ConstructorType value_)
|
GraphAttr(Symbol name, ConstructorType value_)
|
||||||
@ -92,7 +92,7 @@ struct GraphAttr : public AttributeValue {
|
|||||||
ValueType& value() {
|
ValueType& value() {
|
||||||
return value_;
|
return value_;
|
||||||
}
|
}
|
||||||
TORCH_API Ptr clone() const override;
|
Ptr clone() const override;
|
||||||
AttributeKind kind() const override {
|
AttributeKind kind() const override {
|
||||||
return AttributeKind::g;
|
return AttributeKind::g;
|
||||||
}
|
}
|
||||||
@ -101,7 +101,7 @@ struct GraphAttr : public AttributeValue {
|
|||||||
std::shared_ptr<Graph> value_;
|
std::shared_ptr<Graph> value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GraphsAttr : public AttributeValue {
|
struct TORCH_API GraphsAttr : public AttributeValue {
|
||||||
using ConstructorType = std::vector<std::shared_ptr<Graph>>;
|
using ConstructorType = std::vector<std::shared_ptr<Graph>>;
|
||||||
using ValueType = std::vector<std::shared_ptr<Graph>>;
|
using ValueType = std::vector<std::shared_ptr<Graph>>;
|
||||||
GraphsAttr(Symbol name, ConstructorType value_)
|
GraphsAttr(Symbol name, ConstructorType value_)
|
||||||
@ -112,7 +112,7 @@ struct GraphsAttr : public AttributeValue {
|
|||||||
AttributeKind kind() const override {
|
AttributeKind kind() const override {
|
||||||
return AttributeKind::gs;
|
return AttributeKind::gs;
|
||||||
}
|
}
|
||||||
TORCH_API std::unique_ptr<AttributeValue> clone() const override;
|
std::unique_ptr<AttributeValue> clone() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ValueType value_;
|
ValueType value_;
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <torch/csrc/utils/disallow_copy.h>
|
#include <torch/csrc/utils/disallow_copy.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
@ -11,11 +12,11 @@ namespace cpu {
|
|||||||
struct DynamicLibrary {
|
struct DynamicLibrary {
|
||||||
TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
||||||
|
|
||||||
DynamicLibrary(const char* name);
|
TORCH_API DynamicLibrary(const char* name);
|
||||||
|
|
||||||
void* sym(const char* name);
|
TORCH_API void* sym(const char* name);
|
||||||
|
|
||||||
~DynamicLibrary();
|
TORCH_API ~DynamicLibrary();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void* handle = nullptr;
|
void* handle = nullptr;
|
||||||
|
|||||||
@ -6,19 +6,19 @@ namespace jit {
|
|||||||
|
|
||||||
static std::function<void(std::shared_ptr<script::Module> module)>
|
static std::function<void(std::shared_ptr<script::Module> module)>
|
||||||
emit_module_callback;
|
emit_module_callback;
|
||||||
TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module) {
|
void didFinishEmitModule(std::shared_ptr<script::Module> module) {
|
||||||
if (emit_module_callback) {
|
if (emit_module_callback) {
|
||||||
emit_module_callback(std::move(module));
|
emit_module_callback(std::move(module));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static std::function<void(std::shared_ptr<script::Function> fn)>
|
static std::function<void(std::shared_ptr<script::Function> fn)>
|
||||||
emit_function_callback;
|
emit_function_callback;
|
||||||
TORCH_API void didFinishEmitFunction(std::shared_ptr<script::Function> fn) {
|
void didFinishEmitFunction(std::shared_ptr<script::Function> fn) {
|
||||||
if (emit_function_callback) {
|
if (emit_function_callback) {
|
||||||
emit_function_callback(fn);
|
emit_function_callback(fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TORCH_API void setEmitHooks(
|
void setEmitHooks(
|
||||||
std::function<void(std::shared_ptr<script::Module> module)> for_mod,
|
std::function<void(std::shared_ptr<script::Module> module)> for_mod,
|
||||||
std::function<void(std::shared_ptr<script::Function> for_fn)> for_fn) {
|
std::function<void(std::shared_ptr<script::Function> for_fn)> for_fn) {
|
||||||
emit_module_callback = std::move(for_mod);
|
emit_module_callback = std::move(for_mod);
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
#include <torch/csrc/jit/script/python_tree_views.h>
|
#include <torch/csrc/jit/script/python_tree_views.h>
|
||||||
#include <torch/csrc/jit/tracer.h>
|
#include <torch/csrc/jit/tracer.h>
|
||||||
|
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
#include <caffe2/serialize/inline_container.h>
|
#include <caffe2/serialize/inline_container.h>
|
||||||
|
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
@ -84,7 +85,7 @@ void runJITCPPTests(bool runCuda) {
|
|||||||
AT_ERROR("JIT tests not yet supported on Windows");
|
AT_ERROR("JIT tests not yet supported on Windows");
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
void runJITCPPTests(bool runCuda);
|
CAFFE2_API void runJITCPPTests(bool runCuda);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void initJITBindings(PyObject* module) {
|
void initJITBindings(PyObject* module) {
|
||||||
|
|||||||
@ -47,11 +47,11 @@ struct TORCH_API Code {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct InterpreterState {
|
struct InterpreterState {
|
||||||
InterpreterState(const Code& code);
|
TORCH_API InterpreterState(const Code& code);
|
||||||
void run(Stack& stack);
|
TORCH_API void run(Stack& stack);
|
||||||
c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||||
c10::intrusive_ptr<Future> getFuture();
|
c10::intrusive_ptr<Future> getFuture();
|
||||||
~InterpreterState();
|
TORCH_API ~InterpreterState();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
|
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
|
||||||
|
|||||||
@ -33,10 +33,6 @@ static constexpr topo_position_t kMidPoint = 0;
|
|||||||
// - 2^(64-n) is the maximum number of appends to the end without reindex
|
// - 2^(64-n) is the maximum number of appends to the end without reindex
|
||||||
static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
|
static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
|
||||||
|
|
||||||
// Sigh, see
|
|
||||||
// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
|
|
||||||
constexpr Symbol PythonOp::Kind;
|
|
||||||
|
|
||||||
static void printValueRef(std::ostream& out, const Value* n) {
|
static void printValueRef(std::ostream& out, const Value* n) {
|
||||||
out << "%" << n->uniqueName();
|
out << "%" << n->uniqueName();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -231,7 +231,7 @@ struct Value {
|
|||||||
TORCH_API Value* copyMetadata(Value* from);
|
TORCH_API Value* copyMetadata(Value* from);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Node {
|
struct TORCH_API Node {
|
||||||
TH_DISALLOW_COPY_AND_ASSIGN(Node);
|
TH_DISALLOW_COPY_AND_ASSIGN(Node);
|
||||||
friend struct Graph;
|
friend struct Graph;
|
||||||
friend struct Block;
|
friend struct Block;
|
||||||
@ -259,7 +259,7 @@ struct Node {
|
|||||||
topo_position_t topo_position_ = 0;
|
topo_position_t topo_position_ = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
|
Node(Graph* graph_, NodeKind kind_); // defined after graph
|
||||||
public:
|
public:
|
||||||
// each node but Return/Param
|
// each node but Return/Param
|
||||||
// is associated with exactly one place in the node list...
|
// is associated with exactly one place in the node list...
|
||||||
@ -358,7 +358,7 @@ struct Node {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API void replaceAllUsesWith(Node* n);
|
void replaceAllUsesWith(Node* n);
|
||||||
|
|
||||||
// lots of things like chunk have a single input or single output, so we have
|
// lots of things like chunk have a single input or single output, so we have
|
||||||
// a helper to make accessing it easier
|
// a helper to make accessing it easier
|
||||||
@ -399,10 +399,10 @@ struct Node {
|
|||||||
bool is_constant(Symbol name) const {
|
bool is_constant(Symbol name) const {
|
||||||
return static_cast<bool>(get(name));
|
return static_cast<bool>(get(name));
|
||||||
}
|
}
|
||||||
TORCH_API bool mustBeNone() const;
|
bool mustBeNone() const;
|
||||||
|
|
||||||
TORCH_API bool isNondeterministic() const;
|
bool isNondeterministic() const;
|
||||||
TORCH_API bool hasSideEffects() const;
|
bool hasSideEffects() const;
|
||||||
|
|
||||||
// Graphs
|
// Graphs
|
||||||
|
|
||||||
@ -424,11 +424,11 @@ struct Node {
|
|||||||
// Given: %3 = f(%1, %2)
|
// Given: %3 = f(%1, %2)
|
||||||
// Execute: %3.addInput(%4)
|
// Execute: %3.addInput(%4)
|
||||||
// Result: %3 = f(%1, %2, %4)
|
// Result: %3 = f(%1, %2, %4)
|
||||||
TORCH_API Value* addInput(Value* value);
|
Value* addInput(Value* value);
|
||||||
|
|
||||||
// Add 'value' as an input to 'this' at the specified position in the
|
// Add 'value' as an input to 'this' at the specified position in the
|
||||||
// arguments. Returns the added value for ease of chaining.
|
// arguments. Returns the added value for ease of chaining.
|
||||||
TORCH_API Value* insertInput(size_t i, Value* value);
|
Value* insertInput(size_t i, Value* value);
|
||||||
|
|
||||||
// Replace the input of 'this' at position 'i' with
|
// Replace the input of 'this' at position 'i' with
|
||||||
// 'newValue', returning the old node.
|
// 'newValue', returning the old node.
|
||||||
@ -436,7 +436,7 @@ struct Node {
|
|||||||
// Given: %3 = f(%1, %2)
|
// Given: %3 = f(%1, %2)
|
||||||
// Execute: %3.replaceInput(1, %4)
|
// Execute: %3.replaceInput(1, %4)
|
||||||
// Result: %3 = f(%1, %4)
|
// Result: %3 = f(%1, %4)
|
||||||
TORCH_API Value* replaceInput(size_t i, Value* newValue);
|
Value* replaceInput(size_t i, Value* newValue);
|
||||||
|
|
||||||
// Replace all occurrences of 'from' in the inputs of this
|
// Replace all occurrences of 'from' in the inputs of this
|
||||||
// node with 'to'. Corresponds to llvm's replaceUsesOfWith.
|
// node with 'to'. Corresponds to llvm's replaceUsesOfWith.
|
||||||
@ -444,16 +444,16 @@ struct Node {
|
|||||||
// Given: %3 = f(%1, %2, %1)
|
// Given: %3 = f(%1, %2, %1)
|
||||||
// Execute: %3.replaceInputWith(%1, %4)
|
// Execute: %3.replaceInputWith(%1, %4)
|
||||||
// Result: %3 = f(%4, %2, %4)
|
// Result: %3 = f(%4, %2, %4)
|
||||||
TORCH_API void replaceInputWith(Value* from, Value* to);
|
void replaceInputWith(Value* from, Value* to);
|
||||||
|
|
||||||
TORCH_API Value* addOutput();
|
Value* addOutput();
|
||||||
|
|
||||||
TORCH_API Value* insertOutput(size_t i);
|
Value* insertOutput(size_t i);
|
||||||
|
|
||||||
TORCH_API void eraseOutput(size_t i);
|
void eraseOutput(size_t i);
|
||||||
|
|
||||||
TORCH_API Block* addBlock();
|
Block* addBlock();
|
||||||
TORCH_API void eraseBlock(size_t i);
|
void eraseBlock(size_t i);
|
||||||
|
|
||||||
// Each Node can have a list of subblocks. These are used to define structured
|
// Each Node can have a list of subblocks. These are used to define structured
|
||||||
// nested control flow operators such as If and Loop.
|
// nested control flow operators such as If and Loop.
|
||||||
@ -482,10 +482,10 @@ struct Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Is 'this' before 'n' in the topological order?
|
// Is 'this' before 'n' in the topological order?
|
||||||
TORCH_API bool isBefore(const Node* n) const;
|
bool isBefore(const Node* n) const;
|
||||||
|
|
||||||
// Is 'this' after 'n' in the topological order?
|
// Is 'this' after 'n' in the topological order?
|
||||||
TORCH_API bool isAfter(const Node* n) const;
|
bool isAfter(const Node* n) const;
|
||||||
|
|
||||||
// Insert unattached 'this' node before 'n' in the topological order.
|
// Insert unattached 'this' node before 'n' in the topological order.
|
||||||
// Returns this (for chaining).
|
// Returns this (for chaining).
|
||||||
@ -497,7 +497,7 @@ struct Node {
|
|||||||
// Result: %3 = f(%1, %2)
|
// Result: %3 = f(%1, %2)
|
||||||
// %5 = h(%1)
|
// %5 = h(%1)
|
||||||
// %4 = g(%3)
|
// %4 = g(%3)
|
||||||
TORCH_API Node* insertBefore(Node* n);
|
Node* insertBefore(Node* n);
|
||||||
|
|
||||||
// Insert unattached 'this' node after 'n' in the topological order.
|
// Insert unattached 'this' node after 'n' in the topological order.
|
||||||
// Returns this (for chaining).
|
// Returns this (for chaining).
|
||||||
@ -509,7 +509,7 @@ struct Node {
|
|||||||
// Result: %3 = f(%1, %2)
|
// Result: %3 = f(%1, %2)
|
||||||
// %4 = g(%3)
|
// %4 = g(%3)
|
||||||
// %5 = h(%1)
|
// %5 = h(%1)
|
||||||
TORCH_API Node* insertAfter(Node* n);
|
Node* insertAfter(Node* n);
|
||||||
|
|
||||||
// Move 'this' (already in the graph) after 'n' in the topological order.
|
// Move 'this' (already in the graph) after 'n' in the topological order.
|
||||||
//
|
//
|
||||||
@ -522,7 +522,7 @@ struct Node {
|
|||||||
// Result: %3 = g(%1)
|
// Result: %3 = g(%1)
|
||||||
// %2 = f(%1)
|
// %2 = f(%1)
|
||||||
//
|
//
|
||||||
TORCH_API void moveAfter(Node* n);
|
void moveAfter(Node* n);
|
||||||
|
|
||||||
// Move a node 'n' (already in the graph) before 'this' in the topological
|
// Move a node 'n' (already in the graph) before 'this' in the topological
|
||||||
// order.
|
// order.
|
||||||
@ -535,7 +535,7 @@ struct Node {
|
|||||||
// Execute: %3.moveBefore(%2)
|
// Execute: %3.moveBefore(%2)
|
||||||
// Result: %3 = g(%1)
|
// Result: %3 = g(%1)
|
||||||
// %2 = f(%1)
|
// %2 = f(%1)
|
||||||
TORCH_API void moveBefore(Node* n);
|
void moveBefore(Node* n);
|
||||||
|
|
||||||
// Remove the input at 'i' from this node.
|
// Remove the input at 'i' from this node.
|
||||||
//
|
//
|
||||||
@ -545,14 +545,14 @@ struct Node {
|
|||||||
// Given: %3 = f(%1, %2)
|
// Given: %3 = f(%1, %2)
|
||||||
// Execute: %3.removeInput(1)
|
// Execute: %3.removeInput(1)
|
||||||
// Result: %3 = f(%1)
|
// Result: %3 = f(%1)
|
||||||
TORCH_API void removeInput(size_t i);
|
void removeInput(size_t i);
|
||||||
|
|
||||||
// Remove all inputs from a node.
|
// Remove all inputs from a node.
|
||||||
//
|
//
|
||||||
// Given: %3 = f(%1, %2)
|
// Given: %3 = f(%1, %2)
|
||||||
// Execute: %3.removeAllInputs()
|
// Execute: %3.removeAllInputs()
|
||||||
// Result: %3 = f()
|
// Result: %3 = f()
|
||||||
TORCH_API void removeAllInputs();
|
void removeAllInputs();
|
||||||
|
|
||||||
// iterators of the node list starting at this node
|
// iterators of the node list starting at this node
|
||||||
// useful for resuming a search starting at this node
|
// useful for resuming a search starting at this node
|
||||||
@ -577,7 +577,7 @@ struct Node {
|
|||||||
// %3 = g(%1)
|
// %3 = g(%1)
|
||||||
// Execute: %2.destroy()
|
// Execute: %2.destroy()
|
||||||
// Result: %3 = g(%1)
|
// Result: %3 = g(%1)
|
||||||
TORCH_API void destroy();
|
void destroy();
|
||||||
|
|
||||||
// Dynamically cast this node to the subclass indicated by the
|
// Dynamically cast this node to the subclass indicated by the
|
||||||
// template variable, returning nullptr if the cast is invalid..
|
// template variable, returning nullptr if the cast is invalid..
|
||||||
@ -604,11 +604,11 @@ struct Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// XXX: this function is meant to be used with string literals only!
|
// XXX: this function is meant to be used with string literals only!
|
||||||
TORCH_API bool matches(
|
bool matches(
|
||||||
const char* signature_literal,
|
const char* signature_literal,
|
||||||
at::ArrayRef<Symbol> const_inputs = {}) const;
|
at::ArrayRef<Symbol> const_inputs = {}) const;
|
||||||
|
|
||||||
TORCH_API const FunctionSchema& schema() const {
|
const FunctionSchema& schema() const {
|
||||||
if (!schema_) {
|
if (!schema_) {
|
||||||
findSchema();
|
findSchema();
|
||||||
}
|
}
|
||||||
@ -782,15 +782,15 @@ struct Node {
|
|||||||
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
|
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
|
||||||
|
|
||||||
std::pair<Value*, const Argument&> findInput(Symbol name);
|
std::pair<Value*, const Argument&> findInput(Symbol name);
|
||||||
TORCH_API void findSchema() const;
|
void findSchema() const;
|
||||||
// Lookup iterator in use list of _input i_ that corresponds to its use of
|
// Lookup iterator in use list of _input i_ that corresponds to its use of
|
||||||
// _this_
|
// _this_
|
||||||
TORCH_API use_list::iterator findUseForInput(size_t i);
|
use_list::iterator findUseForInput(size_t i);
|
||||||
|
|
||||||
// remove the use of input i, this sets input i to nullptr, but
|
// remove the use of input i, this sets input i to nullptr, but
|
||||||
// is only used internally to Node before setting it to a new value
|
// is only used internally to Node before setting it to a new value
|
||||||
// or erasing the entry from the list.
|
// or erasing the entry from the list.
|
||||||
TORCH_API Value* dropInput(size_t i);
|
Value* dropInput(size_t i);
|
||||||
|
|
||||||
bool inBlockList() const {
|
bool inBlockList() const {
|
||||||
if (next() == nullptr) {
|
if (next() == nullptr) {
|
||||||
@ -799,8 +799,8 @@ struct Node {
|
|||||||
return next() != nullptr;
|
return next() != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API void removeFromList();
|
void removeFromList();
|
||||||
TORCH_API void lint() const;
|
void lint() const;
|
||||||
|
|
||||||
void assignTopoPosition();
|
void assignTopoPosition();
|
||||||
|
|
||||||
@ -818,7 +818,7 @@ struct Node {
|
|||||||
// 'this' will be allocated with s->allocNewInstance(g) so it should have
|
// 'this' will be allocated with s->allocNewInstance(g) so it should have
|
||||||
// the same concrete type as 's'
|
// the same concrete type as 's'
|
||||||
//
|
//
|
||||||
TORCH_API virtual void cloneFrom(Node* s);
|
virtual void cloneFrom(Node* s);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Block {
|
struct Block {
|
||||||
@ -1262,9 +1262,7 @@ struct ProfileOp : public Node {
|
|||||||
// which is not included in libtorch.so. We still include some bits and pieces
|
// which is not included in libtorch.so. We still include some bits and pieces
|
||||||
// of PythonOp here to enable writing simple passes generically. In general,
|
// of PythonOp here to enable writing simple passes generically. In general,
|
||||||
// python-aware bits need to be moved to the descendant classes.
|
// python-aware bits need to be moved to the descendant classes.
|
||||||
struct PythonOp : public Node {
|
struct TORCH_API PythonOp : public Node {
|
||||||
static constexpr Symbol Kind = ::c10::prim::PythonOp;
|
|
||||||
|
|
||||||
using Node::Node;
|
using Node::Node;
|
||||||
|
|
||||||
// should this Python function be skipped over when exported (i.e. for
|
// should this Python function be skipped over when exported (i.e. for
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
#include <caffe2/proto/caffe2_pb.h>
|
#include <caffe2/proto/caffe2_pb.h>
|
||||||
#include <torch/csrc/jit/ir.h>
|
#include <torch/csrc/jit/ir.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
@ -15,7 +16,7 @@ namespace jit {
|
|||||||
* \p Prefix can be used for appending some string to every operator name (e.g.
|
* \p Prefix can be used for appending some string to every operator name (e.g.
|
||||||
* we can add "caffe2::").
|
* we can add "caffe2::").
|
||||||
*/
|
*/
|
||||||
void convertNetDefToIR(
|
TORCH_API void convertNetDefToIR(
|
||||||
const caffe2::NetDef& net,
|
const caffe2::NetDef& net,
|
||||||
Graph* graph,
|
Graph* graph,
|
||||||
std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
|
std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
|
||||||
@ -34,7 +35,7 @@ void convertNetDefToIR(
|
|||||||
* TODO: We might need to do a better job at preserving names of the variables,
|
* TODO: We might need to do a better job at preserving names of the variables,
|
||||||
* especially external_inputs/external_outputs.
|
* especially external_inputs/external_outputs.
|
||||||
*/
|
*/
|
||||||
void convertIRToNetDef(
|
TORCH_API void convertIRToNetDef(
|
||||||
caffe2::NetDef* net,
|
caffe2::NetDef* net,
|
||||||
const Graph& graph,
|
const Graph& graph,
|
||||||
const std::string& prefix = "");
|
const std::string& prefix = "");
|
||||||
|
|||||||
@ -19,9 +19,9 @@ namespace jit {
|
|||||||
// A pass modifies a Graph in place.
|
// A pass modifies a Graph in place.
|
||||||
using Pass = std::function<void(std::shared_ptr<Graph>&)>;
|
using Pass = std::function<void(std::shared_ptr<Graph>&)>;
|
||||||
|
|
||||||
std::vector<Pass>& getCustomPasses();
|
TORCH_API std::vector<Pass>& getCustomPasses();
|
||||||
|
|
||||||
struct RegisterPass {
|
struct TORCH_API RegisterPass {
|
||||||
RegisterPass(Pass p);
|
RegisterPass(Pass p);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -1265,7 +1265,7 @@ c10::optional<const Node*> AliasDb::getLastWildcard() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||||
// WARNING: by adding a case to this list, you are asserting that you have
|
// WARNING: by adding a case to this list, you are asserting that you have
|
||||||
// added a case for the unschematized node in AliasDb::analyze
|
// added a case for the unschematized node in AliasDb::analyze
|
||||||
const static std::unordered_set<Symbol> handled = {
|
const static std::unordered_set<Symbol> handled = {
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class AliasDb {
|
|||||||
|
|
||||||
// Does `n` write to an alias of one of the values in `vs`?
|
// Does `n` write to an alias of one of the values in `vs`?
|
||||||
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
|
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
|
||||||
bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
|
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
|
||||||
const;
|
const;
|
||||||
|
|
||||||
// Does `a` and `b` potentially share a memory location or do either
|
// Does `a` and `b` potentially share a memory location or do either
|
||||||
@ -56,7 +56,7 @@ class AliasDb {
|
|||||||
const at::ArrayRef<Value*>& b) const;
|
const at::ArrayRef<Value*>& b) const;
|
||||||
|
|
||||||
// Do `a` and `b` potentially share a memory location?
|
// Do `a` and `b` potentially share a memory location?
|
||||||
bool mayAlias(const Value* a, const Value* b) const;
|
TORCH_API bool mayAlias(const Value* a, const Value* b) const;
|
||||||
// Do any values in group `a` potentially share a memory location with any
|
// Do any values in group `a` potentially share a memory location with any
|
||||||
// value in group `b`? i.e. may they overlap?
|
// value in group `b`? i.e. may they overlap?
|
||||||
//
|
//
|
||||||
@ -104,7 +104,7 @@ class AliasDb {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Do any nodes write to an alias set inputed/outputed by `n`?
|
// Do any nodes write to an alias set inputed/outputed by `n`?
|
||||||
bool hasWriters(const Node* n) const;
|
TORCH_API bool hasWriters(const Node* n) const;
|
||||||
|
|
||||||
// Move 'n' (already in the graph) after 'movePoint' in the topological order.
|
// Move 'n' (already in the graph) after 'movePoint' in the topological order.
|
||||||
//
|
//
|
||||||
@ -115,8 +115,8 @@ class AliasDb {
|
|||||||
//
|
//
|
||||||
// Returns `false` if it's impossible to move `n` after `MovePoint` without
|
// Returns `false` if it's impossible to move `n` after `MovePoint` without
|
||||||
// violating dependencies, otherwise executes the move and returns `true`
|
// violating dependencies, otherwise executes the move and returns `true`
|
||||||
bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
|
TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
|
||||||
bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
|
TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
|
||||||
|
|
||||||
bool couldMoveAfterTopologically(Node* n, Node* movePoint);
|
bool couldMoveAfterTopologically(Node* n, Node* movePoint);
|
||||||
bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
|
bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
|
||||||
@ -146,7 +146,7 @@ class AliasDb {
|
|||||||
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
|
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
|
||||||
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
|
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
|
||||||
// Do any nodes write to `v`s memory location?
|
// Do any nodes write to `v`s memory location?
|
||||||
bool hasWriters(const Value* v) const;
|
TORCH_API bool hasWriters(const Value* v) const;
|
||||||
// Register the fact that `n` writes to `v`.
|
// Register the fact that `n` writes to `v`.
|
||||||
void registerWrite(const Value* v, Node* n);
|
void registerWrite(const Value* v, Node* n);
|
||||||
// Get all the values that `n` reads from.
|
// Get all the values that `n` reads from.
|
||||||
@ -163,7 +163,7 @@ class AliasDb {
|
|||||||
* Wildcard methods
|
* Wildcard methods
|
||||||
*/
|
*/
|
||||||
// is `v` a wildcard?
|
// is `v` a wildcard?
|
||||||
bool isWildcard(const Value* v) const;
|
TORCH_API bool isWildcard(const Value* v) const;
|
||||||
// Register `v` as a wildcard value.
|
// Register `v` as a wildcard value.
|
||||||
void setWildcard(const Value* v);
|
void setWildcard(const Value* v);
|
||||||
// Get all nodes that write to a wildcard value.
|
// Get all nodes that write to a wildcard value.
|
||||||
|
|||||||
@ -855,7 +855,7 @@ struct PythonPrintPass {
|
|||||||
// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
|
// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
|
||||||
void printRHS(std::ostream& stmt, Node* node) {
|
void printRHS(std::ostream& stmt, Node* node) {
|
||||||
switch (node->kind()) {
|
switch (node->kind()) {
|
||||||
case PythonOp::Kind: {
|
case prim::PythonOp: {
|
||||||
auto value = static_cast<const PythonOp*>(node);
|
auto value = static_cast<const PythonOp*>(node);
|
||||||
if (enforce_importable_) {
|
if (enforce_importable_) {
|
||||||
throw script::ErrorReport(node->getSourceLocation())
|
throw script::ErrorReport(node->getSourceLocation())
|
||||||
@ -1111,7 +1111,7 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API void PythonPrint(
|
void PythonPrint(
|
||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
const script::Function& func,
|
const script::Function& func,
|
||||||
bool is_method,
|
bool is_method,
|
||||||
@ -1124,7 +1124,7 @@ TORCH_API void PythonPrint(
|
|||||||
pp.print(out);
|
pp.print(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API void PythonPrint(
|
void PythonPrint(
|
||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
const script::CompilationUnit& cu,
|
const script::CompilationUnit& cu,
|
||||||
bool is_method,
|
bool is_method,
|
||||||
@ -1137,7 +1137,7 @@ TORCH_API void PythonPrint(
|
|||||||
pp.print(out);
|
pp.print(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API void PythonPrint(
|
void PythonPrint(
|
||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
const ClassTypePtr& classType,
|
const ClassTypePtr& classType,
|
||||||
std::vector<at::Tensor>& tensor_table,
|
std::vector<at::Tensor>& tensor_table,
|
||||||
@ -1148,7 +1148,7 @@ TORCH_API void PythonPrint(
|
|||||||
pp.print(out);
|
pp.print(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
|
bool printerHasSpecialCaseFor(Symbol sym) {
|
||||||
// WARNING: by adding a value to this set, you are asserting
|
// WARNING: by adding a value to this set, you are asserting
|
||||||
// that you have also added special handling of this symbol to
|
// that you have also added special handling of this symbol to
|
||||||
// the printer above. Not adding handling will cause import and export
|
// the printer above. Not adding handling will cause import and export
|
||||||
|
|||||||
@ -6,6 +6,8 @@
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
@ -30,17 +32,17 @@ struct Value;
|
|||||||
class MemoryDAG {
|
class MemoryDAG {
|
||||||
public:
|
public:
|
||||||
// Make `from` point at `to`.
|
// Make `from` point at `to`.
|
||||||
void makePointerTo(Element* from, Element* to);
|
TORCH_API void makePointerTo(Element* from, Element* to);
|
||||||
|
|
||||||
void addToContainedElements(Element* contained, Element* container);
|
void addToContainedElements(Element* contained, Element* container);
|
||||||
|
|
||||||
// Make a fresh element (i.e. an element that doesn't point to anything) and
|
// Make a fresh element (i.e. an element that doesn't point to anything) and
|
||||||
// return it.
|
// return it.
|
||||||
Element* makeFreshValue(const Value* v);
|
TORCH_API Element* makeFreshValue(const Value* v);
|
||||||
|
|
||||||
// Do `a` and `b` potentially share a memory location?
|
// Do `a` and `b` potentially share a memory location?
|
||||||
bool mayAlias(const Element* a, const Element* b) const;
|
bool mayAlias(const Element* a, const Element* b) const;
|
||||||
bool mayAlias(Element* a, Element* b) const;
|
TORCH_API bool mayAlias(Element* a, Element* b) const;
|
||||||
|
|
||||||
// Does a hold reference to any memory that is stored in elem, or vice versa?
|
// Does a hold reference to any memory that is stored in elem, or vice versa?
|
||||||
bool mayContainAlias(const Element* a, const Element* b) const;
|
bool mayContainAlias(const Element* a, const Element* b) const;
|
||||||
@ -124,7 +126,7 @@ struct Element {
|
|||||||
std::unordered_set<Element*> contained_elements;
|
std::unordered_set<Element*> contained_elements;
|
||||||
|
|
||||||
// Return the unique memory locations that `Element` might represent.
|
// Return the unique memory locations that `Element` might represent.
|
||||||
std::unordered_set<const Element*> getMemoryLocations() const;
|
TORCH_API std::unordered_set<const Element*> getMemoryLocations() const;
|
||||||
// We do path compression to make repeated memory location queries faster.
|
// We do path compression to make repeated memory location queries faster.
|
||||||
// An empty cache means it is invalidated (it can never be empty otherwise,
|
// An empty cache means it is invalidated (it can never be empty otherwise,
|
||||||
// since every element must point to at least one memory location).
|
// since every element must point to at least one memory location).
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/jit/ir.h>
|
#include <torch/csrc/jit/ir.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
@ -17,16 +18,16 @@ namespace SubgraphUtils {
|
|||||||
// `n` is destroyed.
|
// `n` is destroyed.
|
||||||
//
|
//
|
||||||
// Returns the new subgraph node.
|
// Returns the new subgraph node.
|
||||||
Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
|
TORCH_API Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
|
||||||
|
|
||||||
// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
|
// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
|
||||||
// subgraphs are merged.
|
// subgraphs are merged.
|
||||||
// `toMerge` is destroyed.
|
// `toMerge` is destroyed.
|
||||||
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
|
TORCH_API void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
|
||||||
|
|
||||||
// Move nodes from a subgraph node to the outer graph.
|
// Move nodes from a subgraph node to the outer graph.
|
||||||
// `subgraphNode` is destroyed.
|
// `subgraphNode` is destroyed.
|
||||||
void unmergeSubgraph(Node* subgraphNode);
|
TORCH_API void unmergeSubgraph(Node* subgraphNode);
|
||||||
|
|
||||||
// Convenience function
|
// Convenience function
|
||||||
std::shared_ptr<Graph> getSubgraph(Node* n);
|
std::shared_ptr<Graph> getSubgraph(Node* n);
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
@ -21,7 +22,7 @@ struct ProfilingRecord {
|
|||||||
ProfilingRecord(const ProfilingRecord&) = delete;
|
ProfilingRecord(const ProfilingRecord&) = delete;
|
||||||
ProfilingRecord(ProfilingRecord&&) noexcept = delete;
|
ProfilingRecord(ProfilingRecord&&) noexcept = delete;
|
||||||
static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival);
|
static ProfiledTensorTypePtr toProfiledTensorTypePtr(const IValue& ival);
|
||||||
static std::unique_ptr<ProfilingRecord> instrumentGraph(
|
TORCH_API static std::unique_ptr<ProfilingRecord> instrumentGraph(
|
||||||
const std::shared_ptr<Graph>& graph);
|
const std::shared_ptr<Graph>& graph);
|
||||||
|
|
||||||
std::shared_ptr<Graph> profiled_graph_;
|
std::shared_ptr<Graph> profiled_graph_;
|
||||||
|
|||||||
@ -19,6 +19,8 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
Symbol ConcretePythonOp::Kind = prim::PythonOp;
|
||||||
|
|
||||||
using c10::Type;
|
using c10::Type;
|
||||||
|
|
||||||
std::string getPythonName(const PyObject* obj_) {
|
std::string getPythonName(const PyObject* obj_) {
|
||||||
|
|||||||
@ -11,6 +11,8 @@ void initPythonIRBindings(PyObject* module);
|
|||||||
// execute a Python function, used for Ops we can't optimize but that we want to
|
// execute a Python function, used for Ops we can't optimize but that we want to
|
||||||
// optimize around
|
// optimize around
|
||||||
struct ConcretePythonOp : public PythonOp {
|
struct ConcretePythonOp : public PythonOp {
|
||||||
|
static Symbol Kind;
|
||||||
|
|
||||||
ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {}
|
ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {}
|
||||||
ConcretePythonOp* init(
|
ConcretePythonOp* init(
|
||||||
THPObjectPtr&& pyobj,
|
THPObjectPtr&& pyobj,
|
||||||
|
|||||||
@ -100,7 +100,7 @@ struct BuiltinFunctionRegistry {
|
|||||||
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
|
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
|
||||||
static BuiltinFunctionRegistry registry;
|
static BuiltinFunctionRegistry registry;
|
||||||
return registry.getAllBuiltinFunctionsFor(name);
|
return registry.getAllBuiltinFunctionsFor(name);
|
||||||
}
|
}
|
||||||
|
|||||||
9
torch/csrc/jit/script/jit_exception.cpp
Normal file
9
torch/csrc/jit/script/jit_exception.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#include <torch/csrc/jit/script/jit_exception.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
JITException::JITException(const std::string& msg) : std::runtime_error(msg) {}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
struct JITException : public std::runtime_error {
|
struct TORCH_API JITException : public std::runtime_error {
|
||||||
JITException() = default;
|
explicit JITException(const std::string& msg);
|
||||||
explicit JITException(const std::string& msg) : std::runtime_error(msg) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
#include <c10/util/C++17.h>
|
#include <c10/util/C++17.h>
|
||||||
#include <torch/csrc/jit/source_range.h>
|
#include <torch/csrc/jit/source_range.h>
|
||||||
#include <torch/csrc/jit/script/strtod.h>
|
#include <torch/csrc/jit/script/strtod.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
#include <ATen/core/Macros.h>
|
#include <ATen/core/Macros.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <clocale>
|
#include <clocale>
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#include "torch/csrc/jit/script/logging.h"
|
#include <torch/csrc/jit/script/logging.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
@ -17,7 +17,7 @@ void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
|
|||||||
raw_counter.count++;
|
raw_counter.count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
||||||
std::unique_lock<std::mutex> lk(m);
|
std::unique_lock<std::mutex> lk(m);
|
||||||
if (!raw_counters.count(name)) {
|
if (!raw_counters.count(name)) {
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@ -37,12 +37,12 @@ class NoopLogger : public LoggerBase {
|
|||||||
//
|
//
|
||||||
// NOTE: this is not written in a scalable way and should probably only be used
|
// NOTE: this is not written in a scalable way and should probably only be used
|
||||||
// in the single-threaded case or for testing.
|
// in the single-threaded case or for testing.
|
||||||
class LockingLogger : public LoggerBase {
|
class TORCH_API LockingLogger : public LoggerBase {
|
||||||
public:
|
public:
|
||||||
TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override;
|
void addStatValue(const std::string& stat_name, int64_t val) override;
|
||||||
TORCH_API virtual int64_t getCounterValue(const std::string& name) const;
|
virtual int64_t getCounterValue(const std::string& name) const;
|
||||||
enum class AggregationType { SUM, AVG };
|
enum class AggregationType { SUM, AVG };
|
||||||
TORCH_API void setAggregationType(
|
void setAggregationType(
|
||||||
const std::string& stat_name,
|
const std::string& stat_name,
|
||||||
AggregationType type);
|
AggregationType type);
|
||||||
~LockingLogger() {}
|
~LockingLogger() {}
|
||||||
|
|||||||
@ -22,7 +22,7 @@ namespace script {
|
|||||||
|
|
||||||
enum NoneStatus { ALWAYS, MAYBE, NEVER };
|
enum NoneStatus { ALWAYS, MAYBE, NEVER };
|
||||||
|
|
||||||
struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
|
struct TORCH_API SugaredValue : public std::enable_shared_from_this<SugaredValue> {
|
||||||
// what is this node? for error reporting (e.g. Module, python function)
|
// what is this node? for error reporting (e.g. Module, python function)
|
||||||
virtual std::string kind() const = 0;
|
virtual std::string kind() const = 0;
|
||||||
|
|
||||||
|
|||||||
@ -39,7 +39,7 @@ void badArgType(const T& v) {
|
|||||||
thread_local std::shared_ptr<TracingState> tracing_state;
|
thread_local std::shared_ptr<TracingState> tracing_state;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
TORCH_API std::function<void()> pauseTracing() {
|
std::function<void()> pauseTracing() {
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
std::shared_ptr<tracer::TracingState> state = getTracingState();
|
std::shared_ptr<tracer::TracingState> state = getTracingState();
|
||||||
tracer::setTracingState(nullptr);
|
tracer::setTracingState(nullptr);
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#include <torch/csrc/generic/utils.cpp>
|
#include <torch/csrc/generic/utils.cpp>
|
||||||
#include <TH/THGenerateHalfType.h>
|
#include <TH/THGenerateHalfType.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
#include <torch/csrc/generic/utils.cpp>
|
#include <torch/csrc/generic/utils.cpp>
|
||||||
#include <TH/THGenerateBoolType.h>
|
#include <TH/THGenerateBoolType.h>
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/functional.h>
|
#include <ATen/core/functional.h>
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@ -59,16 +59,16 @@ struct TensorGroup {
|
|||||||
// enough tensors for all data types until the size_limit, and then split
|
// enough tensors for all data types until the size_limit, and then split
|
||||||
// the accumulated tensors into different groups by data types, therefore:
|
// the accumulated tensors into different groups by data types, therefore:
|
||||||
// it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
|
// it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
|
||||||
std::vector<TensorGroup> take_tensors(
|
TORCH_API std::vector<TensorGroup> take_tensors(
|
||||||
at::TensorList tensors,
|
at::TensorList tensors,
|
||||||
size_t size_limit,
|
size_t size_limit,
|
||||||
bool fine_grained = false);
|
bool fine_grained = false);
|
||||||
|
|
||||||
void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order);
|
TORCH_API void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order);
|
||||||
|
|
||||||
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors);
|
TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors);
|
||||||
|
|
||||||
std::vector<at::Tensor> unflatten_sparse_tensors(
|
TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors(
|
||||||
const at::Tensor& flat_indices,
|
const at::Tensor& flat_indices,
|
||||||
const at::Tensor& flat_values,
|
const at::Tensor& flat_values,
|
||||||
at::TensorList tensors);
|
at::TensorList tensors);
|
||||||
|
|||||||
Reference in New Issue
Block a user