mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This reverts commit 45411d1fc9a2b6d2f891b6ab0ae16409719e09fc. Reverted https://github.com/pytorch/pytorch/pull/129409 on behalf of https://github.com/jeanschmidt due to Breaking internal CI, @albanD please help get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/129409#issuecomment-2571316444))
82 lines
2.1 KiB
Python
82 lines
2.1 KiB
Python
"""
|
|
This script will generate input-out plots for all of the activation
|
|
functions. These are for use in the documentation, and potentially in
|
|
online tutorials.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import matplotlib
|
|
from matplotlib import pyplot as plt
|
|
|
|
import torch
|
|
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
|
|
# Create a directory for the images, if it doesn't exist
|
|
ACTIVATION_IMAGE_PATH = Path(__file__).parent / "activation_images"
|
|
|
|
if not ACTIVATION_IMAGE_PATH.exists():
|
|
ACTIVATION_IMAGE_PATH.mkdir()
|
|
|
|
# In a refactor, these ought to go into their own module or entry
|
|
# points so we can generate this list programmatically
|
|
functions = [
|
|
torch.nn.ELU(),
|
|
torch.nn.Hardshrink(),
|
|
torch.nn.Hardtanh(),
|
|
torch.nn.Hardsigmoid(),
|
|
torch.nn.Hardswish(),
|
|
torch.nn.LeakyReLU(negative_slope=0.1),
|
|
torch.nn.LogSigmoid(),
|
|
torch.nn.PReLU(),
|
|
torch.nn.ReLU(),
|
|
torch.nn.ReLU6(),
|
|
torch.nn.RReLU(),
|
|
torch.nn.SELU(),
|
|
torch.nn.SiLU(),
|
|
torch.nn.Mish(),
|
|
torch.nn.CELU(),
|
|
torch.nn.GELU(),
|
|
torch.nn.Sigmoid(),
|
|
torch.nn.Softplus(),
|
|
torch.nn.Softshrink(),
|
|
torch.nn.Softsign(),
|
|
torch.nn.Tanh(),
|
|
torch.nn.Tanhshrink(),
|
|
]
|
|
|
|
|
|
def plot_function(function, **args):
|
|
"""
|
|
Plot a function on the current plot. The additional arguments may
|
|
be used to specify color, alpha, etc.
|
|
"""
|
|
xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6
|
|
plt.plot(xrange.numpy(), function(xrange).detach().numpy(), **args)
|
|
|
|
|
|
# Step through all the functions
|
|
for function in functions:
|
|
function_name = function._get_name()
|
|
plot_path = ACTIVATION_IMAGE_PATH / f"{function_name}.png"
|
|
if not plot_path.exists():
|
|
# Start a new plot
|
|
plt.clf()
|
|
plt.grid(color="k", alpha=0.2, linestyle="--")
|
|
|
|
# Plot the current function
|
|
plot_function(function)
|
|
|
|
plt.title(function)
|
|
plt.xlabel("Input")
|
|
plt.ylabel("Output")
|
|
plt.xlim([-7, 7])
|
|
plt.ylim([-7, 7])
|
|
|
|
# And save it
|
|
plt.savefig(plot_path)
|
|
print(f"Saved activation image for {function_name} at {plot_path}")
|