mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/81145 Approved by: https://github.com/ezyang
35 lines
815 B
C++
35 lines
815 B
C++
#include <c10/core/SymIntArrayRef.h>
|
|
#include <c10/util/Optional.h>
|
|
#include <iostream>
|
|
|
|
namespace c10 {
|
|
|
|
at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar) {
|
|
auto r = asIntArrayRefSlowOpt(ar);
|
|
TORCH_CHECK(
|
|
r.has_value(),
|
|
"SymIntArrayRef expected to contain only concrete integers");
|
|
return *r;
|
|
}
|
|
|
|
c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(c10::SymIntArrayRef ar) {
|
|
for (c10::SymInt sci : ar) {
|
|
if (sci.is_symbolic()) {
|
|
return c10::nullopt;
|
|
}
|
|
}
|
|
|
|
return {asIntArrayRefUnchecked(ar)};
|
|
}
|
|
|
|
at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) {
|
|
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, SymInt s) {
|
|
os << "SymInt(" << s.data() << ")";
|
|
return os;
|
|
}
|
|
|
|
} // namespace c10
|