Change behavior of clone to clone to a device (#9609)

Summary:
ebetica made me aware that `nn::Module::clone()` always clones to the current device (usually CPU) instead of preserving the device of each parameter. This PR changes the signature of `clone` from

`shared_ptr<Module> clone()`

to

`shared_ptr<Module> clone(optional<Device> device = nullopt)`

with semantics of:

1. If a `device` is given, all parameters/buffers are moved to that device,
2. If no `device` is supplied (default), parameters/buffers retain their device.

ezyang apaszke ebetica
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9609

Differential Revision: D8957367

Pulled By: goldsborough

fbshipit-source-id: 0d409ae645ed2b8d97d6fc060240de2f3d4bc6c8
This commit is contained in:
Peter Goldsborough
2018-07-23 14:49:18 -07:00
committed by Facebook Github Bot
parent 31ba2f15e1
commit d05a8145c5
8 changed files with 117 additions and 30 deletions

View File

@ -184,7 +184,8 @@ TEST_CASE("module/clone") {
SECTION(
"a module that overrides clone() does not throw when clone() is called ") {
struct Cloneable : Module {
std::shared_ptr<Module> clone() const override {
std::shared_ptr<Module> clone(
at::optional<torch::Device> device = at::nullopt) const override {
return nullptr;
}
};
@ -299,6 +300,56 @@ TEST_CASE("module/clone") {
}
}
TEST_CASE("module/clone-to-device", "[cuda]") {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
SECTION("Cloning preserves the device of parameters/buffers") {
TestModule m;
torch::Device device(torch::kCUDA, 0);
m.to(device);
auto clone = m.clone();
for (const auto& parameter : clone->parameters()) {
REQUIRE(parameter->device().type() == device.type());
REQUIRE(parameter->device().index() == device.index());
}
for (const auto& buffer : clone->buffers()) {
REQUIRE(buffer->device().type() == device.type());
REQUIRE(buffer->device().index() == device.index());
}
}
SECTION(
"Cloning to a particular device places all parameters/buffers there") {
TestModule m;
torch::Device device(torch::kCUDA, 1);
// everything is on CPU here
auto clone = m.clone(device);
for (const auto& parameter : clone->parameters()) {
REQUIRE(parameter->device().type() == device.type());
REQUIRE(parameter->device().index() == device.index());
}
for (const auto& buffer : clone->buffers()) {
REQUIRE(buffer->device().type() == device.type());
REQUIRE(buffer->device().index() == device.index());
}
}
}
TEST_CASE("module/parameters") {
torch::manual_seed(0);
struct TestModule : Module {