mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Recently we made it possible to serialize ExportedPrograms with fake parameters/buffers/etc. The serialization regime was kind of whacky; basically we serialized a stub and reassembled the FakeTensor using metadata that we had stashed elsewhere in the Graph state. This was bad for a few reasons: - Storing the metadata separately from the actual serialized object caused situations where you could have one but not the other. An example case is if you had a FakeTensor contained inside a TorchBind object—there was no obviously place to store the metadata for this. This actually happens—TensorQueue in fbgemm does this. - It created an annoying cycle: we had to deserialize the Graph's tensor metadata in order to deserialize (potentially faked) constants, but we need constants in order to deserialize the Graph. This fixes all that. The basic idea is to patch the reducer function for FakeTensor at serialization time, and serialize a copy of the FakeTensor metadata. We already are policing BC for the TensorMeta schema struct so it's not a net increase in the BC surface. As a bonus, I fixed a weird bug with torchbind tracing where we were accidentally reinterpreting a torch.ScriptObject as a torch.ScriptModule (which was the root cause of some weird behavior @bahuang was seeing last week). Differential Revision: [D53601251](https://our.internmc.facebook.com/intern/diff/D53601251/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/119531 Approved by: https://github.com/zhxchen17
85 lines
2.3 KiB
C++
85 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/function.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <torch/csrc/api/include/torch/imethod.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
|
|
|
// A method in a module, e.g. f in:
|
|
//
|
|
// class M(ScriptModule):
|
|
// @script_method
|
|
// def f(self, x):
|
|
// ...
|
|
// Note: because Method/Module are exposed to python these
|
|
// classes use python method naming conventions
|
|
struct TORCH_API Method : public torch::IMethod {
|
|
Method(ObjectPtr owner, Function* function);
|
|
|
|
// the module that contains this method.
|
|
Module owner() const;
|
|
// the raw objectptr that owns this method, for when the method is owned by a
|
|
// torchbind object.
|
|
ObjectPtr raw_owner() const;
|
|
void run(Stack& stack);
|
|
void run(Stack&& stack) {
|
|
run(stack);
|
|
}
|
|
|
|
c10::IValue operator()(
|
|
std::vector<c10::IValue> stack,
|
|
const Kwargs& kwargs = Kwargs()) const override;
|
|
|
|
// Run method async. Invocation on this function would invokes a JIT
|
|
// interpreter that executes ops inline, one by one, on caller's thread. A
|
|
// model can utilize async op, i.e. `fork`, to launch an asynchronous task
|
|
// which will be launched on provided `taskLauncher`.
|
|
c10::intrusive_ptr<c10::ivalue::Future> run_async(
|
|
std::vector<c10::IValue> stack,
|
|
const Kwargs& kwargs = Kwargs(),
|
|
TaskLauncher taskLauncher = at::launch);
|
|
|
|
std::shared_ptr<Graph> graph() const {
|
|
return toGraphFunction(*function_).graph();
|
|
}
|
|
|
|
const std::string& name() const override {
|
|
return function_->name();
|
|
}
|
|
|
|
size_t num_inputs() const {
|
|
return function_->num_inputs();
|
|
}
|
|
|
|
GraphExecutor& get_executor() {
|
|
return toGraphFunction(*function_).get_executor();
|
|
}
|
|
|
|
Function& function() const {
|
|
return *function_;
|
|
}
|
|
|
|
private:
|
|
void setArgumentNames(std::vector<std::string>&) const override;
|
|
|
|
// Methods are uniqued onwed by a single module. This raw pointer allows
|
|
// looking up the module.
|
|
ObjectPtr owner_;
|
|
|
|
// Underlying unbound function
|
|
Function* function_;
|
|
};
|
|
|
|
namespace script {
|
|
// We once had a `script::` namespace that was deleted. This is for backcompat
|
|
// of the public API; new code should not use this type alias.
|
|
using Method = ::torch::jit::Method;
|
|
} // namespace script
|
|
|
|
} // namespace torch::jit
|