[9/N] Fix clang-tidy warnings in jit (#132010)

Follows  #131997

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132010
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2024-07-29 18:38:35 +00:00
committed by PyTorch MergeBot
parent f389bca2e9
commit c764ef6d53
54 changed files with 152 additions and 262 deletions

View File

@ -1,8 +1,7 @@
#include <torch/csrc/jit/passes/add_if_then_else.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -51,5 +50,4 @@ bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) {
return !to_replace.empty();
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API bool AddIfThenElseOp(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <atomic>
namespace torch {
namespace jit {
namespace torch::jit {
static void AnnotateWarns(Block* b) {
static std::atomic<int64_t> idx(0);
@ -25,5 +24,4 @@ void AnnotateWarns(const std::shared_ptr<Graph>& graph) {
AnnotateWarns(graph->block());
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void AnnotateWarns(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -13,8 +13,7 @@
#include <unordered_set>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -79,10 +78,10 @@ std::optional<AutocastScope> parseAutocast(
if (use.user->kind() == prim::SetAttr &&
use.user->s(attr::name) == "_enabled") {
// Search for `prim::SetAttr[name="_enabled"]`
auto ret = constant_as<bool>(use.user->input(1));
enabled = constant_as<bool>(use.user->input(1));
TORCH_CHECK(
ret.has_value(), "Autocast _enabled argument must be a constant");
enabled = ret.value();
enabled.has_value(),
"Autocast _enabled argument must be a constant");
} else if (
use.user->kind() == prim::SetAttr &&
use.user->s(attr::name) == "device") {
@ -532,5 +531,4 @@ void Autocast(const std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("\nAfter Autocast: ", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,13 +3,11 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void Autocast(const std::shared_ptr<Graph>& graph);
TORCH_API bool setAutocastMode(bool value);
TORCH_API bool autocastEnabled();
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -12,8 +12,7 @@
#include <unordered_set>
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
static bool shouldBeCapturedInByBailOut(Node* n) {
return n->kind() != prim::Constant;
@ -223,7 +222,7 @@ struct BailOutGraphBuilderForNode {
// version of an original graph from a particular point
struct BailOutInserter {
explicit BailOutInserter(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)), bailout_index_(0) {}
: graph_(std::move(graph)) {}
void run() {
liveness_sets_ = BuildLivenessSets(graph_);
@ -322,7 +321,7 @@ struct BailOutInserter {
std::shared_ptr<Graph> graph_;
std::map<Node*, Node*> subgraphs;
std::size_t bailout_index_;
std::size_t bailout_index_{0};
std::unordered_map<Node*, std::vector<Value*>> liveness_sets_;
std::vector<Node*> bailouts_;
std::map<Value*, Value*> replacements_;
@ -394,5 +393,4 @@ TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
return bailout_graph;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -10,8 +10,7 @@
#include <list>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
// Replaces prim::Guard nodes with prim::BailOut nodes and
// computes sets of inputs needed to resume execution at
@ -30,5 +29,4 @@ TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
int64_t bailout_index,
const std::shared_ptr<Graph>& orig,
const std::shared_ptr<Graph>& target);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -16,8 +16,7 @@
#include <unordered_map>
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
@ -490,5 +489,4 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void BatchMM(std::shared_ptr<Graph>& graph);
}
} // namespace torch

View File

@ -3,8 +3,7 @@
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/ir_views.h>
namespace torch {
namespace jit {
namespace torch::jit {
// Canonicalize a graph, renumbering it so that all structurally equivalent
// graphs have same numbers.
@ -231,5 +230,4 @@ static void CanonicalizeOutputs(Block* block) {
void CanonicalizeOutputs(std::shared_ptr<Graph>& graph) {
CanonicalizeOutputs(graph->block());
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API std::shared_ptr<Graph> Canonicalize(
const std::shared_ptr<Graph>& graph,
@ -18,5 +17,4 @@ TORCH_API bool isBeforeOrAfter(
const Use& b,
bool checking_before);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,8 +3,7 @@
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
namespace torch {
namespace jit {
namespace torch::jit {
struct ChunkOutput {
ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
@ -96,5 +95,4 @@ void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
EliminateDeadCode(graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void CanonicalizeOps(const std::shared_ptr<Graph>& graph);
}
} // namespace torch

View File

@ -7,10 +7,8 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -77,7 +75,7 @@ static void checkForUnfusedOps(Node* enter_node) {
for (Node* n : guarding_ifs) {
ss << *n << "\n";
}
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
}
// autodiff/nnc both insert a number of guards, see
@ -110,8 +108,7 @@ static void checkForUnfusedOps(Node* enter_node) {
}
ss << "\n";
}
auto range = enter_node->input()->node()->sourceRange();
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
}
}
@ -128,5 +125,4 @@ void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
// TODO: improve control flow not taken, right now always errors
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,10 +3,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void CheckStrictFusion(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
namespace torch::jit {
void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
@ -45,5 +44,4 @@ void ClearProfilingInformation(const std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("After ClearProfilingInformation: ", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void unprofileGraphInputs(const std::shared_ptr<Graph>& graph);
TORCH_API void unprofileBlock(Block* start_block);
@ -15,5 +14,4 @@ TORCH_API void unprofileBlock(Block* start_block);
TORCH_API void ClearProfilingInformation(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
namespace torch::jit {
static void clearUndefinedness(Value* o) {
if (o->type()->kind() == TensorType::Kind) {
@ -35,5 +34,4 @@ void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("After removeUndefinedness: ", graph);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// Undefinedness makes argument matching fail for regular tensor operations
// if 1+ arguments are undefined or possibly undefined tensors.
@ -20,5 +19,4 @@ namespace jit {
// When this happens, this pass will be removed
TORCH_API void ClearUndefinedness(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -7,8 +7,7 @@
#include <unordered_map>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
struct CommonSubexpressionEliminator {
@ -126,5 +125,4 @@ bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) {
CommonSubexpressionEliminator cse(graph);
return cse.run([](Node*) { return nullptr; });
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API bool EliminateCommonSubexpression(
const std::shared_ptr<Graph>& graph);
}
} // namespace torch

View File

@ -16,8 +16,7 @@
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -699,5 +698,4 @@ bool CombineConcats(const std::shared_ptr<Graph>& graph) {
return changed;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// Eliminates common inputs among `aten::cat` ops.
TORCH_API bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph);
@ -15,5 +14,4 @@ TORCH_API void ExpandConcatAndEliminateRedundancy(
TORCH_API bool CombineConcats(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <torch/csrc/jit/ir/node_hashing.h>
#include <unordered_set>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -71,5 +70,4 @@ void ConstantPooling(const std::shared_ptr<Graph>& graph) {
std::unordered_set<Node*, HashNode, EqualNode> constants;
ConstantPooling(graph->block(), constants, aliasDb);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void ConstantPooling(const std::shared_ptr<Graph>& graph);
}
} // namespace torch

View File

@ -16,8 +16,7 @@
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
std::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
const Node* n,
@ -434,5 +433,4 @@ bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph) {
return made_change;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// Runs constant propagation on all objects unless ignore_custom_classes is
// specified as true, in which case user defined classes are skipped. This is
@ -28,5 +27,4 @@ TORCH_API std::optional<Stack> runNodeIfInputsAreConstant(
bool ignore_custom_classes = false,
AliasDb* db = nullptr);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -10,8 +10,7 @@
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/autodiff.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -374,9 +373,9 @@ std::optional<bool> findRequiresGradForOutput(
}
if (use.user->kind() == prim::profile) {
std::optional<bool> req_grad_use;
if ((req_grad_use = getProfileNodeRequiresGrad(use.user)).has_value()) {
return req_grad_use.value();
auto req_grad_use = getProfileNodeRequiresGrad(use.user);
if (req_grad_use.has_value()) {
return req_grad_use;
}
}
@ -393,10 +392,9 @@ std::optional<bool> findRequiresGradForOutput(
}
if (dg_use.user->kind() == prim::profile) {
std::optional<bool> req_grad_use;
if ((req_grad_use = getProfileNodeRequiresGrad(dg_use.user))
.has_value()) {
return req_grad_use.value();
auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user);
if (req_grad_use.has_value()) {
return req_grad_use;
}
}
}
@ -474,5 +472,4 @@ std::vector<Node*> CreateAutodiffSubgraphs(
GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
return diff_nodes;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -5,8 +5,7 @@
#include <cstddef>
namespace torch {
namespace jit {
namespace torch::jit {
// insert GraphExecutor nodes that group together
// subgraphs that are differentiable by the jit's autodiff passes
@ -15,5 +14,4 @@ namespace jit {
TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(
const std::shared_ptr<Graph>& graph,
size_t threshold = 2);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -8,8 +8,7 @@
#include <cstddef>
#include <limits>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -223,5 +222,4 @@ void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
InlineFunctionalGraphs(graph->block());
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,12 +3,10 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph);
TORCH_API void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -7,8 +7,7 @@
#include <unordered_map>
namespace torch {
namespace jit {
namespace torch::jit {
namespace prim {
using namespace ::c10::prim;
@ -458,5 +457,4 @@ void EliminateDeadCode(
eliminator.run(block, /*recurse=*/true);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// If given a top-level graph, DCE will construct do alias analysis that allows
// for "smarter" dead code elimination (we will eliminate mutable ops if we can
@ -38,5 +37,4 @@ TORCH_API void EliminateDeadCode(
std::function<void(const std::unordered_set<const Value*>&)> cb,
DCESideEffectPolicy sideEffectPolicy =
DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -10,8 +10,7 @@
#include <ATen/core/symbol.h>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
@ -231,5 +230,4 @@ void DecomposeOps(std::shared_ptr<Graph>& graph) {
}
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,10 +2,8 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
TORCH_API void DecomposeOps(std::shared_ptr<Graph>& graph);
}
} // namespace torch

View File

@ -10,8 +10,7 @@
#include <optional>
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -67,7 +66,7 @@ bool returnSecondArgDeviceRule(Node* n) {
return setReturnsToDevice(n, tensor_type->device());
}
bool isZerodimCPUTensor(std::shared_ptr<TensorType> tensor_type) {
bool isZerodimCPUTensor(const std::shared_ptr<TensorType>& tensor_type) {
// CPU devices on zerodim tensors are the only device that can be
// overwritten by another device. Therefore, to be conservative
// assume that it is not a zerodim cpu tensor if something is not known.
@ -149,7 +148,7 @@ bool defaultDeviceProp(Node* n) {
struct DeviceTypePropagationPass : public PropertyPropBase {
explicit DeviceTypePropagationPass(std::shared_ptr<Graph> graph)
: PropertyPropBase(graph) {
: PropertyPropBase(std::move(graph)) {
buildRuleRegistry();
}
@ -261,5 +260,4 @@ bool DeviceTypePropagation(std::shared_ptr<Graph>& graph) {
return changed;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,12 +2,10 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
struct Graph;
// Propagates Device type info throughout the given graph.
TORCH_API bool DeviceTypePropagation(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -21,8 +21,7 @@
#include <memory>
#include <stdexcept>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
@ -336,5 +335,4 @@ bool DtypePropagation(std::shared_ptr<Graph>& graph) {
return changed;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,8 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <memory>
namespace torch {
namespace jit {
namespace torch::jit {
struct Graph;
// Propagate tensor properties (e.g., dtype, device, is_contiguous, layout)
@ -13,5 +12,4 @@ struct Graph;
// propagation
TORCH_API bool DtypePropagation(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -24,13 +24,12 @@
#endif
#include <exception>
#include <iostream>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
bool mergeTypes(
ArrayRef<Value*> lhs,
@ -64,9 +63,10 @@ void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) {
} catch (propagation_error& e) {
setUnshapedType(node);
} catch (std::exception& e) {
throw ErrorReport(node->sourceRange())
throw(
ErrorReport(node->sourceRange())
<< ExceptionMessage(e)
<< "\nThe above operation failed shape propagation in this context";
<< "\nThe above operation failed shape propagation in this context");
}
}
}
@ -314,8 +314,7 @@ class ShapePropagator : public PropertyPropBase {
return at::empty_strided(
*type->sizes().concrete_sizes(),
*type->strides().concrete_sizes(),
at::TensorOptions(*type->device())
.dtype(*type->scalarType()))
at::TensorOptions(*type->device()).dtype(type->scalarType()))
.zero_();
}
// fallthrough
@ -992,7 +991,7 @@ class ShapePropagator : public PropertyPropBase {
arg_for_type = 1;
}
auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
return {broadcast(*maybe_tensor_types, *t)};
return {broadcast(*maybe_tensor_types, t)};
}
return {};
}};
@ -1009,7 +1008,7 @@ class ShapePropagator : public PropertyPropBase {
if (!dtype) {
return {};
}
return {broadcast(*maybe_tensor_types, *dtype)};
return {broadcast(*maybe_tensor_types, dtype)};
}
return {};
}};
@ -1714,7 +1713,7 @@ class ShapePropagator : public PropertyPropBase {
Symbol shape_input,
const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
if (auto list_size = determineListSize(node->namedInput(shape_input))) {
return tensor_types.at(0)->withDim(*list_size);
return tensor_types.at(0)->withDim(list_size);
}
return nullptr;
};
@ -1799,8 +1798,8 @@ class ShapePropagator : public PropertyPropBase {
if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) {
return nullptr;
}
int dim1 = *tensor_types.at(0)->dim();
int dim2 = *tensor_types.at(1)->dim();
auto dim1 = *tensor_types.at(0)->dim();
auto dim2 = *tensor_types.at(1)->dim();
if (dim1 == 1 && dim2 == 1) {
// Dot product
return tensor_types.at(0)->withDim(0);
@ -2048,8 +2047,7 @@ class ShapePropagator : public PropertyPropBase {
/*const_inputs=*/attr::size)) {
auto sizes = node->get<c10::List<int64_t>>(attr::size).value();
bool inferred = false;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t inferred_idx;
size_t inferred_idx = 0;
int64_t size_product = 1;
for (const auto i : c10::irange(sizes.size())) {
if (sizes.get(i) == -1) {
@ -2064,7 +2062,7 @@ class ShapePropagator : public PropertyPropBase {
if (inferred) {
SHAPE_ASSERT(size_product != 0);
size_t numel = 1;
int64_t numel = 1;
auto concrete_sizes =
tensor_types.at(0)->sizes().concrete_sizes().value();
for (int64_t s : concrete_sizes)
@ -2154,7 +2152,9 @@ namespace {
using TypeCache = std::unordered_map<TypePtr, TypePtr>;
TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache);
TypePtr getOrCreateUnshapedType(
const TypePtr& type,
TypeCache& unshaped_type_cache);
TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
if (type->isSubtypeOf(*TensorType::get())) {
@ -2172,7 +2172,9 @@ TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
return type->withContained(std::move(unshaped_contained_types));
}
TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache) {
TypePtr getOrCreateUnshapedType(
const TypePtr& type,
TypeCache& unshaped_type_cache) {
auto maybe_cached_type = unshaped_type_cache.find(type);
if (maybe_cached_type != unshaped_type_cache.end()) {
return maybe_cached_type->second;
@ -2220,5 +2222,4 @@ void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
TypeCache unshaped_type_cache;
EraseShapeInformation(graph->block(), unshaped_type_cache);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,8 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <memory>
namespace torch {
namespace jit {
namespace torch::jit {
struct Graph;
@ -39,5 +38,4 @@ TORCH_API bool mergeTypes(
ArrayRef<Value*> rhs,
ArrayRef<Value*> outputs);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -9,8 +9,7 @@
#include <ATen/core/symbol.h>
#include <c10/util/irange.h>
namespace torch {
namespace jit {
namespace torch::jit {
static const auto countsAttribute = Symbol::attr("none_counts");
@ -477,5 +476,4 @@ void specializeAutogradZero(std::shared_ptr<Graph> g) {
azs.run();
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace torch::jit {
// propagate autograd zero information through a gradient graph and
// remove grad_of blocks if present.
@ -17,5 +16,4 @@ struct ProfilingRecord;
TORCH_API void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -7,8 +7,7 @@
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
void update_source_range_and_cs_ptr(
@ -220,5 +219,4 @@ Module PatternBasedRewrite(const Module& module) {
return subgraph_rewriter.runOnModule(module);
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -16,8 +16,7 @@
#include <unordered_set>
#include <vector>
namespace torch {
namespace jit {
namespace torch::jit {
// Forward declarations.
struct RewritePatternDescr;
@ -113,5 +112,4 @@ struct RewritePatternDescr {
std::unordered_map<std::string, std::string> value_name_map;
};
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -25,7 +25,6 @@
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <unordered_map>
#include <utility>
#include <vector>
@ -42,8 +41,7 @@ but not limited to:
static bool symbolic_shape_analysis_test_mode = false;
namespace torch {
namespace jit {
namespace torch::jit {
// This is similar to c10::SymbolicShape, but instead of either having
// a concrete dimension or a symbolic dimension, an argument may be:
@ -210,7 +208,7 @@ bool isListOfTensors(const TypePtr& type) {
std::optional<size_t> normIndex(int64_t index, size_t len) {
if (index < 0) {
index = index + len;
index = index + static_cast<int64_t>(len);
}
if (index >= 0 && index < static_cast<int64_t>(len)) {
return index;
@ -235,7 +233,7 @@ bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) {
return made_change;
}
void replaceWithIValue(Value* v, IValue val) {
void replaceWithIValue(Value* v, const IValue& val) {
WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin());
v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
}
@ -600,7 +598,7 @@ struct SymbolicShapeOpAnalyzer {
SymbolicShapeOpAnalyzer(
const FunctionSchema* schema,
std::shared_ptr<Graph> graph)
const std::shared_ptr<Graph>& graph)
: schema_(schema) {
shape_compute_graph_ = graph->copy();
}
@ -895,7 +893,7 @@ struct SymbolicShapeGraphAnalyzer {
output_index_to_symbolic_shape_[i];
}
}
for (int64_t i = erase_indices.size() - 1; i >= 0; i--) {
for (auto i = static_cast<int64_t>(erase_indices.size()) - 1; i >= 0; i--) {
stitched_shape_compute_graph->eraseOutput(erase_indices[i]);
}
for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) {
@ -945,7 +943,7 @@ struct SymbolicShapeGraphAnalyzer {
}
void registerStitchedComputeOutput(
std::shared_ptr<Graph> stitched_shape_compute_graph,
const std::shared_ptr<Graph>& stitched_shape_compute_graph,
Value* output,
int64_t symbolic_shape) {
stitched_shape_compute_graph->registerOutput(output);
@ -958,8 +956,8 @@ struct SymbolicShapeGraphAnalyzer {
void joinPartialEvaluatedShapeGraphToLargeShapeGraph(
Node* curr,
std::shared_ptr<Graph> partial_eval_graph,
std::shared_ptr<Graph> stitched_shape_compute_graph) {
const std::shared_ptr<Graph>& partial_eval_graph,
const std::shared_ptr<Graph>& stitched_shape_compute_graph) {
// we are building up the large shape compute graph by iteratively
// combining partially evaluated individual node shape graphs.
@ -1183,5 +1181,4 @@ calculateSymbolicShapesOnOp(
return res;
}
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -6,8 +6,7 @@
#include <utility>
#include <variant>
namespace torch {
namespace jit {
namespace torch::jit {
// CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE
@ -54,5 +53,4 @@ TORCH_API std::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -5,8 +5,8 @@
#include <utility>
// SHAPE CACHING CODE
namespace torch {
namespace jit {
namespace torch::jit {
namespace {
using CanonicalArg = std::variant<CanonicalizedSymbolicShape, IValue>;
using CanonicalArgVec = std::vector<CanonicalArg>;
@ -206,5 +206,4 @@ bool operator==(
const CanonicalizedSymbolicShape& b) {
return a.values_ == b.values_;
};
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -3,8 +3,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
namespace torch {
namespace jit {
namespace torch::jit {
struct TORCH_API CanonicalizedSymbolicShape {
// TODO: Consider in the future if it is reasonable to
@ -53,5 +52,4 @@ TORCH_API void cache_shape_function(
TORCH_API void clear_shape_cache();
TORCH_API size_t get_shape_cache_size();
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -15,8 +15,7 @@
#include <sstream>
#include <utility>
namespace torch {
namespace jit {
namespace torch::jit {
// Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph
// and returns back a map from Symbolic Shape Value to its runtime Value *
@ -180,7 +179,7 @@ static StrideInput summarizeOutputStrides(const TensorType& tt) {
// specializations
static std::optional<std::vector<std::vector<StrideInput>>>
TryGeneralizeInputDimensionsToSymbolicShapes(
std::shared_ptr<Graph> tensorexpr_graph) {
const std::shared_ptr<Graph>& tensorexpr_graph) {
std::map<size_t, int64_t> shape_to_sym_shape;
std::vector<std::vector<StrideInput>> input_striding;
@ -214,7 +213,7 @@ TryGeneralizeInputDimensionsToSymbolicShapes(
static void moveConstantTensorsOutOfSubgraph(
Node* tensorexpr_graph_node,
std::shared_ptr<Graph> tensorexpr_graph) {
const std::shared_ptr<Graph>& tensorexpr_graph) {
auto parent = tensorexpr_graph_node->owningGraph();
auto env = [&](Value* v) {
@ -586,7 +585,7 @@ RegisterOperators reg_guard({
} else {
// use index for set if it exists, otherwise extend the vector
// of sym shapes by 1
int64_t sym_dim_index;
size_t sym_dim_index = 0;
if (sym_dim_flat_index.count(value)) {
sym_dim_index = sym_dim_flat_index[value];
} else {
@ -596,7 +595,8 @@ RegisterOperators reg_guard({
}
// TODO: potential optimization - if there is a Symbolic
// Sym with only one use we dont need to test anything
flattened_input_dims.push_back(sym_dim_index);
flattened_input_dims.push_back(
static_cast<int64_t>(sym_dim_index));
}
}
}
@ -680,7 +680,7 @@ RegisterOperators reg_guard({
flattened_stride_offset += num_dims;
}
for (const auto dim_index : c10::irange(num_dims)) {
const int64_t dim_value =
const auto dim_value =
flattened_input_dims[dim_index + flattened_dim_offset];
const int64_t tensor_dim = sizes[dim_index];
if (dim_value >= 0) {
@ -743,5 +743,4 @@ RegisterOperators TensorExprDynamicOp({
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,10 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace torch::jit {
// Takes in a TensorExprGraph of static shapes and generalizes the input shapes
// to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
@ -51,5 +48,4 @@ enum class StrideInput {
TORCH_API std::string toString(StrideInput si);
TORCH_API StrideInput strideInputFromString(const std::string& si);
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -37,8 +37,7 @@ C10_DEFINE_bool(
false,
"enable TE fusion using dynamic shapes");
namespace torch {
namespace jit {
namespace torch::jit {
static bool texpr_reductions_enabled = false;
@ -560,8 +559,7 @@ class TensorExprFuser {
inlineSmallFusionGroups(graph_->block());
GRAPH_DUMP("After inlining small fusion groups: ", graph_);
if (fuse_to_dynamic_shapes_) {
VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled"
<< std::endl;
VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled" << '\n';
generalizeFusionGroups(graph_->block());
GRAPH_DUMP("After generalizing fusion groups: ", graph_);
} else {
@ -1288,7 +1286,7 @@ class TensorExprFuser {
VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
if (!GenerateGuard(fusion_group, add_composed_op_)) {
VLOG(1) << " Unfusing the fusion group because GenerateGuard failed"
<< std::endl;
<< '\n';
SubgraphUtils::unmergeSubgraph(fusion_group);
}
}
@ -1451,5 +1449,4 @@ RegisterOperators TensorExprOps({
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
} // namespace jit
} // namespace torch
} // namespace torch::jit

View File

@ -4,8 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <memory>
namespace torch {
namespace jit {
namespace torch::jit {
// Run TensorExpressions-based fuser.
// If add_composed_op is true, creates a single operation that
@ -71,5 +70,4 @@ TORCH_API bool isSupported(Node* node);
///
TORCH_API OperatorSet& getCustomOperatorSet();
} // namespace tensorexpr
} // namespace jit
} // namespace torch
} // namespace torch::jit