mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: In TorchScript and C++ extensions we currently advocate a mix of `torch::` and `at::` namespace usage. In the C++ frontend I had instead exported all symbols from `at::` and some from `c10::` into the `torch::` namespace. This is far, far easier for users to understand, and also avoid bugs around creating tensors vs. variables. The same should from now on be true for the TorchScript C++ API (for running and loading models) and all C++ extensions. Note that since we're just talking about typedefs, this change does not break any existing code. Once this lands I will update stuff in `pytorch/tutorials` too. zdevito ezyang gchanan Pull Request resolved: https://github.com/pytorch/pytorch/pull/13523 Differential Revision: D12942787 Pulled By: goldsborough fbshipit-source-id: 76058936bd8707b33d9e5bbc2d0705fc3d820763
232 lines
7.5 KiB
C++
232 lines
7.5 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/csrc/autograd/functions/comm.h>
|
|
#include <torch/nn/module.h>
|
|
#include <torch/nn/modules/linear.h>
|
|
#include <torch/nn/parallel/data_parallel.h>
|
|
#include <torch/nn/pimpl.h>
|
|
#include <torch/types.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::nn;
|
|
|
|
struct ParallelTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
|
|
Scatter scatter(
|
|
{torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
|
|
|
|
auto input = torch::ones(10, torch::requires_grad(true));
|
|
auto output = scatter.apply({input});
|
|
|
|
ASSERT_EQ(output.size(), 2);
|
|
ASSERT_EQ(output[0].size(0), 5);
|
|
ASSERT_EQ(output[1].size(0), 5);
|
|
|
|
ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
|
|
.allclose(input));
|
|
|
|
torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
|
|
sum.backward();
|
|
|
|
ASSERT_TRUE(input.grad().defined());
|
|
ASSERT_TRUE(input.grad().device().is_cpu());
|
|
ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
|
|
}
|
|
|
|
TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
|
|
Gather gather(torch::Device(torch::kCUDA, 1));
|
|
|
|
auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0));
|
|
auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1));
|
|
|
|
auto outputs = gather.apply({a, b});
|
|
ASSERT_EQ(outputs.size(), 1);
|
|
torch::Tensor output = outputs.front();
|
|
|
|
ASSERT_EQ(output.size(0), 10);
|
|
ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1));
|
|
|
|
auto chunks = output.chunk(2);
|
|
ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
|
|
ASSERT_TRUE(chunks[1].allclose(b));
|
|
|
|
output.backward();
|
|
|
|
ASSERT_TRUE(a.grad().defined());
|
|
ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
|
|
ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
|
|
|
|
ASSERT_TRUE(b.grad().defined());
|
|
ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
|
|
ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
|
|
}
|
|
|
|
TEST_F(ParallelTest, Replicate_MultiCUDA) {
|
|
Linear linear(3, 4);
|
|
auto replicas = parallel::replicate(
|
|
linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
|
|
ASSERT_EQ(replicas.size(), 2);
|
|
|
|
auto original_parameters = linear->parameters();
|
|
|
|
auto replica1_parameters = replicas[0]->parameters();
|
|
for (auto& parameter : replica1_parameters) {
|
|
ASSERT_EQ(parameter->device(), torch::Device(torch::kCUDA, 0));
|
|
}
|
|
replicas[0]->to(torch::kCPU);
|
|
ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
|
|
for (size_t i = 0; i < original_parameters.size(); ++i) {
|
|
ASSERT_TRUE(replica1_parameters[i]->allclose(*original_parameters[i]));
|
|
ASSERT_TRUE(
|
|
replica1_parameters[i]->data<float>() !=
|
|
original_parameters[i]->data<float>());
|
|
}
|
|
|
|
auto replica2_parameters = replicas[1]->parameters();
|
|
for (auto& parameter : replica2_parameters) {
|
|
ASSERT_EQ(parameter->device(), torch::Device(torch::kCUDA, 1));
|
|
}
|
|
replicas[1]->to(torch::kCPU);
|
|
ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
|
|
for (size_t i = 0; i < original_parameters.size(); ++i) {
|
|
ASSERT_TRUE(replica2_parameters[i]->allclose(*original_parameters[i]));
|
|
ASSERT_TRUE(
|
|
replica2_parameters[i]->data<float>() !=
|
|
original_parameters[i]->data<float>());
|
|
}
|
|
}
|
|
|
|
TEST_F(ParallelTest, ParallelApply_MultiCUDA) {
|
|
Linear a(3, 4);
|
|
|
|
Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
|
|
b->to({torch::kCUDA, 0});
|
|
|
|
Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
|
|
c->to({torch::kCUDA, 1});
|
|
|
|
std::vector<Linear> modules = {a, b, c};
|
|
std::vector<torch::Tensor> inputs = {
|
|
torch::ones({2, 3}),
|
|
torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
|
|
torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
|
|
|
|
auto outputs = parallel::parallel_apply(modules, inputs);
|
|
|
|
ASSERT_EQ(outputs.size(), 3);
|
|
ASSERT_TRUE(outputs[0].device().is_cpu());
|
|
|
|
ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
|
|
ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
|
|
|
|
ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1));
|
|
ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
|
|
}
|
|
|
|
TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
|
|
struct M : torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return torch::ones({5}, torch::dtype(torch::kInt32));
|
|
}
|
|
};
|
|
|
|
std::vector<std::shared_ptr<M>> modules = {
|
|
std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
|
|
std::vector<torch::Tensor> inputs = {
|
|
torch::empty({}), torch::empty({}), torch::empty({})};
|
|
std::vector<torch::Device> devices = {
|
|
{torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
|
|
|
|
auto outputs = parallel::parallel_apply(modules, inputs, devices);
|
|
|
|
ASSERT_EQ(outputs.size(), 3);
|
|
ASSERT_TRUE(outputs[0].device().is_cuda());
|
|
ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1));
|
|
|
|
ASSERT_TRUE(outputs[1].device().is_cuda());
|
|
ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
|
|
|
|
ASSERT_TRUE(outputs[2].device().is_cpu());
|
|
}
|
|
|
|
TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
|
|
struct M : torch::nn::Cloneable<M> {
|
|
void reset() override {}
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
throw std::runtime_error("Badness!");
|
|
}
|
|
};
|
|
|
|
auto m = std::make_shared<M>();
|
|
auto input = torch::ones({10, 3});
|
|
ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!");
|
|
}
|
|
|
|
TEST_F(
|
|
ParallelTest,
|
|
DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
|
|
struct M : torch::nn::Cloneable<M> {
|
|
void reset() override {}
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
// Intermediate tensors should be on the replica's current device.
|
|
intermediate_tensor = torch::rand(5);
|
|
// The returned tensor should be on the output device.
|
|
return torch::ones(3);
|
|
}
|
|
torch::Tensor intermediate_tensor;
|
|
};
|
|
auto m = std::make_shared<M>();
|
|
auto input = torch::ones({10, 3});
|
|
{
|
|
auto output = parallel::data_parallel(
|
|
m,
|
|
input,
|
|
/*devices=*/torch::nullopt,
|
|
/*output_device=*/torch::Device(torch::kCUDA, 1));
|
|
ASSERT_TRUE(output.defined());
|
|
ASSERT_TRUE(output.device().is_cuda());
|
|
ASSERT_EQ(output.device().index(), 1);
|
|
}
|
|
{
|
|
// Verify for the single-device case (where we don't scatter/gather).
|
|
auto output = parallel::data_parallel(
|
|
m,
|
|
input,
|
|
/*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
|
|
/*output_device=*/torch::Device(torch::kCUDA, 1));
|
|
ASSERT_TRUE(m->intermediate_tensor.defined());
|
|
ASSERT_TRUE(m->intermediate_tensor.device().is_cuda());
|
|
ASSERT_EQ(m->intermediate_tensor.device().index(), 0);
|
|
ASSERT_TRUE(output.defined());
|
|
ASSERT_TRUE(output.device().is_cuda());
|
|
ASSERT_EQ(output.device().index(), 1);
|
|
}
|
|
}
|
|
|
|
TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
|
|
struct M : torch::nn::Cloneable<M> {
|
|
void reset() override {}
|
|
torch::Tensor forward(torch::Tensor input) {
|
|
return torch::tensor(torch::getDefaultTensorOptions().device().index());
|
|
}
|
|
};
|
|
|
|
auto m = std::make_shared<M>();
|
|
auto input = torch::ones({10, 3});
|
|
auto output = parallel::data_parallel(m, input);
|
|
|
|
const auto device_count = torch::cuda::device_count();
|
|
ASSERT_EQ(output.numel(), device_count);
|
|
for (size_t i = 0; i < device_count; ++i) {
|
|
ASSERT_EQ(output[i].item<int32_t>(), i);
|
|
}
|
|
}
|