Source code for torch_activation.utils

import os
import torch
import psutil
import plotly
import plotly.graph_objects as go


[docs]def plot_activation(activation: torch.nn.Module, params: dict, save_dir="./images/activation_images", x_range=(-5, 5), y_range=None, preview=False, plot_derivative=True): """ Plot the activation function and optionally its derivative. Parameters: activation (torch.nn.Module): The activation function to plot. params (dict): A dictionary of parameter names and values for the activation function. save_dir (str, optional): The directory to save the generated image. Defaults to "./images/activation_images". x_range (tuple, optional): The x-axis range for the plot. Defaults to (-5, 5). y_range (tuple, optional): The y-axis range for the plot. Defaults to None (auto-scale). preview (bool, optional): Whether to display the plot interactively. Defaults to False. plot_derivative (bool, optional): Whether to plot the derivative of the activation function. Defaults to True. Returns: None The function plots the activation function and, optionally, its derivative for the given parameters. The resulting plot is saved as an image in the specified `save_dir` directory. If `preview` is set to True, the plot will also be displayed interactively. Example:: >>> # Plotting the sigmoid activation function and its derivative >>> params = {'n': 1.0} >>> plot_activation(torch.nn.Sigmoid(), params) """ x = torch.linspace(x_range[0], x_range[1], 1000) fig = go.Figure() # Color for each param colors = plotly.colors.qualitative.D3[:len(params)] for param_name, param_value, color in zip(params.keys(), params.values(), colors): m = activation(**{param_name: param_value}) y = m(x) label = f"{param_name}: {param_value}" # Determine the y-axis range if y_range is None: y_min = torch.min(y).item() y_max = torch.max(y).item() y_padding = (y_max - y_min) * 0.1 # Add a 10% padding y_range = (y_min - y_padding, y_max + y_padding) fig.add_trace(go.Scatter(x=x.detach().numpy(), y=y.detach().numpy(), name=label, line=dict(color=color))) if plot_derivative: d = torch.autograd.grad(y, x, create_graph=True)[0] fig.add_trace(go.Scatter(x=x.detach().numpy(), y=d.detach().numpy(), name=f"Derivative {label}", line=dict(color=color, dash='dash'))) fig.update_layout(xaxis=dict(range=x_range), yaxis=dict(range=y_range), legend=dict(title="Params")) if preview: fig.show() file_name = os.path.join(save_dir, type(activation).__name__ + '.png') fig.write_image(file_name) print(f"Image saved as {file_name}")
if __name__ == "__main__": pass