mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/139605 Approved by: https://github.com/ezyang
80 lines
2.5 KiB
C++
80 lines
2.5 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/jit_type.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/ir_views.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/peephole.h>
|
|
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
// Refine from Value of type List -> len of list
|
|
// If a refinement mapping of List Value * -> len is present in a block
|
|
// the list is guaranteed to be that length
|
|
// TODO: vector may be faster
|
|
using ListRefinement = std::unordered_map<Value*, int64_t>;
|
|
|
|
TORCH_API ListRefinement
|
|
intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2);
|
|
|
|
TORCH_API ListRefinement
|
|
unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2);
|
|
|
|
// Represents the refinement information that can be carried on a boolean
|
|
struct BooleanRefinementMapping {
|
|
BooleanRefinementMapping(
|
|
ListRefinement true_refine,
|
|
ListRefinement false_refine)
|
|
: true_refine_(std::move(true_refine)),
|
|
false_refine_(std::move(false_refine)) {}
|
|
BooleanRefinementMapping() = default; // empty
|
|
|
|
static BooleanRefinementMapping FalseRefinements(
|
|
ListRefinement false_refine) {
|
|
return BooleanRefinementMapping({}, std::move(false_refine));
|
|
}
|
|
|
|
static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) {
|
|
return BooleanRefinementMapping(std::move(true_refine), {});
|
|
}
|
|
|
|
BooleanRefinementMapping intersectBooleanRefinementMapping(
|
|
BooleanRefinementMapping& other) {
|
|
return BooleanRefinementMapping(
|
|
intersectRefinements(true_refine_, other.true_refine()),
|
|
intersectRefinements(false_refine_, other.false_refine()));
|
|
}
|
|
|
|
ListRefinement& true_refine() {
|
|
return true_refine_;
|
|
}
|
|
|
|
ListRefinement& false_refine() {
|
|
return false_refine_;
|
|
}
|
|
|
|
private:
|
|
ListRefinement true_refine_;
|
|
ListRefinement false_refine_;
|
|
};
|
|
|
|
TORCH_API void joinIfRefinements(
|
|
Node* if_node,
|
|
std::unordered_set<Block*>& throwing_blocks,
|
|
ListRefinement& curr_block_refinements,
|
|
ListRefinement& true_block_refinements,
|
|
ListRefinement& false_block_refinements,
|
|
std::unordered_map<Value*, BooleanRefinementMapping>& info);
|
|
|
|
// handles adding blocks to throwing blocks and propagating refinements via
|
|
// boolean comparisons
|
|
TORCH_API bool handleCommonRefinentOperators(
|
|
Node* n,
|
|
std::unordered_set<Block*>& throwing_blocks,
|
|
std::unordered_map<Value*, BooleanRefinementMapping>& info);
|
|
|
|
} // namespace torch::jit
|