Only allow dense floating-point types as the default tensor type. (#5674)

This commit is contained in:
gchanan
2018-03-09 23:50:18 -05:00
committed by Soumith Chintala
parent 03f2ad9029
commit b5ee5e585b
2 changed files with 29 additions and 15 deletions

View File

@ -335,6 +335,14 @@ void py_set_default_tensor_type(PyObject* obj) {
throw unavailable_type(*type);
}
if (!at::isFloatingType(type->aten_type->scalarType())) {
throw TypeError("only floating-point types are supported as the default type");
}
if (type->aten_type->is_sparse()) {
throw TypeError("only dense types are supported as the default type");
}
// get the storage first, so if it doesn't exist we don't change the default tensor type
THPObjectPtr storage = get_storage_obj(*type);
set_default_tensor_type(*type->aten_type);