mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Bootcamp]Add option for flatbuffer loader to copy memory to individual tensors (#76986)
Summary: Add option for flatbuffer loader to copy memory to individual tensors to allow free memeory without waiting for all tensor runs completed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76986 Approved by: https://github.com/qihqi
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b0611c223
commit
bd573389f6
@ -31,9 +31,13 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
mobile::Module parse_mobile_module(void* data, size_t) {
|
mobile::Module parse_mobile_module(
|
||||||
|
void* data,
|
||||||
|
size_t,
|
||||||
|
bool should_copy_tensor_memory = false) {
|
||||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||||
return initialize_mobile_module(flatbuffer_module);
|
return initialize_mobile_module(
|
||||||
|
flatbuffer_module, c10::nullopt, should_copy_tensor_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(FlatbufferTest, UpsampleNearest2d) {
|
TEST(FlatbufferTest, UpsampleNearest2d) {
|
||||||
@ -64,6 +68,37 @@ TEST(FlatbufferTest, UpsampleNearest2d) {
|
|||||||
ASSERT_TRUE(resd2.equal(refd));
|
ASSERT_TRUE(resd2.equal(refd));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FlatbufferTest, UpsampleNearest2dWithCopyTensorMemory) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input: Tensor, scale:float):
|
||||||
|
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
|
||||||
|
inputs.emplace_back(at::Scalar(2.0));
|
||||||
|
auto ref = m.forward(inputs);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
res = bc.forward(inputs);
|
||||||
|
|
||||||
|
auto resd = res.toTensor();
|
||||||
|
auto refd = ref.toTensor();
|
||||||
|
ASSERT_TRUE(resd.equal(refd));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size(), true);
|
||||||
|
|
||||||
|
buff = flatbuffers::DetachedBuffer();
|
||||||
|
|
||||||
|
auto res2 = bc2.forward(inputs);
|
||||||
|
auto resd2 = res2.toTensor();
|
||||||
|
ASSERT_TRUE(resd2.equal(refd));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FlatbufferTest, CheckAttrAccess) {
|
TEST(FlatbufferTest, CheckAttrAccess) {
|
||||||
Module m("m");
|
Module m("m");
|
||||||
m.register_attribute("mobile_optimized", BoolType::get(), true);
|
m.register_attribute("mobile_optimized", BoolType::get(), true);
|
||||||
@ -242,6 +277,50 @@ TEST(FlatbufferTest, Conv) {
|
|||||||
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FlatbufferTest, ConvWithCopyTensorMemory) {
|
||||||
|
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
|
||||||
|
if (s && strcmp(s, "1") == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
||||||
|
m.register_parameter("bias", torch::ones({20}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
|
||||||
|
)");
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
|
||||||
|
inputs.push_back(torch::ones({1, 1, 28, 28}));
|
||||||
|
|
||||||
|
auto outputref = m.forward(inputs).toTensor();
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
auto output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(
|
||||||
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size(), true);
|
||||||
|
buff = flatbuffers::DetachedBuffer();
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc2.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(
|
||||||
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FlatbufferTest, Inline) {
|
TEST(FlatbufferTest, Inline) {
|
||||||
Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
@ -267,6 +346,32 @@ TEST(FlatbufferTest, Inline) {
|
|||||||
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FlatbufferTest, InlineWithCopyTensorMemory) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"JIT(
|
||||||
|
def foo1(self, x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def foo2(self, x):
|
||||||
|
return self.foo1(x) + 2
|
||||||
|
|
||||||
|
def foo3(self, x):
|
||||||
|
return self.foo2(x) + 3
|
||||||
|
)JIT");
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
|
auto output = bc.get_method("foo3")(inputs);
|
||||||
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size(), true);
|
||||||
|
buff = flatbuffers::DetachedBuffer();
|
||||||
|
std::vector<torch::jit::IValue> inputs2({torch::ones({})});
|
||||||
|
output = bc2.get_method("foo3")(inputs2);
|
||||||
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FlatbufferTest, Tuple) {
|
TEST(FlatbufferTest, Tuple) {
|
||||||
Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
|
@ -583,8 +583,16 @@ c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
|
|||||||
if (!storage_loaded_[index]) {
|
if (!storage_loaded_[index]) {
|
||||||
auto* storage = module_->storage_data()->GetMutableObject(index);
|
auto* storage = module_->storage_data()->GetMutableObject(index);
|
||||||
size_t size = storage->data()->size();
|
size_t size = storage->data()->size();
|
||||||
void* ptr = static_cast<void*>(storage->mutable_data()->data());
|
|
||||||
at::DataPtr data(ptr, ptr, deleteNothing2, DeviceType::CPU);
|
at::DataPtr data;
|
||||||
|
if (should_copy_tensor_memory_) {
|
||||||
|
auto* allocator = at::GetCPUAllocator();
|
||||||
|
data = allocator->allocate(size);
|
||||||
|
memcpy(data.get(), storage->data()->data(), size);
|
||||||
|
} else {
|
||||||
|
void* ptr = static_cast<void*>(storage->mutable_data()->data());
|
||||||
|
data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU);
|
||||||
|
}
|
||||||
storages_[index] =
|
storages_[index] =
|
||||||
c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
|
c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
|
||||||
storage_loaded_[index] = true;
|
storage_loaded_[index] = true;
|
||||||
@ -678,8 +686,11 @@ mobile::Module parse_and_initialize_mobile_module(
|
|||||||
|
|
||||||
mobile::Module initialize_mobile_module(
|
mobile::Module initialize_mobile_module(
|
||||||
mobile::serialization::Module* flatbuffer_module,
|
mobile::serialization::Module* flatbuffer_module,
|
||||||
c10::optional<at::Device>) {
|
c10::optional<at::Device>,
|
||||||
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
|
bool should_copy_tensor_memory) {
|
||||||
|
auto flatbufferLoader = FlatbufferLoader();
|
||||||
|
flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory);
|
||||||
|
mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module);
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
|||||||
// This function does step 3 described above.
|
// This function does step 3 described above.
|
||||||
TORCH_API mobile::Module initialize_mobile_module(
|
TORCH_API mobile::Module initialize_mobile_module(
|
||||||
mobile::serialization::Module* flatbuffer_module,
|
mobile::serialization::Module* flatbuffer_module,
|
||||||
c10::optional<at::Device> device = c10::nullopt);
|
c10::optional<at::Device> device = c10::nullopt,
|
||||||
|
bool should_copy_tensor_memory = false);
|
||||||
|
|
||||||
// Parse a mobile::Module from raw bytes.
|
// Parse a mobile::Module from raw bytes.
|
||||||
// ownership of data is shared to the returned Module.
|
// ownership of data is shared to the returned Module.
|
||||||
@ -109,6 +110,14 @@ class TORCH_API FlatbufferLoader {
|
|||||||
return module_;
|
return module_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool getShouldCopyTensorMemory() {
|
||||||
|
return should_copy_tensor_memory_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
|
||||||
|
should_copy_tensor_memory_ = should_copy_tensor_memory;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||||
std::shared_ptr<CompilationUnit> cu_;
|
std::shared_ptr<CompilationUnit> cu_;
|
||||||
|
|
||||||
@ -131,6 +140,7 @@ class TORCH_API FlatbufferLoader {
|
|||||||
TypeResolver type_resolver_ = nullptr;
|
TypeResolver type_resolver_ = nullptr;
|
||||||
mobile::serialization::Module* module_ = nullptr;
|
mobile::serialization::Module* module_ = nullptr;
|
||||||
bool module_parsed_ = false;
|
bool module_parsed_ = false;
|
||||||
|
bool should_copy_tensor_memory_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
Reference in New Issue
Block a user