Use torch:: instead of at:: in all C++ APIs (#13523)

Summary:
In TorchScript and C++ extensions we currently advocate a mix of `torch::` and `at::` namespace usage. In the C++ frontend I had instead exported all symbols from `at::` and some from `c10::` into the `torch::` namespace. This is far, far easier for users to understand, and also avoid bugs around creating tensors vs. variables. The same should from now on be true for the TorchScript C++ API (for running and loading models) and all C++ extensions.

Note that since we're just talking about typedefs, this change does not break any existing code.

Once this lands I will update stuff in `pytorch/tutorials` too.

zdevito ezyang gchanan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13523

Differential Revision: D12942787

Pulled By: goldsborough

fbshipit-source-id: 76058936bd8707b33d9e5bbc2d0705fc3d820763
This commit is contained in:
Peter Goldsborough
2018-11-06 14:28:20 -08:00
committed by Facebook Github Bot
parent be424de869
commit 393ad6582d
90 changed files with 158 additions and 164 deletions

View File

@ -61,15 +61,15 @@ a taste of this interface:
#include <torch/csrc/autograd/variable.h> #include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/function.h> #include <torch/csrc/autograd/function.h>
at::Tensor a = torch::ones({2, 2}, at::requires_grad()); torch::Tensor a = torch::ones({2, 2}, torch::requires_grad());
at::Tensor b = torch::randn({2, 2}); torch::Tensor b = torch::randn({2, 2});
auto c = a + b; auto c = a + b;
c.backward(); // a.grad() will now hold the gradient of c w.r.t. a. c.backward(); // a.grad() will now hold the gradient of c w.r.t. a.
The ``at::Tensor`` class in ATen is not differentiable by default. To add the The ``at::Tensor`` class in ATen is not differentiable by default. To add the
differentiability of tensors the autograd API provides, you must use tensor differentiability of tensors the autograd API provides, you must use tensor
factory functions from the `torch::` namespace instead of the `at` namespace. factory functions from the `torch::` namespace instead of the `at` namespace.
For example, while a tensor created with `at::ones` will not be differentiable, For example, while a tensor created with `torch::ones` will not be differentiable,
a tensor created with `torch::ones` will be. a tensor created with `torch::ones` will be.
C++ Frontend C++ Frontend

View File

@ -6,7 +6,7 @@ configuration files required to depend on PyTorch. We call this distribution
*LibTorch*, and you can download ZIP archives containing the latest LibTorch *LibTorch*, and you can download ZIP archives containing the latest LibTorch
distribution on `our website <https://pytorch.org/get-started/locally/>`_. Below distribution on `our website <https://pytorch.org/get-started/locally/>`_. Below
is a small example of writing a minimal application that depends on LibTorch is a small example of writing a minimal application that depends on LibTorch
and uses the ``at::Tensor`` class which comes with the PyTorch C++ API. and uses the ``torch::Tensor`` class which comes with the PyTorch C++ API.
Minimal Example Minimal Example
--------------- ---------------
@ -37,7 +37,7 @@ this:
target_link_libraries(example-app "${TORCH_LIBRARIES}") target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11) set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
The implementation of our example will simply create a new `at::Tensor` and The implementation of our example will simply create a new `torch::Tensor` and
print it: print it:
.. code-block:: cpp .. code-block:: cpp
@ -46,7 +46,7 @@ print it:
#include <iostream> #include <iostream>
int main() { int main() {
at::Tensor tensor = torch::rand({2, 3}); torch::Tensor tensor = torch::rand({2, 3});
std::cout << tensor << std::endl; std::cout << tensor << std::endl;
} }

View File

@ -2,7 +2,7 @@
#include <torch/nn/cursor.h> #include <torch/nn/cursor.h>
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -3,7 +3,7 @@
#include <torch/data.h> #include <torch/data.h>
#include <torch/data/detail/sequencers.h> #include <torch/data/detail/sequencers.h>
#include <torch/serialize.h> #include <torch/serialize.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -7,7 +7,7 @@
#include <torch/optim/adam.h> #include <torch/optim/adam.h>
#include <torch/optim/optimizer.h> #include <torch/optim/optimizer.h>
#include <torch/optim/sgd.h> #include <torch/optim/sgd.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <torch/jit.h> #include <torch/jit.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <string> #include <string>

View File

@ -3,7 +3,7 @@
#include <torch/csrc/utils/tempfile.h> #include <torch/csrc/utils/tempfile.h>
#include <torch/nn/init.h> #include <torch/nn/init.h>
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -3,7 +3,7 @@
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/nn/modules/rnn.h> #include <torch/nn/modules/rnn.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>
@ -54,8 +54,8 @@ TEST_F(ModuleTest, ZeroGrad) {
TEST_F(ModuleTest, ZeroGradWithUndefined) { TEST_F(ModuleTest, ZeroGradWithUndefined) {
struct TestModule : torch::nn::Module { struct TestModule : torch::nn::Module {
TestModule() { TestModule() {
x = register_parameter("x", torch::ones(5, at::requires_grad())); x = register_parameter("x", torch::ones(5, torch::requires_grad()));
y = register_parameter("y", torch::ones(5, at::requires_grad())); y = register_parameter("y", torch::ones(5, torch::requires_grad()));
} }
torch::Tensor x, y; torch::Tensor x, y;
}; };
@ -194,7 +194,7 @@ TEST_F(ModuleTest, Conversion_MultiCUDA) {
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA); ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter->device().index(), 0); ASSERT_EQ(parameter->device().index(), 0);
} }
module->to({at::kCUDA, 1}); module->to({torch::kCUDA, 1});
for (auto& parameter : module->parameters()) { for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA); ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter->device().index(), 1); ASSERT_EQ(parameter->device().index(), 1);

View File

@ -7,7 +7,7 @@
#include <torch/nn/modules/embedding.h> #include <torch/nn/modules/embedding.h>
#include <torch/nn/modules/functional.h> #include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -5,7 +5,7 @@
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h> #include <torch/nn/modules/sequential.h>
#include <torch/optim.h> #include <torch/optim.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/optim_baseline.h> #include <test/cpp/api/optim_baseline.h>

View File

@ -1,6 +1,6 @@
// @generated from test/cpp/api/optim_baseline.py // @generated from test/cpp/api/optim_baseline.py
#include <torch/tensor.h> #include <torch/types.h>
#include <vector> #include <vector>

View File

@ -9,7 +9,7 @@ import torch.optim
HEADER = """ HEADER = """
#include <torch/tensor.h> #include <torch/types.h>
#include <vector> #include <vector>

View File

@ -5,7 +5,7 @@
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/nn/parallel/data_parallel.h> #include <torch/nn/parallel/data_parallel.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -3,7 +3,7 @@
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/nn/modules/rnn.h> #include <torch/nn/modules/rnn.h>
#include <torch/optim/adam.h> #include <torch/optim/adam.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -7,7 +7,7 @@
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/nn/modules/rnn.h> #include <torch/nn/modules/rnn.h>
#include <torch/nn/modules/sequential.h> #include <torch/nn/modules/sequential.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <algorithm> #include <algorithm>

View File

@ -6,7 +6,7 @@
#include <torch/optim/optimizer.h> #include <torch/optim/optimizer.h>
#include <torch/optim/sgd.h> #include <torch/optim/sgd.h>
#include <torch/serialize.h> #include <torch/serialize.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <test/cpp/api/support.h> #include <test/cpp/api/support.h>

View File

@ -5,7 +5,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <string> #include <string>

View File

@ -1,6 +1,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>

View File

@ -1,6 +1,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <ATen/Context.h> #include <ATen/Context.h>
#include <ATen/Functions.h> #include <ATen/Functions.h>

View File

@ -5,10 +5,10 @@
// into one shared library. // into one shared library.
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size); void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) { torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor"); AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor"); AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
auto output = at::zeros_like(x); auto output = torch::zeros_like(x);
sigmoid_add_cuda( sigmoid_add_cuda(
x.data<float>(), y.data<float>(), output.data<float>(), output.numel()); x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
return output; return output;

View File

@ -1,53 +1,53 @@
/* /*
* CuDNN ReLU extension. Simple function but contains the general structure of * CuDNN ReLU extension. Simple function but contains the general structure of
* most CuDNN extensions: * most CuDNN extensions:
* 1) Check arguments. at::check* functions provide a standard way to validate * 1) Check arguments. torch::check* functions provide a standard way to
* input and provide pretty errors. * validate input and provide pretty errors. 2) Create descriptors. Most CuDNN
* 2) Create descriptors. Most CuDNN functions require creating and setting a * functions require creating and setting a variety of descriptors. 3) Apply the
* variety of descriptors. * CuDNN function. 4) Destroy your descriptors. 5) Return something (optional).
* 3) Apply the CuDNN function.
* 4) Destroy your descriptors.
* 5) Return something (optional).
*/ */
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK #include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
#include <ATen/cudnn/Handle.h> // for getCudnnHandle #include <ATen/cudnn/Handle.h> // for getCudnnHandle
// Name of function in python module and name used for error messages by // Name of function in python module and name used for error messages by
// at::check* functions. // torch::check* functions.
const char* cudnn_relu_name = "cudnn_relu"; const char* cudnn_relu_name = "cudnn_relu";
// Check arguments to cudnn_relu // Check arguments to cudnn_relu
void cudnn_relu_check(const at::Tensor& inputs, const at::Tensor& outputs) { void cudnn_relu_check(
const torch::Tensor& inputs,
const torch::Tensor& outputs) {
// Create TensorArgs. These record the names and positions of each tensor as a // Create TensorArgs. These record the names and positions of each tensor as a
// parameter. // parameter.
at::TensorArg arg_inputs(inputs, "inputs", 0); torch::TensorArg arg_inputs(inputs, "inputs", 0);
at::TensorArg arg_outputs(outputs, "outputs", 1); torch::TensorArg arg_outputs(outputs, "outputs", 1);
// Check arguments. No need to return anything. These functions with throw an // Check arguments. No need to return anything. These functions with throw an
// error if they fail. Messages are populated using information from // error if they fail. Messages are populated using information from
// TensorArgs. // TensorArgs.
at::checkContiguous(cudnn_relu_name, arg_inputs); torch::checkContiguous(cudnn_relu_name, arg_inputs);
at::checkScalarType(cudnn_relu_name, arg_inputs, at::kFloat); torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
at::checkBackend(cudnn_relu_name, arg_inputs.tensor, at::Backend::CUDA); torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
at::checkContiguous(cudnn_relu_name, arg_outputs); torch::checkContiguous(cudnn_relu_name, arg_outputs);
at::checkScalarType(cudnn_relu_name, arg_outputs, at::kFloat); torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
at::checkBackend(cudnn_relu_name, arg_outputs.tensor, at::Backend::CUDA); torch::checkBackend(
at::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs); cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
} }
void cudnn_relu(const at::Tensor& inputs, const at::Tensor& outputs) { void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) {
// Most CuDNN extensions will follow a similar pattern. // Most CuDNN extensions will follow a similar pattern.
// Step 1: Check inputs. This will throw an error if inputs are invalid, so no // Step 1: Check inputs. This will throw an error if inputs are invalid, so no
// need to check return codes here. // need to check return codes here.
cudnn_relu_check(inputs, outputs); cudnn_relu_check(inputs, outputs);
// Step 2: Create descriptors // Step 2: Create descriptors
cudnnHandle_t cuDnn = at::native::getCudnnHandle(); cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
// Note: 4 is minimum dim for a TensorDescriptor. Input and output are same // Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
// size and type and contiguous, so one descriptor is sufficient. // size and type and contiguous, so one descriptor is sufficient.
at::native::TensorDescriptor input_tensor_desc(inputs, 4); torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
cudnnActivationDescriptor_t activationDesc; cudnnActivationDescriptor_t activationDesc;
// Note: Always check return value of cudnn functions using CUDNN_CHECK // Note: Always check return value of cudnn functions using CUDNN_CHECK
AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc)); AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));

View File

@ -3,15 +3,15 @@
struct Doubler { struct Doubler {
Doubler(int A, int B) { Doubler(int A, int B) {
tensor_ = tensor_ =
torch::ones({A, B}, torch::dtype(torch::kDouble).requires_grad(true)); torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
} }
at::Tensor forward() { torch::Tensor forward() {
return tensor_ * 2; return tensor_ * 2;
} }
at::Tensor get() const { torch::Tensor get() const {
return tensor_; return tensor_;
} }
private: private:
at::Tensor tensor_; torch::Tensor tensor_;
}; };

View File

@ -1,26 +1,26 @@
#include <torch/extension.h> #include <torch/extension.h>
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) { torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
return x.sigmoid() + y.sigmoid(); return x.sigmoid() + y.sigmoid();
} }
struct MatrixMultiplier { struct MatrixMultiplier {
MatrixMultiplier(int A, int B) { MatrixMultiplier(int A, int B) {
tensor_ = tensor_ =
torch::ones({A, B}, torch::dtype(torch::kDouble).requires_grad(true)); torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
} }
at::Tensor forward(at::Tensor weights) { torch::Tensor forward(torch::Tensor weights) {
return tensor_.mm(weights); return tensor_.mm(weights);
} }
at::Tensor get() const { torch::Tensor get() const {
return tensor_; return tensor_;
} }
private: private:
at::Tensor tensor_; torch::Tensor tensor_;
}; };
bool function_taking_optional(c10::optional<at::Tensor> tensor) { bool function_taking_optional(c10::optional<torch::Tensor> tensor) {
return tensor.has_value(); return tensor.has_value();
} }

View File

@ -5,11 +5,11 @@
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
std::vector<at::Tensor> custom_op( std::vector<torch::Tensor> custom_op(
at::Tensor tensor, torch::Tensor tensor,
double scalar, double scalar,
int64_t repeat) { int64_t repeat) {
std::vector<at::Tensor> output; std::vector<torch::Tensor> output;
output.reserve(repeat); output.reserve(repeat);
for (int64_t i = 0; i < repeat; ++i) { for (int64_t i = 0; i < repeat; ++i) {
output.push_back(tensor * scalar); output.push_back(tensor * scalar);

View File

@ -15,7 +15,7 @@
# endif # endif
// clang-format on // clang-format on
CUSTOM_OP_API std::vector<at::Tensor> custom_op( CUSTOM_OP_API std::vector<torch::Tensor> custom_op(
at::Tensor tensor, torch::Tensor tensor,
double scalar, double scalar,
int64_t repeat); int64_t repeat);

View File

@ -33,7 +33,7 @@ void get_operator_from_registry_and_execute() {
torch::jit::Stack stack; torch::jit::Stack stack;
torch::jit::push(stack, torch::ones(5), 2.0, 3); torch::jit::push(stack, torch::ones(5), 2.0, 3);
op->getOperation()(stack); op->getOperation()(stack);
std::vector<at::Tensor> output; std::vector<torch::Tensor> output;
torch::jit::pop(stack, output); torch::jit::pop(stack, output);
const auto manual = custom_op(torch::ones(5), 2.0, 3); const auto manual = custom_op(torch::ones(5), 2.0, 3);
@ -99,19 +99,19 @@ void test_move_to_device(const std::string& path_to_exported_script_module) {
torch::jit::load(path_to_exported_script_module); torch::jit::load(path_to_exported_script_module);
AT_ASSERT(module != nullptr); AT_ASSERT(module != nullptr);
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) { helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
return tensor.device().is_cpu(); return tensor.device().is_cpu();
}); });
module->to(at::kCUDA); module->to(torch::kCUDA);
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) { helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
return tensor.device().is_cuda(); return tensor.device().is_cuda();
}); });
module->to(at::kCPU); module->to(torch::kCPU);
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) { helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
return tensor.device().is_cpu(); return tensor.device().is_cpu();
}); });
} }
@ -121,16 +121,16 @@ void test_move_to_dtype(const std::string& path_to_exported_script_module) {
torch::jit::load(path_to_exported_script_module); torch::jit::load(path_to_exported_script_module);
AT_ASSERT(module != nullptr); AT_ASSERT(module != nullptr);
module->to(at::kInt); module->to(torch::kInt);
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) { helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
return tensor.dtype() == at::kInt; return tensor.dtype() == torch::kInt;
}); });
module->to(at::kDouble); module->to(torch::kDouble);
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) { helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
return tensor.dtype() == at::kDouble; return tensor.dtype() == torch::kDouble;
}); });
} }
@ -147,7 +147,7 @@ int main(int argc, const char* argv[]) {
test_argument_checking_for_serialized_modules(path_to_exported_script_module); test_argument_checking_for_serialized_modules(path_to_exported_script_module);
test_move_to_dtype(path_to_exported_script_module); test_move_to_dtype(path_to_exported_script_module);
if (at::globalContext().getNumGPUs() > 0) { if (torch::globalContext().getNumGPUs() > 0) {
test_move_to_device(path_to_exported_script_module); test_move_to_device(path_to_exported_script_module);
} }

View File

@ -149,7 +149,7 @@ class TestCppExtension(common.TestCase):
def test_inline_jit_compile_extension_with_functions_as_list(self): def test_inline_jit_compile_extension_with_functions_as_list(self):
cpp_source = ''' cpp_source = '''
at::Tensor tanh_add(at::Tensor x, at::Tensor y) { torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
return x.tanh() + y.tanh(); return x.tanh() + y.tanh();
} }
''' '''
@ -170,7 +170,7 @@ class TestCppExtension(common.TestCase):
def test_inline_jit_compile_extension_with_functions_as_dict(self): def test_inline_jit_compile_extension_with_functions_as_dict(self):
cpp_source = ''' cpp_source = '''
at::Tensor tanh_add(at::Tensor x, at::Tensor y) { torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
return x.tanh() + y.tanh(); return x.tanh() + y.tanh();
} }
''' '''
@ -186,14 +186,14 @@ class TestCppExtension(common.TestCase):
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self): def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
cpp_source1 = ''' cpp_source1 = '''
at::Tensor sin_add(at::Tensor x, at::Tensor y) { torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) {
return x.sin() + y.sin(); return x.sin() + y.sin();
} }
''' '''
cpp_source2 = ''' cpp_source2 = '''
#include <torch/extension.h> #include <torch/extension.h>
at::Tensor sin_add(at::Tensor x, at::Tensor y); torch::Tensor sin_add(torch::Tensor x, torch::Tensor y);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sin_add", &sin_add, "sin(x) + sin(y)"); m.def("sin_add", &sin_add, "sin(x) + sin(y)");
} }
@ -224,8 +224,8 @@ class TestCppExtension(common.TestCase):
} }
} }
at::Tensor cos_add(at::Tensor x, at::Tensor y) { torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
auto output = at::zeros_like(x); auto output = torch::zeros_like(x);
const int threads = 1024; const int threads = 1024;
const int blocks = (output.numel() + threads - 1) / threads; const int blocks = (output.numel() + threads - 1) / threads;
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel()); cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
@ -234,7 +234,7 @@ class TestCppExtension(common.TestCase):
''' '''
# Here, the C++ source need only declare the function signature. # Here, the C++ source need only declare the function signature.
cpp_source = 'at::Tensor cos_add(at::Tensor x, at::Tensor y);' cpp_source = 'torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);'
module = torch.utils.cpp_extension.load_inline( module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_cuda', name='inline_jit_extension_cuda',
@ -258,7 +258,7 @@ class TestCppExtension(common.TestCase):
def test_lenient_flag_handling_in_jit_extensions(self): def test_lenient_flag_handling_in_jit_extensions(self):
cpp_source = ''' cpp_source = '''
at::Tensor tanh_add(at::Tensor x, at::Tensor y) { torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
return x.tanh() + y.tanh(); return x.tanh() + y.tanh();
} }
''' '''
@ -303,8 +303,8 @@ class TestCppExtension(common.TestCase):
} }
} }
at::Tensor half_test(at::Tensor input) { torch::Tensor half_test(torch::Tensor input) {
auto output = at::empty(1, input.options().dtype(at::kFloat)); auto output = torch::empty(1, input.options().dtype(torch::kFloat));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
half_test_kernel<scalar_t><<<1, 1>>>( half_test_kernel<scalar_t><<<1, 1>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
@ -316,7 +316,7 @@ class TestCppExtension(common.TestCase):
module = torch.utils.cpp_extension.load_inline( module = torch.utils.cpp_extension.load_inline(
name='half_test_extension', name='half_test_extension',
cpp_sources='at::Tensor half_test(at::Tensor input);', cpp_sources='torch::Tensor half_test(torch::Tensor input);',
cuda_sources=cuda_source, cuda_sources=cuda_source,
functions=['half_test'], functions=['half_test'],
verbose=True) verbose=True)

View File

@ -6,5 +6,5 @@
#include <torch/nn.h> #include <torch/nn.h>
#include <torch/optim.h> #include <torch/optim.h>
#include <torch/serialize.h> #include <torch/serialize.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>

View File

@ -6,7 +6,7 @@
#include <torch/data/iterator.h> #include <torch/data/iterator.h>
#include <torch/data/samplers/random.h> #include <torch/data/samplers/random.h>
#include <torch/data/worker_exception.h> #include <torch/data/worker_exception.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/csrc/utils/memory.h> #include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h> #include <torch/csrc/utils/variadic.h>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/arg.h> #include <torch/arg.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <chrono> #include <chrono>
#include <cstddef> #include <cstddef>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/data/example.h> #include <torch/data/example.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <ATen/core/ArrayRef.h> #include <ATen/core/ArrayRef.h>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/data/datasets/base.h> #include <torch/data/datasets/base.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <ATen/core/ArrayRef.h> #include <ATen/core/ArrayRef.h>

View File

@ -2,7 +2,7 @@
#include <torch/data/datasets/base.h> #include <torch/data/datasets/base.h>
#include <torch/data/example.h> #include <torch/data/example.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <string> #include <string>

View File

@ -2,7 +2,7 @@
#include <torch/data/datasets/base.h> #include <torch/data/datasets/base.h>
#include <torch/data/example.h> #include <torch/data/example.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/data/detail/queue.h> #include <torch/data/detail/queue.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
namespace torch { namespace torch {
namespace data { namespace data {

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/csrc/utils/variadic.h> #include <torch/csrc/utils/variadic.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/data/samplers/base.h> #include <torch/data/samplers/base.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/data/samplers/base.h> #include <torch/data/samplers/base.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -2,7 +2,7 @@
#include <torch/data/samplers/base.h> #include <torch/data/samplers/base.h>
#include <torch/data/samplers/custom_batch_request.h> #include <torch/data/samplers/custom_batch_request.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
#include <utility> #include <utility>
#include <vector> #include <vector>

View File

@ -2,7 +2,7 @@
#include <torch/data/example.h> #include <torch/data/example.h>
#include <torch/data/transforms/collate.h> #include <torch/data/transforms/collate.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <utility> #include <utility>
#include <vector> #include <vector>

View File

@ -2,7 +2,7 @@
#include <torch/data/example.h> #include <torch/data/example.h>
#include <torch/data/transforms/base.h> #include <torch/data/transforms/base.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <functional> #include <functional>
#include <utility> #include <utility>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/csrc/utils/variadic.h> #include <torch/csrc/utils/variadic.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <ATen/OptionsGuard.h> #include <ATen/OptionsGuard.h>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <iterator> #include <iterator>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <torch/tensor.h> #include <torch/types.h>
namespace torch { namespace torch {
namespace nn { namespace nn {

View File

@ -4,7 +4,7 @@
#include <torch/nn/cursor.h> #include <torch/nn/cursor.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>

View File

@ -3,7 +3,7 @@
#include <torch/detail/static.h> #include <torch/detail/static.h>
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/csrc/autograd/variable.h> #include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h> #include <torch/csrc/utils/memory.h>

View File

@ -2,7 +2,7 @@
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstdint> #include <cstdint>

View File

@ -3,7 +3,7 @@
#include <torch/expanding_array.h> #include <torch/expanding_array.h>
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -2,7 +2,7 @@
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -2,7 +2,7 @@
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -3,7 +3,7 @@
#include <torch/csrc/utils/variadic.h> #include <torch/csrc/utils/variadic.h>
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <functional> #include <functional>
#include <utility> #include <utility>

View File

@ -3,7 +3,7 @@
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>

View File

@ -3,7 +3,7 @@
#include <torch/nn/cloneable.h> #include <torch/nn/cloneable.h>
#include <torch/nn/modules/dropout.h> #include <torch/nn/modules/dropout.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -5,7 +5,7 @@
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/nn/modules/any.h> #include <torch/nn/modules/any.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -3,7 +3,7 @@
#include <torch/cuda.h> #include <torch/cuda.h>
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/csrc/autograd/functions/comm.h> #include <torch/csrc/autograd/functions/comm.h>
#include <torch/csrc/cuda/comm.h> #include <torch/csrc/cuda/comm.h>

View File

@ -2,7 +2,7 @@
#include <torch/arg.h> #include <torch/arg.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/csrc/utils/variadic.h> #include <torch/csrc/utils/variadic.h>

View File

@ -3,7 +3,7 @@
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/optim/optimizer.h> #include <torch/optim/optimizer.h>
#include <torch/optim/serialize.h> #include <torch/optim/serialize.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <utility> #include <utility>
#include <vector> #include <vector>

View File

@ -4,7 +4,7 @@
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/optim/optimizer.h> #include <torch/optim/optimizer.h>
#include <torch/optim/serialize.h> #include <torch/optim/serialize.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <functional> #include <functional>
#include <memory> #include <memory>

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>

View File

@ -3,7 +3,7 @@
#include <torch/arg.h> #include <torch/arg.h>
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/optim/optimizer.h> #include <torch/optim/optimizer.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <utility> #include <utility>

View File

@ -1,11 +1,11 @@
#pragma once #pragma once
#include <torch/detail/static.h> #include <torch/detail/static.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/csrc/python_headers.h> #include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <iterator> #include <iterator>
#include <string> #include <string>

View File

@ -13,9 +13,6 @@ using namespace at; // NOLINT
using c10::optional; using c10::optional;
using c10::nullopt; using c10::nullopt;
using c10::optional;
using c10::nullopt;
using Dtype = at::ScalarType; using Dtype = at::ScalarType;
/// Fixed width dtypes. /// Fixed width dtypes.

View File

@ -1,7 +1,7 @@
#include <torch/data/datasets/mnist.h> #include <torch/data/datasets/mnist.h>
#include <torch/data/example.h> #include <torch/data/example.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -1,6 +1,6 @@
#include <torch/data/samplers/random.h> #include <torch/data/samplers/random.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>

View File

@ -1,6 +1,6 @@
#include <torch/data/samplers/sequential.h> #include <torch/data/samplers/sequential.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>

View File

@ -1,6 +1,6 @@
#include <torch/data/samplers/stream.h> #include <torch/data/samplers/stream.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -1,7 +1,7 @@
#include <torch/nn/cursor.h> #include <torch/nn/cursor.h>
#include <torch/nn/module.h> #include <torch/nn/module.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>

View File

@ -1,6 +1,6 @@
#include <torch/nn/init.h> #include <torch/nn/init.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>

View File

@ -1,7 +1,7 @@
#include <torch/nn/modules/batchnorm.h> #include <torch/nn/modules/batchnorm.h>
#include <torch/cuda.h> #include <torch/cuda.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -1,7 +1,7 @@
#include <torch/nn/modules/conv.h> #include <torch/nn/modules/conv.h>
#include <torch/expanding_array.h> #include <torch/expanding_array.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <cmath> #include <cmath>

View File

@ -1,6 +1,6 @@
#include <torch/nn/modules/dropout.h> #include <torch/nn/modules/dropout.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -1,6 +1,6 @@
#include <torch/nn/modules/embedding.h> #include <torch/nn/modules/embedding.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <cstddef> #include <cstddef>

View File

@ -1,6 +1,6 @@
#include <torch/nn/modules/functional.h> #include <torch/nn/modules/functional.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <functional> #include <functional>
#include <utility> #include <utility>

View File

@ -1,6 +1,6 @@
#include <torch/nn/modules/linear.h> #include <torch/nn/modules/linear.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <cmath> #include <cmath>

View File

@ -1,7 +1,7 @@
#include <torch/nn/modules/rnn.h> #include <torch/nn/modules/rnn.h>
#include <torch/nn/modules/dropout.h> #include <torch/nn/modules/dropout.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -3,7 +3,7 @@
#include <torch/csrc/autograd/generated/variable_factories.h> #include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/nn/cursor.h> #include <torch/nn/cursor.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <string> #include <string>
#include <utility> #include <utility>

View File

@ -1,7 +1,7 @@
#include <torch/optim/serialize.h> #include <torch/optim/serialize.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>

View File

@ -4,7 +4,7 @@
#include <torch/nn/pimpl.h> #include <torch/nn/pimpl.h>
#include <torch/optim/optimizer.h> #include <torch/optim/optimizer.h>
#include <torch/optim/serialize.h> #include <torch/optim/serialize.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>

View File

@ -1,6 +1,6 @@
#include <torch/serialize/input-archive.h> #include <torch/serialize/input-archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <torch/csrc/jit/import.h> #include <torch/csrc/jit/import.h>

View File

@ -1,6 +1,6 @@
#include <torch/serialize/output-archive.h> #include <torch/serialize/output-archive.h>
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/utils.h> #include <torch/utils.h>
#include <torch/csrc/jit/export.h> #include <torch/csrc/jit/export.h>

View File

@ -1,4 +1,4 @@
#include <torch/tensor.h> #include <torch/types.h>
#include <torch/serialize/archive.h> #include <torch/serialize/archive.h>
namespace torch { namespace torch {

View File

@ -423,7 +423,6 @@ private:
template<TypeKind K, typename T> template<TypeKind K, typename T>
struct SingleElementType : public Type { struct SingleElementType : public Type {
static const TypeKind Kind = K; static const TypeKind Kind = K;
static constexpr bool is_singleton = true;
TypePtr getElementType() const { TypePtr getElementType() const {
return elem; return elem;
} }
@ -488,9 +487,6 @@ struct FutureType;
using FutureTypePtr = std::shared_ptr<FutureType>; using FutureTypePtr = std::shared_ptr<FutureType>;
struct TORCH_API FutureType : public Type { struct TORCH_API FutureType : public Type {
// It's not exactly a singleton, but there should be exactly once instance of
// Future[T] for every T
static constexpr bool is_singleton = true;
friend struct Type; friend struct Type;
template<typename ... T> template<typename ... T>
static FutureTypePtr create(TypePtr elem) { static FutureTypePtr create(TypePtr elem) {

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <torch/csrc/api/include/torch/types.h>
#include <torch/csrc/autograd/generated/variable_factories.h> #include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/custom_operator.h> #include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/import.h> #include <torch/csrc/jit/import.h>

View File

@ -633,13 +633,13 @@ def load_inline(name,
as its docstring. as its docstring.
The sources in ``cuda_sources`` are concatenated into a separate ``.cu`` The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
file and prepended with ``ATen/ATen.h``, ``cuda.h`` and ``cuda_runtime.h`` file and prepended with ``torch/types.h``, ``cuda.h`` and
includes. The ``.cpp`` and ``.cu`` files are compiled separately, but ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
ultimately linked into a single library. Note that no bindings are separately, but ultimately linked into a single library. Note that no
generated for functions in ``cuda_sources`` per se. To bind to a CUDA bindings are generated for functions in ``cuda_sources`` per se. To bind
kernel, you must create a C++ function that calls it, and either declare or to a CUDA kernel, you must create a C++ function that calls it, and either
define this C++ function in one of the ``cpp_sources`` (and include its declare or define this C++ function in one of the ``cpp_sources`` (and
name in ``functions``). include its name in ``functions``).
See :func:`load` for a description of arguments omitted below. See :func:`load` for a description of arguments omitted below.
@ -702,7 +702,7 @@ def load_inline(name,
sources = [cpp_source_path] sources = [cpp_source_path]
if cuda_sources: if cuda_sources:
cuda_sources.insert(0, '#include <ATen/ATen.h>') cuda_sources.insert(0, '#include <torch/types.h>')
cuda_sources.insert(1, '#include <cuda.h>') cuda_sources.insert(1, '#include <cuda.h>')
cuda_sources.insert(2, '#include <cuda_runtime.h>') cuda_sources.insert(2, '#include <cuda_runtime.h>')