Source code for deeprm.utils.activations

"""
Activation function utilities for DeepRM.
"""

from deeprm.utils import check_deps

check_deps.check_torch_available()
import torch.nn as nn


[docs] def get_activation_fn(activation: str): """ Returns the activation function module based on the given activation name. Args: activation (str): Name of the activation function. Supported values are "relu", "gelu", "silu", and "elu". Returns: torch.nn.Module: Activation function module. Raises: RuntimeError: If the given activation function name is not supported. """ if activation == "relu": return nn.ReLU() elif activation == "gelu": return nn.GELU() elif activation == "silu": return nn.SiLU() elif activation == "elu": return nn.ELU() raise RuntimeError(f"The following activation function is not supported: {activation}")