[JIT] Add Type::repr_str to return human-readable str (#39544)

Summary:
Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly

Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str().

annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated"

Closes https://github.com/pytorch/pytorch/issues/39449
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544

Differential Revision: D21978759

Pulled By: gmagogsfm

fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0
This commit is contained in:
Yanan Cao
2020-06-10 11:59:01 -07:00
committed by Facebook GitHub Bot
parent 4e892bd99c
commit c22bbb2124
37 changed files with 238 additions and 224 deletions

View File

@ -75,7 +75,7 @@ struct Argument {
}
return c10::str(
"Expected a value of type '",
type()->python_str(),
type()->repr_str(),
"' for argument '",
name(),
"' but instead found type '",

View File

@ -190,7 +190,7 @@ inline void FunctionSchema::checkArg(
TORCH_CHECK(
false,
formatTypeMismatchMsg(
argument, value.type()->python_str(), pos));
argument, value.type()->repr_str(), pos));
}
}

View File

@ -18,7 +18,7 @@ bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) {
namespace ivalue {
// This is in ivalue.cpp because we need to access Type::python_str, which
// This is in ivalue.cpp because we need to access Type::annotation_str, which
// is declared in jit_type.h
void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
// NB: doing pointer comparison here
@ -26,9 +26,9 @@ void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
// Type's, this needs to be changed!
TORCH_CHECK(actual_type == expected_type,
"Tried to convert an IValue of type ",
actual_type->python_str(),
actual_type->repr_str(),
" to custom class type ",
expected_type->python_str());
expected_type->repr_str());
}
CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
@ -291,7 +291,7 @@ std::ostream& printMaybeAnnotatedList(
auto list_elem_type = the_list.type()->expect<ListType>()->getElementType();
if (the_list.toListRef().size() == 0 ||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type()->python_str() << ", ";
out << "annotate(" << the_list.type()->annotation_str() << ", ";
printList(out, the_list.toListRef(), "[", "]", formatter);
out << ")";
return out;
@ -332,7 +332,7 @@ std::ostream& printMaybeAnnotatedDict(
auto value_type = the_dict.type()->cast<DictType>()->getValueType();
if (the_dict.toGenericDict().size() == 0 ||
!elementTypeCanBeInferredFromMembers(value_type)) {
out << "annotate(" << the_dict.type()->python_str() << ",";
out << "annotate(" << the_dict.type()->annotation_str() << ",";
printDict(out, the_dict.toGenericDict(), formatter) << ")";
} else {
return printDict(out, the_dict.toGenericDict(), formatter);

View File

@ -68,8 +68,8 @@ struct Type;
using TypePtr = std::shared_ptr<Type>;
using ConstTypePtr = std::shared_ptr<const Type>;
// Use this to customize how a Type is printed using `python_str()`. If
// c10::nullopt is returned, `python_str()` falls through to its default
// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
@ -81,7 +81,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
protected:
Type(TypeKind kind) : kind_(kind) {}
virtual std::string python_str_impl(TypePrinter printer) const {
virtual std::string annotation_str_impl(TypePrinter printer) const {
return str();
}
@ -94,7 +94,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
// if this returns false and the why_not stream is non-null, it contains
// additional details that describe why this is not a subtype of 'rhs'.
// This additional information should only contain details that are not obvious
// from the python_str() that describes the type. For instance it is clear that `int <: str` is false
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const;
virtual bool is_module() const;
@ -111,19 +111,26 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
//
// Takes a custom printer that users can pass in to customize the output of
// this method.
std::string python_str(TypePrinter printer) const {
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
return *renamed;
}
}
return python_str_impl(printer);
return annotation_str_impl(printer);
}
std::string python_str() const {
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
// debuggers out.
return python_str(nullptr);
return annotation_str(nullptr);
}
// Returns a human readable string that includes additional information like
// "type is inferred rather than explictly defined" to help construct more
// user-friendly messages.
virtual std::string repr_str() const {
return annotation_str();
}
TypeKind kind() const {
@ -306,9 +313,9 @@ struct CAFFE2_API OptionalType
private:
OptionalType(TypePtr elem) : SingleElementType(elem) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Optional[" << getElementType()->python_str(printer) << "]";
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
return ss.str();
}
};
@ -641,6 +648,10 @@ struct CAFFE2_API TensorType : public Type {
std::string str() const override;
std::string repr_str() const override {
return str() + (isInferredType() ? " (inferred)" : "");
}
c10::optional<size_t> numel() const {
size_t prod = 1;
const auto& shape = sizes();
@ -852,9 +863,9 @@ struct CAFFE2_API ListType
private:
ListType(TypePtr elem) : SingleElementType(elem) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "List[" << getElementType()->python_str(printer) << "]";
ss << "List[" << getElementType()->annotation_str(printer) << "]";
return ss.str();
}
};
@ -928,10 +939,10 @@ struct CAFFE2_API DictType : public Type {
has_free_variables(
key->hasFreeVariables() || value->hasFreeVariables()) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Dict[" << getKeyType()->python_str(printer) << ", "
<< getValueType()->python_str(printer) << "]";
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", "
<< getValueType()->annotation_str(printer) << "]";
return ss.str();
}
@ -964,9 +975,9 @@ struct CAFFE2_API FutureType
private:
FutureType(TypePtr elem) : SingleElementType(elem) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Future[" << getElementType()->python_str(printer) << "]";
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
return ss.str();
}
};
@ -996,9 +1007,9 @@ struct CAFFE2_API RRefType
private:
RRefType(TypePtr elem) : SingleElementType(elem) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "RRef[" << getElementType()->python_str(printer) << "]";
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
return ss.str();
}
};
@ -1105,7 +1116,7 @@ struct CAFFE2_API TupleType : public NamedType {
return true;
}
std::string python_str_impl(TypePrinter printer = nullptr) const override;
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
std::vector<TypePtr> elements_;
bool has_free_variables_;
@ -1135,7 +1146,7 @@ struct CAFFE2_API NumberType : public Type {
protected:
NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "number"; // technically not a valid python type, but
// we need to use it when parsing back in annotations
// for implicit conversions
@ -1164,7 +1175,7 @@ struct CAFFE2_API FloatType : public NumberType {
private:
FloatType() : NumberType(TypeKind::FloatType) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "float";
}
};
@ -1191,7 +1202,7 @@ struct CAFFE2_API IntType : public NumberType {
private:
IntType() : NumberType(TypeKind::IntType) {}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "int";
}
};
@ -1229,9 +1240,9 @@ struct CAFFE2_API StringType : public Type {
}
std::string str() const override {
// we only use "str" (not "string") in both FunctionSchema and script
return python_str();
return annotation_str();
}
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "str";
}
static const TypeKind Kind = TypeKind::StringType;
@ -1266,7 +1277,7 @@ struct CAFFE2_API FunctionType : public NamedType {
private:
FunctionType(torch::jit::Function* function);
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
const auto& n = name().value();
return n.qualifiedName();
}
@ -1749,7 +1760,7 @@ struct CAFFE2_API ClassType : public NamedType {
}
std::string str() const override {
return python_str();
return annotation_str();
}
const std::vector<torch::jit::Function*>& methods() const;
@ -1773,7 +1784,7 @@ struct CAFFE2_API ClassType : public NamedType {
auto type = findAttribute(name);
TORCH_CHECK(
type,
python_str(),
repr_str(),
" does not have an attribute with name '",
name,
"'");
@ -1815,7 +1826,7 @@ struct CAFFE2_API ClassType : public NamedType {
}
TORCH_CHECK(
false,
python_str(),
repr_str(),
" does not have an attribute with name '",
name,
"'");
@ -1867,9 +1878,9 @@ struct CAFFE2_API ClassType : public NamedType {
TypePtr atype = getAttribute(*slot_idx);
TORCH_CHECK(
ty->isSubtypeOf(atype),
ty->python_str(),
ty->repr_str(),
" is not compatible with the type ",
atype->python_str(),
atype->repr_str(),
" for the field '",
name,
"'");
@ -1904,7 +1915,7 @@ struct CAFFE2_API ClassType : public NamedType {
}
TORCH_CHECK(
false,
python_str(),
repr_str(),
" does not have constant field with the name '",
name,
"'");
@ -2014,7 +2025,7 @@ struct CAFFE2_API ClassType : public NamedType {
std::weak_ptr<CompilationUnit> cu,
bool is_module);
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
const auto& n = name().value();
return n.qualifiedName();
}
@ -2095,7 +2106,7 @@ struct CAFFE2_API InterfaceType : public NamedType {
const InterfaceType& rhs,
std::ostream* why_not);
std::string python_str_impl(TypePrinter printer = nullptr) const override {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return name()->qualifiedName();
}

View File

@ -268,9 +268,9 @@ c10::optional<TypePtr> unifyTypeList(
auto maybe_unified = unifyTypes(ret_type, elements.at(i));
if (!maybe_unified) {
why_not << "Could not unify type list since element " << i << " of type "
<< elements.at(i)->python_str()
<< elements.at(i)->repr_str()
<< " did not match the types before it ("
<< ret_type->python_str() << ")";
<< ret_type->repr_str() << ")";
return c10::nullopt;
}
ret_type = maybe_unified.value();
@ -300,8 +300,8 @@ MatchTypeReturn matchTypeVariables(
}
std::stringstream ss;
ss << "Type variable '" << vt->name() << "' previously matched to type "
<< it->second->python_str() << " is matched to type "
<< actual->python_str();
<< it->second->repr_str() << " is matched to type "
<< actual->repr_str();
return ss.str();
} else if (auto lt_formal = formal->cast<ListType>()) {
if (auto lt_actual = actual->cast<ListType>()) {
@ -322,8 +322,8 @@ MatchTypeReturn matchTypeVariables(
}
std::stringstream ss;
ss << "Cannot match " << lt_formal->python_str() << " to "
<< actual->python_str();
ss << "Cannot match " << lt_formal->repr_str() << " to "
<< actual->repr_str();
return ss.str();
} else if (auto tp_formal = formal->cast<TupleType>()) {
if (auto tp_actual = actual->cast<TupleType>()) {
@ -340,7 +340,7 @@ MatchTypeReturn matchTypeVariables(
return MatchTypeReturn::Success();
} else {
std::stringstream ss;
ss << "Cannot match a tuple to " << actual->python_str();
ss << "Cannot match a tuple to " << actual->repr_str();
return MatchTypeReturn(ss.str());
}
} else if (auto lt_formal = formal->cast<FutureType>()) {
@ -353,7 +353,7 @@ MatchTypeReturn matchTypeVariables(
return MatchTypeReturn::Success();
} else {
std::stringstream ss;
ss << "Cannot match a future to " << actual->python_str();
ss << "Cannot match a future to " << actual->repr_str();
return ss.str();
}
} else if (auto lt_formal = formal->cast<RRefType>()) {
@ -366,7 +366,7 @@ MatchTypeReturn matchTypeVariables(
return MatchTypeReturn::Success();
} else {
std::stringstream ss;
ss << "Cannot match a rref to " << actual->python_str();
ss << "Cannot match a rref to " << actual->repr_str();
return ss.str();
}
} else if (auto opt_formal = formal->cast<OptionalType>()) {
@ -403,12 +403,12 @@ MatchTypeReturn matchTypeVariables(
return MatchTypeReturn::Success();
} else {
std::stringstream ss;
ss << "Cannot match a dict to " << actual->python_str();
ss << "Cannot match a dict to " << actual->repr_str();
return ss.str();
}
}
AT_ERROR("Unhandled free variable container: ", formal->python_str());
AT_ERROR("Unhandled free variable container: ", formal->repr_str());
}
// change return types like List[List[t]] into List[List[int]]
@ -761,7 +761,7 @@ std::string TupleType::str() const {
}
return ss.str();
}
std::string TupleType::python_str_impl(TypePrinter printer) const {
std::string TupleType::annotation_str_impl(TypePrinter printer) const {
std::stringstream ss;
if (schema_ && name()) {
ss << name()->qualifiedName();
@ -770,7 +770,7 @@ std::string TupleType::python_str_impl(TypePrinter printer) const {
for(size_t i = 0; i < elements().size(); ++i) {
if(i > 0)
ss << ", ";
ss << elements()[i]->python_str(printer);
ss << elements()[i]->annotation_str(printer);
}
ss << "]";
}
@ -978,7 +978,7 @@ void ClassType::addMethod(torch::jit::Function* method) {
"Can't redefine method: ",
method->name(),
" on class: ",
python_str());
repr_str());
methods_.push_back(method);
}
@ -997,7 +997,7 @@ torch::jit::Function& ClassType::getMethod(const std::string& name) const {
"Couldn't find method: '",
name,
"' on class: '",
python_str(),
repr_str(),
"'");
return *method;
}
@ -1019,7 +1019,7 @@ void ClassType::unsafeRemoveMethod(const std::string& name) {
"Can't delete undefined method ",
name,
" on class: ",
python_str());
repr_str());
}
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
@ -1047,8 +1047,8 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
// Module Interface Type but the Class Type is not a Module Class Type
if (!is_module() && iface->is_module()) {
if (why_not) {
*why_not << "Class '" << python_str() << "' is not a subtype of "
<< "the module interface '" << rhs->python_str()
*why_not << "Class '" << repr_str() << "' is not a subtype of "
<< "the module interface '" << rhs->repr_str()
<< "' , only ScriptModule class can be subtype of module"
<< " interface.\n";
}
@ -1058,8 +1058,8 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
auto self_method = findMethod(schema.name());
if (!self_method) {
if (why_not) {
*why_not << "Class '" << python_str() << "' does not have method '"
<< schema.name() << "' but '" << rhs->python_str()
*why_not << "Class '" << repr_str() << "' does not have method '"
<< schema.name() << "' but '" << rhs->repr_str()
<< "' does.\n";
}
return false;
@ -1067,9 +1067,9 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
if (!self_method->getSchema().isSubtypeOf(
schema, /*is_method=*/true, why_not)) {
if (why_not) {
*why_not << "Method on class '" << python_str()
*why_not << "Method on class '" << repr_str()
<< "' (1) is not compatible with interface '"
<< rhs->python_str() << "' (2)\n"
<< rhs->repr_str() << "' (2)\n"
<< " (1) " << self_method->getSchema() << "\n"
<< " (2) " << schema << "\n";
}
@ -1091,8 +1091,8 @@ bool InterfaceType::isSubTypeImpl(
std::ostream* why_not) {
if (!lhs.is_module() && rhs.is_module()) {
if (why_not) {
*why_not << "Interface '" << lhs.python_str() << "' is not a subtype of "
<< "the module interface '" << rhs.python_str() << "'.\n";
*why_not << "Interface '" << lhs.repr_str() << "' is not a subtype of "
<< "the module interface '" << rhs.repr_str() << "'.\n";
}
return false;
}
@ -1100,17 +1100,17 @@ bool InterfaceType::isSubTypeImpl(
auto self_schema = lhs.getMethod(schema.name());
if (!self_schema) {
if (why_not) {
*why_not << "Interface '" << lhs.python_str()
*why_not << "Interface '" << lhs.repr_str()
<< "' does not have method '" << schema.name() << "' but interface '"
<< rhs.python_str() << "' does.\n";
<< rhs.repr_str() << "' does.\n";
}
return false;
}
if (!self_schema->isSubtypeOf(schema, /*is_method=*/true, why_not)) {
if (why_not) {
*why_not << "Method on interface '" << lhs.python_str()
*why_not << "Method on interface '" << lhs.repr_str()
<< "' (1) is not compatible with interface '"
<< rhs.python_str() << "' (2)\n"
<< rhs.repr_str() << "' (2)\n"
<< " (1) " << *self_schema << "\n"
<< " (2) " << schema << "\n";
return false;
@ -1178,7 +1178,7 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what)
" '",
name,
"' to ",
python_str(),
repr_str(),
" but a constant field of the same name already exists with value ",
constantValues_[i]);
}
@ -1192,9 +1192,9 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what)
" '",
name,
"' to ",
python_str(),
repr_str(),
" but an attribute field of the same name already exists with type ",
attributes_[i].getType()->python_str());
attributes_[i].getType()->repr_str());
}
}
@ -1264,7 +1264,7 @@ IValue ClassType::getConstant(const std::string& name) const {
const auto& v = findConstant(name);
TORCH_CHECK(
v.has_value(),
python_str(),
repr_str(),
" does not have a constant field with name '",
name,
"'");
@ -1275,7 +1275,7 @@ IValue ClassType::getConstant(size_t slot) const {
TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
TORCH_CHECK(
slot < constantValues_.size(),
python_str(),
repr_str(),
" does not have a constant slot of index ",
slot);
return constantValues_[slot];
@ -1336,9 +1336,9 @@ void checkNoAny(const Type& base, const char* what, const std::string& attrname,
" '",
attrname,
"' of type ",
attrtype->python_str(),
attrtype->repr_str(),
" to '",
base.python_str(),
base.repr_str(),
"' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples.");
}

View File

@ -19,12 +19,12 @@ TEST(TypeCustomPrinter, Basic) {
// Tensor types should be rewritten
torch::Tensor iv = torch::rand({2, 3});
const auto type = TensorType::create(iv);
EXPECT_EQ(type->python_str(), "Tensor");
EXPECT_EQ(type->python_str(printer), "CustomTensor");
EXPECT_EQ(type->annotation_str(), "Tensor");
EXPECT_EQ(type->annotation_str(printer), "CustomTensor");
// Unrelated types shoudl not be affected
const auto intType = IntType::create();
EXPECT_EQ(intType->python_str(printer), intType->python_str());
EXPECT_EQ(intType->annotation_str(printer), intType->annotation_str());
}
TEST(TypeCustomPrinter, ContainedTypes) {
@ -40,14 +40,14 @@ TEST(TypeCustomPrinter, ContainedTypes) {
// Contained types should work
const auto tupleType = TupleType::create({type, IntType::get(), type});
EXPECT_EQ(tupleType->python_str(), "Tuple[Tensor, int, Tensor]");
EXPECT_EQ(tupleType->annotation_str(), "Tuple[Tensor, int, Tensor]");
EXPECT_EQ(
tupleType->python_str(printer), "Tuple[CustomTensor, int, CustomTensor]");
tupleType->annotation_str(printer), "Tuple[CustomTensor, int, CustomTensor]");
const auto dictType = DictType::create(IntType::get(), type);
EXPECT_EQ(dictType->python_str(printer), "Dict[int, CustomTensor]");
EXPECT_EQ(dictType->annotation_str(printer), "Dict[int, CustomTensor]");
const auto listType = ListType::create(tupleType);
EXPECT_EQ(
listType->python_str(printer),
listType->annotation_str(printer),
"List[Tuple[CustomTensor, int, CustomTensor]]");
}
@ -67,11 +67,11 @@ TEST(TypeCustomPrinter, NamedTuples) {
const auto namedTupleType = TupleType::createNamed(
"my.named.tuple", {"foo", "bar"}, {type, IntType::get()});
EXPECT_EQ(namedTupleType->python_str(printer), "Rewritten");
EXPECT_EQ(namedTupleType->annotation_str(printer), "Rewritten");
// Put it inside another tuple, should still work
const auto outerTupleType = TupleType::create({IntType::get(), namedTupleType});
EXPECT_EQ(outerTupleType->python_str(printer), "Tuple[int, Rewritten]");
EXPECT_EQ(outerTupleType->annotation_str(printer), "Tuple[int, Rewritten]");
}
static TypePtr importType(

View File

@ -27,7 +27,7 @@ void testUnifyTypes() {
TORCH_INTERNAL_ASSERT(out);
std::stringstream ss;
ss << (*out)->python_str();
ss << (*out)->annotation_str();
testing::FileCheck()
.check("Optional[Tuple[Optional[int], Optional[int]]]")
->run(ss.str());

View File

@ -14,20 +14,20 @@ void testMobileTypeParser() {
std::string int_ps("int");
auto int_tp = c10::parseType(int_ps);
std::string int_tps = int_tp->python_str();
std::string int_tps = int_tp->annotation_str();
ASSERT_EQ(int_ps, int_tps);
std::string tuple_ps(
"Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
auto tuple_tp = c10::parseType(tuple_ps);
std::string tuple_tps = tuple_tp->python_str();
std::string tuple_tps = tuple_tp->annotation_str();
ASSERT_EQ(tuple_ps, tuple_tps);
std::string tuple_space_ps(
"Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]");
auto tuple_space_tp = c10::parseType(tuple_space_ps);
// tuple_space_tps should not have weird white spaces
std::string tuple_space_tps = tuple_space_tp->python_str();
std::string tuple_space_tps = tuple_space_tp->annotation_str();
ASSERT_EQ(tuple_ps, tuple_space_tps);
std::string typo_token("List[tensor]");

View File

@ -17623,7 +17623,8 @@ a")
def foo(a):
return a
with self.assertRaisesRegex(RuntimeError, "Inferred \'a\' to be of type \'Tensor"):
with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'"
r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")):
foo(1)
def test_type_comments_in_body(self):

View File

@ -70,7 +70,7 @@ TypePtr tryInferTypeWithTypeHint(
type_hint_ptr != nullptr &&
module.value().type()->isSubtypeOfExt(
type_hint_ptr, &subtype_check_msg),
module.value().type()->python_str(),
module.value().type()->repr_str(),
" is not a subtype of the type hint: ",
type_qualified_name.qualifiedName(),
", did you pass a valid interface type?\n",

View File

@ -336,12 +336,12 @@ c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
TORCH_INTERNAL_ASSERT(
ownerRRef->type()->isSubtypeOf(TensorType::get()),
"Expect OwnerRRef to be a sub-type of TensorType, but got ",
ownerRRef->type()->python_str());
ownerRRef->type()->repr_str());
} else {
TORCH_INTERNAL_ASSERT(
ownerRRef->type() == type,
"OwnerRRef type is ",
ownerRRef->type()->python_str(),
ownerRRef->type()->repr_str(),
", expected type is ",
type);
}

View File

@ -40,11 +40,11 @@ struct TORCH_API Object {
TORCH_CHECK(
v.type()->isSubtypeOf(expected),
"Expected a value of type '",
expected->python_str(),
expected->repr_str(),
"' for field '",
name,
"', but found '",
v.type()->python_str(),
v.type()->repr_str(),
"'");
_ivalue()->setSlot(*slot, std::move(v));
} else {
@ -61,7 +61,7 @@ struct TORCH_API Object {
}
TORCH_CHECK(
false,
_ivalue()->type()->python_str(),
_ivalue()->type()->repr_str(),
" does not have a field with name '",
name,
"'");

View File

@ -258,13 +258,13 @@ void ConcreteModuleType::dump() const {
}
std::cout << "\nAttributes: \n";
for (const auto& pr : data_.attributes_) {
std::cout << "\t" << pr.key() << ": " << pr.value().type_->python_str()
std::cout << "\t" << pr.key() << ": " << pr.value().type_->annotation_str()
<< "\n";
}
std::cout << "\nSubmodules: \n";
for (const auto& info : data_.modules_) {
std::cout << "\t" << info.name_ << ": "
<< info.meta_->getJitType()->python_str() << "\n";
<< info.meta_->getJitType()->annotation_str() << "\n";
}
std::cout << "\nOverloads: \n";
for (const auto& pr : data_.overloads_) {
@ -273,7 +273,7 @@ void ConcreteModuleType::dump() const {
std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
std::cout << "isPoisoned: " << isPoisoned << "\n";
if (jitType_) {
std::cout << "jit type: " << jitType_->python_str() << "\n";
std::cout << "jit type: " << jitType_->annotation_str() << "\n";
}
}

View File

@ -374,9 +374,9 @@ struct Environment {
if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) {
auto error = ErrorReport(loc);
error << "Variable '" << name << "' previously has type "
<< simple_parent->type()->python_str()
<< simple_parent->type()->repr_str()
<< " but is now being assigned to a value of type "
<< as_simple_value->type()->python_str();
<< as_simple_value->type()->repr_str();
// Special-cased error msg if we're trying to assign to a tensor list.
if (simple_parent->type()->kind() == TypeKind::ListType &&
@ -397,9 +397,9 @@ struct Environment {
if (!as_simple_value->type()->isSubtypeOf(annotated_type)) {
throw ErrorReport(loc)
<< "Variable '" << name << "' is annotated with type "
<< annotated_type->python_str()
<< annotated_type->repr_str()
<< " but is being assigned to a value of type "
<< as_simple_value->type()->python_str();
<< as_simple_value->type()->repr_str();
}
insertStore(name, loc, std::move(as_simple_value), annotated_type);
} else {
@ -972,8 +972,8 @@ struct to_ir {
if (!result->type()->isSubtypeOf(result_type)) {
throw ErrorReport(stmt.range())
<< "Return value was annotated as having type "
<< result_type->python_str() << " but is actually of type "
<< result->type()->python_str();
<< result_type->repr_str() << " but is actually of type "
<< result->type()->repr_str();
}
} else {
result_type = def_stack_.back().merged_return_type_;
@ -984,9 +984,9 @@ struct to_ir {
if (!merged_result_type) {
throw ErrorReport(stmt.range())
<< "Previous return statement returned a value of type "
<< result_type->python_str()
<< result_type->repr_str()
<< " but this return statement returns a value of type "
<< result->type()->python_str();
<< result->type()->repr_str();
}
result_type = merged_result_type.value();
}
@ -1208,7 +1208,7 @@ struct to_ir {
throw ErrorReport(loc)
<< "Expected list type annotation for list comprehension"
", found "
<< type_hint->python_str();
<< type_hint->repr_str();
}
list_value->setType(type_hint);
type_set = true;
@ -1326,8 +1326,8 @@ struct to_ir {
auto unified = unifyTypes(true_type, false_type);
if (!unified) {
throw ErrorReport(range)
<< "if-expression's true branch has type " << true_type->python_str()
<< " but false branch has type " << false_type->python_str();
<< "if-expression's true branch has type " << true_type->repr_str()
<< " but false branch has type " << false_type->repr_str();
}
// Add op outputs
@ -1342,13 +1342,13 @@ struct to_ir {
out = asSimple(bool_cast->call(loc, method, {v}, {}, 0));
} catch (...) {
throw ErrorReport(loc) << "Could not cast value of type "
<< v->type()->python_str() << " to bool";
<< v->type()->repr_str() << " to bool";
}
// cast value not response for checking output type
if (!out->type()->isSubtypeOf(BoolType::get())) {
throw ErrorReport(loc)
<< "expected a bool expression for condition but found "
<< out->type()->python_str();
<< out->type()->repr_str();
}
return out;
}
@ -1507,8 +1507,8 @@ struct to_ir {
if (!unified) {
ErrorReport error(loc);
error << "Type mismatch: " << x << " is set to type "
<< tv->type()->python_str() << " in the true branch"
<< " and type " << fv->type()->python_str()
<< tv->type()->repr_str() << " in the true branch"
<< " and type " << fv->type()->repr_str()
<< " in the false branch";
if (save_true->findInParentFrame(x) ||
save_false->findInParentFrame(x)) {
@ -1922,7 +1922,7 @@ struct to_ir {
magic_method_name = out_of_place_method_name;
} else {
throw ErrorReport(stmt.range())
<< "Cannot emit inplace op on " << type->python_str()
<< "Cannot emit inplace op on " << type->repr_str()
<< " since it does not define an " << in_place_method_name << " or "
<< out_of_place_method_name << " method";
}
@ -1954,7 +1954,7 @@ struct to_ir {
const TypePtr type = sliceable->type();
if (subscriptExprs.size() != 1) {
throw ErrorReport(subscriptExprs)
<< "Sliced expression not yet supported for " << type->python_str()
<< "Sliced expression not yet supported for " << type->repr_str()
<< " augmented assignment. "
<< "File a bug if you want this";
}
@ -1968,7 +1968,7 @@ struct to_ir {
if (elemType == nullptr) {
throw ErrorReport(lhs)
<< type->python_str() << " does not support augmented assignment.";
<< type->repr_str() << " does not support augmented assignment.";
}
const auto idxValue = emitExpr(subscriptExprs[0]);
const auto containerArg =
@ -2541,8 +2541,8 @@ struct to_ir {
std::stringstream why_not;
if (!expr->type()->isSubtypeOfExt(type, &why_not)) {
throw ErrorReport(apply.inputs())
<< "expected an expression of type " << type->python_str()
<< " but found " << expr->type()->python_str() << "\n"
<< "expected an expression of type " << type->repr_str()
<< " but found " << expr->type()->repr_str() << "\n"
<< why_not.str();
}
@ -3050,7 +3050,7 @@ struct to_ir {
// If the type hint was not a List[T] throw an error
throw ErrorReport(tree)
<< "Expected a List type hint but instead got "
<< type_hint->python_str();
<< type_hint->repr_str();
}
} else if (!values.empty()) {
std::stringstream ss;
@ -3068,8 +3068,8 @@ struct to_ir {
if (!v->type()->isSubtypeOfExt(elem_type, &ss)) {
throw ErrorReport(tree)
<< "Lists must contain only a single type, expected: "
<< elem_type->python_str() << " but found "
<< v->type()->python_str() << " instead.\n"
<< elem_type->repr_str() << " but found "
<< v->type()->repr_str() << " instead.\n"
<< ss.str();
}
}
@ -3120,8 +3120,8 @@ struct to_ir {
throw ErrorReport(trees[i])
<< "Dict " << what
<< " must contain only a single type, expected: "
<< type->python_str() << " but found "
<< values[i]->type()->python_str() << " instead.\n"
<< type->repr_str() << " but found "
<< values[i]->type()->repr_str() << " instead.\n"
<< ss.str();
}
}
@ -3329,7 +3329,7 @@ struct to_ir {
} else {
throw ErrorReport(loc)
<< "Unsupported operation: indexing tensor with unsupported index type '"
<< index->type()->python_str()
<< index->type()->repr_str()
<< "'. Only ints, slices, lists and tensors are supported";
}
};
@ -3485,7 +3485,7 @@ struct to_ir {
if (elems.size() == 0 ||
!convertibleToList(tuple_typ, ListType::create(elems[0]))) {
throw ErrorReport(loc)
<< "Cannot index into a " << tuple_typ->python_str()
<< "Cannot index into a " << tuple_typ->repr_str()
<< " with a non-integer literal because we cannot resolve the output type";
}
output_type = elems[0];

View File

@ -148,8 +148,8 @@ static Value* tryMatchArgument(
matchTypeVariables(arg.type(), value->type(), type_env);
if (!matched.success()) {
if (failure_messages) {
err() << "Could not match type " << value->type()->python_str() << " to "
<< arg.type()->python_str() << " in argument '" << arg.name()
err() << "Could not match type " << value->type()->repr_str() << " to "
<< arg.type()->repr_str() << " in argument '" << arg.name()
<< "': " << matched.reason() << ".\n";
}
return nullptr;
@ -157,9 +157,9 @@ static Value* tryMatchArgument(
const auto concrete_type = tryEvalTypeVariables(arg.type(), type_env);
if (!concrete_type) {
if (failure_messages) {
err() << "Type variables in type " << arg.type()->python_str()
err() << "Type variables in type " << arg.type()->repr_str()
<< " could not be inferred from actual type "
<< value->type()->python_str();
<< value->type()->repr_str();
}
return nullptr;
}
@ -172,7 +172,7 @@ static Value* tryMatchArgument(
concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) {
if (failure_messages) {
auto& ostream = err()
<< arg.formatTypeMismatchMsg(value->type()->python_str());
<< arg.formatTypeMismatchMsg(value->type()->repr_str());
if (auto pt = value->type()->cast<TensorType>()) {
if (pt->isInferredType()) {
@ -414,7 +414,7 @@ static c10::optional<MatchedSchema> tryMatchSchema(
auto return_types = fmap(returns, [&](const Argument& r) {
TypePtr result = tryEvalTypeVariables(r.type(), type_env);
TORCH_INTERNAL_ASSERT(
result, r.type()->python_str(), " has unbound type variables.");
result, r.type()->repr_str(), " has unbound type variables.");
return result;
});
// Codegen does not support return of namedtuples with undefined field names.

View File

@ -70,7 +70,7 @@ bool SimpleValue::hasAttr(
auto class_type = value_->type()->cast<ClassType>();
if (!class_type) {
throw ErrorReport(loc) << "hasattr's first argument must be an object, got "
<< value_->type()->python_str() << " instead";
<< value_->type()->repr_str() << " instead";
}
return class_type->hasMethod(field) || class_type->hasAttribute(field) ||
@ -176,7 +176,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
ErrorReport report(loc);
report << "Tried to access nonexistent attribute or method '" << field
<< "' of type '" << value_->type()->python_str() << "'.";
<< "' of type '" << value_->type()->repr_str() << "'.";
if (value_->type()->kind() == ClassType::Kind) {
report << " Did you forget to initialize an attribute in __init__()?";
}
@ -205,7 +205,7 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
graph->insertNode(graph->createListUnpack(value_, *size_hint));
return fmap(unpack->outputs(), make_simple_value);
}
throw ErrorReport(loc) << value_->type()->python_str()
throw ErrorReport(loc) << value_->type()->repr_str()
<< " cannot be used as a tuple";
}
@ -232,8 +232,7 @@ void SimpleValue::setAttr(
const auto classType = value_->type()->cast<ClassType>();
if (!classType) {
throw ErrorReport(loc) << "Tried to set an attribute: " << field
<< " on a non-class: "
<< value_->type()->python_str();
<< " on a non-class: " << value_->type()->repr_str();
}
auto expectedType = classType->findAttribute(field);
if (!expectedType) {
@ -255,7 +254,7 @@ void SimpleValue::setAttr(
throw ErrorReport(loc)
<< "Assignment to attribute '" << field
<< "' cannot be of a type that contains class "
<< "'" << classType->python_str() << "'.\n"
<< "'" << classType->repr_str() << "'.\n"
<< "Classes that recursively contain instances of themselves"
<< " are not yet supported";
}
@ -283,8 +282,8 @@ void SimpleValue::setAttr(
const auto newType = newValue->type();
if (!newType->isSubtypeOf(expectedType)) {
throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
<< expectedType->python_str() << " but got "
<< newType->python_str();
<< expectedType->repr_str() << " but got "
<< newType->repr_str();
}
auto& g = *m.graph();
@ -341,7 +340,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
val_type->isSubtypeOf(TensorType::get())) {
return g.insert(aten::len, {val}, {}, loc);
} else {
throw ErrorReport(loc) << "'" << val_type->python_str() << "'"
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
<< " object is not iterable";
}
}
@ -367,7 +366,7 @@ SugaredValuePtr SimpleValue::getitem(
} else if (auto class_type = val_type->cast<ClassType>()) {
return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1);
} else {
throw ErrorReport(loc) << "'" << val_type->python_str() << "'"
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
<< " object is not subscriptable";
}
}
@ -393,7 +392,7 @@ SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) {
}
return std::make_shared<SugaredTupleValue>(tup_sugared);
} else {
throw ErrorReport(loc) << "'" << type->python_str() << "'"
throw ErrorReport(loc) << "'" << type->repr_str() << "'"
<< " object is not iterable";
}
}
@ -407,7 +406,7 @@ RangeValue::RangeValue(
auto typ = inputs[i]->type();
if (!typ->cast<IntType>()) {
throw ErrorReport(loc)
<< "all inputs of range must be ints, found " << typ->python_str()
<< "all inputs of range must be ints, found " << typ->repr_str()
<< " in argument " << c10::guts::to_string(i);
}
}

View File

@ -147,7 +147,7 @@ struct TORCH_API SimpleValue : public SugaredValue {
SimpleValue(Value* value) : value_(value) {}
std::string kind() const override {
std::stringstream ss;
ss << "value of type '" << value_->type()->python_str() << "'";
ss << "value of type '" << value_->type()->annotation_str() << "'";
return ss.str();
}
Value* asValue(const SourceRange& range, Function& m) override {

View File

@ -367,7 +367,7 @@ static IValue addInput(
AT_ERROR(
"Only tensors or (possibly nested) dict or tuples of tensors can be "
"inputs to traced functions. Got ",
type->python_str());
type->repr_str());
}
}

View File

@ -1081,9 +1081,9 @@ void AliasDb::replaceWithNewValue(Value* existing, Value* new_value) {
*unshapedType(existing->type()) == *unshapedType(new_value->type()),
"Types must be strictly equal if you are replacing aliasing information. ",
"Got existing: '",
existing->type()->python_str(),
existing->type()->repr_str(),
"', new_value: '",
new_value->type()->python_str(),
new_value->type()->repr_str(),
"'");
if (!isMutableTypeInternal(existing)) {
return;
@ -1099,9 +1099,9 @@ void AliasDb::copyValue(Value* from, Value* to) {
*unshapedType(from->type()) == *unshapedType(to->type()),
"Types must be strictly equal if you are copying aliasing information. ",
"Got from: '",
from->type()->python_str(),
from->type()->repr_str(),
"', to: '",
to->type()->python_str(),
to->type()->repr_str(),
"'");
if (!isMutableTypeInternal(to)) {
return;
@ -1557,8 +1557,8 @@ void Lint(const AliasDb* db) {
auto it = db->elementMap_.find(v);
if (it == db->elementMap_.end()) {
failed = true;
ss << "Value %" << v->debugName() << " of type "
<< v->type()->python_str() << " wasn't found in the element map.\n"
ss << "Value %" << v->debugName() << " of type " << v->type()->repr_str()
<< " wasn't found in the element map.\n"
<< "It was defined in " << *v->node();
}
}

View File

@ -1599,9 +1599,9 @@ Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
TORCH_CHECK(
v->type()->isSubtypeOf(elem_type),
"Expected a list element that subtypes '",
elem_type->python_str(),
elem_type->repr_str(),
"' but got an element of type '",
v->type()->python_str(),
v->type()->repr_str(),
"'");
}
n->output()->setType(ListType::create(elem_type));
@ -1714,7 +1714,7 @@ Value* Graph::insertToList(Value* v, TypePtr type) {
} else {
TORCH_CHECK(
false,
ptr->python_str(),
ptr->repr_str(),
" is not one of the supported element types for tolist: int, float, bool");
}

View File

@ -368,10 +368,9 @@ void IRParser::parseOperator(Block* b) {
if (!schema_return_type->hasFreeVariables() &&
!v.type->isSubtypeOf(schema_return_type)) {
throw ErrorReport(source_range)
<< "Annotated type " << v.type->python_str()
<< "Annotated type " << v.type->repr_str()
<< " does not match schema type "
<< schema_return_type->python_str() << " for operator "
<< *schema;
<< schema_return_type->repr_str() << " for operator " << *schema;
}
vmap[v.name]->setType(v.type);
}

View File

@ -136,7 +136,7 @@ static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
"Unknown type ",
type->python_str(),
type->repr_str(),
" encountered in graph lowering. This type is not supported in ONNX export.");
result.emplace_back(
script::Object(obj.toObject()).run_method("__getstate__"));

View File

@ -351,9 +351,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
if (!unified_key) {
return InferredType(c10::str(
"Dictionary inputs to traced functions must have consistent type. Found ",
key_type->python_str(),
key_type->repr_str(),
" and ",
(entry_key_type_match.type())->python_str()));
(entry_key_type_match.type())->repr_str()));
}
// Try to infer the value type and unify it with the existing one
@ -366,9 +366,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
if (!unified_value) {
return InferredType(c10::str(
"Dictionary inputs to traced functions must have consistent type. Found ",
value_type->python_str(),
value_type->repr_str(),
" and ",
(entry_value_type_match.type())->python_str()));
(entry_value_type_match.type())->repr_str()));
}
key_type = *unified_key;
@ -395,9 +395,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
if (!unified_type) {
return InferredType(c10::str(
"List inputs to traced functions must have consistent element type. Found ",
element_type->python_str(),
element_type->repr_str(),
" and ",
(element_type_match.type())->python_str()));
(element_type_match.type())->repr_str()));
}
element_type = *unified_type;
}
@ -453,7 +453,7 @@ inline Stack toTraceableStack(const py::tuple& inputs) {
TORCH_CHECK(
isTraceableType(info.type()),
"Type '",
info.type()->python_str(),
info.type()->repr_str(),
"' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
" Tuples of Tensors can be traced");
return info.toTuple()->elements();
@ -542,7 +542,7 @@ inline IValue toIValue(
"Object ",
py::str(obj),
" had a different number of elements than type ",
type->python_str()));
type->repr_str()));
}
std::vector<IValue> values;
values.reserve(tuple_size);
@ -669,7 +669,7 @@ inline IValue toIValue(
"Object ",
py::str(obj),
" is not compatible with interface ",
interfaceType->python_str(),
interfaceType->repr_str(),
"\n",
why_not.str()));
}
@ -694,7 +694,7 @@ inline IValue toIValue(
return py::cast<double>(obj);
} else {
throw py::cast_error(
c10::str("Cannot cast ", py::str(obj), " to ", type->python_str()));
c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
}
}
case TypeKind::RRefType: {
@ -726,7 +726,7 @@ inline IValue toIValue(
break;
}
throw py::cast_error(c10::str(
"toIValue() cannot handle converting to type: ", type->python_str()));
"toIValue() cannot handle converting to type: ", type->repr_str()));
}
// Small wrapper around getting the type name string from Python to make

View File

@ -18,7 +18,7 @@ py::object ScriptClass::__call__(py::args args, py::kwargs kwargs) {
fmt::format(
"Custom C++ class: '{}' does not have an '__init__' method bound. "
"Did you forget to add '.def(torch::init<...>)' to its registration?",
instance.type()->python_str()));
instance.type()->repr_str()));
Method init_method(instance._ivalue(), init_fn);
invokeScriptMethodFromPython(init_method, std::move(args), std::move(kwargs));
return py::cast(instance);

View File

@ -641,7 +641,7 @@ void initPythonIRBindings(PyObject* module_) {
using ::c10::Type;
py::class_<Type, std::shared_ptr<Type>>(m, "Type")
.def("__repr__", [](Type& t) { return t.python_str(); })
.def("__repr__", [](Type& t) { return t.annotation_str(); })
.def(
"str",
[](Type& t) {

View File

@ -671,7 +671,7 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
TORCH_CHECK(
type->isSubtypeOf(tt),
"Can't to redefine NamedTuple: ",
tt->python_str());
tt->repr_str());
return type;
}
get_python_cu()->register_type(tt);

View File

@ -260,7 +260,7 @@ FunctionSchema getSchemaWithNameAndDefaults(
c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
if (!value) {
ErrorReport error(range);
error << "Expected a default value of type " << arg.type()->python_str()
error << "Expected a default value of type " << arg.type()->repr_str()
<< " on parameter \"" << arg.name() << "\".";
if (arg.is_inferred_type()) {
error << "Because \"" << arg.name()
@ -799,7 +799,7 @@ void initJitScriptBindings(PyObject* module) {
TORCH_INTERNAL_ASSERT(
setstate_schema.arguments().size() == 2,
"__setstate__ method for class ",
class_type->python_str(),
class_type->repr_str(),
" must have exactly 2 arguments!");
auto state_type = setstate_schema.arguments().at(1).type();
(*setstate_method)(Stack{toIValue(state, state_type)});

View File

@ -158,7 +158,7 @@ IValue tensorToListRecursive(
} else {
TORCH_CHECK(
false,
ty->python_str(),
ty->repr_str(),
" is not one of the supported types for tolist: int, float, bool");
}
}

View File

@ -1319,16 +1319,16 @@ Function* checkSortSchema(const c10::TypePtr& list_element_type) {
return method;
}
}
error_str << "To sort a list of " << class_type->python_str()
error_str << "To sort a list of " << class_type->repr_str()
<< " it must define a "
<< "__lt__ method with two inputs of type "
<< class_type->python_str() << " that "
<< class_type->repr_str() << " that "
<< "returns a bool";
} else {
error_str << "To sort a list of " << list_element_type->python_str()
error_str << "To sort a list of " << list_element_type->repr_str()
<< " must be of Tensors, ints, floats, bools or "
<< "a User Defined Class that defines the __lt__ compare method"
<< ", got list of " << list_element_type->python_str() << "\n";
<< ", got list of " << list_element_type->repr_str() << "\n";
}
throw std::runtime_error(error_str.str());
}

View File

@ -31,7 +31,7 @@ void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
elem_type != BoolType::get()) {
std::stringstream error;
error << "Input must be of ints, floats, or bools, "
<< "got " << elem_type->python_str();
<< "got " << elem_type->repr_str();
// special case empty list torch.tensor([])
if (elem_type->isSubtypeOf(TensorType::get())) {
if (empty_list) {
@ -220,11 +220,11 @@ int createTensorFromList(Stack& stack) {
tensor.numel() == 0) {
TORCH_WARN(
"Creating a tensor from an empty ",
elem_type->python_str(),
elem_type->repr_str(),
"list will create a tensor of default floating point type (currently ",
default_type,
") in python but a tensor of type ",
elem_type->python_str(),
elem_type->repr_str(),
" in torchscript.\n",
"Pass in a dtype argument to ensure consistent behavior");
}

View File

@ -111,7 +111,7 @@ c10::IValue getFunctionTuple(const Function& func) {
std::vector<IValue> types;
types.reserve(code.type_table().size());
for (const TypePtr& t : code.type_table()) {
types.emplace_back(t->python_str());
types.emplace_back(t->annotation_str());
}
// since the register location is embedded into the bytecode, pass the

View File

@ -50,7 +50,7 @@ void postSetStateValidate(const IValue& v) {
"The field '{}' was left uninitialized after '__setstate__', "
"but expected a value of type '{}'",
attrName,
attrType->python_str()));
attrType->repr_str()));
}
}
}

View File

@ -356,7 +356,7 @@ Module ScriptModuleDeserializer::LEGACY_convertModule(
module_type->getAttributeName(i),
"' was left unitialized after __setstate__, but expected a ",
"value of type '",
v.type()->python_str(),
v.type()->repr_str(),
"'");
}
}

View File

@ -497,7 +497,7 @@ void Pickler::endTypeTag(const IValue& ivalue) {
// Push the dict type
TORCH_INTERNAL_ASSERT(ivalue.type());
pushString(ivalue.type()->python_str());
pushString(ivalue.type()->annotation_str());
// Pop the dict and type into a tuple
push<PickleOpCode>(PickleOpCode::TUPLE2);
@ -658,7 +658,7 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
TORCH_CHECK(
set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()),
"'__setstate__' must return None, but found value of type",
set_schema.returns().at(0).type()->python_str());
set_schema.returns().at(0).type()->annotation_str());
// Check that the return type of __getstate__ matches the input to
// __setstate__
@ -668,9 +668,9 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
TORCH_CHECK(
get_type->isSubtypeOf(set_type),
"'__getstate__'s return type (",
get_type->python_str(),
get_type->annotation_str(),
") does not match '__setstate__'s argument type (",
set_type->python_str(),
set_type->annotation_str(),
")");
return true;

View File

@ -500,7 +500,7 @@ struct PythonPrintImpl {
indent();
body_ << useOf(lhs[i]);
if (requiresAnnotation(lhs[i], rhs[i])) {
body_ << ": " << lhs[i]->type()->python_str(type_printer_);
body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
}
body_ << " = " << useOf(rhs[i]) << "\n";
}
@ -769,7 +769,7 @@ struct PythonPrintImpl {
if (i > 0) {
body_ << ", ";
}
body_ << useOf(v) << ": " << v->type()->python_str(type_printer_);
body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
}
body_ << "):\n";
printBody(graph->block());
@ -804,7 +804,7 @@ struct PythonPrintImpl {
if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
// print the namedtuple constructor and let rest of tuple printing
// continue
ss << v.type()->expect<TupleType>()->python_str(type_printer_);
ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
}
return false;
};
@ -857,14 +857,14 @@ struct PythonPrintImpl {
} break;
case prim::Uninitialized: {
stmt << "uninitialized("
<< node->output()->type()->python_str(type_printer_) << ")";
<< node->output()->type()->annotation_str(type_printer_) << ")";
} break;
case prim::Constant: {
if (node->outputs().size() == 1 &&
node->output()->type()->kind() == TypeKind::FunctionType) {
auto fn = node->output()->type()->expect<FunctionType>();
registerDependency(fn);
stmt << fn->python_str(type_printer_);
stmt << fn->annotation_str(type_printer_);
} else if (!node->mustBeNone()) {
IValue v = toIValue(node->output()).value();
printConstant(stmt, v);
@ -875,8 +875,9 @@ struct PythonPrintImpl {
case aten::ScalarImplicit:
case aten::FloatImplicit:
case aten::IntImplicit: {
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
<< ", " << useOf(node->input()) << ")";
stmt << "annotate("
<< node->output()->type()->annotation_str(type_printer_) << ", "
<< useOf(node->input()) << ")";
} break;
case aten::Int: {
printValueList(stmt, node->inputs(), "int(", ")");
@ -902,7 +903,7 @@ struct PythonPrintImpl {
case prim::TupleConstruct: {
if (auto qualname =
node->output()->type()->expect<TupleType>()->name()) {
stmt << node->output()->type()->python_str(type_printer_);
stmt << node->output()->type()->annotation_str(type_printer_);
}
printValueList(
stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
@ -922,13 +923,14 @@ struct PythonPrintImpl {
// what type is supposed to be inside them
if (node->inputs().size() == 0) {
stmt << "annotate("
<< node->output()->type()->python_str(type_printer_) << ", [])";
<< node->output()->type()->annotation_str(type_printer_)
<< ", [])";
// If we can't infer the type based on what's inside, explicitly
// annotate it to disambiguate.
// This happens for List[Tensor] vs. List[Optional[Tensor]]
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
stmt << "annotate("
<< node->output()->type()->python_str(type_printer_) << ", ";
<< node->output()->type()->annotation_str(type_printer_) << ", ";
printValueList(stmt, node->inputs(), "[", "]");
stmt << ")";
// Otherwise just print a list
@ -947,7 +949,7 @@ struct PythonPrintImpl {
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
stmt << "annotate("
<< node->output()->type()->python_str(type_printer_) << ", ";
<< node->output()->type()->annotation_str(type_printer_) << ", ";
printDict(stmt, node->inputs());
stmt << ")";
// Otherwise just print a dict
@ -957,8 +959,8 @@ struct PythonPrintImpl {
} break;
case prim::CreateObject: {
const auto classType = node->output()->type()->expect<ClassType>();
stmt << classType->python_str(type_printer_) << ".__new__("
<< classType->python_str(type_printer_) << ")";
stmt << classType->annotation_str(type_printer_) << ".__new__("
<< classType->annotation_str(type_printer_) << ")";
} break;
case prim::GetAttr: {
const auto obj = node->inputs().at(0);
@ -1013,8 +1015,8 @@ struct PythonPrintImpl {
if (node->input()->type()->isSubtypeOf(NoneType::get()) ||
node->input()->mustBeNone()) {
auto input_type = OptionalType::create(node->output()->type());
stmt << "annotate(" << input_type->python_str(type_printer_) << ", "
<< useOf(node->input()) << ")";
stmt << "annotate(" << input_type->annotation_str(type_printer_)
<< ", " << useOf(node->input()) << ")";
} else {
stmt << useOf(node->input());
}
@ -1027,14 +1029,14 @@ struct PythonPrintImpl {
case prim::unchecked_unwrap_optional:
case prim::unchecked_cast: {
stmt << "unchecked_cast("
<< node->output()->type()->python_str(type_printer_) << ", "
<< node->output()->type()->annotation_str(type_printer_) << ", "
<< useOf(node->input()) << ")";
} break;
case prim::isinstance: {
stmt << "isinstance(" << useOf(node->input()) << ", ";
const auto& types = node->tys(attr::types);
if (types.size() == 1) {
stmt << types.at(0)->python_str(type_printer_);
stmt << types.at(0)->annotation_str(type_printer_);
} else {
// check multiple things, e.g. (str, list, int)
stmt << "(";
@ -1043,7 +1045,7 @@ struct PythonPrintImpl {
if (!first) {
stmt << ", ";
}
stmt << typ->python_str(type_printer_);
stmt << typ->annotation_str(type_printer_);
first = false;
}
stmt << ")";
@ -1051,8 +1053,8 @@ struct PythonPrintImpl {
stmt << ")";
} break;
case prim::tolist: {
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
<< ", ";
stmt << "annotate("
<< node->output()->type()->annotation_str(type_printer_) << ", ";
stmt << useOf(node->input(0)) << ".tolist()"
<< ")";
} break;
@ -1172,11 +1174,11 @@ struct PythonPrintImpl {
// the flag print_first_argument_type determines when to do this
body_ << arg_name;
if (print_first_argument_type) {
body_ << ": " << arg.type()->python_str(type_printer_);
body_ << ": " << arg.type()->annotation_str(type_printer_);
}
} else {
body_ << ",\n " << arg_name << ": "
<< arg.type()->python_str(type_printer_);
<< arg.type()->annotation_str(type_printer_);
}
if (arg.default_value()) {
printDefaultValue(arg, body_, *arg.default_value());
@ -1184,7 +1186,8 @@ struct PythonPrintImpl {
assignValue(*param_it++, arg_name);
}
body_ << ") -> " << schema.returns().at(0).type()->python_str(type_printer_)
body_ << ") -> "
<< schema.returns().at(0).type()->annotation_str(type_printer_)
<< ":\n";
printBody(graph.block());
}
@ -1275,12 +1278,12 @@ struct PythonPrintImpl {
// Print out a direct manipulation of the annotations dict, like:
// __annotations__["0"] = SomeType
body_ << "__annotations__["
<< "\"" << name << "\"] = " << type->python_str(type_printer_)
<< "\n";
<< "\"" << name
<< "\"] = " << type->annotation_str(type_printer_) << "\n";
} else {
// Otherwise: just emit a python 3 attribute annotation, like:
// foo : SomeType
body_ << name << " : " << type->python_str(type_printer_) << "\n";
body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
}
}
@ -1291,7 +1294,7 @@ struct PythonPrintImpl {
indent();
body_ << name << " : "
<< "Final[" << v.type()->python_str(type_printer_) << "] = ";
<< "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
printConstant(*ss, v);
body_ << ss->str() << "\n";
@ -1319,7 +1322,7 @@ struct PythonPrintImpl {
TORCH_INTERNAL_ASSERT(attr.type());
indent();
body_ << attr.name() << " : "
<< attr.type()->python_str(type_printer_) << "\n";
<< attr.type()->annotation_str(type_printer_) << "\n";
}
}
} else if (auto interfaceType = type->cast<InterfaceType>()) {
@ -1342,11 +1345,12 @@ struct PythonPrintImpl {
auto type = arg.type();
registerClassDependencies(type);
body_ << ", " << arg.name() << ": "
<< type->python_str(type_printer_);
<< type->annotation_str(type_printer_);
}
auto return_type = method.returns().at(0).type();
registerClassDependencies(return_type);
body_ << ") -> " << return_type->python_str(type_printer_) << ":\n";
body_ << ") -> " << return_type->annotation_str(type_printer_)
<< ":\n";
indent();
body_ << " pass\n";
}

View File

@ -149,7 +149,7 @@ void restoreContainerTypeTags(IValue& ivalue, TypePtr type) {
} else if (auto list_type = type->cast<ListType>()) {
ivalue.toList().unsafeSetElementType(list_type->getElementType());
} else {
AT_ERROR("Unknown type for tag restoration: " + type->python_str());
AT_ERROR("Unknown type for tag restoration: " + type->annotation_str());
}
}

View File

@ -202,7 +202,7 @@ class class_ {
TORCH_CHECK(
*first_arg_type == *classTypePtr,
"self argument of __getstate__ must be the custom class type. Got ",
first_arg_type->python_str());
first_arg_type->repr_str());
TORCH_CHECK(
getstate_schema.returns().size() == 1,
"__getstate__ should return exactly one value for serialization. Got: ",
@ -214,9 +214,9 @@ class class_ {
(*arg_type == *ser_type),
"__setstate__'s argument should be the same type as the "
"return value of __getstate__. Got ",
arg_type->python_str(),
arg_type->repr_str(),
" but expected ",
ser_type->python_str());
ser_type->repr_str());
return *this;
}