mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add dict comprehension (#47774)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47774 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D25615464 Pulled By: ansley fbshipit-source-id: 10bba6f70e812fa580cbbbf097e93de7142484cc
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ea4ccc730e
commit
d17dc37112
@ -58,7 +58,7 @@ namespace c10 {
|
||||
_(prim, ReturnStmt) \
|
||||
_(prim, BreakStmt) \
|
||||
_(prim, ContinueStmt) \
|
||||
_(prim, ListComprehensionScope) \
|
||||
_(prim, ComprehensionScope) \
|
||||
_(prim, Store) \
|
||||
_(prim, AutogradZero) \
|
||||
_(prim, AutogradAnyNonZero) \
|
||||
|
@ -331,6 +331,40 @@ class TestJit(JitTestCase):
|
||||
def dot(points, query, dim):
|
||||
return (points * query).sum(dim)
|
||||
|
||||
def test_dict_comprehension(self):
|
||||
def fn():
|
||||
return {i : chr(i + 65) for i in range(4)}
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_dict_comprehension_with_type_annotation(self):
|
||||
def fn():
|
||||
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
|
||||
return d
|
||||
self.checkScript(fn, ())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, ""):
|
||||
with self.assertRaisesRegex(AssertionError, "Expected Dict "
|
||||
"type annotation for dict "
|
||||
"comprehension, found "
|
||||
"Tuple[int, str]"):
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
|
||||
return d
|
||||
|
||||
def test_dict_comprehension_scope(self):
|
||||
def comprehension_can_access_outer_scope_variables():
|
||||
lst = ["foo", "bar", "baz"]
|
||||
return {l : len(l) for l in lst}
|
||||
|
||||
self.checkScript(comprehension_can_access_outer_scope_variables, ())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
|
||||
@torch.jit.script
|
||||
def outer_scope_cannot_access_comprehension_variables():
|
||||
d = {i : chr(i + 65) for i in range(4)}
|
||||
i = i + 1
|
||||
|
||||
def test_constants_pkl(self):
|
||||
# This test asserts that the serialization archive includes a `constants.pkl`
|
||||
# file. This file is used by `torch.load` to determine whether a zip file
|
||||
|
@ -158,7 +158,7 @@ struct ControlFlowLoadStores {
|
||||
case prim::Store: {
|
||||
environment_stack->setVar(n->s(attr::name), n->input()->type());
|
||||
} break;
|
||||
case prim::ListComprehensionScope: {
|
||||
case prim::ComprehensionScope: {
|
||||
addControlFlowLoadStores(n->blocks().at(0));
|
||||
} break;
|
||||
}
|
||||
@ -205,7 +205,7 @@ struct EraseLoadStores {
|
||||
n->output()->replaceAllUsesWith(var);
|
||||
n->destroy();
|
||||
} break;
|
||||
case prim::ListComprehensionScope: {
|
||||
case prim::ComprehensionScope: {
|
||||
// writes within a local variable scope do not leak into
|
||||
// the rest of the graph
|
||||
auto body = n->blocks().at(0);
|
||||
|
@ -1283,7 +1283,7 @@ struct to_ir {
|
||||
// comprehension introduces it's own scope. no variable assigned
|
||||
// leaks into the rest of the graph
|
||||
Node* n =
|
||||
graph->insertNode(create(prim::ListComprehensionScope, lc.range(), 0));
|
||||
graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
|
||||
auto* comprehension_block = n->addBlock();
|
||||
pushFrame(comprehension_block);
|
||||
WithInsertPoint guard(comprehension_block);
|
||||
@ -1302,6 +1302,52 @@ struct to_ir {
|
||||
return list_value;
|
||||
}
|
||||
|
||||
Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) {
|
||||
const auto loc = dc.range();
|
||||
const auto targets_list = List<Expr>::create(dc.range(), {dc.target()});
|
||||
const auto itrs = List<Expr>::create(dc.range(), {dc.iter()});
|
||||
|
||||
Value* dict_value =
|
||||
graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
|
||||
// Set the default type to be Dict[Str, Tensor]
|
||||
dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
|
||||
bool type_set = false;
|
||||
if (type_hint) {
|
||||
if (!type_hint->cast<DictType>()) {
|
||||
throw ErrorReport(loc)
|
||||
<< "Expected Dict type annotation for dict comprehension"
|
||||
", found "
|
||||
<< type_hint->repr_str();
|
||||
}
|
||||
dict_value->setType(type_hint);
|
||||
type_set = true;
|
||||
}
|
||||
|
||||
// A dict comprehension introduces its own scope. No variable assigned
|
||||
// may leak into the rest of the graph
|
||||
Node* n =
|
||||
graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0));
|
||||
auto* comprehension_block = n->addBlock();
|
||||
pushFrame(comprehension_block);
|
||||
WithInsertPoint guard(comprehension_block);
|
||||
auto emit_body = [&]() {
|
||||
auto k = emitExpr(dc.key());
|
||||
auto v = emitExpr(dc.value());
|
||||
if (!type_set) {
|
||||
dict_value->setType(DictType::create(k->type(), v->type()));
|
||||
type_set = true;
|
||||
}
|
||||
NamedValue self = NamedValue(loc, "self", dict_value);
|
||||
NamedValue input_k = NamedValue(loc, "", k);
|
||||
NamedValue input_v = NamedValue(loc, "", v);
|
||||
emitBuiltinCall(
|
||||
loc, *graph, aten::_set_item, {self, input_k, input_v}, {});
|
||||
};
|
||||
emitFor(targets_list, itrs, loc, emit_body);
|
||||
popFrame();
|
||||
return dict_value;
|
||||
}
|
||||
|
||||
// Insert subtyping refinements
|
||||
void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
|
||||
for (const Refinement& r : ref.activeRefinements()) {
|
||||
@ -3397,6 +3443,10 @@ struct to_ir {
|
||||
auto lc = ListComp(tree);
|
||||
return emitListComprehension(lc, type_hint);
|
||||
} break;
|
||||
case TK_DICT_COMP: {
|
||||
auto dc = DictComp(tree);
|
||||
return emitDictComprehension(dc, type_hint);
|
||||
} break;
|
||||
default:
|
||||
throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
|
||||
}
|
||||
|
@ -102,6 +102,7 @@ namespace jit {
|
||||
_(TK_ASSERT, "assert", "assert") \
|
||||
_(TK_DOTS, "dots", "...") \
|
||||
_(TK_LIST_COMP, "list comprehension", "") \
|
||||
_(TK_DICT_COMP, "dict comprehension", "") \
|
||||
_(TK_BREAK, "break", "break") \
|
||||
_(TK_CONTINUE, "continue", "continue") \
|
||||
_(TK_DELETE, "del", "del") \
|
||||
|
@ -144,6 +144,16 @@ struct ParserImpl {
|
||||
} break;
|
||||
case '{': {
|
||||
L.next();
|
||||
// If we have a dict literal, `keys` and `values` will store the keys
|
||||
// and values used in the object's construction. EDGE CASE: We have a
|
||||
// dict comprehension, so we'll get the first element of the dict
|
||||
// comprehension in `keys` and a list comprehension in `values`.
|
||||
// For example, `{i : chr(i + 65) for i in range(4)}` would give us
|
||||
// `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
|
||||
// The optimal way of handling this case is to simply splice the new
|
||||
// dict comprehension together from the existing list comprehension.
|
||||
// Splicing prevents breaking changes to our API and does not require
|
||||
// the use of global variables.
|
||||
std::vector<Expr> keys;
|
||||
std::vector<Expr> values;
|
||||
auto range = L.cur().range;
|
||||
@ -155,10 +165,16 @@ struct ParserImpl {
|
||||
} while (L.nextIf(','));
|
||||
}
|
||||
L.expect('}');
|
||||
if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
|
||||
ListComp lc(*values.begin());
|
||||
prefix = DictComp::create(
|
||||
range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
|
||||
} else {
|
||||
prefix = DictLiteral::create(
|
||||
range,
|
||||
List<Expr>::create(range, keys),
|
||||
List<Expr>::create(range, values));
|
||||
}
|
||||
} break;
|
||||
case TK_STRINGLITERAL: {
|
||||
prefix = parseConcatenatedStringLiterals();
|
||||
|
@ -309,6 +309,7 @@ struct Expr : public TreeView {
|
||||
case '^':
|
||||
case '|':
|
||||
case TK_LIST_COMP:
|
||||
case TK_DICT_COMP:
|
||||
case TK_DOTS:
|
||||
case TK_IN:
|
||||
case TK_WITH_ITEM:
|
||||
@ -579,6 +580,35 @@ struct ListComp : public Expr {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: supports only single comprehension for now
|
||||
struct DictComp : public Expr {
|
||||
explicit DictComp(const TreeRef& tree) : Expr(tree) {
|
||||
tree->match(TK_DICT_COMP);
|
||||
}
|
||||
Expr key() const {
|
||||
return Expr(subtree(0));
|
||||
}
|
||||
Expr value() const {
|
||||
return Expr(subtree(1));
|
||||
}
|
||||
Expr target() const {
|
||||
return Expr(subtree(2));
|
||||
}
|
||||
Expr iter() const {
|
||||
return Expr(subtree(3));
|
||||
}
|
||||
// TODO: no ifs for now
|
||||
static DictComp create(
|
||||
const SourceRange& range,
|
||||
const Expr& key,
|
||||
const Expr& value,
|
||||
const Expr& target,
|
||||
const Expr& iter) {
|
||||
return DictComp(
|
||||
Compound::create(TK_DICT_COMP, range, {key, value, target, iter}));
|
||||
}
|
||||
};
|
||||
|
||||
struct Global : public Stmt {
|
||||
explicit Global(const TreeRef& tree) : Stmt(tree) {
|
||||
tree_->match(TK_GLOBAL);
|
||||
|
@ -352,6 +352,14 @@ void initTreeViewBindings(PyObject* module) {
|
||||
const Expr& iter) {
|
||||
return ListComp::create(range, elt, target, iter);
|
||||
}));
|
||||
py::class_<DictComp, Expr>(m, "DictComp")
|
||||
.def(py::init([](const SourceRange& range,
|
||||
const Expr& key,
|
||||
const Expr& value,
|
||||
const Expr& target,
|
||||
const Expr& iter) {
|
||||
return DictComp::create(range, key, value, target, iter);
|
||||
}));
|
||||
py::class_<ListLiteral, Expr>(m, "ListLiteral")
|
||||
.def(py::init([](const SourceRange& range, std::vector<Expr> args) {
|
||||
return ListLiteral::create(range, wrap_list(range, std::move(args)));
|
||||
|
@ -14,6 +14,7 @@ from torch._C._jit_tree_views import (
|
||||
ListLiteral, TupleLiteral, DictLiteral, Const,
|
||||
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
|
||||
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
|
||||
DictComp,
|
||||
)
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
|
||||
@ -810,18 +811,34 @@ class ExprBuilder(Builder):
|
||||
@staticmethod
|
||||
def build_ListComp(ctx, stmt):
|
||||
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
|
||||
if (len(stmt.generators) > 1):
|
||||
raise NotSupportedError(r, "multiple comprehension generators not supported yet")
|
||||
if (len(stmt.generators) != 1):
|
||||
raise NotSupportedError(r, "Only a single generator is currently supported")
|
||||
|
||||
if (len(stmt.generators[0].ifs) != 0):
|
||||
raise NotSupportedError(r, "comprehension ifs not supported yet")
|
||||
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
|
||||
|
||||
elt_expr = build_expr(ctx, stmt.elt)
|
||||
target_expr = build_expr(ctx, stmt.generators[0].target)
|
||||
|
||||
iter_expr = build_expr(ctx, stmt.generators[0].iter)
|
||||
|
||||
return ListComp(r, elt_expr, target_expr, iter_expr)
|
||||
|
||||
@staticmethod
|
||||
def build_DictComp(ctx, stmt):
|
||||
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
|
||||
if (len(stmt.generators) != 1):
|
||||
raise NotSupportedError(r, "Only a single generator is currently supported")
|
||||
|
||||
if (len(stmt.generators[0].ifs) != 0):
|
||||
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
|
||||
|
||||
key_expr = build_expr(ctx, stmt.key)
|
||||
value_expr = build_expr(ctx, stmt.value)
|
||||
target_expr = build_expr(ctx, stmt.generators[0].target)
|
||||
iter_expr = build_expr(ctx, stmt.generators[0].iter)
|
||||
|
||||
return DictComp(r, key_expr, value_expr, target_expr, iter_expr)
|
||||
|
||||
@staticmethod
|
||||
def build_Starred(ctx, expr):
|
||||
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
|
||||
|
Reference in New Issue
Block a user