[export] Add effect token to export (#121424)

Following the creation of effect tokens (https://github.com/pytorch/pytorch/pull/120296), we want to now add support for these tokens in export because the calling/returning convention has changed. The inputs are now `(tokens, params, buffers, constants, user_inputs)` and the outputs are `(tokens, buffer_mutations, user_mutations, user_outputs)`. The graph looks something like:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %attr : [num_users=2] = placeholder[target=attr]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %with_effects : [num_users=2] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, _TorchScriptTesting.takes_foo.default, %attr, %arg1_1), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 1), kwargs = {})
    %with_effects_1 : [num_users=2] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%getitem, _TorchScriptTesting.takes_foo.default, %attr, %getitem_1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects_1, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects_1, 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %getitem_3), kwargs = {})
    return (getitem_2, add)
```

During unlifting, we will first remove the tokens and with_effect calls using the `remove_effect_tokens` pass. (cc @SherlockNoMad on the pass to remove tokens). This is so that this won't change the calling conventions when retracing. The graph after unlifting looks something like:
```
graph():
    %attr_1 : [num_users=2] = get_attr[target=attr]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %takes_foo_default_1 : [num_users=1] = call_function[target=torch.ops._TorchScriptTesting.takes_foo.default](args = (%attr_1, %arg1_1), kwargs = {})
    %takes_foo_default : [num_users=1] = call_function[target=torch.ops._TorchScriptTesting.takes_foo.default](args = (%attr_1, %takes_foo_default_1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %takes_foo_default), kwargs = {})
    return (add,)
```

Serialization support will be added in a followup.
Note: tokens only affect custom ops that take in ScriptObjects, not ScriptObject methods yet.

Differential Revision: [D54639390](https://our.internmc.facebook.com/intern/diff/D54639390)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121424
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
angelayi
2024-03-08 10:29:12 -08:00
committed by PyTorch MergeBot
parent eb3919944d
commit e8836759d0
9 changed files with 278 additions and 27 deletions

View File

@ -342,6 +342,10 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
});
m.def(
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
m.def(
"takes_foo_list_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor[]");
m.def(
"takes_foo_tuple_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> (Tensor, Tensor)");
m.class_<FooGetterSetter>("_FooGetterSetter")
.def(torch::init<int64_t, int64_t>())
@ -476,11 +480,37 @@ at::Tensor takes_foo(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
return foo->add_tensor(x);
}
std::vector<at::Tensor> takes_foo_list_return(
c10::intrusive_ptr<Foo> foo,
at::Tensor x) {
std::vector<at::Tensor> result;
result.reserve(3);
auto a = foo->add_tensor(x);
auto b = foo->add_tensor(a);
auto c = foo->add_tensor(b);
result.push_back(a);
result.push_back(b);
result.push_back(c);
return result;
}
std::tuple<at::Tensor, at::Tensor> takes_foo_tuple_return(
c10::intrusive_ptr<Foo> foo,
at::Tensor x) {
auto a = foo->add_tensor(x);
auto b = foo->add_tensor(a);
return std::make_tuple(a, b);
}
TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
m.impl("takes_foo", takes_foo);
m.impl("takes_foo_list_return", takes_foo_list_return);
m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
}
TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
m.impl("takes_foo", &takes_foo);
m.impl("takes_foo_list_return", takes_foo_list_return);
m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
}
} // namespace