mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Initial torchbind support in PT2 (#117697)
This PR adds the bare minimum functionality to get torchbind working in an e2e testable way on PT2. It implements: * ProxyTensor support * Simple torch.export support (proxytensor-only path, e.g. non-strict). * add some tests exercising the path. Because all this is not fully baked, I hide the functionality behind a feature flag (`enable_torchbind_tracing()`) so it does not affect regular users for now. Still on the agenda: * Dynamo support * Actual FakeMode support * Mutability support Hoping to get this first bit in as a standalone, as it will unblock some more extensive experimentation/testing going on internally. Differential Revision: [D51825372](https://our.internmc.facebook.com/intern/diff/D51825372/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117697 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
@ -46,6 +46,9 @@ struct Foo : torch::CustomClassHolder {
|
||||
int64_t add(int64_t z) {
|
||||
return (x + y) * z;
|
||||
}
|
||||
at::Tensor add_tensor(at::Tensor z) {
|
||||
return (x + y) * z;
|
||||
}
|
||||
void increment(int64_t z) {
|
||||
this->x += z;
|
||||
this->y += z;
|
||||
@ -317,8 +320,18 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
.def("info", &Foo::info)
|
||||
.def("increment", &Foo::increment)
|
||||
.def("add", &Foo::add)
|
||||
.def("add_tensor", &Foo::add_tensor)
|
||||
.def("__eq__", &Foo::eq)
|
||||
.def("combine", &Foo::combine);
|
||||
.def("combine", &Foo::combine)
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<Foo> self) { // __getstate__
|
||||
return std::vector<int64_t>{self->x, self->y};
|
||||
},
|
||||
[](std::vector<int64_t> state) { // __setstate__
|
||||
return c10::make_intrusive<Foo>(state[0], state[1]);
|
||||
});
|
||||
m.def(
|
||||
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
|
||||
|
||||
m.class_<FooGetterSetter>("_FooGetterSetter")
|
||||
.def(torch::init<int64_t, int64_t>())
|
||||
@ -436,4 +449,15 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
});
|
||||
}
|
||||
|
||||
at::Tensor takes_foo(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
|
||||
return foo->add_tensor(x);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
|
||||
m.impl("takes_foo", takes_foo);
|
||||
}
|
||||
TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
|
||||
m.impl("takes_foo", &takes_foo);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Reference in New Issue
Block a user