Initial torchbind support in PT2 (#117697)

This PR adds the bare minimum functionality to get torchbind working in an e2e testable way on PT2.

It implements:
* ProxyTensor support
* Simple torch.export support (proxytensor-only path, e.g. non-strict).
* add some tests exercising the path.

Because all this is not fully baked, I hide the functionality behind a feature flag (`enable_torchbind_tracing()`) so it does not affect regular users for now.

Still on the agenda:
* Dynamo support
* Actual FakeMode support
* Mutability support

Hoping to get this first bit in as a standalone, as it will unblock some more extensive experimentation/testing going on internally.

Differential Revision: [D51825372](https://our.internmc.facebook.com/intern/diff/D51825372/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117697
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
suo
2024-01-18 17:34:43 -08:00
committed by PyTorch MergeBot
parent c51a4e64c0
commit 4057d005ff
14 changed files with 296 additions and 22 deletions

View File

@ -46,6 +46,9 @@ struct Foo : torch::CustomClassHolder {
int64_t add(int64_t z) {
return (x + y) * z;
}
at::Tensor add_tensor(at::Tensor z) {
return (x + y) * z;
}
void increment(int64_t z) {
this->x += z;
this->y += z;
@ -317,8 +320,18 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("info", &Foo::info)
.def("increment", &Foo::increment)
.def("add", &Foo::add)
.def("add_tensor", &Foo::add_tensor)
.def("__eq__", &Foo::eq)
.def("combine", &Foo::combine);
.def("combine", &Foo::combine)
.def_pickle(
[](c10::intrusive_ptr<Foo> self) { // __getstate__
return std::vector<int64_t>{self->x, self->y};
},
[](std::vector<int64_t> state) { // __setstate__
return c10::make_intrusive<Foo>(state[0], state[1]);
});
m.def(
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
m.class_<FooGetterSetter>("_FooGetterSetter")
.def(torch::init<int64_t, int64_t>())
@ -436,4 +449,15 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
});
}
at::Tensor takes_foo(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
return foo->add_tensor(x);
}
TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
m.impl("takes_foo", takes_foo);
}
TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
m.impl("takes_foo", &takes_foo);
}
} // namespace