[JIT] Add Type::repr_str to return human-readable str (#39544)

Summary:
Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly

Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str().

annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated"

Closes https://github.com/pytorch/pytorch/issues/39449
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544

Differential Revision: D21978759

Pulled By: gmagogsfm

fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0
This commit is contained in:
Yanan Cao
2020-06-10 11:59:01 -07:00
committed by Facebook GitHub Bot
parent 4e892bd99c
commit c22bbb2124
37 changed files with 238 additions and 224 deletions

View File

@ -500,7 +500,7 @@ struct PythonPrintImpl {
indent();
body_ << useOf(lhs[i]);
if (requiresAnnotation(lhs[i], rhs[i])) {
body_ << ": " << lhs[i]->type()->python_str(type_printer_);
body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
}
body_ << " = " << useOf(rhs[i]) << "\n";
}
@ -769,7 +769,7 @@ struct PythonPrintImpl {
if (i > 0) {
body_ << ", ";
}
body_ << useOf(v) << ": " << v->type()->python_str(type_printer_);
body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
}
body_ << "):\n";
printBody(graph->block());
@ -804,7 +804,7 @@ struct PythonPrintImpl {
if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
// print the namedtuple constructor and let rest of tuple printing
// continue
ss << v.type()->expect<TupleType>()->python_str(type_printer_);
ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
}
return false;
};
@ -857,14 +857,14 @@ struct PythonPrintImpl {
} break;
case prim::Uninitialized: {
stmt << "uninitialized("
<< node->output()->type()->python_str(type_printer_) << ")";
<< node->output()->type()->annotation_str(type_printer_) << ")";
} break;
case prim::Constant: {
if (node->outputs().size() == 1 &&
node->output()->type()->kind() == TypeKind::FunctionType) {
auto fn = node->output()->type()->expect<FunctionType>();
registerDependency(fn);
stmt << fn->python_str(type_printer_);
stmt << fn->annotation_str(type_printer_);
} else if (!node->mustBeNone()) {
IValue v = toIValue(node->output()).value();
printConstant(stmt, v);
@ -875,8 +875,9 @@ struct PythonPrintImpl {
case aten::ScalarImplicit:
case aten::FloatImplicit:
case aten::IntImplicit: {
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
<< ", " << useOf(node->input()) << ")";
stmt << "annotate("
<< node->output()->type()->annotation_str(type_printer_) << ", "
<< useOf(node->input()) << ")";
} break;
case aten::Int: {
printValueList(stmt, node->inputs(), "int(", ")");
@ -902,7 +903,7 @@ struct PythonPrintImpl {
case prim::TupleConstruct: {
if (auto qualname =
node->output()->type()->expect<TupleType>()->name()) {
stmt << node->output()->type()->python_str(type_printer_);
stmt << node->output()->type()->annotation_str(type_printer_);
}
printValueList(
stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
@ -922,13 +923,14 @@ struct PythonPrintImpl {
// what type is supposed to be inside them
if (node->inputs().size() == 0) {
stmt << "annotate("
<< node->output()->type()->python_str(type_printer_) << ", [])";
<< node->output()->type()->annotation_str(type_printer_)
<< ", [])";
// If we can't infer the type based on what's inside, explicitly
// annotate it to disambiguate.
// This happens for List[Tensor] vs. List[Optional[Tensor]]
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
stmt << "annotate("
<< node->output()->type()->python_str(type_printer_) << ", ";
<< node->output()->type()->annotation_str(type_printer_) << ", ";
printValueList(stmt, node->inputs(), "[", "]");
stmt << ")";
// Otherwise just print a list
@ -947,7 +949,7 @@ struct PythonPrintImpl {
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
stmt << "annotate("
<< node->output()->type()->python_str(type_printer_) << ", ";
<< node->output()->type()->annotation_str(type_printer_) << ", ";
printDict(stmt, node->inputs());
stmt << ")";
// Otherwise just print a dict
@ -957,8 +959,8 @@ struct PythonPrintImpl {
} break;
case prim::CreateObject: {
const auto classType = node->output()->type()->expect<ClassType>();
stmt << classType->python_str(type_printer_) << ".__new__("
<< classType->python_str(type_printer_) << ")";
stmt << classType->annotation_str(type_printer_) << ".__new__("
<< classType->annotation_str(type_printer_) << ")";
} break;
case prim::GetAttr: {
const auto obj = node->inputs().at(0);
@ -1013,8 +1015,8 @@ struct PythonPrintImpl {
if (node->input()->type()->isSubtypeOf(NoneType::get()) ||
node->input()->mustBeNone()) {
auto input_type = OptionalType::create(node->output()->type());
stmt << "annotate(" << input_type->python_str(type_printer_) << ", "
<< useOf(node->input()) << ")";
stmt << "annotate(" << input_type->annotation_str(type_printer_)
<< ", " << useOf(node->input()) << ")";
} else {
stmt << useOf(node->input());
}
@ -1027,14 +1029,14 @@ struct PythonPrintImpl {
case prim::unchecked_unwrap_optional:
case prim::unchecked_cast: {
stmt << "unchecked_cast("
<< node->output()->type()->python_str(type_printer_) << ", "
<< node->output()->type()->annotation_str(type_printer_) << ", "
<< useOf(node->input()) << ")";
} break;
case prim::isinstance: {
stmt << "isinstance(" << useOf(node->input()) << ", ";
const auto& types = node->tys(attr::types);
if (types.size() == 1) {
stmt << types.at(0)->python_str(type_printer_);
stmt << types.at(0)->annotation_str(type_printer_);
} else {
// check multiple things, e.g. (str, list, int)
stmt << "(";
@ -1043,7 +1045,7 @@ struct PythonPrintImpl {
if (!first) {
stmt << ", ";
}
stmt << typ->python_str(type_printer_);
stmt << typ->annotation_str(type_printer_);
first = false;
}
stmt << ")";
@ -1051,8 +1053,8 @@ struct PythonPrintImpl {
stmt << ")";
} break;
case prim::tolist: {
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
<< ", ";
stmt << "annotate("
<< node->output()->type()->annotation_str(type_printer_) << ", ";
stmt << useOf(node->input(0)) << ".tolist()"
<< ")";
} break;
@ -1172,11 +1174,11 @@ struct PythonPrintImpl {
// the flag print_first_argument_type determines when to do this
body_ << arg_name;
if (print_first_argument_type) {
body_ << ": " << arg.type()->python_str(type_printer_);
body_ << ": " << arg.type()->annotation_str(type_printer_);
}
} else {
body_ << ",\n " << arg_name << ": "
<< arg.type()->python_str(type_printer_);
<< arg.type()->annotation_str(type_printer_);
}
if (arg.default_value()) {
printDefaultValue(arg, body_, *arg.default_value());
@ -1184,7 +1186,8 @@ struct PythonPrintImpl {
assignValue(*param_it++, arg_name);
}
body_ << ") -> " << schema.returns().at(0).type()->python_str(type_printer_)
body_ << ") -> "
<< schema.returns().at(0).type()->annotation_str(type_printer_)
<< ":\n";
printBody(graph.block());
}
@ -1275,12 +1278,12 @@ struct PythonPrintImpl {
// Print out a direct manipulation of the annotations dict, like:
// __annotations__["0"] = SomeType
body_ << "__annotations__["
<< "\"" << name << "\"] = " << type->python_str(type_printer_)
<< "\n";
<< "\"" << name
<< "\"] = " << type->annotation_str(type_printer_) << "\n";
} else {
// Otherwise: just emit a python 3 attribute annotation, like:
// foo : SomeType
body_ << name << " : " << type->python_str(type_printer_) << "\n";
body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
}
}
@ -1291,7 +1294,7 @@ struct PythonPrintImpl {
indent();
body_ << name << " : "
<< "Final[" << v.type()->python_str(type_printer_) << "] = ";
<< "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
printConstant(*ss, v);
body_ << ss->str() << "\n";
@ -1319,7 +1322,7 @@ struct PythonPrintImpl {
TORCH_INTERNAL_ASSERT(attr.type());
indent();
body_ << attr.name() << " : "
<< attr.type()->python_str(type_printer_) << "\n";
<< attr.type()->annotation_str(type_printer_) << "\n";
}
}
} else if (auto interfaceType = type->cast<InterfaceType>()) {
@ -1342,11 +1345,12 @@ struct PythonPrintImpl {
auto type = arg.type();
registerClassDependencies(type);
body_ << ", " << arg.name() << ": "
<< type->python_str(type_printer_);
<< type->annotation_str(type_printer_);
}
auto return_type = method.returns().at(0).type();
registerClassDependencies(return_type);
body_ << ") -> " << return_type->python_str(type_printer_) << ":\n";
body_ << ") -> " << return_type->annotation_str(type_printer_)
<< ":\n";
indent();
body_ << " pass\n";
}