mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[9/N] Fix clang-tidy warnings in jit (#132010)
Follows #131997 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132010 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -1,8 +1,7 @@
|
||||
#include <torch/csrc/jit/passes/add_if_then_else.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -51,5 +50,4 @@ bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) {
|
||||
return !to_replace.empty();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API bool AddIfThenElseOp(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
static void AnnotateWarns(Block* b) {
|
||||
static std::atomic<int64_t> idx(0);
|
||||
@ -25,5 +24,4 @@ void AnnotateWarns(const std::shared_ptr<Graph>& graph) {
|
||||
AnnotateWarns(graph->block());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void AnnotateWarns(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -13,8 +13,7 @@
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -79,10 +78,10 @@ std::optional<AutocastScope> parseAutocast(
|
||||
if (use.user->kind() == prim::SetAttr &&
|
||||
use.user->s(attr::name) == "_enabled") {
|
||||
// Search for `prim::SetAttr[name="_enabled"]`
|
||||
auto ret = constant_as<bool>(use.user->input(1));
|
||||
enabled = constant_as<bool>(use.user->input(1));
|
||||
TORCH_CHECK(
|
||||
ret.has_value(), "Autocast _enabled argument must be a constant");
|
||||
enabled = ret.value();
|
||||
enabled.has_value(),
|
||||
"Autocast _enabled argument must be a constant");
|
||||
} else if (
|
||||
use.user->kind() == prim::SetAttr &&
|
||||
use.user->s(attr::name) == "device") {
|
||||
@ -532,5 +531,4 @@ void Autocast(const std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("\nAfter Autocast: ", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -3,13 +3,11 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void Autocast(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API bool setAutocastMode(bool value);
|
||||
TORCH_API bool autocastEnabled();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -12,8 +12,7 @@
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
static bool shouldBeCapturedInByBailOut(Node* n) {
|
||||
return n->kind() != prim::Constant;
|
||||
@ -223,7 +222,7 @@ struct BailOutGraphBuilderForNode {
|
||||
// version of an original graph from a particular point
|
||||
struct BailOutInserter {
|
||||
explicit BailOutInserter(std::shared_ptr<Graph> graph)
|
||||
: graph_(std::move(graph)), bailout_index_(0) {}
|
||||
: graph_(std::move(graph)) {}
|
||||
|
||||
void run() {
|
||||
liveness_sets_ = BuildLivenessSets(graph_);
|
||||
@ -322,7 +321,7 @@ struct BailOutInserter {
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::map<Node*, Node*> subgraphs;
|
||||
std::size_t bailout_index_;
|
||||
std::size_t bailout_index_{0};
|
||||
std::unordered_map<Node*, std::vector<Value*>> liveness_sets_;
|
||||
std::vector<Node*> bailouts_;
|
||||
std::map<Value*, Value*> replacements_;
|
||||
@ -394,5 +393,4 @@ TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
|
||||
return bailout_graph;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -10,8 +10,7 @@
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Replaces prim::Guard nodes with prim::BailOut nodes and
|
||||
// computes sets of inputs needed to resume execution at
|
||||
@ -30,5 +29,4 @@ TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
|
||||
int64_t bailout_index,
|
||||
const std::shared_ptr<Graph>& orig,
|
||||
const std::shared_ptr<Graph>& target);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -16,8 +16,7 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
|
||||
@ -490,5 +489,4 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
|
||||
PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void BatchMM(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
@ -3,8 +3,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Canonicalize a graph, renumbering it so that all structurally equivalent
|
||||
// graphs have same numbers.
|
||||
@ -231,5 +230,4 @@ static void CanonicalizeOutputs(Block* block) {
|
||||
void CanonicalizeOutputs(std::shared_ptr<Graph>& graph) {
|
||||
CanonicalizeOutputs(graph->block());
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API std::shared_ptr<Graph> Canonicalize(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
@ -18,5 +17,4 @@ TORCH_API bool isBeforeOrAfter(
|
||||
const Use& b,
|
||||
bool checking_before);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -3,8 +3,7 @@
|
||||
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
struct ChunkOutput {
|
||||
ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
|
||||
@ -96,5 +95,4 @@ void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void CanonicalizeOps(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
@ -7,10 +7,8 @@
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/quantization/helper.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -77,7 +75,7 @@ static void checkForUnfusedOps(Node* enter_node) {
|
||||
for (Node* n : guarding_ifs) {
|
||||
ss << *n << "\n";
|
||||
}
|
||||
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
|
||||
throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
|
||||
}
|
||||
|
||||
// autodiff/nnc both insert a number of guards, see
|
||||
@ -110,8 +108,7 @@ static void checkForUnfusedOps(Node* enter_node) {
|
||||
}
|
||||
ss << "\n";
|
||||
}
|
||||
auto range = enter_node->input()->node()->sourceRange();
|
||||
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
|
||||
throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -128,5 +125,4 @@ void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
|
||||
// TODO: improve control flow not taken, right now always errors
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -3,10 +3,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void CheckStrictFusion(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
|
||||
for (auto i : graph->inputs()) {
|
||||
@ -45,5 +44,4 @@ void ClearProfilingInformation(const std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("After ClearProfilingInformation: ", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -6,8 +6,7 @@
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void unprofileGraphInputs(const std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void unprofileBlock(Block* start_block);
|
||||
@ -15,5 +14,4 @@ TORCH_API void unprofileBlock(Block* start_block);
|
||||
|
||||
TORCH_API void ClearProfilingInformation(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
static void clearUndefinedness(Value* o) {
|
||||
if (o->type()->kind() == TensorType::Kind) {
|
||||
@ -35,5 +34,4 @@ void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("After removeUndefinedness: ", graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -6,8 +6,7 @@
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Undefinedness makes argument matching fail for regular tensor operations
|
||||
// if 1+ arguments are undefined or possibly undefined tensors.
|
||||
@ -20,5 +19,4 @@ namespace jit {
|
||||
// When this happens, this pass will be removed
|
||||
TORCH_API void ClearUndefinedness(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -7,8 +7,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
namespace {
|
||||
|
||||
struct CommonSubexpressionEliminator {
|
||||
@ -126,5 +125,4 @@ bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) {
|
||||
CommonSubexpressionEliminator cse(graph);
|
||||
return cse.run([](Node*) { return nullptr; });
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API bool EliminateCommonSubexpression(
|
||||
const std::shared_ptr<Graph>& graph);
|
||||
}
|
||||
} // namespace torch
|
||||
|
@ -16,8 +16,7 @@
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -699,5 +698,4 @@ bool CombineConcats(const std::shared_ptr<Graph>& graph) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Eliminates common inputs among `aten::cat` ops.
|
||||
TORCH_API bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph);
|
||||
@ -15,5 +14,4 @@ TORCH_API void ExpandConcatAndEliminateRedundancy(
|
||||
|
||||
TORCH_API bool CombineConcats(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -6,8 +6,7 @@
|
||||
#include <torch/csrc/jit/ir/node_hashing.h>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -71,5 +70,4 @@ void ConstantPooling(const std::shared_ptr<Graph>& graph) {
|
||||
std::unordered_set<Node*, HashNode, EqualNode> constants;
|
||||
ConstantPooling(graph->block(), constants, aliasDb);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void ConstantPooling(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
@ -16,8 +16,7 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
std::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
|
||||
const Node* n,
|
||||
@ -434,5 +433,4 @@ bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph) {
|
||||
return made_change;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Runs constant propagation on all objects unless ignore_custom_classes is
|
||||
// specified as true, in which case user defined classes are skipped. This is
|
||||
@ -28,5 +27,4 @@ TORCH_API std::optional<Stack> runNodeIfInputsAreConstant(
|
||||
bool ignore_custom_classes = false,
|
||||
AliasDb* db = nullptr);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -10,8 +10,7 @@
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -374,9 +373,9 @@ std::optional<bool> findRequiresGradForOutput(
|
||||
}
|
||||
|
||||
if (use.user->kind() == prim::profile) {
|
||||
std::optional<bool> req_grad_use;
|
||||
if ((req_grad_use = getProfileNodeRequiresGrad(use.user)).has_value()) {
|
||||
return req_grad_use.value();
|
||||
auto req_grad_use = getProfileNodeRequiresGrad(use.user);
|
||||
if (req_grad_use.has_value()) {
|
||||
return req_grad_use;
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,10 +392,9 @@ std::optional<bool> findRequiresGradForOutput(
|
||||
}
|
||||
|
||||
if (dg_use.user->kind() == prim::profile) {
|
||||
std::optional<bool> req_grad_use;
|
||||
if ((req_grad_use = getProfileNodeRequiresGrad(dg_use.user))
|
||||
.has_value()) {
|
||||
return req_grad_use.value();
|
||||
auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user);
|
||||
if (req_grad_use.has_value()) {
|
||||
return req_grad_use;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -474,5 +472,4 @@ std::vector<Node*> CreateAutodiffSubgraphs(
|
||||
GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
|
||||
return diff_nodes;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -5,8 +5,7 @@
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// insert GraphExecutor nodes that group together
|
||||
// subgraphs that are differentiable by the jit's autodiff passes
|
||||
@ -15,5 +14,4 @@ namespace jit {
|
||||
TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
size_t threshold = 2);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -8,8 +8,7 @@
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -223,5 +222,4 @@ void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
|
||||
InlineFunctionalGraphs(graph->block());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -3,12 +3,10 @@
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -7,8 +7,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace prim {
|
||||
using namespace ::c10::prim;
|
||||
@ -458,5 +457,4 @@ void EliminateDeadCode(
|
||||
eliminator.run(block, /*recurse=*/true);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// If given a top-level graph, DCE will construct do alias analysis that allows
|
||||
// for "smarter" dead code elimination (we will eliminate mutable ops if we can
|
||||
@ -38,5 +37,4 @@ TORCH_API void EliminateDeadCode(
|
||||
std::function<void(const std::unordered_set<const Value*>&)> cb,
|
||||
DCESideEffectPolicy sideEffectPolicy =
|
||||
DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -10,8 +10,7 @@
|
||||
|
||||
#include <ATen/core/symbol.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
||||
@ -231,5 +230,4 @@ void DecomposeOps(std::shared_ptr<Graph>& graph) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
TORCH_API void DecomposeOps(std::shared_ptr<Graph>& graph);
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
@ -10,8 +10,7 @@
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -67,7 +66,7 @@ bool returnSecondArgDeviceRule(Node* n) {
|
||||
return setReturnsToDevice(n, tensor_type->device());
|
||||
}
|
||||
|
||||
bool isZerodimCPUTensor(std::shared_ptr<TensorType> tensor_type) {
|
||||
bool isZerodimCPUTensor(const std::shared_ptr<TensorType>& tensor_type) {
|
||||
// CPU devices on zerodim tensors are the only device that can be
|
||||
// overwritten by another device. Therefore, to be conservative
|
||||
// assume that it is not a zerodim cpu tensor if something is not known.
|
||||
@ -149,7 +148,7 @@ bool defaultDeviceProp(Node* n) {
|
||||
|
||||
struct DeviceTypePropagationPass : public PropertyPropBase {
|
||||
explicit DeviceTypePropagationPass(std::shared_ptr<Graph> graph)
|
||||
: PropertyPropBase(graph) {
|
||||
: PropertyPropBase(std::move(graph)) {
|
||||
buildRuleRegistry();
|
||||
}
|
||||
|
||||
@ -261,5 +260,4 @@ bool DeviceTypePropagation(std::shared_ptr<Graph>& graph) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,12 +2,10 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
struct Graph;
|
||||
|
||||
// Propagates Device type info throughout the given graph.
|
||||
TORCH_API bool DeviceTypePropagation(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -21,8 +21,7 @@
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -336,5 +335,4 @@ bool DtypePropagation(std::shared_ptr<Graph>& graph) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -4,8 +4,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
struct Graph;
|
||||
|
||||
// Propagate tensor properties (e.g., dtype, device, is_contiguous, layout)
|
||||
@ -13,5 +12,4 @@ struct Graph;
|
||||
// propagation
|
||||
TORCH_API bool DtypePropagation(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -24,13 +24,12 @@
|
||||
#endif
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
bool mergeTypes(
|
||||
ArrayRef<Value*> lhs,
|
||||
@ -64,9 +63,10 @@ void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) {
|
||||
} catch (propagation_error& e) {
|
||||
setUnshapedType(node);
|
||||
} catch (std::exception& e) {
|
||||
throw ErrorReport(node->sourceRange())
|
||||
throw(
|
||||
ErrorReport(node->sourceRange())
|
||||
<< ExceptionMessage(e)
|
||||
<< "\nThe above operation failed shape propagation in this context";
|
||||
<< "\nThe above operation failed shape propagation in this context");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -314,8 +314,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||
return at::empty_strided(
|
||||
*type->sizes().concrete_sizes(),
|
||||
*type->strides().concrete_sizes(),
|
||||
at::TensorOptions(*type->device())
|
||||
.dtype(*type->scalarType()))
|
||||
at::TensorOptions(*type->device()).dtype(type->scalarType()))
|
||||
.zero_();
|
||||
}
|
||||
// fallthrough
|
||||
@ -992,7 +991,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||
arg_for_type = 1;
|
||||
}
|
||||
auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
|
||||
return {broadcast(*maybe_tensor_types, *t)};
|
||||
return {broadcast(*maybe_tensor_types, t)};
|
||||
}
|
||||
return {};
|
||||
}};
|
||||
@ -1009,7 +1008,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||
if (!dtype) {
|
||||
return {};
|
||||
}
|
||||
return {broadcast(*maybe_tensor_types, *dtype)};
|
||||
return {broadcast(*maybe_tensor_types, dtype)};
|
||||
}
|
||||
return {};
|
||||
}};
|
||||
@ -1714,7 +1713,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||
Symbol shape_input,
|
||||
const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
|
||||
if (auto list_size = determineListSize(node->namedInput(shape_input))) {
|
||||
return tensor_types.at(0)->withDim(*list_size);
|
||||
return tensor_types.at(0)->withDim(list_size);
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
@ -1799,8 +1798,8 @@ class ShapePropagator : public PropertyPropBase {
|
||||
if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) {
|
||||
return nullptr;
|
||||
}
|
||||
int dim1 = *tensor_types.at(0)->dim();
|
||||
int dim2 = *tensor_types.at(1)->dim();
|
||||
auto dim1 = *tensor_types.at(0)->dim();
|
||||
auto dim2 = *tensor_types.at(1)->dim();
|
||||
if (dim1 == 1 && dim2 == 1) {
|
||||
// Dot product
|
||||
return tensor_types.at(0)->withDim(0);
|
||||
@ -2048,8 +2047,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||
/*const_inputs=*/attr::size)) {
|
||||
auto sizes = node->get<c10::List<int64_t>>(attr::size).value();
|
||||
bool inferred = false;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
size_t inferred_idx;
|
||||
size_t inferred_idx = 0;
|
||||
int64_t size_product = 1;
|
||||
for (const auto i : c10::irange(sizes.size())) {
|
||||
if (sizes.get(i) == -1) {
|
||||
@ -2064,7 +2062,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||
|
||||
if (inferred) {
|
||||
SHAPE_ASSERT(size_product != 0);
|
||||
size_t numel = 1;
|
||||
int64_t numel = 1;
|
||||
auto concrete_sizes =
|
||||
tensor_types.at(0)->sizes().concrete_sizes().value();
|
||||
for (int64_t s : concrete_sizes)
|
||||
@ -2154,7 +2152,9 @@ namespace {
|
||||
|
||||
using TypeCache = std::unordered_map<TypePtr, TypePtr>;
|
||||
|
||||
TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache);
|
||||
TypePtr getOrCreateUnshapedType(
|
||||
const TypePtr& type,
|
||||
TypeCache& unshaped_type_cache);
|
||||
|
||||
TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
|
||||
if (type->isSubtypeOf(*TensorType::get())) {
|
||||
@ -2172,7 +2172,9 @@ TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
|
||||
return type->withContained(std::move(unshaped_contained_types));
|
||||
}
|
||||
|
||||
TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache) {
|
||||
TypePtr getOrCreateUnshapedType(
|
||||
const TypePtr& type,
|
||||
TypeCache& unshaped_type_cache) {
|
||||
auto maybe_cached_type = unshaped_type_cache.find(type);
|
||||
if (maybe_cached_type != unshaped_type_cache.end()) {
|
||||
return maybe_cached_type->second;
|
||||
@ -2220,5 +2222,4 @@ void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
|
||||
TypeCache unshaped_type_cache;
|
||||
EraseShapeInformation(graph->block(), unshaped_type_cache);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -4,8 +4,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
struct Graph;
|
||||
|
||||
@ -39,5 +38,4 @@ TORCH_API bool mergeTypes(
|
||||
ArrayRef<Value*> rhs,
|
||||
ArrayRef<Value*> outputs);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -9,8 +9,7 @@
|
||||
#include <ATen/core/symbol.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
static const auto countsAttribute = Symbol::attr("none_counts");
|
||||
|
||||
@ -477,5 +476,4 @@ void specializeAutogradZero(std::shared_ptr<Graph> g) {
|
||||
azs.run();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// propagate autograd zero information through a gradient graph and
|
||||
// remove grad_of blocks if present.
|
||||
@ -17,5 +16,4 @@ struct ProfilingRecord;
|
||||
|
||||
TORCH_API void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -7,8 +7,7 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
void update_source_range_and_cs_ptr(
|
||||
@ -220,5 +219,4 @@ Module PatternBasedRewrite(const Module& module) {
|
||||
return subgraph_rewriter.runOnModule(module);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -16,8 +16,7 @@
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Forward declarations.
|
||||
struct RewritePatternDescr;
|
||||
@ -113,5 +112,4 @@ struct RewritePatternDescr {
|
||||
std::unordered_map<std::string, std::string> value_name_map;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -25,7 +25,6 @@
|
||||
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -42,8 +41,7 @@ but not limited to:
|
||||
|
||||
static bool symbolic_shape_analysis_test_mode = false;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// This is similar to c10::SymbolicShape, but instead of either having
|
||||
// a concrete dimension or a symbolic dimension, an argument may be:
|
||||
@ -210,7 +208,7 @@ bool isListOfTensors(const TypePtr& type) {
|
||||
|
||||
std::optional<size_t> normIndex(int64_t index, size_t len) {
|
||||
if (index < 0) {
|
||||
index = index + len;
|
||||
index = index + static_cast<int64_t>(len);
|
||||
}
|
||||
if (index >= 0 && index < static_cast<int64_t>(len)) {
|
||||
return index;
|
||||
@ -235,7 +233,7 @@ bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) {
|
||||
return made_change;
|
||||
}
|
||||
|
||||
void replaceWithIValue(Value* v, IValue val) {
|
||||
void replaceWithIValue(Value* v, const IValue& val) {
|
||||
WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin());
|
||||
v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
|
||||
}
|
||||
@ -600,7 +598,7 @@ struct SymbolicShapeOpAnalyzer {
|
||||
|
||||
SymbolicShapeOpAnalyzer(
|
||||
const FunctionSchema* schema,
|
||||
std::shared_ptr<Graph> graph)
|
||||
const std::shared_ptr<Graph>& graph)
|
||||
: schema_(schema) {
|
||||
shape_compute_graph_ = graph->copy();
|
||||
}
|
||||
@ -895,7 +893,7 @@ struct SymbolicShapeGraphAnalyzer {
|
||||
output_index_to_symbolic_shape_[i];
|
||||
}
|
||||
}
|
||||
for (int64_t i = erase_indices.size() - 1; i >= 0; i--) {
|
||||
for (auto i = static_cast<int64_t>(erase_indices.size()) - 1; i >= 0; i--) {
|
||||
stitched_shape_compute_graph->eraseOutput(erase_indices[i]);
|
||||
}
|
||||
for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) {
|
||||
@ -945,7 +943,7 @@ struct SymbolicShapeGraphAnalyzer {
|
||||
}
|
||||
|
||||
void registerStitchedComputeOutput(
|
||||
std::shared_ptr<Graph> stitched_shape_compute_graph,
|
||||
const std::shared_ptr<Graph>& stitched_shape_compute_graph,
|
||||
Value* output,
|
||||
int64_t symbolic_shape) {
|
||||
stitched_shape_compute_graph->registerOutput(output);
|
||||
@ -958,8 +956,8 @@ struct SymbolicShapeGraphAnalyzer {
|
||||
|
||||
void joinPartialEvaluatedShapeGraphToLargeShapeGraph(
|
||||
Node* curr,
|
||||
std::shared_ptr<Graph> partial_eval_graph,
|
||||
std::shared_ptr<Graph> stitched_shape_compute_graph) {
|
||||
const std::shared_ptr<Graph>& partial_eval_graph,
|
||||
const std::shared_ptr<Graph>& stitched_shape_compute_graph) {
|
||||
// we are building up the large shape compute graph by iteratively
|
||||
// combining partially evaluated individual node shape graphs.
|
||||
|
||||
@ -1183,5 +1181,4 @@ calculateSymbolicShapesOnOp(
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -6,8 +6,7 @@
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE
|
||||
|
||||
@ -54,5 +53,4 @@ TORCH_API std::optional<std::vector<c10::SymbolicShape>>
|
||||
calculateSymbolicShapesOnOp(
|
||||
const FunctionSchema* schema,
|
||||
const std::vector<SSAInput>& inputs);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -5,8 +5,8 @@
|
||||
#include <utility>
|
||||
|
||||
// SHAPE CACHING CODE
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace torch::jit {
|
||||
namespace {
|
||||
using CanonicalArg = std::variant<CanonicalizedSymbolicShape, IValue>;
|
||||
using CanonicalArgVec = std::vector<CanonicalArg>;
|
||||
@ -206,5 +206,4 @@ bool operator==(
|
||||
const CanonicalizedSymbolicShape& b) {
|
||||
return a.values_ == b.values_;
|
||||
};
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -3,8 +3,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
struct TORCH_API CanonicalizedSymbolicShape {
|
||||
// TODO: Consider in the future if it is reasonable to
|
||||
@ -53,5 +52,4 @@ TORCH_API void cache_shape_function(
|
||||
TORCH_API void clear_shape_cache();
|
||||
TORCH_API size_t get_shape_cache_size();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -15,8 +15,7 @@
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph
|
||||
// and returns back a map from Symbolic Shape Value to its runtime Value *
|
||||
@ -180,7 +179,7 @@ static StrideInput summarizeOutputStrides(const TensorType& tt) {
|
||||
// specializations
|
||||
static std::optional<std::vector<std::vector<StrideInput>>>
|
||||
TryGeneralizeInputDimensionsToSymbolicShapes(
|
||||
std::shared_ptr<Graph> tensorexpr_graph) {
|
||||
const std::shared_ptr<Graph>& tensorexpr_graph) {
|
||||
std::map<size_t, int64_t> shape_to_sym_shape;
|
||||
std::vector<std::vector<StrideInput>> input_striding;
|
||||
|
||||
@ -214,7 +213,7 @@ TryGeneralizeInputDimensionsToSymbolicShapes(
|
||||
|
||||
static void moveConstantTensorsOutOfSubgraph(
|
||||
Node* tensorexpr_graph_node,
|
||||
std::shared_ptr<Graph> tensorexpr_graph) {
|
||||
const std::shared_ptr<Graph>& tensorexpr_graph) {
|
||||
auto parent = tensorexpr_graph_node->owningGraph();
|
||||
|
||||
auto env = [&](Value* v) {
|
||||
@ -586,7 +585,7 @@ RegisterOperators reg_guard({
|
||||
} else {
|
||||
// use index for set if it exists, otherwise extend the vector
|
||||
// of sym shapes by 1
|
||||
int64_t sym_dim_index;
|
||||
size_t sym_dim_index = 0;
|
||||
if (sym_dim_flat_index.count(value)) {
|
||||
sym_dim_index = sym_dim_flat_index[value];
|
||||
} else {
|
||||
@ -596,7 +595,8 @@ RegisterOperators reg_guard({
|
||||
}
|
||||
// TODO: potential optimization - if there is a Symbolic
|
||||
// Sym with only one use we dont need to test anything
|
||||
flattened_input_dims.push_back(sym_dim_index);
|
||||
flattened_input_dims.push_back(
|
||||
static_cast<int64_t>(sym_dim_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -680,7 +680,7 @@ RegisterOperators reg_guard({
|
||||
flattened_stride_offset += num_dims;
|
||||
}
|
||||
for (const auto dim_index : c10::irange(num_dims)) {
|
||||
const int64_t dim_value =
|
||||
const auto dim_value =
|
||||
flattened_input_dims[dim_index + flattened_dim_offset];
|
||||
const int64_t tensor_dim = sizes[dim_index];
|
||||
if (dim_value >= 0) {
|
||||
@ -743,5 +743,4 @@ RegisterOperators TensorExprDynamicOp({
|
||||
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
|
||||
});
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -4,10 +4,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Takes in a TensorExprGraph of static shapes and generalizes the input shapes
|
||||
// to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
|
||||
@ -51,5 +48,4 @@ enum class StrideInput {
|
||||
TORCH_API std::string toString(StrideInput si);
|
||||
TORCH_API StrideInput strideInputFromString(const std::string& si);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -37,8 +37,7 @@ C10_DEFINE_bool(
|
||||
false,
|
||||
"enable TE fusion using dynamic shapes");
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
static bool texpr_reductions_enabled = false;
|
||||
|
||||
@ -560,8 +559,7 @@ class TensorExprFuser {
|
||||
inlineSmallFusionGroups(graph_->block());
|
||||
GRAPH_DUMP("After inlining small fusion groups: ", graph_);
|
||||
if (fuse_to_dynamic_shapes_) {
|
||||
VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled"
|
||||
<< std::endl;
|
||||
VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled" << '\n';
|
||||
generalizeFusionGroups(graph_->block());
|
||||
GRAPH_DUMP("After generalizing fusion groups: ", graph_);
|
||||
} else {
|
||||
@ -1288,7 +1286,7 @@ class TensorExprFuser {
|
||||
VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
|
||||
if (!GenerateGuard(fusion_group, add_composed_op_)) {
|
||||
VLOG(1) << " Unfusing the fusion group because GenerateGuard failed"
|
||||
<< std::endl;
|
||||
<< '\n';
|
||||
SubgraphUtils::unmergeSubgraph(fusion_group);
|
||||
}
|
||||
}
|
||||
@ -1451,5 +1449,4 @@ RegisterOperators TensorExprOps({
|
||||
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
|
||||
});
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
@ -4,8 +4,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace torch::jit {
|
||||
|
||||
// Run TensorExpressions-based fuser.
|
||||
// If add_composed_op is true, creates a single operation that
|
||||
@ -71,5 +70,4 @@ TORCH_API bool isSupported(Node* node);
|
||||
///
|
||||
TORCH_API OperatorSet& getCustomOperatorSet();
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
} // namespace torch::jit
|
||||
|
Reference in New Issue
Block a user