From 9b4a748e29a720d0fade7e1298a68cc36cfd5f5e Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Mon, 9 Jun 2025 05:49:32 +0000 Subject: [PATCH] [nativert] Move Weights to PyTorch core (#155156) Summary: Moves Weights class to PyTorch core Torch Native Runtime RFC: pytorch/rfcs#72 README: https://github.com/pytorch/pytorch/blob/main/torch/nativert/OVERVIEW.md Test Plan: buck2 run mode/dev-nosan caffe2/test/cpp/nativert:weights_test Differential Revision: D75973156 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155156 Approved by: https://github.com/zhxchen17 --- build_variables.bzl | 1 + test/cpp/nativert/CMakeLists.txt | 1 + test/cpp/nativert/test_weights.cpp | 92 ++++++ torch/nativert/executor/Weights.cpp | 439 ++++++++++++++++++++++++++++ torch/nativert/executor/Weights.h | 145 +++++++++ 5 files changed, 678 insertions(+) create mode 100644 test/cpp/nativert/test_weights.cpp create mode 100644 torch/nativert/executor/Weights.cpp create mode 100644 torch/nativert/executor/Weights.h diff --git a/build_variables.bzl b/build_variables.bzl index 75d532450316..1f47ca15f56e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -596,6 +596,7 @@ libtorch_nativert_sources = [ "torch/nativert/graph/TensorMeta.cpp", "torch/nativert/executor/Placement.cpp", "torch/nativert/executor/PlacementUtils.cpp", + "torch/nativert/executor/Weights.cpp", "torch/nativert/executor/memory/FunctionSchema.cpp", "torch/nativert/common/FileUtil.cpp", ] diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 40ca66801b1a..0ea0048062da 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -9,6 +9,7 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp ${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp + ${TORCH_ROOT}/torch/nativert/executor/Weights.cpp ${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/FunctionSchema.cpp ) diff --git a/test/cpp/nativert/test_weights.cpp b/test/cpp/nativert/test_weights.cpp new file mode 100644 index 000000000000..43d05d5ad887 --- /dev/null +++ b/test/cpp/nativert/test_weights.cpp @@ -0,0 +1,92 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::nativert { +class WeightsTest : public ::testing::Test { + protected: + void SetUp() override { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + graph = stringToGraph(source); + placement = std::make_unique(c10::Device(c10::DeviceType::CPU)); + } + std::shared_ptr graph; + std::unique_ptr placement; +}; +TEST_F(WeightsTest, ConstructEmptyStateDict) { + std::unordered_map stateDict; + Weights weights(graph.get(), stateDict, *placement); + // Check that weights are initialized correctly + EXPECT_TRUE(weights.parameters().empty()); + EXPECT_TRUE(weights.buffers().empty()); + EXPECT_FALSE(weights.contains("non_existent_weight")); +} +TEST_F(WeightsTest, SetAndGetValue) { + std::unordered_map stateDict; + Weights weights(graph.get(), stateDict, *placement); + at::Tensor tensor = at::ones({2, 2}); + weights.setValue("added_weight", tensor); + EXPECT_TRUE(weights.contains("added_weight")); + EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes()); +} + +} // namespace torch::nativert + +using namespace ::testing; +struct ContainsTensorDict : torch::CustomClassHolder { + explicit ContainsTensorDict(at::Tensor t) : t_(t) {} + + explicit ContainsTensorDict(c10::Dict dict) { + t_ = dict.at(std::string("init_tensor")); + } + + c10::Dict serialize() const { + c10::Dict dict; + dict.insert(std::string("init_tensor"), t_); + return dict; + } + + at::Tensor t_; +}; + +static auto reg = + torch::class_("testing", "ContainsTensorDict") + .def(torch::init()) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) + -> c10::Dict { + return self->serialize(); + }, + // __setstate__ + [](c10::Dict data) + -> c10::intrusive_ptr { + return c10::make_intrusive(std::move(data)); + }); + +TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) { + // Save + auto customObj = + c10::make_intrusive(torch::tensor({1, 2, 3})); + const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj))); + + // Load + const auto loadedCustomObj = + torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()}); + EXPECT_TRUE(loadedCustomObj.isObject()); + EXPECT_EQ( + loadedCustomObj.to>() + ->t_[0] + .item(), + 1); +} diff --git a/torch/nativert/executor/Weights.cpp b/torch/nativert/executor/Weights.cpp new file mode 100644 index 000000000000..44a29d95eb67 --- /dev/null +++ b/torch/nativert/executor/Weights.cpp @@ -0,0 +1,439 @@ + +#include +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +#include +#include + +namespace torch::nativert { + +WeightVersion Weights::globalVersion_ = 0; + +Weights::Weights( + const Graph* graph, + const std::optional>& + stateDict, + Placement placement) + : graph_(graph), + weightsMeta_(graph->weightsMeta()), + placement_(std::move(placement)), + version_(globalVersion_++) { + if (stateDict.has_value()) { + loadStateDict(stateDict.value()); + } +} + +Weights::Weights( + const Graph* graph, + std::shared_ptr pytorchStreamReader, + const std::unordered_map& stateDictPaths, + std::string_view stateDictPathPrefix, + const std::unordered_map& constantPaths, + std::string_view constantPathPrefix, + Placement placement, + std::function skipSizeCheck, + std::function skipDtypeCheck) + : graph_(graph), + weightsMeta_(graph->weightsMeta()), + placement_(std::move(placement)), + version_(globalVersion_++), + skipSizeCheck_(std::move(skipSizeCheck)), + skipDtypeCheck_(std::move(skipDtypeCheck)) { + auto loadAndInsert = + [&](const std::string& tensorName, + std::string_view pathPrefix, + const std::unordered_map& tensorPaths, + bool isUsed) { + auto pathIt = tensorPaths.find(tensorName); + TORCH_CHECK( + pathIt != tensorPaths.end(), + "Couldn't find ", + tensorName, + " in tensorPaths"); + + const std::string tensorPath = std::string{pathPrefix} + pathIt->second; + VLOG(1) << "Loading weight from: " << tensorPath; + TORCH_CHECK( + pytorchStreamReader->hasRecord(tensorPath), + tensorPath, + " not found"); + + auto [tensorData, tensorDataSize] = + pytorchStreamReader->getRecord(tensorPath); + + // TODO: We now have two copies of metadata for weights, one in + // model definition /models/.json, another in + // /extra/xl_weights/_model_param_config.json + // Currently, we only use the metadata from model definition. + std::optional tensorMeta; + if (weightsMeta_.find(tensorName) != weightsMeta_.end()) { + tensorMeta = weightsMeta_.at(tensorName); + } else { + TORCH_CHECK(false, "Tensor meta not found for: ", tensorName); + } + + if (tensorDataSize == 0 && tensorMeta->numel() > 0) { + VLOG(1) << "Tensor " << tensorName + << " does not have data and create on Meta device"; + allValues_[tensorName] = at::empty_strided( + tensorMeta->sizes(), + tensorMeta->strides(), + tensorMeta->asTensorOptions().device(at::kMeta)); + return; + } + + if (!isUsed) { + VLOG(1) << "Tensor " << tensorName << " is not used during inference"; + auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); + allValues_[tensorName] = + at::scalar_tensor(0, at::TensorOptions().device(targetDevice)); + return; + } + + size_t bytesPerEntry = + c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize(); + auto device = tensorData.device(); + auto storage = c10::Storage( + c10::Storage::use_byte_size_t(), + at::detail::computeStorageNbytes( + tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry), + std::move(tensorData), // ownership is transferred + nullptr, + false); + const auto tensorOptions = at::TensorOptions(device) + .dtype(tensorMeta->dtype()) + .requires_grad(false); + auto tensor = + at::empty({0}, tensorOptions) + .set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides()); + + auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); + VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice; + if (!isSameDevice(targetDevice, tensor.device())) { + tensor = tensor.to(targetDevice); + } + + allValues_[tensorName] = tensor; + }; + + auto loadAndInsertParamsBuffers = [&](const auto& tensorName, bool isUsed) { + return loadAndInsert( + std::string(tensorName), stateDictPathPrefix, stateDictPaths, isUsed); + }; + + size_t weightIndex = 0; + bool isUsed = true; + const auto& weightValues = graph->weightValues(); + + for (const auto& tensorName : graph->signature().parameters()) { + isUsed = !weightValues[weightIndex]->users().empty(); + if (!isUsed) { + unusedWeights_.insert(std::string(tensorName)); + } + loadAndInsertParamsBuffers(tensorName, isUsed); + weightIndex++; + } + for (const auto& tensorName : graph->signature().buffers()) { + isUsed = !weightValues[weightIndex]->users().empty(); + if (!isUsed) { + unusedWeights_.insert(std::string(tensorName)); + } + loadAndInsertParamsBuffers(tensorName, isUsed); + weightIndex++; + } + + // Load tensor constants and custom object constants, they are both stored + // in the same directory in the archive, i.e. "extra/constants/" tensor + // constants are prefixed with "tensor_" custom objects are prefixed with + // "custom_obj_" + auto loadConstants = [&](const auto& constants) { + for (const auto& constantName : constants) { + auto pathIt = constantPaths.find(std::string(constantName)); + TORCH_CHECK( + pathIt != constantPaths.end(), + "Couldn't find ", + constantName, + " in constantPaths"); + auto& fileName = pathIt->second; + + if (c10::starts_with( + fileName, + torch::_export::archive_spec::TENSOR_CONSTANT_FILENAME_PREFIX)) { + // tensor constants + isUsed = !weightValues[weightIndex]->users().empty(); + if (!isUsed) { + unusedWeights_.insert(std::string(constantName)); + } + loadAndInsert( + std::string(constantName), + constantPathPrefix, + constantPaths, + isUsed); + weightIndex++; + } else { + TORCH_CHECK(false, "Unknown constant path: ", fileName); + } + } + }; + loadConstants(graph->signature().nonPersistentBuffers()); + loadConstants(graph->signature().tensorConstants()); + + // custom object constants + for (const auto& customObjName : graph->signature().customObjs()) { + auto pathIt = constantPaths.find(std::string(customObjName)); + TORCH_CHECK( + pathIt != constantPaths.end(), + "Couldn't find ", + customObjName, + " in constantPaths"); + auto& fileName = pathIt->second; + + if (!c10::starts_with( + fileName, + torch::_export::archive_spec::CUSTOM_OBJ_FILENAME_PREFIX)) { + TORCH_CHECK(false, "Unknown constant path: ", fileName); + } + std::string customObjPath = std::string{constantPathPrefix} + fileName; + LOG(INFO) << "Loading custom object from: " << customObjPath; + + TORCH_CHECK( + pytorchStreamReader->hasRecord(customObjPath), + customObjPath, + " not found"); + + const auto& [customObjData, customObjDataSize] = + pytorchStreamReader->getRecord(customObjPath); + + const char* customObjDataPtr = + reinterpret_cast(customObjData.get()); + std::string customObjBytes( + customObjDataPtr, customObjDataPtr + customObjDataSize); + + c10::IValue customObj = torch::jit::pickle_load_obj(customObjBytes); + TORCH_CHECK( + customObj.isCustomClass(), "Custom object is not a custom class"); + TORCH_CHECK(!customObj.isNone(), "Custom object is None"); + customObjs_[std::string(customObjName)] = std::move(customObj); + customObjsPaths_[customObjPath] = std::string(customObjName); + } +} + +std::unordered_map Weights::parameters() const { + std::unordered_map result; + for (const auto& name : graph_->signature().parameters()) { + result.emplace(name, allValues_.at(std::string(name))); + } + return result; +} + +std::unordered_map Weights::buffers() const { + std::unordered_map result; + for (const auto& name : graph_->signature().buffers()) { + result.emplace(name, allValues_.at(std::string(name))); + } + return result; +} + +std::unordered_map Weights::attributes() const { + return allValues_; +} + +at::Tensor Weights::at(const std::string& name) const { + auto it = allValues_.find(name); + if (it != allValues_.end()) { + return it->second; + } + + TORCH_CHECK(false, name, " not found in Weights ", toString()); +} + +at::Tensor& Weights::at(const std::string& name) { + auto it = allValues_.find(name); + if (it != allValues_.end()) { + return it->second; + } + + TORCH_CHECK(false, name, " not found in Weights ", toString()); +} + +bool Weights::contains(const std::string& name) const { + return allValues_.find(name) != allValues_.end(); +} + +c10::IValue Weights::getCustomObj(const std::string& name) const { + auto it = customObjs_.find(name); + if (it != customObjs_.end()) { + return it->second; + } + + TORCH_CHECK(false, "Custom objects ", name, " not found in Weights"); +} + +c10::IValue Weights::getCustomObjByFileName(const std::string& name) const { + auto it = customObjsPaths_.find(name); + TORCH_CHECK( + it != customObjsPaths_.end(), + "Custom objects with file name ", + name, + " not found in Weights"); + const std::string obj_name = it->second; + return getCustomObj(obj_name); +} + +void Weights::loadStateDict( + const std::unordered_map& stateDict) { + auto validateAndInsert = [&](const std::string& name) { + auto stateDictIt = stateDict.find(name); + TORCH_CHECK( + stateDictIt != stateDict.end(), + "Couldn't find ", + name, + " in stateDict"); + + // Verify that the tensor matches the tensorMeta + auto it = weightsMeta_.find(name); + TORCH_CHECK( + it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta"); + + auto targetDevice = placement_.getMappedDevice(it->second.device()); + auto tensor = stateDictIt->second.toTensor().to(targetDevice); + + TORCH_CHECK(tensor.sizes() == it->second.sizes()); + TORCH_CHECK(tensor.dtype() == it->second.dtype()); + + allValues_.emplace(name, tensor); + }; + + for (const auto& name : graph_->signature().parameters()) { + validateAndInsert(std::string(name)); + } + for (const auto& name : graph_->signature().buffers()) { + validateAndInsert(std::string(name)); + } + // TensorConstants_ not filled !! +} + +void Weights::validateValue(const std::string& name, const at::Tensor& newValue) + const { + auto& weightMeta = weightsMeta_.at(name); + + TORCH_CHECK( + weightMeta.sizes() == newValue.sizes() || + (skipSizeCheck_ && skipSizeCheck_(name)) || + unusedWeights_.find(name) != unusedWeights_.end(), + "Mismatched sizes for ", + name, + ": ", + weightMeta.sizes(), + " vs ", + newValue.sizes()); + TORCH_CHECK( + weightMeta.dtype() == newValue.dtype() || + (skipDtypeCheck_ && skipDtypeCheck_(name)) || + unusedWeights_.find(name) != unusedWeights_.end(), + "Mismatched dtype for ", + name, + ": ", + weightMeta.dtype(), + " vs ", + newValue.dtype()); + + auto targetDevice = placement_.getMappedDevice(weightMeta.device()); + if (targetDevice.is_cpu() && targetDevice.has_index()) { + LOG(WARNING) << "Target device is cpu but has index: " << targetDevice; + } + TORCH_CHECK( + isSameDevice(targetDevice, newValue.device()), + "Mismatched device for ", + name, + ": ", + targetDevice, + " vs ", + newValue.device()); +} + +void Weights::setValue(const std::string& name, const at::Tensor& newValue) { + if (allValues_.find(name) != allValues_.end()) { + validateValue(name, newValue); + } else { + LOG(WARNING) << name << " is not found in the registered weights"; + } + + allValues_[name] = newValue; +} + +void Weights::updateValue(const std::string& name, const at::Tensor& newValue) { + auto it = allValues_.find(name); + TORCH_CHECK( + it != allValues_.end(), name, " not found in Weights ", toString()); + validateValue(name, newValue); + + it->second.copy_(newValue); +} + +void Weights::updateValues( + const std::unordered_map& newValues) { + for (auto& [name, newValue] : newValues) { + updateValue(name, newValue); + } +} + +std::string Weights::toString() const { + std::stringstream ss; + ss << '['; + for (const auto& [name, _] : allValues_) { + ss << name << ", "; + } + ss << ']'; + ss << '['; + for (const auto& [name, _] : customObjs_) { + ss << name << ", "; + } + ss << ']'; + return ss.str(); +} + +void Weights::validateAllWeightsLoaded() { + auto checkNames = [&](const auto& names) { + for (const auto& name : names) { + if (unusedWeights_.find(std::string(name)) != unusedWeights_.end()) { + continue; + } + auto it = allValues_.find(std::string(name)); + TORCH_CHECK(it != allValues_.end(), "Missing weight: ", name); + TORCH_CHECK(it->second.defined(), "Weight not defined: ", name); + if (it->second.device().is_meta()) { + LOG(WARNING) << "Weight is on meta device: " << name; + } + } + }; + checkNames(graph_->signature().parameters()); + checkNames(graph_->signature().buffers()); + checkNames(graph_->signature().nonPersistentBuffers()); + checkNames(graph_->signature().tensorConstants()); +} + +void Weights::updateFoldedConst(std::string_view name, c10::IValue tensor) { + foldedConstsMap_[std::string{name}] = std::move(tensor); +} + +const std::unordered_map& Weights::getFoldedConsts() + const { + return foldedConstsMap_; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h new file mode 100644 index 000000000000..19f40114e2a2 --- /dev/null +++ b/torch/nativert/executor/Weights.h @@ -0,0 +1,145 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include + +namespace torch::nativert { + +using WeightVersion = int; +/** + * @brief A class that manages the weights of a graph, providing functionality + * to load, access, and manipulate them. + * + * It is responsible for handling the parameters, buffers, and constants + * associated with a graph It provides mechanisms to load weights from + * serialized data, access and modify them, and performs necessary validation + * checks. + */ +class Weights { + public: + explicit Weights( + const Graph* graph, + const std::optional>& + stateDict = std::nullopt, + Placement placement = Placement()); + + // Arguments + // - pytorchStreamReader: the reader for the model archive + // - stateDictPath: a map from parameter/buffer/constant name to file path in + // the archive + // - stateDictPathPrefix: a prefix that will be prepended to paths in + // stateDictPathPrefix + // - constantPaths: a map from constant name to file path in the archive + // - constantPathPrefix: a prefix that will be prepended to paths in + // constantPathPrefix + // - placement: the device placement of the weights, default to follow the + // original device in the weight's metadata + explicit Weights( + const Graph* graph, + std::shared_ptr + pytorchStreamReader, + const std::unordered_map& stateDictPaths, + std::string_view stateDictPathPrefix, + const std::unordered_map& constantPaths, + std::string_view constantPathPrefix, + Placement placement = Placement(), + std::function skipSizeCheck = {}, + std::function skipDtypeCheck = {}); + + at::Tensor at(const std::string& name) const; + at::Tensor& at(const std::string& name); + bool contains(const std::string& name) const; + c10::IValue getCustomObj(const std::string& name) const; + c10::IValue getCustomObjByFileName(const std::string& name) const; + + std::unordered_map parameters() const; + + std::unordered_map buffers() const; + + std::unordered_map attributes() const; + + void loadStateDict( + const std::unordered_map& stateDict); + + /* + * Replace the value stored at the weight with name "name". + */ + void setValue(const std::string& name, const at::Tensor& newValue); + + /* + * Update the value stored at the weight with name "name". + * This is done in-place. + */ + void updateValue(const std::string& name, const at::Tensor& newValue); + + void updateValues( + const std::unordered_map& newValues); + + void validateValue(const std::string& name, const at::Tensor& newValue) const; + + void validateAllWeightsLoaded(); + + void updateFoldedConst(std::string_view name, c10::IValue tensor); + + const std::unordered_map& getFoldedConsts() const; + + C10_ALWAYS_INLINE const c10::FastMap& + getConstFoldedValues() const { + return constFoldedValues_; + } + + C10_ALWAYS_INLINE void setConstFoldedValue( + const std::string& n, + c10::IValue iv) { + constFoldedValues_.insert_or_assign(n, std::move(iv)); + } + + std::string toString() const; + + WeightVersion version() const { + return version_; + } + + private: + const Graph* graph_; + const std::unordered_map& weightsMeta_; + Placement placement_; + + // keys are parameter/buffer/constant names, not graph input names! + std::unordered_map allValues_; + + std::unordered_map customObjs_; + + // contains CustomClassHolder map from a file name to an arbitray + // key in customObjs_ that hold the loaded content of the file. + // This is used in AOTIDelegateExecutor. + std::unordered_map customObjsPaths_; + + // The liftcycle of folded consts should be tied with the weights from which + // it was derived. The ordering of the constant should be consistent with + // the output order of const graph. + std::vector foldedConsts_; + std::unordered_map foldedConstsMap_; + + c10::FastMap constFoldedValues_; + + // unique version number for this instance of weight + const WeightVersion version_; + + // every instance of Weight has a unique version number + static WeightVersion globalVersion_; + + std::function skipSizeCheck_ = {}; + std::function skipDtypeCheck_ = {}; + + // save the names of unused weights + std::unordered_set unusedWeights_; +}; + +} // namespace torch::nativert