mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
[JIT] Add torch._C.ScriptDict (#52659)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52659 **Summary** This commit adds `torch._C.ScriptDict`, a dictionary type that has reference semantics across the Python/TorchScript boundary. That is, modifications made to instances of `torch._C.ScriptDict` in TorchScript are visible in Python even when it is not returned from the function. Instances can be constructed by passing an instance of a Python dictionary to `torch.jit.script`. In the case of an empty dictionary, its type is assumed to be `Dict[str, Tensor]` to be consistent with the handling of empty dictionaries in TorchScript source code. `torch._C.ScriptDict` is implemented using a modified version of pybind's `stl_bind.h`-style bindings attached to `ScriptDict`, `ScriptDictIterator` and `ScriptDictKeyIterator`, wrapper classes around `c10::impl::GenericDict` and `c10::impl::GenericDict::iterator`. These bindings allow instances of `torch._C.ScriptDict` to be used as if it were a regular `dict` Python. Reference semantics are achieved by simply retrieving the `IValue` contained in `ScriptDict` in `toIValue` (invoked when converting Python arguments to `IValues` before calling TorchScript code). **Test Plan** This commit adds `TestScriptDict` to `test_list_dict.py`, a set of tests that check that all of the common dictionary operations are supported and that instances have reference semantics across the Python/TorchScript boundary. Differential Revision: D27211605 D27211605 Test Plan: Imported from OSS Reviewed By: gmagogsfm Pulled By: SplitInfinity fbshipit-source-id: 446d4e5328375791aa73eb9e8b04dfe3465af960
This commit is contained in:
committed by
Facebook GitHub Bot
parent
95b1bc1009
commit
b14c3205fd
201
torch/csrc/jit/python/python_dict.cpp
Normal file
201
torch/csrc/jit/python/python_dict.cpp
Normal file
@ -0,0 +1,201 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/jit/python/python_dict.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
IValue ScriptDictIterator::next() {
|
||||
if (iter_ == end_) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
|
||||
// Since this is the iterator for .items(), the current key and value
|
||||
// should be returned as a tuple.
|
||||
IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
|
||||
|
||||
// Advance the iterator for next time.
|
||||
iter_++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
IValue ScriptDictKeyIterator::next() {
|
||||
if (iter_ == end_) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
|
||||
// Since this is the iterator for .keys() and __iter__(), return only the key.
|
||||
IValue result = iter_->key();
|
||||
|
||||
// Advance the iterator for next time.
|
||||
iter_++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void initScriptDictBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
|
||||
.def(
|
||||
"__next__",
|
||||
[](ScriptDictKeyIterator& iter) {
|
||||
auto result = iter.next();
|
||||
return toPyObject(result);
|
||||
})
|
||||
.def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
|
||||
|
||||
py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
|
||||
.def(
|
||||
"__next__",
|
||||
[](ScriptDictIterator& iter) {
|
||||
auto result = iter.next();
|
||||
return toPyObject(result);
|
||||
})
|
||||
.def("__iter__", [](ScriptDictIterator& iter) { return iter; });
|
||||
|
||||
py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
|
||||
.def(py::init([](py::dict dict) {
|
||||
TypePtr type = nullptr;
|
||||
|
||||
if (dict.size() > 0) {
|
||||
// If the source dictionary is nonempty, try to infer its type.
|
||||
auto inferred_type = tryToInferType(dict);
|
||||
|
||||
if (!inferred_type.success()) {
|
||||
std::stringstream ss;
|
||||
ss << "Unable to infer type of dictionary: "
|
||||
<< inferred_type.reason();
|
||||
throw JITException(ss.str());
|
||||
}
|
||||
|
||||
type = inferred_type.type();
|
||||
} else {
|
||||
// If is empty, assume the type is Dict[str, Tensor] as is done in
|
||||
// TorchScript code.
|
||||
type = DictType::create(StringType::get(), TensorType::getInferred());
|
||||
}
|
||||
|
||||
auto data = toIValue(std::move(dict), type);
|
||||
return std::make_shared<ScriptDict>(data);
|
||||
}))
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const std::shared_ptr<ScriptDict>& self) {
|
||||
return toPyObject(self->repr());
|
||||
})
|
||||
.def(
|
||||
"__bool__",
|
||||
[](const std::shared_ptr<ScriptDict>& self) {
|
||||
return toPyObject(self->toBool());
|
||||
})
|
||||
.def(
|
||||
"__len__",
|
||||
[](const std::shared_ptr<ScriptDict>& self) {
|
||||
return toPyObject(self->len());
|
||||
})
|
||||
.def(
|
||||
"__contains__",
|
||||
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
|
||||
try {
|
||||
return toPyObject(self->contains(
|
||||
toIValue(std::move(key), self->type()->getKeyType())));
|
||||
} catch (const py::cast_error& e) {
|
||||
throw py::key_error();
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"__getitem__",
|
||||
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
|
||||
IValue value;
|
||||
|
||||
// Convert key to IValue.
|
||||
try {
|
||||
value = toIValue(std::move(key), self->type()->getKeyType());
|
||||
} catch (const py::cast_error& e) {
|
||||
// It would be nice to throw py::type_error here but py::key_error
|
||||
// needs to be thrown for parity with eager mode.
|
||||
throw py::key_error();
|
||||
}
|
||||
|
||||
// Call getItem on self.
|
||||
try {
|
||||
value = self->getItem(value);
|
||||
} catch (const std::out_of_range& e) { // Key doesn't exist.
|
||||
throw py::key_error();
|
||||
}
|
||||
|
||||
return toPyObject(std::move(value));
|
||||
},
|
||||
py::return_value_policy::
|
||||
reference_internal) // Return value is a reference to an object
|
||||
// that resides in the ScriptDict
|
||||
.def(
|
||||
"__setitem__",
|
||||
[](const std::shared_ptr<ScriptDict>& self,
|
||||
py::object key,
|
||||
py::object value) {
|
||||
IValue key_ivalue, value_ivalue;
|
||||
|
||||
// Try to convert the key to an IValue.
|
||||
try {
|
||||
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
|
||||
} catch (const py::cast_error& e) {
|
||||
throw py::type_error();
|
||||
}
|
||||
|
||||
// Try to convert the value to an IValue.
|
||||
try {
|
||||
value_ivalue =
|
||||
toIValue(std::move(value), self->type()->getValueType());
|
||||
} catch (const py::cast_error& e) {
|
||||
throw py::type_error();
|
||||
}
|
||||
|
||||
self->setItem(key_ivalue, value_ivalue);
|
||||
})
|
||||
.def(
|
||||
"__delitem__",
|
||||
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
|
||||
IValue key_ivalue;
|
||||
|
||||
// Try to convert the key to an IValue.
|
||||
try {
|
||||
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
|
||||
} catch (const py::cast_error& e) {
|
||||
throw py::type_error();
|
||||
}
|
||||
|
||||
// If removed = false, that means the key didn't exist in the
|
||||
// dictionary.
|
||||
bool removed = self->delItem(key_ivalue);
|
||||
|
||||
if (!removed) {
|
||||
throw py::key_error();
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"__iter__",
|
||||
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
|
||||
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
|
||||
// long as the iterator
|
||||
.def(
|
||||
"items",
|
||||
[](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
|
||||
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
|
||||
// long as the iterator
|
||||
.def(
|
||||
"keys",
|
||||
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
|
||||
py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
|
||||
// long as the iterator
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user