mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D14901379: [jit] Add options to Operator to enable registration of alias analysis passes
Differential Revision: D14901379 Original commit changeset: d92a497e280f fbshipit-source-id: 51d31491ab90907a6c95af5d8a59dff5e5ed36a4
This commit is contained in:
committed by
Facebook Github Bot
parent
0414f23855
commit
242743eedb
@ -56,7 +56,6 @@ namespace jit {
|
||||
_(TopologicalMove) \
|
||||
_(SubgraphUtils) \
|
||||
_(AliasAnalysis) \
|
||||
_(AliasRegistration) \
|
||||
_(WriteTracking) \
|
||||
_(Wildcards) \
|
||||
_(MemoryDAG) \
|
||||
|
@ -605,37 +605,5 @@ void testMemoryDAG() {
|
||||
ASSERT_FALSE(t.mayAlias(foo, baz));
|
||||
}
|
||||
}
|
||||
|
||||
void testAliasRegistration() {
|
||||
{
|
||||
auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT);
|
||||
RegisterOperators reg({createOperator(
|
||||
"foo::rand",
|
||||
[](at::Tensor) -> at::Tensor {
|
||||
return at::rand({2, 2});
|
||||
},
|
||||
opts)});
|
||||
const auto rand_op = Symbol::fromQualString("foo::rand");
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto a = graph->addInput();
|
||||
auto b = graph->insert(rand_op, {a});
|
||||
AliasDb aliasDb(graph);
|
||||
// Conservatively we assume there is a reference
|
||||
ASSERT_TRUE(aliasDb.mayAlias(a, b));
|
||||
}
|
||||
{
|
||||
auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE);
|
||||
RegisterOperators reg({createOperator(
|
||||
"foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)});
|
||||
const auto rand_op = Symbol::fromQualString("foo::pure");
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto a = graph->addInput();
|
||||
auto b = graph->insert(rand_op, {a});
|
||||
AliasDb aliasDb(graph);
|
||||
// PURE means there is no reference
|
||||
ASSERT_FALSE(aliasDb.mayAlias(a, b));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -170,8 +170,7 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
|
||||
template <typename Implementation>
|
||||
Operator createOperator(
|
||||
const std::string& schemaOrName,
|
||||
Implementation&& implementation,
|
||||
OperatorOptions options = OperatorOptions()) {
|
||||
Implementation&& implementation) {
|
||||
using Traits = c10::guts::infer_function_traits_t<Implementation>;
|
||||
using ArgumentTypes =
|
||||
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
|
||||
@ -202,20 +201,16 @@ Operator createOperator(
|
||||
name.ns().toUnqualString());
|
||||
}
|
||||
|
||||
return Operator(
|
||||
schema,
|
||||
[implementation, schema](Stack& stack) {
|
||||
ArgumentTuple tuple;
|
||||
torch::jit::detail::callOperatorWithTuple(
|
||||
schema,
|
||||
std::move(
|
||||
implementation), // NOLINT(bugprone-move-forwarding-reference)
|
||||
stack,
|
||||
tuple,
|
||||
typename MakeIndices<kNumberOfArguments>::indices{});
|
||||
return 0;
|
||||
},
|
||||
std::move(options));
|
||||
return Operator(schema, [implementation, schema](Stack& stack) {
|
||||
ArgumentTuple tuple;
|
||||
torch::jit::detail::callOperatorWithTuple(
|
||||
schema,
|
||||
std::move(implementation), // NOLINT(bugprone-move-forwarding-reference)
|
||||
stack,
|
||||
tuple,
|
||||
typename MakeIndices<kNumberOfArguments>::indices{});
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
/// Registration class for new operators. Effectively calls
|
||||
@ -245,10 +240,9 @@ struct TORCH_API RegisterOperators {
|
||||
template <typename Implementation>
|
||||
RegisterOperators& op(
|
||||
const std::string& name,
|
||||
Implementation&& implementation,
|
||||
OperatorOptions options = OperatorOptions()) {
|
||||
registerOperator(createOperator(
|
||||
name, std::forward<Implementation>(implementation), options));
|
||||
Implementation&& implementation) {
|
||||
registerOperator(
|
||||
createOperator(name, std::forward<Implementation>(implementation)));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
@ -379,8 +379,7 @@ void registerOperator(Operator&& op) {
|
||||
op.schema().name(),
|
||||
". File a bug to add a case for this operator.\n");
|
||||
}
|
||||
if (!aliasAnalysisHasSpecialCaseFor(s) &&
|
||||
op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) {
|
||||
if (!aliasAnalysisHasSpecialCaseFor(s)) {
|
||||
AT_ERROR(
|
||||
"Missing special case in alias analysis for non-schematized"
|
||||
" operator ",
|
||||
|
@ -3,10 +3,9 @@
|
||||
// it now to implement correct semantic checking for script
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/operator_options.h>
|
||||
#include <ATen/core/stack.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
@ -59,31 +58,19 @@ using OperationCreator = std::function<Operation(const Node*)>;
|
||||
*/
|
||||
|
||||
struct TORCH_API Operator {
|
||||
Operator(
|
||||
FunctionSchema schema,
|
||||
OperationCreator op_creator,
|
||||
OperatorOptions options = OperatorOptions())
|
||||
Operator(FunctionSchema schema, OperationCreator op_creator)
|
||||
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
|
||||
op_creator_(std::move(op_creator)),
|
||||
options_(std::move(options)) {}
|
||||
op_creator_(std::move(op_creator)) {}
|
||||
|
||||
Operator(
|
||||
const std::string& schema,
|
||||
OperationCreator op_creator,
|
||||
OperatorOptions options = OperatorOptions())
|
||||
: schema_string_(schema),
|
||||
op_creator_(std::move(op_creator)),
|
||||
options_(std::move(options)) {}
|
||||
Operator(const std::string& schema, OperationCreator op_creator)
|
||||
: schema_string_(schema), op_creator_(std::move(op_creator)) {}
|
||||
|
||||
// Helper constructor to register `op` to run
|
||||
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
|
||||
// This is accomplished by marking the schema varargs and having no required
|
||||
// arguments. This is used for things like prim::While or prim::If that can
|
||||
// take a number of different valid input types and lengths.
|
||||
Operator(
|
||||
Symbol name,
|
||||
OperationCreator op_creator,
|
||||
OperatorOptions options = OperatorOptions())
|
||||
Operator(Symbol name, OperationCreator op_creator)
|
||||
: Operator(
|
||||
FunctionSchema(
|
||||
name,
|
||||
@ -92,24 +79,15 @@ struct TORCH_API Operator {
|
||||
{},
|
||||
/*is_vararg*/ true,
|
||||
/*is_varret*/ true),
|
||||
std::move(op_creator),
|
||||
std::move(options)) {}
|
||||
std::move(op_creator)) {}
|
||||
|
||||
Operator(
|
||||
FunctionSchema schema,
|
||||
Operation op,
|
||||
OperatorOptions options = OperatorOptions())
|
||||
Operator(FunctionSchema schema, Operation op)
|
||||
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
|
||||
op_(std::make_shared<Operation>(std::move(op))),
|
||||
options_(std::move(options)) {}
|
||||
op_(std::make_shared<Operation>(std::move(op))) {}
|
||||
|
||||
Operator(
|
||||
const std::string& schema,
|
||||
Operation op,
|
||||
OperatorOptions options = OperatorOptions())
|
||||
Operator(const std::string& schema, Operation op)
|
||||
: schema_string_(schema),
|
||||
op_(std::make_shared<Operation>(std::move(op))),
|
||||
options_(std::move(options)) {}
|
||||
op_(std::make_shared<Operation>(std::move(op))) {}
|
||||
|
||||
bool matches(const Node* node) const;
|
||||
|
||||
@ -132,10 +110,6 @@ struct TORCH_API Operator {
|
||||
return *schema_;
|
||||
}
|
||||
|
||||
const OperatorOptions& options() const {
|
||||
return options_;
|
||||
}
|
||||
|
||||
private:
|
||||
mutable c10::optional<std::string> schema_string_;
|
||||
// cannot use c10::optional because windows has issues that require an
|
||||
@ -147,7 +121,6 @@ struct TORCH_API Operator {
|
||||
// NB: std::function has a default state (where it == nullptr).
|
||||
std::shared_ptr<Operation> op_;
|
||||
OperationCreator op_creator_;
|
||||
OperatorOptions options_;
|
||||
};
|
||||
|
||||
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
|
||||
|
@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/passes/alias_analysis.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
enum class AliasAnalysisKind {
|
||||
DEFAULT, // The most conservative alias analysis type, assumes side-effects
|
||||
PURE
|
||||
};
|
||||
|
||||
struct OperatorOptions {
|
||||
OperatorOptions(){};
|
||||
|
||||
OperatorOptions aliasAnalysis(AliasAnalysisKind aak) const noexcept {
|
||||
OperatorOptions r = *this;
|
||||
r.aliasAnalysisKind_ = aak;
|
||||
return r;
|
||||
}
|
||||
|
||||
const AliasAnalysisKind& aliasAnalysis() const {
|
||||
return aliasAnalysisKind_;
|
||||
}
|
||||
AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1,6 +1,5 @@
|
||||
#include <torch/csrc/jit/passes/alias_analysis.h>
|
||||
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/script/error_report.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
||||
@ -347,21 +346,6 @@ void AliasDb::analyze(Node* node) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if analysis was run using
|
||||
// the registered analyzer.
|
||||
bool AliasDb::tryRegisteredAnalysis(Node* node) {
|
||||
const Operator& op = getOperatorFor(node);
|
||||
auto analysis = op.options().aliasAnalysis();
|
||||
switch (analysis) {
|
||||
case AliasAnalysisKind::PURE:
|
||||
analyzeCreator(node);
|
||||
return true;
|
||||
case AliasAnalysisKind::DEFAULT:
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// The basic strategy is:
|
||||
// 1. Retrieve alias information for every input.
|
||||
// 2. Use the node's schema's alias annotations to propgagate alias/write
|
||||
@ -425,9 +409,6 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
// These ops do nothing
|
||||
return;
|
||||
default:
|
||||
if (tryRegisteredAnalysis(node)) {
|
||||
return;
|
||||
}
|
||||
AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
|
||||
}
|
||||
|
||||
|
@ -183,7 +183,6 @@ class AliasDb {
|
||||
void analyzeWait(Node* node);
|
||||
void analyzeSetAttr(Node* node);
|
||||
void analyzeCustomOp(Node* node);
|
||||
bool tryRegisteredAnalysis(Node* node);
|
||||
|
||||
/**
|
||||
* Alias manipulation methods
|
||||
|
Reference in New Issue
Block a user