mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Modify is_nondeterministic to utilize tags in schemaInfo and integrate with ir.cpp (#81836)
- Modified is_nondeterministic method in SchemaInfo class to utilize tags. - Modified isNonDeterministic method in ir.cpp to utilize SchemaInfo when a Node is an aten op. - Added an assert to ensure that if a node is an aten op kind, it has a schema. - Tested through verifying that all IR.cpp tests run, and through adding 2 custom determinism checks to test for the special dropout edge case and a general bernoulli case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81836 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
14968d59f2
commit
fc3555ce4d
@ -1607,5 +1607,57 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
|
||||
[&graph] { AliasDb aliasDb(graph); },
|
||||
"Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
|
||||
}
|
||||
|
||||
TEST(IRNonDeterminismTest, Basic) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto graph_string = R"IR(
|
||||
graph():
|
||||
%x : Tensor = prim::MakeTestTensor()
|
||||
%0 : int = prim::Constant[value=0]()
|
||||
%1 : NoneType = prim::Constant()
|
||||
%2 : Tensor = aten::bernoulli(%x, %1)
|
||||
%3 : Tensor = aten::add(%x, %2, %0)
|
||||
return (%3))IR";
|
||||
parseIR(graph_string, graph.get());
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (n->kind() == aten::bernoulli) {
|
||||
ASSERT_TRUE(n->isNondeterministic());
|
||||
} else {
|
||||
ASSERT_FALSE(n->isNondeterministic());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRNonDeterminismTest, DropoutSpecialCase) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto graph_string = R"IR(
|
||||
graph():
|
||||
%x : Tensor = prim::MakeTestTensor()
|
||||
%0 : bool = prim::Constant[value=0]()
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%3 : float = prim::Constant[value=1.0]()
|
||||
%4 : Tensor = aten::dropout(%x, %3, %0)
|
||||
%5 : Tensor = aten::dropout(%x, %3, %1)
|
||||
%6 : Tensor = aten::add(%4, %5, %3)
|
||||
return (%6))IR";
|
||||
parseIR(graph_string, graph.get());
|
||||
|
||||
bool train = false;
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (n->kind() == aten::dropout) {
|
||||
if (!train) {
|
||||
ASSERT_FALSE(n->isNondeterministic());
|
||||
train = true;
|
||||
} else {
|
||||
ASSERT_TRUE(n->isNondeterministic());
|
||||
}
|
||||
} else {
|
||||
ASSERT_FALSE(n->isNondeterministic());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1144,40 +1144,25 @@ Operation Node::getOperation() const {
|
||||
}
|
||||
|
||||
bool Node::isNondeterministic() const {
|
||||
static const OperatorSet nondeterministic_ops = {
|
||||
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
|
||||
"aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
|
||||
"aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
|
||||
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
|
||||
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
|
||||
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
|
||||
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
|
||||
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
|
||||
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
|
||||
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
|
||||
"aten::poisson(Tensor self, Generator? generator) -> Tensor",
|
||||
"aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
|
||||
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
||||
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
||||
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
|
||||
|
||||
if (!isMemberOf(nondeterministic_ops)) {
|
||||
const auto schema = maybeSchema();
|
||||
if (!kind().is_aten()) {
|
||||
return false;
|
||||
}
|
||||
// Dropout with train = False is deterministic
|
||||
if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
|
||||
is_constant(attr::train) && !get<bool>(attr::train).value()) {
|
||||
// All aten ops are expecte to have a schema. However this is left as a
|
||||
// warning instead of an assert to ensure that previous use cases do not
|
||||
// break.
|
||||
if (!schema) {
|
||||
TORCH_WARN("aten Schema not found.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
torch::utils::SchemaInfo schema_info(*schema);
|
||||
if (hasNamedInput("train")) {
|
||||
auto value = constant_as<bool>(namedInput("train"));
|
||||
if (value.has_value()) {
|
||||
schema_info.addArgumentValue("train", *value);
|
||||
}
|
||||
}
|
||||
return schema_info.is_nondeterministic();
|
||||
}
|
||||
|
||||
bool Node::hasSideEffects() const {
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/csrc/utils/schema_info.h>
|
||||
|
||||
namespace torch {
|
||||
@ -107,20 +108,17 @@ bool SchemaInfo::is_mutable(c10::string_view name) {
|
||||
}
|
||||
|
||||
bool SchemaInfo::is_nondeterministic() const {
|
||||
static const std::vector<c10::FunctionSchema> nondeterministic_ops =
|
||||
getNonDeterministicOps();
|
||||
static const c10::FunctionSchema detach_schema = torch::jit::parseSchema(
|
||||
static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema(
|
||||
"aten::dropout(Tensor input, float p, bool train) -> Tensor");
|
||||
if (detach_schema == this->schema_ && value_map_.count("train") &&
|
||||
if (dropout_schema == schema_ && value_map_.count("train") &&
|
||||
!value_map_.at("train").toBool()) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(
|
||||
nondeterministic_ops.begin(),
|
||||
nondeterministic_ops.end(),
|
||||
[this](const c10 ::FunctionSchema& nondeterministic_op) {
|
||||
return nondeterministic_op == this->schema_;
|
||||
});
|
||||
|
||||
const auto& op = c10::Dispatcher::singleton().findOp(
|
||||
c10::OperatorName(schema_.name(), schema_.overload_name()));
|
||||
|
||||
return op && op->hasTag(at::Tag::nondeterministic_seeded);
|
||||
}
|
||||
|
||||
bool SchemaInfo::may_alias(
|
||||
@ -203,42 +201,6 @@ bool SchemaInfo::mayContainAliasImpl(
|
||||
wildcard_set_.count(rhs);
|
||||
}
|
||||
|
||||
std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
|
||||
// This list of nondeterministic ops is copied from JIT ir.cpp.
|
||||
static const std::vector<std::string> nondeterministic_op_strings = {
|
||||
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
|
||||
"aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
|
||||
"aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
|
||||
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
|
||||
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
|
||||
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
|
||||
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
|
||||
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
|
||||
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
|
||||
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
|
||||
"aten::poisson(Tensor self, Generator? generator) -> Tensor",
|
||||
"aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
|
||||
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
||||
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
||||
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
|
||||
"aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
|
||||
|
||||
std::vector<c10::FunctionSchema> nondeterministic_ops;
|
||||
nondeterministic_ops.reserve(nondeterministic_op_strings.size());
|
||||
for (const std::string& signature : nondeterministic_op_strings) {
|
||||
nondeterministic_ops.push_back(torch::jit::parseSchema(signature));
|
||||
}
|
||||
|
||||
return nondeterministic_ops;
|
||||
}
|
||||
|
||||
void SchemaInfo::ensureConservativity(
|
||||
const std::unordered_set<at::Symbol>& duplicates,
|
||||
const std::vector<c10::Argument>& arguments_list,
|
||||
|
@ -81,8 +81,6 @@ struct TORCH_API SchemaInfo {
|
||||
const c10::SchemaArgument& lhs,
|
||||
const c10::SchemaArgument& rhs);
|
||||
|
||||
static std::vector<c10::FunctionSchema> getNonDeterministicOps();
|
||||
|
||||
static std::vector<c10::FunctionSchema> getTrainingOps();
|
||||
|
||||
// Set of all wildcard arguments
|
||||
|
Reference in New Issue
Block a user