mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactoring TensorImpl by using constexpr and std::is_same_v (#161043)
As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161043 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -643,47 +643,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
}
|
||||
}
|
||||
|
||||
// From https://stackoverflow.com/a/3057522/23845
|
||||
// TODO: does C++14 have a stdlib template for this?
|
||||
template <typename T>
|
||||
struct identity {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
ArrayRef<T> generic_sizes() {
|
||||
return _generic_sizes(identity<T>());
|
||||
}
|
||||
static_assert(
|
||||
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
|
||||
"Only supports int64_t and c10::SymInt.");
|
||||
|
||||
ArrayRef<int64_t> _generic_sizes(identity<int64_t>) {
|
||||
return sizes();
|
||||
}
|
||||
ArrayRef<c10::SymInt> _generic_sizes(identity<c10::SymInt>) {
|
||||
return sym_sizes();
|
||||
if constexpr (std::is_same_v<T, int64_t>) {
|
||||
return sizes();
|
||||
} else {
|
||||
return sym_sizes();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ArrayRef<T> generic_strides() {
|
||||
return _generic_strides(identity<T>());
|
||||
}
|
||||
static_assert(
|
||||
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
|
||||
"Only supports int64_t and c10::SymInt.");
|
||||
|
||||
ArrayRef<int64_t> _generic_strides(identity<int64_t>) {
|
||||
return strides();
|
||||
}
|
||||
ArrayRef<c10::SymInt> _generic_strides(identity<c10::SymInt>) {
|
||||
return sym_strides();
|
||||
if constexpr (std::is_same_v<T, int64_t>) {
|
||||
return strides();
|
||||
} else {
|
||||
return sym_strides();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T generic_storage_offset() {
|
||||
return _generic_storage_offset(identity<T>());
|
||||
}
|
||||
static_assert(
|
||||
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
|
||||
"Only supports int64_t and c10::SymInt.");
|
||||
|
||||
int64_t _generic_storage_offset(identity<int64_t>) {
|
||||
return storage_offset();
|
||||
}
|
||||
c10::SymInt _generic_storage_offset(identity<c10::SymInt>) {
|
||||
return sym_storage_offset();
|
||||
if constexpr (std::is_same_v<T, int64_t>) {
|
||||
return storage_offset();
|
||||
} else {
|
||||
return sym_storage_offset();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user