mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add torch._library.register_fake_class to fakify torchBind class (#122622)
This PR only adds abstract class registration logic without touching existing tests so they still trace with real script object. The added tests are only for registration APIs and test error messages. Our design is that the abstract implementation should be in Python. This is much better in terms of usability. But this also has implications for custom op that takes script object as input, which is detailed later in this stack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122622 Approved by: https://github.com/zou3519 ghstack dependencies: #122619, #122620, #122621
This commit is contained in:
@ -194,6 +194,10 @@ struct TensorQueue : torch::CustomClassHolder {
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
std::vector<at::Tensor> get_raw_queue() {
|
||||
std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
|
||||
return raw_queue;
|
||||
}
|
||||
|
||||
private:
|
||||
std::deque<at::Tensor> queue_;
|
||||
@ -563,6 +567,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
.def("top", &TensorQueue::top)
|
||||
.def("size", &TensorQueue::size)
|
||||
.def("clone_queue", &TensorQueue::clone_queue)
|
||||
.def("get_raw_queue", &TensorQueue::get_raw_queue)
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<TensorQueue>& self)
|
||||
|
Reference in New Issue
Block a user