mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 19:54:53 +08:00
Only allow dense floating-point types as the default tensor type. (#5674)
This commit is contained in:
committed by
Soumith Chintala
parent
03f2ad9029
commit
b5ee5e585b
@ -335,6 +335,14 @@ void py_set_default_tensor_type(PyObject* obj) {
|
||||
throw unavailable_type(*type);
|
||||
}
|
||||
|
||||
if (!at::isFloatingType(type->aten_type->scalarType())) {
|
||||
throw TypeError("only floating-point types are supported as the default type");
|
||||
}
|
||||
|
||||
if (type->aten_type->is_sparse()) {
|
||||
throw TypeError("only dense types are supported as the default type");
|
||||
}
|
||||
|
||||
// get the storage first, so if it doesn't exist we don't change the default tensor type
|
||||
THPObjectPtr storage = get_storage_obj(*type);
|
||||
set_default_tensor_type(*type->aten_type);
|
||||
|
||||
Reference in New Issue
Block a user