mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[JIT] Add Type::repr_str to return human-readable str (#39544)
Summary: Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str(). annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated" Closes https://github.com/pytorch/pytorch/issues/39449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544 Differential Revision: D21978759 Pulled By: gmagogsfm fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4e892bd99c
commit
c22bbb2124
@ -500,7 +500,7 @@ struct PythonPrintImpl {
|
||||
indent();
|
||||
body_ << useOf(lhs[i]);
|
||||
if (requiresAnnotation(lhs[i], rhs[i])) {
|
||||
body_ << ": " << lhs[i]->type()->python_str(type_printer_);
|
||||
body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
|
||||
}
|
||||
body_ << " = " << useOf(rhs[i]) << "\n";
|
||||
}
|
||||
@ -769,7 +769,7 @@ struct PythonPrintImpl {
|
||||
if (i > 0) {
|
||||
body_ << ", ";
|
||||
}
|
||||
body_ << useOf(v) << ": " << v->type()->python_str(type_printer_);
|
||||
body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
|
||||
}
|
||||
body_ << "):\n";
|
||||
printBody(graph->block());
|
||||
@ -804,7 +804,7 @@ struct PythonPrintImpl {
|
||||
if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
|
||||
// print the namedtuple constructor and let rest of tuple printing
|
||||
// continue
|
||||
ss << v.type()->expect<TupleType>()->python_str(type_printer_);
|
||||
ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@ -857,14 +857,14 @@ struct PythonPrintImpl {
|
||||
} break;
|
||||
case prim::Uninitialized: {
|
||||
stmt << "uninitialized("
|
||||
<< node->output()->type()->python_str(type_printer_) << ")";
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ")";
|
||||
} break;
|
||||
case prim::Constant: {
|
||||
if (node->outputs().size() == 1 &&
|
||||
node->output()->type()->kind() == TypeKind::FunctionType) {
|
||||
auto fn = node->output()->type()->expect<FunctionType>();
|
||||
registerDependency(fn);
|
||||
stmt << fn->python_str(type_printer_);
|
||||
stmt << fn->annotation_str(type_printer_);
|
||||
} else if (!node->mustBeNone()) {
|
||||
IValue v = toIValue(node->output()).value();
|
||||
printConstant(stmt, v);
|
||||
@ -875,8 +875,9 @@ struct PythonPrintImpl {
|
||||
case aten::ScalarImplicit:
|
||||
case aten::FloatImplicit:
|
||||
case aten::IntImplicit: {
|
||||
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
|
||||
<< ", " << useOf(node->input()) << ")";
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", "
|
||||
<< useOf(node->input()) << ")";
|
||||
} break;
|
||||
case aten::Int: {
|
||||
printValueList(stmt, node->inputs(), "int(", ")");
|
||||
@ -902,7 +903,7 @@ struct PythonPrintImpl {
|
||||
case prim::TupleConstruct: {
|
||||
if (auto qualname =
|
||||
node->output()->type()->expect<TupleType>()->name()) {
|
||||
stmt << node->output()->type()->python_str(type_printer_);
|
||||
stmt << node->output()->type()->annotation_str(type_printer_);
|
||||
}
|
||||
printValueList(
|
||||
stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
|
||||
@ -922,13 +923,14 @@ struct PythonPrintImpl {
|
||||
// what type is supposed to be inside them
|
||||
if (node->inputs().size() == 0) {
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", [])";
|
||||
<< node->output()->type()->annotation_str(type_printer_)
|
||||
<< ", [])";
|
||||
// If we can't infer the type based on what's inside, explicitly
|
||||
// annotate it to disambiguate.
|
||||
// This happens for List[Tensor] vs. List[Optional[Tensor]]
|
||||
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", ";
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
||||
printValueList(stmt, node->inputs(), "[", "]");
|
||||
stmt << ")";
|
||||
// Otherwise just print a list
|
||||
@ -947,7 +949,7 @@ struct PythonPrintImpl {
|
||||
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
|
||||
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", ";
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
||||
printDict(stmt, node->inputs());
|
||||
stmt << ")";
|
||||
// Otherwise just print a dict
|
||||
@ -957,8 +959,8 @@ struct PythonPrintImpl {
|
||||
} break;
|
||||
case prim::CreateObject: {
|
||||
const auto classType = node->output()->type()->expect<ClassType>();
|
||||
stmt << classType->python_str(type_printer_) << ".__new__("
|
||||
<< classType->python_str(type_printer_) << ")";
|
||||
stmt << classType->annotation_str(type_printer_) << ".__new__("
|
||||
<< classType->annotation_str(type_printer_) << ")";
|
||||
} break;
|
||||
case prim::GetAttr: {
|
||||
const auto obj = node->inputs().at(0);
|
||||
@ -1013,8 +1015,8 @@ struct PythonPrintImpl {
|
||||
if (node->input()->type()->isSubtypeOf(NoneType::get()) ||
|
||||
node->input()->mustBeNone()) {
|
||||
auto input_type = OptionalType::create(node->output()->type());
|
||||
stmt << "annotate(" << input_type->python_str(type_printer_) << ", "
|
||||
<< useOf(node->input()) << ")";
|
||||
stmt << "annotate(" << input_type->annotation_str(type_printer_)
|
||||
<< ", " << useOf(node->input()) << ")";
|
||||
} else {
|
||||
stmt << useOf(node->input());
|
||||
}
|
||||
@ -1027,14 +1029,14 @@ struct PythonPrintImpl {
|
||||
case prim::unchecked_unwrap_optional:
|
||||
case prim::unchecked_cast: {
|
||||
stmt << "unchecked_cast("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", "
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", "
|
||||
<< useOf(node->input()) << ")";
|
||||
} break;
|
||||
case prim::isinstance: {
|
||||
stmt << "isinstance(" << useOf(node->input()) << ", ";
|
||||
const auto& types = node->tys(attr::types);
|
||||
if (types.size() == 1) {
|
||||
stmt << types.at(0)->python_str(type_printer_);
|
||||
stmt << types.at(0)->annotation_str(type_printer_);
|
||||
} else {
|
||||
// check multiple things, e.g. (str, list, int)
|
||||
stmt << "(";
|
||||
@ -1043,7 +1045,7 @@ struct PythonPrintImpl {
|
||||
if (!first) {
|
||||
stmt << ", ";
|
||||
}
|
||||
stmt << typ->python_str(type_printer_);
|
||||
stmt << typ->annotation_str(type_printer_);
|
||||
first = false;
|
||||
}
|
||||
stmt << ")";
|
||||
@ -1051,8 +1053,8 @@ struct PythonPrintImpl {
|
||||
stmt << ")";
|
||||
} break;
|
||||
case prim::tolist: {
|
||||
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
|
||||
<< ", ";
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
||||
stmt << useOf(node->input(0)) << ".tolist()"
|
||||
<< ")";
|
||||
} break;
|
||||
@ -1172,11 +1174,11 @@ struct PythonPrintImpl {
|
||||
// the flag print_first_argument_type determines when to do this
|
||||
body_ << arg_name;
|
||||
if (print_first_argument_type) {
|
||||
body_ << ": " << arg.type()->python_str(type_printer_);
|
||||
body_ << ": " << arg.type()->annotation_str(type_printer_);
|
||||
}
|
||||
} else {
|
||||
body_ << ",\n " << arg_name << ": "
|
||||
<< arg.type()->python_str(type_printer_);
|
||||
<< arg.type()->annotation_str(type_printer_);
|
||||
}
|
||||
if (arg.default_value()) {
|
||||
printDefaultValue(arg, body_, *arg.default_value());
|
||||
@ -1184,7 +1186,8 @@ struct PythonPrintImpl {
|
||||
assignValue(*param_it++, arg_name);
|
||||
}
|
||||
|
||||
body_ << ") -> " << schema.returns().at(0).type()->python_str(type_printer_)
|
||||
body_ << ") -> "
|
||||
<< schema.returns().at(0).type()->annotation_str(type_printer_)
|
||||
<< ":\n";
|
||||
printBody(graph.block());
|
||||
}
|
||||
@ -1275,12 +1278,12 @@ struct PythonPrintImpl {
|
||||
// Print out a direct manipulation of the annotations dict, like:
|
||||
// __annotations__["0"] = SomeType
|
||||
body_ << "__annotations__["
|
||||
<< "\"" << name << "\"] = " << type->python_str(type_printer_)
|
||||
<< "\n";
|
||||
<< "\"" << name
|
||||
<< "\"] = " << type->annotation_str(type_printer_) << "\n";
|
||||
} else {
|
||||
// Otherwise: just emit a python 3 attribute annotation, like:
|
||||
// foo : SomeType
|
||||
body_ << name << " : " << type->python_str(type_printer_) << "\n";
|
||||
body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -1291,7 +1294,7 @@ struct PythonPrintImpl {
|
||||
|
||||
indent();
|
||||
body_ << name << " : "
|
||||
<< "Final[" << v.type()->python_str(type_printer_) << "] = ";
|
||||
<< "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
|
||||
auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
|
||||
printConstant(*ss, v);
|
||||
body_ << ss->str() << "\n";
|
||||
@ -1319,7 +1322,7 @@ struct PythonPrintImpl {
|
||||
TORCH_INTERNAL_ASSERT(attr.type());
|
||||
indent();
|
||||
body_ << attr.name() << " : "
|
||||
<< attr.type()->python_str(type_printer_) << "\n";
|
||||
<< attr.type()->annotation_str(type_printer_) << "\n";
|
||||
}
|
||||
}
|
||||
} else if (auto interfaceType = type->cast<InterfaceType>()) {
|
||||
@ -1342,11 +1345,12 @@ struct PythonPrintImpl {
|
||||
auto type = arg.type();
|
||||
registerClassDependencies(type);
|
||||
body_ << ", " << arg.name() << ": "
|
||||
<< type->python_str(type_printer_);
|
||||
<< type->annotation_str(type_printer_);
|
||||
}
|
||||
auto return_type = method.returns().at(0).type();
|
||||
registerClassDependencies(return_type);
|
||||
body_ << ") -> " << return_type->python_str(type_printer_) << ":\n";
|
||||
body_ << ") -> " << return_type->annotation_str(type_printer_)
|
||||
<< ":\n";
|
||||
indent();
|
||||
body_ << " pass\n";
|
||||
}
|
||||
|
Reference in New Issue
Block a user