mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Format all onnx python code with black and isort with ```sh isort torch/onnx/ test/onnx black torch/onnx/ test/onnx ``` Updated lintrunner config to include these paths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76754 Approved by: https://github.com/suo, https://github.com/BowenBao
20 lines
518 B
Python
20 lines
518 B
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import torch
|
|
|
|
|
|
# Autograd funtion that is a replica of the autograd funtion in
|
|
# test_utility_funs.py (test_autograd_module_name)
|
|
class CustomFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
ctx.save_for_backward(input)
|
|
return input.clamp(min=0)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(input,) = ctx.saved_tensors
|
|
grad_input = grad_output.clone()
|
|
grad_input[input < 0] = 0
|
|
return grad_input
|