make profiling take no_grad flags into account (#31071)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31071

Previously the profiler would think Tensors would require grad, even
when the no_grad flag is enabled during execution. This makes the profiling
and guards respect the no_grad flag, which eliminates extra differentiable
graphs that appear in the backward graph (where no_grad is typically enabled).

Test Plan: Imported from OSS

Differential Revision: D18915468

Pulled By: zdevito

fbshipit-source-id: 1ae816a16ab78ae5352825cc6b4a68ed7681a089
This commit is contained in:
Zachary DeVito
2019-12-17 13:18:53 -08:00
committed by Facebook Github Bot
parent dab5f72543
commit cc8d6342fc
5 changed files with 23 additions and 16 deletions

View File

@ -309,7 +309,7 @@ class TestFuser(JitTestCase):
with enable_profiling_mode():
warmup_backward(c.sum())
graph = backward_graph(s)
self.assertAllFused(graph, except_for={'aten::Float'})
self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")

View File

@ -51,6 +51,17 @@ namespace jit {
// indicating whether this is the last use of the value. The interpreter
// should generate a move rather than a copy in this case.
TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
if (!t.defined()) {
return TensorType::get()->withUndefined();
}
auto r = TensorType::create(t);
if (!at::GradMode::is_enabled()) {
return r->withRequiresGrad(false);
}
return r;
}
namespace {
// insert Drop nodes to kill references for anything unused:
@ -1010,9 +1021,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
} break;
case GUARD: {
auto t = stack.back().toTensor();
auto actual = t.defined() ? TensorType::create(t)
: TensorType::get()->withUndefined();
const TypePtr &expected = af.types[inst.X];
auto actual = tensorTypeInCurrentExecutionContext(t);
const TypePtr& expected = af.types[inst.X];
push(stack, *expected == *actual);
++af.pc;
} break;

View File

@ -99,5 +99,11 @@ struct InterpreterContinuation {
bool grad_mode_enabled;
};
// what is the tensors type, including state from the current execution context
// that modifies how the tensor behaves. For instance if no_grad is enabled
// this will cause the TensorType to have requires_grad=False.
TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext(
const at::Tensor& t);
} // namespace jit
} // namespace torch

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/profiling_record.h>
namespace torch {
namespace jit {
@ -58,7 +59,7 @@ void ProfilingRecord::insertShapeProfile(Node *n, Value *i) {
if (t.isTensor()) {
if (t.toTensor().defined()) {
auto pttp = TensorType::create(t.toTensor());
auto pttp = tensorTypeInCurrentExecutionContext(t.toTensor());
std::lock_guard<std::mutex> lock(this->mutex_);
if (auto type = pno->type()->cast<TensorType>()) {
if (!first) {
@ -127,14 +128,5 @@ std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
return pr;
}
TensorTypePtr ProfilingRecord::toTensorTypePtr(const IValue& ival) {
if (ival.isTensor()) {
auto tensor = ival.toTensor();
return TensorType::create(tensor);
}
return {nullptr};
}
} // namespace jit
} // namespace torch

View File

@ -21,7 +21,6 @@ struct ProfilingRecord {
// are captured in callbacks_
ProfilingRecord(const ProfilingRecord&) = delete;
ProfilingRecord(ProfilingRecord&&) noexcept = delete;
static TensorTypePtr toTensorTypePtr(const IValue& ival);
TORCH_API static std::unique_ptr<ProfilingRecord> instrumentGraph(
const std::shared_ptr<Graph>& graph);