mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 06:34:55 +08:00
Change behavior of clone to clone to a device (#9609)
Summary: ebetica made me aware that `nn::Module::clone()` always clones to the current device (usually CPU) instead of preserving the device of each parameter. This PR changes the signature of `clone` from `shared_ptr<Module> clone()` to `shared_ptr<Module> clone(optional<Device> device = nullopt)` with semantics of: 1. If a `device` is given, all parameters/buffers are moved to that device, 2. If no `device` is supplied (default), parameters/buffers retain their device. ezyang apaszke ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/9609 Differential Revision: D8957367 Pulled By: goldsborough fbshipit-source-id: 0d409ae645ed2b8d97d6fc060240de2f3d4bc6c8
This commit is contained in:
committed by
Facebook Github Bot
parent
31ba2f15e1
commit
d05a8145c5
@ -184,7 +184,8 @@ TEST_CASE("module/clone") {
|
||||
SECTION(
|
||||
"a module that overrides clone() does not throw when clone() is called ") {
|
||||
struct Cloneable : Module {
|
||||
std::shared_ptr<Module> clone() const override {
|
||||
std::shared_ptr<Module> clone(
|
||||
at::optional<torch::Device> device = at::nullopt) const override {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
@ -299,6 +300,56 @@ TEST_CASE("module/clone") {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("module/clone-to-device", "[cuda]") {
|
||||
struct TestModule : public Cloneable<TestModule> {
|
||||
TestModule() {
|
||||
reset();
|
||||
}
|
||||
void reset() override {
|
||||
l1 = register_module("l1", Linear(10, 3));
|
||||
l2 = register_module("l2", Linear(3, 5));
|
||||
l3 = register_module("l3", Linear(5, 100));
|
||||
buffer = register_buffer("buf", torch::ones({2, 2}));
|
||||
}
|
||||
|
||||
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
|
||||
torch::Tensor buffer;
|
||||
};
|
||||
|
||||
SECTION("Cloning preserves the device of parameters/buffers") {
|
||||
TestModule m;
|
||||
torch::Device device(torch::kCUDA, 0);
|
||||
|
||||
m.to(device);
|
||||
|
||||
auto clone = m.clone();
|
||||
for (const auto& parameter : clone->parameters()) {
|
||||
REQUIRE(parameter->device().type() == device.type());
|
||||
REQUIRE(parameter->device().index() == device.index());
|
||||
}
|
||||
for (const auto& buffer : clone->buffers()) {
|
||||
REQUIRE(buffer->device().type() == device.type());
|
||||
REQUIRE(buffer->device().index() == device.index());
|
||||
}
|
||||
}
|
||||
|
||||
SECTION(
|
||||
"Cloning to a particular device places all parameters/buffers there") {
|
||||
TestModule m;
|
||||
torch::Device device(torch::kCUDA, 1);
|
||||
// everything is on CPU here
|
||||
auto clone = m.clone(device);
|
||||
for (const auto& parameter : clone->parameters()) {
|
||||
REQUIRE(parameter->device().type() == device.type());
|
||||
REQUIRE(parameter->device().index() == device.index());
|
||||
}
|
||||
for (const auto& buffer : clone->buffers()) {
|
||||
REQUIRE(buffer->device().type() == device.type());
|
||||
REQUIRE(buffer->device().index() == device.index());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("module/parameters") {
|
||||
torch::manual_seed(0);
|
||||
struct TestModule : Module {
|
||||
|
Reference in New Issue
Block a user