From c22bbb212436c3ac44dbc68d0d25da64c6572057 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Wed, 10 Jun 2020 11:59:01 -0700 Subject: [PATCH] [JIT] Add Type::repr_str to return human-readable str (#39544) Summary: Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str(). annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated" Closes https://github.com/pytorch/pytorch/issues/39449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544 Differential Revision: D21978759 Pulled By: gmagogsfm fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0 --- aten/src/ATen/core/function_schema.h | 2 +- aten/src/ATen/core/function_schema_inl.h | 2 +- aten/src/ATen/core/ivalue.cpp | 10 +-- aten/src/ATen/core/jit_type.h | 79 +++++++++++-------- aten/src/ATen/core/type.cpp | 70 ++++++++-------- aten/src/ATen/test/type_test.cpp | 18 ++--- test/cpp/jit/test_jit_type.cpp | 2 +- test/cpp/jit/test_mobile_type_parser.cpp | 6 +- test/test_jit.py | 3 +- torch/csrc/distributed/rpc/py_rref.cpp | 2 +- torch/csrc/distributed/rpc/rref_context.cpp | 4 +- torch/csrc/jit/api/object.h | 6 +- .../jit/frontend/concrete_module_type.cpp | 6 +- torch/csrc/jit/frontend/ir_emitter.cpp | 54 ++++++------- torch/csrc/jit/frontend/schema_matching.cpp | 12 +-- torch/csrc/jit/frontend/sugared_value.cpp | 23 +++--- torch/csrc/jit/frontend/sugared_value.h | 2 +- torch/csrc/jit/frontend/tracer.cpp | 2 +- torch/csrc/jit/ir/alias_analysis.cpp | 12 +-- torch/csrc/jit/ir/ir.cpp | 6 +- torch/csrc/jit/ir/irparser.cpp | 5 +- torch/csrc/jit/passes/lower_graph.cpp | 2 +- torch/csrc/jit/python/pybind_utils.h | 22 +++--- torch/csrc/jit/python/python_custom_class.cpp | 2 +- torch/csrc/jit/python/python_ir.cpp | 2 +- .../csrc/jit/python/python_sugared_value.cpp | 2 +- torch/csrc/jit/python/script_init.cpp | 4 +- torch/csrc/jit/runtime/register_ops_utils.cpp | 2 +- .../jit/runtime/register_prim_ops_fulljit.cpp | 8 +- .../csrc/jit/runtime/register_special_ops.cpp | 6 +- .../csrc/jit/serialization/export_module.cpp | 2 +- torch/csrc/jit/serialization/import.cpp | 2 +- .../csrc/jit/serialization/import_legacy.cpp | 2 +- torch/csrc/jit/serialization/pickler.cpp | 8 +- torch/csrc/jit/serialization/python_print.cpp | 64 ++++++++------- torch/csrc/jit/serialization/unpickler.cpp | 2 +- torch/custom_class.h | 6 +- 37 files changed, 238 insertions(+), 224 deletions(-) diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index daf28227f202..ed716809f591 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -75,7 +75,7 @@ struct Argument { } return c10::str( "Expected a value of type '", - type()->python_str(), + type()->repr_str(), "' for argument '", name(), "' but instead found type '", diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index 225ee85e6de0..bc9a68fbad3f 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -190,7 +190,7 @@ inline void FunctionSchema::checkArg( TORCH_CHECK( false, formatTypeMismatchMsg( - argument, value.type()->python_str(), pos)); + argument, value.type()->repr_str(), pos)); } } diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 2724a6e1d937..0d95059ccf95 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -18,7 +18,7 @@ bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) { namespace ivalue { -// This is in ivalue.cpp because we need to access Type::python_str, which +// This is in ivalue.cpp because we need to access Type::annotation_str, which // is declared in jit_type.h void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { // NB: doing pointer comparison here @@ -26,9 +26,9 @@ void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { // Type's, this needs to be changed! TORCH_CHECK(actual_type == expected_type, "Tried to convert an IValue of type ", - actual_type->python_str(), + actual_type->repr_str(), " to custom class type ", - expected_type->python_str()); + expected_type->repr_str()); } CAFFE2_API c10::intrusive_ptr ConstantString::create( @@ -291,7 +291,7 @@ std::ostream& printMaybeAnnotatedList( auto list_elem_type = the_list.type()->expect()->getElementType(); if (the_list.toListRef().size() == 0 || !elementTypeCanBeInferredFromMembers(list_elem_type)) { - out << "annotate(" << the_list.type()->python_str() << ", "; + out << "annotate(" << the_list.type()->annotation_str() << ", "; printList(out, the_list.toListRef(), "[", "]", formatter); out << ")"; return out; @@ -332,7 +332,7 @@ std::ostream& printMaybeAnnotatedDict( auto value_type = the_dict.type()->cast()->getValueType(); if (the_dict.toGenericDict().size() == 0 || !elementTypeCanBeInferredFromMembers(value_type)) { - out << "annotate(" << the_dict.type()->python_str() << ","; + out << "annotate(" << the_dict.type()->annotation_str() << ","; printDict(out, the_dict.toGenericDict(), formatter) << ")"; } else { return printDict(out, the_dict.toGenericDict(), formatter); diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 4b7c73994ac6..d8f4a73df427 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -68,8 +68,8 @@ struct Type; using TypePtr = std::shared_ptr; using ConstTypePtr = std::shared_ptr; -// Use this to customize how a Type is printed using `python_str()`. If -// c10::nullopt is returned, `python_str()` falls through to its default +// Use this to customize how a Type is printed using `annotation_str()`. If +// c10::nullopt is returned, `annotation_str()` falls through to its default // implementation. using TypePrinter = std::function(const ConstTypePtr&)>; @@ -81,7 +81,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this { protected: Type(TypeKind kind) : kind_(kind) {} - virtual std::string python_str_impl(TypePrinter printer) const { + virtual std::string annotation_str_impl(TypePrinter printer) const { return str(); } @@ -94,7 +94,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this { // if this returns false and the why_not stream is non-null, it contains // additional details that describe why this is not a subtype of 'rhs'. // This additional information should only contain details that are not obvious - // from the python_str() that describes the type. For instance it is clear that `int <: str` is false + // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false // but not clear why `Foo <: InterfaceBar` might be false. virtual bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const; virtual bool is_module() const; @@ -111,19 +111,26 @@ struct CAFFE2_API Type : std::enable_shared_from_this { // // Takes a custom printer that users can pass in to customize the output of // this method. - std::string python_str(TypePrinter printer) const { + std::string annotation_str(TypePrinter printer) const { if (printer) { // the printer can return nullopt to fall through to the default impl if (auto renamed = printer(shared_from_this())) { return *renamed; } } - return python_str_impl(printer); + return annotation_str_impl(printer); } - std::string python_str() const { + std::string annotation_str() const { // Overload instead of define a default value for `printer` to help // debuggers out. - return python_str(nullptr); + return annotation_str(nullptr); + } + + // Returns a human readable string that includes additional information like + // "type is inferred rather than explictly defined" to help construct more + // user-friendly messages. + virtual std::string repr_str() const { + return annotation_str(); } TypeKind kind() const { @@ -306,9 +313,9 @@ struct CAFFE2_API OptionalType private: OptionalType(TypePtr elem) : SingleElementType(elem) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::stringstream ss; - ss << "Optional[" << getElementType()->python_str(printer) << "]"; + ss << "Optional[" << getElementType()->annotation_str(printer) << "]"; return ss.str(); } }; @@ -641,6 +648,10 @@ struct CAFFE2_API TensorType : public Type { std::string str() const override; + std::string repr_str() const override { + return str() + (isInferredType() ? " (inferred)" : ""); + } + c10::optional numel() const { size_t prod = 1; const auto& shape = sizes(); @@ -852,9 +863,9 @@ struct CAFFE2_API ListType private: ListType(TypePtr elem) : SingleElementType(elem) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::stringstream ss; - ss << "List[" << getElementType()->python_str(printer) << "]"; + ss << "List[" << getElementType()->annotation_str(printer) << "]"; return ss.str(); } }; @@ -928,10 +939,10 @@ struct CAFFE2_API DictType : public Type { has_free_variables( key->hasFreeVariables() || value->hasFreeVariables()) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::stringstream ss; - ss << "Dict[" << getKeyType()->python_str(printer) << ", " - << getValueType()->python_str(printer) << "]"; + ss << "Dict[" << getKeyType()->annotation_str(printer) << ", " + << getValueType()->annotation_str(printer) << "]"; return ss.str(); } @@ -964,9 +975,9 @@ struct CAFFE2_API FutureType private: FutureType(TypePtr elem) : SingleElementType(elem) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::stringstream ss; - ss << "Future[" << getElementType()->python_str(printer) << "]"; + ss << "Future[" << getElementType()->annotation_str(printer) << "]"; return ss.str(); } }; @@ -996,9 +1007,9 @@ struct CAFFE2_API RRefType private: RRefType(TypePtr elem) : SingleElementType(elem) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::stringstream ss; - ss << "RRef[" << getElementType()->python_str(printer) << "]"; + ss << "RRef[" << getElementType()->annotation_str(printer) << "]"; return ss.str(); } }; @@ -1105,7 +1116,7 @@ struct CAFFE2_API TupleType : public NamedType { return true; } - std::string python_str_impl(TypePrinter printer = nullptr) const override; + std::string annotation_str_impl(TypePrinter printer = nullptr) const override; std::vector elements_; bool has_free_variables_; @@ -1135,7 +1146,7 @@ struct CAFFE2_API NumberType : public Type { protected: NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { return "number"; // technically not a valid python type, but // we need to use it when parsing back in annotations // for implicit conversions @@ -1164,7 +1175,7 @@ struct CAFFE2_API FloatType : public NumberType { private: FloatType() : NumberType(TypeKind::FloatType) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { return "float"; } }; @@ -1191,7 +1202,7 @@ struct CAFFE2_API IntType : public NumberType { private: IntType() : NumberType(TypeKind::IntType) {} - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { return "int"; } }; @@ -1229,9 +1240,9 @@ struct CAFFE2_API StringType : public Type { } std::string str() const override { // we only use "str" (not "string") in both FunctionSchema and script - return python_str(); + return annotation_str(); } - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { return "str"; } static const TypeKind Kind = TypeKind::StringType; @@ -1266,7 +1277,7 @@ struct CAFFE2_API FunctionType : public NamedType { private: FunctionType(torch::jit::Function* function); - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { const auto& n = name().value(); return n.qualifiedName(); } @@ -1749,7 +1760,7 @@ struct CAFFE2_API ClassType : public NamedType { } std::string str() const override { - return python_str(); + return annotation_str(); } const std::vector& methods() const; @@ -1773,7 +1784,7 @@ struct CAFFE2_API ClassType : public NamedType { auto type = findAttribute(name); TORCH_CHECK( type, - python_str(), + repr_str(), " does not have an attribute with name '", name, "'"); @@ -1815,7 +1826,7 @@ struct CAFFE2_API ClassType : public NamedType { } TORCH_CHECK( false, - python_str(), + repr_str(), " does not have an attribute with name '", name, "'"); @@ -1867,9 +1878,9 @@ struct CAFFE2_API ClassType : public NamedType { TypePtr atype = getAttribute(*slot_idx); TORCH_CHECK( ty->isSubtypeOf(atype), - ty->python_str(), + ty->repr_str(), " is not compatible with the type ", - atype->python_str(), + atype->repr_str(), " for the field '", name, "'"); @@ -1904,7 +1915,7 @@ struct CAFFE2_API ClassType : public NamedType { } TORCH_CHECK( false, - python_str(), + repr_str(), " does not have constant field with the name '", name, "'"); @@ -2014,7 +2025,7 @@ struct CAFFE2_API ClassType : public NamedType { std::weak_ptr cu, bool is_module); - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { const auto& n = name().value(); return n.qualifiedName(); } @@ -2095,7 +2106,7 @@ struct CAFFE2_API InterfaceType : public NamedType { const InterfaceType& rhs, std::ostream* why_not); - std::string python_str_impl(TypePrinter printer = nullptr) const override { + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { return name()->qualifiedName(); } diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index a2d257d2a091..dbf1224c28dd 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -268,9 +268,9 @@ c10::optional unifyTypeList( auto maybe_unified = unifyTypes(ret_type, elements.at(i)); if (!maybe_unified) { why_not << "Could not unify type list since element " << i << " of type " - << elements.at(i)->python_str() + << elements.at(i)->repr_str() << " did not match the types before it (" - << ret_type->python_str() << ")"; + << ret_type->repr_str() << ")"; return c10::nullopt; } ret_type = maybe_unified.value(); @@ -300,8 +300,8 @@ MatchTypeReturn matchTypeVariables( } std::stringstream ss; ss << "Type variable '" << vt->name() << "' previously matched to type " - << it->second->python_str() << " is matched to type " - << actual->python_str(); + << it->second->repr_str() << " is matched to type " + << actual->repr_str(); return ss.str(); } else if (auto lt_formal = formal->cast()) { if (auto lt_actual = actual->cast()) { @@ -322,8 +322,8 @@ MatchTypeReturn matchTypeVariables( } std::stringstream ss; - ss << "Cannot match " << lt_formal->python_str() << " to " - << actual->python_str(); + ss << "Cannot match " << lt_formal->repr_str() << " to " + << actual->repr_str(); return ss.str(); } else if (auto tp_formal = formal->cast()) { if (auto tp_actual = actual->cast()) { @@ -340,7 +340,7 @@ MatchTypeReturn matchTypeVariables( return MatchTypeReturn::Success(); } else { std::stringstream ss; - ss << "Cannot match a tuple to " << actual->python_str(); + ss << "Cannot match a tuple to " << actual->repr_str(); return MatchTypeReturn(ss.str()); } } else if (auto lt_formal = formal->cast()) { @@ -353,7 +353,7 @@ MatchTypeReturn matchTypeVariables( return MatchTypeReturn::Success(); } else { std::stringstream ss; - ss << "Cannot match a future to " << actual->python_str(); + ss << "Cannot match a future to " << actual->repr_str(); return ss.str(); } } else if (auto lt_formal = formal->cast()) { @@ -366,7 +366,7 @@ MatchTypeReturn matchTypeVariables( return MatchTypeReturn::Success(); } else { std::stringstream ss; - ss << "Cannot match a rref to " << actual->python_str(); + ss << "Cannot match a rref to " << actual->repr_str(); return ss.str(); } } else if (auto opt_formal = formal->cast()) { @@ -403,12 +403,12 @@ MatchTypeReturn matchTypeVariables( return MatchTypeReturn::Success(); } else { std::stringstream ss; - ss << "Cannot match a dict to " << actual->python_str(); + ss << "Cannot match a dict to " << actual->repr_str(); return ss.str(); } } - AT_ERROR("Unhandled free variable container: ", formal->python_str()); + AT_ERROR("Unhandled free variable container: ", formal->repr_str()); } // change return types like List[List[t]] into List[List[int]] @@ -761,7 +761,7 @@ std::string TupleType::str() const { } return ss.str(); } -std::string TupleType::python_str_impl(TypePrinter printer) const { +std::string TupleType::annotation_str_impl(TypePrinter printer) const { std::stringstream ss; if (schema_ && name()) { ss << name()->qualifiedName(); @@ -770,7 +770,7 @@ std::string TupleType::python_str_impl(TypePrinter printer) const { for(size_t i = 0; i < elements().size(); ++i) { if(i > 0) ss << ", "; - ss << elements()[i]->python_str(printer); + ss << elements()[i]->annotation_str(printer); } ss << "]"; } @@ -978,7 +978,7 @@ void ClassType::addMethod(torch::jit::Function* method) { "Can't redefine method: ", method->name(), " on class: ", - python_str()); + repr_str()); methods_.push_back(method); } @@ -997,7 +997,7 @@ torch::jit::Function& ClassType::getMethod(const std::string& name) const { "Couldn't find method: '", name, "' on class: '", - python_str(), + repr_str(), "'"); return *method; } @@ -1019,7 +1019,7 @@ void ClassType::unsafeRemoveMethod(const std::string& name) { "Can't delete undefined method ", name, " on class: ", - python_str()); + repr_str()); } ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { @@ -1047,8 +1047,8 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { // Module Interface Type but the Class Type is not a Module Class Type if (!is_module() && iface->is_module()) { if (why_not) { - *why_not << "Class '" << python_str() << "' is not a subtype of " - << "the module interface '" << rhs->python_str() + *why_not << "Class '" << repr_str() << "' is not a subtype of " + << "the module interface '" << rhs->repr_str() << "' , only ScriptModule class can be subtype of module" << " interface.\n"; } @@ -1058,8 +1058,8 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { auto self_method = findMethod(schema.name()); if (!self_method) { if (why_not) { - *why_not << "Class '" << python_str() << "' does not have method '" - << schema.name() << "' but '" << rhs->python_str() + *why_not << "Class '" << repr_str() << "' does not have method '" + << schema.name() << "' but '" << rhs->repr_str() << "' does.\n"; } return false; @@ -1067,9 +1067,9 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { if (!self_method->getSchema().isSubtypeOf( schema, /*is_method=*/true, why_not)) { if (why_not) { - *why_not << "Method on class '" << python_str() + *why_not << "Method on class '" << repr_str() << "' (1) is not compatible with interface '" - << rhs->python_str() << "' (2)\n" + << rhs->repr_str() << "' (2)\n" << " (1) " << self_method->getSchema() << "\n" << " (2) " << schema << "\n"; } @@ -1091,8 +1091,8 @@ bool InterfaceType::isSubTypeImpl( std::ostream* why_not) { if (!lhs.is_module() && rhs.is_module()) { if (why_not) { - *why_not << "Interface '" << lhs.python_str() << "' is not a subtype of " - << "the module interface '" << rhs.python_str() << "'.\n"; + *why_not << "Interface '" << lhs.repr_str() << "' is not a subtype of " + << "the module interface '" << rhs.repr_str() << "'.\n"; } return false; } @@ -1100,17 +1100,17 @@ bool InterfaceType::isSubTypeImpl( auto self_schema = lhs.getMethod(schema.name()); if (!self_schema) { if (why_not) { - *why_not << "Interface '" << lhs.python_str() + *why_not << "Interface '" << lhs.repr_str() << "' does not have method '" << schema.name() << "' but interface '" - << rhs.python_str() << "' does.\n"; + << rhs.repr_str() << "' does.\n"; } return false; } if (!self_schema->isSubtypeOf(schema, /*is_method=*/true, why_not)) { if (why_not) { - *why_not << "Method on interface '" << lhs.python_str() + *why_not << "Method on interface '" << lhs.repr_str() << "' (1) is not compatible with interface '" - << rhs.python_str() << "' (2)\n" + << rhs.repr_str() << "' (2)\n" << " (1) " << *self_schema << "\n" << " (2) " << schema << "\n"; return false; @@ -1178,7 +1178,7 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what) " '", name, "' to ", - python_str(), + repr_str(), " but a constant field of the same name already exists with value ", constantValues_[i]); } @@ -1192,9 +1192,9 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what) " '", name, "' to ", - python_str(), + repr_str(), " but an attribute field of the same name already exists with type ", - attributes_[i].getType()->python_str()); + attributes_[i].getType()->repr_str()); } } @@ -1264,7 +1264,7 @@ IValue ClassType::getConstant(const std::string& name) const { const auto& v = findConstant(name); TORCH_CHECK( v.has_value(), - python_str(), + repr_str(), " does not have a constant field with name '", name, "'"); @@ -1275,7 +1275,7 @@ IValue ClassType::getConstant(size_t slot) const { TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size()); TORCH_CHECK( slot < constantValues_.size(), - python_str(), + repr_str(), " does not have a constant slot of index ", slot); return constantValues_[slot]; @@ -1336,9 +1336,9 @@ void checkNoAny(const Type& base, const char* what, const std::string& attrname, " '", attrname, "' of type ", - attrtype->python_str(), + attrtype->repr_str(), " to '", - base.python_str(), + base.repr_str(), "' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples."); } diff --git a/aten/src/ATen/test/type_test.cpp b/aten/src/ATen/test/type_test.cpp index eb0c361fa0da..c5959bdfeb8c 100644 --- a/aten/src/ATen/test/type_test.cpp +++ b/aten/src/ATen/test/type_test.cpp @@ -19,12 +19,12 @@ TEST(TypeCustomPrinter, Basic) { // Tensor types should be rewritten torch::Tensor iv = torch::rand({2, 3}); const auto type = TensorType::create(iv); - EXPECT_EQ(type->python_str(), "Tensor"); - EXPECT_EQ(type->python_str(printer), "CustomTensor"); + EXPECT_EQ(type->annotation_str(), "Tensor"); + EXPECT_EQ(type->annotation_str(printer), "CustomTensor"); // Unrelated types shoudl not be affected const auto intType = IntType::create(); - EXPECT_EQ(intType->python_str(printer), intType->python_str()); + EXPECT_EQ(intType->annotation_str(printer), intType->annotation_str()); } TEST(TypeCustomPrinter, ContainedTypes) { @@ -40,14 +40,14 @@ TEST(TypeCustomPrinter, ContainedTypes) { // Contained types should work const auto tupleType = TupleType::create({type, IntType::get(), type}); - EXPECT_EQ(tupleType->python_str(), "Tuple[Tensor, int, Tensor]"); + EXPECT_EQ(tupleType->annotation_str(), "Tuple[Tensor, int, Tensor]"); EXPECT_EQ( - tupleType->python_str(printer), "Tuple[CustomTensor, int, CustomTensor]"); + tupleType->annotation_str(printer), "Tuple[CustomTensor, int, CustomTensor]"); const auto dictType = DictType::create(IntType::get(), type); - EXPECT_EQ(dictType->python_str(printer), "Dict[int, CustomTensor]"); + EXPECT_EQ(dictType->annotation_str(printer), "Dict[int, CustomTensor]"); const auto listType = ListType::create(tupleType); EXPECT_EQ( - listType->python_str(printer), + listType->annotation_str(printer), "List[Tuple[CustomTensor, int, CustomTensor]]"); } @@ -67,11 +67,11 @@ TEST(TypeCustomPrinter, NamedTuples) { const auto namedTupleType = TupleType::createNamed( "my.named.tuple", {"foo", "bar"}, {type, IntType::get()}); - EXPECT_EQ(namedTupleType->python_str(printer), "Rewritten"); + EXPECT_EQ(namedTupleType->annotation_str(printer), "Rewritten"); // Put it inside another tuple, should still work const auto outerTupleType = TupleType::create({IntType::get(), namedTupleType}); - EXPECT_EQ(outerTupleType->python_str(printer), "Tuple[int, Rewritten]"); + EXPECT_EQ(outerTupleType->annotation_str(printer), "Tuple[int, Rewritten]"); } static TypePtr importType( diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp index f5dd613791e7..16c69ccd05fd 100644 --- a/test/cpp/jit/test_jit_type.cpp +++ b/test/cpp/jit/test_jit_type.cpp @@ -27,7 +27,7 @@ void testUnifyTypes() { TORCH_INTERNAL_ASSERT(out); std::stringstream ss; - ss << (*out)->python_str(); + ss << (*out)->annotation_str(); testing::FileCheck() .check("Optional[Tuple[Optional[int], Optional[int]]]") ->run(ss.str()); diff --git a/test/cpp/jit/test_mobile_type_parser.cpp b/test/cpp/jit/test_mobile_type_parser.cpp index b1700c534f64..989d16794bd2 100644 --- a/test/cpp/jit/test_mobile_type_parser.cpp +++ b/test/cpp/jit/test_mobile_type_parser.cpp @@ -14,20 +14,20 @@ void testMobileTypeParser() { std::string int_ps("int"); auto int_tp = c10::parseType(int_ps); - std::string int_tps = int_tp->python_str(); + std::string int_tps = int_tp->annotation_str(); ASSERT_EQ(int_ps, int_tps); std::string tuple_ps( "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]"); auto tuple_tp = c10::parseType(tuple_ps); - std::string tuple_tps = tuple_tp->python_str(); + std::string tuple_tps = tuple_tp->annotation_str(); ASSERT_EQ(tuple_ps, tuple_tps); std::string tuple_space_ps( "Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]"); auto tuple_space_tp = c10::parseType(tuple_space_ps); // tuple_space_tps should not have weird white spaces - std::string tuple_space_tps = tuple_space_tp->python_str(); + std::string tuple_space_tps = tuple_space_tp->annotation_str(); ASSERT_EQ(tuple_ps, tuple_space_tps); std::string typo_token("List[tensor]"); diff --git a/test/test_jit.py b/test/test_jit.py index db3f0367ff18..d6b44a979183 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -17623,7 +17623,8 @@ a") def foo(a): return a - with self.assertRaisesRegex(RuntimeError, "Inferred \'a\' to be of type \'Tensor"): + with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'" + r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")): foo(1) def test_type_comments_in_body(self): diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 739198fe51eb..125ed89aaf19 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -70,7 +70,7 @@ TypePtr tryInferTypeWithTypeHint( type_hint_ptr != nullptr && module.value().type()->isSubtypeOfExt( type_hint_ptr, &subtype_check_msg), - module.value().type()->python_str(), + module.value().type()->repr_str(), " is not a subtype of the type hint: ", type_qualified_name.qualifiedName(), ", did you pass a valid interface type?\n", diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index 48cc39f24c62..d1bbbe7d88c6 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -336,12 +336,12 @@ c10::intrusive_ptr RRefContext::getOrCreateOwnerRRef( TORCH_INTERNAL_ASSERT( ownerRRef->type()->isSubtypeOf(TensorType::get()), "Expect OwnerRRef to be a sub-type of TensorType, but got ", - ownerRRef->type()->python_str()); + ownerRRef->type()->repr_str()); } else { TORCH_INTERNAL_ASSERT( ownerRRef->type() == type, "OwnerRRef type is ", - ownerRRef->type()->python_str(), + ownerRRef->type()->repr_str(), ", expected type is ", type); } diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 99c99cddeba8..305c254ad1c0 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -40,11 +40,11 @@ struct TORCH_API Object { TORCH_CHECK( v.type()->isSubtypeOf(expected), "Expected a value of type '", - expected->python_str(), + expected->repr_str(), "' for field '", name, "', but found '", - v.type()->python_str(), + v.type()->repr_str(), "'"); _ivalue()->setSlot(*slot, std::move(v)); } else { @@ -61,7 +61,7 @@ struct TORCH_API Object { } TORCH_CHECK( false, - _ivalue()->type()->python_str(), + _ivalue()->type()->repr_str(), " does not have a field with name '", name, "'"); diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index 4243d1cb8218..8613f04037b3 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -258,13 +258,13 @@ void ConcreteModuleType::dump() const { } std::cout << "\nAttributes: \n"; for (const auto& pr : data_.attributes_) { - std::cout << "\t" << pr.key() << ": " << pr.value().type_->python_str() + std::cout << "\t" << pr.key() << ": " << pr.value().type_->annotation_str() << "\n"; } std::cout << "\nSubmodules: \n"; for (const auto& info : data_.modules_) { std::cout << "\t" << info.name_ << ": " - << info.meta_->getJitType()->python_str() << "\n"; + << info.meta_->getJitType()->annotation_str() << "\n"; } std::cout << "\nOverloads: \n"; for (const auto& pr : data_.overloads_) { @@ -273,7 +273,7 @@ void ConcreteModuleType::dump() const { std::string isPoisoned = data_.isPoisoned_ ? "true" : "false"; std::cout << "isPoisoned: " << isPoisoned << "\n"; if (jitType_) { - std::cout << "jit type: " << jitType_->python_str() << "\n"; + std::cout << "jit type: " << jitType_->annotation_str() << "\n"; } } diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index c17d148bb0d4..347044909938 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -374,9 +374,9 @@ struct Environment { if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) { auto error = ErrorReport(loc); error << "Variable '" << name << "' previously has type " - << simple_parent->type()->python_str() + << simple_parent->type()->repr_str() << " but is now being assigned to a value of type " - << as_simple_value->type()->python_str(); + << as_simple_value->type()->repr_str(); // Special-cased error msg if we're trying to assign to a tensor list. if (simple_parent->type()->kind() == TypeKind::ListType && @@ -397,9 +397,9 @@ struct Environment { if (!as_simple_value->type()->isSubtypeOf(annotated_type)) { throw ErrorReport(loc) << "Variable '" << name << "' is annotated with type " - << annotated_type->python_str() + << annotated_type->repr_str() << " but is being assigned to a value of type " - << as_simple_value->type()->python_str(); + << as_simple_value->type()->repr_str(); } insertStore(name, loc, std::move(as_simple_value), annotated_type); } else { @@ -972,8 +972,8 @@ struct to_ir { if (!result->type()->isSubtypeOf(result_type)) { throw ErrorReport(stmt.range()) << "Return value was annotated as having type " - << result_type->python_str() << " but is actually of type " - << result->type()->python_str(); + << result_type->repr_str() << " but is actually of type " + << result->type()->repr_str(); } } else { result_type = def_stack_.back().merged_return_type_; @@ -984,9 +984,9 @@ struct to_ir { if (!merged_result_type) { throw ErrorReport(stmt.range()) << "Previous return statement returned a value of type " - << result_type->python_str() + << result_type->repr_str() << " but this return statement returns a value of type " - << result->type()->python_str(); + << result->type()->repr_str(); } result_type = merged_result_type.value(); } @@ -1208,7 +1208,7 @@ struct to_ir { throw ErrorReport(loc) << "Expected list type annotation for list comprehension" ", found " - << type_hint->python_str(); + << type_hint->repr_str(); } list_value->setType(type_hint); type_set = true; @@ -1326,8 +1326,8 @@ struct to_ir { auto unified = unifyTypes(true_type, false_type); if (!unified) { throw ErrorReport(range) - << "if-expression's true branch has type " << true_type->python_str() - << " but false branch has type " << false_type->python_str(); + << "if-expression's true branch has type " << true_type->repr_str() + << " but false branch has type " << false_type->repr_str(); } // Add op outputs @@ -1342,13 +1342,13 @@ struct to_ir { out = asSimple(bool_cast->call(loc, method, {v}, {}, 0)); } catch (...) { throw ErrorReport(loc) << "Could not cast value of type " - << v->type()->python_str() << " to bool"; + << v->type()->repr_str() << " to bool"; } // cast value not response for checking output type if (!out->type()->isSubtypeOf(BoolType::get())) { throw ErrorReport(loc) << "expected a bool expression for condition but found " - << out->type()->python_str(); + << out->type()->repr_str(); } return out; } @@ -1507,8 +1507,8 @@ struct to_ir { if (!unified) { ErrorReport error(loc); error << "Type mismatch: " << x << " is set to type " - << tv->type()->python_str() << " in the true branch" - << " and type " << fv->type()->python_str() + << tv->type()->repr_str() << " in the true branch" + << " and type " << fv->type()->repr_str() << " in the false branch"; if (save_true->findInParentFrame(x) || save_false->findInParentFrame(x)) { @@ -1922,7 +1922,7 @@ struct to_ir { magic_method_name = out_of_place_method_name; } else { throw ErrorReport(stmt.range()) - << "Cannot emit inplace op on " << type->python_str() + << "Cannot emit inplace op on " << type->repr_str() << " since it does not define an " << in_place_method_name << " or " << out_of_place_method_name << " method"; } @@ -1954,7 +1954,7 @@ struct to_ir { const TypePtr type = sliceable->type(); if (subscriptExprs.size() != 1) { throw ErrorReport(subscriptExprs) - << "Sliced expression not yet supported for " << type->python_str() + << "Sliced expression not yet supported for " << type->repr_str() << " augmented assignment. " << "File a bug if you want this"; } @@ -1968,7 +1968,7 @@ struct to_ir { if (elemType == nullptr) { throw ErrorReport(lhs) - << type->python_str() << " does not support augmented assignment."; + << type->repr_str() << " does not support augmented assignment."; } const auto idxValue = emitExpr(subscriptExprs[0]); const auto containerArg = @@ -2541,8 +2541,8 @@ struct to_ir { std::stringstream why_not; if (!expr->type()->isSubtypeOfExt(type, &why_not)) { throw ErrorReport(apply.inputs()) - << "expected an expression of type " << type->python_str() - << " but found " << expr->type()->python_str() << "\n" + << "expected an expression of type " << type->repr_str() + << " but found " << expr->type()->repr_str() << "\n" << why_not.str(); } @@ -3050,7 +3050,7 @@ struct to_ir { // If the type hint was not a List[T] throw an error throw ErrorReport(tree) << "Expected a List type hint but instead got " - << type_hint->python_str(); + << type_hint->repr_str(); } } else if (!values.empty()) { std::stringstream ss; @@ -3068,8 +3068,8 @@ struct to_ir { if (!v->type()->isSubtypeOfExt(elem_type, &ss)) { throw ErrorReport(tree) << "Lists must contain only a single type, expected: " - << elem_type->python_str() << " but found " - << v->type()->python_str() << " instead.\n" + << elem_type->repr_str() << " but found " + << v->type()->repr_str() << " instead.\n" << ss.str(); } } @@ -3120,8 +3120,8 @@ struct to_ir { throw ErrorReport(trees[i]) << "Dict " << what << " must contain only a single type, expected: " - << type->python_str() << " but found " - << values[i]->type()->python_str() << " instead.\n" + << type->repr_str() << " but found " + << values[i]->type()->repr_str() << " instead.\n" << ss.str(); } } @@ -3329,7 +3329,7 @@ struct to_ir { } else { throw ErrorReport(loc) << "Unsupported operation: indexing tensor with unsupported index type '" - << index->type()->python_str() + << index->type()->repr_str() << "'. Only ints, slices, lists and tensors are supported"; } }; @@ -3485,7 +3485,7 @@ struct to_ir { if (elems.size() == 0 || !convertibleToList(tuple_typ, ListType::create(elems[0]))) { throw ErrorReport(loc) - << "Cannot index into a " << tuple_typ->python_str() + << "Cannot index into a " << tuple_typ->repr_str() << " with a non-integer literal because we cannot resolve the output type"; } output_type = elems[0]; diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index fd08f8b02252..4ac10e99e8d0 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -148,8 +148,8 @@ static Value* tryMatchArgument( matchTypeVariables(arg.type(), value->type(), type_env); if (!matched.success()) { if (failure_messages) { - err() << "Could not match type " << value->type()->python_str() << " to " - << arg.type()->python_str() << " in argument '" << arg.name() + err() << "Could not match type " << value->type()->repr_str() << " to " + << arg.type()->repr_str() << " in argument '" << arg.name() << "': " << matched.reason() << ".\n"; } return nullptr; @@ -157,9 +157,9 @@ static Value* tryMatchArgument( const auto concrete_type = tryEvalTypeVariables(arg.type(), type_env); if (!concrete_type) { if (failure_messages) { - err() << "Type variables in type " << arg.type()->python_str() + err() << "Type variables in type " << arg.type()->repr_str() << " could not be inferred from actual type " - << value->type()->python_str(); + << value->type()->repr_str(); } return nullptr; } @@ -172,7 +172,7 @@ static Value* tryMatchArgument( concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) { if (failure_messages) { auto& ostream = err() - << arg.formatTypeMismatchMsg(value->type()->python_str()); + << arg.formatTypeMismatchMsg(value->type()->repr_str()); if (auto pt = value->type()->cast()) { if (pt->isInferredType()) { @@ -414,7 +414,7 @@ static c10::optional tryMatchSchema( auto return_types = fmap(returns, [&](const Argument& r) { TypePtr result = tryEvalTypeVariables(r.type(), type_env); TORCH_INTERNAL_ASSERT( - result, r.type()->python_str(), " has unbound type variables."); + result, r.type()->repr_str(), " has unbound type variables."); return result; }); // Codegen does not support return of namedtuples with undefined field names. diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 98fa631b9562..da77d325da82 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -70,7 +70,7 @@ bool SimpleValue::hasAttr( auto class_type = value_->type()->cast(); if (!class_type) { throw ErrorReport(loc) << "hasattr's first argument must be an object, got " - << value_->type()->python_str() << " instead"; + << value_->type()->repr_str() << " instead"; } return class_type->hasMethod(field) || class_type->hasAttribute(field) || @@ -176,7 +176,7 @@ std::shared_ptr SimpleValue::attr( ErrorReport report(loc); report << "Tried to access nonexistent attribute or method '" << field - << "' of type '" << value_->type()->python_str() << "'."; + << "' of type '" << value_->type()->repr_str() << "'."; if (value_->type()->kind() == ClassType::Kind) { report << " Did you forget to initialize an attribute in __init__()?"; } @@ -205,7 +205,7 @@ std::vector> SimpleValue::asTuple( graph->insertNode(graph->createListUnpack(value_, *size_hint)); return fmap(unpack->outputs(), make_simple_value); } - throw ErrorReport(loc) << value_->type()->python_str() + throw ErrorReport(loc) << value_->type()->repr_str() << " cannot be used as a tuple"; } @@ -232,8 +232,7 @@ void SimpleValue::setAttr( const auto classType = value_->type()->cast(); if (!classType) { throw ErrorReport(loc) << "Tried to set an attribute: " << field - << " on a non-class: " - << value_->type()->python_str(); + << " on a non-class: " << value_->type()->repr_str(); } auto expectedType = classType->findAttribute(field); if (!expectedType) { @@ -255,7 +254,7 @@ void SimpleValue::setAttr( throw ErrorReport(loc) << "Assignment to attribute '" << field << "' cannot be of a type that contains class " - << "'" << classType->python_str() << "'.\n" + << "'" << classType->repr_str() << "'.\n" << "Classes that recursively contain instances of themselves" << " are not yet supported"; } @@ -283,8 +282,8 @@ void SimpleValue::setAttr( const auto newType = newValue->type(); if (!newType->isSubtypeOf(expectedType)) { throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected " - << expectedType->python_str() << " but got " - << newType->python_str(); + << expectedType->repr_str() << " but got " + << newType->repr_str(); } auto& g = *m.graph(); @@ -341,7 +340,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { val_type->isSubtypeOf(TensorType::get())) { return g.insert(aten::len, {val}, {}, loc); } else { - throw ErrorReport(loc) << "'" << val_type->python_str() << "'" + throw ErrorReport(loc) << "'" << val_type->repr_str() << "'" << " object is not iterable"; } } @@ -367,7 +366,7 @@ SugaredValuePtr SimpleValue::getitem( } else if (auto class_type = val_type->cast()) { return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1); } else { - throw ErrorReport(loc) << "'" << val_type->python_str() << "'" + throw ErrorReport(loc) << "'" << val_type->repr_str() << "'" << " object is not subscriptable"; } } @@ -393,7 +392,7 @@ SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) { } return std::make_shared(tup_sugared); } else { - throw ErrorReport(loc) << "'" << type->python_str() << "'" + throw ErrorReport(loc) << "'" << type->repr_str() << "'" << " object is not iterable"; } } @@ -407,7 +406,7 @@ RangeValue::RangeValue( auto typ = inputs[i]->type(); if (!typ->cast()) { throw ErrorReport(loc) - << "all inputs of range must be ints, found " << typ->python_str() + << "all inputs of range must be ints, found " << typ->repr_str() << " in argument " << c10::guts::to_string(i); } } diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 00a7ae3e421d..790aa471d56f 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -147,7 +147,7 @@ struct TORCH_API SimpleValue : public SugaredValue { SimpleValue(Value* value) : value_(value) {} std::string kind() const override { std::stringstream ss; - ss << "value of type '" << value_->type()->python_str() << "'"; + ss << "value of type '" << value_->type()->annotation_str() << "'"; return ss.str(); } Value* asValue(const SourceRange& range, Function& m) override { diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 0a490501b755..4740c7c30710 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -367,7 +367,7 @@ static IValue addInput( AT_ERROR( "Only tensors or (possibly nested) dict or tuples of tensors can be " "inputs to traced functions. Got ", - type->python_str()); + type->repr_str()); } } diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index f78f307783fa..f97b4137e10c 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -1081,9 +1081,9 @@ void AliasDb::replaceWithNewValue(Value* existing, Value* new_value) { *unshapedType(existing->type()) == *unshapedType(new_value->type()), "Types must be strictly equal if you are replacing aliasing information. ", "Got existing: '", - existing->type()->python_str(), + existing->type()->repr_str(), "', new_value: '", - new_value->type()->python_str(), + new_value->type()->repr_str(), "'"); if (!isMutableTypeInternal(existing)) { return; @@ -1099,9 +1099,9 @@ void AliasDb::copyValue(Value* from, Value* to) { *unshapedType(from->type()) == *unshapedType(to->type()), "Types must be strictly equal if you are copying aliasing information. ", "Got from: '", - from->type()->python_str(), + from->type()->repr_str(), "', to: '", - to->type()->python_str(), + to->type()->repr_str(), "'"); if (!isMutableTypeInternal(to)) { return; @@ -1557,8 +1557,8 @@ void Lint(const AliasDb* db) { auto it = db->elementMap_.find(v); if (it == db->elementMap_.end()) { failed = true; - ss << "Value %" << v->debugName() << " of type " - << v->type()->python_str() << " wasn't found in the element map.\n" + ss << "Value %" << v->debugName() << " of type " << v->type()->repr_str() + << " wasn't found in the element map.\n" << "It was defined in " << *v->node(); } } diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index b5232eb47be6..628330971e1f 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1599,9 +1599,9 @@ Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef values) { TORCH_CHECK( v->type()->isSubtypeOf(elem_type), "Expected a list element that subtypes '", - elem_type->python_str(), + elem_type->repr_str(), "' but got an element of type '", - v->type()->python_str(), + v->type()->repr_str(), "'"); } n->output()->setType(ListType::create(elem_type)); @@ -1714,7 +1714,7 @@ Value* Graph::insertToList(Value* v, TypePtr type) { } else { TORCH_CHECK( false, - ptr->python_str(), + ptr->repr_str(), " is not one of the supported element types for tolist: int, float, bool"); } diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index 3e93c3c538ea..ca6ab74741de 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -368,10 +368,9 @@ void IRParser::parseOperator(Block* b) { if (!schema_return_type->hasFreeVariables() && !v.type->isSubtypeOf(schema_return_type)) { throw ErrorReport(source_range) - << "Annotated type " << v.type->python_str() + << "Annotated type " << v.type->repr_str() << " does not match schema type " - << schema_return_type->python_str() << " for operator " - << *schema; + << schema_return_type->repr_str() << " for operator " << *schema; } vmap[v.name]->setType(v.type); } diff --git a/torch/csrc/jit/passes/lower_graph.cpp b/torch/csrc/jit/passes/lower_graph.cpp index c19aee6fb0b0..50ee4856bd0a 100644 --- a/torch/csrc/jit/passes/lower_graph.cpp +++ b/torch/csrc/jit/passes/lower_graph.cpp @@ -136,7 +136,7 @@ static std::vector loadTensors(const std::vector& slots) { getCustomClass( "__torch__.torch.classes.quantized.LinearPackedParamsBase")), "Unknown type ", - type->python_str(), + type->repr_str(), " encountered in graph lowering. This type is not supported in ONNX export."); result.emplace_back( script::Object(obj.toObject()).run_method("__getstate__")); diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index ace1b3aec06d..87a534d67ada 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -351,9 +351,9 @@ inline InferredType tryToInferContainerType(py::handle input) { if (!unified_key) { return InferredType(c10::str( "Dictionary inputs to traced functions must have consistent type. Found ", - key_type->python_str(), + key_type->repr_str(), " and ", - (entry_key_type_match.type())->python_str())); + (entry_key_type_match.type())->repr_str())); } // Try to infer the value type and unify it with the existing one @@ -366,9 +366,9 @@ inline InferredType tryToInferContainerType(py::handle input) { if (!unified_value) { return InferredType(c10::str( "Dictionary inputs to traced functions must have consistent type. Found ", - value_type->python_str(), + value_type->repr_str(), " and ", - (entry_value_type_match.type())->python_str())); + (entry_value_type_match.type())->repr_str())); } key_type = *unified_key; @@ -395,9 +395,9 @@ inline InferredType tryToInferContainerType(py::handle input) { if (!unified_type) { return InferredType(c10::str( "List inputs to traced functions must have consistent element type. Found ", - element_type->python_str(), + element_type->repr_str(), " and ", - (element_type_match.type())->python_str())); + (element_type_match.type())->repr_str())); } element_type = *unified_type; } @@ -453,7 +453,7 @@ inline Stack toTraceableStack(const py::tuple& inputs) { TORCH_CHECK( isTraceableType(info.type()), "Type '", - info.type()->python_str(), + info.type()->repr_str(), "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and" " Tuples of Tensors can be traced"); return info.toTuple()->elements(); @@ -542,7 +542,7 @@ inline IValue toIValue( "Object ", py::str(obj), " had a different number of elements than type ", - type->python_str())); + type->repr_str())); } std::vector values; values.reserve(tuple_size); @@ -669,7 +669,7 @@ inline IValue toIValue( "Object ", py::str(obj), " is not compatible with interface ", - interfaceType->python_str(), + interfaceType->repr_str(), "\n", why_not.str())); } @@ -694,7 +694,7 @@ inline IValue toIValue( return py::cast(obj); } else { throw py::cast_error( - c10::str("Cannot cast ", py::str(obj), " to ", type->python_str())); + c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str())); } } case TypeKind::RRefType: { @@ -726,7 +726,7 @@ inline IValue toIValue( break; } throw py::cast_error(c10::str( - "toIValue() cannot handle converting to type: ", type->python_str())); + "toIValue() cannot handle converting to type: ", type->repr_str())); } // Small wrapper around getting the type name string from Python to make diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 7afb8577ab7d..7bc6b4e59f2f 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -18,7 +18,7 @@ py::object ScriptClass::__call__(py::args args, py::kwargs kwargs) { fmt::format( "Custom C++ class: '{}' does not have an '__init__' method bound. " "Did you forget to add '.def(torch::init<...>)' to its registration?", - instance.type()->python_str())); + instance.type()->repr_str())); Method init_method(instance._ivalue(), init_fn); invokeScriptMethodFromPython(init_method, std::move(args), std::move(kwargs)); return py::cast(instance); diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 33008c6d11ac..81b4807ae58f 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -641,7 +641,7 @@ void initPythonIRBindings(PyObject* module_) { using ::c10::Type; py::class_>(m, "Type") - .def("__repr__", [](Type& t) { return t.python_str(); }) + .def("__repr__", [](Type& t) { return t.annotation_str(); }) .def( "str", [](Type& t) { diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index c086dca06c12..068b2e7818a5 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -671,7 +671,7 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) { TORCH_CHECK( type->isSubtypeOf(tt), "Can't to redefine NamedTuple: ", - tt->python_str()); + tt->repr_str()); return type; } get_python_cu()->register_type(tt); diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index a611a2e673b1..7ee84944f464 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -260,7 +260,7 @@ FunctionSchema getSchemaWithNameAndDefaults( c10::optional value = tryCalculateDefaultParam(arg, it->second); if (!value) { ErrorReport error(range); - error << "Expected a default value of type " << arg.type()->python_str() + error << "Expected a default value of type " << arg.type()->repr_str() << " on parameter \"" << arg.name() << "\"."; if (arg.is_inferred_type()) { error << "Because \"" << arg.name() @@ -799,7 +799,7 @@ void initJitScriptBindings(PyObject* module) { TORCH_INTERNAL_ASSERT( setstate_schema.arguments().size() == 2, "__setstate__ method for class ", - class_type->python_str(), + class_type->repr_str(), " must have exactly 2 arguments!"); auto state_type = setstate_schema.arguments().at(1).type(); (*setstate_method)(Stack{toIValue(state, state_type)}); diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index c932376f2b88..c303956dadfe 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -158,7 +158,7 @@ IValue tensorToListRecursive( } else { TORCH_CHECK( false, - ty->python_str(), + ty->repr_str(), " is not one of the supported types for tolist: int, float, bool"); } } diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 711725c96960..088b218b6866 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -1319,16 +1319,16 @@ Function* checkSortSchema(const c10::TypePtr& list_element_type) { return method; } } - error_str << "To sort a list of " << class_type->python_str() + error_str << "To sort a list of " << class_type->repr_str() << " it must define a " << "__lt__ method with two inputs of type " - << class_type->python_str() << " that " + << class_type->repr_str() << " that " << "returns a bool"; } else { - error_str << "To sort a list of " << list_element_type->python_str() + error_str << "To sort a list of " << list_element_type->repr_str() << " must be of Tensors, ints, floats, bools or " << "a User Defined Class that defines the __lt__ compare method" - << ", got list of " << list_element_type->python_str() << "\n"; + << ", got list of " << list_element_type->repr_str() << "\n"; } throw std::runtime_error(error_str.str()); } diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 4ffd3f3865b8..244e24b941a0 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -31,7 +31,7 @@ void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { elem_type != BoolType::get()) { std::stringstream error; error << "Input must be of ints, floats, or bools, " - << "got " << elem_type->python_str(); + << "got " << elem_type->repr_str(); // special case empty list torch.tensor([]) if (elem_type->isSubtypeOf(TensorType::get())) { if (empty_list) { @@ -220,11 +220,11 @@ int createTensorFromList(Stack& stack) { tensor.numel() == 0) { TORCH_WARN( "Creating a tensor from an empty ", - elem_type->python_str(), + elem_type->repr_str(), "list will create a tensor of default floating point type (currently ", default_type, ") in python but a tensor of type ", - elem_type->python_str(), + elem_type->repr_str(), " in torchscript.\n", "Pass in a dtype argument to ensure consistent behavior"); } diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 0f337f069643..d26987f0cf6f 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -111,7 +111,7 @@ c10::IValue getFunctionTuple(const Function& func) { std::vector types; types.reserve(code.type_table().size()); for (const TypePtr& t : code.type_table()) { - types.emplace_back(t->python_str()); + types.emplace_back(t->annotation_str()); } // since the register location is embedded into the bytecode, pass the diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index d6ed2b2704ad..de22957bb1bb 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -50,7 +50,7 @@ void postSetStateValidate(const IValue& v) { "The field '{}' was left uninitialized after '__setstate__', " "but expected a value of type '{}'", attrName, - attrType->python_str())); + attrType->repr_str())); } } } diff --git a/torch/csrc/jit/serialization/import_legacy.cpp b/torch/csrc/jit/serialization/import_legacy.cpp index 69a59ba01e8a..df6041f01ddc 100644 --- a/torch/csrc/jit/serialization/import_legacy.cpp +++ b/torch/csrc/jit/serialization/import_legacy.cpp @@ -356,7 +356,7 @@ Module ScriptModuleDeserializer::LEGACY_convertModule( module_type->getAttributeName(i), "' was left unitialized after __setstate__, but expected a ", "value of type '", - v.type()->python_str(), + v.type()->repr_str(), "'"); } } diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 83dc84bcef13..552ff8c73952 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -497,7 +497,7 @@ void Pickler::endTypeTag(const IValue& ivalue) { // Push the dict type TORCH_INTERNAL_ASSERT(ivalue.type()); - pushString(ivalue.type()->python_str()); + pushString(ivalue.type()->annotation_str()); // Pop the dict and type into a tuple push(PickleOpCode::TUPLE2); @@ -658,7 +658,7 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { TORCH_CHECK( set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()), "'__setstate__' must return None, but found value of type", - set_schema.returns().at(0).type()->python_str()); + set_schema.returns().at(0).type()->annotation_str()); // Check that the return type of __getstate__ matches the input to // __setstate__ @@ -668,9 +668,9 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { TORCH_CHECK( get_type->isSubtypeOf(set_type), "'__getstate__'s return type (", - get_type->python_str(), + get_type->annotation_str(), ") does not match '__setstate__'s argument type (", - set_type->python_str(), + set_type->annotation_str(), ")"); return true; diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 080f20838e3c..8ceab37d20dd 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -500,7 +500,7 @@ struct PythonPrintImpl { indent(); body_ << useOf(lhs[i]); if (requiresAnnotation(lhs[i], rhs[i])) { - body_ << ": " << lhs[i]->type()->python_str(type_printer_); + body_ << ": " << lhs[i]->type()->annotation_str(type_printer_); } body_ << " = " << useOf(rhs[i]) << "\n"; } @@ -769,7 +769,7 @@ struct PythonPrintImpl { if (i > 0) { body_ << ", "; } - body_ << useOf(v) << ": " << v->type()->python_str(type_printer_); + body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_); } body_ << "):\n"; printBody(graph->block()); @@ -804,7 +804,7 @@ struct PythonPrintImpl { if (v.isTuple() && v.type()->expect()->schema()) { // print the namedtuple constructor and let rest of tuple printing // continue - ss << v.type()->expect()->python_str(type_printer_); + ss << v.type()->expect()->annotation_str(type_printer_); } return false; }; @@ -857,14 +857,14 @@ struct PythonPrintImpl { } break; case prim::Uninitialized: { stmt << "uninitialized(" - << node->output()->type()->python_str(type_printer_) << ")"; + << node->output()->type()->annotation_str(type_printer_) << ")"; } break; case prim::Constant: { if (node->outputs().size() == 1 && node->output()->type()->kind() == TypeKind::FunctionType) { auto fn = node->output()->type()->expect(); registerDependency(fn); - stmt << fn->python_str(type_printer_); + stmt << fn->annotation_str(type_printer_); } else if (!node->mustBeNone()) { IValue v = toIValue(node->output()).value(); printConstant(stmt, v); @@ -875,8 +875,9 @@ struct PythonPrintImpl { case aten::ScalarImplicit: case aten::FloatImplicit: case aten::IntImplicit: { - stmt << "annotate(" << node->output()->type()->python_str(type_printer_) - << ", " << useOf(node->input()) << ")"; + stmt << "annotate(" + << node->output()->type()->annotation_str(type_printer_) << ", " + << useOf(node->input()) << ")"; } break; case aten::Int: { printValueList(stmt, node->inputs(), "int(", ")"); @@ -902,7 +903,7 @@ struct PythonPrintImpl { case prim::TupleConstruct: { if (auto qualname = node->output()->type()->expect()->name()) { - stmt << node->output()->type()->python_str(type_printer_); + stmt << node->output()->type()->annotation_str(type_printer_); } printValueList( stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")"); @@ -922,13 +923,14 @@ struct PythonPrintImpl { // what type is supposed to be inside them if (node->inputs().size() == 0) { stmt << "annotate(" - << node->output()->type()->python_str(type_printer_) << ", [])"; + << node->output()->type()->annotation_str(type_printer_) + << ", [])"; // If we can't infer the type based on what's inside, explicitly // annotate it to disambiguate. // This happens for List[Tensor] vs. List[Optional[Tensor]] } else if (!elementTypeCanBeInferredFromMembers(elem_type)) { stmt << "annotate(" - << node->output()->type()->python_str(type_printer_) << ", "; + << node->output()->type()->annotation_str(type_printer_) << ", "; printValueList(stmt, node->inputs(), "[", "]"); stmt << ")"; // Otherwise just print a list @@ -947,7 +949,7 @@ struct PythonPrintImpl { !elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) || !elementTypeCanBeInferredFromMembers(dict_type->getValueType())) { stmt << "annotate(" - << node->output()->type()->python_str(type_printer_) << ", "; + << node->output()->type()->annotation_str(type_printer_) << ", "; printDict(stmt, node->inputs()); stmt << ")"; // Otherwise just print a dict @@ -957,8 +959,8 @@ struct PythonPrintImpl { } break; case prim::CreateObject: { const auto classType = node->output()->type()->expect(); - stmt << classType->python_str(type_printer_) << ".__new__(" - << classType->python_str(type_printer_) << ")"; + stmt << classType->annotation_str(type_printer_) << ".__new__(" + << classType->annotation_str(type_printer_) << ")"; } break; case prim::GetAttr: { const auto obj = node->inputs().at(0); @@ -1013,8 +1015,8 @@ struct PythonPrintImpl { if (node->input()->type()->isSubtypeOf(NoneType::get()) || node->input()->mustBeNone()) { auto input_type = OptionalType::create(node->output()->type()); - stmt << "annotate(" << input_type->python_str(type_printer_) << ", " - << useOf(node->input()) << ")"; + stmt << "annotate(" << input_type->annotation_str(type_printer_) + << ", " << useOf(node->input()) << ")"; } else { stmt << useOf(node->input()); } @@ -1027,14 +1029,14 @@ struct PythonPrintImpl { case prim::unchecked_unwrap_optional: case prim::unchecked_cast: { stmt << "unchecked_cast(" - << node->output()->type()->python_str(type_printer_) << ", " + << node->output()->type()->annotation_str(type_printer_) << ", " << useOf(node->input()) << ")"; } break; case prim::isinstance: { stmt << "isinstance(" << useOf(node->input()) << ", "; const auto& types = node->tys(attr::types); if (types.size() == 1) { - stmt << types.at(0)->python_str(type_printer_); + stmt << types.at(0)->annotation_str(type_printer_); } else { // check multiple things, e.g. (str, list, int) stmt << "("; @@ -1043,7 +1045,7 @@ struct PythonPrintImpl { if (!first) { stmt << ", "; } - stmt << typ->python_str(type_printer_); + stmt << typ->annotation_str(type_printer_); first = false; } stmt << ")"; @@ -1051,8 +1053,8 @@ struct PythonPrintImpl { stmt << ")"; } break; case prim::tolist: { - stmt << "annotate(" << node->output()->type()->python_str(type_printer_) - << ", "; + stmt << "annotate(" + << node->output()->type()->annotation_str(type_printer_) << ", "; stmt << useOf(node->input(0)) << ".tolist()" << ")"; } break; @@ -1172,11 +1174,11 @@ struct PythonPrintImpl { // the flag print_first_argument_type determines when to do this body_ << arg_name; if (print_first_argument_type) { - body_ << ": " << arg.type()->python_str(type_printer_); + body_ << ": " << arg.type()->annotation_str(type_printer_); } } else { body_ << ",\n " << arg_name << ": " - << arg.type()->python_str(type_printer_); + << arg.type()->annotation_str(type_printer_); } if (arg.default_value()) { printDefaultValue(arg, body_, *arg.default_value()); @@ -1184,7 +1186,8 @@ struct PythonPrintImpl { assignValue(*param_it++, arg_name); } - body_ << ") -> " << schema.returns().at(0).type()->python_str(type_printer_) + body_ << ") -> " + << schema.returns().at(0).type()->annotation_str(type_printer_) << ":\n"; printBody(graph.block()); } @@ -1275,12 +1278,12 @@ struct PythonPrintImpl { // Print out a direct manipulation of the annotations dict, like: // __annotations__["0"] = SomeType body_ << "__annotations__[" - << "\"" << name << "\"] = " << type->python_str(type_printer_) - << "\n"; + << "\"" << name + << "\"] = " << type->annotation_str(type_printer_) << "\n"; } else { // Otherwise: just emit a python 3 attribute annotation, like: // foo : SomeType - body_ << name << " : " << type->python_str(type_printer_) << "\n"; + body_ << name << " : " << type->annotation_str(type_printer_) << "\n"; } } @@ -1291,7 +1294,7 @@ struct PythonPrintImpl { indent(); body_ << name << " : " - << "Final[" << v.type()->python_str(type_printer_) << "] = "; + << "Final[" << v.type()->annotation_str(type_printer_) << "] = "; auto ss = std::make_shared(&source_range_stack_); printConstant(*ss, v); body_ << ss->str() << "\n"; @@ -1319,7 +1322,7 @@ struct PythonPrintImpl { TORCH_INTERNAL_ASSERT(attr.type()); indent(); body_ << attr.name() << " : " - << attr.type()->python_str(type_printer_) << "\n"; + << attr.type()->annotation_str(type_printer_) << "\n"; } } } else if (auto interfaceType = type->cast()) { @@ -1342,11 +1345,12 @@ struct PythonPrintImpl { auto type = arg.type(); registerClassDependencies(type); body_ << ", " << arg.name() << ": " - << type->python_str(type_printer_); + << type->annotation_str(type_printer_); } auto return_type = method.returns().at(0).type(); registerClassDependencies(return_type); - body_ << ") -> " << return_type->python_str(type_printer_) << ":\n"; + body_ << ") -> " << return_type->annotation_str(type_printer_) + << ":\n"; indent(); body_ << " pass\n"; } diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 235c7ee5fafa..d85cce76ed00 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -149,7 +149,7 @@ void restoreContainerTypeTags(IValue& ivalue, TypePtr type) { } else if (auto list_type = type->cast()) { ivalue.toList().unsafeSetElementType(list_type->getElementType()); } else { - AT_ERROR("Unknown type for tag restoration: " + type->python_str()); + AT_ERROR("Unknown type for tag restoration: " + type->annotation_str()); } } diff --git a/torch/custom_class.h b/torch/custom_class.h index eb723071ef3e..6251450240e1 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -202,7 +202,7 @@ class class_ { TORCH_CHECK( *first_arg_type == *classTypePtr, "self argument of __getstate__ must be the custom class type. Got ", - first_arg_type->python_str()); + first_arg_type->repr_str()); TORCH_CHECK( getstate_schema.returns().size() == 1, "__getstate__ should return exactly one value for serialization. Got: ", @@ -214,9 +214,9 @@ class class_ { (*arg_type == *ser_type), "__setstate__'s argument should be the same type as the " "return value of __getstate__. Got ", - arg_type->python_str(), + arg_type->repr_str(), " but expected ", - ser_type->python_str()); + ser_type->repr_str()); return *this; }