Add dict comprehension (#47774)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47774

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D25615464

Pulled By: ansley

fbshipit-source-id: 10bba6f70e812fa580cbbbf097e93de7142484cc
This commit is contained in:
Ansley Ussery
2020-12-17 15:18:04 -08:00
committed by Facebook GitHub Bot
parent ea4ccc730e
commit d17dc37112
9 changed files with 168 additions and 12 deletions

View File

@ -58,7 +58,7 @@ namespace c10 {
_(prim, ReturnStmt) \ _(prim, ReturnStmt) \
_(prim, BreakStmt) \ _(prim, BreakStmt) \
_(prim, ContinueStmt) \ _(prim, ContinueStmt) \
_(prim, ListComprehensionScope) \ _(prim, ComprehensionScope) \
_(prim, Store) \ _(prim, Store) \
_(prim, AutogradZero) \ _(prim, AutogradZero) \
_(prim, AutogradAnyNonZero) \ _(prim, AutogradAnyNonZero) \

View File

@ -331,6 +331,40 @@ class TestJit(JitTestCase):
def dot(points, query, dim): def dot(points, query, dim):
return (points * query).sum(dim) return (points * query).sum(dim)
def test_dict_comprehension(self):
def fn():
return {i : chr(i + 65) for i in range(4)}
self.checkScript(fn, ())
def test_dict_comprehension_with_type_annotation(self):
def fn():
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
return d
self.checkScript(fn, ())
with self.assertRaisesRegex(RuntimeError, ""):
with self.assertRaisesRegex(AssertionError, "Expected Dict "
"type annotation for dict "
"comprehension, found "
"Tuple[int, str]"):
@torch.jit.script
def fn():
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
return d
def test_dict_comprehension_scope(self):
def comprehension_can_access_outer_scope_variables():
lst = ["foo", "bar", "baz"]
return {l : len(l) for l in lst}
self.checkScript(comprehension_can_access_outer_scope_variables, ())
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
@torch.jit.script
def outer_scope_cannot_access_comprehension_variables():
d = {i : chr(i + 65) for i in range(4)}
i = i + 1
def test_constants_pkl(self): def test_constants_pkl(self):
# This test asserts that the serialization archive includes a `constants.pkl` # This test asserts that the serialization archive includes a `constants.pkl`
# file. This file is used by `torch.load` to determine whether a zip file # file. This file is used by `torch.load` to determine whether a zip file

View File

@ -158,7 +158,7 @@ struct ControlFlowLoadStores {
case prim::Store: { case prim::Store: {
environment_stack->setVar(n->s(attr::name), n->input()->type()); environment_stack->setVar(n->s(attr::name), n->input()->type());
} break; } break;
case prim::ListComprehensionScope: { case prim::ComprehensionScope: {
addControlFlowLoadStores(n->blocks().at(0)); addControlFlowLoadStores(n->blocks().at(0));
} break; } break;
} }
@ -205,7 +205,7 @@ struct EraseLoadStores {
n->output()->replaceAllUsesWith(var); n->output()->replaceAllUsesWith(var);
n->destroy(); n->destroy();
} break; } break;
case prim::ListComprehensionScope: { case prim::ComprehensionScope: {
// writes within a local variable scope do not leak into // writes within a local variable scope do not leak into
// the rest of the graph // the rest of the graph
auto body = n->blocks().at(0); auto body = n->blocks().at(0);

View File

@ -1283,7 +1283,7 @@ struct to_ir {
// comprehension introduces it's own scope. no variable assigned // comprehension introduces it's own scope. no variable assigned
// leaks into the rest of the graph // leaks into the rest of the graph
Node* n = Node* n =
graph->insertNode(create(prim::ListComprehensionScope, lc.range(), 0)); graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
auto* comprehension_block = n->addBlock(); auto* comprehension_block = n->addBlock();
pushFrame(comprehension_block); pushFrame(comprehension_block);
WithInsertPoint guard(comprehension_block); WithInsertPoint guard(comprehension_block);
@ -1302,6 +1302,52 @@ struct to_ir {
return list_value; return list_value;
} }
Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) {
const auto loc = dc.range();
const auto targets_list = List<Expr>::create(dc.range(), {dc.target()});
const auto itrs = List<Expr>::create(dc.range(), {dc.iter()});
Value* dict_value =
graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
// Set the default type to be Dict[Str, Tensor]
dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
bool type_set = false;
if (type_hint) {
if (!type_hint->cast<DictType>()) {
throw ErrorReport(loc)
<< "Expected Dict type annotation for dict comprehension"
", found "
<< type_hint->repr_str();
}
dict_value->setType(type_hint);
type_set = true;
}
// A dict comprehension introduces its own scope. No variable assigned
// may leak into the rest of the graph
Node* n =
graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0));
auto* comprehension_block = n->addBlock();
pushFrame(comprehension_block);
WithInsertPoint guard(comprehension_block);
auto emit_body = [&]() {
auto k = emitExpr(dc.key());
auto v = emitExpr(dc.value());
if (!type_set) {
dict_value->setType(DictType::create(k->type(), v->type()));
type_set = true;
}
NamedValue self = NamedValue(loc, "self", dict_value);
NamedValue input_k = NamedValue(loc, "", k);
NamedValue input_v = NamedValue(loc, "", v);
emitBuiltinCall(
loc, *graph, aten::_set_item, {self, input_k, input_v}, {});
};
emitFor(targets_list, itrs, loc, emit_body);
popFrame();
return dict_value;
}
// Insert subtyping refinements // Insert subtyping refinements
void insertRefinements(const SourceRange& loc, const RefinementSet& ref) { void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
for (const Refinement& r : ref.activeRefinements()) { for (const Refinement& r : ref.activeRefinements()) {
@ -3397,6 +3443,10 @@ struct to_ir {
auto lc = ListComp(tree); auto lc = ListComp(tree);
return emitListComprehension(lc, type_hint); return emitListComprehension(lc, type_hint);
} break; } break;
case TK_DICT_COMP: {
auto dc = DictComp(tree);
return emitDictComprehension(dc, type_hint);
} break;
default: default:
throw ErrorReport(tree) << "Cannot emit expr for: " << tree; throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
} }

View File

@ -102,6 +102,7 @@ namespace jit {
_(TK_ASSERT, "assert", "assert") \ _(TK_ASSERT, "assert", "assert") \
_(TK_DOTS, "dots", "...") \ _(TK_DOTS, "dots", "...") \
_(TK_LIST_COMP, "list comprehension", "") \ _(TK_LIST_COMP, "list comprehension", "") \
_(TK_DICT_COMP, "dict comprehension", "") \
_(TK_BREAK, "break", "break") \ _(TK_BREAK, "break", "break") \
_(TK_CONTINUE, "continue", "continue") \ _(TK_CONTINUE, "continue", "continue") \
_(TK_DELETE, "del", "del") \ _(TK_DELETE, "del", "del") \

View File

@ -144,6 +144,16 @@ struct ParserImpl {
} break; } break;
case '{': { case '{': {
L.next(); L.next();
// If we have a dict literal, `keys` and `values` will store the keys
// and values used in the object's construction. EDGE CASE: We have a
// dict comprehension, so we'll get the first element of the dict
// comprehension in `keys` and a list comprehension in `values`.
// For example, `{i : chr(i + 65) for i in range(4)}` would give us
// `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
// The optimal way of handling this case is to simply splice the new
// dict comprehension together from the existing list comprehension.
// Splicing prevents breaking changes to our API and does not require
// the use of global variables.
std::vector<Expr> keys; std::vector<Expr> keys;
std::vector<Expr> values; std::vector<Expr> values;
auto range = L.cur().range; auto range = L.cur().range;
@ -155,10 +165,16 @@ struct ParserImpl {
} while (L.nextIf(',')); } while (L.nextIf(','));
} }
L.expect('}'); L.expect('}');
prefix = DictLiteral::create( if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
range, ListComp lc(*values.begin());
List<Expr>::create(range, keys), prefix = DictComp::create(
List<Expr>::create(range, values)); range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
} else {
prefix = DictLiteral::create(
range,
List<Expr>::create(range, keys),
List<Expr>::create(range, values));
}
} break; } break;
case TK_STRINGLITERAL: { case TK_STRINGLITERAL: {
prefix = parseConcatenatedStringLiterals(); prefix = parseConcatenatedStringLiterals();

View File

@ -309,6 +309,7 @@ struct Expr : public TreeView {
case '^': case '^':
case '|': case '|':
case TK_LIST_COMP: case TK_LIST_COMP:
case TK_DICT_COMP:
case TK_DOTS: case TK_DOTS:
case TK_IN: case TK_IN:
case TK_WITH_ITEM: case TK_WITH_ITEM:
@ -579,6 +580,35 @@ struct ListComp : public Expr {
} }
}; };
// TODO: supports only single comprehension for now
struct DictComp : public Expr {
explicit DictComp(const TreeRef& tree) : Expr(tree) {
tree->match(TK_DICT_COMP);
}
Expr key() const {
return Expr(subtree(0));
}
Expr value() const {
return Expr(subtree(1));
}
Expr target() const {
return Expr(subtree(2));
}
Expr iter() const {
return Expr(subtree(3));
}
// TODO: no ifs for now
static DictComp create(
const SourceRange& range,
const Expr& key,
const Expr& value,
const Expr& target,
const Expr& iter) {
return DictComp(
Compound::create(TK_DICT_COMP, range, {key, value, target, iter}));
}
};
struct Global : public Stmt { struct Global : public Stmt {
explicit Global(const TreeRef& tree) : Stmt(tree) { explicit Global(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_GLOBAL); tree_->match(TK_GLOBAL);

View File

@ -352,6 +352,14 @@ void initTreeViewBindings(PyObject* module) {
const Expr& iter) { const Expr& iter) {
return ListComp::create(range, elt, target, iter); return ListComp::create(range, elt, target, iter);
})); }));
py::class_<DictComp, Expr>(m, "DictComp")
.def(py::init([](const SourceRange& range,
const Expr& key,
const Expr& value,
const Expr& target,
const Expr& iter) {
return DictComp::create(range, key, value, target, iter);
}));
py::class_<ListLiteral, Expr>(m, "ListLiteral") py::class_<ListLiteral, Expr>(m, "ListLiteral")
.def(py::init([](const SourceRange& range, std::vector<Expr> args) { .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
return ListLiteral::create(range, wrap_list(range, std::move(args))); return ListLiteral::create(range, wrap_list(range, std::move(args)));

View File

@ -14,6 +14,7 @@ from torch._C._jit_tree_views import (
ListLiteral, TupleLiteral, DictLiteral, Const, ListLiteral, TupleLiteral, DictLiteral, Const,
StringLiteral, ListComp, Attribute, BinOp, UnaryOp, StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
SliceExpr, Subscript, TernaryIf, With, WithItem, Property, SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
) )
from torch._utils_internal import get_source_lines_and_file from torch._utils_internal import get_source_lines_and_file
@ -810,18 +811,34 @@ class ExprBuilder(Builder):
@staticmethod @staticmethod
def build_ListComp(ctx, stmt): def build_ListComp(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
if (len(stmt.generators) > 1): if (len(stmt.generators) != 1):
raise NotSupportedError(r, "multiple comprehension generators not supported yet") raise NotSupportedError(r, "Only a single generator is currently supported")
if (len(stmt.generators[0].ifs) != 0): if (len(stmt.generators[0].ifs) != 0):
raise NotSupportedError(r, "comprehension ifs not supported yet") raise NotSupportedError(r, "Comprehension ifs are not supported yet")
elt_expr = build_expr(ctx, stmt.elt) elt_expr = build_expr(ctx, stmt.elt)
target_expr = build_expr(ctx, stmt.generators[0].target) target_expr = build_expr(ctx, stmt.generators[0].target)
iter_expr = build_expr(ctx, stmt.generators[0].iter) iter_expr = build_expr(ctx, stmt.generators[0].iter)
return ListComp(r, elt_expr, target_expr, iter_expr) return ListComp(r, elt_expr, target_expr, iter_expr)
@staticmethod
def build_DictComp(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
if (len(stmt.generators) != 1):
raise NotSupportedError(r, "Only a single generator is currently supported")
if (len(stmt.generators[0].ifs) != 0):
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
key_expr = build_expr(ctx, stmt.key)
value_expr = build_expr(ctx, stmt.value)
target_expr = build_expr(ctx, stmt.generators[0].target)
iter_expr = build_expr(ctx, stmt.generators[0].iter)
return DictComp(r, key_expr, value_expr, target_expr, iter_expr)
@staticmethod @staticmethod
def build_Starred(ctx, expr): def build_Starred(ctx, expr):
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)