Source code for torch_activation.glus

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor


[docs]class ReGLU(nn.Module): r""" Applies the GeGLU activation function, defined as: :math:`\text{GeGLU}(x) = \text{ReLU} (xW + b) \odot (xV + c)` See: https://doi.org/10.48550/arXiv.2002.05202 Args: dim (int, optional): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = ReGLU(20) >>> input = torch.randn(3, 20, 20) >>> output = m(input) """ def __init__(self, dim: int = -1): super(ReGLU, self).__init__() self.dim = dim
[docs] def forward(self, x) -> Tensor: a, b = x.chunk(2, dim=self.dim) return a * F.relu(b)
[docs]class GeGLU(nn.Module): r""" Applies the GeGLU activation function, defined as: :math:`\text{GeGLU}(x) = \text{GELU} (xW + b) \odot (xV + c)` See: https://doi.org/10.48550/arXiv.2002.05202 Args: dim (int, optional): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = GeGLU(20) >>> input = torch.randn(3, 20, 20) >>> output = m(input) """ def __init__(self, dim: int = -1): super(GeGLU, self).__init__() self.dim = dim
[docs] def forward(self, x) -> Tensor: a, b = x.chunk(2, dim=-1) return a * F.gelu(b)
[docs]class SwiGLU(nn.Module): r""" Applies the SwiGLU activation function, defined as: :math:`\sigma(x) = \text{Swish} (xW + b) \odot (xV + c)` See: https://doi.org/10.48550/arXiv.2002.05202 Args: dim (int, optional): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = SwiGLU(20) >>> input = torch.randn(3, 20, 20) >>> output = m(input) """ def __init__(self, dim: int = -1): super(SwiGLU, self).__init__() self.dim = dim
[docs] def forward(self, x) -> Tensor: a, b = x.chunk(2, dim=-1) return a * F.silu(b)
[docs]class SeGLU(nn.Module): r""" Applies the SeGLU activation function, defined as: :math:`\text{SeGLU}(x) = \text{SELU} (xW + b) \odot (xV + c)` Args: dim (int, optional): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = SeGLU(20) >>> input = torch.randn(3, 20, 20) >>> output = m(input) """ def __init__(self, dim: int = -1): super(SeGLU, self).__init__() self.dim = dim
[docs] def forward(self, x) -> Tensor: a, b = x.chunk(2, dim=-1) return a * F.selu(b)
if __name__ == "__main__": pass