mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Follows #131997 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132010 Approved by: https://github.com/Skylion007
38 lines
933 B
C++
38 lines
933 B
C++
#include <torch/csrc/jit/passes/clear_undefinedness.h>
|
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
static void clearUndefinedness(Value* o) {
|
|
if (o->type()->kind() == TensorType::Kind) {
|
|
o->setType(TensorType::get());
|
|
} else if (
|
|
o->type()->kind() == ListType::Kind &&
|
|
o->type()->expectRef<ListType>().getElementType()->kind() ==
|
|
TensorType::Kind) {
|
|
o->setType(ListType::create(TensorType::get()));
|
|
}
|
|
}
|
|
|
|
static void clearUndefinedness(Block* block) {
|
|
for (auto n : block->nodes()) {
|
|
for (auto o : n->outputs()) {
|
|
clearUndefinedness(o);
|
|
}
|
|
for (auto ib : n->blocks()) {
|
|
clearUndefinedness(ib);
|
|
}
|
|
}
|
|
}
|
|
|
|
void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
|
|
for (auto i : graph->inputs()) {
|
|
clearUndefinedness(i);
|
|
}
|
|
clearUndefinedness(graph->block());
|
|
GRAPH_DUMP("After removeUndefinedness: ", graph);
|
|
}
|
|
|
|
} // namespace torch::jit
|