mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move Weights to PyTorch core (#155156)
Summary: Moves Weights class to PyTorch core Torch Native Runtime RFC: pytorch/rfcs#72 README: https://github.com/pytorch/pytorch/blob/main/torch/nativert/OVERVIEW.md Test Plan: buck2 run mode/dev-nosan caffe2/test/cpp/nativert:weights_test Differential Revision: D75973156 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155156 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
6fb6293159
commit
9b4a748e29
@ -596,6 +596,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/TensorMeta.cpp",
|
||||
"torch/nativert/executor/Placement.cpp",
|
||||
"torch/nativert/executor/PlacementUtils.cpp",
|
||||
"torch/nativert/executor/Weights.cpp",
|
||||
"torch/nativert/executor/memory/FunctionSchema.cpp",
|
||||
"torch/nativert/common/FileUtil.cpp",
|
||||
]
|
||||
|
@ -9,6 +9,7 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/Weights.cpp
|
||||
${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/FunctionSchema.cpp
|
||||
)
|
||||
|
92
test/cpp/nativert/test_weights.cpp
Normal file
92
test/cpp/nativert/test_weights.cpp
Normal file
@ -0,0 +1,92 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/torch.h>
|
||||
#include <memory>
|
||||
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
#include <torch/nativert/executor/Weights.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
class WeightsTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
static constexpr std::string_view source =
|
||||
R"(graph(%foo, %bar, %baz):
|
||||
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
||||
return(%o2, %baz)
|
||||
)";
|
||||
graph = stringToGraph(source);
|
||||
placement = std::make_unique<Placement>(c10::Device(c10::DeviceType::CPU));
|
||||
}
|
||||
std::shared_ptr<Graph> graph;
|
||||
std::unique_ptr<Placement> placement;
|
||||
};
|
||||
TEST_F(WeightsTest, ConstructEmptyStateDict) {
|
||||
std::unordered_map<std::string, c10::IValue> stateDict;
|
||||
Weights weights(graph.get(), stateDict, *placement);
|
||||
// Check that weights are initialized correctly
|
||||
EXPECT_TRUE(weights.parameters().empty());
|
||||
EXPECT_TRUE(weights.buffers().empty());
|
||||
EXPECT_FALSE(weights.contains("non_existent_weight"));
|
||||
}
|
||||
TEST_F(WeightsTest, SetAndGetValue) {
|
||||
std::unordered_map<std::string, c10::IValue> stateDict;
|
||||
Weights weights(graph.get(), stateDict, *placement);
|
||||
at::Tensor tensor = at::ones({2, 2});
|
||||
weights.setValue("added_weight", tensor);
|
||||
EXPECT_TRUE(weights.contains("added_weight"));
|
||||
EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes());
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
||||
using namespace ::testing;
|
||||
struct ContainsTensorDict : torch::CustomClassHolder {
|
||||
explicit ContainsTensorDict(at::Tensor t) : t_(t) {}
|
||||
|
||||
explicit ContainsTensorDict(c10::Dict<std::string, at::Tensor> dict) {
|
||||
t_ = dict.at(std::string("init_tensor"));
|
||||
}
|
||||
|
||||
c10::Dict<std::string, at::Tensor> serialize() const {
|
||||
c10::Dict<std::string, at::Tensor> dict;
|
||||
dict.insert(std::string("init_tensor"), t_);
|
||||
return dict;
|
||||
}
|
||||
|
||||
at::Tensor t_;
|
||||
};
|
||||
|
||||
static auto reg =
|
||||
torch::class_<ContainsTensorDict>("testing", "ContainsTensorDict")
|
||||
.def(torch::init<at::Tensor>())
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<ContainsTensorDict>& self)
|
||||
-> c10::Dict<std::string, at::Tensor> {
|
||||
return self->serialize();
|
||||
},
|
||||
// __setstate__
|
||||
[](c10::Dict<std::string, at::Tensor> data)
|
||||
-> c10::intrusive_ptr<ContainsTensorDict> {
|
||||
return c10::make_intrusive<ContainsTensorDict>(std::move(data));
|
||||
});
|
||||
|
||||
TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) {
|
||||
// Save
|
||||
auto customObj =
|
||||
c10::make_intrusive<ContainsTensorDict>(torch::tensor({1, 2, 3}));
|
||||
const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj)));
|
||||
|
||||
// Load
|
||||
const auto loadedCustomObj =
|
||||
torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()});
|
||||
EXPECT_TRUE(loadedCustomObj.isObject());
|
||||
EXPECT_EQ(
|
||||
loadedCustomObj.to<c10::intrusive_ptr<ContainsTensorDict>>()
|
||||
->t_[0]
|
||||
.item<int>(),
|
||||
1);
|
||||
}
|
439
torch/nativert/executor/Weights.cpp
Normal file
439
torch/nativert/executor/Weights.cpp
Normal file
@ -0,0 +1,439 @@
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <utility>
|
||||
|
||||
#include <torch/csrc/export/pt2_archive_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/import_read.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/nativert/executor/Weights.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#endif
|
||||
|
||||
#include <c10/util/string_view.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
WeightVersion Weights::globalVersion_ = 0;
|
||||
|
||||
Weights::Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict,
|
||||
Placement placement)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
placement_(std::move(placement)),
|
||||
version_(globalVersion_++) {
|
||||
if (stateDict.has_value()) {
|
||||
loadStateDict(stateDict.value());
|
||||
}
|
||||
}
|
||||
|
||||
Weights::Weights(
|
||||
const Graph* graph,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const std::unordered_map<std::string, std::string>& stateDictPaths,
|
||||
std::string_view stateDictPathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& constantPaths,
|
||||
std::string_view constantPathPrefix,
|
||||
Placement placement,
|
||||
std::function<bool(const std::string&)> skipSizeCheck,
|
||||
std::function<bool(const std::string&)> skipDtypeCheck)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
placement_(std::move(placement)),
|
||||
version_(globalVersion_++),
|
||||
skipSizeCheck_(std::move(skipSizeCheck)),
|
||||
skipDtypeCheck_(std::move(skipDtypeCheck)) {
|
||||
auto loadAndInsert =
|
||||
[&](const std::string& tensorName,
|
||||
std::string_view pathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& tensorPaths,
|
||||
bool isUsed) {
|
||||
auto pathIt = tensorPaths.find(tensorName);
|
||||
TORCH_CHECK(
|
||||
pathIt != tensorPaths.end(),
|
||||
"Couldn't find ",
|
||||
tensorName,
|
||||
" in tensorPaths");
|
||||
|
||||
const std::string tensorPath = std::string{pathPrefix} + pathIt->second;
|
||||
VLOG(1) << "Loading weight from: " << tensorPath;
|
||||
TORCH_CHECK(
|
||||
pytorchStreamReader->hasRecord(tensorPath),
|
||||
tensorPath,
|
||||
" not found");
|
||||
|
||||
auto [tensorData, tensorDataSize] =
|
||||
pytorchStreamReader->getRecord(tensorPath);
|
||||
|
||||
// TODO: We now have two copies of metadata for weights, one in
|
||||
// model definition /models/<model_name>.json, another in
|
||||
// /extra/xl_weights/<model_name>_model_param_config.json
|
||||
// Currently, we only use the metadata from model definition.
|
||||
std::optional<TensorMeta> tensorMeta;
|
||||
if (weightsMeta_.find(tensorName) != weightsMeta_.end()) {
|
||||
tensorMeta = weightsMeta_.at(tensorName);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Tensor meta not found for: ", tensorName);
|
||||
}
|
||||
|
||||
if (tensorDataSize == 0 && tensorMeta->numel() > 0) {
|
||||
VLOG(1) << "Tensor " << tensorName
|
||||
<< " does not have data and create on Meta device";
|
||||
allValues_[tensorName] = at::empty_strided(
|
||||
tensorMeta->sizes(),
|
||||
tensorMeta->strides(),
|
||||
tensorMeta->asTensorOptions().device(at::kMeta));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isUsed) {
|
||||
VLOG(1) << "Tensor " << tensorName << " is not used during inference";
|
||||
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
|
||||
allValues_[tensorName] =
|
||||
at::scalar_tensor(0, at::TensorOptions().device(targetDevice));
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bytesPerEntry =
|
||||
c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize();
|
||||
auto device = tensorData.device();
|
||||
auto storage = c10::Storage(
|
||||
c10::Storage::use_byte_size_t(),
|
||||
at::detail::computeStorageNbytes(
|
||||
tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry),
|
||||
std::move(tensorData), // ownership is transferred
|
||||
nullptr,
|
||||
false);
|
||||
const auto tensorOptions = at::TensorOptions(device)
|
||||
.dtype(tensorMeta->dtype())
|
||||
.requires_grad(false);
|
||||
auto tensor =
|
||||
at::empty({0}, tensorOptions)
|
||||
.set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides());
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
|
||||
VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice;
|
||||
if (!isSameDevice(targetDevice, tensor.device())) {
|
||||
tensor = tensor.to(targetDevice);
|
||||
}
|
||||
|
||||
allValues_[tensorName] = tensor;
|
||||
};
|
||||
|
||||
auto loadAndInsertParamsBuffers = [&](const auto& tensorName, bool isUsed) {
|
||||
return loadAndInsert(
|
||||
std::string(tensorName), stateDictPathPrefix, stateDictPaths, isUsed);
|
||||
};
|
||||
|
||||
size_t weightIndex = 0;
|
||||
bool isUsed = true;
|
||||
const auto& weightValues = graph->weightValues();
|
||||
|
||||
for (const auto& tensorName : graph->signature().parameters()) {
|
||||
isUsed = !weightValues[weightIndex]->users().empty();
|
||||
if (!isUsed) {
|
||||
unusedWeights_.insert(std::string(tensorName));
|
||||
}
|
||||
loadAndInsertParamsBuffers(tensorName, isUsed);
|
||||
weightIndex++;
|
||||
}
|
||||
for (const auto& tensorName : graph->signature().buffers()) {
|
||||
isUsed = !weightValues[weightIndex]->users().empty();
|
||||
if (!isUsed) {
|
||||
unusedWeights_.insert(std::string(tensorName));
|
||||
}
|
||||
loadAndInsertParamsBuffers(tensorName, isUsed);
|
||||
weightIndex++;
|
||||
}
|
||||
|
||||
// Load tensor constants and custom object constants, they are both stored
|
||||
// in the same directory in the archive, i.e. "extra/constants/" tensor
|
||||
// constants are prefixed with "tensor_" custom objects are prefixed with
|
||||
// "custom_obj_"
|
||||
auto loadConstants = [&](const auto& constants) {
|
||||
for (const auto& constantName : constants) {
|
||||
auto pathIt = constantPaths.find(std::string(constantName));
|
||||
TORCH_CHECK(
|
||||
pathIt != constantPaths.end(),
|
||||
"Couldn't find ",
|
||||
constantName,
|
||||
" in constantPaths");
|
||||
auto& fileName = pathIt->second;
|
||||
|
||||
if (c10::starts_with(
|
||||
fileName,
|
||||
torch::_export::archive_spec::TENSOR_CONSTANT_FILENAME_PREFIX)) {
|
||||
// tensor constants
|
||||
isUsed = !weightValues[weightIndex]->users().empty();
|
||||
if (!isUsed) {
|
||||
unusedWeights_.insert(std::string(constantName));
|
||||
}
|
||||
loadAndInsert(
|
||||
std::string(constantName),
|
||||
constantPathPrefix,
|
||||
constantPaths,
|
||||
isUsed);
|
||||
weightIndex++;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown constant path: ", fileName);
|
||||
}
|
||||
}
|
||||
};
|
||||
loadConstants(graph->signature().nonPersistentBuffers());
|
||||
loadConstants(graph->signature().tensorConstants());
|
||||
|
||||
// custom object constants
|
||||
for (const auto& customObjName : graph->signature().customObjs()) {
|
||||
auto pathIt = constantPaths.find(std::string(customObjName));
|
||||
TORCH_CHECK(
|
||||
pathIt != constantPaths.end(),
|
||||
"Couldn't find ",
|
||||
customObjName,
|
||||
" in constantPaths");
|
||||
auto& fileName = pathIt->second;
|
||||
|
||||
if (!c10::starts_with(
|
||||
fileName,
|
||||
torch::_export::archive_spec::CUSTOM_OBJ_FILENAME_PREFIX)) {
|
||||
TORCH_CHECK(false, "Unknown constant path: ", fileName);
|
||||
}
|
||||
std::string customObjPath = std::string{constantPathPrefix} + fileName;
|
||||
LOG(INFO) << "Loading custom object from: " << customObjPath;
|
||||
|
||||
TORCH_CHECK(
|
||||
pytorchStreamReader->hasRecord(customObjPath),
|
||||
customObjPath,
|
||||
" not found");
|
||||
|
||||
const auto& [customObjData, customObjDataSize] =
|
||||
pytorchStreamReader->getRecord(customObjPath);
|
||||
|
||||
const char* customObjDataPtr =
|
||||
reinterpret_cast<const char*>(customObjData.get());
|
||||
std::string customObjBytes(
|
||||
customObjDataPtr, customObjDataPtr + customObjDataSize);
|
||||
|
||||
c10::IValue customObj = torch::jit::pickle_load_obj(customObjBytes);
|
||||
TORCH_CHECK(
|
||||
customObj.isCustomClass(), "Custom object is not a custom class");
|
||||
TORCH_CHECK(!customObj.isNone(), "Custom object is None");
|
||||
customObjs_[std::string(customObjName)] = std::move(customObj);
|
||||
customObjsPaths_[customObjPath] = std::string(customObjName);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> Weights::parameters() const {
|
||||
std::unordered_map<std::string, at::Tensor> result;
|
||||
for (const auto& name : graph_->signature().parameters()) {
|
||||
result.emplace(name, allValues_.at(std::string(name)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> Weights::buffers() const {
|
||||
std::unordered_map<std::string, at::Tensor> result;
|
||||
for (const auto& name : graph_->signature().buffers()) {
|
||||
result.emplace(name, allValues_.at(std::string(name)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> Weights::attributes() const {
|
||||
return allValues_;
|
||||
}
|
||||
|
||||
at::Tensor Weights::at(const std::string& name) const {
|
||||
auto it = allValues_.find(name);
|
||||
if (it != allValues_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, name, " not found in Weights ", toString());
|
||||
}
|
||||
|
||||
at::Tensor& Weights::at(const std::string& name) {
|
||||
auto it = allValues_.find(name);
|
||||
if (it != allValues_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, name, " not found in Weights ", toString());
|
||||
}
|
||||
|
||||
bool Weights::contains(const std::string& name) const {
|
||||
return allValues_.find(name) != allValues_.end();
|
||||
}
|
||||
|
||||
c10::IValue Weights::getCustomObj(const std::string& name) const {
|
||||
auto it = customObjs_.find(name);
|
||||
if (it != customObjs_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Custom objects ", name, " not found in Weights");
|
||||
}
|
||||
|
||||
c10::IValue Weights::getCustomObjByFileName(const std::string& name) const {
|
||||
auto it = customObjsPaths_.find(name);
|
||||
TORCH_CHECK(
|
||||
it != customObjsPaths_.end(),
|
||||
"Custom objects with file name ",
|
||||
name,
|
||||
" not found in Weights");
|
||||
const std::string obj_name = it->second;
|
||||
return getCustomObj(obj_name);
|
||||
}
|
||||
|
||||
void Weights::loadStateDict(
|
||||
const std::unordered_map<std::string, c10::IValue>& stateDict) {
|
||||
auto validateAndInsert = [&](const std::string& name) {
|
||||
auto stateDictIt = stateDict.find(name);
|
||||
TORCH_CHECK(
|
||||
stateDictIt != stateDict.end(),
|
||||
"Couldn't find ",
|
||||
name,
|
||||
" in stateDict");
|
||||
|
||||
// Verify that the tensor matches the tensorMeta
|
||||
auto it = weightsMeta_.find(name);
|
||||
TORCH_CHECK(
|
||||
it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta");
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(it->second.device());
|
||||
auto tensor = stateDictIt->second.toTensor().to(targetDevice);
|
||||
|
||||
TORCH_CHECK(tensor.sizes() == it->second.sizes());
|
||||
TORCH_CHECK(tensor.dtype() == it->second.dtype());
|
||||
|
||||
allValues_.emplace(name, tensor);
|
||||
};
|
||||
|
||||
for (const auto& name : graph_->signature().parameters()) {
|
||||
validateAndInsert(std::string(name));
|
||||
}
|
||||
for (const auto& name : graph_->signature().buffers()) {
|
||||
validateAndInsert(std::string(name));
|
||||
}
|
||||
// TensorConstants_ not filled !!
|
||||
}
|
||||
|
||||
void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
|
||||
const {
|
||||
auto& weightMeta = weightsMeta_.at(name);
|
||||
|
||||
TORCH_CHECK(
|
||||
weightMeta.sizes() == newValue.sizes() ||
|
||||
(skipSizeCheck_ && skipSizeCheck_(name)) ||
|
||||
unusedWeights_.find(name) != unusedWeights_.end(),
|
||||
"Mismatched sizes for ",
|
||||
name,
|
||||
": ",
|
||||
weightMeta.sizes(),
|
||||
" vs ",
|
||||
newValue.sizes());
|
||||
TORCH_CHECK(
|
||||
weightMeta.dtype() == newValue.dtype() ||
|
||||
(skipDtypeCheck_ && skipDtypeCheck_(name)) ||
|
||||
unusedWeights_.find(name) != unusedWeights_.end(),
|
||||
"Mismatched dtype for ",
|
||||
name,
|
||||
": ",
|
||||
weightMeta.dtype(),
|
||||
" vs ",
|
||||
newValue.dtype());
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(weightMeta.device());
|
||||
if (targetDevice.is_cpu() && targetDevice.has_index()) {
|
||||
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
isSameDevice(targetDevice, newValue.device()),
|
||||
"Mismatched device for ",
|
||||
name,
|
||||
": ",
|
||||
targetDevice,
|
||||
" vs ",
|
||||
newValue.device());
|
||||
}
|
||||
|
||||
void Weights::setValue(const std::string& name, const at::Tensor& newValue) {
|
||||
if (allValues_.find(name) != allValues_.end()) {
|
||||
validateValue(name, newValue);
|
||||
} else {
|
||||
LOG(WARNING) << name << " is not found in the registered weights";
|
||||
}
|
||||
|
||||
allValues_[name] = newValue;
|
||||
}
|
||||
|
||||
void Weights::updateValue(const std::string& name, const at::Tensor& newValue) {
|
||||
auto it = allValues_.find(name);
|
||||
TORCH_CHECK(
|
||||
it != allValues_.end(), name, " not found in Weights ", toString());
|
||||
validateValue(name, newValue);
|
||||
|
||||
it->second.copy_(newValue);
|
||||
}
|
||||
|
||||
void Weights::updateValues(
|
||||
const std::unordered_map<std::string, at::Tensor>& newValues) {
|
||||
for (auto& [name, newValue] : newValues) {
|
||||
updateValue(name, newValue);
|
||||
}
|
||||
}
|
||||
|
||||
std::string Weights::toString() const {
|
||||
std::stringstream ss;
|
||||
ss << '[';
|
||||
for (const auto& [name, _] : allValues_) {
|
||||
ss << name << ", ";
|
||||
}
|
||||
ss << ']';
|
||||
ss << '[';
|
||||
for (const auto& [name, _] : customObjs_) {
|
||||
ss << name << ", ";
|
||||
}
|
||||
ss << ']';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void Weights::validateAllWeightsLoaded() {
|
||||
auto checkNames = [&](const auto& names) {
|
||||
for (const auto& name : names) {
|
||||
if (unusedWeights_.find(std::string(name)) != unusedWeights_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto it = allValues_.find(std::string(name));
|
||||
TORCH_CHECK(it != allValues_.end(), "Missing weight: ", name);
|
||||
TORCH_CHECK(it->second.defined(), "Weight not defined: ", name);
|
||||
if (it->second.device().is_meta()) {
|
||||
LOG(WARNING) << "Weight is on meta device: " << name;
|
||||
}
|
||||
}
|
||||
};
|
||||
checkNames(graph_->signature().parameters());
|
||||
checkNames(graph_->signature().buffers());
|
||||
checkNames(graph_->signature().nonPersistentBuffers());
|
||||
checkNames(graph_->signature().tensorConstants());
|
||||
}
|
||||
|
||||
void Weights::updateFoldedConst(std::string_view name, c10::IValue tensor) {
|
||||
foldedConstsMap_[std::string{name}] = std::move(tensor);
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, c10::IValue>& Weights::getFoldedConsts()
|
||||
const {
|
||||
return foldedConstsMap_;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
145
torch/nativert/executor/Weights.h
Normal file
145
torch/nativert/executor/Weights.h
Normal file
@ -0,0 +1,145 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using WeightVersion = int;
|
||||
/**
|
||||
* @brief A class that manages the weights of a graph, providing functionality
|
||||
* to load, access, and manipulate them.
|
||||
*
|
||||
* It is responsible for handling the parameters, buffers, and constants
|
||||
* associated with a graph It provides mechanisms to load weights from
|
||||
* serialized data, access and modify them, and performs necessary validation
|
||||
* checks.
|
||||
*/
|
||||
class Weights {
|
||||
public:
|
||||
explicit Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict = std::nullopt,
|
||||
Placement placement = Placement());
|
||||
|
||||
// Arguments
|
||||
// - pytorchStreamReader: the reader for the model archive
|
||||
// - stateDictPath: a map from parameter/buffer/constant name to file path in
|
||||
// the archive
|
||||
// - stateDictPathPrefix: a prefix that will be prepended to paths in
|
||||
// stateDictPathPrefix
|
||||
// - constantPaths: a map from constant name to file path in the archive
|
||||
// - constantPathPrefix: a prefix that will be prepended to paths in
|
||||
// constantPathPrefix
|
||||
// - placement: the device placement of the weights, default to follow the
|
||||
// original device in the weight's metadata
|
||||
explicit Weights(
|
||||
const Graph* graph,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader,
|
||||
const std::unordered_map<std::string, std::string>& stateDictPaths,
|
||||
std::string_view stateDictPathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& constantPaths,
|
||||
std::string_view constantPathPrefix,
|
||||
Placement placement = Placement(),
|
||||
std::function<bool(const std::string&)> skipSizeCheck = {},
|
||||
std::function<bool(const std::string&)> skipDtypeCheck = {});
|
||||
|
||||
at::Tensor at(const std::string& name) const;
|
||||
at::Tensor& at(const std::string& name);
|
||||
bool contains(const std::string& name) const;
|
||||
c10::IValue getCustomObj(const std::string& name) const;
|
||||
c10::IValue getCustomObjByFileName(const std::string& name) const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> parameters() const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> buffers() const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> attributes() const;
|
||||
|
||||
void loadStateDict(
|
||||
const std::unordered_map<std::string, c10::IValue>& stateDict);
|
||||
|
||||
/*
|
||||
* Replace the value stored at the weight with name "name".
|
||||
*/
|
||||
void setValue(const std::string& name, const at::Tensor& newValue);
|
||||
|
||||
/*
|
||||
* Update the value stored at the weight with name "name".
|
||||
* This is done in-place.
|
||||
*/
|
||||
void updateValue(const std::string& name, const at::Tensor& newValue);
|
||||
|
||||
void updateValues(
|
||||
const std::unordered_map<std::string, at::Tensor>& newValues);
|
||||
|
||||
void validateValue(const std::string& name, const at::Tensor& newValue) const;
|
||||
|
||||
void validateAllWeightsLoaded();
|
||||
|
||||
void updateFoldedConst(std::string_view name, c10::IValue tensor);
|
||||
|
||||
const std::unordered_map<std::string, c10::IValue>& getFoldedConsts() const;
|
||||
|
||||
C10_ALWAYS_INLINE const c10::FastMap<std::string, c10::IValue>&
|
||||
getConstFoldedValues() const {
|
||||
return constFoldedValues_;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void setConstFoldedValue(
|
||||
const std::string& n,
|
||||
c10::IValue iv) {
|
||||
constFoldedValues_.insert_or_assign(n, std::move(iv));
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
WeightVersion version() const {
|
||||
return version_;
|
||||
}
|
||||
|
||||
private:
|
||||
const Graph* graph_;
|
||||
const std::unordered_map<std::string, TensorMeta>& weightsMeta_;
|
||||
Placement placement_;
|
||||
|
||||
// keys are parameter/buffer/constant names, not graph input names!
|
||||
std::unordered_map<std::string, at::Tensor> allValues_;
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> customObjs_;
|
||||
|
||||
// contains CustomClassHolder map from a file name to an arbitray
|
||||
// key in customObjs_ that hold the loaded content of the file.
|
||||
// This is used in AOTIDelegateExecutor.
|
||||
std::unordered_map<std::string, std::string> customObjsPaths_;
|
||||
|
||||
// The liftcycle of folded consts should be tied with the weights from which
|
||||
// it was derived. The ordering of the constant should be consistent with
|
||||
// the output order of const graph.
|
||||
std::vector<c10::IValue> foldedConsts_;
|
||||
std::unordered_map<std::string, c10::IValue> foldedConstsMap_;
|
||||
|
||||
c10::FastMap<std::string, c10::IValue> constFoldedValues_;
|
||||
|
||||
// unique version number for this instance of weight
|
||||
const WeightVersion version_;
|
||||
|
||||
// every instance of Weight has a unique version number
|
||||
static WeightVersion globalVersion_;
|
||||
|
||||
std::function<bool(const std::string&)> skipSizeCheck_ = {};
|
||||
std::function<bool(const std::string&)> skipDtypeCheck_ = {};
|
||||
|
||||
// save the names of unused weights
|
||||
std::unordered_set<std::string> unusedWeights_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user