Files
pytorch/docs/source/scripts/build_activation_images.py
zeshengzong c9c0f8eae3 Add plot for torch.nn.Threshold and torch.nn.GLU (#150171)
Fixes #150170

## Changes

- Add plot for `torch.nn.Threshold` and `torch.nn.GLU`
- Add example output make them easier get result by users

## Test Result

![image](https://github.com/user-attachments/assets/f6c5bc46-f9b7-4db7-9797-e08d8423d1b3)

![image](https://github.com/user-attachments/assets/ad4e6c84-7b29-44f1-b7bd-9c81e4a92ef8)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150171
Approved by: https://github.com/albanD
2025-04-08 03:55:37 +00:00

90 lines
2.3 KiB
Python

"""
This script will generate input-out plots for all of the activation
functions. These are for use in the documentation, and potentially in
online tutorials.
"""
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
import torch
matplotlib.use("Agg")
# Create a directory for the images, if it doesn't exist
ACTIVATION_IMAGE_PATH = Path(__file__).parent / "activation_images"
if not ACTIVATION_IMAGE_PATH.exists():
ACTIVATION_IMAGE_PATH.mkdir()
# In a refactor, these ought to go into their own module or entry
# points so we can generate this list programmatically
functions = [
torch.nn.ELU(),
torch.nn.Hardshrink(),
torch.nn.Hardtanh(),
torch.nn.Hardsigmoid(),
torch.nn.Hardswish(),
torch.nn.LeakyReLU(negative_slope=0.1),
torch.nn.LogSigmoid(),
torch.nn.PReLU(),
torch.nn.ReLU(),
torch.nn.ReLU6(),
torch.nn.RReLU(),
torch.nn.SELU(),
torch.nn.SiLU(),
torch.nn.Mish(),
torch.nn.CELU(),
torch.nn.GELU(),
torch.nn.Sigmoid(),
torch.nn.Softplus(),
torch.nn.Softshrink(),
torch.nn.Softsign(),
torch.nn.Tanh(),
torch.nn.Tanhshrink(),
torch.nn.Threshold(0, 0.5),
torch.nn.GLU(),
]
def plot_function(function, **args):
"""
Plot a function on the current plot. The additional arguments may
be used to specify color, alpha, etc.
"""
if isinstance(function, torch.nn.GLU):
xrange = torch.arange(-7.0, 7.0, 0.01).unsqueeze(1).repeat(1, 2)
x = xrange.numpy()[:, 0]
else:
xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6
x = xrange.numpy()
y = function(xrange).detach().numpy()
plt.plot(x, y, **args)
# Step through all the functions
for function in functions:
function_name = function._get_name()
plot_path = ACTIVATION_IMAGE_PATH / f"{function_name}.png"
if not plot_path.exists():
# Start a new plot
plt.clf()
plt.grid(color="k", alpha=0.2, linestyle="--")
# Plot the current function
plot_function(function)
plt.title(function)
plt.xlabel("Input")
plt.ylabel("Output")
plt.xlim([-7, 7])
plt.ylim([-7, 7])
# And save it
plt.savefig(plot_path)
print(f"Saved activation image for {function_name} at {plot_path}")