mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/65368. See discussion in the issue. cc mruberry SsnL jbschlosser soulitzer Pull Request resolved: https://github.com/pytorch/pytorch/pull/65415 Reviewed By: soulitzer Differential Revision: D31093303 Pulled By: albanD fbshipit-source-id: 621c74c7a2aceee95e3d3b708c7f1a1d59e59b93
90 lines
2.2 KiB
Python
90 lines
2.2 KiB
Python
"""
|
|
This script will generate input-out plots for all of the activation
|
|
functions. These are for use in the documentation, and potentially in
|
|
online tutorials.
|
|
"""
|
|
|
|
import os.path
|
|
import torch.nn.modules.activation
|
|
import torch.autograd
|
|
import matplotlib
|
|
|
|
matplotlib.use('Agg')
|
|
|
|
import pylab
|
|
|
|
|
|
# Create a directory for the images, if it doesn't exist
|
|
ACTIVATION_IMAGE_PATH = os.path.join(
|
|
os.path.realpath(os.path.join(__file__, "..")),
|
|
"activation_images"
|
|
)
|
|
|
|
if not os.path.exists(ACTIVATION_IMAGE_PATH):
|
|
os.mkdir(ACTIVATION_IMAGE_PATH)
|
|
|
|
# In a refactor, these ought to go into their own module or entry
|
|
# points so we can generate this list programmaticly
|
|
functions = [
|
|
'ELU',
|
|
'Hardshrink',
|
|
'Hardtanh',
|
|
'Hardsigmoid',
|
|
'Hardswish',
|
|
'LeakyReLU', # Perhaps we should add text explaining slight slope?
|
|
'LogSigmoid',
|
|
'PReLU',
|
|
'ReLU',
|
|
'ReLU6',
|
|
'RReLU',
|
|
'SELU',
|
|
'SiLU',
|
|
'Mish',
|
|
'CELU',
|
|
'GELU',
|
|
'Sigmoid',
|
|
'Softplus',
|
|
'Softshrink',
|
|
'Softsign',
|
|
'Tanh',
|
|
'Tanhshrink'
|
|
]
|
|
|
|
|
|
def plot_function(function, **args):
|
|
"""
|
|
Plot a function on the current plot. The additional arguments may
|
|
be used to specify color, alpha, etc.
|
|
"""
|
|
xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6
|
|
pylab.plot(
|
|
xrange.numpy(),
|
|
function(xrange).detach().numpy(),
|
|
**args
|
|
)
|
|
|
|
|
|
# Step through all the functions
|
|
for function_name in functions:
|
|
plot_path = os.path.join(ACTIVATION_IMAGE_PATH, function_name + ".png")
|
|
if not os.path.exists(plot_path):
|
|
function = torch.nn.modules.activation.__dict__[function_name]()
|
|
|
|
# Start a new plot
|
|
pylab.clf()
|
|
pylab.grid(color='k', alpha=0.2, linestyle='--')
|
|
|
|
# Plot the current function
|
|
plot_function(function)
|
|
|
|
# The titles are a little redundant, given context?
|
|
pylab.title(function_name + " activation function")
|
|
pylab.xlabel("Input")
|
|
pylab.ylabel("Output")
|
|
pylab.xlim([-7, 7])
|
|
pylab.ylim([-7, 7])
|
|
|
|
# And save it
|
|
pylab.savefig(plot_path)
|
|
print('Saved activation image for {} at {}'.format(function, plot_path))
|