mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use torch:: instead of at:: in all C++ APIs (#13523)
Summary: In TorchScript and C++ extensions we currently advocate a mix of `torch::` and `at::` namespace usage. In the C++ frontend I had instead exported all symbols from `at::` and some from `c10::` into the `torch::` namespace. This is far, far easier for users to understand, and also avoid bugs around creating tensors vs. variables. The same should from now on be true for the TorchScript C++ API (for running and loading models) and all C++ extensions. Note that since we're just talking about typedefs, this change does not break any existing code. Once this lands I will update stuff in `pytorch/tutorials` too. zdevito ezyang gchanan Pull Request resolved: https://github.com/pytorch/pytorch/pull/13523 Differential Revision: D12942787 Pulled By: goldsborough fbshipit-source-id: 76058936bd8707b33d9e5bbc2d0705fc3d820763
This commit is contained in:
committed by
Facebook Github Bot
parent
be424de869
commit
393ad6582d
@ -61,15 +61,15 @@ a taste of this interface:
|
|||||||
#include <torch/csrc/autograd/variable.h>
|
#include <torch/csrc/autograd/variable.h>
|
||||||
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/function.h>
|
||||||
|
|
||||||
at::Tensor a = torch::ones({2, 2}, at::requires_grad());
|
torch::Tensor a = torch::ones({2, 2}, torch::requires_grad());
|
||||||
at::Tensor b = torch::randn({2, 2});
|
torch::Tensor b = torch::randn({2, 2});
|
||||||
auto c = a + b;
|
auto c = a + b;
|
||||||
c.backward(); // a.grad() will now hold the gradient of c w.r.t. a.
|
c.backward(); // a.grad() will now hold the gradient of c w.r.t. a.
|
||||||
|
|
||||||
The ``at::Tensor`` class in ATen is not differentiable by default. To add the
|
The ``at::Tensor`` class in ATen is not differentiable by default. To add the
|
||||||
differentiability of tensors the autograd API provides, you must use tensor
|
differentiability of tensors the autograd API provides, you must use tensor
|
||||||
factory functions from the `torch::` namespace instead of the `at` namespace.
|
factory functions from the `torch::` namespace instead of the `at` namespace.
|
||||||
For example, while a tensor created with `at::ones` will not be differentiable,
|
For example, while a tensor created with `torch::ones` will not be differentiable,
|
||||||
a tensor created with `torch::ones` will be.
|
a tensor created with `torch::ones` will be.
|
||||||
|
|
||||||
C++ Frontend
|
C++ Frontend
|
||||||
|
@ -6,7 +6,7 @@ configuration files required to depend on PyTorch. We call this distribution
|
|||||||
*LibTorch*, and you can download ZIP archives containing the latest LibTorch
|
*LibTorch*, and you can download ZIP archives containing the latest LibTorch
|
||||||
distribution on `our website <https://pytorch.org/get-started/locally/>`_. Below
|
distribution on `our website <https://pytorch.org/get-started/locally/>`_. Below
|
||||||
is a small example of writing a minimal application that depends on LibTorch
|
is a small example of writing a minimal application that depends on LibTorch
|
||||||
and uses the ``at::Tensor`` class which comes with the PyTorch C++ API.
|
and uses the ``torch::Tensor`` class which comes with the PyTorch C++ API.
|
||||||
|
|
||||||
Minimal Example
|
Minimal Example
|
||||||
---------------
|
---------------
|
||||||
@ -37,7 +37,7 @@ this:
|
|||||||
target_link_libraries(example-app "${TORCH_LIBRARIES}")
|
target_link_libraries(example-app "${TORCH_LIBRARIES}")
|
||||||
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
|
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
|
||||||
|
|
||||||
The implementation of our example will simply create a new `at::Tensor` and
|
The implementation of our example will simply create a new `torch::Tensor` and
|
||||||
print it:
|
print it:
|
||||||
|
|
||||||
.. code-block:: cpp
|
.. code-block:: cpp
|
||||||
@ -46,7 +46,7 @@ print it:
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
at::Tensor tensor = torch::rand({2, 3});
|
torch::Tensor tensor = torch::rand({2, 3});
|
||||||
std::cout << tensor << std::endl;
|
std::cout << tensor << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/nn/cursor.h>
|
#include <torch/nn/cursor.h>
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/data.h>
|
#include <torch/data.h>
|
||||||
#include <torch/data/detail/sequencers.h>
|
#include <torch/data/detail/sequencers.h>
|
||||||
#include <torch/serialize.h>
|
#include <torch/serialize.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include <torch/optim/adam.h>
|
#include <torch/optim/adam.h>
|
||||||
#include <torch/optim/optimizer.h>
|
#include <torch/optim/optimizer.h>
|
||||||
#include <torch/optim/sgd.h>
|
#include <torch/optim/sgd.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <torch/jit.h>
|
#include <torch/jit.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/csrc/utils/tempfile.h>
|
#include <torch/csrc/utils/tempfile.h>
|
||||||
#include <torch/nn/init.h>
|
#include <torch/nn/init.h>
|
||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/nn/modules/rnn.h>
|
#include <torch/nn/modules/rnn.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
@ -54,8 +54,8 @@ TEST_F(ModuleTest, ZeroGrad) {
|
|||||||
TEST_F(ModuleTest, ZeroGradWithUndefined) {
|
TEST_F(ModuleTest, ZeroGradWithUndefined) {
|
||||||
struct TestModule : torch::nn::Module {
|
struct TestModule : torch::nn::Module {
|
||||||
TestModule() {
|
TestModule() {
|
||||||
x = register_parameter("x", torch::ones(5, at::requires_grad()));
|
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
|
||||||
y = register_parameter("y", torch::ones(5, at::requires_grad()));
|
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
|
||||||
}
|
}
|
||||||
torch::Tensor x, y;
|
torch::Tensor x, y;
|
||||||
};
|
};
|
||||||
@ -194,7 +194,7 @@ TEST_F(ModuleTest, Conversion_MultiCUDA) {
|
|||||||
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
||||||
ASSERT_EQ(parameter->device().index(), 0);
|
ASSERT_EQ(parameter->device().index(), 0);
|
||||||
}
|
}
|
||||||
module->to({at::kCUDA, 1});
|
module->to({torch::kCUDA, 1});
|
||||||
for (auto& parameter : module->parameters()) {
|
for (auto& parameter : module->parameters()) {
|
||||||
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA);
|
||||||
ASSERT_EQ(parameter->device().index(), 1);
|
ASSERT_EQ(parameter->device().index(), 1);
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include <torch/nn/modules/embedding.h>
|
#include <torch/nn/modules/embedding.h>
|
||||||
#include <torch/nn/modules/functional.h>
|
#include <torch/nn/modules/functional.h>
|
||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/nn/modules/sequential.h>
|
#include <torch/nn/modules/sequential.h>
|
||||||
#include <torch/optim.h>
|
#include <torch/optim.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/optim_baseline.h>
|
#include <test/cpp/api/optim_baseline.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// @generated from test/cpp/api/optim_baseline.py
|
// @generated from test/cpp/api/optim_baseline.py
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import torch.optim
|
|||||||
|
|
||||||
|
|
||||||
HEADER = """
|
HEADER = """
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/nn/parallel/data_parallel.h>
|
#include <torch/nn/parallel/data_parallel.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/nn/modules/rnn.h>
|
#include <torch/nn/modules/rnn.h>
|
||||||
#include <torch/optim/adam.h>
|
#include <torch/optim/adam.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
#include <torch/nn/modules/rnn.h>
|
#include <torch/nn/modules/rnn.h>
|
||||||
#include <torch/nn/modules/sequential.h>
|
#include <torch/nn/modules/sequential.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
#include <torch/optim/optimizer.h>
|
#include <torch/optim/optimizer.h>
|
||||||
#include <torch/optim/sgd.h>
|
#include <torch/optim/sgd.h>
|
||||||
#include <torch/serialize.h>
|
#include <torch/serialize.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <test/cpp/api/support.h>
|
#include <test/cpp/api/support.h>
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <ATen/Context.h>
|
#include <ATen/Context.h>
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
|
@ -5,10 +5,10 @@
|
|||||||
// into one shared library.
|
// into one shared library.
|
||||||
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
|
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
|
||||||
|
|
||||||
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||||
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
|
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
|
||||||
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
|
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
|
||||||
auto output = at::zeros_like(x);
|
auto output = torch::zeros_like(x);
|
||||||
sigmoid_add_cuda(
|
sigmoid_add_cuda(
|
||||||
x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
||||||
return output;
|
return output;
|
||||||
|
@ -1,53 +1,53 @@
|
|||||||
/*
|
/*
|
||||||
* CuDNN ReLU extension. Simple function but contains the general structure of
|
* CuDNN ReLU extension. Simple function but contains the general structure of
|
||||||
* most CuDNN extensions:
|
* most CuDNN extensions:
|
||||||
* 1) Check arguments. at::check* functions provide a standard way to validate
|
* 1) Check arguments. torch::check* functions provide a standard way to
|
||||||
* input and provide pretty errors.
|
* validate input and provide pretty errors. 2) Create descriptors. Most CuDNN
|
||||||
* 2) Create descriptors. Most CuDNN functions require creating and setting a
|
* functions require creating and setting a variety of descriptors. 3) Apply the
|
||||||
* variety of descriptors.
|
* CuDNN function. 4) Destroy your descriptors. 5) Return something (optional).
|
||||||
* 3) Apply the CuDNN function.
|
|
||||||
* 4) Destroy your descriptors.
|
|
||||||
* 5) Return something (optional).
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
|
|
||||||
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
|
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
|
||||||
|
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
|
||||||
#include <ATen/cudnn/Handle.h> // for getCudnnHandle
|
#include <ATen/cudnn/Handle.h> // for getCudnnHandle
|
||||||
|
|
||||||
// Name of function in python module and name used for error messages by
|
// Name of function in python module and name used for error messages by
|
||||||
// at::check* functions.
|
// torch::check* functions.
|
||||||
const char* cudnn_relu_name = "cudnn_relu";
|
const char* cudnn_relu_name = "cudnn_relu";
|
||||||
|
|
||||||
// Check arguments to cudnn_relu
|
// Check arguments to cudnn_relu
|
||||||
void cudnn_relu_check(const at::Tensor& inputs, const at::Tensor& outputs) {
|
void cudnn_relu_check(
|
||||||
|
const torch::Tensor& inputs,
|
||||||
|
const torch::Tensor& outputs) {
|
||||||
// Create TensorArgs. These record the names and positions of each tensor as a
|
// Create TensorArgs. These record the names and positions of each tensor as a
|
||||||
// parameter.
|
// parameter.
|
||||||
at::TensorArg arg_inputs(inputs, "inputs", 0);
|
torch::TensorArg arg_inputs(inputs, "inputs", 0);
|
||||||
at::TensorArg arg_outputs(outputs, "outputs", 1);
|
torch::TensorArg arg_outputs(outputs, "outputs", 1);
|
||||||
// Check arguments. No need to return anything. These functions with throw an
|
// Check arguments. No need to return anything. These functions with throw an
|
||||||
// error if they fail. Messages are populated using information from
|
// error if they fail. Messages are populated using information from
|
||||||
// TensorArgs.
|
// TensorArgs.
|
||||||
at::checkContiguous(cudnn_relu_name, arg_inputs);
|
torch::checkContiguous(cudnn_relu_name, arg_inputs);
|
||||||
at::checkScalarType(cudnn_relu_name, arg_inputs, at::kFloat);
|
torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
|
||||||
at::checkBackend(cudnn_relu_name, arg_inputs.tensor, at::Backend::CUDA);
|
torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
|
||||||
at::checkContiguous(cudnn_relu_name, arg_outputs);
|
torch::checkContiguous(cudnn_relu_name, arg_outputs);
|
||||||
at::checkScalarType(cudnn_relu_name, arg_outputs, at::kFloat);
|
torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
|
||||||
at::checkBackend(cudnn_relu_name, arg_outputs.tensor, at::Backend::CUDA);
|
torch::checkBackend(
|
||||||
at::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
|
cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
|
||||||
|
torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cudnn_relu(const at::Tensor& inputs, const at::Tensor& outputs) {
|
void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) {
|
||||||
// Most CuDNN extensions will follow a similar pattern.
|
// Most CuDNN extensions will follow a similar pattern.
|
||||||
// Step 1: Check inputs. This will throw an error if inputs are invalid, so no
|
// Step 1: Check inputs. This will throw an error if inputs are invalid, so no
|
||||||
// need to check return codes here.
|
// need to check return codes here.
|
||||||
cudnn_relu_check(inputs, outputs);
|
cudnn_relu_check(inputs, outputs);
|
||||||
// Step 2: Create descriptors
|
// Step 2: Create descriptors
|
||||||
cudnnHandle_t cuDnn = at::native::getCudnnHandle();
|
cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
|
||||||
// Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
|
// Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
|
||||||
// size and type and contiguous, so one descriptor is sufficient.
|
// size and type and contiguous, so one descriptor is sufficient.
|
||||||
at::native::TensorDescriptor input_tensor_desc(inputs, 4);
|
torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
|
||||||
cudnnActivationDescriptor_t activationDesc;
|
cudnnActivationDescriptor_t activationDesc;
|
||||||
// Note: Always check return value of cudnn functions using CUDNN_CHECK
|
// Note: Always check return value of cudnn functions using CUDNN_CHECK
|
||||||
AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
|
AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
|
||||||
|
@ -3,15 +3,15 @@
|
|||||||
struct Doubler {
|
struct Doubler {
|
||||||
Doubler(int A, int B) {
|
Doubler(int A, int B) {
|
||||||
tensor_ =
|
tensor_ =
|
||||||
torch::ones({A, B}, torch::dtype(torch::kDouble).requires_grad(true));
|
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
|
||||||
}
|
}
|
||||||
at::Tensor forward() {
|
torch::Tensor forward() {
|
||||||
return tensor_ * 2;
|
return tensor_ * 2;
|
||||||
}
|
}
|
||||||
at::Tensor get() const {
|
torch::Tensor get() const {
|
||||||
return tensor_;
|
return tensor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
at::Tensor tensor_;
|
torch::Tensor tensor_;
|
||||||
};
|
};
|
||||||
|
@ -1,26 +1,26 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||||
return x.sigmoid() + y.sigmoid();
|
return x.sigmoid() + y.sigmoid();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MatrixMultiplier {
|
struct MatrixMultiplier {
|
||||||
MatrixMultiplier(int A, int B) {
|
MatrixMultiplier(int A, int B) {
|
||||||
tensor_ =
|
tensor_ =
|
||||||
torch::ones({A, B}, torch::dtype(torch::kDouble).requires_grad(true));
|
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
|
||||||
}
|
}
|
||||||
at::Tensor forward(at::Tensor weights) {
|
torch::Tensor forward(torch::Tensor weights) {
|
||||||
return tensor_.mm(weights);
|
return tensor_.mm(weights);
|
||||||
}
|
}
|
||||||
at::Tensor get() const {
|
torch::Tensor get() const {
|
||||||
return tensor_;
|
return tensor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
at::Tensor tensor_;
|
torch::Tensor tensor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool function_taking_optional(c10::optional<at::Tensor> tensor) {
|
bool function_taking_optional(c10::optional<torch::Tensor> tensor) {
|
||||||
return tensor.has_value();
|
return tensor.has_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,11 +5,11 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
std::vector<at::Tensor> custom_op(
|
std::vector<torch::Tensor> custom_op(
|
||||||
at::Tensor tensor,
|
torch::Tensor tensor,
|
||||||
double scalar,
|
double scalar,
|
||||||
int64_t repeat) {
|
int64_t repeat) {
|
||||||
std::vector<at::Tensor> output;
|
std::vector<torch::Tensor> output;
|
||||||
output.reserve(repeat);
|
output.reserve(repeat);
|
||||||
for (int64_t i = 0; i < repeat; ++i) {
|
for (int64_t i = 0; i < repeat; ++i) {
|
||||||
output.push_back(tensor * scalar);
|
output.push_back(tensor * scalar);
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# endif
|
# endif
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
CUSTOM_OP_API std::vector<at::Tensor> custom_op(
|
CUSTOM_OP_API std::vector<torch::Tensor> custom_op(
|
||||||
at::Tensor tensor,
|
torch::Tensor tensor,
|
||||||
double scalar,
|
double scalar,
|
||||||
int64_t repeat);
|
int64_t repeat);
|
||||||
|
@ -33,7 +33,7 @@ void get_operator_from_registry_and_execute() {
|
|||||||
torch::jit::Stack stack;
|
torch::jit::Stack stack;
|
||||||
torch::jit::push(stack, torch::ones(5), 2.0, 3);
|
torch::jit::push(stack, torch::ones(5), 2.0, 3);
|
||||||
op->getOperation()(stack);
|
op->getOperation()(stack);
|
||||||
std::vector<at::Tensor> output;
|
std::vector<torch::Tensor> output;
|
||||||
torch::jit::pop(stack, output);
|
torch::jit::pop(stack, output);
|
||||||
|
|
||||||
const auto manual = custom_op(torch::ones(5), 2.0, 3);
|
const auto manual = custom_op(torch::ones(5), 2.0, 3);
|
||||||
@ -99,19 +99,19 @@ void test_move_to_device(const std::string& path_to_exported_script_module) {
|
|||||||
torch::jit::load(path_to_exported_script_module);
|
torch::jit::load(path_to_exported_script_module);
|
||||||
AT_ASSERT(module != nullptr);
|
AT_ASSERT(module != nullptr);
|
||||||
|
|
||||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||||
return tensor.device().is_cpu();
|
return tensor.device().is_cpu();
|
||||||
});
|
});
|
||||||
|
|
||||||
module->to(at::kCUDA);
|
module->to(torch::kCUDA);
|
||||||
|
|
||||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||||
return tensor.device().is_cuda();
|
return tensor.device().is_cuda();
|
||||||
});
|
});
|
||||||
|
|
||||||
module->to(at::kCPU);
|
module->to(torch::kCPU);
|
||||||
|
|
||||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||||
return tensor.device().is_cpu();
|
return tensor.device().is_cpu();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -121,16 +121,16 @@ void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
|||||||
torch::jit::load(path_to_exported_script_module);
|
torch::jit::load(path_to_exported_script_module);
|
||||||
AT_ASSERT(module != nullptr);
|
AT_ASSERT(module != nullptr);
|
||||||
|
|
||||||
module->to(at::kInt);
|
module->to(torch::kInt);
|
||||||
|
|
||||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||||
return tensor.dtype() == at::kInt;
|
return tensor.dtype() == torch::kInt;
|
||||||
});
|
});
|
||||||
|
|
||||||
module->to(at::kDouble);
|
module->to(torch::kDouble);
|
||||||
|
|
||||||
helpers::check_all_parameters(*module, [](const at::Tensor& tensor) {
|
helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
|
||||||
return tensor.dtype() == at::kDouble;
|
return tensor.dtype() == torch::kDouble;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
|
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
|
||||||
test_move_to_dtype(path_to_exported_script_module);
|
test_move_to_dtype(path_to_exported_script_module);
|
||||||
|
|
||||||
if (at::globalContext().getNumGPUs() > 0) {
|
if (torch::globalContext().getNumGPUs() > 0) {
|
||||||
test_move_to_device(path_to_exported_script_module);
|
test_move_to_device(path_to_exported_script_module);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ class TestCppExtension(common.TestCase):
|
|||||||
|
|
||||||
def test_inline_jit_compile_extension_with_functions_as_list(self):
|
def test_inline_jit_compile_extension_with_functions_as_list(self):
|
||||||
cpp_source = '''
|
cpp_source = '''
|
||||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
|
||||||
return x.tanh() + y.tanh();
|
return x.tanh() + y.tanh();
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
@ -170,7 +170,7 @@ class TestCppExtension(common.TestCase):
|
|||||||
|
|
||||||
def test_inline_jit_compile_extension_with_functions_as_dict(self):
|
def test_inline_jit_compile_extension_with_functions_as_dict(self):
|
||||||
cpp_source = '''
|
cpp_source = '''
|
||||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
|
||||||
return x.tanh() + y.tanh();
|
return x.tanh() + y.tanh();
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
@ -186,14 +186,14 @@ class TestCppExtension(common.TestCase):
|
|||||||
|
|
||||||
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
|
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
|
||||||
cpp_source1 = '''
|
cpp_source1 = '''
|
||||||
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) {
|
||||||
return x.sin() + y.sin();
|
return x.sin() + y.sin();
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
|
|
||||||
cpp_source2 = '''
|
cpp_source2 = '''
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
at::Tensor sin_add(at::Tensor x, at::Tensor y);
|
torch::Tensor sin_add(torch::Tensor x, torch::Tensor y);
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
|
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
|
||||||
}
|
}
|
||||||
@ -224,8 +224,8 @@ class TestCppExtension(common.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor cos_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
|
||||||
auto output = at::zeros_like(x);
|
auto output = torch::zeros_like(x);
|
||||||
const int threads = 1024;
|
const int threads = 1024;
|
||||||
const int blocks = (output.numel() + threads - 1) / threads;
|
const int blocks = (output.numel() + threads - 1) / threads;
|
||||||
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
||||||
@ -234,7 +234,7 @@ class TestCppExtension(common.TestCase):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
# Here, the C++ source need only declare the function signature.
|
# Here, the C++ source need only declare the function signature.
|
||||||
cpp_source = 'at::Tensor cos_add(at::Tensor x, at::Tensor y);'
|
cpp_source = 'torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);'
|
||||||
|
|
||||||
module = torch.utils.cpp_extension.load_inline(
|
module = torch.utils.cpp_extension.load_inline(
|
||||||
name='inline_jit_extension_cuda',
|
name='inline_jit_extension_cuda',
|
||||||
@ -258,7 +258,7 @@ class TestCppExtension(common.TestCase):
|
|||||||
|
|
||||||
def test_lenient_flag_handling_in_jit_extensions(self):
|
def test_lenient_flag_handling_in_jit_extensions(self):
|
||||||
cpp_source = '''
|
cpp_source = '''
|
||||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
|
||||||
return x.tanh() + y.tanh();
|
return x.tanh() + y.tanh();
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
@ -303,8 +303,8 @@ class TestCppExtension(common.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor half_test(at::Tensor input) {
|
torch::Tensor half_test(torch::Tensor input) {
|
||||||
auto output = at::empty(1, input.options().dtype(at::kFloat));
|
auto output = torch::empty(1, input.options().dtype(torch::kFloat));
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
|
||||||
half_test_kernel<scalar_t><<<1, 1>>>(
|
half_test_kernel<scalar_t><<<1, 1>>>(
|
||||||
input.data<scalar_t>(),
|
input.data<scalar_t>(),
|
||||||
@ -316,7 +316,7 @@ class TestCppExtension(common.TestCase):
|
|||||||
|
|
||||||
module = torch.utils.cpp_extension.load_inline(
|
module = torch.utils.cpp_extension.load_inline(
|
||||||
name='half_test_extension',
|
name='half_test_extension',
|
||||||
cpp_sources='at::Tensor half_test(at::Tensor input);',
|
cpp_sources='torch::Tensor half_test(torch::Tensor input);',
|
||||||
cuda_sources=cuda_source,
|
cuda_sources=cuda_source,
|
||||||
functions=['half_test'],
|
functions=['half_test'],
|
||||||
verbose=True)
|
verbose=True)
|
||||||
|
@ -6,5 +6,5 @@
|
|||||||
#include <torch/nn.h>
|
#include <torch/nn.h>
|
||||||
#include <torch/optim.h>
|
#include <torch/optim.h>
|
||||||
#include <torch/serialize.h>
|
#include <torch/serialize.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
#include <torch/data/iterator.h>
|
#include <torch/data/iterator.h>
|
||||||
#include <torch/data/samplers/random.h>
|
#include <torch/data/samplers/random.h>
|
||||||
#include <torch/data/worker_exception.h>
|
#include <torch/data/worker_exception.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <torch/csrc/utils/memory.h>
|
#include <torch/csrc/utils/memory.h>
|
||||||
#include <torch/csrc/utils/variadic.h>
|
#include <torch/csrc/utils/variadic.h>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/arg.h>
|
#include <torch/arg.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/data/example.h>
|
#include <torch/data/example.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <ATen/core/ArrayRef.h>
|
#include <ATen/core/ArrayRef.h>
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/data/datasets/base.h>
|
#include <torch/data/datasets/base.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <ATen/core/ArrayRef.h>
|
#include <ATen/core/ArrayRef.h>
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/data/datasets/base.h>
|
#include <torch/data/datasets/base.h>
|
||||||
#include <torch/data/example.h>
|
#include <torch/data/example.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/data/datasets/base.h>
|
#include <torch/data/datasets/base.h>
|
||||||
#include <torch/data/example.h>
|
#include <torch/data/example.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/data/detail/queue.h>
|
#include <torch/data/detail/queue.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/utils/variadic.h>
|
#include <torch/csrc/utils/variadic.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/data/samplers/base.h>
|
#include <torch/data/samplers/base.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/data/samplers/base.h>
|
#include <torch/data/samplers/base.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/data/samplers/base.h>
|
#include <torch/data/samplers/base.h>
|
||||||
#include <torch/data/samplers/custom_batch_request.h>
|
#include <torch/data/samplers/custom_batch_request.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/data/example.h>
|
#include <torch/data/example.h>
|
||||||
#include <torch/data/transforms/collate.h>
|
#include <torch/data/transforms/collate.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/data/example.h>
|
#include <torch/data/example.h>
|
||||||
#include <torch/data/transforms/base.h>
|
#include <torch/data/transforms/base.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/utils/variadic.h>
|
#include <torch/csrc/utils/variadic.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <ATen/OptionsGuard.h>
|
#include <ATen/OptionsGuard.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace nn {
|
namespace nn {
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <torch/nn/cursor.h>
|
#include <torch/nn/cursor.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/detail/static.h>
|
#include <torch/detail/static.h>
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <torch/csrc/autograd/variable.h>
|
#include <torch/csrc/autograd/variable.h>
|
||||||
#include <torch/csrc/utils/memory.h>
|
#include <torch/csrc/utils/memory.h>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/expanding_array.h>
|
#include <torch/expanding_array.h>
|
||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/csrc/utils/variadic.h>
|
#include <torch/csrc/utils/variadic.h>
|
||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/nn/cloneable.h>
|
#include <torch/nn/cloneable.h>
|
||||||
#include <torch/nn/modules/dropout.h>
|
#include <torch/nn/modules/dropout.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/nn/modules/any.h>
|
#include <torch/nn/modules/any.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <torch/csrc/autograd/functions/comm.h>
|
#include <torch/csrc/autograd/functions/comm.h>
|
||||||
#include <torch/csrc/cuda/comm.h>
|
#include <torch/csrc/cuda/comm.h>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/arg.h>
|
#include <torch/arg.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <torch/csrc/utils/variadic.h>
|
#include <torch/csrc/utils/variadic.h>
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/optim/optimizer.h>
|
#include <torch/optim/optimizer.h>
|
||||||
#include <torch/optim/serialize.h>
|
#include <torch/optim/serialize.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/optim/optimizer.h>
|
#include <torch/optim/optimizer.h>
|
||||||
#include <torch/optim/serialize.h>
|
#include <torch/optim/serialize.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/arg.h>
|
#include <torch/arg.h>
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/optim/optimizer.h>
|
#include <torch/optim/optimizer.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/detail/static.h>
|
#include <torch/detail/static.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <torch/csrc/python_headers.h>
|
#include <torch/csrc/python_headers.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -13,9 +13,6 @@ using namespace at; // NOLINT
|
|||||||
using c10::optional;
|
using c10::optional;
|
||||||
using c10::nullopt;
|
using c10::nullopt;
|
||||||
|
|
||||||
using c10::optional;
|
|
||||||
using c10::nullopt;
|
|
||||||
|
|
||||||
using Dtype = at::ScalarType;
|
using Dtype = at::ScalarType;
|
||||||
|
|
||||||
/// Fixed width dtypes.
|
/// Fixed width dtypes.
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/data/datasets/mnist.h>
|
#include <torch/data/datasets/mnist.h>
|
||||||
|
|
||||||
#include <torch/data/example.h>
|
#include <torch/data/example.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/data/samplers/random.h>
|
#include <torch/data/samplers/random.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/data/samplers/sequential.h>
|
#include <torch/data/samplers/sequential.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/data/samplers/stream.h>
|
#include <torch/data/samplers/stream.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/nn/cursor.h>
|
#include <torch/nn/cursor.h>
|
||||||
|
|
||||||
#include <torch/nn/module.h>
|
#include <torch/nn/module.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/nn/init.h>
|
#include <torch/nn/init.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/nn/modules/batchnorm.h>
|
#include <torch/nn/modules/batchnorm.h>
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/nn/modules/conv.h>
|
#include <torch/nn/modules/conv.h>
|
||||||
|
|
||||||
#include <torch/expanding_array.h>
|
#include <torch/expanding_array.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/nn/modules/dropout.h>
|
#include <torch/nn/modules/dropout.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/nn/modules/embedding.h>
|
#include <torch/nn/modules/embedding.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/nn/modules/functional.h>
|
#include <torch/nn/modules/functional.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/nn/modules/linear.h>
|
#include <torch/nn/modules/linear.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/nn/modules/rnn.h>
|
#include <torch/nn/modules/rnn.h>
|
||||||
|
|
||||||
#include <torch/nn/modules/dropout.h>
|
#include <torch/nn/modules/dropout.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||||
#include <torch/nn/cursor.h>
|
#include <torch/nn/cursor.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/optim/serialize.h>
|
#include <torch/optim/serialize.h>
|
||||||
|
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <torch/nn/pimpl.h>
|
#include <torch/nn/pimpl.h>
|
||||||
#include <torch/optim/optimizer.h>
|
#include <torch/optim/optimizer.h>
|
||||||
#include <torch/optim/serialize.h>
|
#include <torch/optim/serialize.h>
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/serialize/input-archive.h>
|
#include <torch/serialize/input-archive.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <torch/csrc/jit/import.h>
|
#include <torch/csrc/jit/import.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <torch/serialize/output-archive.h>
|
#include <torch/serialize/output-archive.h>
|
||||||
|
|
||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/utils.h>
|
#include <torch/utils.h>
|
||||||
|
|
||||||
#include <torch/csrc/jit/export.h>
|
#include <torch/csrc/jit/export.h>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
#include <torch/tensor.h>
|
#include <torch/types.h>
|
||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
@ -423,7 +423,6 @@ private:
|
|||||||
template<TypeKind K, typename T>
|
template<TypeKind K, typename T>
|
||||||
struct SingleElementType : public Type {
|
struct SingleElementType : public Type {
|
||||||
static const TypeKind Kind = K;
|
static const TypeKind Kind = K;
|
||||||
static constexpr bool is_singleton = true;
|
|
||||||
TypePtr getElementType() const {
|
TypePtr getElementType() const {
|
||||||
return elem;
|
return elem;
|
||||||
}
|
}
|
||||||
@ -488,9 +487,6 @@ struct FutureType;
|
|||||||
using FutureTypePtr = std::shared_ptr<FutureType>;
|
using FutureTypePtr = std::shared_ptr<FutureType>;
|
||||||
|
|
||||||
struct TORCH_API FutureType : public Type {
|
struct TORCH_API FutureType : public Type {
|
||||||
// It's not exactly a singleton, but there should be exactly once instance of
|
|
||||||
// Future[T] for every T
|
|
||||||
static constexpr bool is_singleton = true;
|
|
||||||
friend struct Type;
|
friend struct Type;
|
||||||
template<typename ... T>
|
template<typename ... T>
|
||||||
static FutureTypePtr create(TypePtr elem) {
|
static FutureTypePtr create(TypePtr elem) {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/api/include/torch/types.h>
|
||||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||||
#include <torch/csrc/jit/custom_operator.h>
|
#include <torch/csrc/jit/custom_operator.h>
|
||||||
#include <torch/csrc/jit/import.h>
|
#include <torch/csrc/jit/import.h>
|
||||||
|
@ -633,13 +633,13 @@ def load_inline(name,
|
|||||||
as its docstring.
|
as its docstring.
|
||||||
|
|
||||||
The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
|
The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
|
||||||
file and prepended with ``ATen/ATen.h``, ``cuda.h`` and ``cuda_runtime.h``
|
file and prepended with ``torch/types.h``, ``cuda.h`` and
|
||||||
includes. The ``.cpp`` and ``.cu`` files are compiled separately, but
|
``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
|
||||||
ultimately linked into a single library. Note that no bindings are
|
separately, but ultimately linked into a single library. Note that no
|
||||||
generated for functions in ``cuda_sources`` per se. To bind to a CUDA
|
bindings are generated for functions in ``cuda_sources`` per se. To bind
|
||||||
kernel, you must create a C++ function that calls it, and either declare or
|
to a CUDA kernel, you must create a C++ function that calls it, and either
|
||||||
define this C++ function in one of the ``cpp_sources`` (and include its
|
declare or define this C++ function in one of the ``cpp_sources`` (and
|
||||||
name in ``functions``).
|
include its name in ``functions``).
|
||||||
|
|
||||||
See :func:`load` for a description of arguments omitted below.
|
See :func:`load` for a description of arguments omitted below.
|
||||||
|
|
||||||
@ -702,7 +702,7 @@ def load_inline(name,
|
|||||||
sources = [cpp_source_path]
|
sources = [cpp_source_path]
|
||||||
|
|
||||||
if cuda_sources:
|
if cuda_sources:
|
||||||
cuda_sources.insert(0, '#include <ATen/ATen.h>')
|
cuda_sources.insert(0, '#include <torch/types.h>')
|
||||||
cuda_sources.insert(1, '#include <cuda.h>')
|
cuda_sources.insert(1, '#include <cuda.h>')
|
||||||
cuda_sources.insert(2, '#include <cuda_runtime.h>')
|
cuda_sources.insert(2, '#include <cuda_runtime.h>')
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user