Support record_stream in dispatch mode (#99529)

Summary:
Issuing a `t.record_stream(s)` call while a `TorchDispatchMode` is active was throwing because PyTorch was unable to convert a c10::Stream back to a Python object. It's now fixed.

Fixes https://github.com/pytorch/pytorch/issues/94403

Test Plan: Added a unit test

Differential Revision: D45117566

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99529
Approved by: https://github.com/albanD
This commit is contained in:
Luca Wehrstedt
2023-04-21 07:17:19 +00:00
committed by PyTorch MergeBot
parent 0ac0d9d224
commit 24bf15fe8d
7 changed files with 95 additions and 0 deletions

View File

@ -634,6 +634,8 @@ py::object toPyObject(IValue ivalue) {
}
} else if (ivalue.isDevice()) {
return py::cast<py::object>(THPDevice_New(std::move(ivalue).toDevice()));
} else if (ivalue.isStream()) {
return py::cast(std::move(ivalue).toStream());
} else if (ivalue.isGenericDict()) {
auto dict = std::move(ivalue).toGenericDict();
py::dict py_dict;