Revert D14901379: [jit] Add options to Operator to enable registration of alias analysis passes

Differential Revision:
D14901379

Original commit changeset: d92a497e280f

fbshipit-source-id: 51d31491ab90907a6c95af5d8a59dff5e5ed36a4
This commit is contained in:
Michael Suo
2019-04-17 16:48:28 -07:00
committed by Facebook Github Bot
parent 0414f23855
commit 242743eedb
8 changed files with 26 additions and 142 deletions

View File

@ -56,7 +56,6 @@ namespace jit {
_(TopologicalMove) \
_(SubgraphUtils) \
_(AliasAnalysis) \
_(AliasRegistration) \
_(WriteTracking) \
_(Wildcards) \
_(MemoryDAG) \

View File

@ -605,37 +605,5 @@ void testMemoryDAG() {
ASSERT_FALSE(t.mayAlias(foo, baz));
}
}
void testAliasRegistration() {
{
auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT);
RegisterOperators reg({createOperator(
"foo::rand",
[](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
},
opts)});
const auto rand_op = Symbol::fromQualString("foo::rand");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// Conservatively we assume there is a reference
ASSERT_TRUE(aliasDb.mayAlias(a, b));
}
{
auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE);
RegisterOperators reg({createOperator(
"foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)});
const auto rand_op = Symbol::fromQualString("foo::pure");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// PURE means there is no reference
ASSERT_FALSE(aliasDb.mayAlias(a, b));
}
}
} // namespace jit
} // namespace torch

View File

@ -170,8 +170,7 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
template <typename Implementation>
Operator createOperator(
const std::string& schemaOrName,
Implementation&& implementation,
OperatorOptions options = OperatorOptions()) {
Implementation&& implementation) {
using Traits = c10::guts::infer_function_traits_t<Implementation>;
using ArgumentTypes =
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
@ -202,20 +201,16 @@ Operator createOperator(
name.ns().toUnqualString());
}
return Operator(
schema,
[implementation, schema](Stack& stack) {
ArgumentTuple tuple;
torch::jit::detail::callOperatorWithTuple(
schema,
std::move(
implementation), // NOLINT(bugprone-move-forwarding-reference)
stack,
tuple,
typename MakeIndices<kNumberOfArguments>::indices{});
return 0;
},
std::move(options));
return Operator(schema, [implementation, schema](Stack& stack) {
ArgumentTuple tuple;
torch::jit::detail::callOperatorWithTuple(
schema,
std::move(implementation), // NOLINT(bugprone-move-forwarding-reference)
stack,
tuple,
typename MakeIndices<kNumberOfArguments>::indices{});
return 0;
});
}
/// Registration class for new operators. Effectively calls
@ -245,10 +240,9 @@ struct TORCH_API RegisterOperators {
template <typename Implementation>
RegisterOperators& op(
const std::string& name,
Implementation&& implementation,
OperatorOptions options = OperatorOptions()) {
registerOperator(createOperator(
name, std::forward<Implementation>(implementation), options));
Implementation&& implementation) {
registerOperator(
createOperator(name, std::forward<Implementation>(implementation)));
return *this;
}
};

View File

@ -379,8 +379,7 @@ void registerOperator(Operator&& op) {
op.schema().name(),
". File a bug to add a case for this operator.\n");
}
if (!aliasAnalysisHasSpecialCaseFor(s) &&
op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) {
if (!aliasAnalysisHasSpecialCaseFor(s)) {
AT_ERROR(
"Missing special case in alias analysis for non-schematized"
" operator ",

View File

@ -3,10 +3,9 @@
// it now to implement correct semantic checking for script
#pragma once
#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator_options.h>
#include <ATen/core/stack.h>
#include <ATen/ATen.h>
#include <ATen/core/function_schema.h>
@ -59,31 +58,19 @@ using OperationCreator = std::function<Operation(const Node*)>;
*/
struct TORCH_API Operator {
Operator(
FunctionSchema schema,
OperationCreator op_creator,
OperatorOptions options = OperatorOptions())
Operator(FunctionSchema schema, OperationCreator op_creator)
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
op_creator_(std::move(op_creator)) {}
Operator(
const std::string& schema,
OperationCreator op_creator,
OperatorOptions options = OperatorOptions())
: schema_string_(schema),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
Operator(const std::string& schema, OperationCreator op_creator)
: schema_string_(schema), op_creator_(std::move(op_creator)) {}
// Helper constructor to register `op` to run
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
// This is accomplished by marking the schema varargs and having no required
// arguments. This is used for things like prim::While or prim::If that can
// take a number of different valid input types and lengths.
Operator(
Symbol name,
OperationCreator op_creator,
OperatorOptions options = OperatorOptions())
Operator(Symbol name, OperationCreator op_creator)
: Operator(
FunctionSchema(
name,
@ -92,24 +79,15 @@ struct TORCH_API Operator {
{},
/*is_vararg*/ true,
/*is_varret*/ true),
std::move(op_creator),
std::move(options)) {}
std::move(op_creator)) {}
Operator(
FunctionSchema schema,
Operation op,
OperatorOptions options = OperatorOptions())
Operator(FunctionSchema schema, Operation op)
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
op_(std::make_shared<Operation>(std::move(op))) {}
Operator(
const std::string& schema,
Operation op,
OperatorOptions options = OperatorOptions())
Operator(const std::string& schema, Operation op)
: schema_string_(schema),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
op_(std::make_shared<Operation>(std::move(op))) {}
bool matches(const Node* node) const;
@ -132,10 +110,6 @@ struct TORCH_API Operator {
return *schema_;
}
const OperatorOptions& options() const {
return options_;
}
private:
mutable c10::optional<std::string> schema_string_;
// cannot use c10::optional because windows has issues that require an
@ -147,7 +121,6 @@ struct TORCH_API Operator {
// NB: std::function has a default state (where it == nullptr).
std::shared_ptr<Operation> op_;
OperationCreator op_creator_;
OperatorOptions options_;
};
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);

View File

@ -1,29 +0,0 @@
#pragma once
#include <torch/csrc/jit/passes/alias_analysis.h>
namespace torch {
namespace jit {
enum class AliasAnalysisKind {
DEFAULT, // The most conservative alias analysis type, assumes side-effects
PURE
};
struct OperatorOptions {
OperatorOptions(){};
OperatorOptions aliasAnalysis(AliasAnalysisKind aak) const noexcept {
OperatorOptions r = *this;
r.aliasAnalysisKind_ = aak;
return r;
}
const AliasAnalysisKind& aliasAnalysis() const {
return aliasAnalysisKind_;
}
AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
};
} // namespace jit
} // namespace torch

View File

@ -1,6 +1,5 @@
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/utils/memory.h>
@ -347,21 +346,6 @@ void AliasDb::analyze(Node* node) {
}
}
// Returns true if analysis was run using
// the registered analyzer.
bool AliasDb::tryRegisteredAnalysis(Node* node) {
const Operator& op = getOperatorFor(node);
auto analysis = op.options().aliasAnalysis();
switch (analysis) {
case AliasAnalysisKind::PURE:
analyzeCreator(node);
return true;
case AliasAnalysisKind::DEFAULT:
return false;
}
return false;
}
// The basic strategy is:
// 1. Retrieve alias information for every input.
// 2. Use the node's schema's alias annotations to propgagate alias/write
@ -425,9 +409,6 @@ void AliasDb::analyzeImpl(Node* node) {
// These ops do nothing
return;
default:
if (tryRegisteredAnalysis(node)) {
return;
}
AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
}

View File

@ -183,7 +183,6 @@ class AliasDb {
void analyzeWait(Node* node);
void analyzeSetAttr(Node* node);
void analyzeCustomOp(Node* node);
bool tryRegisteredAnalysis(Node* node);
/**
* Alias manipulation methods