Files
pytorch/torch/nativert/executor/DelegateExecutor.h
Kevin Fu 04349f9ee5 [PT2]: Skip AOTI Weight Loading during Init (#158416)
Summary: AOTI already has weights embedded in .so file. So for the initial load, no need to load the weights again. This allows lowered modules can have different set of weights on different hardwares.

Test Plan:
```
MODEL_TYPE=ads_mtml_offsite_cvr_oba_optout_dedicated_model
MODEL_ENTITY_ID=895279202
SNAPSHOT_ID=0
MODULE=merge

buck2 run mode/dev-nosan -c fbcode.nvcc_arch=a100,h100 -c fbcode.enable_gpu_sections=true fbcode//caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.disagg.gpu.${MODULE} --moduleName ${MODULE} --predictor-hardware-type 1 --submodToDevice ""  --benchmarkDontRebatchSamples=true --benchmarkNumIterations 1000
```

Rollback Plan:

Differential Revision: D78383881

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158416
Approved by: https://github.com/henryoier, https://github.com/SherlockNoMad
2025-07-17 06:47:47 +00:00

55 lines
2.1 KiB
C++

#pragma once
#include <memory>
#include <vector>
#include <ATen/core/Tensor.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
#include <torch/nativert/executor/Weights.h>
namespace torch::nativert {
std::string extractToTemporaryFolder(
caffe2::serialize::PyTorchStreamReader& packageReader,
const std::string& targetPath);
using MakeProxyExecutorFn =
std::function<std::unique_ptr<torch::aot_inductor::ProxyExecutor>(
const std::string&,
bool,
std::optional<std::unordered_map<std::string, c10::IValue>>)>;
// This is the extension point for delegation backends.
class DelegateExecutor {
public:
virtual ~DelegateExecutor() {}
// Runtime calls processWeights() to pass the weights to the delegate backend.
// Typically, a backend would perform some form of validation and processing,
// such as constant folding. The processed weights stays in the deactivate
// state until commitWeights() is called.
//
// Weights tensors are co-owned by the runtime and the delegate backend.
// In the regular inference run() path, neither Runtime or Delegate backend
// can modify the weights tensor.
// To support inplace weight update, weight tensors are be exposed by
// ModelRunner::getWeights() to an external caller. The external caller can
// then modify the weight tensors in-place. Such mutation would instantly
// affect the weight tensors in the delegate backend.
// When a weight tensor is no longer used by the delegate backend, the backend
// must release it by decreasing a refcount. Runtime would
// also release the refcount for weight tensor if it's no longer activate. The
// underlying storage for weight tensors will be freed when the refcount
// reaches 0.
virtual void processWeights(std::shared_ptr<Weights> weights) = 0;
// This call activate the processed weights.
virtual void commitWeights() = 0;
virtual void initWeights(std::shared_ptr<Weights> weights) = 0;
virtual std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) = 0;
};
} // namespace torch::nativert