Refactoring TensorImpl by using constexpr and std::is_same_v (#161043)

As the title stated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161043
Approved by: https://github.com/Skylion007
This commit is contained in:
FFFrog
2025-08-20 14:48:24 +08:00
committed by PyTorch MergeBot
parent 9b4adc4db7
commit 2beffb3311

View File

@ -643,47 +643,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
}
// From https://stackoverflow.com/a/3057522/23845
// TODO: does C++14 have a stdlib template for this?
template <typename T>
struct identity {
typedef T type;
};
template <typename T>
ArrayRef<T> generic_sizes() {
return _generic_sizes(identity<T>());
}
static_assert(
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
"Only supports int64_t and c10::SymInt.");
ArrayRef<int64_t> _generic_sizes(identity<int64_t>) {
return sizes();
}
ArrayRef<c10::SymInt> _generic_sizes(identity<c10::SymInt>) {
return sym_sizes();
if constexpr (std::is_same_v<T, int64_t>) {
return sizes();
} else {
return sym_sizes();
}
}
template <typename T>
ArrayRef<T> generic_strides() {
return _generic_strides(identity<T>());
}
static_assert(
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
"Only supports int64_t and c10::SymInt.");
ArrayRef<int64_t> _generic_strides(identity<int64_t>) {
return strides();
}
ArrayRef<c10::SymInt> _generic_strides(identity<c10::SymInt>) {
return sym_strides();
if constexpr (std::is_same_v<T, int64_t>) {
return strides();
} else {
return sym_strides();
}
}
template <typename T>
T generic_storage_offset() {
return _generic_storage_offset(identity<T>());
}
static_assert(
std::is_same_v<T, int64_t> || std::is_same_v<T, c10::SymInt>,
"Only supports int64_t and c10::SymInt.");
int64_t _generic_storage_offset(identity<int64_t>) {
return storage_offset();
}
c10::SymInt _generic_storage_offset(identity<c10::SymInt>) {
return sym_storage_offset();
if constexpr (std::is_same_v<T, int64_t>) {
return storage_offset();
} else {
return sym_storage_offset();
}
}
/**