mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactoring TensorImpl by using constexpr and std::is_same_v (#161043)
As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161043 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -643,47 +643,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// From https://stackoverflow.com/a/3057522/23845
|
|
||||||
// TODO: does C++14 have a stdlib template for this?
|
|
||||||
template <typename T>
|
|
||||||
struct identity {
|
|
||||||
typedef T type;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ArrayRef<T> generic_sizes() {
|
ArrayRef<T> generic_sizes() {
|
||||||
return _generic_sizes(identity<T>());
|
static_assert(
|
||||||
}
|
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
|
||||||
|
"Only supports int64_t and c10::SymInt.");
|
||||||
|
|
||||||
ArrayRef<int64_t> _generic_sizes(identity<int64_t>) {
|
if constexpr (std::is_same_v<T, int64_t>) {
|
||||||
return sizes();
|
return sizes();
|
||||||
}
|
} else {
|
||||||
ArrayRef<c10::SymInt> _generic_sizes(identity<c10::SymInt>) {
|
return sym_sizes();
|
||||||
return sym_sizes();
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ArrayRef<T> generic_strides() {
|
ArrayRef<T> generic_strides() {
|
||||||
return _generic_strides(identity<T>());
|
static_assert(
|
||||||
}
|
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
|
||||||
|
"Only supports int64_t and c10::SymInt.");
|
||||||
|
|
||||||
ArrayRef<int64_t> _generic_strides(identity<int64_t>) {
|
if constexpr (std::is_same_v<T, int64_t>) {
|
||||||
return strides();
|
return strides();
|
||||||
}
|
} else {
|
||||||
ArrayRef<c10::SymInt> _generic_strides(identity<c10::SymInt>) {
|
return sym_strides();
|
||||||
return sym_strides();
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T generic_storage_offset() {
|
T generic_storage_offset() {
|
||||||
return _generic_storage_offset(identity<T>());
|
static_assert(
|
||||||
}
|
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
|
||||||
|
"Only supports int64_t and c10::SymInt.");
|
||||||
|
|
||||||
int64_t _generic_storage_offset(identity<int64_t>) {
|
if constexpr (std::is_same_v<T, int64_t>) {
|
||||||
return storage_offset();
|
return storage_offset();
|
||||||
}
|
} else {
|
||||||
c10::SymInt _generic_storage_offset(identity<c10::SymInt>) {
|
return sym_storage_offset();
|
||||||
return sym_storage_offset();
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Reference in New Issue
Block a user