mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add dict comprehension (#47774)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47774 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D25615464 Pulled By: ansley fbshipit-source-id: 10bba6f70e812fa580cbbbf097e93de7142484cc
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ea4ccc730e
commit
d17dc37112
@ -58,7 +58,7 @@ namespace c10 {
|
|||||||
_(prim, ReturnStmt) \
|
_(prim, ReturnStmt) \
|
||||||
_(prim, BreakStmt) \
|
_(prim, BreakStmt) \
|
||||||
_(prim, ContinueStmt) \
|
_(prim, ContinueStmt) \
|
||||||
_(prim, ListComprehensionScope) \
|
_(prim, ComprehensionScope) \
|
||||||
_(prim, Store) \
|
_(prim, Store) \
|
||||||
_(prim, AutogradZero) \
|
_(prim, AutogradZero) \
|
||||||
_(prim, AutogradAnyNonZero) \
|
_(prim, AutogradAnyNonZero) \
|
||||||
|
@ -331,6 +331,40 @@ class TestJit(JitTestCase):
|
|||||||
def dot(points, query, dim):
|
def dot(points, query, dim):
|
||||||
return (points * query).sum(dim)
|
return (points * query).sum(dim)
|
||||||
|
|
||||||
|
def test_dict_comprehension(self):
|
||||||
|
def fn():
|
||||||
|
return {i : chr(i + 65) for i in range(4)}
|
||||||
|
self.checkScript(fn, ())
|
||||||
|
|
||||||
|
def test_dict_comprehension_with_type_annotation(self):
|
||||||
|
def fn():
|
||||||
|
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
|
||||||
|
return d
|
||||||
|
self.checkScript(fn, ())
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, ""):
|
||||||
|
with self.assertRaisesRegex(AssertionError, "Expected Dict "
|
||||||
|
"type annotation for dict "
|
||||||
|
"comprehension, found "
|
||||||
|
"Tuple[int, str]"):
|
||||||
|
@torch.jit.script
|
||||||
|
def fn():
|
||||||
|
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
|
||||||
|
return d
|
||||||
|
|
||||||
|
def test_dict_comprehension_scope(self):
|
||||||
|
def comprehension_can_access_outer_scope_variables():
|
||||||
|
lst = ["foo", "bar", "baz"]
|
||||||
|
return {l : len(l) for l in lst}
|
||||||
|
|
||||||
|
self.checkScript(comprehension_can_access_outer_scope_variables, ())
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
|
||||||
|
@torch.jit.script
|
||||||
|
def outer_scope_cannot_access_comprehension_variables():
|
||||||
|
d = {i : chr(i + 65) for i in range(4)}
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
def test_constants_pkl(self):
|
def test_constants_pkl(self):
|
||||||
# This test asserts that the serialization archive includes a `constants.pkl`
|
# This test asserts that the serialization archive includes a `constants.pkl`
|
||||||
# file. This file is used by `torch.load` to determine whether a zip file
|
# file. This file is used by `torch.load` to determine whether a zip file
|
||||||
|
@ -158,7 +158,7 @@ struct ControlFlowLoadStores {
|
|||||||
case prim::Store: {
|
case prim::Store: {
|
||||||
environment_stack->setVar(n->s(attr::name), n->input()->type());
|
environment_stack->setVar(n->s(attr::name), n->input()->type());
|
||||||
} break;
|
} break;
|
||||||
case prim::ListComprehensionScope: {
|
case prim::ComprehensionScope: {
|
||||||
addControlFlowLoadStores(n->blocks().at(0));
|
addControlFlowLoadStores(n->blocks().at(0));
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
@ -205,7 +205,7 @@ struct EraseLoadStores {
|
|||||||
n->output()->replaceAllUsesWith(var);
|
n->output()->replaceAllUsesWith(var);
|
||||||
n->destroy();
|
n->destroy();
|
||||||
} break;
|
} break;
|
||||||
case prim::ListComprehensionScope: {
|
case prim::ComprehensionScope: {
|
||||||
// writes within a local variable scope do not leak into
|
// writes within a local variable scope do not leak into
|
||||||
// the rest of the graph
|
// the rest of the graph
|
||||||
auto body = n->blocks().at(0);
|
auto body = n->blocks().at(0);
|
||||||
|
@ -1283,7 +1283,7 @@ struct to_ir {
|
|||||||
// comprehension introduces it's own scope. no variable assigned
|
// comprehension introduces it's own scope. no variable assigned
|
||||||
// leaks into the rest of the graph
|
// leaks into the rest of the graph
|
||||||
Node* n =
|
Node* n =
|
||||||
graph->insertNode(create(prim::ListComprehensionScope, lc.range(), 0));
|
graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0));
|
||||||
auto* comprehension_block = n->addBlock();
|
auto* comprehension_block = n->addBlock();
|
||||||
pushFrame(comprehension_block);
|
pushFrame(comprehension_block);
|
||||||
WithInsertPoint guard(comprehension_block);
|
WithInsertPoint guard(comprehension_block);
|
||||||
@ -1302,6 +1302,52 @@ struct to_ir {
|
|||||||
return list_value;
|
return list_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) {
|
||||||
|
const auto loc = dc.range();
|
||||||
|
const auto targets_list = List<Expr>::create(dc.range(), {dc.target()});
|
||||||
|
const auto itrs = List<Expr>::create(dc.range(), {dc.iter()});
|
||||||
|
|
||||||
|
Value* dict_value =
|
||||||
|
graph->insertNode(graph->create(prim::DictConstruct, 1))->output();
|
||||||
|
// Set the default type to be Dict[Str, Tensor]
|
||||||
|
dict_value->setType(DictType::create(StringType::get(), TensorType::get()));
|
||||||
|
bool type_set = false;
|
||||||
|
if (type_hint) {
|
||||||
|
if (!type_hint->cast<DictType>()) {
|
||||||
|
throw ErrorReport(loc)
|
||||||
|
<< "Expected Dict type annotation for dict comprehension"
|
||||||
|
", found "
|
||||||
|
<< type_hint->repr_str();
|
||||||
|
}
|
||||||
|
dict_value->setType(type_hint);
|
||||||
|
type_set = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A dict comprehension introduces its own scope. No variable assigned
|
||||||
|
// may leak into the rest of the graph
|
||||||
|
Node* n =
|
||||||
|
graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0));
|
||||||
|
auto* comprehension_block = n->addBlock();
|
||||||
|
pushFrame(comprehension_block);
|
||||||
|
WithInsertPoint guard(comprehension_block);
|
||||||
|
auto emit_body = [&]() {
|
||||||
|
auto k = emitExpr(dc.key());
|
||||||
|
auto v = emitExpr(dc.value());
|
||||||
|
if (!type_set) {
|
||||||
|
dict_value->setType(DictType::create(k->type(), v->type()));
|
||||||
|
type_set = true;
|
||||||
|
}
|
||||||
|
NamedValue self = NamedValue(loc, "self", dict_value);
|
||||||
|
NamedValue input_k = NamedValue(loc, "", k);
|
||||||
|
NamedValue input_v = NamedValue(loc, "", v);
|
||||||
|
emitBuiltinCall(
|
||||||
|
loc, *graph, aten::_set_item, {self, input_k, input_v}, {});
|
||||||
|
};
|
||||||
|
emitFor(targets_list, itrs, loc, emit_body);
|
||||||
|
popFrame();
|
||||||
|
return dict_value;
|
||||||
|
}
|
||||||
|
|
||||||
// Insert subtyping refinements
|
// Insert subtyping refinements
|
||||||
void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
|
void insertRefinements(const SourceRange& loc, const RefinementSet& ref) {
|
||||||
for (const Refinement& r : ref.activeRefinements()) {
|
for (const Refinement& r : ref.activeRefinements()) {
|
||||||
@ -3397,6 +3443,10 @@ struct to_ir {
|
|||||||
auto lc = ListComp(tree);
|
auto lc = ListComp(tree);
|
||||||
return emitListComprehension(lc, type_hint);
|
return emitListComprehension(lc, type_hint);
|
||||||
} break;
|
} break;
|
||||||
|
case TK_DICT_COMP: {
|
||||||
|
auto dc = DictComp(tree);
|
||||||
|
return emitDictComprehension(dc, type_hint);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
|
throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
|
||||||
}
|
}
|
||||||
|
@ -102,6 +102,7 @@ namespace jit {
|
|||||||
_(TK_ASSERT, "assert", "assert") \
|
_(TK_ASSERT, "assert", "assert") \
|
||||||
_(TK_DOTS, "dots", "...") \
|
_(TK_DOTS, "dots", "...") \
|
||||||
_(TK_LIST_COMP, "list comprehension", "") \
|
_(TK_LIST_COMP, "list comprehension", "") \
|
||||||
|
_(TK_DICT_COMP, "dict comprehension", "") \
|
||||||
_(TK_BREAK, "break", "break") \
|
_(TK_BREAK, "break", "break") \
|
||||||
_(TK_CONTINUE, "continue", "continue") \
|
_(TK_CONTINUE, "continue", "continue") \
|
||||||
_(TK_DELETE, "del", "del") \
|
_(TK_DELETE, "del", "del") \
|
||||||
|
@ -144,6 +144,16 @@ struct ParserImpl {
|
|||||||
} break;
|
} break;
|
||||||
case '{': {
|
case '{': {
|
||||||
L.next();
|
L.next();
|
||||||
|
// If we have a dict literal, `keys` and `values` will store the keys
|
||||||
|
// and values used in the object's construction. EDGE CASE: We have a
|
||||||
|
// dict comprehension, so we'll get the first element of the dict
|
||||||
|
// comprehension in `keys` and a list comprehension in `values`.
|
||||||
|
// For example, `{i : chr(i + 65) for i in range(4)}` would give us
|
||||||
|
// `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
|
||||||
|
// The optimal way of handling this case is to simply splice the new
|
||||||
|
// dict comprehension together from the existing list comprehension.
|
||||||
|
// Splicing prevents breaking changes to our API and does not require
|
||||||
|
// the use of global variables.
|
||||||
std::vector<Expr> keys;
|
std::vector<Expr> keys;
|
||||||
std::vector<Expr> values;
|
std::vector<Expr> values;
|
||||||
auto range = L.cur().range;
|
auto range = L.cur().range;
|
||||||
@ -155,10 +165,16 @@ struct ParserImpl {
|
|||||||
} while (L.nextIf(','));
|
} while (L.nextIf(','));
|
||||||
}
|
}
|
||||||
L.expect('}');
|
L.expect('}');
|
||||||
prefix = DictLiteral::create(
|
if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
|
||||||
range,
|
ListComp lc(*values.begin());
|
||||||
List<Expr>::create(range, keys),
|
prefix = DictComp::create(
|
||||||
List<Expr>::create(range, values));
|
range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
|
||||||
|
} else {
|
||||||
|
prefix = DictLiteral::create(
|
||||||
|
range,
|
||||||
|
List<Expr>::create(range, keys),
|
||||||
|
List<Expr>::create(range, values));
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case TK_STRINGLITERAL: {
|
case TK_STRINGLITERAL: {
|
||||||
prefix = parseConcatenatedStringLiterals();
|
prefix = parseConcatenatedStringLiterals();
|
||||||
|
@ -309,6 +309,7 @@ struct Expr : public TreeView {
|
|||||||
case '^':
|
case '^':
|
||||||
case '|':
|
case '|':
|
||||||
case TK_LIST_COMP:
|
case TK_LIST_COMP:
|
||||||
|
case TK_DICT_COMP:
|
||||||
case TK_DOTS:
|
case TK_DOTS:
|
||||||
case TK_IN:
|
case TK_IN:
|
||||||
case TK_WITH_ITEM:
|
case TK_WITH_ITEM:
|
||||||
@ -579,6 +580,35 @@ struct ListComp : public Expr {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: supports only single comprehension for now
|
||||||
|
struct DictComp : public Expr {
|
||||||
|
explicit DictComp(const TreeRef& tree) : Expr(tree) {
|
||||||
|
tree->match(TK_DICT_COMP);
|
||||||
|
}
|
||||||
|
Expr key() const {
|
||||||
|
return Expr(subtree(0));
|
||||||
|
}
|
||||||
|
Expr value() const {
|
||||||
|
return Expr(subtree(1));
|
||||||
|
}
|
||||||
|
Expr target() const {
|
||||||
|
return Expr(subtree(2));
|
||||||
|
}
|
||||||
|
Expr iter() const {
|
||||||
|
return Expr(subtree(3));
|
||||||
|
}
|
||||||
|
// TODO: no ifs for now
|
||||||
|
static DictComp create(
|
||||||
|
const SourceRange& range,
|
||||||
|
const Expr& key,
|
||||||
|
const Expr& value,
|
||||||
|
const Expr& target,
|
||||||
|
const Expr& iter) {
|
||||||
|
return DictComp(
|
||||||
|
Compound::create(TK_DICT_COMP, range, {key, value, target, iter}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Global : public Stmt {
|
struct Global : public Stmt {
|
||||||
explicit Global(const TreeRef& tree) : Stmt(tree) {
|
explicit Global(const TreeRef& tree) : Stmt(tree) {
|
||||||
tree_->match(TK_GLOBAL);
|
tree_->match(TK_GLOBAL);
|
||||||
|
@ -352,6 +352,14 @@ void initTreeViewBindings(PyObject* module) {
|
|||||||
const Expr& iter) {
|
const Expr& iter) {
|
||||||
return ListComp::create(range, elt, target, iter);
|
return ListComp::create(range, elt, target, iter);
|
||||||
}));
|
}));
|
||||||
|
py::class_<DictComp, Expr>(m, "DictComp")
|
||||||
|
.def(py::init([](const SourceRange& range,
|
||||||
|
const Expr& key,
|
||||||
|
const Expr& value,
|
||||||
|
const Expr& target,
|
||||||
|
const Expr& iter) {
|
||||||
|
return DictComp::create(range, key, value, target, iter);
|
||||||
|
}));
|
||||||
py::class_<ListLiteral, Expr>(m, "ListLiteral")
|
py::class_<ListLiteral, Expr>(m, "ListLiteral")
|
||||||
.def(py::init([](const SourceRange& range, std::vector<Expr> args) {
|
.def(py::init([](const SourceRange& range, std::vector<Expr> args) {
|
||||||
return ListLiteral::create(range, wrap_list(range, std::move(args)));
|
return ListLiteral::create(range, wrap_list(range, std::move(args)));
|
||||||
|
@ -14,6 +14,7 @@ from torch._C._jit_tree_views import (
|
|||||||
ListLiteral, TupleLiteral, DictLiteral, Const,
|
ListLiteral, TupleLiteral, DictLiteral, Const,
|
||||||
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
|
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
|
||||||
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
|
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
|
||||||
|
DictComp,
|
||||||
)
|
)
|
||||||
from torch._utils_internal import get_source_lines_and_file
|
from torch._utils_internal import get_source_lines_and_file
|
||||||
|
|
||||||
@ -810,18 +811,34 @@ class ExprBuilder(Builder):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def build_ListComp(ctx, stmt):
|
def build_ListComp(ctx, stmt):
|
||||||
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
|
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
|
||||||
if (len(stmt.generators) > 1):
|
if (len(stmt.generators) != 1):
|
||||||
raise NotSupportedError(r, "multiple comprehension generators not supported yet")
|
raise NotSupportedError(r, "Only a single generator is currently supported")
|
||||||
|
|
||||||
if (len(stmt.generators[0].ifs) != 0):
|
if (len(stmt.generators[0].ifs) != 0):
|
||||||
raise NotSupportedError(r, "comprehension ifs not supported yet")
|
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
|
||||||
|
|
||||||
elt_expr = build_expr(ctx, stmt.elt)
|
elt_expr = build_expr(ctx, stmt.elt)
|
||||||
target_expr = build_expr(ctx, stmt.generators[0].target)
|
target_expr = build_expr(ctx, stmt.generators[0].target)
|
||||||
|
|
||||||
iter_expr = build_expr(ctx, stmt.generators[0].iter)
|
iter_expr = build_expr(ctx, stmt.generators[0].iter)
|
||||||
|
|
||||||
return ListComp(r, elt_expr, target_expr, iter_expr)
|
return ListComp(r, elt_expr, target_expr, iter_expr)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_DictComp(ctx, stmt):
|
||||||
|
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
|
||||||
|
if (len(stmt.generators) != 1):
|
||||||
|
raise NotSupportedError(r, "Only a single generator is currently supported")
|
||||||
|
|
||||||
|
if (len(stmt.generators[0].ifs) != 0):
|
||||||
|
raise NotSupportedError(r, "Comprehension ifs are not supported yet")
|
||||||
|
|
||||||
|
key_expr = build_expr(ctx, stmt.key)
|
||||||
|
value_expr = build_expr(ctx, stmt.value)
|
||||||
|
target_expr = build_expr(ctx, stmt.generators[0].target)
|
||||||
|
iter_expr = build_expr(ctx, stmt.generators[0].iter)
|
||||||
|
|
||||||
|
return DictComp(r, key_expr, value_expr, target_expr, iter_expr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_Starred(ctx, expr):
|
def build_Starred(ctx, expr):
|
||||||
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
|
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
|
||||||
|
Reference in New Issue
Block a user