mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Add OSS version of ModelRunner (#159268)
Summary: Implement a ModelRunner from scratch with the minimum features for OSS only Test Plan: test_export -r NativeRT Rollback Plan: Differential Revision: D78979812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159268 Approved by: https://github.com/dolpm
This commit is contained in:
committed by
PyTorch MergeBot
parent
c0c24b61ff
commit
8460131087
@ -669,6 +669,16 @@ void Unpickler::readGlobal(
|
||||
// See [NOTE] skip_next_read_global
|
||||
this->skip_next_read_global--;
|
||||
if (this->skip_next_read_global == 1) {
|
||||
if (module_name == "torch" && class_name == "Tensor") {
|
||||
// This is a special case when we are unpickling a subclassed tensor
|
||||
// with type torch.nn.Buffer. We didn't frequently run into this because
|
||||
// torch.nn.Buffer is introduced later in PyTorch 2 and this type IValue
|
||||
// will not be used in C++.
|
||||
rebuildTensor(false);
|
||||
stack_.emplace_back(int64_t(globals_.size() - 1));
|
||||
this->skip_next_read_global = 0;
|
||||
return;
|
||||
}
|
||||
// Pass through to the correct handler
|
||||
} else if (this->skip_next_read_global == 0) {
|
||||
// Corresponds to the type of `Tensor` being unpickled
|
||||
@ -773,6 +783,10 @@ void Unpickler::readGlobal(
|
||||
// Unpickle a Tensor with Python attributes or
|
||||
// a Subclassed Tensor.
|
||||
rebuildTensorFromTypeV2();
|
||||
} else if (
|
||||
module_name == "torch._utils" && (class_name == "_rebuild_parameter")) {
|
||||
// Unpickle a Parameter
|
||||
rebuildParameter();
|
||||
} else if (
|
||||
module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
|
||||
rebuildSparseTensor();
|
||||
@ -1024,6 +1038,18 @@ void Unpickler::rebuildTensorFromTypeV2() {
|
||||
});
|
||||
}
|
||||
|
||||
void Unpickler::rebuildParameter() {
|
||||
globals_.emplace_back([this] {
|
||||
auto args = pop(stack_).toTuple();
|
||||
size_t tup_idx = 0;
|
||||
const auto args_elems = args->elements();
|
||||
auto result = args_elems.at(tup_idx++).toTensor();
|
||||
auto requires_grad = args_elems.at(tup_idx++).toBool();
|
||||
result.requires_grad_(requires_grad);
|
||||
stack_.emplace_back(std::move(result));
|
||||
});
|
||||
}
|
||||
|
||||
#ifdef USE_RPC
|
||||
void Unpickler::rebuildRRef() {
|
||||
globals_.emplace_back([this] {
|
||||
|
Reference in New Issue
Block a user