Source code for deeprm.model.deeprm_model

"""
Module: deeprm.model.deeprm_model
This module defines the DeepRM model architecture, including the ResNet block,
Transformer model, positional encoding, and regression head.
"""

import math

from deeprm.utils import check_deps

check_deps.check_torch_available()


import torch  # noqa
from torch import Tensor, nn  # noqa

from deeprm.utils.activations import get_activation_fn  # noqa


[docs] class ResNetBlock(nn.Module): """ A 1D ResNet block for 1D DeepRM. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. hidden_channels (int): Number of hidden channels. If None, set to out_channels. (optional) kernel_size (int): Kernel size for the middle convolutional layer. Default is 3. (optional) stride (int): Stride for the convolutional layers. Default is 1. (optional) activation (str): Activation function to use. Default is 'gelu'. (optional) dropout (float): Dropout rate. Default is 0.1. (optional) groups (int): Number of groups for grouped convolution. Default is 1. (optional) Attributes: bn1 (torch.nn.BatchNorm1d): Batch normalization layer for the first convolution. activation (typing.Callable): Activation function. conv1 (torch.nn.Conv1d): First convolutional layer with kernel size 1. bn2 (torch.nn.BatchNorm1d): Batch normalization layer for the second convolution. conv2 (torch.nn.Conv1d): Second convolutional layer with specified kernel size and groups. bn3 (torch.nn.BatchNorm1d): Batch normalization layer for the third convolution. dropout (torch.nn.Dropout): Dropout layer. conv3 (torch.nn.Conv1d): Third convolutional layer with kernel size 1. shortcut (torch.nn.Module): Shortcut connection to match input and output dimensions. """ def __init__( self, in_channels: int, out_channels: int, hidden_channels: int = None, kernel_size: int = 3, stride: int = 1, activation: str = "gelu", dropout: float = 0.1, groups: int = 1, ) -> None: super().__init__() if hidden_channels is None: hidden_channels = out_channels self.bn1 = nn.BatchNorm1d(out_channels) self.activation = get_activation_fn(activation) self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size=1, stride=stride, padding="same") self.bn2 = nn.BatchNorm1d(out_channels) self.conv2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, stride, padding="same", groups=groups) self.bn3 = nn.BatchNorm1d(out_channels) self.dropout = nn.Dropout(dropout) self.conv3 = nn.Conv1d(hidden_channels, out_channels, kernel_size=1, stride=stride, padding="same") if in_channels != out_channels or stride != 1: self.shortcut = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride), nn.BatchNorm1d(out_channels) ) else: self.shortcut = nn.Identity()
[docs] def forward(self, x: Tensor) -> Tensor: """ Forward pass through the ResNet block. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_channels, sequence_length). Returns: torch.Tensor: Output tensor of shape (batch_size, out_channels, sequence_length). """ residual = self.shortcut(x) out = self.bn1(x) out = self.activation(out) out = self.conv1(out) out = self.bn2(out) out = self.activation(out) out = self.conv2(out) out = self.bn3(out) out = self.activation(out) out = self.dropout(out) out = self.conv3(out) out += residual return out
[docs] class TransformerModel(nn.Module): """ A Transformer model for DeepRM. Args: d_model (int): Dimension of the model. n_heads (int): Number of attention heads. d_ff (int): Dimension of the feed-forward network. n_layers (int): Number of encoder layers. encoder_dropout (float): Dropout rate for the encoder. Default is 0.1. (optional) lin_dropout (float): Dropout rate for the linear layers. Default is 0.1. (optional) kmer_size (int): Size of the k-mer. Default is 5. (optional) signal_size (int): Size of the signal input. Default is 25. (optional) block_len (int): Length of the block. Default is 17. (optional) seq_len (int): Length of the sequence. Default is 200. (optional) t_act (str): Activation function for the transformer. Default is 'gelu'. (optional) lin_act (str): Activation function for the linear layers. Default is 'relu'. (optional) lin_depth (int): Depth of the linear layers. Default is 1. (optional) signal_stride (int): Stride for the signal input. Default is 6. (optional) **kwargs: Additional keyword arguments. Attributes: kmer_embedding (torch.nn.Embedding): Embedding layer for k-mer sequences. signal_embedding (torch.nn.Linear): Linear layer for signal input. pos_encoding (PositionalEncoding): Positional encoding layer. cnn_encoder (torch.nn.Sequential): Sequential container for CNN encoder blocks. transformer_encoder (torch.nn.TransformerEncoder): Transformer encoder. regression_head (RegressionHead): Regression head for the model output. d_model (int): Dimension of the model. model_type (str): Type of the model, set to 'Transformer'. kmer_size (int): Size of the k-mer. signal_stride (int): Stride for the signal input. unit_size (int): Size of the unit for processing sequences. target_start_idx (int): Start index for the target in the sequence. target_end_idx (int): End index for the target in the sequence. seq_len (int): Length of the sequence. block_len (int): Length of the block. """ def __init__( self, d_model: int, n_heads: int, d_ff: int, n_layers: int, encoder_dropout: float = 0.1, lin_dropout: float = 0.1, kmer_size: int = 5, signal_size: int = 25, block_len=17, seq_len: int = 200, t_act: str = "gelu", lin_act: str = "relu", lin_depth: int = 1, signal_stride=6, **kwargs, ) -> None: super().__init__() ## Embedding Initialization self.kmer_embedding = nn.Embedding(4**kmer_size + 1, d_model) self.signal_embedding = nn.Linear(signal_size + 3, d_model) self.pos_encoding = PositionalEncoding(d_model, seq_len) ## Encoder Initialization self.d_model = d_model self.model_type = "Transformer" encoder_layer = nn.TransformerEncoderLayer( d_model, n_heads, d_ff, dropout=encoder_dropout, activation=t_act, batch_first=True ) self.cnn_encoder = nn.Sequential() self.cnn_encoder.add_module( "resnet_1", ResNetBlock(d_model, d_model, kernel_size=5, stride=1, groups=8, activation=t_act, dropout=encoder_dropout), ) self.cnn_encoder.add_module( "resnet_2", ResNetBlock( d_model, d_model, kernel_size=15, stride=1, groups=8, activation=t_act, dropout=encoder_dropout ), ) self.cnn_encoder.add_module( "resnet_3", ResNetBlock(d_model, d_model, kernel_size=5, stride=1, groups=8, activation=t_act, dropout=encoder_dropout), ) self.cnn_encoder.add_module( "resnet_4", ResNetBlock( d_model, d_model, kernel_size=15, stride=1, groups=8, activation=t_act, dropout=encoder_dropout ), ) self.cnn_encoder = nn.SyncBatchNorm.convert_sync_batchnorm(self.cnn_encoder) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, n_layers) ## Regression Head Initialization self.regression_head = RegressionHead(d_model, lin_act, lin_depth, lin_dropout, seq_len) self.regression_head = nn.SyncBatchNorm.convert_sync_batchnorm(self.regression_head) ## Weight Initialization self.init_weights() self.kmer_size = kmer_size self.signal_stride = signal_stride self.unit_size = int((seq_len + kmer_size - 1) / block_len) self.target_start_idx = (block_len // 2) * self.unit_size - (kmer_size // 2) self.target_end_idx = self.target_start_idx + self.unit_size self.seq_len = seq_len self.block_len = block_len
[docs] def init_weights(self, initrange=0.1): """ Initialize the weights of the model. Args: initrange (float): Range for uniform initialization of weights. Default is 0.1. (optional) Returns: None """ self.kmer_embedding.weight.data.uniform_(-initrange, initrange) self.signal_embedding.weight.data.uniform_(-initrange, initrange) self.regression_head.init_weights(initrange) return None
[docs] def process_kmer(self, src_kmer: Tensor, src_seg_len_flat: Tensor) -> Tensor: """ Process the k-mer input to convert nucleotide characters to numerical indices. Args: src_kmer (torch.Tensor): Input tensor of shape (batch_size, seq_len) containing nucleotide characters. src_seg_len_flat (torch.Tensor): Flattened segment lengths for the input sequences. Returns: torch.Tensor: Processed k-mer tensor of shape (batch_size, seq_len) with numerical indices. """ batch = src_kmer.shape[0] src_kmer = (src_kmer - 65).clip(None, 8) % 5 ## Convert ACGTU to 01233. src_kmer = src_kmer.unfold(1, self.kmer_size, 1) src_kmer = src_kmer * (4 ** torch.arange(self.kmer_size, device=src_kmer.device, dtype=torch.int)).unsqueeze( 0 ).unsqueeze(0) src_kmer = src_kmer.sum(dim=-1) + 1 src_kmer = torch.cat([src_kmer, torch.zeros(batch, 1, device=src_kmer.device, dtype=torch.int)], dim=1) src_kmer = src_kmer.flatten() src_kmer = src_kmer.repeat_interleave(src_seg_len_flat) src_kmer = src_kmer.reshape(batch, self.seq_len) src_kmer = src_kmer.int() return src_kmer
[docs] def process_signal(self, src_signal: Tensor) -> Tensor: """ Process the signal input by unfolding it into segments based on the signal stride and k-mer size. Args: src_signal (torch.Tensor): Input tensor of shape (batch_size, seq_len, signal_size) containing signal data. Returns: torch.Tensor: Processed signal tensor of shape (batch_size, new_seq_len, signal_size) after unfolding. """ src_signal = src_signal.unfold(1, self.signal_stride * self.kmer_size, self.signal_stride) return src_signal
[docs] def flatten_seg_len(self, src_seg_len: Tensor) -> Tensor: """ Flatten the segment lengths to create a single dimension for each sequence. Args: src_seg_len (torch.Tensor): Input tensor of shape (batch_size, num_segments) containing segment lengths. Returns: torch.Tensor: Flattened segment lengths of shape (batch_size, seq_len). """ src_seg_len_flat = torch.cat([src_seg_len, self.seq_len - src_seg_len.sum(dim=1, keepdims=True)], dim=1) src_seg_len_flat = src_seg_len_flat.flatten() return src_seg_len_flat
[docs] def create_src_pad_mask(self, src_signal: Tensor, src_seg_len: Tensor) -> Tensor: """ Create a padding mask for the source signal to ignore padded values during processing. Args: src_signal (torch.Tensor): Input tensor of shape (batch_size, seq_len, signal_size) containing signal data. src_seg_len (torch.Tensor): Segment lengths tensor of shape (batch_size, num_segments). Returns: torch.Tensor: Padding mask of shape (batch_size, seq_len) where True indicates padded positions. """ batch = src_signal.shape[0] src_pad_mask = torch.arange(self.seq_len, device=src_signal.device) src_pad_mask = src_pad_mask.repeat(batch, 1) src_pad_mask = src_pad_mask >= src_seg_len.sum(dim=1, keepdim=True) return src_pad_mask
[docs] def create_target_mask(self, src_seg_len: Tensor, src_seg_len_flat: Tensor) -> Tensor: """ Create a target mask to identify the target positions in the sequence. Args: src_seg_len (torch.Tensor): Segment lengths tensor of shape (batch_size, num_segments). src_seg_len_flat (torch.Tensor): Flattened segment lengths tensor of shape (batch_size, seq_len). Returns: torch.Tensor: Target mask of shape (batch_size, seq_len) where True indicates target positions. """ batch = src_seg_len.shape[0] width = src_seg_len.shape[1] target_mask = torch.arange(width + 1, device=src_seg_len.device, dtype=torch.int) target_mask = target_mask == self.block_len // 2 target_mask = target_mask.repeat(batch) target_mask = target_mask.repeat_interleave(src_seg_len_flat) target_mask = target_mask.reshape(batch, self.seq_len) target_mask = target_mask.int() return target_mask
[docs] def process_dwell_bq(self, src_dwell_bq: Tensor, src_seg_len_flat: Tensor) -> Tensor: """ Process the dwell time and base quality input by flattening and repeating it based on segment lengths. Args: src_dwell_bq (torch.Tensor): Input tensor of shape (batch_size, seq_len, channel) containing dwell time and base quality. src_seg_len_flat (torch.Tensor): Flattened segment lengths for the input sequences. Returns: torch.Tensor: Processed dwell time and base quality tensor of shape (batch_size, seq_len, channel). """ batch = src_dwell_bq.shape[0] channel = src_dwell_bq.shape[2] src_dwell_bq = torch.cat( [src_dwell_bq, torch.zeros(batch, 1, channel, device=src_dwell_bq.device, dtype=torch.float32)], dim=1 ) src_dwell_bq = src_dwell_bq.flatten(end_dim=1) src_dwell_bq = src_dwell_bq.repeat_interleave(src_seg_len_flat, dim=0) src_dwell_bq = src_dwell_bq.reshape(batch, self.seq_len, channel) return src_dwell_bq
[docs] def forward(self, src_kmer: Tensor, src_signal: Tensor, src_seg_len: Tensor, src_dwell_bq: Tensor) -> Tensor: """ Forward pass through the Transformer model. Args: src_kmer (torch.Tensor): Input tensor of shape (batch_size, seq_len) containing k-mer sequences. src_signal (torch.Tensor): Input tensor of shape (batch_size, seq_len, signal_size) containing signal data. src_seg_len (torch.Tensor): Segment lengths tensor of shape (batch_size, num_segments). src_dwell_bq (torch.Tensor): Input tensor of shape (batch_size, seq_len, channel) containing dwell time and base quality. Returns: torch.Tensor: Output tensor of shape (batch_size, seq_len) after processing through the model. """ with torch.no_grad(): src_seg_len_flat = self.flatten_seg_len(src_seg_len) src_kmer = self.process_kmer(src_kmer, src_seg_len_flat) src_signal = self.process_signal(src_signal) src_dwell_bq = self.process_dwell_bq(src_dwell_bq, src_seg_len_flat) src_pad_mask = self.create_src_pad_mask(src_signal, src_seg_len) target_mask = self.create_target_mask(src_seg_len, src_seg_len_flat) src_signal = torch.cat([src_signal, src_dwell_bq], dim=-1) kmer_embedding = self.kmer_embedding(src_kmer) signal_embedding = self.signal_embedding(src_signal) pos_encoding = self.pos_encoding(src_kmer.shape[0]) ## add all embeddings and dropout final_embedding = torch.stack([kmer_embedding, signal_embedding, pos_encoding], dim=0).sum(dim=0) final_embedding = final_embedding.permute(0, 2, 1) # Change to (batch, feature, time) output = self.cnn_encoder(final_embedding) output = output.permute(0, 2, 1) # Change back to (batch, time, feature) output = self.transformer_encoder(src=output, mask=None, src_key_padding_mask=src_pad_mask) ## apply regression head to each token: output = self.regression_head(output) output = output.squeeze(-1) target_mask_sum = target_mask.sum(dim=1) output = output * target_mask output = output.sum(dim=1) output = output / target_mask_sum output = torch.sigmoid(output) return output
## END OF TransformerModel
[docs] class PositionalEncoding(nn.Module): """ Positional Encoding for Transformer models. Args: d_model (int): Dimension of the model. seq_len (int): Length of the sequence. Attributes: pe (torch.Tensor): Positional encoding tensor of shape (1, seq_len, d_model). """ def __init__(self, d_model: int, seq_len: int) -> None: super().__init__() position = torch.arange(seq_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(1, seq_len, d_model) pe[:, :, 0::2] = torch.sin(position * div_term) pe[:, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe)
[docs] def forward(self, batch_size) -> Tensor: """ Forward pass to repeat the positional encoding for the given batch size. Args: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` Returns: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ pe = self.pe.repeat(batch_size, 1, 1) return pe
## END OF PositionalEncoding
[docs] class RegressionHead(nn.Module): """ Regression head for the Transformer model. Args: d_model (int): Dimension of the model. lin_act (str): Activation function for the linear layers. lin_depth (int): Depth of the linear layers. lin_dropout (float): Dropout rate for the linear layers. seq_length (int): Length of the sequence. Attributes: lin_layers (torch.nn.Sequential): Sequential container for the linear layers. """ def __init__(self, d_model: int, lin_act: str, lin_depth: int, lin_dropout: float, seq_length: int): super().__init__() layer_list = [] for i in range(lin_depth - 1): layer_list.append(nn.Linear(d_model, d_model)) layer_list.append(nn.BatchNorm1d(seq_length)) layer_list.append(get_activation_fn(lin_act)) layer_list.append(nn.Dropout(lin_dropout)) layer_list.append(nn.Linear(d_model, d_model)) layer_list.append(get_activation_fn(lin_act)) layer_list.append(nn.Linear(d_model, d_model // 4)) layer_list.append(get_activation_fn(lin_act)) layer_list.append(nn.Linear(d_model // 4, 1)) self.lin_layers = nn.Sequential(*layer_list)
[docs] def forward(self, x: Tensor) -> Tensor: """ Forward pass through the regression head. Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_length, d_model). Returns: torch.Tensor: Output tensor of shape (batch_size, seq_length, 1) after processing through the linear layers. """ return self.lin_layers(x)
[docs] def init_weights(self, initrange=0.1): """ Initialize the weights of the linear layers in the regression head. Args: initrange (float): Range for uniform initialization of weights. Default is 0.1. Returns: None """ for layer in self.lin_layers: if isinstance(layer, nn.Linear): layer.weight.data.uniform_(-initrange, initrange) if layer.bias is not None: layer.bias.data.zero_() return None
## END OF RegressionHead