diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index c574b538b4be..9f9f685231c5 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -233,3 +233,7 @@ def avg_pool2d(g, input, kernel_size, stride, padding, ceil_mode, count_include_ def log_softmax(g, input, dim=None): return g.op("Log", g.op('Softmax', input, axis_i=dim).setTypeAs(input)) + + +def unfold(g, input, dimension, size, step): + return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)