mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This diff fixes two things which come up when testing a tgif-published pt2 model remote net: 1) Updates isSameDevice to handle meta device to avoid this error: ``` what(): Unsupported device typemeta and meta Exception raised from isSameDevice at fbcode/caffe2/torch/nativert/executor/PlacementUtils.cpp:20 ``` 2. Updates xl weight v2 loading logic in Weights.cpp to handle non-TBE xl-weights. Today, we enforce the device is the same for an old weight and new weight when replacing with ModelRunnerAdapter.setAttr(). However, the way we replace non-TBE xl weights is to find any weights on "meta" device and then replace them with their correct weight with real device from xl_weights folder. Therefore, the new weight and old weight will always have different devices and the device check is invalid. I don't think we've run into this so far bc non-TBE xl weights have not been thoroughly tested until now. Test Plan: Run MRS you model merge net, which uses non-TBE xl weights. Confirm that before change #1 we get error: ``` Unsupported device typemeta and meta ``` Then after change #1 and before change #2 we get: ``` what(): Mismatched device for merge.user_tower.linear.weight: meta vs cpu Exception raised from validateValue at fbcode/caffe2/torch/nativert/executor/Weights.cpp:374 ``` After change run is successful Command: ``` MODEL_ENTITY_ID=921242082 SNAPSHOT_ID=1269 module_name=merge SAMPLE_INPUT_DIR=/data/users/georgiaphillips/models/921242082/${SNAPSHOT_ID}/${module_name}_archive/package/data/sample_inputs buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100,a100 -c fbcode.enable_gpu_sections=true caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.${module_name} --moduleName=${module_name} --submodToDevice="merge|cuda0" --benchmarkEnableProfiling=false --disableStaticRuntime=true --doNotRandomizeSampleInputs=true --benchmarkDontRebatchSamples=true --pytorch_predictor_sigmoid_static_dispatch_enable=false --pytorch_predictor_sigmoid_graph_passes_enable=false --sampleInputFilePath=${SAMPLE_INPUT_DIR}/${module_name}.pt ``` Rollback Plan: Differential Revision: D80713052 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162842 Approved by: https://github.com/henryoier
464 lines
15 KiB
C++
464 lines
15 KiB
C++
|
|
#include <c10/util/Logging.h>
|
|
#include <utility>
|
|
|
|
#include <torch/csrc/export/pt2_archive_constants.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/csrc/jit/serialization/import_read.h>
|
|
#include <torch/csrc/jit/serialization/pickle.h>
|
|
#include <torch/nativert/executor/Weights.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/empty_strided.h>
|
|
#include <ATen/ops/scalar_tensor.h>
|
|
#endif
|
|
|
|
#include <caffe2/serialize/inline_container.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
WeightVersion Weights::globalVersion_ = 0;
|
|
|
|
Weights::Weights(
|
|
const Graph* graph,
|
|
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
|
stateDict,
|
|
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
|
constants)
|
|
: graph_(graph),
|
|
weightsMeta_(graph->weightsMeta()),
|
|
version_(globalVersion_++) {
|
|
if (stateDict.has_value()) {
|
|
loadStateDict(stateDict.value());
|
|
}
|
|
if (constants.has_value()) {
|
|
for (const auto& [name, value] : constants.value()) {
|
|
if (value.isTensor()) {
|
|
allValues_[name] = value.toTensor();
|
|
} else if (value.isCustomClass()) {
|
|
customObjs_[name] = value;
|
|
} else {
|
|
TORCH_CHECK(false, "Unknown constant type: ", value.tagKind());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Weights::Weights(
|
|
const Graph* graph,
|
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
|
const std::unordered_map<std::string, std::string>& stateDictPaths,
|
|
std::string_view stateDictPathPrefix,
|
|
const std::unordered_map<std::string, std::string>& constantPaths,
|
|
std::string_view constantPathPrefix,
|
|
std::function<bool(const std::string&)> skipSizeCheck,
|
|
std::function<bool(const std::string&)> skipDtypeCheck)
|
|
: graph_(graph),
|
|
weightsMeta_(graph->weightsMeta()),
|
|
version_(globalVersion_++),
|
|
skipSizeCheck_(std::move(skipSizeCheck)),
|
|
skipDtypeCheck_(std::move(skipDtypeCheck)) {
|
|
auto loadAndInsert =
|
|
[&](const std::string& tensorName,
|
|
std::string_view pathPrefix,
|
|
const std::unordered_map<std::string, std::string>& tensorPaths,
|
|
bool isUsed) {
|
|
auto pathIt = tensorPaths.find(tensorName);
|
|
TORCH_CHECK(
|
|
pathIt != tensorPaths.end(),
|
|
"Couldn't find ",
|
|
tensorName,
|
|
" in tensorPaths");
|
|
|
|
const std::string tensorPath = std::string{pathPrefix} + pathIt->second;
|
|
VLOG(1) << "Loading weight from: " << tensorPath;
|
|
TORCH_CHECK(
|
|
pytorchStreamReader->hasRecord(tensorPath),
|
|
tensorPath,
|
|
" not found");
|
|
|
|
auto [tensorData, tensorDataSize] =
|
|
pytorchStreamReader->getRecord(tensorPath);
|
|
|
|
// TODO: We now have two copies of metadata for weights, one in
|
|
// model definition /models/<model_name>.json, another in
|
|
// /extra/xl_weights/<model_name>_model_param_config.json
|
|
// Currently, we only use the metadata from model definition.
|
|
std::optional<TensorMeta> tensorMeta;
|
|
if (weightsMeta_.find(tensorName) != weightsMeta_.end()) {
|
|
tensorMeta = weightsMeta_.at(tensorName);
|
|
} else {
|
|
TORCH_CHECK(false, "Tensor meta not found for: ", tensorName);
|
|
}
|
|
|
|
if (tensorDataSize == 0 && tensorMeta->numel() > 0) {
|
|
VLOG(1) << "Tensor " << tensorName
|
|
<< " does not have data and create on Meta device";
|
|
allValues_[tensorName] = at::empty_strided(
|
|
tensorMeta->sizes(),
|
|
tensorMeta->strides(),
|
|
tensorMeta->asTensorOptions().device(at::kMeta));
|
|
return;
|
|
}
|
|
|
|
if (!isUsed) {
|
|
VLOG(1) << "Tensor " << tensorName << " is not used during inference";
|
|
auto targetDevice = tensorMeta->device();
|
|
allValues_[tensorName] =
|
|
at::scalar_tensor(0, at::TensorOptions().device(targetDevice));
|
|
return;
|
|
}
|
|
|
|
size_t bytesPerEntry =
|
|
c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize();
|
|
auto device = tensorData.device();
|
|
auto storage = c10::Storage(
|
|
c10::Storage::use_byte_size_t(),
|
|
at::detail::computeStorageNbytes(
|
|
tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry),
|
|
std::move(tensorData), // ownership is transferred
|
|
nullptr,
|
|
false);
|
|
const auto tensorOptions = at::TensorOptions(device)
|
|
.dtype(tensorMeta->dtype())
|
|
.requires_grad(false);
|
|
auto tensor =
|
|
at::empty({0}, tensorOptions)
|
|
.set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides());
|
|
|
|
auto targetDevice = tensorMeta->device();
|
|
VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice;
|
|
if (!isSameDevice(targetDevice, tensor.device())) {
|
|
tensor = tensor.to(targetDevice);
|
|
}
|
|
|
|
allValues_[tensorName] = tensor;
|
|
};
|
|
|
|
auto loadAndInsertParamsBuffers = [&](const auto& tensorName, bool isUsed) {
|
|
return loadAndInsert(
|
|
std::string(tensorName), stateDictPathPrefix, stateDictPaths, isUsed);
|
|
};
|
|
|
|
size_t weightIndex = 0;
|
|
bool isUsed = true;
|
|
const auto& weightValues = graph->weightValues();
|
|
|
|
for (const auto& tensorName : graph->signature().parameters()) {
|
|
isUsed = !weightValues[weightIndex]->users().empty();
|
|
if (!isUsed) {
|
|
unusedWeights_.insert(std::string(tensorName));
|
|
}
|
|
loadAndInsertParamsBuffers(tensorName, isUsed);
|
|
weightIndex++;
|
|
}
|
|
for (const auto& tensorName : graph->signature().buffers()) {
|
|
isUsed = !weightValues[weightIndex]->users().empty();
|
|
if (!isUsed) {
|
|
unusedWeights_.insert(std::string(tensorName));
|
|
}
|
|
loadAndInsertParamsBuffers(tensorName, isUsed);
|
|
weightIndex++;
|
|
}
|
|
|
|
// Load tensor constants and custom object constants, they are both stored
|
|
// in the same directory in the archive, i.e. "extra/constants/" tensor
|
|
// constants are prefixed with "tensor_" custom objects are prefixed with
|
|
// "custom_obj_"
|
|
auto loadConstants = [&](const auto& constants) {
|
|
for (const auto& constantName : constants) {
|
|
auto pathIt = constantPaths.find(std::string(constantName));
|
|
TORCH_CHECK(
|
|
pathIt != constantPaths.end(),
|
|
"Couldn't find ",
|
|
constantName,
|
|
" in constantPaths");
|
|
auto& fileName = pathIt->second;
|
|
|
|
if (c10::starts_with(
|
|
fileName,
|
|
torch::_export::archive_spec::TENSOR_CONSTANT_FILENAME_PREFIX)) {
|
|
// tensor constants
|
|
isUsed = !weightValues[weightIndex]->users().empty();
|
|
if (!isUsed) {
|
|
unusedWeights_.insert(std::string(constantName));
|
|
}
|
|
loadAndInsert(
|
|
std::string(constantName),
|
|
constantPathPrefix,
|
|
constantPaths,
|
|
isUsed);
|
|
weightIndex++;
|
|
} else {
|
|
TORCH_CHECK(false, "Unknown constant path: ", fileName);
|
|
}
|
|
}
|
|
};
|
|
loadConstants(graph->signature().nonPersistentBuffers());
|
|
loadConstants(graph->signature().tensorConstants());
|
|
|
|
// custom object constants
|
|
for (const auto& customObjName : graph->signature().customObjs()) {
|
|
auto pathIt = constantPaths.find(std::string(customObjName));
|
|
TORCH_CHECK(
|
|
pathIt != constantPaths.end(),
|
|
"Couldn't find ",
|
|
customObjName,
|
|
" in constantPaths");
|
|
auto& fileName = pathIt->second;
|
|
|
|
if (!c10::starts_with(
|
|
fileName,
|
|
torch::_export::archive_spec::CUSTOM_OBJ_FILENAME_PREFIX)) {
|
|
TORCH_CHECK(false, "Unknown constant path: ", fileName);
|
|
}
|
|
std::string customObjPath = std::string{constantPathPrefix} + fileName;
|
|
LOG(INFO) << "Loading custom object from: " << customObjPath;
|
|
|
|
TORCH_CHECK(
|
|
pytorchStreamReader->hasRecord(customObjPath),
|
|
customObjPath,
|
|
" not found");
|
|
|
|
const auto& [customObjData, customObjDataSize] =
|
|
pytorchStreamReader->getRecord(customObjPath);
|
|
|
|
const char* customObjDataPtr =
|
|
reinterpret_cast<const char*>(customObjData.get());
|
|
std::string customObjBytes(
|
|
customObjDataPtr, customObjDataPtr + customObjDataSize);
|
|
|
|
c10::IValue customObj = torch::jit::pickle_load_obj(customObjBytes);
|
|
TORCH_CHECK(
|
|
customObj.isCustomClass(), "Custom object is not a custom class");
|
|
TORCH_CHECK(!customObj.isNone(), "Custom object is None");
|
|
customObjs_[std::string(customObjName)] = std::move(customObj);
|
|
customObjsPaths_[customObjPath] = std::string(customObjName);
|
|
}
|
|
}
|
|
|
|
std::unordered_map<std::string, at::Tensor> Weights::parameters() const {
|
|
std::unordered_map<std::string, at::Tensor> result;
|
|
for (const auto& name : graph_->signature().parameters()) {
|
|
result.emplace(name, allValues_.at(std::string(name)));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::unordered_map<std::string, at::Tensor> Weights::buffers() const {
|
|
std::unordered_map<std::string, at::Tensor> result;
|
|
for (const auto& name : graph_->signature().buffers()) {
|
|
result.emplace(name, allValues_.at(std::string(name)));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::unordered_map<std::string, at::Tensor> Weights::attributes() const {
|
|
return allValues_;
|
|
}
|
|
|
|
at::Tensor Weights::at(const std::string& name) const {
|
|
auto it = allValues_.find(name);
|
|
if (it != allValues_.end()) {
|
|
return it->second;
|
|
}
|
|
|
|
TORCH_CHECK(false, name, " not found in Weights ", toString());
|
|
}
|
|
|
|
at::Tensor& Weights::at(const std::string& name) {
|
|
auto it = allValues_.find(name);
|
|
if (it != allValues_.end()) {
|
|
return it->second;
|
|
}
|
|
|
|
TORCH_CHECK(false, name, " not found in Weights ", toString());
|
|
}
|
|
|
|
bool Weights::contains(const std::string& name) const {
|
|
return allValues_.find(name) != allValues_.end();
|
|
}
|
|
|
|
c10::IValue Weights::getCustomObj(const std::string& name) const {
|
|
auto it = customObjs_.find(name);
|
|
if (it != customObjs_.end()) {
|
|
return it->second;
|
|
}
|
|
|
|
TORCH_CHECK(false, "Custom objects ", name, " not found in Weights");
|
|
}
|
|
|
|
c10::IValue Weights::getCustomObjByFileName(const std::string& name) const {
|
|
auto it = customObjsPaths_.find(name);
|
|
TORCH_CHECK(
|
|
it != customObjsPaths_.end(),
|
|
"Custom objects with file name ",
|
|
name,
|
|
" not found in Weights");
|
|
const std::string obj_name = it->second;
|
|
return getCustomObj(obj_name);
|
|
}
|
|
|
|
void Weights::loadStateDict(
|
|
const std::unordered_map<std::string, c10::IValue>& stateDict) {
|
|
auto validateAndInsert = [&](const std::string& name) {
|
|
auto stateDictIt = stateDict.find(name);
|
|
TORCH_CHECK(
|
|
stateDictIt != stateDict.end(),
|
|
"Couldn't find ",
|
|
name,
|
|
" in stateDict");
|
|
|
|
// Verify that the tensor matches the tensorMeta
|
|
auto it = weightsMeta_.find(name);
|
|
TORCH_CHECK(
|
|
it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta");
|
|
|
|
auto targetDevice = it->second.device();
|
|
auto tensor = stateDictIt->second.toTensor().to(targetDevice);
|
|
|
|
TORCH_CHECK(tensor.sizes() == it->second.sizes());
|
|
TORCH_CHECK(tensor.dtype() == it->second.dtype());
|
|
|
|
allValues_.emplace(name, tensor);
|
|
};
|
|
|
|
for (const auto& name : graph_->signature().parameters()) {
|
|
validateAndInsert(std::string(name));
|
|
}
|
|
for (const auto& name : graph_->signature().buffers()) {
|
|
validateAndInsert(std::string(name));
|
|
}
|
|
// TensorConstants_ not filled !!
|
|
}
|
|
|
|
void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
|
|
const {
|
|
validateValue(name, newValue, /*skipDeviceCheck=*/false);
|
|
}
|
|
|
|
void Weights::validateValue(
|
|
const std::string& name,
|
|
const at::Tensor& newValue,
|
|
bool skipDeviceCheck) const {
|
|
auto& weightMeta = weightsMeta_.at(name);
|
|
|
|
TORCH_CHECK(
|
|
weightMeta.sizes() == newValue.sizes() ||
|
|
(skipSizeCheck_ && skipSizeCheck_(name)) ||
|
|
unusedWeights_.find(name) != unusedWeights_.end(),
|
|
"Mismatched sizes for ",
|
|
name,
|
|
": ",
|
|
weightMeta.sizes(),
|
|
" vs ",
|
|
newValue.sizes());
|
|
TORCH_CHECK(
|
|
weightMeta.dtype() == newValue.dtype() ||
|
|
(skipDtypeCheck_ && skipDtypeCheck_(name)) ||
|
|
unusedWeights_.find(name) != unusedWeights_.end(),
|
|
"Mismatched dtype for ",
|
|
name,
|
|
": ",
|
|
weightMeta.dtype(),
|
|
" vs ",
|
|
newValue.dtype());
|
|
|
|
if (!skipDeviceCheck) {
|
|
auto targetDevice = weightMeta.device();
|
|
if (targetDevice.is_cpu() && targetDevice.has_index()) {
|
|
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
|
|
}
|
|
TORCH_CHECK(
|
|
isSameDevice(targetDevice, newValue.device()),
|
|
"Mismatched device for ",
|
|
name,
|
|
": ",
|
|
targetDevice,
|
|
" vs ",
|
|
newValue.device());
|
|
}
|
|
}
|
|
|
|
void Weights::setValue(const std::string& name, const at::Tensor& newValue) {
|
|
setValue(name, newValue, /*skipDeviceCheck=*/false);
|
|
}
|
|
|
|
void Weights::setValue(
|
|
const std::string& name,
|
|
const at::Tensor& newValue,
|
|
bool skipDeviceCheck) {
|
|
if (allValues_.find(name) != allValues_.end()) {
|
|
validateValue(name, newValue, skipDeviceCheck);
|
|
} else {
|
|
LOG(WARNING) << name << " is not found in the registered weights";
|
|
}
|
|
|
|
allValues_[name] = newValue;
|
|
}
|
|
|
|
void Weights::updateValue(const std::string& name, const at::Tensor& newValue) {
|
|
auto it = allValues_.find(name);
|
|
TORCH_CHECK(
|
|
it != allValues_.end(), name, " not found in Weights ", toString());
|
|
validateValue(name, newValue);
|
|
|
|
it->second.copy_(newValue);
|
|
}
|
|
|
|
void Weights::updateValues(
|
|
const std::unordered_map<std::string, at::Tensor>& newValues) {
|
|
for (auto& [name, newValue] : newValues) {
|
|
updateValue(name, newValue);
|
|
}
|
|
}
|
|
|
|
std::string Weights::toString() const {
|
|
std::stringstream ss;
|
|
ss << '[';
|
|
for (const auto& [name, _] : allValues_) {
|
|
ss << name << ", ";
|
|
}
|
|
ss << ']';
|
|
ss << '[';
|
|
for (const auto& [name, _] : customObjs_) {
|
|
ss << name << ", ";
|
|
}
|
|
ss << ']';
|
|
return ss.str();
|
|
}
|
|
|
|
void Weights::validateAllWeightsLoaded() {
|
|
auto checkNames = [&](const auto& names) {
|
|
for (const auto& name : names) {
|
|
if (unusedWeights_.find(std::string(name)) != unusedWeights_.end()) {
|
|
continue;
|
|
}
|
|
auto it = allValues_.find(std::string(name));
|
|
TORCH_CHECK(it != allValues_.end(), "Missing weight: ", name);
|
|
TORCH_CHECK(it->second.defined(), "Weight not defined: ", name);
|
|
if (it->second.device().is_meta()) {
|
|
LOG(WARNING) << "Weight is on meta device: " << name;
|
|
}
|
|
}
|
|
};
|
|
checkNames(graph_->signature().parameters());
|
|
checkNames(graph_->signature().buffers());
|
|
checkNames(graph_->signature().nonPersistentBuffers());
|
|
checkNames(graph_->signature().tensorConstants());
|
|
}
|
|
|
|
void Weights::updateFoldedConst(std::string_view name, c10::IValue tensor) {
|
|
foldedConstsMap_[std::string{name}] = std::move(tensor);
|
|
}
|
|
|
|
const std::unordered_map<std::string, c10::IValue>& Weights::getFoldedConsts()
|
|
const {
|
|
return foldedConstsMap_;
|
|
}
|
|
|
|
} // namespace torch::nativert
|