mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved. - Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly - Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused @albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and `torch/autograd/_functions/utils.py`? Thanks! ## BC Breaking - **Deprecated members in `torch.onnx.verification` are removed** Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323 Approved by: https://github.com/titaiwangms, https://github.com/angelayi
211 lines
6.5 KiB
Python
211 lines
6.5 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import pytorch_test_common
|
|
from onnx_test_common import run_model_test
|
|
|
|
import torch
|
|
from torch.onnx import OperatorExportTypes
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
|
opset_version = 20
|
|
keep_initializers_as_inputs = False
|
|
onnx_shape_inference = True
|
|
|
|
def test_single_output(self):
|
|
class SingleOut(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
result = i.exp()
|
|
result = result.log()
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * result
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
result = input + 5
|
|
return SingleOut.apply(result) + 3
|
|
|
|
model = Caller()
|
|
input = torch.ones(1)
|
|
run_model_test(self, model, input_args=(input,))
|
|
|
|
def test_multi_output(self):
|
|
class MultiOut(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
result_exp = i.exp()
|
|
result_log = result_exp.log()
|
|
ctx.save_for_backward(result_exp, result_log)
|
|
return result_exp, result_log
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * result
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
return MultiOut.apply(input)
|
|
|
|
model = Caller()
|
|
input = torch.ones(1, 5)
|
|
run_model_test(self, model, input_args=(input,))
|
|
|
|
def test_partial_output(self):
|
|
class PartialOut(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
ctx.save_for_backward(input)
|
|
values, _ = torch.topk(input, 3)
|
|
return values
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
return PartialOut.apply(input)
|
|
|
|
model = Caller()
|
|
input = torch.ones(1, 5)
|
|
run_model_test(self, model, input_args=(input,))
|
|
|
|
def test_nested_autograd(self):
|
|
class Child(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
result = i.log()
|
|
result_log = result.log()
|
|
ctx.save_for_backward(result_log)
|
|
return result_log
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * result
|
|
|
|
class Parent(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
result_exp = i.exp()
|
|
result_log = Child.apply(result_exp)
|
|
ctx.save_for_backward(result_exp, result_log)
|
|
return result_exp, result_log
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * result
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
return Parent.apply(input)
|
|
|
|
model = Caller()
|
|
input = torch.ones(1, 5)
|
|
run_model_test(self, model, input_args=(input,))
|
|
|
|
# Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported
|
|
def test_aten_unsupported(self):
|
|
class Erf(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
erf_out = torch.special.erf(x)
|
|
ctx.save_for_backward(erf_out)
|
|
return erf_out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
result = ctx.saved_tensors
|
|
return torch.special.erfinv(result), None
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
return Erf.apply(input)
|
|
|
|
model = Caller()
|
|
input = torch.ones(1, 5)
|
|
|
|
# Test ONNX_FALLTHROUGH_MODE
|
|
graph, _, _ = torch.onnx.utils._model_to_graph(
|
|
model,
|
|
(input,),
|
|
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
|
|
)
|
|
iter = graph.nodes()
|
|
self.assertEqual(next(iter).kind(), "prim::PythonOp")
|
|
|
|
# Test ATEN_FALLBACK_MODE
|
|
graph, _, _ = torch.onnx.utils._model_to_graph(
|
|
model,
|
|
(input,),
|
|
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
|
)
|
|
iter = graph.nodes()
|
|
self.assertEqual(next(iter).kind(), "aten::ATen")
|
|
|
|
def test_inline_and_symbolic(self):
|
|
class Exp(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
ctx.save_for_backward(input)
|
|
return i.exp()
|
|
|
|
@staticmethod
|
|
def symbolic(g, input):
|
|
return g.op("Exp", input)
|
|
|
|
class LogLog(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
ctx.save_for_backward(input)
|
|
return i.log().log()
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
exp_result = Exp.apply(input)
|
|
return LogLog.apply(exp_result)
|
|
|
|
model = Caller()
|
|
input = torch.ones(1)
|
|
run_model_test(self, model, input_args=(input,))
|
|
|
|
def test_inline_with_scoped_tracing(self):
|
|
class Exp(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
ctx.save_for_backward(input)
|
|
return i.exp()
|
|
|
|
@staticmethod
|
|
def symbolic(g, input):
|
|
return g.op("Exp", input)
|
|
|
|
class LogLog(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
ctx.save_for_backward(input)
|
|
return i.log().log()
|
|
|
|
class Caller(torch.nn.Module):
|
|
def forward(self, input):
|
|
exp_result = Exp.apply(input)
|
|
return LogLog.apply(exp_result)
|
|
|
|
model = Caller()
|
|
input = torch.ones(1)
|
|
|
|
torch.jit._trace._trace_module_map = {
|
|
_m: torch.typename(type(_m)) for _m in model.modules()
|
|
}
|
|
run_model_test(self, model, input_args=(input,))
|
|
torch.jit._trace._trace_module_map = None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|