[JIT] Add torch._C.ScriptDict (#52659)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52659

**Summary**
This commit adds `torch._C.ScriptDict`, a dictionary type that has reference
semantics across the Python/TorchScript boundary. That is, modifications
made to instances of `torch._C.ScriptDict` in TorchScript are visible in
Python even when it is not returned from the function. Instances can be
constructed by passing an instance of a Python dictionary to
`torch.jit.script`. In the case of an empty dictionary, its type is
assumed to be `Dict[str, Tensor]` to be consistent with the handling of
empty dictionaries in TorchScript source code.

`torch._C.ScriptDict` is implemented using a modified version of pybind's `stl_bind.h`-style bindings attached to `ScriptDict`, `ScriptDictIterator` and `ScriptDictKeyIterator`, wrapper classes around `c10::impl::GenericDict` and `c10::impl::GenericDict::iterator`. These bindings allow instances of `torch._C.ScriptDict` to be used as if it were a regular `dict` Python. Reference semantics are achieved by simply retrieving the `IValue` contained in `ScriptDict` in `toIValue` (invoked when converting Python arguments to `IValues` before calling TorchScript code).

**Test Plan**
This commit adds `TestScriptDict` to `test_list_dict.py`, a set of tests
that check that all of the common dictionary operations are supported
and that instances have reference semantics across the
Python/TorchScript boundary.

Differential Revision:
D27211605
D27211605

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Pulled By: SplitInfinity

fbshipit-source-id: 446d4e5328375791aa73eb9e8b04dfe3465af960
This commit is contained in:
Meghan Lele
2021-05-27 10:24:11 -07:00
committed by Facebook GitHub Bot
parent 95b1bc1009
commit b14c3205fd
11 changed files with 612 additions and 10 deletions

View File

@ -0,0 +1,201 @@
#include <ATen/core/ivalue.h>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/python/python_dict.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <sstream>
#include <stdexcept>
namespace torch {
namespace jit {
IValue ScriptDictIterator::next() {
if (iter_ == end_) {
throw py::stop_iteration();
}
// Since this is the iterator for .items(), the current key and value
// should be returned as a tuple.
IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
// Advance the iterator for next time.
iter_++;
return result;
}
IValue ScriptDictKeyIterator::next() {
if (iter_ == end_) {
throw py::stop_iteration();
}
// Since this is the iterator for .keys() and __iter__(), return only the key.
IValue result = iter_->key();
// Advance the iterator for next time.
iter_++;
return result;
}
void initScriptDictBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
.def(
"__next__",
[](ScriptDictKeyIterator& iter) {
auto result = iter.next();
return toPyObject(result);
})
.def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
.def(
"__next__",
[](ScriptDictIterator& iter) {
auto result = iter.next();
return toPyObject(result);
})
.def("__iter__", [](ScriptDictIterator& iter) { return iter; });
py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
.def(py::init([](py::dict dict) {
TypePtr type = nullptr;
if (dict.size() > 0) {
// If the source dictionary is nonempty, try to infer its type.
auto inferred_type = tryToInferType(dict);
if (!inferred_type.success()) {
std::stringstream ss;
ss << "Unable to infer type of dictionary: "
<< inferred_type.reason();
throw JITException(ss.str());
}
type = inferred_type.type();
} else {
// If is empty, assume the type is Dict[str, Tensor] as is done in
// TorchScript code.
type = DictType::create(StringType::get(), TensorType::getInferred());
}
auto data = toIValue(std::move(dict), type);
return std::make_shared<ScriptDict>(data);
}))
.def(
"__repr__",
[](const std::shared_ptr<ScriptDict>& self) {
return toPyObject(self->repr());
})
.def(
"__bool__",
[](const std::shared_ptr<ScriptDict>& self) {
return toPyObject(self->toBool());
})
.def(
"__len__",
[](const std::shared_ptr<ScriptDict>& self) {
return toPyObject(self->len());
})
.def(
"__contains__",
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
try {
return toPyObject(self->contains(
toIValue(std::move(key), self->type()->getKeyType())));
} catch (const py::cast_error& e) {
throw py::key_error();
}
})
.def(
"__getitem__",
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
IValue value;
// Convert key to IValue.
try {
value = toIValue(std::move(key), self->type()->getKeyType());
} catch (const py::cast_error& e) {
// It would be nice to throw py::type_error here but py::key_error
// needs to be thrown for parity with eager mode.
throw py::key_error();
}
// Call getItem on self.
try {
value = self->getItem(value);
} catch (const std::out_of_range& e) { // Key doesn't exist.
throw py::key_error();
}
return toPyObject(std::move(value));
},
py::return_value_policy::
reference_internal) // Return value is a reference to an object
// that resides in the ScriptDict
.def(
"__setitem__",
[](const std::shared_ptr<ScriptDict>& self,
py::object key,
py::object value) {
IValue key_ivalue, value_ivalue;
// Try to convert the key to an IValue.
try {
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
} catch (const py::cast_error& e) {
throw py::type_error();
}
// Try to convert the value to an IValue.
try {
value_ivalue =
toIValue(std::move(value), self->type()->getValueType());
} catch (const py::cast_error& e) {
throw py::type_error();
}
self->setItem(key_ivalue, value_ivalue);
})
.def(
"__delitem__",
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
IValue key_ivalue;
// Try to convert the key to an IValue.
try {
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
} catch (const py::cast_error& e) {
throw py::type_error();
}
// If removed = false, that means the key didn't exist in the
// dictionary.
bool removed = self->delItem(key_ivalue);
if (!removed) {
throw py::key_error();
}
})
.def(
"__iter__",
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
// long as the iterator
.def(
"items",
[](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
// long as the iterator
.def(
"keys",
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
// long as the iterator
}
} // namespace jit
} // namespace torch