Source code for torch_activation.piece_wise

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor

[docs]class DELU(nn.Module): r""" Applies the DELU activation function: :math:`\text{DELU}(x) = \begin{cases} \text{SiLU}(x), x \leqslant 0 \\x(n-1), \text{otherwise} \end{cases}` See: https://doi.org/10.20944/preprints202301.0463.v1 Args: n (float, optional): Scaling factor for the positive part of the input. Default: 1.0. inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. Examples: >>> m = nn.DELU() >>> x = torch.randn(2) >>> output = m(x) >>> m = nn.DELU(inplace=True) >>> x = torch.randn(2) >>> m(x) """ def __init__(self, n: float = 1.0, inplace: bool = False): super(DELU, self).__init__() self.n = torch.nn.Parameter(Tensor(n)) self.inplace = inplace
[docs] def forward(self, x) -> Tensor: return self._forward_inplace(x) if self.inplace else self._forward(x)
def _forward(self, x): return torch.where(x <= 0, F.silu(x), (self.n + 0.5) * x + torch.abs(torch.exp(-x) - 1)) def _forward_inplace(self, x): x[x <= 0] = F.silu(x[x <= 0]) x[x > 0] = (self.n + 0.5) * x[x > 0] + \ torch.abs(torch.exp(-x[x > 0]) - 1) return x
""" class DReLUs(nn.Module): Applies the DReLUs activation function: :math:`\text{DReLUs}(x) = \begin{cases} \alpha (e ^ x - 1), x \leqslant 0 \\x, \text{otherwise} \end{cases}` Args: alpha (float, optional): Scaling factor for the positive part of the input. Default: 1.0. inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. Examples: >>> m = nn.DReLUs() >>> x = torch.randn(2) >>> output = m(x) def __init__(self, alpha: float = 1.0, inplace: bool = False): self.alpha = alpha self.inplace = inplace def forward(self, x) -> Tensor: return self._forward_inplace(x) if self.inplace else self._forward(x) def _forward(self, x): return torch.where(x > 0, x, self.alpha * (torch.exp(x) - 1)) def _forward_inplace(self, x): x[x <= 0] = (torch.exp(x[x <= 0]) - 1) * self.alpha return x """