mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix contiguous AD and Autogradzero inconsistency (#18633)
Summary: Fixes #17962 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18633 Differential Revision: D14700449 Pulled By: wanchaol fbshipit-source-id: 3d15d67c01b69b28394a0f2f001db90ed9fd31dc
This commit is contained in:
committed by
Facebook Github Bot
parent
5950c1e8c4
commit
a21e256e8d
@ -538,6 +538,8 @@ def method_tests():
|
||||
('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', (), [1]),
|
||||
('clone', (S, M, S), NO_ARGS),
|
||||
('clone', (), NO_ARGS, 'scalar'),
|
||||
('contiguous', (S, S), NO_ARGS, '', (True,)),
|
||||
('contiguous', torch.randn(S, S).transpose(0, 1), NO_ARGS, 'not_contiguous', (True,)),
|
||||
('dist', (S, S, S), ((S, S, S),)),
|
||||
('dist', (S, S, S), ((S,),), 'broadcast_rhs'),
|
||||
('dist', (S,), ((S, S, S),), 'broadcast_lhs'),
|
||||
|
@ -102,6 +102,16 @@ RegisterOperators reg({
|
||||
return 0;
|
||||
}
|
||||
),
|
||||
Operator(
|
||||
"aten::is_contiguous(Tensor self) -> bool",
|
||||
[](Stack & stack) {
|
||||
autograd::profiler::RecordFunction record("is_contiguous");
|
||||
auto result = ((std::move(peek(stack, 0, 1))).toTensor()).is_contiguous();
|
||||
drop(stack, 1);
|
||||
pack(stack, std::move(result));
|
||||
return 0;
|
||||
}
|
||||
),
|
||||
|
||||
// Generated operators
|
||||
${constructors}
|
||||
|
@ -198,7 +198,7 @@ bool isDifferentiable(Graph& g) {
|
||||
// graph and a backward graph. Forward graph will be used to replace the node in
|
||||
// grad_desc.f, and backward graph will be used to construct GradOf(node) in
|
||||
// reverse_block. Grad_values(a.k.a gradOutputs) propagated through
|
||||
// node->owningGraph() in **reversed** order, thus GradientPair.forward ahould
|
||||
// node->owningGraph() in **reversed** order, thus GradientPair.forward should
|
||||
// be inserted **after** the node being replaced, so that we don't traverse the
|
||||
// graph infinite times.
|
||||
//
|
||||
@ -775,10 +775,10 @@ class GradientHelper {
|
||||
// later. It is ok to replace any backward function with known-zero inputs with
|
||||
// something that produces known-zero outputs. This function encloses each
|
||||
// know-linear backward function in a 'GradOf' sub-block so that we can perform
|
||||
// optimizations using this information. In particular, specializeUndef will
|
||||
// observe if all the inputs to the linear block are Undef, which the autograd
|
||||
// uses to represent zeros, and then propagate the undefs to the outputs of the
|
||||
// block.
|
||||
// optimizations using this information. In particular, specializeAutogradZero will
|
||||
// observe if all the inputs to the linear block are AutogradZeroTensor, which the
|
||||
// autograd uses to represent zeros, and then propagate the zeros to the outputs
|
||||
// of the block.
|
||||
static std::vector<Value*> linearGradientForNode(
|
||||
Node* node,
|
||||
ArrayRef<Value*> grad_values) {
|
||||
@ -795,7 +795,7 @@ static std::vector<Value*> linearGradientForNode(
|
||||
WithInsertPoint guard(block);
|
||||
auto results = GradientHelper(node).gradient(grad_values);
|
||||
return fmap(results, [block, linear](Value* grad) -> Value* {
|
||||
if (!grad)
|
||||
if (!grad || grad->mustBeNone())
|
||||
return nullptr;
|
||||
block->registerOutput(grad);
|
||||
return linear->addOutput()->copyMetadata(grad);
|
||||
@ -841,8 +841,8 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
|
||||
const auto get_grad = [&](Value* v) -> Value* {
|
||||
auto it = grad_map.find(v);
|
||||
if (it == grad_map.end()) {
|
||||
auto undef = graph.insertNode(graph.createAutogradZero());
|
||||
std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
|
||||
auto autograd_zero = graph.insertNode(graph.createAutogradZero());
|
||||
std::tie(it, std::ignore) = grad_map.emplace(v, autograd_zero->output());
|
||||
}
|
||||
return it->second;
|
||||
};
|
||||
|
@ -404,7 +404,7 @@ const std::vector<std::string> functions = {
|
||||
|
||||
def contiguous(self):
|
||||
def backward(grad_output):
|
||||
return None
|
||||
return grad_output
|
||||
|
||||
return self.contiguous(), backward
|
||||
|
||||
|
Reference in New Issue
Block a user