Add dict comprehension (#47774)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47774

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D25615464

Pulled By: ansley

fbshipit-source-id: 10bba6f70e812fa580cbbbf097e93de7142484cc
This commit is contained in:
Ansley Ussery
2020-12-17 15:18:04 -08:00
committed by Facebook GitHub Bot
parent ea4ccc730e
commit d17dc37112
9 changed files with 168 additions and 12 deletions

View File

@ -58,7 +58,7 @@ namespace c10 {
_(prim, ReturnStmt) \
_(prim, BreakStmt) \
_(prim, ContinueStmt) \
_(prim, ListComprehensionScope) \
_(prim, ComprehensionScope) \
_(prim, Store) \
_(prim, AutogradZero) \
_(prim, AutogradAnyNonZero) \

View File

@ -331,6 +331,40 @@ class TestJit(JitTestCase):
def dot(points, query, dim):
return (points * query).sum(dim)
def test_dict_comprehension(self):
def fn():
return {i : chr(i + 65) for i in range(4)}
self.checkScript(fn, ())
def test_dict_comprehension_with_type_annotation(self):
def fn():
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
return d
self.checkScript(fn, ())
with self.assertRaisesRegex(RuntimeError, ""):
with self.assertRaisesRegex(AssertionError, "Expected Dict "
"type annotation for dict "
"comprehension, found "
"Tuple[int, str]"):
@torch.jit.script
def fn():
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
return d
def test_dict_comprehension_scope(self):
def comprehension_can_access_outer_scope_variables():
lst = ["foo", "bar", "baz"]
return {l : len(l) for l in lst}
self.checkScript(comprehension_can_access_outer_scope_variables, ())
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
@torch.jit.script
def outer_scope_cannot_access_comprehension_variables():
d = {i : chr(i + 65) for i in range(4)}
i = i + 1
def test_constants_pkl(self):
# This test asserts that the serialization archive includes a `constants.pkl`
# file. This file is used by `torch.load` to determine whether a zip file

View File

@ -158,7 +158,7 @@ struct ControlFlowLoadStores {
case prim::Store: {
environment_stack->setVar(n->s(attr::name), n->input()->type());
} break;
case prim::ListComprehensionScope: {
case prim::ComprehensionScope: {
addControlFlowLoadStores(n->blocks().at(0));
} break;
}
@ -205,7 +205,7 @@ struct EraseLoadStores {
n->output()->replaceAllUsesWith(var);
n->destroy();
} break;
case prim::ListComprehensionScope: {
case prim::ComprehensionScope: {
// writes within a local variable scope do not leak into
// the rest of the graph
auto body = n->blocks().at(0);

View File

@ -1283,7 +1283,7 @@ struct to_ir {
// comprehension introduces it's own scope. no variable assigned
// leaks into the rest of the graph
Node* n =
graph->insertNode(create(prim::ListComprehensionScope, lc.range(), 0));
graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
auto* comprehension_block = n->addBlock();
pushFrame(comprehension_block);
WithInsertPoint guard(comprehension_block);
@ -1302,6 +1302,52 @@ struct to_ir {
return list_value;
}
Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) {
const auto loc = dc.range();
const auto targets_list = List<Expr>::create(dc.range(), {dc.target()});
const auto itrs = List<Expr>::create(dc.range(), {dc.iter()});
Value* dict_value =
graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
// Set the default type to be Dict[Str, Tensor]
dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
bool type_set = false;
if (type_hint) {
if (!type_hint->cast<DictType>()) {
throw ErrorReport(loc)
<< "Expected Dict type annotation for dict comprehension"
", found "
<< type_hint->repr_str();
}
dict_value->setType(type_hint);
type_set = true;
}
// A dict comprehension introduces its own scope. No variable assigned
// may leak into the rest of the graph
Node* n =
graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0));
auto* comprehension_block = n->addBlock();
pushFrame(comprehension_block);
WithInsertPoint guard(comprehension_block);
auto emit_body = [&]() {
auto k = emitExpr(dc.key());
auto v = emitExpr(dc.value());
if (!type_set) {
dict_value->setType(DictType::create(k->type(), v->type()));
type_set = true;
}
NamedValue self = NamedValue(loc, "self", dict_value);
NamedValue input_k = NamedValue(loc, "", k);
NamedValue input_v = NamedValue(loc, "", v);
emitBuiltinCall(
loc, *graph, aten::_set_item, {self, input_k, input_v}, {});
};
emitFor(targets_list, itrs, loc, emit_body);
popFrame();
return dict_value;
}
// Insert subtyping refinements
void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
for (const Refinement& r : ref.activeRefinements()) {
@ -3397,6 +3443,10 @@ struct to_ir {
auto lc = ListComp(tree);
return emitListComprehension(lc, type_hint);
} break;
case TK_DICT_COMP: {
auto dc = DictComp(tree);
return emitDictComprehension(dc, type_hint);
} break;
default:
throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
}

View File

@ -102,6 +102,7 @@ namespace jit {
_(TK_ASSERT, "assert", "assert") \
_(TK_DOTS, "dots", "...") \
_(TK_LIST_COMP, "list comprehension", "") \
_(TK_DICT_COMP, "dict comprehension", "") \
_(TK_BREAK, "break", "break") \
_(TK_CONTINUE, "continue", "continue") \
_(TK_DELETE, "del", "del") \

View File

@ -144,6 +144,16 @@ struct ParserImpl {
} break;
case '{': {
L.next();
// If we have a dict literal, `keys` and `values` will store the keys
// and values used in the object's construction. EDGE CASE: We have a
// dict comprehension, so we'll get the first element of the dict
// comprehension in `keys` and a list comprehension in `values`.
// For example, `{i : chr(i + 65) for i in range(4)}` would give us
// `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
// The optimal way of handling this case is to simply splice the new
// dict comprehension together from the existing list comprehension.
// Splicing prevents breaking changes to our API and does not require
// the use of global variables.
std::vector<Expr> keys;
std::vector<Expr> values;
auto range = L.cur().range;
@ -155,10 +165,16 @@ struct ParserImpl {
} while (L.nextIf(','));
}
L.expect('}');
if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
ListComp lc(*values.begin());
prefix = DictComp::create(
range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
} else {
prefix = DictLiteral::create(
range,
List<Expr>::create(range, keys),
List<Expr>::create(range, values));
}
} break;
case TK_STRINGLITERAL: {
prefix = parseConcatenatedStringLiterals();

View File

@ -309,6 +309,7 @@ struct Expr : public TreeView {
case '^':
case '|':
case TK_LIST_COMP:
case TK_DICT_COMP:
case TK_DOTS:
case TK_IN:
case TK_WITH_ITEM:
@ -579,6 +580,35 @@ struct ListComp : public Expr {
}
};
// TODO: supports only single comprehension for now
struct DictComp : public Expr {
explicit DictComp(const TreeRef& tree) : Expr(tree) {
tree->match(TK_DICT_COMP);
}
Expr key() const {
return Expr(subtree(0));
}
Expr value() const {
return Expr(subtree(1));
}
Expr target() const {
return Expr(subtree(2));
}
Expr iter() const {
return Expr(subtree(3));
}
// TODO: no ifs for now
static DictComp create(
const SourceRange& range,
const Expr& key,
const Expr& value,
const Expr& target,
const Expr& iter) {
return DictComp(
Compound::create(TK_DICT_COMP, range, {key, value, target, iter}));
}
};
struct Global : public Stmt {
explicit Global(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_GLOBAL);

View File

@ -352,6 +352,14 @@ void initTreeViewBindings(PyObject* module) {
const Expr& iter) {
return ListComp::create(range, elt, target, iter);
}));
py::class_<DictComp, Expr>(m, "DictComp")
.def(py::init([](const SourceRange& range,
const Expr& key,
const Expr& value,
const Expr& target,
const Expr& iter) {
return DictComp::create(range, key, value, target, iter);
}));
py::class_<ListLiteral, Expr>(m, "ListLiteral")
.def(py::init([](const SourceRange& range, std::vector<Expr> args) {
return ListLiteral::create(range, wrap_list(range, std::move(args)));

View File

@ -14,6 +14,7 @@ from torch._C._jit_tree_views import (
ListLiteral, TupleLiteral, DictLiteral, Const,
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
from torch._utils_internal import get_source_lines_and_file
@ -810,18 +811,34 @@ class ExprBuilder(Builder):
@staticmethod
def build_ListComp(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
if (len(stmt.generators) > 1):
raise NotSupportedError(r, "multiple comprehension generators not supported yet")
if (len(stmt.generators) != 1):
raise NotSupportedError(r, "Only a single generator is currently supported")
if (len(stmt.generators[0].ifs) != 0):
raise NotSupportedError(r, "comprehension ifs not supported yet")
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
elt_expr = build_expr(ctx, stmt.elt)
target_expr = build_expr(ctx, stmt.generators[0].target)
iter_expr = build_expr(ctx, stmt.generators[0].iter)
return ListComp(r, elt_expr, target_expr, iter_expr)
@staticmethod
def build_DictComp(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
if (len(stmt.generators) != 1):
raise NotSupportedError(r, "Only a single generator is currently supported")
if (len(stmt.generators[0].ifs) != 0):
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
key_expr = build_expr(ctx, stmt.key)
value_expr = build_expr(ctx, stmt.value)
target_expr = build_expr(ctx, stmt.generators[0].target)
iter_expr = build_expr(ctx, stmt.generators[0].iter)
return DictComp(r, key_expr, value_expr, target_expr, iter_expr)
@staticmethod
def build_Starred(ctx, expr):
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)