mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Symbols are given hidden visibility by default on Linux to emulate the behavior on Windows. This helps developers catch visibility issues in their streamlined Linux dev environment before being surprised, late in the process, by Windows errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20461 Reviewed By: kostmo Differential Revision: D15410410 Pulled By: dzhulgakov fbshipit-source-id: 1d684b5a9a80b692966a775c3f1c56b7c72ffc95
150 lines
4.9 KiB
C++
150 lines
4.9 KiB
C++
#include <ATen/core/jit_type.h>
|
|
#include <torch/csrc/jit/constants.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
|
|
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
|
|
bool getRequiresGrad(Value* value) {
|
|
return value->requires_grad();
|
|
}
|
|
|
|
void setRequiresGrad(Value* value, bool req_value) {
|
|
if (auto type = value->type()->cast<DimensionedTensorType>()) {
|
|
value->setType(type->withRequiresGrad(req_value));
|
|
}
|
|
}
|
|
|
|
void setRequiresGrad(
|
|
at::ArrayRef<Value*> outputs,
|
|
const std::vector<bool>& values) {
|
|
AT_ASSERT(outputs.size() == values.size());
|
|
for (size_t i = 0; i < values.size(); ++i) {
|
|
setRequiresGrad(outputs[i], values[i]);
|
|
}
|
|
}
|
|
|
|
void setRequiresGrad(Node* node, const std::vector<bool>& values) {
|
|
setRequiresGrad(node->outputs(), values);
|
|
}
|
|
|
|
std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
|
|
AT_ASSERT(a.size() == b.size());
|
|
for (size_t i = 0; i < a.size(); ++i) {
|
|
a[i] = a[i] || b[i];
|
|
}
|
|
return a;
|
|
}
|
|
|
|
void PropagateRequiresGradSimpleNode(Node* node) {
|
|
static const OperatorSet comparison_ops = {
|
|
"aten::lt(Tensor self, Tensor other) -> Tensor",
|
|
"aten::le(Tensor self, Tensor other) -> Tensor",
|
|
"aten::gt(Tensor self, Tensor other) -> Tensor",
|
|
"aten::ge(Tensor self, Tensor other) -> Tensor",
|
|
"aten::eq(Tensor self, Tensor other) -> Tensor",
|
|
"aten::ne(Tensor self, Tensor other) -> Tensor",
|
|
"aten::lt(Tensor self, Scalar other) -> Tensor",
|
|
"aten::le(Tensor self, Scalar other) -> Tensor",
|
|
"aten::gt(Tensor self, Scalar other) -> Tensor",
|
|
"aten::ge(Tensor self, Scalar other) -> Tensor",
|
|
"aten::eq(Tensor self, Scalar other) -> Tensor",
|
|
"aten::ne(Tensor self, Scalar other) -> Tensor",
|
|
};
|
|
|
|
if (comparison_ops.find(node)) {
|
|
return setRequiresGrad(node->output(), false);
|
|
} else if (node->matches(
|
|
"aten::type_as(Tensor self, Tensor other) -> Tensor")) {
|
|
return setRequiresGrad(node->output(), node->input(0)->requires_grad());
|
|
} else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
|
|
return setRequiresGrad(node->output(), false);
|
|
} else if (node->kind() == aten::tensor) {
|
|
if (auto grad_index =
|
|
node->schema().argumentIndexWithName("requires_grad")) {
|
|
if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
|
|
return setRequiresGrad(node->output(), *const_arg);
|
|
}
|
|
}
|
|
if (auto type = node->output()->type()->cast<DimensionedTensorType>()) {
|
|
setRequiresGrad(node->output(), at::isFloatingType(type->scalarType()));
|
|
}
|
|
return;
|
|
}
|
|
|
|
auto inputs = node->inputs();
|
|
auto outputs = node->outputs();
|
|
bool should_require =
|
|
std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
|
|
for (Value* output : outputs) {
|
|
if (auto type = output->type()->cast<DimensionedTensorType>()) {
|
|
setRequiresGrad(
|
|
output, should_require && at::isFloatingType(type->scalarType()));
|
|
}
|
|
}
|
|
}
|
|
|
|
void PropagateRequiresGrad(Block* block);
|
|
|
|
void PropagateRequiresGrad(Node* node) {
|
|
if (node->kind() == prim::If) {
|
|
auto blocks = node->blocks();
|
|
auto true_block = blocks.at(0);
|
|
auto false_block = blocks.at(1);
|
|
|
|
PropagateRequiresGrad(true_block);
|
|
PropagateRequiresGrad(false_block);
|
|
|
|
auto outputs_require = bitwiseOr(
|
|
fmap(true_block->outputs(), getRequiresGrad),
|
|
fmap(false_block->outputs(), getRequiresGrad));
|
|
setRequiresGrad(node, outputs_require);
|
|
} else if (node->kind() == prim::Loop) {
|
|
auto body = node->blocks().at(0);
|
|
std::vector<bool> body_inputs_require =
|
|
fmap(node->inputs().slice(2), getRequiresGrad);
|
|
std::vector<bool> body_outputs_require(node->outputs().size(), false);
|
|
|
|
std::vector<bool> new_body_inputs_require = body_inputs_require;
|
|
std::vector<bool> new_body_outputs_require = body_outputs_require;
|
|
|
|
// continue iterating until the results have converged
|
|
do {
|
|
body_inputs_require = new_body_inputs_require;
|
|
body_outputs_require = new_body_outputs_require;
|
|
|
|
new_body_inputs_require =
|
|
bitwiseOr(body_inputs_require, body_outputs_require);
|
|
setRequiresGrad(
|
|
body->param_node()->outputs().slice(1), new_body_inputs_require);
|
|
PropagateRequiresGrad(body);
|
|
new_body_outputs_require =
|
|
fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
|
|
} while (new_body_inputs_require != body_inputs_require &&
|
|
new_body_outputs_require != body_outputs_require);
|
|
|
|
setRequiresGrad(node, body_outputs_require);
|
|
} else {
|
|
PropagateRequiresGradSimpleNode(node);
|
|
}
|
|
}
|
|
|
|
void PropagateRequiresGrad(Block* block) {
|
|
for (Node* node : block->nodes()) {
|
|
PropagateRequiresGrad(node);
|
|
}
|
|
}
|
|
} // anonymous namespace
|
|
|
|
void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
|
|
PropagateRequiresGrad(graph->block());
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|