mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use torch:: instead of at:: in all C++ APIs (#13523)
Summary: In TorchScript and C++ extensions we currently advocate a mix of `torch::` and `at::` namespace usage. In the C++ frontend I had instead exported all symbols from `at::` and some from `c10::` into the `torch::` namespace. This is far, far easier for users to understand, and also avoid bugs around creating tensors vs. variables. The same should from now on be true for the TorchScript C++ API (for running and loading models) and all C++ extensions. Note that since we're just talking about typedefs, this change does not break any existing code. Once this lands I will update stuff in `pytorch/tutorials` too. zdevito ezyang gchanan Pull Request resolved: https://github.com/pytorch/pytorch/pull/13523 Differential Revision: D12942787 Pulled By: goldsborough fbshipit-source-id: 76058936bd8707b33d9e5bbc2d0705fc3d820763
This commit is contained in:
committed by
Facebook Github Bot
parent
be424de869
commit
393ad6582d
@ -61,15 +61,15 @@ a taste of this interface:
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
|
||||
at::Tensor a = torch::ones({2, 2}, at::requires_grad());
|
||||
at::Tensor b = torch::randn({2, 2});
|
||||
torch::Tensor a = torch::ones({2, 2}, torch::requires_grad());
|
||||
torch::Tensor b = torch::randn({2, 2});
|
||||
auto c = a + b;
|
||||
c.backward(); // a.grad() will now hold the gradient of c w.r.t. a.
|
||||
|
||||
The ``at::Tensor`` class in ATen is not differentiable by default. To add the
|
||||
differentiability of tensors the autograd API provides, you must use tensor
|
||||
factory functions from the `torch::` namespace instead of the `at` namespace.
|
||||
For example, while a tensor created with `at::ones` will not be differentiable,
|
||||
For example, while a tensor created with `torch::ones` will not be differentiable,
|
||||
a tensor created with `torch::ones` will be.
|
||||
|
||||
C++ Frontend
|
||||
|
@ -6,7 +6,7 @@ configuration files required to depend on PyTorch. We call this distribution
|
||||
*LibTorch*, and you can download ZIP archives containing the latest LibTorch
|
||||
distribution on `our website <https://pytorch.org/get-started/locally/>`_. Below
|
||||
is a small example of writing a minimal application that depends on LibTorch
|
||||
and uses the ``at::Tensor`` class which comes with the PyTorch C++ API.
|
||||
and uses the ``torch::Tensor`` class which comes with the PyTorch C++ API.
|
||||
|
||||
Minimal Example
|
||||
---------------
|
||||
@ -37,7 +37,7 @@ this:
|
||||
target_link_libraries(example-app "${TORCH_LIBRARIES}")
|
||||
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
|
||||
|
||||
The implementation of our example will simply create a new `at::Tensor` and
|
||||
The implementation of our example will simply create a new `torch::Tensor` and
|
||||
print it:
|
||||
|
||||
.. code-block:: cpp
|
||||
@ -46,7 +46,7 @@ print it:
|
||||
#include <iostream>
|
||||
|
||||
int main() {
|
||||
at::Tensor tensor = torch::rand({2, 3});
|
||||
torch::Tensor tensor = torch::rand({2, 3});
|
||||
std::cout << tensor << std::endl;
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/nn/cursor.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/data.h>
|
||||
#include <torch/data/detail/sequencers.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <torch/optim/adam.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/sgd.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/jit.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/csrc/utils/tempfile.h>
|
||||
#include <torch/nn/init.h>
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/rnn.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
@ -54,8 +54,8 @@ TEST_F(ModuleTest, ZeroGrad) {
|
||||
TEST_F(ModuleTest, ZeroGradWithUndefined) {
|
||||
struct TestModule : torch::nn::Module {
|
||||
TestModule() {
|
||||
x = register_parameter("x", torch::ones(5, at::requires_grad()));
|
||||
y = register_parameter("y", torch::ones(5, at::requires_grad()));
|
||||
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
|
||||
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
|
||||
}
|
||||
torch::Tensor x, y;
|
||||
};
|
||||
@ -194,7 +194,7 @@ TEST_F(ModuleTest, Conversion_MultiCUDA) {
|
||||
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
||||
ASSERT_EQ(parameter->device().index(), 0);
|
||||
}
|
||||
module->to({at::kCUDA, 1});
|
||||
module->to({torch::kCUDA, 1});
|
||||
for (auto& parameter : module->parameters()) {
|
||||
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
||||
ASSERT_EQ(parameter->device().index(), 1);
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <torch/nn/modules/embedding.h>
|
||||
#include <torch/nn/modules/functional.h>
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/sequential.h>
|
||||
#include <torch/optim.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/optim_baseline.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
// @generated from test/cpp/api/optim_baseline.py
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
@ -9,7 +9,7 @@ import torch.optim
|
||||
|
||||
|
||||
HEADER = """
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/parallel/data_parallel.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/rnn.h>
|
||||
#include <torch/optim/adam.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/rnn.h>
|
||||
#include <torch/nn/modules/sequential.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/sgd.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <string>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Functions.h>
|
||||
|
@ -5,10 +5,10 @@
|
||||
// into one shared library.
|
||||
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
|
||||
|
||||
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
|
||||
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
|
||||
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
|
||||
auto output = at::zeros_like(x);
|
||||
auto output = torch::zeros_like(x);
|
||||
sigmoid_add_cuda(
|
||||
x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
||||
return output;
|
||||
|
@ -1,53 +1,53 @@
|
||||
/*
|
||||
* CuDNN ReLU extension. Simple function but contains the general structure of
|
||||
* most CuDNN extensions:
|
||||
* 1) Check arguments. at::check* functions provide a standard way to validate
|
||||
* input and provide pretty errors.
|
||||
* 2) Create descriptors. Most CuDNN functions require creating and setting a
|
||||
* variety of descriptors.
|
||||
* 3) Apply the CuDNN function.
|
||||
* 4) Destroy your descriptors.
|
||||
* 5) Return something (optional).
|
||||
* 1) Check arguments. torch::check* functions provide a standard way to
|
||||
* validate input and provide pretty errors. 2) Create descriptors. Most CuDNN
|
||||
* functions require creating and setting a variety of descriptors. 3) Apply the
|
||||
* CuDNN function. 4) Destroy your descriptors. 5) Return something (optional).
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
|
||||
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
|
||||
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
|
||||
#include <ATen/cudnn/Handle.h> // for getCudnnHandle
|
||||
|
||||
// Name of function in python module and name used for error messages by
|
||||
// at::check* functions.
|
||||
// torch::check* functions.
|
||||
const char* cudnn_relu_name = "cudnn_relu";
|
||||
|
||||
// Check arguments to cudnn_relu
|
||||
void cudnn_relu_check(const at::Tensor& inputs, const at::Tensor& outputs) {
|
||||
void cudnn_relu_check(
|
||||
const torch::Tensor& inputs,
|
||||
const torch::Tensor& outputs) {
|
||||
// Create TensorArgs. These record the names and positions of each tensor as a
|
||||
// parameter.
|
||||
at::TensorArg arg_inputs(inputs, "inputs", 0);
|
||||
at::TensorArg arg_outputs(outputs, "outputs", 1);
|
||||
torch::TensorArg arg_inputs(inputs, "inputs", 0);
|
||||
torch::TensorArg arg_outputs(outputs, "outputs", 1);
|
||||
// Check arguments. No need to return anything. These functions with throw an
|
||||
// error if they fail. Messages are populated using information from
|
||||
// TensorArgs.
|
||||
at::checkContiguous(cudnn_relu_name, arg_inputs);
|
||||
at::checkScalarType(cudnn_relu_name, arg_inputs, at::kFloat);
|
||||
at::checkBackend(cudnn_relu_name, arg_inputs.tensor, at::Backend::CUDA);
|
||||
at::checkContiguous(cudnn_relu_name, arg_outputs);
|
||||
at::checkScalarType(cudnn_relu_name, arg_outputs, at::kFloat);
|
||||
at::checkBackend(cudnn_relu_name, arg_outputs.tensor, at::Backend::CUDA);
|
||||
at::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
|
||||
torch::checkContiguous(cudnn_relu_name, arg_inputs);
|
||||
torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
|
||||
torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
|
||||
torch::checkContiguous(cudnn_relu_name, arg_outputs);
|
||||
torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
|
||||
torch::checkBackend(
|
||||
cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
|
||||
torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
|
||||
}
|
||||
|
||||
void cudnn_relu(const at::Tensor& inputs, const at::Tensor& outputs) {
|
||||
void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) {
|
||||
// Most CuDNN extensions will follow a similar pattern.
|
||||
// Step 1: Check inputs. This will throw an error if inputs are invalid, so no
|
||||
// need to check return codes here.
|
||||
cudnn_relu_check(inputs, outputs);
|
||||
// Step 2: Create descriptors
|
||||
cudnnHandle_t cuDnn = at::native::getCudnnHandle();
|
||||
cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
|
||||
// Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
|
||||
// size and type and contiguous, so one descriptor is sufficient.
|
||||
at::native::TensorDescriptor input_tensor_desc(inputs, 4);
|
||||
torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
|
||||
cudnnActivationDescriptor_t activationDesc;
|
||||
// Note: Always check return value of cudnn functions using CUDNN_CHECK
|
||||
AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
|
||||
|
@ -3,15 +3,15 @@
|
||||
struct Doubler {
|
||||
Doubler(int A, int B) {
|
||||
tensor_ =
|
||||
torch::ones({A, B}, torch::dtype(torch::kDouble).requires_grad(true));
|
||||
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
|
||||
}
|
||||
at::Tensor forward() {
|
||||
torch::Tensor forward() {
|
||||
return tensor_ * 2;
|
||||
}
|
||||
at::Tensor get() const {
|
||||
torch::Tensor get() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor tensor_;
|
||||
torch::Tensor tensor_;
|
||||
};
|
||||
|
@ -1,26 +1,26 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
|
||||
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||
return x.sigmoid() + y.sigmoid();
|
||||
}
|
||||
|
||||
struct MatrixMultiplier {
|
||||
MatrixMultiplier(int A, int B) {
|
||||
tensor_ =
|
||||
torch::ones({A, B}, torch::dtype(torch::kDouble).requires_grad(true));
|
||||
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
|
||||
}
|
||||
at::Tensor forward(at::Tensor weights) {
|
||||
torch::Tensor forward(torch::Tensor weights) {
|
||||
return tensor_.mm(weights);
|
||||
}
|
||||
at::Tensor get() const {
|
||||
torch::Tensor get() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor tensor_;
|
||||
torch::Tensor tensor_;
|
||||
};
|
||||
|
||||
bool function_taking_optional(c10::optional<at::Tensor> tensor) {
|
||||
bool function_taking_optional(c10::optional<torch::Tensor> tensor) {
|
||||
return tensor.has_value();
|
||||
}
|
||||
|
||||
|
@ -5,11 +5,11 @@
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> custom_op(
|
||||
at::Tensor tensor,
|
||||
std::vector<torch::Tensor> custom_op(
|
||||
torch::Tensor tensor,
|
||||
double scalar,
|
||||
int64_t repeat) {
|
||||
std::vector<at::Tensor> output;
|
||||
std::vector<torch::Tensor> output;
|
||||
output.reserve(repeat);
|
||||
for (int64_t i = 0; i < repeat; ++i) {
|
||||
output.push_back(tensor * scalar);
|
||||
|
@ -15,7 +15,7 @@
|
||||
# endif
|
||||
// clang-format on
|
||||
|
||||
CUSTOM_OP_API std::vector<at::Tensor> custom_op(
|
||||
at::Tensor tensor,
|
||||
CUSTOM_OP_API std::vector<torch::Tensor> custom_op(
|
||||
torch::Tensor tensor,
|
||||
double scalar,
|
||||
int64_t repeat);
|
||||
|
@ -33,7 +33,7 @@ void get_operator_from_registry_and_execute() {
|
||||
torch::jit::Stack stack;
|
||||
torch::jit::push(stack, torch::ones(5), 2.0, 3);
|
||||
op->getOperation()(stack);
|
||||
std::vector<at::Tensor> output;
|
||||
std::vector<torch::Tensor> output;
|
||||
torch::jit::pop(stack, output);
|
||||
|
||||
const auto manual = custom_op(torch::ones(5), 2.0, 3);
|
||||
@ -99,19 +99,19 @@ void test_move_to_device(const std::string& path_to_exported_script_module) {
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
AT_ASSERT(module != nullptr);
|
||||
|
||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
||||
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||
return tensor.device().is_cpu();
|
||||
});
|
||||
|
||||
module->to(at::kCUDA);
|
||||
module->to(torch::kCUDA);
|
||||
|
||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
||||
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||
return tensor.device().is_cuda();
|
||||
});
|
||||
|
||||
module->to(at::kCPU);
|
||||
module->to(torch::kCPU);
|
||||
|
||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
||||
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||
return tensor.device().is_cpu();
|
||||
});
|
||||
}
|
||||
@ -121,16 +121,16 @@ void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
AT_ASSERT(module != nullptr);
|
||||
|
||||
module->to(at::kInt);
|
||||
module->to(torch::kInt);
|
||||
|
||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
||||
return tensor.dtype() == at::kInt;
|
||||
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||
return tensor.dtype() == torch::kInt;
|
||||
});
|
||||
|
||||
module->to(at::kDouble);
|
||||
module->to(torch::kDouble);
|
||||
|
||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
||||
return tensor.dtype() == at::kDouble;
|
||||
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||
return tensor.dtype() == torch::kDouble;
|
||||
});
|
||||
}
|
||||
|
||||
@ -147,7 +147,7 @@ int main(int argc, const char* argv[]) {
|
||||
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
|
||||
test_move_to_dtype(path_to_exported_script_module);
|
||||
|
||||
if (at::globalContext().getNumGPUs() > 0) {
|
||||
if (torch::globalContext().getNumGPUs() > 0) {
|
||||
test_move_to_device(path_to_exported_script_module);
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,7 @@ class TestCppExtension(common.TestCase):
|
||||
|
||||
def test_inline_jit_compile_extension_with_functions_as_list(self):
|
||||
cpp_source = '''
|
||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
||||
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
|
||||
return x.tanh() + y.tanh();
|
||||
}
|
||||
'''
|
||||
@ -170,7 +170,7 @@ class TestCppExtension(common.TestCase):
|
||||
|
||||
def test_inline_jit_compile_extension_with_functions_as_dict(self):
|
||||
cpp_source = '''
|
||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
||||
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
|
||||
return x.tanh() + y.tanh();
|
||||
}
|
||||
'''
|
||||
@ -186,14 +186,14 @@ class TestCppExtension(common.TestCase):
|
||||
|
||||
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
|
||||
cpp_source1 = '''
|
||||
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
|
||||
torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) {
|
||||
return x.sin() + y.sin();
|
||||
}
|
||||
'''
|
||||
|
||||
cpp_source2 = '''
|
||||
#include <torch/extension.h>
|
||||
at::Tensor sin_add(at::Tensor x, at::Tensor y);
|
||||
torch::Tensor sin_add(torch::Tensor x, torch::Tensor y);
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
|
||||
}
|
||||
@ -224,8 +224,8 @@ class TestCppExtension(common.TestCase):
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor cos_add(at::Tensor x, at::Tensor y) {
|
||||
auto output = at::zeros_like(x);
|
||||
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
|
||||
auto output = torch::zeros_like(x);
|
||||
const int threads = 1024;
|
||||
const int blocks = (output.numel() + threads - 1) / threads;
|
||||
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
||||
@ -234,7 +234,7 @@ class TestCppExtension(common.TestCase):
|
||||
'''
|
||||
|
||||
# Here, the C++ source need only declare the function signature.
|
||||
cpp_source = 'at::Tensor cos_add(at::Tensor x, at::Tensor y);'
|
||||
cpp_source = 'torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);'
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name='inline_jit_extension_cuda',
|
||||
@ -258,7 +258,7 @@ class TestCppExtension(common.TestCase):
|
||||
|
||||
def test_lenient_flag_handling_in_jit_extensions(self):
|
||||
cpp_source = '''
|
||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
||||
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
|
||||
return x.tanh() + y.tanh();
|
||||
}
|
||||
'''
|
||||
@ -303,8 +303,8 @@ class TestCppExtension(common.TestCase):
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor half_test(at::Tensor input) {
|
||||
auto output = at::empty(1, input.options().dtype(at::kFloat));
|
||||
torch::Tensor half_test(torch::Tensor input) {
|
||||
auto output = torch::empty(1, input.options().dtype(torch::kFloat));
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
|
||||
half_test_kernel<scalar_t><<<1, 1>>>(
|
||||
input.data<scalar_t>(),
|
||||
@ -316,7 +316,7 @@ class TestCppExtension(common.TestCase):
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name='half_test_extension',
|
||||
cpp_sources='at::Tensor half_test(at::Tensor input);',
|
||||
cpp_sources='torch::Tensor half_test(torch::Tensor input);',
|
||||
cuda_sources=cuda_source,
|
||||
functions=['half_test'],
|
||||
verbose=True)
|
||||
|
@ -6,5 +6,5 @@
|
||||
#include <torch/nn.h>
|
||||
#include <torch/optim.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include <torch/data/iterator.h>
|
||||
#include <torch/data/samplers/random.h>
|
||||
#include <torch/data/worker_exception.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/arg.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/data/example.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/core/ArrayRef.h>
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/data/datasets/base.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/core/ArrayRef.h>
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/data/datasets/base.h>
|
||||
#include <torch/data/example.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/data/datasets/base.h>
|
||||
#include <torch/data/example.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/data/detail/queue.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace data {
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/data/samplers/base.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/data/samplers/base.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/data/samplers/base.h>
|
||||
#include <torch/data/samplers/custom_batch_request.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/data/example.h>
|
||||
#include <torch/data/transforms/collate.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/data/example.h>
|
||||
#include <torch/data/transforms/base.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/OptionsGuard.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <iterator>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <torch/nn/cursor.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/detail/static.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/expanding_array.h>
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/nn/modules/dropout.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/modules/any.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/cuda.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/autograd/functions/comm.h>
|
||||
#include <torch/csrc/cuda/comm.h>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <torch/arg.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
|
@ -1,11 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/detail/static.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
@ -13,9 +13,6 @@ using namespace at; // NOLINT
|
||||
using c10::optional;
|
||||
using c10::nullopt;
|
||||
|
||||
using c10::optional;
|
||||
using c10::nullopt;
|
||||
|
||||
using Dtype = at::ScalarType;
|
||||
|
||||
/// Fixed width dtypes.
|
@ -1,7 +1,7 @@
|
||||
#include <torch/data/datasets/mnist.h>
|
||||
|
||||
#include <torch/data/example.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/data/samplers/random.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/data/samplers/sequential.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/data/samplers/stream.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/nn/cursor.h>
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/nn/init.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/nn/modules/batchnorm.h>
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/nn/modules/conv.h>
|
||||
|
||||
#include <torch/expanding_array.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <cmath>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/nn/modules/dropout.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/nn/modules/embedding.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/nn/modules/functional.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/nn/modules/linear.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <cmath>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/nn/modules/rnn.h>
|
||||
|
||||
#include <torch/nn/modules/dropout.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/nn/cursor.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/optim/serialize.h>
|
||||
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/serialize/input-archive.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <torch/csrc/jit/import.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/serialize/output-archive.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <torch/csrc/jit/export.h>
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
|
||||
namespace torch {
|
||||
|
@ -423,7 +423,6 @@ private:
|
||||
template<TypeKind K, typename T>
|
||||
struct SingleElementType : public Type {
|
||||
static const TypeKind Kind = K;
|
||||
static constexpr bool is_singleton = true;
|
||||
TypePtr getElementType() const {
|
||||
return elem;
|
||||
}
|
||||
@ -488,9 +487,6 @@ struct FutureType;
|
||||
using FutureTypePtr = std::shared_ptr<FutureType>;
|
||||
|
||||
struct TORCH_API FutureType : public Type {
|
||||
// It's not exactly a singleton, but there should be exactly once instance of
|
||||
// Future[T] for every T
|
||||
static constexpr bool is_singleton = true;
|
||||
friend struct Type;
|
||||
template<typename ... T>
|
||||
static FutureTypePtr create(TypePtr elem) {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/api/include/torch/types.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/import.h>
|
||||
|
@ -633,13 +633,13 @@ def load_inline(name,
|
||||
as its docstring.
|
||||
|
||||
The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
|
||||
file and prepended with ``ATen/ATen.h``, ``cuda.h`` and ``cuda_runtime.h``
|
||||
includes. The ``.cpp`` and ``.cu`` files are compiled separately, but
|
||||
ultimately linked into a single library. Note that no bindings are
|
||||
generated for functions in ``cuda_sources`` per se. To bind to a CUDA
|
||||
kernel, you must create a C++ function that calls it, and either declare or
|
||||
define this C++ function in one of the ``cpp_sources`` (and include its
|
||||
name in ``functions``).
|
||||
file and prepended with ``torch/types.h``, ``cuda.h`` and
|
||||
``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
|
||||
separately, but ultimately linked into a single library. Note that no
|
||||
bindings are generated for functions in ``cuda_sources`` per se. To bind
|
||||
to a CUDA kernel, you must create a C++ function that calls it, and either
|
||||
declare or define this C++ function in one of the ``cpp_sources`` (and
|
||||
include its name in ``functions``).
|
||||
|
||||
See :func:`load` for a description of arguments omitted below.
|
||||
|
||||
@ -702,7 +702,7 @@ def load_inline(name,
|
||||
sources = [cpp_source_path]
|
||||
|
||||
if cuda_sources:
|
||||
cuda_sources.insert(0, '#include <ATen/ATen.h>')
|
||||
cuda_sources.insert(0, '#include <torch/types.h>')
|
||||
cuda_sources.insert(1, '#include <cuda.h>')
|
||||
cuda_sources.insert(2, '#include <cuda_runtime.h>')
|
||||
|
||||
|
Reference in New Issue
Block a user