Source code for torch_activation.lincomb

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from typing import Iterable


[docs]class LinComb(nn.Module): r""" Applies the LinComb activation function: :math:`\text{LinComb}(x) = \sum_{i=1}^{n} w_i \cdot F_i(x)` See: https://doi.org/10.20944/preprints202301.0463.v1 Args: activations (Iterable[str]): List of activation functions. Shape: - Input: :math:`(*)` where :math:`*` means any number of additional dimensions. - Output: :math:`(*)` Examples:: >>> activations = [nn.ReLU(), nn.Sigmoid()] >>> m = LinComb(activation_functions) >>> input = torch.randn(10) >>> output = m(input) """ def __init__(self, activations: Iterable[str]): super(LinComb, self).__init__() self.activations = nn.ModuleList(activations) self.weights = nn.Parameter(Tensor(len(activations))) self.weights.data.uniform_(-1, 1)
[docs] def forward(self, input) -> Tensor: activations = [self.weights[i] * self.activations[i] (input) for i in range(len(self.activation_functions))] return torch.sum(torch.stack(activations), dim=0)
[docs]class NormLinComb(nn.Module): r""" Applies the LinComb activation function: :math:`\text{NormLinComb}(x) = \frac{\sum_{i=1}^{n} w_i \cdot F_i(x)}{\|\|W\|\|}` See: https://doi.org/10.20944/preprints202301.0463.v1 Args: activations (Iterable[str]): List of activation functions. Shape: - Input: :math:`(*)` where :math:`*` means any number of additional dimensions. - Output: :math:`(*)` Examples:: >>> activations = [nn.ReLU(), nn.Sigmoid()] >>> m = NormLinComb(activation_functions) >>> input = torch.randn(10) >>> output = m(input) """ def __init__(self, activations: Iterable[str]): super(NormLinComb, self).__init__() self.activations = nn.ModuleList(activations) self.weights = nn.Parameter(Tensor(len(activations))) self.weights.data.uniform_(-1, 1)
[docs] def forward(self, input) -> Tensor: activations = [self.weights[i] * self.activations[i] (input) for i in range(len(self.activation_functions))] output = torch.sum(torch.stack(activations), dim=0) return output / torch.norm(self.norm)