mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
[JIT] Add Type::repr_str to return human-readable str (#39544)
Summary: Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str(). annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated" Closes https://github.com/pytorch/pytorch/issues/39449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544 Differential Revision: D21978759 Pulled By: gmagogsfm fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4e892bd99c
commit
c22bbb2124
@ -75,7 +75,7 @@ struct Argument {
|
||||
}
|
||||
return c10::str(
|
||||
"Expected a value of type '",
|
||||
type()->python_str(),
|
||||
type()->repr_str(),
|
||||
"' for argument '",
|
||||
name(),
|
||||
"' but instead found type '",
|
||||
|
@ -190,7 +190,7 @@ inline void FunctionSchema::checkArg(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
formatTypeMismatchMsg(
|
||||
argument, value.type()->python_str(), pos));
|
||||
argument, value.type()->repr_str(), pos));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) {
|
||||
|
||||
namespace ivalue {
|
||||
|
||||
// This is in ivalue.cpp because we need to access Type::python_str, which
|
||||
// This is in ivalue.cpp because we need to access Type::annotation_str, which
|
||||
// is declared in jit_type.h
|
||||
void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
|
||||
// NB: doing pointer comparison here
|
||||
@ -26,9 +26,9 @@ void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
|
||||
// Type's, this needs to be changed!
|
||||
TORCH_CHECK(actual_type == expected_type,
|
||||
"Tried to convert an IValue of type ",
|
||||
actual_type->python_str(),
|
||||
actual_type->repr_str(),
|
||||
" to custom class type ",
|
||||
expected_type->python_str());
|
||||
expected_type->repr_str());
|
||||
}
|
||||
|
||||
CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
|
||||
@ -291,7 +291,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||
auto list_elem_type = the_list.type()->expect<ListType>()->getElementType();
|
||||
if (the_list.toListRef().size() == 0 ||
|
||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||
out << "annotate(" << the_list.type()->python_str() << ", ";
|
||||
out << "annotate(" << the_list.type()->annotation_str() << ", ";
|
||||
printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
out << ")";
|
||||
return out;
|
||||
@ -332,7 +332,7 @@ std::ostream& printMaybeAnnotatedDict(
|
||||
auto value_type = the_dict.type()->cast<DictType>()->getValueType();
|
||||
if (the_dict.toGenericDict().size() == 0 ||
|
||||
!elementTypeCanBeInferredFromMembers(value_type)) {
|
||||
out << "annotate(" << the_dict.type()->python_str() << ",";
|
||||
out << "annotate(" << the_dict.type()->annotation_str() << ",";
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ")";
|
||||
} else {
|
||||
return printDict(out, the_dict.toGenericDict(), formatter);
|
||||
|
@ -68,8 +68,8 @@ struct Type;
|
||||
using TypePtr = std::shared_ptr<Type>;
|
||||
using ConstTypePtr = std::shared_ptr<const Type>;
|
||||
|
||||
// Use this to customize how a Type is printed using `python_str()`. If
|
||||
// c10::nullopt is returned, `python_str()` falls through to its default
|
||||
// Use this to customize how a Type is printed using `annotation_str()`. If
|
||||
// c10::nullopt is returned, `annotation_str()` falls through to its default
|
||||
// implementation.
|
||||
using TypePrinter =
|
||||
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
|
||||
@ -81,7 +81,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
|
||||
protected:
|
||||
Type(TypeKind kind) : kind_(kind) {}
|
||||
|
||||
virtual std::string python_str_impl(TypePrinter printer) const {
|
||||
virtual std::string annotation_str_impl(TypePrinter printer) const {
|
||||
return str();
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
|
||||
// if this returns false and the why_not stream is non-null, it contains
|
||||
// additional details that describe why this is not a subtype of 'rhs'.
|
||||
// This additional information should only contain details that are not obvious
|
||||
// from the python_str() that describes the type. For instance it is clear that `int <: str` is false
|
||||
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
|
||||
// but not clear why `Foo <: InterfaceBar` might be false.
|
||||
virtual bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const;
|
||||
virtual bool is_module() const;
|
||||
@ -111,19 +111,26 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
|
||||
//
|
||||
// Takes a custom printer that users can pass in to customize the output of
|
||||
// this method.
|
||||
std::string python_str(TypePrinter printer) const {
|
||||
std::string annotation_str(TypePrinter printer) const {
|
||||
if (printer) {
|
||||
// the printer can return nullopt to fall through to the default impl
|
||||
if (auto renamed = printer(shared_from_this())) {
|
||||
return *renamed;
|
||||
}
|
||||
}
|
||||
return python_str_impl(printer);
|
||||
return annotation_str_impl(printer);
|
||||
}
|
||||
std::string python_str() const {
|
||||
std::string annotation_str() const {
|
||||
// Overload instead of define a default value for `printer` to help
|
||||
// debuggers out.
|
||||
return python_str(nullptr);
|
||||
return annotation_str(nullptr);
|
||||
}
|
||||
|
||||
// Returns a human readable string that includes additional information like
|
||||
// "type is inferred rather than explictly defined" to help construct more
|
||||
// user-friendly messages.
|
||||
virtual std::string repr_str() const {
|
||||
return annotation_str();
|
||||
}
|
||||
|
||||
TypeKind kind() const {
|
||||
@ -306,9 +313,9 @@ struct CAFFE2_API OptionalType
|
||||
private:
|
||||
OptionalType(TypePtr elem) : SingleElementType(elem) {}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Optional[" << getElementType()->python_str(printer) << "]";
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -641,6 +648,10 @@ struct CAFFE2_API TensorType : public Type {
|
||||
|
||||
std::string str() const override;
|
||||
|
||||
std::string repr_str() const override {
|
||||
return str() + (isInferredType() ? " (inferred)" : "");
|
||||
}
|
||||
|
||||
c10::optional<size_t> numel() const {
|
||||
size_t prod = 1;
|
||||
const auto& shape = sizes();
|
||||
@ -852,9 +863,9 @@ struct CAFFE2_API ListType
|
||||
private:
|
||||
ListType(TypePtr elem) : SingleElementType(elem) {}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "List[" << getElementType()->python_str(printer) << "]";
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -928,10 +939,10 @@ struct CAFFE2_API DictType : public Type {
|
||||
has_free_variables(
|
||||
key->hasFreeVariables() || value->hasFreeVariables()) {}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Dict[" << getKeyType()->python_str(printer) << ", "
|
||||
<< getValueType()->python_str(printer) << "]";
|
||||
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", "
|
||||
<< getValueType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -964,9 +975,9 @@ struct CAFFE2_API FutureType
|
||||
private:
|
||||
FutureType(TypePtr elem) : SingleElementType(elem) {}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future[" << getElementType()->python_str(printer) << "]";
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -996,9 +1007,9 @@ struct CAFFE2_API RRefType
|
||||
private:
|
||||
RRefType(TypePtr elem) : SingleElementType(elem) {}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef[" << getElementType()->python_str(printer) << "]";
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1105,7 +1116,7 @@ struct CAFFE2_API TupleType : public NamedType {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override;
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
|
||||
|
||||
std::vector<TypePtr> elements_;
|
||||
bool has_free_variables_;
|
||||
@ -1135,7 +1146,7 @@ struct CAFFE2_API NumberType : public Type {
|
||||
protected:
|
||||
NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
return "number"; // technically not a valid python type, but
|
||||
// we need to use it when parsing back in annotations
|
||||
// for implicit conversions
|
||||
@ -1164,7 +1175,7 @@ struct CAFFE2_API FloatType : public NumberType {
|
||||
|
||||
private:
|
||||
FloatType() : NumberType(TypeKind::FloatType) {}
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
return "float";
|
||||
}
|
||||
};
|
||||
@ -1191,7 +1202,7 @@ struct CAFFE2_API IntType : public NumberType {
|
||||
|
||||
private:
|
||||
IntType() : NumberType(TypeKind::IntType) {}
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
return "int";
|
||||
}
|
||||
};
|
||||
@ -1229,9 +1240,9 @@ struct CAFFE2_API StringType : public Type {
|
||||
}
|
||||
std::string str() const override {
|
||||
// we only use "str" (not "string") in both FunctionSchema and script
|
||||
return python_str();
|
||||
return annotation_str();
|
||||
}
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
return "str";
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::StringType;
|
||||
@ -1266,7 +1277,7 @@ struct CAFFE2_API FunctionType : public NamedType {
|
||||
|
||||
private:
|
||||
FunctionType(torch::jit::Function* function);
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
const auto& n = name().value();
|
||||
return n.qualifiedName();
|
||||
}
|
||||
@ -1749,7 +1760,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
}
|
||||
|
||||
std::string str() const override {
|
||||
return python_str();
|
||||
return annotation_str();
|
||||
}
|
||||
|
||||
const std::vector<torch::jit::Function*>& methods() const;
|
||||
@ -1773,7 +1784,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
auto type = findAttribute(name);
|
||||
TORCH_CHECK(
|
||||
type,
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" does not have an attribute with name '",
|
||||
name,
|
||||
"'");
|
||||
@ -1815,7 +1826,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" does not have an attribute with name '",
|
||||
name,
|
||||
"'");
|
||||
@ -1867,9 +1878,9 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
TypePtr atype = getAttribute(*slot_idx);
|
||||
TORCH_CHECK(
|
||||
ty->isSubtypeOf(atype),
|
||||
ty->python_str(),
|
||||
ty->repr_str(),
|
||||
" is not compatible with the type ",
|
||||
atype->python_str(),
|
||||
atype->repr_str(),
|
||||
" for the field '",
|
||||
name,
|
||||
"'");
|
||||
@ -1904,7 +1915,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" does not have constant field with the name '",
|
||||
name,
|
||||
"'");
|
||||
@ -2014,7 +2025,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module);
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
const auto& n = name().value();
|
||||
return n.qualifiedName();
|
||||
}
|
||||
@ -2095,7 +2106,7 @@ struct CAFFE2_API InterfaceType : public NamedType {
|
||||
const InterfaceType& rhs,
|
||||
std::ostream* why_not);
|
||||
|
||||
std::string python_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
return name()->qualifiedName();
|
||||
}
|
||||
|
||||
|
@ -268,9 +268,9 @@ c10::optional<TypePtr> unifyTypeList(
|
||||
auto maybe_unified = unifyTypes(ret_type, elements.at(i));
|
||||
if (!maybe_unified) {
|
||||
why_not << "Could not unify type list since element " << i << " of type "
|
||||
<< elements.at(i)->python_str()
|
||||
<< elements.at(i)->repr_str()
|
||||
<< " did not match the types before it ("
|
||||
<< ret_type->python_str() << ")";
|
||||
<< ret_type->repr_str() << ")";
|
||||
return c10::nullopt;
|
||||
}
|
||||
ret_type = maybe_unified.value();
|
||||
@ -300,8 +300,8 @@ MatchTypeReturn matchTypeVariables(
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "Type variable '" << vt->name() << "' previously matched to type "
|
||||
<< it->second->python_str() << " is matched to type "
|
||||
<< actual->python_str();
|
||||
<< it->second->repr_str() << " is matched to type "
|
||||
<< actual->repr_str();
|
||||
return ss.str();
|
||||
} else if (auto lt_formal = formal->cast<ListType>()) {
|
||||
if (auto lt_actual = actual->cast<ListType>()) {
|
||||
@ -322,8 +322,8 @@ MatchTypeReturn matchTypeVariables(
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match " << lt_formal->python_str() << " to "
|
||||
<< actual->python_str();
|
||||
ss << "Cannot match " << lt_formal->repr_str() << " to "
|
||||
<< actual->repr_str();
|
||||
return ss.str();
|
||||
} else if (auto tp_formal = formal->cast<TupleType>()) {
|
||||
if (auto tp_actual = actual->cast<TupleType>()) {
|
||||
@ -340,7 +340,7 @@ MatchTypeReturn matchTypeVariables(
|
||||
return MatchTypeReturn::Success();
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match a tuple to " << actual->python_str();
|
||||
ss << "Cannot match a tuple to " << actual->repr_str();
|
||||
return MatchTypeReturn(ss.str());
|
||||
}
|
||||
} else if (auto lt_formal = formal->cast<FutureType>()) {
|
||||
@ -353,7 +353,7 @@ MatchTypeReturn matchTypeVariables(
|
||||
return MatchTypeReturn::Success();
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match a future to " << actual->python_str();
|
||||
ss << "Cannot match a future to " << actual->repr_str();
|
||||
return ss.str();
|
||||
}
|
||||
} else if (auto lt_formal = formal->cast<RRefType>()) {
|
||||
@ -366,7 +366,7 @@ MatchTypeReturn matchTypeVariables(
|
||||
return MatchTypeReturn::Success();
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match a rref to " << actual->python_str();
|
||||
ss << "Cannot match a rref to " << actual->repr_str();
|
||||
return ss.str();
|
||||
}
|
||||
} else if (auto opt_formal = formal->cast<OptionalType>()) {
|
||||
@ -403,12 +403,12 @@ MatchTypeReturn matchTypeVariables(
|
||||
return MatchTypeReturn::Success();
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match a dict to " << actual->python_str();
|
||||
ss << "Cannot match a dict to " << actual->repr_str();
|
||||
return ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
AT_ERROR("Unhandled free variable container: ", formal->python_str());
|
||||
AT_ERROR("Unhandled free variable container: ", formal->repr_str());
|
||||
}
|
||||
|
||||
// change return types like List[List[t]] into List[List[int]]
|
||||
@ -761,7 +761,7 @@ std::string TupleType::str() const {
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
std::string TupleType::python_str_impl(TypePrinter printer) const {
|
||||
std::string TupleType::annotation_str_impl(TypePrinter printer) const {
|
||||
std::stringstream ss;
|
||||
if (schema_ && name()) {
|
||||
ss << name()->qualifiedName();
|
||||
@ -770,7 +770,7 @@ std::string TupleType::python_str_impl(TypePrinter printer) const {
|
||||
for(size_t i = 0; i < elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
ss << ", ";
|
||||
ss << elements()[i]->python_str(printer);
|
||||
ss << elements()[i]->annotation_str(printer);
|
||||
}
|
||||
ss << "]";
|
||||
}
|
||||
@ -978,7 +978,7 @@ void ClassType::addMethod(torch::jit::Function* method) {
|
||||
"Can't redefine method: ",
|
||||
method->name(),
|
||||
" on class: ",
|
||||
python_str());
|
||||
repr_str());
|
||||
methods_.push_back(method);
|
||||
}
|
||||
|
||||
@ -997,7 +997,7 @@ torch::jit::Function& ClassType::getMethod(const std::string& name) const {
|
||||
"Couldn't find method: '",
|
||||
name,
|
||||
"' on class: '",
|
||||
python_str(),
|
||||
repr_str(),
|
||||
"'");
|
||||
return *method;
|
||||
}
|
||||
@ -1019,7 +1019,7 @@ void ClassType::unsafeRemoveMethod(const std::string& name) {
|
||||
"Can't delete undefined method ",
|
||||
name,
|
||||
" on class: ",
|
||||
python_str());
|
||||
repr_str());
|
||||
}
|
||||
|
||||
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
|
||||
@ -1047,8 +1047,8 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
// Module Interface Type but the Class Type is not a Module Class Type
|
||||
if (!is_module() && iface->is_module()) {
|
||||
if (why_not) {
|
||||
*why_not << "Class '" << python_str() << "' is not a subtype of "
|
||||
<< "the module interface '" << rhs->python_str()
|
||||
*why_not << "Class '" << repr_str() << "' is not a subtype of "
|
||||
<< "the module interface '" << rhs->repr_str()
|
||||
<< "' , only ScriptModule class can be subtype of module"
|
||||
<< " interface.\n";
|
||||
}
|
||||
@ -1058,8 +1058,8 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
auto self_method = findMethod(schema.name());
|
||||
if (!self_method) {
|
||||
if (why_not) {
|
||||
*why_not << "Class '" << python_str() << "' does not have method '"
|
||||
<< schema.name() << "' but '" << rhs->python_str()
|
||||
*why_not << "Class '" << repr_str() << "' does not have method '"
|
||||
<< schema.name() << "' but '" << rhs->repr_str()
|
||||
<< "' does.\n";
|
||||
}
|
||||
return false;
|
||||
@ -1067,9 +1067,9 @@ bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
if (!self_method->getSchema().isSubtypeOf(
|
||||
schema, /*is_method=*/true, why_not)) {
|
||||
if (why_not) {
|
||||
*why_not << "Method on class '" << python_str()
|
||||
*why_not << "Method on class '" << repr_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs->python_str() << "' (2)\n"
|
||||
<< rhs->repr_str() << "' (2)\n"
|
||||
<< " (1) " << self_method->getSchema() << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
}
|
||||
@ -1091,8 +1091,8 @@ bool InterfaceType::isSubTypeImpl(
|
||||
std::ostream* why_not) {
|
||||
if (!lhs.is_module() && rhs.is_module()) {
|
||||
if (why_not) {
|
||||
*why_not << "Interface '" << lhs.python_str() << "' is not a subtype of "
|
||||
<< "the module interface '" << rhs.python_str() << "'.\n";
|
||||
*why_not << "Interface '" << lhs.repr_str() << "' is not a subtype of "
|
||||
<< "the module interface '" << rhs.repr_str() << "'.\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -1100,17 +1100,17 @@ bool InterfaceType::isSubTypeImpl(
|
||||
auto self_schema = lhs.getMethod(schema.name());
|
||||
if (!self_schema) {
|
||||
if (why_not) {
|
||||
*why_not << "Interface '" << lhs.python_str()
|
||||
*why_not << "Interface '" << lhs.repr_str()
|
||||
<< "' does not have method '" << schema.name() << "' but interface '"
|
||||
<< rhs.python_str() << "' does.\n";
|
||||
<< rhs.repr_str() << "' does.\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!self_schema->isSubtypeOf(schema, /*is_method=*/true, why_not)) {
|
||||
if (why_not) {
|
||||
*why_not << "Method on interface '" << lhs.python_str()
|
||||
*why_not << "Method on interface '" << lhs.repr_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs.python_str() << "' (2)\n"
|
||||
<< rhs.repr_str() << "' (2)\n"
|
||||
<< " (1) " << *self_schema << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
return false;
|
||||
@ -1178,7 +1178,7 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what)
|
||||
" '",
|
||||
name,
|
||||
"' to ",
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" but a constant field of the same name already exists with value ",
|
||||
constantValues_[i]);
|
||||
}
|
||||
@ -1192,9 +1192,9 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what)
|
||||
" '",
|
||||
name,
|
||||
"' to ",
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" but an attribute field of the same name already exists with type ",
|
||||
attributes_[i].getType()->python_str());
|
||||
attributes_[i].getType()->repr_str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1264,7 +1264,7 @@ IValue ClassType::getConstant(const std::string& name) const {
|
||||
const auto& v = findConstant(name);
|
||||
TORCH_CHECK(
|
||||
v.has_value(),
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" does not have a constant field with name '",
|
||||
name,
|
||||
"'");
|
||||
@ -1275,7 +1275,7 @@ IValue ClassType::getConstant(size_t slot) const {
|
||||
TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
|
||||
TORCH_CHECK(
|
||||
slot < constantValues_.size(),
|
||||
python_str(),
|
||||
repr_str(),
|
||||
" does not have a constant slot of index ",
|
||||
slot);
|
||||
return constantValues_[slot];
|
||||
@ -1336,9 +1336,9 @@ void checkNoAny(const Type& base, const char* what, const std::string& attrname,
|
||||
" '",
|
||||
attrname,
|
||||
"' of type ",
|
||||
attrtype->python_str(),
|
||||
attrtype->repr_str(),
|
||||
" to '",
|
||||
base.python_str(),
|
||||
base.repr_str(),
|
||||
"' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples.");
|
||||
}
|
||||
|
||||
|
@ -19,12 +19,12 @@ TEST(TypeCustomPrinter, Basic) {
|
||||
// Tensor types should be rewritten
|
||||
torch::Tensor iv = torch::rand({2, 3});
|
||||
const auto type = TensorType::create(iv);
|
||||
EXPECT_EQ(type->python_str(), "Tensor");
|
||||
EXPECT_EQ(type->python_str(printer), "CustomTensor");
|
||||
EXPECT_EQ(type->annotation_str(), "Tensor");
|
||||
EXPECT_EQ(type->annotation_str(printer), "CustomTensor");
|
||||
|
||||
// Unrelated types shoudl not be affected
|
||||
const auto intType = IntType::create();
|
||||
EXPECT_EQ(intType->python_str(printer), intType->python_str());
|
||||
EXPECT_EQ(intType->annotation_str(printer), intType->annotation_str());
|
||||
}
|
||||
|
||||
TEST(TypeCustomPrinter, ContainedTypes) {
|
||||
@ -40,14 +40,14 @@ TEST(TypeCustomPrinter, ContainedTypes) {
|
||||
|
||||
// Contained types should work
|
||||
const auto tupleType = TupleType::create({type, IntType::get(), type});
|
||||
EXPECT_EQ(tupleType->python_str(), "Tuple[Tensor, int, Tensor]");
|
||||
EXPECT_EQ(tupleType->annotation_str(), "Tuple[Tensor, int, Tensor]");
|
||||
EXPECT_EQ(
|
||||
tupleType->python_str(printer), "Tuple[CustomTensor, int, CustomTensor]");
|
||||
tupleType->annotation_str(printer), "Tuple[CustomTensor, int, CustomTensor]");
|
||||
const auto dictType = DictType::create(IntType::get(), type);
|
||||
EXPECT_EQ(dictType->python_str(printer), "Dict[int, CustomTensor]");
|
||||
EXPECT_EQ(dictType->annotation_str(printer), "Dict[int, CustomTensor]");
|
||||
const auto listType = ListType::create(tupleType);
|
||||
EXPECT_EQ(
|
||||
listType->python_str(printer),
|
||||
listType->annotation_str(printer),
|
||||
"List[Tuple[CustomTensor, int, CustomTensor]]");
|
||||
}
|
||||
|
||||
@ -67,11 +67,11 @@ TEST(TypeCustomPrinter, NamedTuples) {
|
||||
|
||||
const auto namedTupleType = TupleType::createNamed(
|
||||
"my.named.tuple", {"foo", "bar"}, {type, IntType::get()});
|
||||
EXPECT_EQ(namedTupleType->python_str(printer), "Rewritten");
|
||||
EXPECT_EQ(namedTupleType->annotation_str(printer), "Rewritten");
|
||||
|
||||
// Put it inside another tuple, should still work
|
||||
const auto outerTupleType = TupleType::create({IntType::get(), namedTupleType});
|
||||
EXPECT_EQ(outerTupleType->python_str(printer), "Tuple[int, Rewritten]");
|
||||
EXPECT_EQ(outerTupleType->annotation_str(printer), "Tuple[int, Rewritten]");
|
||||
}
|
||||
|
||||
static TypePtr importType(
|
||||
|
@ -27,7 +27,7 @@ void testUnifyTypes() {
|
||||
TORCH_INTERNAL_ASSERT(out);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << (*out)->python_str();
|
||||
ss << (*out)->annotation_str();
|
||||
testing::FileCheck()
|
||||
.check("Optional[Tuple[Optional[int], Optional[int]]]")
|
||||
->run(ss.str());
|
||||
|
@ -14,20 +14,20 @@ void testMobileTypeParser() {
|
||||
|
||||
std::string int_ps("int");
|
||||
auto int_tp = c10::parseType(int_ps);
|
||||
std::string int_tps = int_tp->python_str();
|
||||
std::string int_tps = int_tp->annotation_str();
|
||||
ASSERT_EQ(int_ps, int_tps);
|
||||
|
||||
std::string tuple_ps(
|
||||
"Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
|
||||
auto tuple_tp = c10::parseType(tuple_ps);
|
||||
std::string tuple_tps = tuple_tp->python_str();
|
||||
std::string tuple_tps = tuple_tp->annotation_str();
|
||||
ASSERT_EQ(tuple_ps, tuple_tps);
|
||||
|
||||
std::string tuple_space_ps(
|
||||
"Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]");
|
||||
auto tuple_space_tp = c10::parseType(tuple_space_ps);
|
||||
// tuple_space_tps should not have weird white spaces
|
||||
std::string tuple_space_tps = tuple_space_tp->python_str();
|
||||
std::string tuple_space_tps = tuple_space_tp->annotation_str();
|
||||
ASSERT_EQ(tuple_ps, tuple_space_tps);
|
||||
|
||||
std::string typo_token("List[tensor]");
|
||||
|
@ -17623,7 +17623,8 @@ a")
|
||||
def foo(a):
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Inferred \'a\' to be of type \'Tensor"):
|
||||
with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'"
|
||||
r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")):
|
||||
foo(1)
|
||||
|
||||
def test_type_comments_in_body(self):
|
||||
|
@ -70,7 +70,7 @@ TypePtr tryInferTypeWithTypeHint(
|
||||
type_hint_ptr != nullptr &&
|
||||
module.value().type()->isSubtypeOfExt(
|
||||
type_hint_ptr, &subtype_check_msg),
|
||||
module.value().type()->python_str(),
|
||||
module.value().type()->repr_str(),
|
||||
" is not a subtype of the type hint: ",
|
||||
type_qualified_name.qualifiedName(),
|
||||
", did you pass a valid interface type?\n",
|
||||
|
@ -336,12 +336,12 @@ c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
ownerRRef->type()->isSubtypeOf(TensorType::get()),
|
||||
"Expect OwnerRRef to be a sub-type of TensorType, but got ",
|
||||
ownerRRef->type()->python_str());
|
||||
ownerRRef->type()->repr_str());
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
ownerRRef->type() == type,
|
||||
"OwnerRRef type is ",
|
||||
ownerRRef->type()->python_str(),
|
||||
ownerRRef->type()->repr_str(),
|
||||
", expected type is ",
|
||||
type);
|
||||
}
|
||||
|
@ -40,11 +40,11 @@ struct TORCH_API Object {
|
||||
TORCH_CHECK(
|
||||
v.type()->isSubtypeOf(expected),
|
||||
"Expected a value of type '",
|
||||
expected->python_str(),
|
||||
expected->repr_str(),
|
||||
"' for field '",
|
||||
name,
|
||||
"', but found '",
|
||||
v.type()->python_str(),
|
||||
v.type()->repr_str(),
|
||||
"'");
|
||||
_ivalue()->setSlot(*slot, std::move(v));
|
||||
} else {
|
||||
@ -61,7 +61,7 @@ struct TORCH_API Object {
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
_ivalue()->type()->python_str(),
|
||||
_ivalue()->type()->repr_str(),
|
||||
" does not have a field with name '",
|
||||
name,
|
||||
"'");
|
||||
|
@ -258,13 +258,13 @@ void ConcreteModuleType::dump() const {
|
||||
}
|
||||
std::cout << "\nAttributes: \n";
|
||||
for (const auto& pr : data_.attributes_) {
|
||||
std::cout << "\t" << pr.key() << ": " << pr.value().type_->python_str()
|
||||
std::cout << "\t" << pr.key() << ": " << pr.value().type_->annotation_str()
|
||||
<< "\n";
|
||||
}
|
||||
std::cout << "\nSubmodules: \n";
|
||||
for (const auto& info : data_.modules_) {
|
||||
std::cout << "\t" << info.name_ << ": "
|
||||
<< info.meta_->getJitType()->python_str() << "\n";
|
||||
<< info.meta_->getJitType()->annotation_str() << "\n";
|
||||
}
|
||||
std::cout << "\nOverloads: \n";
|
||||
for (const auto& pr : data_.overloads_) {
|
||||
@ -273,7 +273,7 @@ void ConcreteModuleType::dump() const {
|
||||
std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
|
||||
std::cout << "isPoisoned: " << isPoisoned << "\n";
|
||||
if (jitType_) {
|
||||
std::cout << "jit type: " << jitType_->python_str() << "\n";
|
||||
std::cout << "jit type: " << jitType_->annotation_str() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -374,9 +374,9 @@ struct Environment {
|
||||
if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) {
|
||||
auto error = ErrorReport(loc);
|
||||
error << "Variable '" << name << "' previously has type "
|
||||
<< simple_parent->type()->python_str()
|
||||
<< simple_parent->type()->repr_str()
|
||||
<< " but is now being assigned to a value of type "
|
||||
<< as_simple_value->type()->python_str();
|
||||
<< as_simple_value->type()->repr_str();
|
||||
|
||||
// Special-cased error msg if we're trying to assign to a tensor list.
|
||||
if (simple_parent->type()->kind() == TypeKind::ListType &&
|
||||
@ -397,9 +397,9 @@ struct Environment {
|
||||
if (!as_simple_value->type()->isSubtypeOf(annotated_type)) {
|
||||
throw ErrorReport(loc)
|
||||
<< "Variable '" << name << "' is annotated with type "
|
||||
<< annotated_type->python_str()
|
||||
<< annotated_type->repr_str()
|
||||
<< " but is being assigned to a value of type "
|
||||
<< as_simple_value->type()->python_str();
|
||||
<< as_simple_value->type()->repr_str();
|
||||
}
|
||||
insertStore(name, loc, std::move(as_simple_value), annotated_type);
|
||||
} else {
|
||||
@ -972,8 +972,8 @@ struct to_ir {
|
||||
if (!result->type()->isSubtypeOf(result_type)) {
|
||||
throw ErrorReport(stmt.range())
|
||||
<< "Return value was annotated as having type "
|
||||
<< result_type->python_str() << " but is actually of type "
|
||||
<< result->type()->python_str();
|
||||
<< result_type->repr_str() << " but is actually of type "
|
||||
<< result->type()->repr_str();
|
||||
}
|
||||
} else {
|
||||
result_type = def_stack_.back().merged_return_type_;
|
||||
@ -984,9 +984,9 @@ struct to_ir {
|
||||
if (!merged_result_type) {
|
||||
throw ErrorReport(stmt.range())
|
||||
<< "Previous return statement returned a value of type "
|
||||
<< result_type->python_str()
|
||||
<< result_type->repr_str()
|
||||
<< " but this return statement returns a value of type "
|
||||
<< result->type()->python_str();
|
||||
<< result->type()->repr_str();
|
||||
}
|
||||
result_type = merged_result_type.value();
|
||||
}
|
||||
@ -1208,7 +1208,7 @@ struct to_ir {
|
||||
throw ErrorReport(loc)
|
||||
<< "Expected list type annotation for list comprehension"
|
||||
", found "
|
||||
<< type_hint->python_str();
|
||||
<< type_hint->repr_str();
|
||||
}
|
||||
list_value->setType(type_hint);
|
||||
type_set = true;
|
||||
@ -1326,8 +1326,8 @@ struct to_ir {
|
||||
auto unified = unifyTypes(true_type, false_type);
|
||||
if (!unified) {
|
||||
throw ErrorReport(range)
|
||||
<< "if-expression's true branch has type " << true_type->python_str()
|
||||
<< " but false branch has type " << false_type->python_str();
|
||||
<< "if-expression's true branch has type " << true_type->repr_str()
|
||||
<< " but false branch has type " << false_type->repr_str();
|
||||
}
|
||||
|
||||
// Add op outputs
|
||||
@ -1342,13 +1342,13 @@ struct to_ir {
|
||||
out = asSimple(bool_cast->call(loc, method, {v}, {}, 0));
|
||||
} catch (...) {
|
||||
throw ErrorReport(loc) << "Could not cast value of type "
|
||||
<< v->type()->python_str() << " to bool";
|
||||
<< v->type()->repr_str() << " to bool";
|
||||
}
|
||||
// cast value not response for checking output type
|
||||
if (!out->type()->isSubtypeOf(BoolType::get())) {
|
||||
throw ErrorReport(loc)
|
||||
<< "expected a bool expression for condition but found "
|
||||
<< out->type()->python_str();
|
||||
<< out->type()->repr_str();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1507,8 +1507,8 @@ struct to_ir {
|
||||
if (!unified) {
|
||||
ErrorReport error(loc);
|
||||
error << "Type mismatch: " << x << " is set to type "
|
||||
<< tv->type()->python_str() << " in the true branch"
|
||||
<< " and type " << fv->type()->python_str()
|
||||
<< tv->type()->repr_str() << " in the true branch"
|
||||
<< " and type " << fv->type()->repr_str()
|
||||
<< " in the false branch";
|
||||
if (save_true->findInParentFrame(x) ||
|
||||
save_false->findInParentFrame(x)) {
|
||||
@ -1922,7 +1922,7 @@ struct to_ir {
|
||||
magic_method_name = out_of_place_method_name;
|
||||
} else {
|
||||
throw ErrorReport(stmt.range())
|
||||
<< "Cannot emit inplace op on " << type->python_str()
|
||||
<< "Cannot emit inplace op on " << type->repr_str()
|
||||
<< " since it does not define an " << in_place_method_name << " or "
|
||||
<< out_of_place_method_name << " method";
|
||||
}
|
||||
@ -1954,7 +1954,7 @@ struct to_ir {
|
||||
const TypePtr type = sliceable->type();
|
||||
if (subscriptExprs.size() != 1) {
|
||||
throw ErrorReport(subscriptExprs)
|
||||
<< "Sliced expression not yet supported for " << type->python_str()
|
||||
<< "Sliced expression not yet supported for " << type->repr_str()
|
||||
<< " augmented assignment. "
|
||||
<< "File a bug if you want this";
|
||||
}
|
||||
@ -1968,7 +1968,7 @@ struct to_ir {
|
||||
|
||||
if (elemType == nullptr) {
|
||||
throw ErrorReport(lhs)
|
||||
<< type->python_str() << " does not support augmented assignment.";
|
||||
<< type->repr_str() << " does not support augmented assignment.";
|
||||
}
|
||||
const auto idxValue = emitExpr(subscriptExprs[0]);
|
||||
const auto containerArg =
|
||||
@ -2541,8 +2541,8 @@ struct to_ir {
|
||||
std::stringstream why_not;
|
||||
if (!expr->type()->isSubtypeOfExt(type, &why_not)) {
|
||||
throw ErrorReport(apply.inputs())
|
||||
<< "expected an expression of type " << type->python_str()
|
||||
<< " but found " << expr->type()->python_str() << "\n"
|
||||
<< "expected an expression of type " << type->repr_str()
|
||||
<< " but found " << expr->type()->repr_str() << "\n"
|
||||
<< why_not.str();
|
||||
}
|
||||
|
||||
@ -3050,7 +3050,7 @@ struct to_ir {
|
||||
// If the type hint was not a List[T] throw an error
|
||||
throw ErrorReport(tree)
|
||||
<< "Expected a List type hint but instead got "
|
||||
<< type_hint->python_str();
|
||||
<< type_hint->repr_str();
|
||||
}
|
||||
} else if (!values.empty()) {
|
||||
std::stringstream ss;
|
||||
@ -3068,8 +3068,8 @@ struct to_ir {
|
||||
if (!v->type()->isSubtypeOfExt(elem_type, &ss)) {
|
||||
throw ErrorReport(tree)
|
||||
<< "Lists must contain only a single type, expected: "
|
||||
<< elem_type->python_str() << " but found "
|
||||
<< v->type()->python_str() << " instead.\n"
|
||||
<< elem_type->repr_str() << " but found "
|
||||
<< v->type()->repr_str() << " instead.\n"
|
||||
<< ss.str();
|
||||
}
|
||||
}
|
||||
@ -3120,8 +3120,8 @@ struct to_ir {
|
||||
throw ErrorReport(trees[i])
|
||||
<< "Dict " << what
|
||||
<< " must contain only a single type, expected: "
|
||||
<< type->python_str() << " but found "
|
||||
<< values[i]->type()->python_str() << " instead.\n"
|
||||
<< type->repr_str() << " but found "
|
||||
<< values[i]->type()->repr_str() << " instead.\n"
|
||||
<< ss.str();
|
||||
}
|
||||
}
|
||||
@ -3329,7 +3329,7 @@ struct to_ir {
|
||||
} else {
|
||||
throw ErrorReport(loc)
|
||||
<< "Unsupported operation: indexing tensor with unsupported index type '"
|
||||
<< index->type()->python_str()
|
||||
<< index->type()->repr_str()
|
||||
<< "'. Only ints, slices, lists and tensors are supported";
|
||||
}
|
||||
};
|
||||
@ -3485,7 +3485,7 @@ struct to_ir {
|
||||
if (elems.size() == 0 ||
|
||||
!convertibleToList(tuple_typ, ListType::create(elems[0]))) {
|
||||
throw ErrorReport(loc)
|
||||
<< "Cannot index into a " << tuple_typ->python_str()
|
||||
<< "Cannot index into a " << tuple_typ->repr_str()
|
||||
<< " with a non-integer literal because we cannot resolve the output type";
|
||||
}
|
||||
output_type = elems[0];
|
||||
|
@ -148,8 +148,8 @@ static Value* tryMatchArgument(
|
||||
matchTypeVariables(arg.type(), value->type(), type_env);
|
||||
if (!matched.success()) {
|
||||
if (failure_messages) {
|
||||
err() << "Could not match type " << value->type()->python_str() << " to "
|
||||
<< arg.type()->python_str() << " in argument '" << arg.name()
|
||||
err() << "Could not match type " << value->type()->repr_str() << " to "
|
||||
<< arg.type()->repr_str() << " in argument '" << arg.name()
|
||||
<< "': " << matched.reason() << ".\n";
|
||||
}
|
||||
return nullptr;
|
||||
@ -157,9 +157,9 @@ static Value* tryMatchArgument(
|
||||
const auto concrete_type = tryEvalTypeVariables(arg.type(), type_env);
|
||||
if (!concrete_type) {
|
||||
if (failure_messages) {
|
||||
err() << "Type variables in type " << arg.type()->python_str()
|
||||
err() << "Type variables in type " << arg.type()->repr_str()
|
||||
<< " could not be inferred from actual type "
|
||||
<< value->type()->python_str();
|
||||
<< value->type()->repr_str();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -172,7 +172,7 @@ static Value* tryMatchArgument(
|
||||
concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) {
|
||||
if (failure_messages) {
|
||||
auto& ostream = err()
|
||||
<< arg.formatTypeMismatchMsg(value->type()->python_str());
|
||||
<< arg.formatTypeMismatchMsg(value->type()->repr_str());
|
||||
|
||||
if (auto pt = value->type()->cast<TensorType>()) {
|
||||
if (pt->isInferredType()) {
|
||||
@ -414,7 +414,7 @@ static c10::optional<MatchedSchema> tryMatchSchema(
|
||||
auto return_types = fmap(returns, [&](const Argument& r) {
|
||||
TypePtr result = tryEvalTypeVariables(r.type(), type_env);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result, r.type()->python_str(), " has unbound type variables.");
|
||||
result, r.type()->repr_str(), " has unbound type variables.");
|
||||
return result;
|
||||
});
|
||||
// Codegen does not support return of namedtuples with undefined field names.
|
||||
|
@ -70,7 +70,7 @@ bool SimpleValue::hasAttr(
|
||||
auto class_type = value_->type()->cast<ClassType>();
|
||||
if (!class_type) {
|
||||
throw ErrorReport(loc) << "hasattr's first argument must be an object, got "
|
||||
<< value_->type()->python_str() << " instead";
|
||||
<< value_->type()->repr_str() << " instead";
|
||||
}
|
||||
|
||||
return class_type->hasMethod(field) || class_type->hasAttribute(field) ||
|
||||
@ -176,7 +176,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
|
||||
ErrorReport report(loc);
|
||||
report << "Tried to access nonexistent attribute or method '" << field
|
||||
<< "' of type '" << value_->type()->python_str() << "'.";
|
||||
<< "' of type '" << value_->type()->repr_str() << "'.";
|
||||
if (value_->type()->kind() == ClassType::Kind) {
|
||||
report << " Did you forget to initialize an attribute in __init__()?";
|
||||
}
|
||||
@ -205,7 +205,7 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
|
||||
graph->insertNode(graph->createListUnpack(value_, *size_hint));
|
||||
return fmap(unpack->outputs(), make_simple_value);
|
||||
}
|
||||
throw ErrorReport(loc) << value_->type()->python_str()
|
||||
throw ErrorReport(loc) << value_->type()->repr_str()
|
||||
<< " cannot be used as a tuple";
|
||||
}
|
||||
|
||||
@ -232,8 +232,7 @@ void SimpleValue::setAttr(
|
||||
const auto classType = value_->type()->cast<ClassType>();
|
||||
if (!classType) {
|
||||
throw ErrorReport(loc) << "Tried to set an attribute: " << field
|
||||
<< " on a non-class: "
|
||||
<< value_->type()->python_str();
|
||||
<< " on a non-class: " << value_->type()->repr_str();
|
||||
}
|
||||
auto expectedType = classType->findAttribute(field);
|
||||
if (!expectedType) {
|
||||
@ -255,7 +254,7 @@ void SimpleValue::setAttr(
|
||||
throw ErrorReport(loc)
|
||||
<< "Assignment to attribute '" << field
|
||||
<< "' cannot be of a type that contains class "
|
||||
<< "'" << classType->python_str() << "'.\n"
|
||||
<< "'" << classType->repr_str() << "'.\n"
|
||||
<< "Classes that recursively contain instances of themselves"
|
||||
<< " are not yet supported";
|
||||
}
|
||||
@ -283,8 +282,8 @@ void SimpleValue::setAttr(
|
||||
const auto newType = newValue->type();
|
||||
if (!newType->isSubtypeOf(expectedType)) {
|
||||
throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
|
||||
<< expectedType->python_str() << " but got "
|
||||
<< newType->python_str();
|
||||
<< expectedType->repr_str() << " but got "
|
||||
<< newType->repr_str();
|
||||
}
|
||||
|
||||
auto& g = *m.graph();
|
||||
@ -341,7 +340,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
|
||||
val_type->isSubtypeOf(TensorType::get())) {
|
||||
return g.insert(aten::len, {val}, {}, loc);
|
||||
} else {
|
||||
throw ErrorReport(loc) << "'" << val_type->python_str() << "'"
|
||||
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
|
||||
<< " object is not iterable";
|
||||
}
|
||||
}
|
||||
@ -367,7 +366,7 @@ SugaredValuePtr SimpleValue::getitem(
|
||||
} else if (auto class_type = val_type->cast<ClassType>()) {
|
||||
return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1);
|
||||
} else {
|
||||
throw ErrorReport(loc) << "'" << val_type->python_str() << "'"
|
||||
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
|
||||
<< " object is not subscriptable";
|
||||
}
|
||||
}
|
||||
@ -393,7 +392,7 @@ SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) {
|
||||
}
|
||||
return std::make_shared<SugaredTupleValue>(tup_sugared);
|
||||
} else {
|
||||
throw ErrorReport(loc) << "'" << type->python_str() << "'"
|
||||
throw ErrorReport(loc) << "'" << type->repr_str() << "'"
|
||||
<< " object is not iterable";
|
||||
}
|
||||
}
|
||||
@ -407,7 +406,7 @@ RangeValue::RangeValue(
|
||||
auto typ = inputs[i]->type();
|
||||
if (!typ->cast<IntType>()) {
|
||||
throw ErrorReport(loc)
|
||||
<< "all inputs of range must be ints, found " << typ->python_str()
|
||||
<< "all inputs of range must be ints, found " << typ->repr_str()
|
||||
<< " in argument " << c10::guts::to_string(i);
|
||||
}
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ struct TORCH_API SimpleValue : public SugaredValue {
|
||||
SimpleValue(Value* value) : value_(value) {}
|
||||
std::string kind() const override {
|
||||
std::stringstream ss;
|
||||
ss << "value of type '" << value_->type()->python_str() << "'";
|
||||
ss << "value of type '" << value_->type()->annotation_str() << "'";
|
||||
return ss.str();
|
||||
}
|
||||
Value* asValue(const SourceRange& range, Function& m) override {
|
||||
|
@ -367,7 +367,7 @@ static IValue addInput(
|
||||
AT_ERROR(
|
||||
"Only tensors or (possibly nested) dict or tuples of tensors can be "
|
||||
"inputs to traced functions. Got ",
|
||||
type->python_str());
|
||||
type->repr_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1081,9 +1081,9 @@ void AliasDb::replaceWithNewValue(Value* existing, Value* new_value) {
|
||||
*unshapedType(existing->type()) == *unshapedType(new_value->type()),
|
||||
"Types must be strictly equal if you are replacing aliasing information. ",
|
||||
"Got existing: '",
|
||||
existing->type()->python_str(),
|
||||
existing->type()->repr_str(),
|
||||
"', new_value: '",
|
||||
new_value->type()->python_str(),
|
||||
new_value->type()->repr_str(),
|
||||
"'");
|
||||
if (!isMutableTypeInternal(existing)) {
|
||||
return;
|
||||
@ -1099,9 +1099,9 @@ void AliasDb::copyValue(Value* from, Value* to) {
|
||||
*unshapedType(from->type()) == *unshapedType(to->type()),
|
||||
"Types must be strictly equal if you are copying aliasing information. ",
|
||||
"Got from: '",
|
||||
from->type()->python_str(),
|
||||
from->type()->repr_str(),
|
||||
"', to: '",
|
||||
to->type()->python_str(),
|
||||
to->type()->repr_str(),
|
||||
"'");
|
||||
if (!isMutableTypeInternal(to)) {
|
||||
return;
|
||||
@ -1557,8 +1557,8 @@ void Lint(const AliasDb* db) {
|
||||
auto it = db->elementMap_.find(v);
|
||||
if (it == db->elementMap_.end()) {
|
||||
failed = true;
|
||||
ss << "Value %" << v->debugName() << " of type "
|
||||
<< v->type()->python_str() << " wasn't found in the element map.\n"
|
||||
ss << "Value %" << v->debugName() << " of type " << v->type()->repr_str()
|
||||
<< " wasn't found in the element map.\n"
|
||||
<< "It was defined in " << *v->node();
|
||||
}
|
||||
}
|
||||
|
@ -1599,9 +1599,9 @@ Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
|
||||
TORCH_CHECK(
|
||||
v->type()->isSubtypeOf(elem_type),
|
||||
"Expected a list element that subtypes '",
|
||||
elem_type->python_str(),
|
||||
elem_type->repr_str(),
|
||||
"' but got an element of type '",
|
||||
v->type()->python_str(),
|
||||
v->type()->repr_str(),
|
||||
"'");
|
||||
}
|
||||
n->output()->setType(ListType::create(elem_type));
|
||||
@ -1714,7 +1714,7 @@ Value* Graph::insertToList(Value* v, TypePtr type) {
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
ptr->python_str(),
|
||||
ptr->repr_str(),
|
||||
" is not one of the supported element types for tolist: int, float, bool");
|
||||
}
|
||||
|
||||
|
@ -368,10 +368,9 @@ void IRParser::parseOperator(Block* b) {
|
||||
if (!schema_return_type->hasFreeVariables() &&
|
||||
!v.type->isSubtypeOf(schema_return_type)) {
|
||||
throw ErrorReport(source_range)
|
||||
<< "Annotated type " << v.type->python_str()
|
||||
<< "Annotated type " << v.type->repr_str()
|
||||
<< " does not match schema type "
|
||||
<< schema_return_type->python_str() << " for operator "
|
||||
<< *schema;
|
||||
<< schema_return_type->repr_str() << " for operator " << *schema;
|
||||
}
|
||||
vmap[v.name]->setType(v.type);
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
|
||||
getCustomClass(
|
||||
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
|
||||
"Unknown type ",
|
||||
type->python_str(),
|
||||
type->repr_str(),
|
||||
" encountered in graph lowering. This type is not supported in ONNX export.");
|
||||
result.emplace_back(
|
||||
script::Object(obj.toObject()).run_method("__getstate__"));
|
||||
|
@ -351,9 +351,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
|
||||
if (!unified_key) {
|
||||
return InferredType(c10::str(
|
||||
"Dictionary inputs to traced functions must have consistent type. Found ",
|
||||
key_type->python_str(),
|
||||
key_type->repr_str(),
|
||||
" and ",
|
||||
(entry_key_type_match.type())->python_str()));
|
||||
(entry_key_type_match.type())->repr_str()));
|
||||
}
|
||||
|
||||
// Try to infer the value type and unify it with the existing one
|
||||
@ -366,9 +366,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
|
||||
if (!unified_value) {
|
||||
return InferredType(c10::str(
|
||||
"Dictionary inputs to traced functions must have consistent type. Found ",
|
||||
value_type->python_str(),
|
||||
value_type->repr_str(),
|
||||
" and ",
|
||||
(entry_value_type_match.type())->python_str()));
|
||||
(entry_value_type_match.type())->repr_str()));
|
||||
}
|
||||
|
||||
key_type = *unified_key;
|
||||
@ -395,9 +395,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
|
||||
if (!unified_type) {
|
||||
return InferredType(c10::str(
|
||||
"List inputs to traced functions must have consistent element type. Found ",
|
||||
element_type->python_str(),
|
||||
element_type->repr_str(),
|
||||
" and ",
|
||||
(element_type_match.type())->python_str()));
|
||||
(element_type_match.type())->repr_str()));
|
||||
}
|
||||
element_type = *unified_type;
|
||||
}
|
||||
@ -453,7 +453,7 @@ inline Stack toTraceableStack(const py::tuple& inputs) {
|
||||
TORCH_CHECK(
|
||||
isTraceableType(info.type()),
|
||||
"Type '",
|
||||
info.type()->python_str(),
|
||||
info.type()->repr_str(),
|
||||
"' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
|
||||
" Tuples of Tensors can be traced");
|
||||
return info.toTuple()->elements();
|
||||
@ -542,7 +542,7 @@ inline IValue toIValue(
|
||||
"Object ",
|
||||
py::str(obj),
|
||||
" had a different number of elements than type ",
|
||||
type->python_str()));
|
||||
type->repr_str()));
|
||||
}
|
||||
std::vector<IValue> values;
|
||||
values.reserve(tuple_size);
|
||||
@ -669,7 +669,7 @@ inline IValue toIValue(
|
||||
"Object ",
|
||||
py::str(obj),
|
||||
" is not compatible with interface ",
|
||||
interfaceType->python_str(),
|
||||
interfaceType->repr_str(),
|
||||
"\n",
|
||||
why_not.str()));
|
||||
}
|
||||
@ -694,7 +694,7 @@ inline IValue toIValue(
|
||||
return py::cast<double>(obj);
|
||||
} else {
|
||||
throw py::cast_error(
|
||||
c10::str("Cannot cast ", py::str(obj), " to ", type->python_str()));
|
||||
c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
|
||||
}
|
||||
}
|
||||
case TypeKind::RRefType: {
|
||||
@ -726,7 +726,7 @@ inline IValue toIValue(
|
||||
break;
|
||||
}
|
||||
throw py::cast_error(c10::str(
|
||||
"toIValue() cannot handle converting to type: ", type->python_str()));
|
||||
"toIValue() cannot handle converting to type: ", type->repr_str()));
|
||||
}
|
||||
|
||||
// Small wrapper around getting the type name string from Python to make
|
||||
|
@ -18,7 +18,7 @@ py::object ScriptClass::__call__(py::args args, py::kwargs kwargs) {
|
||||
fmt::format(
|
||||
"Custom C++ class: '{}' does not have an '__init__' method bound. "
|
||||
"Did you forget to add '.def(torch::init<...>)' to its registration?",
|
||||
instance.type()->python_str()));
|
||||
instance.type()->repr_str()));
|
||||
Method init_method(instance._ivalue(), init_fn);
|
||||
invokeScriptMethodFromPython(init_method, std::move(args), std::move(kwargs));
|
||||
return py::cast(instance);
|
||||
|
@ -641,7 +641,7 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
|
||||
using ::c10::Type;
|
||||
py::class_<Type, std::shared_ptr<Type>>(m, "Type")
|
||||
.def("__repr__", [](Type& t) { return t.python_str(); })
|
||||
.def("__repr__", [](Type& t) { return t.annotation_str(); })
|
||||
.def(
|
||||
"str",
|
||||
[](Type& t) {
|
||||
|
@ -671,7 +671,7 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
|
||||
TORCH_CHECK(
|
||||
type->isSubtypeOf(tt),
|
||||
"Can't to redefine NamedTuple: ",
|
||||
tt->python_str());
|
||||
tt->repr_str());
|
||||
return type;
|
||||
}
|
||||
get_python_cu()->register_type(tt);
|
||||
|
@ -260,7 +260,7 @@ FunctionSchema getSchemaWithNameAndDefaults(
|
||||
c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
|
||||
if (!value) {
|
||||
ErrorReport error(range);
|
||||
error << "Expected a default value of type " << arg.type()->python_str()
|
||||
error << "Expected a default value of type " << arg.type()->repr_str()
|
||||
<< " on parameter \"" << arg.name() << "\".";
|
||||
if (arg.is_inferred_type()) {
|
||||
error << "Because \"" << arg.name()
|
||||
@ -799,7 +799,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
setstate_schema.arguments().size() == 2,
|
||||
"__setstate__ method for class ",
|
||||
class_type->python_str(),
|
||||
class_type->repr_str(),
|
||||
" must have exactly 2 arguments!");
|
||||
auto state_type = setstate_schema.arguments().at(1).type();
|
||||
(*setstate_method)(Stack{toIValue(state, state_type)});
|
||||
|
@ -158,7 +158,7 @@ IValue tensorToListRecursive(
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
ty->python_str(),
|
||||
ty->repr_str(),
|
||||
" is not one of the supported types for tolist: int, float, bool");
|
||||
}
|
||||
}
|
||||
|
@ -1319,16 +1319,16 @@ Function* checkSortSchema(const c10::TypePtr& list_element_type) {
|
||||
return method;
|
||||
}
|
||||
}
|
||||
error_str << "To sort a list of " << class_type->python_str()
|
||||
error_str << "To sort a list of " << class_type->repr_str()
|
||||
<< " it must define a "
|
||||
<< "__lt__ method with two inputs of type "
|
||||
<< class_type->python_str() << " that "
|
||||
<< class_type->repr_str() << " that "
|
||||
<< "returns a bool";
|
||||
} else {
|
||||
error_str << "To sort a list of " << list_element_type->python_str()
|
||||
error_str << "To sort a list of " << list_element_type->repr_str()
|
||||
<< " must be of Tensors, ints, floats, bools or "
|
||||
<< "a User Defined Class that defines the __lt__ compare method"
|
||||
<< ", got list of " << list_element_type->python_str() << "\n";
|
||||
<< ", got list of " << list_element_type->repr_str() << "\n";
|
||||
}
|
||||
throw std::runtime_error(error_str.str());
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
|
||||
elem_type != BoolType::get()) {
|
||||
std::stringstream error;
|
||||
error << "Input must be of ints, floats, or bools, "
|
||||
<< "got " << elem_type->python_str();
|
||||
<< "got " << elem_type->repr_str();
|
||||
// special case empty list torch.tensor([])
|
||||
if (elem_type->isSubtypeOf(TensorType::get())) {
|
||||
if (empty_list) {
|
||||
@ -220,11 +220,11 @@ int createTensorFromList(Stack& stack) {
|
||||
tensor.numel() == 0) {
|
||||
TORCH_WARN(
|
||||
"Creating a tensor from an empty ",
|
||||
elem_type->python_str(),
|
||||
elem_type->repr_str(),
|
||||
"list will create a tensor of default floating point type (currently ",
|
||||
default_type,
|
||||
") in python but a tensor of type ",
|
||||
elem_type->python_str(),
|
||||
elem_type->repr_str(),
|
||||
" in torchscript.\n",
|
||||
"Pass in a dtype argument to ensure consistent behavior");
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ c10::IValue getFunctionTuple(const Function& func) {
|
||||
std::vector<IValue> types;
|
||||
types.reserve(code.type_table().size());
|
||||
for (const TypePtr& t : code.type_table()) {
|
||||
types.emplace_back(t->python_str());
|
||||
types.emplace_back(t->annotation_str());
|
||||
}
|
||||
|
||||
// since the register location is embedded into the bytecode, pass the
|
||||
|
@ -50,7 +50,7 @@ void postSetStateValidate(const IValue& v) {
|
||||
"The field '{}' was left uninitialized after '__setstate__', "
|
||||
"but expected a value of type '{}'",
|
||||
attrName,
|
||||
attrType->python_str()));
|
||||
attrType->repr_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -356,7 +356,7 @@ Module ScriptModuleDeserializer::LEGACY_convertModule(
|
||||
module_type->getAttributeName(i),
|
||||
"' was left unitialized after __setstate__, but expected a ",
|
||||
"value of type '",
|
||||
v.type()->python_str(),
|
||||
v.type()->repr_str(),
|
||||
"'");
|
||||
}
|
||||
}
|
||||
|
@ -497,7 +497,7 @@ void Pickler::endTypeTag(const IValue& ivalue) {
|
||||
|
||||
// Push the dict type
|
||||
TORCH_INTERNAL_ASSERT(ivalue.type());
|
||||
pushString(ivalue.type()->python_str());
|
||||
pushString(ivalue.type()->annotation_str());
|
||||
|
||||
// Pop the dict and type into a tuple
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE2);
|
||||
@ -658,7 +658,7 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||
TORCH_CHECK(
|
||||
set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()),
|
||||
"'__setstate__' must return None, but found value of type",
|
||||
set_schema.returns().at(0).type()->python_str());
|
||||
set_schema.returns().at(0).type()->annotation_str());
|
||||
|
||||
// Check that the return type of __getstate__ matches the input to
|
||||
// __setstate__
|
||||
@ -668,9 +668,9 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||
TORCH_CHECK(
|
||||
get_type->isSubtypeOf(set_type),
|
||||
"'__getstate__'s return type (",
|
||||
get_type->python_str(),
|
||||
get_type->annotation_str(),
|
||||
") does not match '__setstate__'s argument type (",
|
||||
set_type->python_str(),
|
||||
set_type->annotation_str(),
|
||||
")");
|
||||
|
||||
return true;
|
||||
|
@ -500,7 +500,7 @@ struct PythonPrintImpl {
|
||||
indent();
|
||||
body_ << useOf(lhs[i]);
|
||||
if (requiresAnnotation(lhs[i], rhs[i])) {
|
||||
body_ << ": " << lhs[i]->type()->python_str(type_printer_);
|
||||
body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
|
||||
}
|
||||
body_ << " = " << useOf(rhs[i]) << "\n";
|
||||
}
|
||||
@ -769,7 +769,7 @@ struct PythonPrintImpl {
|
||||
if (i > 0) {
|
||||
body_ << ", ";
|
||||
}
|
||||
body_ << useOf(v) << ": " << v->type()->python_str(type_printer_);
|
||||
body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
|
||||
}
|
||||
body_ << "):\n";
|
||||
printBody(graph->block());
|
||||
@ -804,7 +804,7 @@ struct PythonPrintImpl {
|
||||
if (v.isTuple() && v.type()->expect<TupleType>()->schema()) {
|
||||
// print the namedtuple constructor and let rest of tuple printing
|
||||
// continue
|
||||
ss << v.type()->expect<TupleType>()->python_str(type_printer_);
|
||||
ss << v.type()->expect<TupleType>()->annotation_str(type_printer_);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@ -857,14 +857,14 @@ struct PythonPrintImpl {
|
||||
} break;
|
||||
case prim::Uninitialized: {
|
||||
stmt << "uninitialized("
|
||||
<< node->output()->type()->python_str(type_printer_) << ")";
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ")";
|
||||
} break;
|
||||
case prim::Constant: {
|
||||
if (node->outputs().size() == 1 &&
|
||||
node->output()->type()->kind() == TypeKind::FunctionType) {
|
||||
auto fn = node->output()->type()->expect<FunctionType>();
|
||||
registerDependency(fn);
|
||||
stmt << fn->python_str(type_printer_);
|
||||
stmt << fn->annotation_str(type_printer_);
|
||||
} else if (!node->mustBeNone()) {
|
||||
IValue v = toIValue(node->output()).value();
|
||||
printConstant(stmt, v);
|
||||
@ -875,8 +875,9 @@ struct PythonPrintImpl {
|
||||
case aten::ScalarImplicit:
|
||||
case aten::FloatImplicit:
|
||||
case aten::IntImplicit: {
|
||||
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
|
||||
<< ", " << useOf(node->input()) << ")";
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", "
|
||||
<< useOf(node->input()) << ")";
|
||||
} break;
|
||||
case aten::Int: {
|
||||
printValueList(stmt, node->inputs(), "int(", ")");
|
||||
@ -902,7 +903,7 @@ struct PythonPrintImpl {
|
||||
case prim::TupleConstruct: {
|
||||
if (auto qualname =
|
||||
node->output()->type()->expect<TupleType>()->name()) {
|
||||
stmt << node->output()->type()->python_str(type_printer_);
|
||||
stmt << node->output()->type()->annotation_str(type_printer_);
|
||||
}
|
||||
printValueList(
|
||||
stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
|
||||
@ -922,13 +923,14 @@ struct PythonPrintImpl {
|
||||
// what type is supposed to be inside them
|
||||
if (node->inputs().size() == 0) {
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", [])";
|
||||
<< node->output()->type()->annotation_str(type_printer_)
|
||||
<< ", [])";
|
||||
// If we can't infer the type based on what's inside, explicitly
|
||||
// annotate it to disambiguate.
|
||||
// This happens for List[Tensor] vs. List[Optional[Tensor]]
|
||||
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", ";
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
||||
printValueList(stmt, node->inputs(), "[", "]");
|
||||
stmt << ")";
|
||||
// Otherwise just print a list
|
||||
@ -947,7 +949,7 @@ struct PythonPrintImpl {
|
||||
!elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
|
||||
!elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", ";
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
||||
printDict(stmt, node->inputs());
|
||||
stmt << ")";
|
||||
// Otherwise just print a dict
|
||||
@ -957,8 +959,8 @@ struct PythonPrintImpl {
|
||||
} break;
|
||||
case prim::CreateObject: {
|
||||
const auto classType = node->output()->type()->expect<ClassType>();
|
||||
stmt << classType->python_str(type_printer_) << ".__new__("
|
||||
<< classType->python_str(type_printer_) << ")";
|
||||
stmt << classType->annotation_str(type_printer_) << ".__new__("
|
||||
<< classType->annotation_str(type_printer_) << ")";
|
||||
} break;
|
||||
case prim::GetAttr: {
|
||||
const auto obj = node->inputs().at(0);
|
||||
@ -1013,8 +1015,8 @@ struct PythonPrintImpl {
|
||||
if (node->input()->type()->isSubtypeOf(NoneType::get()) ||
|
||||
node->input()->mustBeNone()) {
|
||||
auto input_type = OptionalType::create(node->output()->type());
|
||||
stmt << "annotate(" << input_type->python_str(type_printer_) << ", "
|
||||
<< useOf(node->input()) << ")";
|
||||
stmt << "annotate(" << input_type->annotation_str(type_printer_)
|
||||
<< ", " << useOf(node->input()) << ")";
|
||||
} else {
|
||||
stmt << useOf(node->input());
|
||||
}
|
||||
@ -1027,14 +1029,14 @@ struct PythonPrintImpl {
|
||||
case prim::unchecked_unwrap_optional:
|
||||
case prim::unchecked_cast: {
|
||||
stmt << "unchecked_cast("
|
||||
<< node->output()->type()->python_str(type_printer_) << ", "
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", "
|
||||
<< useOf(node->input()) << ")";
|
||||
} break;
|
||||
case prim::isinstance: {
|
||||
stmt << "isinstance(" << useOf(node->input()) << ", ";
|
||||
const auto& types = node->tys(attr::types);
|
||||
if (types.size() == 1) {
|
||||
stmt << types.at(0)->python_str(type_printer_);
|
||||
stmt << types.at(0)->annotation_str(type_printer_);
|
||||
} else {
|
||||
// check multiple things, e.g. (str, list, int)
|
||||
stmt << "(";
|
||||
@ -1043,7 +1045,7 @@ struct PythonPrintImpl {
|
||||
if (!first) {
|
||||
stmt << ", ";
|
||||
}
|
||||
stmt << typ->python_str(type_printer_);
|
||||
stmt << typ->annotation_str(type_printer_);
|
||||
first = false;
|
||||
}
|
||||
stmt << ")";
|
||||
@ -1051,8 +1053,8 @@ struct PythonPrintImpl {
|
||||
stmt << ")";
|
||||
} break;
|
||||
case prim::tolist: {
|
||||
stmt << "annotate(" << node->output()->type()->python_str(type_printer_)
|
||||
<< ", ";
|
||||
stmt << "annotate("
|
||||
<< node->output()->type()->annotation_str(type_printer_) << ", ";
|
||||
stmt << useOf(node->input(0)) << ".tolist()"
|
||||
<< ")";
|
||||
} break;
|
||||
@ -1172,11 +1174,11 @@ struct PythonPrintImpl {
|
||||
// the flag print_first_argument_type determines when to do this
|
||||
body_ << arg_name;
|
||||
if (print_first_argument_type) {
|
||||
body_ << ": " << arg.type()->python_str(type_printer_);
|
||||
body_ << ": " << arg.type()->annotation_str(type_printer_);
|
||||
}
|
||||
} else {
|
||||
body_ << ",\n " << arg_name << ": "
|
||||
<< arg.type()->python_str(type_printer_);
|
||||
<< arg.type()->annotation_str(type_printer_);
|
||||
}
|
||||
if (arg.default_value()) {
|
||||
printDefaultValue(arg, body_, *arg.default_value());
|
||||
@ -1184,7 +1186,8 @@ struct PythonPrintImpl {
|
||||
assignValue(*param_it++, arg_name);
|
||||
}
|
||||
|
||||
body_ << ") -> " << schema.returns().at(0).type()->python_str(type_printer_)
|
||||
body_ << ") -> "
|
||||
<< schema.returns().at(0).type()->annotation_str(type_printer_)
|
||||
<< ":\n";
|
||||
printBody(graph.block());
|
||||
}
|
||||
@ -1275,12 +1278,12 @@ struct PythonPrintImpl {
|
||||
// Print out a direct manipulation of the annotations dict, like:
|
||||
// __annotations__["0"] = SomeType
|
||||
body_ << "__annotations__["
|
||||
<< "\"" << name << "\"] = " << type->python_str(type_printer_)
|
||||
<< "\n";
|
||||
<< "\"" << name
|
||||
<< "\"] = " << type->annotation_str(type_printer_) << "\n";
|
||||
} else {
|
||||
// Otherwise: just emit a python 3 attribute annotation, like:
|
||||
// foo : SomeType
|
||||
body_ << name << " : " << type->python_str(type_printer_) << "\n";
|
||||
body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -1291,7 +1294,7 @@ struct PythonPrintImpl {
|
||||
|
||||
indent();
|
||||
body_ << name << " : "
|
||||
<< "Final[" << v.type()->python_str(type_printer_) << "] = ";
|
||||
<< "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
|
||||
auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
|
||||
printConstant(*ss, v);
|
||||
body_ << ss->str() << "\n";
|
||||
@ -1319,7 +1322,7 @@ struct PythonPrintImpl {
|
||||
TORCH_INTERNAL_ASSERT(attr.type());
|
||||
indent();
|
||||
body_ << attr.name() << " : "
|
||||
<< attr.type()->python_str(type_printer_) << "\n";
|
||||
<< attr.type()->annotation_str(type_printer_) << "\n";
|
||||
}
|
||||
}
|
||||
} else if (auto interfaceType = type->cast<InterfaceType>()) {
|
||||
@ -1342,11 +1345,12 @@ struct PythonPrintImpl {
|
||||
auto type = arg.type();
|
||||
registerClassDependencies(type);
|
||||
body_ << ", " << arg.name() << ": "
|
||||
<< type->python_str(type_printer_);
|
||||
<< type->annotation_str(type_printer_);
|
||||
}
|
||||
auto return_type = method.returns().at(0).type();
|
||||
registerClassDependencies(return_type);
|
||||
body_ << ") -> " << return_type->python_str(type_printer_) << ":\n";
|
||||
body_ << ") -> " << return_type->annotation_str(type_printer_)
|
||||
<< ":\n";
|
||||
indent();
|
||||
body_ << " pass\n";
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ void restoreContainerTypeTags(IValue& ivalue, TypePtr type) {
|
||||
} else if (auto list_type = type->cast<ListType>()) {
|
||||
ivalue.toList().unsafeSetElementType(list_type->getElementType());
|
||||
} else {
|
||||
AT_ERROR("Unknown type for tag restoration: " + type->python_str());
|
||||
AT_ERROR("Unknown type for tag restoration: " + type->annotation_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -202,7 +202,7 @@ class class_ {
|
||||
TORCH_CHECK(
|
||||
*first_arg_type == *classTypePtr,
|
||||
"self argument of __getstate__ must be the custom class type. Got ",
|
||||
first_arg_type->python_str());
|
||||
first_arg_type->repr_str());
|
||||
TORCH_CHECK(
|
||||
getstate_schema.returns().size() == 1,
|
||||
"__getstate__ should return exactly one value for serialization. Got: ",
|
||||
@ -214,9 +214,9 @@ class class_ {
|
||||
(*arg_type == *ser_type),
|
||||
"__setstate__'s argument should be the same type as the "
|
||||
"return value of __getstate__. Got ",
|
||||
arg_type->python_str(),
|
||||
arg_type->repr_str(),
|
||||
" but expected ",
|
||||
ser_type->python_str());
|
||||
ser_type->repr_str());
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
Reference in New Issue
Block a user