""" This script will generate input-out plots for all of the activation functions. These are for use in the documentation, and potentially in online tutorials. """ from pathlib import Path import matplotlib from matplotlib import pyplot as plt import torch matplotlib.use("Agg") # Create a directory for the images, if it doesn't exist ACTIVATION_IMAGE_PATH = Path(__file__).parent / "activation_images" if not ACTIVATION_IMAGE_PATH.exists(): ACTIVATION_IMAGE_PATH.mkdir() # In a refactor, these ought to go into their own module or entry # points so we can generate this list programmatically functions = [ torch.nn.ELU(), torch.nn.Hardshrink(), torch.nn.Hardtanh(), torch.nn.Hardsigmoid(), torch.nn.Hardswish(), torch.nn.LeakyReLU(negative_slope=0.1), torch.nn.LogSigmoid(), torch.nn.PReLU(), torch.nn.ReLU(), torch.nn.ReLU6(), torch.nn.RReLU(), torch.nn.SELU(), torch.nn.SiLU(), torch.nn.Mish(), torch.nn.CELU(), torch.nn.GELU(), torch.nn.Sigmoid(), torch.nn.Softplus(), torch.nn.Softshrink(), torch.nn.Softsign(), torch.nn.Tanh(), torch.nn.Tanhshrink(), ] def plot_function(function, **args): """ Plot a function on the current plot. The additional arguments may be used to specify color, alpha, etc. """ xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6 plt.plot(xrange.numpy(), function(xrange).detach().numpy(), **args) # Step through all the functions for function in functions: function_name = function._get_name() plot_path = ACTIVATION_IMAGE_PATH / f"{function_name}.png" if not plot_path.exists(): # Start a new plot plt.clf() plt.grid(color="k", alpha=0.2, linestyle="--") # Plot the current function plot_function(function) plt.title(function) plt.xlabel("Input") plt.ylabel("Output") plt.xlim([-7, 7]) plt.ylim([-7, 7]) # And save it plt.savefig(plot_path) print(f"Saved activation image for {function_name} at {plot_path}")