[JIT] Modify is_nondeterministic to utilize tags in schemaInfo and integrate with ir.cpp (#81836)

- Modified is_nondeterministic method in SchemaInfo class to utilize tags.
- Modified isNonDeterministic method in ir.cpp to utilize SchemaInfo when a Node is an aten op.
- Added an assert to ensure that if a node is an aten op kind, it has a schema.
- Tested through verifying that all IR.cpp tests run, and through adding 2 custom determinism checks to test for the special dropout edge case and a general bernoulli case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81836
Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-07-22 15:26:09 -07:00
committed by PyTorch MergeBot
parent 14968d59f2
commit fc3555ce4d
4 changed files with 75 additions and 78 deletions

View File

@ -1607,5 +1607,57 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
[&graph] { AliasDb aliasDb(graph); },
"Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
TEST(IRNonDeterminismTest, Basic) {
auto graph = std::make_shared<Graph>();
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%0 : int = prim::Constant[value=0]()
%1 : NoneType = prim::Constant()
%2 : Tensor = aten::bernoulli(%x, %1)
%3 : Tensor = aten::add(%x, %2, %0)
return (%3))IR";
parseIR(graph_string, graph.get());
for (Node* n : graph->nodes()) {
if (n->kind() == aten::bernoulli) {
ASSERT_TRUE(n->isNondeterministic());
} else {
ASSERT_FALSE(n->isNondeterministic());
}
}
}
TEST(IRNonDeterminismTest, DropoutSpecialCase) {
auto graph = std::make_shared<Graph>();
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%0 : bool = prim::Constant[value=0]()
%1 : bool = prim::Constant[value=1]()
%3 : int = prim::Constant[value=1]()
%3 : float = prim::Constant[value=1.0]()
%4 : Tensor = aten::dropout(%x, %3, %0)
%5 : Tensor = aten::dropout(%x, %3, %1)
%6 : Tensor = aten::add(%4, %5, %3)
return (%6))IR";
parseIR(graph_string, graph.get());
bool train = false;
for (Node* n : graph->nodes()) {
if (n->kind() == aten::dropout) {
if (!train) {
ASSERT_FALSE(n->isNondeterministic());
train = true;
} else {
ASSERT_TRUE(n->isNondeterministic());
}
} else {
ASSERT_FALSE(n->isNondeterministic());
}
}
}
} // namespace jit
} // namespace torch

View File

@ -1144,40 +1144,25 @@ Operation Node::getOperation() const {
}
bool Node::isNondeterministic() const {
static const OperatorSet nondeterministic_ops = {
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
"aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
"aten::poisson(Tensor self, Generator? generator) -> Tensor",
"aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
if (!isMemberOf(nondeterministic_ops)) {
const auto schema = maybeSchema();
if (!kind().is_aten()) {
return false;
}
// Dropout with train = False is deterministic
if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
is_constant(attr::train) && !get<bool>(attr::train).value()) {
// All aten ops are expecte to have a schema. However this is left as a
// warning instead of an assert to ensure that previous use cases do not
// break.
if (!schema) {
TORCH_WARN("aten Schema not found.");
return false;
}
return true;
torch::utils::SchemaInfo schema_info(*schema);
if (hasNamedInput("train")) {
auto value = constant_as<bool>(namedInput("train"));
if (value.has_value()) {
schema_info.addArgumentValue("train", *value);
}
}
return schema_info.is_nondeterministic();
}
bool Node::hasSideEffects() const {

View File

@ -1,3 +1,4 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/utils/schema_info.h>
namespace torch {
@ -107,20 +108,17 @@ bool SchemaInfo::is_mutable(c10::string_view name) {
}
bool SchemaInfo::is_nondeterministic() const {
static const std::vector<c10::FunctionSchema> nondeterministic_ops =
getNonDeterministicOps();
static const c10::FunctionSchema detach_schema = torch::jit::parseSchema(
static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema(
"aten::dropout(Tensor input, float p, bool train) -> Tensor");
if (detach_schema == this->schema_ && value_map_.count("train") &&
if (dropout_schema == schema_ && value_map_.count("train") &&
!value_map_.at("train").toBool()) {
return false;
}
return std::any_of(
nondeterministic_ops.begin(),
nondeterministic_ops.end(),
[this](const c10 ::FunctionSchema& nondeterministic_op) {
return nondeterministic_op == this->schema_;
});
const auto& op = c10::Dispatcher::singleton().findOp(
c10::OperatorName(schema_.name(), schema_.overload_name()));
return op && op->hasTag(at::Tag::nondeterministic_seeded);
}
bool SchemaInfo::may_alias(
@ -203,42 +201,6 @@ bool SchemaInfo::mayContainAliasImpl(
wildcard_set_.count(rhs);
}
std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
// This list of nondeterministic ops is copied from JIT ir.cpp.
static const std::vector<std::string> nondeterministic_op_strings = {
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
"aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
"aten::poisson(Tensor self, Generator? generator) -> Tensor",
"aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
std::vector<c10::FunctionSchema> nondeterministic_ops;
nondeterministic_ops.reserve(nondeterministic_op_strings.size());
for (const std::string& signature : nondeterministic_op_strings) {
nondeterministic_ops.push_back(torch::jit::parseSchema(signature));
}
return nondeterministic_ops;
}
void SchemaInfo::ensureConservativity(
const std::unordered_set<at::Symbol>& duplicates,
const std::vector<c10::Argument>& arguments_list,

View File

@ -81,8 +81,6 @@ struct TORCH_API SchemaInfo {
const c10::SchemaArgument& lhs,
const c10::SchemaArgument& rhs);
static std::vector<c10::FunctionSchema> getNonDeterministicOps();
static std::vector<c10::FunctionSchema> getTrainingOps();
// Set of all wildcard arguments