mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make profiling take no_grad flags into account (#31071)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31071 Previously the profiler would think Tensors would require grad, even when the no_grad flag is enabled during execution. This makes the profiling and guards respect the no_grad flag, which eliminates extra differentiable graphs that appear in the backward graph (where no_grad is typically enabled). Test Plan: Imported from OSS Differential Revision: D18915468 Pulled By: zdevito fbshipit-source-id: 1ae816a16ab78ae5352825cc6b4a68ed7681a089
This commit is contained in:
committed by
Facebook Github Bot
parent
dab5f72543
commit
cc8d6342fc
@ -309,7 +309,7 @@ class TestFuser(JitTestCase):
|
||||
with enable_profiling_mode():
|
||||
warmup_backward(c.sum())
|
||||
graph = backward_graph(s)
|
||||
self.assertAllFused(graph, except_for={'aten::Float'})
|
||||
self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
|
@ -51,6 +51,17 @@ namespace jit {
|
||||
// indicating whether this is the last use of the value. The interpreter
|
||||
// should generate a move rather than a copy in this case.
|
||||
|
||||
TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
|
||||
if (!t.defined()) {
|
||||
return TensorType::get()->withUndefined();
|
||||
}
|
||||
auto r = TensorType::create(t);
|
||||
if (!at::GradMode::is_enabled()) {
|
||||
return r->withRequiresGrad(false);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// insert Drop nodes to kill references for anything unused:
|
||||
@ -1010,9 +1021,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
} break;
|
||||
case GUARD: {
|
||||
auto t = stack.back().toTensor();
|
||||
auto actual = t.defined() ? TensorType::create(t)
|
||||
: TensorType::get()->withUndefined();
|
||||
const TypePtr &expected = af.types[inst.X];
|
||||
auto actual = tensorTypeInCurrentExecutionContext(t);
|
||||
const TypePtr& expected = af.types[inst.X];
|
||||
push(stack, *expected == *actual);
|
||||
++af.pc;
|
||||
} break;
|
||||
|
@ -99,5 +99,11 @@ struct InterpreterContinuation {
|
||||
bool grad_mode_enabled;
|
||||
};
|
||||
|
||||
// what is the tensors type, including state from the current execution context
|
||||
// that modifies how the tensor behaves. For instance if no_grad is enabled
|
||||
// this will cause the TensorType to have requires_grad=False.
|
||||
TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext(
|
||||
const at::Tensor& t);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/profiling_record.h>
|
||||
#include <torch/csrc/jit/interpreter.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/profiling_record.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -58,7 +59,7 @@ void ProfilingRecord::insertShapeProfile(Node *n, Value *i) {
|
||||
if (t.isTensor()) {
|
||||
|
||||
if (t.toTensor().defined()) {
|
||||
auto pttp = TensorType::create(t.toTensor());
|
||||
auto pttp = tensorTypeInCurrentExecutionContext(t.toTensor());
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
if (auto type = pno->type()->cast<TensorType>()) {
|
||||
if (!first) {
|
||||
@ -127,14 +128,5 @@ std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
|
||||
return pr;
|
||||
}
|
||||
|
||||
TensorTypePtr ProfilingRecord::toTensorTypePtr(const IValue& ival) {
|
||||
if (ival.isTensor()) {
|
||||
auto tensor = ival.toTensor();
|
||||
return TensorType::create(tensor);
|
||||
}
|
||||
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -21,7 +21,6 @@ struct ProfilingRecord {
|
||||
// are captured in callbacks_
|
||||
ProfilingRecord(const ProfilingRecord&) = delete;
|
||||
ProfilingRecord(ProfilingRecord&&) noexcept = delete;
|
||||
static TensorTypePtr toTensorTypePtr(const IValue& ival);
|
||||
TORCH_API static std::unique_ptr<ProfilingRecord> instrumentGraph(
|
||||
const std::shared_ptr<Graph>& graph);
|
||||
|
||||
|
Reference in New Issue
Block a user