mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
autodiff changes to enable profiling
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25397 Differential Revision: D17565747 Pulled By: Krovatkin fbshipit-source-id: b772437d9e02df99db6e662cb7d1227359959bed
This commit is contained in:
committed by
Facebook Github Bot
parent
0cb10d7ebf
commit
db5791d543
@ -379,15 +379,14 @@ struct CAFFE2_API TensorType : public Type {
|
||||
return TensorTypePtr(new TensorType(t));
|
||||
}
|
||||
|
||||
static TensorTypePtr create(
|
||||
c10::optional<at::ScalarType> scalar_type,
|
||||
c10::optional<Device> device,
|
||||
const VaryingShape& sizes,
|
||||
const VaryingStrides& strides,
|
||||
c10::optional<bool> requires_grad,
|
||||
c10::optional<bool> autograd_zero=c10::nullopt) {
|
||||
return TensorTypePtr(new TensorType(
|
||||
scalar_type, device, sizes, strides, requires_grad));
|
||||
static TensorTypePtr create(c10::optional<at::ScalarType> scalar_type,
|
||||
c10::optional<Device> device,
|
||||
const VaryingShape &sizes,
|
||||
const VaryingStrides &strides,
|
||||
c10::optional<bool> requires_grad,
|
||||
c10::optional<bool> undefined = false) {
|
||||
return TensorTypePtr(new TensorType(scalar_type, device, sizes, strides,
|
||||
requires_grad, undefined));
|
||||
}
|
||||
|
||||
static TensorTypePtr create(
|
||||
@ -462,8 +461,9 @@ struct CAFFE2_API TensorType : public Type {
|
||||
|
||||
auto rt = rhs.expect<TensorType>();
|
||||
return scalar_type_ == rt->scalarType() && sizes() == rt->sizes() &&
|
||||
strides() == rt->strides() && device() == rt->device() &&
|
||||
requiresGrad() == rt->requiresGrad() && autogradZero() == rt->autogradZero();
|
||||
strides() == rt->strides() && device() == rt->device() &&
|
||||
requiresGrad() == rt->requiresGrad() &&
|
||||
undefined() == rt->undefined();
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override;
|
||||
|
||||
@ -533,68 +533,52 @@ struct CAFFE2_API TensorType : public Type {
|
||||
return cloned;
|
||||
}
|
||||
|
||||
TensorTypePtr merge(TensorTypePtr other) const {
|
||||
auto scalar_type = merge_primitive(scalarType(), other->scalarType());
|
||||
auto dev = merge_primitive(device(), other->device());
|
||||
auto sz = sizes().merge(other->sizes());
|
||||
auto srs = strides().merge(other->strides());
|
||||
auto gr = merge_primitive(requiresGrad(), other->requiresGrad());
|
||||
auto zero = merge_primitive(autogradZero(), other->autogradZero());
|
||||
return TensorType::create(scalar_type, dev, sz, srs, gr, zero);
|
||||
}
|
||||
TensorTypePtr merge(TensorTypePtr other) const;
|
||||
|
||||
// is all information about the type specified except for autograd?
|
||||
// This replaces the notion of a 'CompleteTensorType' that used to exist
|
||||
// in the type-hierarchy. Excluding require_grad and autogradZero allows
|
||||
// in the type-hierarchy. Excluding require_grad and undefined allows
|
||||
// this to match the old behavior.
|
||||
bool isComplete() const {
|
||||
return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
|
||||
}
|
||||
|
||||
TensorTypePtr withAutogradZero() {
|
||||
TensorTypePtr withUndefined() {
|
||||
auto r = clone();
|
||||
r->autograd_zero_ = true;
|
||||
r->undefined_ = true;
|
||||
return r;
|
||||
}
|
||||
|
||||
c10::optional<bool> autogradZero() const {
|
||||
return autograd_zero_;
|
||||
}
|
||||
c10::optional<bool> undefined() const { return undefined_; }
|
||||
|
||||
static TensorTypePtr get();
|
||||
|
||||
static const TypeKind Kind = TypeKind::TensorType;
|
||||
|
||||
private:
|
||||
TensorType(const at::Tensor& tensor)
|
||||
: Type(TypeKind::TensorType),
|
||||
scalar_type_(tensor.scalar_type()),
|
||||
device_(tensor.device()),
|
||||
sizes_(tensor.sizes().size()),
|
||||
strides_(tensor.sizes().size()),
|
||||
requires_grad_(tensor.requires_grad()) {
|
||||
if (!tensor.is_mkldnn() && !tensor.is_sparse()) {
|
||||
sizes_ = tensor.sizes().vec();
|
||||
strides_ = tensor.strides().vec();
|
||||
}
|
||||
TensorType(const at::Tensor &tensor)
|
||||
: Type(TypeKind::TensorType), scalar_type_(tensor.scalar_type()),
|
||||
device_(tensor.device()), sizes_(tensor.sizes().size()),
|
||||
strides_(tensor.sizes().size()),
|
||||
requires_grad_(tensor.requires_grad()), undefined_(false) {
|
||||
if (!tensor.is_mkldnn() && !tensor.is_sparse()) {
|
||||
sizes_ = tensor.sizes().vec();
|
||||
strides_ = tensor.strides().vec();
|
||||
}
|
||||
}
|
||||
TensorType(
|
||||
c10::optional<at::ScalarType> scalar_type,
|
||||
c10::optional<Device> device,
|
||||
const VaryingShape& sizes,
|
||||
const VaryingStrides& strides,
|
||||
c10::optional<bool> requires_grad,
|
||||
c10::optional<bool> autograd_zero=c10::nullopt)
|
||||
: Type(TypeKind::TensorType),
|
||||
scalar_type_(scalar_type),
|
||||
device_(device),
|
||||
sizes_(sizes),
|
||||
strides_(strides),
|
||||
requires_grad_(requires_grad),
|
||||
autograd_zero_(autograd_zero) {}
|
||||
TensorType(c10::optional<at::ScalarType> scalar_type,
|
||||
c10::optional<Device> device, const VaryingShape &sizes,
|
||||
const VaryingStrides &strides,
|
||||
c10::optional<bool> requires_grad,
|
||||
c10::optional<bool> undefined = false)
|
||||
: Type(TypeKind::TensorType), scalar_type_(scalar_type),
|
||||
device_(device), sizes_(sizes), strides_(strides),
|
||||
requires_grad_(requires_grad), undefined_(undefined) {}
|
||||
|
||||
TensorTypePtr clone() const {
|
||||
return TensorTypePtr(new TensorType(
|
||||
scalar_type_, device_, sizes_, strides_, requires_grad_, autograd_zero_));
|
||||
TensorTypePtr clone() const {
|
||||
return TensorTypePtr(new TensorType(scalar_type_, device_, sizes_,
|
||||
strides_, requires_grad_,
|
||||
undefined_));
|
||||
}
|
||||
|
||||
static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
|
||||
@ -614,11 +598,17 @@ struct CAFFE2_API TensorType : public Type {
|
||||
VaryingStrides strides_;
|
||||
c10::optional<bool> requires_grad_;
|
||||
// we exploit the fact certain tensors must be zero in the autograd to
|
||||
// optimize gradient computation. If true, this means that this tensor
|
||||
// must only contain zeros. Normally this will be nullopt, meaning
|
||||
// the tensor may or may not contain only zeros. If false,
|
||||
// this means the tensor must have some non-zero elements.
|
||||
c10::optional<bool> autograd_zero_;
|
||||
// optimize gradient computation. Such zero tensors are currently implemented
|
||||
// with `UndefinedTensorImpl.` They can be handled only by special operators
|
||||
// (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false.
|
||||
// Normally, `undefined_` is set to false, unless a type was created
|
||||
// with `withUndefined`
|
||||
// This will also mean that `undefined` tensors will fail
|
||||
// `subtypeOf(TensorType::get())` check
|
||||
// undefined_ may become `c10::nullopt` if the tensor was observed to be both
|
||||
// defined and undefined. However, no tensor type starts out with
|
||||
// `undefined_` set to `c10::nullopt`
|
||||
c10::optional<bool> undefined_;
|
||||
};
|
||||
|
||||
struct ListType;
|
||||
|
@ -31,8 +31,8 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
}
|
||||
out << ")";
|
||||
}
|
||||
if (value->autogradZero() && *value->autogradZero()) {
|
||||
out << "[AutogradZero]";
|
||||
if (value->undefined() && *value->undefined()) {
|
||||
out << "[Undefined]";
|
||||
}
|
||||
} else if(t.kind() == TypeKind::ListType) {
|
||||
auto prim = t.cast<ListType>()->getElementType();
|
||||
@ -513,6 +513,16 @@ VaryingShape VaryingShape::merge(const VaryingShape& other) const {
|
||||
return VaryingShape(std::move(dims));
|
||||
}
|
||||
|
||||
TensorTypePtr TensorType::merge(TensorTypePtr other) const {
|
||||
auto scalar_type = merge_primitive(scalarType(), other->scalarType());
|
||||
auto dev = merge_primitive(device(), other->device());
|
||||
auto sz = sizes().merge(other->sizes());
|
||||
auto srs = strides().merge(other->strides());
|
||||
auto gr = merge_primitive(requiresGrad(), other->requiresGrad());
|
||||
auto undef = merge_primitive(undefined(), other->undefined());
|
||||
return TensorType::create(scalar_type, dev, sz, srs, gr, undef);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const VaryingShape & vs) {
|
||||
|
||||
out << "(";
|
||||
|
@ -25,6 +25,11 @@ std::vector<at::Tensor> run(
|
||||
return fmap(stack, [](const IValue& i) { return i.toTensor(); });
|
||||
}
|
||||
|
||||
static void unpackReturnTuple(Stack &stack) {
|
||||
auto tuple = pop(stack).toTuple();
|
||||
stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
|
||||
}
|
||||
|
||||
std::pair<tensor_list, tensor_list> runGradient(
|
||||
Gradient& grad_spec,
|
||||
tensor_list& tensors_in,
|
||||
@ -32,6 +37,7 @@ std::pair<tensor_list, tensor_list> runGradient(
|
||||
static const auto as_tensorlist = [](const Stack& stack) {
|
||||
return fmap(stack, [](const IValue& i) { return i.toTensor(); });
|
||||
};
|
||||
|
||||
Code f_code{grad_spec.f}, df_code{grad_spec.df};
|
||||
InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
|
||||
|
||||
@ -46,7 +52,7 @@ std::pair<tensor_list, tensor_list> runGradient(
|
||||
for (auto offset : grad_spec.df_input_captured_outputs)
|
||||
df_stack.push_back(f_stack[offset]);
|
||||
df_interpreter.run(df_stack);
|
||||
|
||||
unpackReturnTuple(df_stack);
|
||||
// Outputs of f needs to be sliced
|
||||
f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
|
||||
return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
|
||||
|
@ -16953,7 +16953,7 @@ nn_functional_tests = [
|
||||
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
|
||||
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
|
||||
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
|
||||
'', (True, 'aten::_batch_norm_impl_index')),
|
||||
'', (False, 'aten::_batch_norm_impl_index')),
|
||||
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
|
||||
('layer_norm', (S, S, S, S), ([5],), '',
|
||||
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
|
||||
|
@ -22,7 +22,7 @@ class TestFuser(JitTestCase):
|
||||
def assertAllFused(self, graph, except_for=()):
|
||||
if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
|
||||
graph = next(graph.nodes()).g('Subgraph')
|
||||
allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
|
||||
allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::TupleConstruct'} | set(except_for)
|
||||
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
|
||||
'got {}'.format(graph))
|
||||
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
|
||||
@ -878,8 +878,9 @@ class TestFuser(JitTestCase):
|
||||
assert backward is None
|
||||
backward = g
|
||||
old_plans.add(str(backward))
|
||||
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "aten::_grad_sum_to_size"]), i)
|
||||
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "prim::Param"]), 3 - i)
|
||||
self.assertEqual(len([1 for o in next(backward.outputs()).node().inputs()
|
||||
if o.node().kind() == "aten::_grad_sum_to_size"]), i)
|
||||
self.assertEqual(len([1 for o in next(backward.outputs()).node().inputs() if o.node().kind() == "prim::Param"]), 3 - i)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -215,8 +215,7 @@ void ArgumentSpecCreator::specializeTypes(
|
||||
input_stack.back()++;
|
||||
auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
|
||||
if (!arg.defined()) {
|
||||
result_stack.back().emplace_back(
|
||||
TensorType::get()->withAutogradZero());
|
||||
result_stack.back().emplace_back(TensorType::get()->withUndefined());
|
||||
} else {
|
||||
result_stack.back().emplace_back(arg.toType());
|
||||
}
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/script/compiler.h>
|
||||
#include <torch/csrc/jit/symbolic_script.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <algorithm>
|
||||
@ -525,6 +524,58 @@ static void liftConstants(Block* block, Block* move_to_this_block) {
|
||||
liftConstants(block->return_node(), move_to_this_block);
|
||||
}
|
||||
|
||||
// we need to fold aten::_size_if_not_equal at the differentiation time
|
||||
// while we know the shapes of aten::_size_if_not_equal's arguments
|
||||
// Otherwise, they will become inputs to a reverse Graph, and we will
|
||||
// lose this information and we don't profile Scalars, or Lists yet.
|
||||
static void foldSizeIfNotEqual(Block *node);
|
||||
|
||||
static void foldSizeIfNotEqual(Node *node) {
|
||||
for (Value *input : node->inputs()) {
|
||||
|
||||
if (input->node()->kind() != aten::_size_if_not_equal) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ptt_input =
|
||||
input->node()->input(0)->node()->input()->type()->expect<TensorType>();
|
||||
auto ptt_output =
|
||||
input->node()->input(1)->node()->input()->type()->expect<TensorType>();
|
||||
|
||||
auto input_size = ptt_input->sizes().concrete_sizes();
|
||||
auto output_size = ptt_output->sizes().concrete_sizes();
|
||||
|
||||
if (!input_size || !output_size) {
|
||||
continue;
|
||||
}
|
||||
// insert in front of _grad_sum_to_size
|
||||
WithInsertPoint guard(node);
|
||||
IValue ival{};
|
||||
Value *size;
|
||||
if (input_size != output_size) {
|
||||
size = node->owningGraph()->insertConstant(*input_size);
|
||||
} else {
|
||||
size = node->owningGraph()->insertConstant(IValue());
|
||||
}
|
||||
node->replaceInputWith(input, size);
|
||||
}
|
||||
|
||||
for (auto ib : node->blocks()) {
|
||||
foldSizeIfNotEqual(ib);
|
||||
}
|
||||
}
|
||||
|
||||
// we need to fold aten::_size_if_not_equal at the differentiation time
|
||||
// while we know the shapes of aten::_size_if_not_equal's arguments
|
||||
// Otherwise, they will become inputs to a reverse Graph, and we will
|
||||
// lose this information and we don't profile Scalars, or Lists yet.
|
||||
static void foldSizeIfNotEqual(Block *reverse_block) {
|
||||
for (auto n : reverse_block->nodes()) {
|
||||
foldSizeIfNotEqual(n);
|
||||
}
|
||||
foldSizeIfNotEqual(reverse_block->return_node());
|
||||
}
|
||||
|
||||
static void deduplicateSizeCaptures(
|
||||
Gradient& grad_desc,
|
||||
ReverseDetails& rev_info) {
|
||||
@ -592,6 +643,8 @@ static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
|
||||
// have time before the 1.0 release, so I put this only as a peephole
|
||||
// optimization.
|
||||
liftConstants(rev_info.reverse_block, rev_info.reverse_block);
|
||||
// TODO: see if this pass can be replaced with peephole pass
|
||||
foldSizeIfNotEqual(rev_info.reverse_block);
|
||||
// We generally add a lot of aten::size calls (for derivatives of broadcasting
|
||||
// operators), and they often end up duplicated, and would get captured
|
||||
// multiple times. Make sure we deduplicate them before lifting.
|
||||
@ -693,6 +746,7 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
|
||||
// f's outputs that we differentiate).
|
||||
if (rev_info.grad_map.count(tmp) == 0)
|
||||
continue;
|
||||
|
||||
Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
|
||||
Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
|
||||
// This is quite weird because we can't first make a sum and then replace
|
||||
@ -745,6 +799,14 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
|
||||
reverse_block->owningNode()->destroy();
|
||||
}
|
||||
|
||||
void packReturnValuesIntoTuple(const std::shared_ptr<Graph> &graph) {
|
||||
auto returnNode = graph->block()->return_node();
|
||||
WithInsertPoint wip(returnNode);
|
||||
auto tuple = graph->insertNode(graph->createTuple(returnNode->inputs()));
|
||||
returnNode->removeAllInputs();
|
||||
returnNode->addInput(tuple->output());
|
||||
}
|
||||
|
||||
Gradient differentiate(std::shared_ptr<Graph>& graph) {
|
||||
Gradient grad_desc;
|
||||
// Take ownership of the graph
|
||||
@ -770,6 +832,7 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
|
||||
// It's possible the we've cloned the same constants many times, so
|
||||
// de-duplicate them
|
||||
ConstantPooling(grad_desc.df);
|
||||
packReturnValuesIntoTuple(grad_desc.df);
|
||||
return grad_desc;
|
||||
}
|
||||
} // namespace jit
|
||||
|
@ -207,6 +207,12 @@ struct UnpackInstructions {
|
||||
std::vector<size_t> sizes_;
|
||||
};
|
||||
|
||||
// unpack values packed by `packReturnValuesIntoTuple`
|
||||
static void unpackReturnTuple(Stack &stack) {
|
||||
auto tuple = pop(stack).toTuple();
|
||||
stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
|
||||
}
|
||||
|
||||
struct DifferentiableGraphBackward : public autograd::Node {
|
||||
DifferentiableGraphBackward(
|
||||
GraphExecutor executor,
|
||||
@ -223,6 +229,7 @@ struct DifferentiableGraphBackward : public autograd::Node {
|
||||
input_instructions_.unpack(std::move(inputs), stack);
|
||||
captures_.unpack(stack, shared_from_this());
|
||||
executor.run(stack);
|
||||
unpackReturnTuple(stack);
|
||||
|
||||
// NB: stack.size() == num_outputs() is not always true
|
||||
// after we added TensorList support.
|
||||
@ -552,6 +559,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
|
||||
// Phase 0. Inline functions, then clean up any artifacts that the inliner
|
||||
// left in that may inhibit optimization
|
||||
Inline(*opt_graph);
|
||||
specializeAutogradZero(*opt_graph);
|
||||
LowerSimpleTuples(opt_graph);
|
||||
ConstantPooling(opt_graph);
|
||||
|
||||
@ -642,7 +650,6 @@ GraphExecutorState GraphExecutor::getDebugState() {
|
||||
}
|
||||
|
||||
void runRequiredPasses(const std::shared_ptr<Graph>& g) {
|
||||
specializeAutogradZero(*g);
|
||||
LowerGradOf(*g);
|
||||
// implicit inserted expand nodes are not necessarily always valid
|
||||
// when used inside script methods that might have unstable shapes
|
||||
|
@ -985,8 +985,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
++af.pc;
|
||||
} break;
|
||||
case GUARD: {
|
||||
auto actual = TensorType::create(stack.back().toTensor());
|
||||
const TypePtr& expected = af.types[inst.X];
|
||||
auto t = stack.back().toTensor();
|
||||
auto actual = t.defined() ? TensorType::create(t)
|
||||
: TensorType::get()->withUndefined();
|
||||
const TypePtr &expected = af.types[inst.X];
|
||||
push(stack, *expected == *actual);
|
||||
++af.pc;
|
||||
} break;
|
||||
|
@ -15,15 +15,13 @@ namespace jit {
|
||||
// introduce more granularity here (e.g. List[int] will never alias
|
||||
// List[float]).
|
||||
c10::optional<TypeKind> AliasDb::getMutableTypeKind(const TypePtr& type) {
|
||||
if (type->isSubtypeOf(TensorType::get())) {
|
||||
return TypeKind::TensorType;
|
||||
}
|
||||
|
||||
switch (type->kind()) {
|
||||
case TypeKind::ListType:
|
||||
case TypeKind::TupleType:
|
||||
case TypeKind::DictType:
|
||||
case TypeKind::ClassType:
|
||||
case TypeKind::TensorType:
|
||||
return type->kind();
|
||||
case TypeKind::OptionalType:
|
||||
return getMutableTypeKind(type->cast<OptionalType>()->getElementType());
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <torch/csrc/jit/graph_executor.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
|
||||
|
||||
@ -10,14 +11,14 @@ namespace jit {
|
||||
// operations generated by the symbolic autodiff code and cleans up
|
||||
// AutogradAdds when possible. Outputs of other nodes are conservatively
|
||||
// marked Unknown and not optimized.
|
||||
void specializeAutogradZero(Graph& g) {
|
||||
void specializeAutogradZero(Graph &g) {
|
||||
enum class State { Nonzero, Zero, Unknown };
|
||||
std::unordered_map<Value*, State> state;
|
||||
|
||||
for (Value* input : g.inputs()) {
|
||||
const auto& tp = input->type();
|
||||
if (auto tt = tp->cast<TensorType>()) {
|
||||
if (tt->autogradZero() && *tt->autogradZero()) {
|
||||
if (tt->undefined() && *tt->undefined()) {
|
||||
state[input] = State::Zero;
|
||||
} else {
|
||||
state[input] = State::Nonzero;
|
||||
@ -101,6 +102,7 @@ void specializeAutogradZero(Graph& g) {
|
||||
} else if (state[a] == State::Nonzero && state[b] == State::Nonzero) {
|
||||
// when both are Nonzero, we can use a normal, optimizable add
|
||||
// instruction
|
||||
|
||||
WithInsertPoint guard(n);
|
||||
auto* g = n->owningGraph();
|
||||
auto* cOne = g->insertConstant(1);
|
||||
@ -124,6 +126,32 @@ void specializeAutogradZero(Graph& g) {
|
||||
case prim::AutogradZero: {
|
||||
state[n->output()] = State::Zero;
|
||||
} break;
|
||||
case prim::profile: {
|
||||
// if prim::profile doesn't have an input
|
||||
// it's a counter to keep track how many times
|
||||
// a graph was profiled
|
||||
if (n->inputs().size() > 0) {
|
||||
state[n->output()] = State::Unknown;
|
||||
// state[n->input()];
|
||||
}
|
||||
break;
|
||||
}
|
||||
case prim::BailOut: {
|
||||
if (auto ptt = n->output()->type()->expect<TensorType>()) {
|
||||
state[n->output()] =
|
||||
ptt->undefined()
|
||||
? *ptt->undefined() ? State::Zero : State::Nonzero
|
||||
: State::Unknown;
|
||||
}
|
||||
} break;
|
||||
case prim::Guard: {
|
||||
if (auto ptt = n->output()->type()->expect<TensorType>()) {
|
||||
state[n->output()] =
|
||||
ptt->undefined()
|
||||
? *ptt->undefined() ? State::Zero : State::Nonzero
|
||||
: State::Unknown;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
for (auto o : n->outputs()) {
|
||||
state[o] = State::Unknown;
|
||||
|
@ -11,7 +11,7 @@ namespace jit {
|
||||
// operations generated by the symbolic autodiff code and cleans up
|
||||
// AutogradAdds when possible. Outputs of other nodes are conservatively
|
||||
// marked Unknown and not optimized.
|
||||
TORCH_API void specializeAutogradZero(Graph& g);
|
||||
TORCH_API void specializeAutogradZero(Graph &g);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1068,7 +1068,7 @@ const std::vector<std::string> functions = {
|
||||
return grad_self, None, None, None, None, None
|
||||
return output, indices, backward
|
||||
|
||||
def batch_norm(input : Tensor,
|
||||
def batch_norm_disabled(input : Tensor,
|
||||
weight : Optional[Tensor],
|
||||
bias : Optional[Tensor],
|
||||
running_mean : Optional[Tensor],
|
||||
|
Reference in New Issue
Block a user