Source code for deeprm.train.train

"""
DeepRM Training Module

This module provides the training functionality for the DeepRM Transformer model.
It includes the Trainer class, which handles the training loop, evaluation, and checkpointing.
"""

import argparse
import gc
import glob
import importlib
import os
import time

import numpy as np
import tqdm

from deeprm.utils import check_deps
from deeprm.utils.logging import get_logger

check_deps.check_torch_available()

import torch  # noqa: E402
import torch.distributed as dist  # noqa: E402
import torch.multiprocessing as mp  # noqa: E402
from torch.nn.parallel import DistributedDataParallel as DDP  # noqa: E402
from torch.utils.tensorboard import SummaryWriter  # noqa: E402

from deeprm.train.train_dataloader import load_dataset  # noqa: E402

log = get_logger(__name__)

try:
    import torchmetrics.classification as cm

    TORCHMETRICS_AVAILABLE = True
except Exception:
    TORCHMETRICS_AVAILABLE = False


[docs] def add_arguments(parser: argparse.ArgumentParser): """ Adds command-line arguments. Args: parser (argparse.ArgumentParser): Argument parser to which arguments will be added. Returns: None """ parser.add_argument("--gpu", dest="num_gpu", type=int, default=None, help="Number of GPUs to use") parser.add_argument("--batch", dest="batch_size", type=int, default=1024, help="Batch size for training") parser.add_argument( "--eval_batch", dest="eval_batch_size", type=int, default=None, help="Batch size for evaluation" ) parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--epochs", type=int, default=1000, help="Number of epochs to train") parser.add_argument("--data", dest="data_path", type=str, required=True, help="Path to the dataset") parser.add_argument("--output", type=str, required=True, help="Output directory for saving models and logs") parser.add_argument("--tb", dest="tb_path", type=str, default=None, help="TensorBoard log directory") parser.add_argument("--model", dest="model_type", type=str, default="deeprm_model", help="Model type") parser.add_argument("--es-delta", type=float, default=1e-5, help="Early stopping delta") parser.add_argument("--es-patience", type=int, default=50, help="Early stopping patience") parser.add_argument("--es-start", type=int, default=1000, help="Epoch to start early stopping") parser.add_argument("--disk-shard-size", type=int, default=None, help="Disk shard size") parser.add_argument("--seed", type=int, default=None, help="Random seed") parser.add_argument("--enc-dim", type=int, default=512, help="Encoder dimension") parser.add_argument("--lin-dim", type=int, default=1024, help="Linear layer dimension") parser.add_argument("--head", type=int, default=8, help="Number of attention heads") parser.add_argument("--enc-layer", type=int, default=6, help="Number of encoder layers") parser.add_argument("--lin-layer", type=int, default=4, help="Number of linear layers") parser.add_argument("--enc-dropout", type=float, default=0.1, help="Dropout rate for encoder") parser.add_argument("--lin-dropout", type=float, default=0.1, help="Dropout rate for linear layers") parser.add_argument("--period", type=int, default=30, help="Period for logging") parser.add_argument( "--buffer_size", dest="shuffle_buffer_size", type=int, default=160000, help="Shuffle buffer size" ) parser.add_argument("--kmer-size", type=int, default=5, help="K-mer size") parser.add_argument("--signal-size", type=int, default=30, help="Signal size") parser.add_argument("--block-len", type=int, default=17, help="Block length") parser.add_argument("--seq-len", type=int, default=200, help="Sequence length") parser.add_argument("--t-act", type=str, default="gelu", help="Activation function for transformer") parser.add_argument("--lin-act", type=str, default="gelu", help="Activation function for linear layers") parser.add_argument("--lr-step", type=int, default=4000, help="Learning rate step size") parser.add_argument("--lr-interval", type=int, default=100, help="Learning rate interval") parser.add_argument("--weight-decay", type=float, default=0.1, help="Weight decay for optimizer") parser.add_argument("--class-ratio", type=int, default=None, help="Class ratio for balancing") parser.add_argument("--log-interval", type=int, default=10, help="Interval for logging") parser.add_argument("--eval-interval", type=int, default=1000, help="Interval for evaluation") parser.add_argument("--save-interval", type=int, default=None, help="Interval for saving checkpoints") parser.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--profiler", type=int, default=0, help="Profiler flag") parser.add_argument("--pin-memory", type=int, default=1, help="Pin memory flag") parser.add_argument("--yield_period", type=int, default=None, help="Yield period for data loading") parser.add_argument("--rlrop", type=float, default=None, help="ReduceLROnPlateau threshold") parser.add_argument("--loss", type=str, default="BCE", help="Loss function") parser.add_argument("--score-feature", type=bool, default=False, help="Score feature flag") parser.add_argument("--gpu-pool", type=int, nargs="+", default=None, help="GPU pool") parser.add_argument("--cut-overlap", type=bool, default=False, help="Cut overlap flag") parser.add_argument("--load-checkpoint", type=str, default=None, help="Path to load checkpoint") parser.add_argument("--workers", dest="num_workers", type=int, default=8, help="Number of workers") parser.add_argument("--prefetch", type=int, default=512, help="Prefetch factor") parser.add_argument("--stride", dest="signal_stride", type=int, default=6, help="Signal stride") parser.add_argument("--no-bq", action="store_true", default=False, help="No base quality flag") parser.add_argument("--load-weight-only", action="store_true", default=False, help="Load weights only flag") parser.add_argument("--override-lr", action="store_true", default=False, help="Override learning rate flag") parser.add_argument("--comment", type=str, default="None", help="Comment for the run") parser.add_argument("--model-name", type=str, default=None, help="Model name") return None
[docs] def main(args: argparse.Namespace): """ Main function to start the training process. Args: args (argparse.Namespace): Parsed command-line arguments. Returns: None """ strfttime = time.strftime("%Y%m%d-%H%M%S") if args.num_gpu is None: if args.gpu_pool is None: args.num_gpu = torch.cuda.device_count() else: args.num_gpu = len(args.gpu_pool) if args.gpu_pool is None: args.gpu_pool = list(range(args.num_gpu)) else: if len(args.gpu_pool) < args.num_gpu: raise ValueError("GPU Pool should be the same or larger than the number of GPUs to use.") if args.disk_shard_size is None: sample_file = glob.glob(os.path.join(args.data_path, "train", "pos", "*.npz"))[0] with np.load(sample_file) as f: args.disk_shard_size = f["kmer_token"].shape[0] log.info(f"Setting disk shard size to {args.disk_shard_size}.") if args.eval_batch_size is None: args.eval_batch_size = args.batch_size * 4 if args.model_name is None: args.model_name = f"{args.model_type}-{args.comment}-{strfttime}" if args.yield_period is None: args.yield_period = args.disk_shard_size if args.save_interval is None: args.save_interval = args.eval_interval args.output = os.path.join(args.output, args.model_name) if args.tb_path is None: args.tb_path = os.path.join(args.output, "tensorboard_log") else: args.tb_path = os.path.join(args.tb_path, args.model_name) args_dict = vars(args) torch.multiprocessing.set_sharing_strategy("file_system") os.makedirs(args_dict["output"], exist_ok=True) os.makedirs(args_dict["tb_path"], exist_ok=True) if args_dict["seed"] is None: args_dict["seed"] = np.random.randint(0, 10000000) log.info("Training Program Started.") log.info(f"Seed: {args_dict['seed']}") log.info(f"Using {args_dict['num_gpu']} GPUs.") try: mp.spawn(main_worker, nprocs=args_dict["num_gpu"], args=(args_dict,)) except Exception as e: log.error("Training Program Failed.") raise e log.info("Training Program Complete.") return None
[docs] class Trainer: def __init__( self, rank: int, gpu_id: int, model: torch.nn.Module, train_loader, val_loader, optimizer: torch.optim.Optimizer, scheduler, loss_func: torch.nn.Module, grad_clip: float, metric_func_dict: dict, checkpoint_path: str, tb_path: str, es_start: int, es_patience: int, es_delta: float, model_name: str, num_gpu: int, lr_interval: int, eval_interval: int, log_interval: int, save_interval: int, model_config: dict = None, soft_label: float = None, score_feature: bool = False, cut_overlap: bool = False, signal_stride: int = 6, no_bq: bool = False, **kwargs, ) -> None: """ Initializes the Trainer class. Args: rank (int): Rank of the current process. gpu_id (int): GPU ID to use. model (torch.nn.Module): Model to train. train_loader (torch.utils.data.DataLoader): DataLoader for training data. (deeprm.train.train_dataloader.NanoporeDataLoader) val_loader (torch.utils.data.DataLoader): DataLoader for validation data. (deeprm.train.train_dataloader.NanoporeDataLoader) optimizer (torch.optim.Optimizer): Optimizer for training. scheduler: Learning rate scheduler. loss_func (torch.nn.Module): Loss function. grad_clip (float): Gradient clipping value. metric_func_dict (dict): Dictionary of metric functions. checkpoint_path (str): Path to save checkpoints. tb_path (str): Path for TensorBoard logs. es_start (int): Epoch to start early stopping. es_patience (int): Patience for early stopping. es_delta (float): Delta for early stopping. model_name (str): Name of the model. num_gpu (int): Number of GPUs to use. lr_interval (int): Interval for learning rate updates. eval_interval (int): Interval for evaluation. log_interval (int): Interval for logging. save_interval (int): Interval for saving checkpoints. model_config (dict): Model configuration dictionary. Defaults to None. (optional) soft_label (float): Soft label value. Defaults to None. (optional) score_feature (bool): Score feature flag. Defaults to False. (optional) cut_overlap (bool): Cut overlap flag. Defaults to False. (optional) signal_stride (int): Signal stride. Defaults to 6. (optional) no_bq (bool): No base quality flag. Defaults to False. (optional) **kwargs: Additional keyword arguments. """ self.rank = rank self.gpu_id = gpu_id self.model = model self.train_loader = train_loader self.val_loader = val_loader self.optimizer = optimizer self.loss_func = loss_func self.grad_clip = grad_clip self.scheduler = scheduler self.lr_interval = lr_interval self.log_interval = log_interval self.save_interval = save_interval self.eval_interval = eval_interval self.metric_func_dict = metric_func_dict self.checkpoint_path = checkpoint_path self.es_start = es_start self.es_patience = es_patience self.es_delta = es_delta self.continue_training = 1 self.best_val_loss = np.inf self.best_val_loss_epoch = 0 self.current_val_loss = 0 self.current_val_metric_dict = {} self.current_epoch = 0 self.current_batch = 0 self.current_step = 0 self.current_batch_loss = 0 self.current_interval_loss = 0 self.current_interval_losses = [] self.current_lr = 0 self.pbar = None self.model_name = model_name self.num_gpu = num_gpu self.model_config = model_config self.devname = f"{os.uname()[1]}-{self.gpu_id}" self.tb_path = tb_path self.soft_label = soft_label self.score_feature = score_feature self.cut_overlap = cut_overlap self.signal_stride = signal_stride self.histogram = False self.no_bq = no_bq if self.rank == 0 and self.tb_path is not None: self.tb_writer = SummaryWriter(tb_path) else: self.tb_writer = None self.eval_sources, self.eval_targets = self._cache_eval_data() def _cache_eval_data(self): """ Caches evaluation data for faster evaluation. Returns: tuple: Cached sources and targets for evaluation. """ sources = [] targets = [] with torch.no_grad(): for source, target in self.val_loader: sources.append(source) targets.append(target) return sources, targets def _feed_model(self, source, target): """ Feeds data to the model and returns the output and target. Args: source (dict): Source data. target (torch.Tensor): Target data. Returns: tuple: Model output and target. """ target = target.to(torch.float32).to(self.gpu_id) src_kmer = source["kmer_token"].to(self.gpu_id) src_signal = source["signal_token"].to(self.gpu_id) src_seg_len = source["segment_len"].to(self.gpu_id) if self.no_bq: output = self.model(src_kmer, src_signal, src_seg_len) else: src_dwell_bq = source["dwell_bq_token"].to(self.gpu_id) output = self.model(src_kmer, src_signal, src_seg_len, src_dwell_bq) return output, target def _run_batch(self, source, target): """ Runs a single batch of training. Args: source (dict): Source data. target (torch.Tensor): Target data. Returns: None """ self.optimizer.zero_grad() output, target = self._feed_model(source, target) loss = self.loss_func(output, target) loss.backward() if self.grad_clip > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() self.current_batch_loss = loss.item() dist.barrier() time.sleep(0.0001 * self.gpu_id) evaltext = ( f"LR {self.current_lr:.3E} | T-Loss {self.current_batch_loss:.3E} | V-Loss {self.current_val_loss:.3E} | " ) evaltext += " | ".join( [f"{k.upper()} {v:.3E}" for k, v in self.current_val_metric_dict.items() if (k in ["auroc", "ap"])] ) self.pbar.update(1) self.pbar.set_postfix_str(evaltext) return None def _run_epoch(self): """ Runs a single epoch of training. Returns: None """ self.train_loader.set_epoch(self.current_epoch) self.val_loader.set_epoch(self.current_epoch) self.current_batch = 0 self.model.train() self.current_lr = self.optimizer.param_groups[0]["lr"] colour_choice = ["red", "green", "blue", "yellow", "magenta", "cyan", "white", "black"] dist.barrier() time.sleep(0.03 * self.gpu_id) with tqdm.tqdm( total=len(self.train_loader), desc=f"[GPU {self.gpu_id}] Epoch {self.current_epoch}", position=self.rank, colour=colour_choice[self.rank % len(colour_choice)], smoothing=0, ) as self.pbar: for source, targets in self.train_loader: self._run_batch(source, targets) self.current_interval_losses.append(self.current_batch_loss) if self.current_step % self.log_interval == 0: current_interval_loss = np.mean(self.current_interval_losses) self.current_interval_losses = [] dist.barrier() current_interval_loss = torch.tensor(current_interval_loss).to(self.gpu_id) dist.all_reduce(current_interval_loss, op=dist.ReduceOp.SUM) self.current_interval_loss = current_interval_loss / self.num_gpu if self.rank == 0: self.tb_writer.add_scalar("Loss", self.current_interval_loss, self.current_step) self.tb_writer.add_scalar("Learning_Rate", self.current_lr, self.current_step) dist.barrier() if self.current_step % self.eval_interval == 0 and self.current_step > 0: self._run_eval() if self.current_step % self.save_interval == 0 and self.current_step > 0: dist.barrier() if self.rank == 0: if self.histogram: try: for name, parameter in self.model.named_parameters(): self.tb_writer.add_histogram( name, parameter.clone().cpu().data.numpy(), self.current_step ) except Exception as e: log.warning(f"Error adding histogram: {e}") pass self._save_checkpoint() dist.barrier() if self.current_step % self.lr_interval == 0: if self.scheduler.__class__.__name__ == "ReduceLROnPlateau": self.scheduler.step(self.current_val_loss) else: self.scheduler.step() self.current_lr = self.optimizer.param_groups[0]["lr"] self.current_batch += 1 self.current_step += 1 gc.collect() return None def _run_eval(self): """ Runs evaluation on the validation dataset. Returns: None """ self.model.eval() val_loss = [] outputs = [] with torch.no_grad(): for source, target in zip(self.eval_sources, self.eval_targets): output, target = self._feed_model(source, target) loss = self.loss_func(output, target) val_loss.append(loss.item()) outputs.append(output) val_loss = np.mean(val_loss) outputs = torch.cat(outputs, dim=0) targets = torch.cat(self.eval_targets, dim=0) targets = targets.to(torch.long).to(self.gpu_id) metric_dict = {} for metric_name, metric_func in self.metric_func_dict.items(): metric_dict[metric_name] = metric_func(outputs, targets) dist.barrier() val_loss = torch.tensor(val_loss).to(self.gpu_id) dist.all_reduce(val_loss, op=dist.ReduceOp.SUM) val_loss = val_loss / self.num_gpu for key, value in metric_dict.items(): metric = value.clone().detach().to(self.gpu_id) dist.all_reduce(metric, op=dist.ReduceOp.SUM) metric_dict[key] = metric / self.num_gpu self.current_val_loss = val_loss self.current_val_metric_dict = metric_dict if self.rank == 0: self.tb_writer.add_scalar("Val_Loss", val_loss, self.current_step) for key, value in metric_dict.items(): self.tb_writer.add_scalar(f"Val_{key}", value, self.current_step) dist.barrier() self.model.train() return None def _save_checkpoint(self): """ Saves the model checkpoint if validation loss improves. Returns: None """ if self.best_val_loss - self.current_val_loss > self.es_delta: ## Save Model if Improved self.best_val_loss = self.current_val_loss self.best_val_loss_epoch = self.current_epoch torch.save( { "model_state_dict": self.model.module.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "val_loss": self.current_val_loss, "metric_dict": self.current_val_metric_dict, "model_config": self.model_config, }, f"{self.checkpoint_path}/{self.model_name}-{self.current_epoch}-{self.current_step}.pt", ) self.continue_training = 1 elif self.current_epoch > self.es_start and self.current_epoch - self.best_val_loss_epoch > self.es_patience: self.continue_training = 0 else: self.continue_training = 1 return None
[docs] def train(self, max_epochs: int): """ Trains the model for a specified number of epochs. Args: max_epochs (int): The maximum number of epochs to train the model. Returns: None """ for epoch in range(max_epochs): dist.barrier() self.current_epoch = epoch self._run_epoch() if self.continue_training == 0: log.error(f"Early Stopping at Epoch {self.current_epoch}") break if self.rank == 0: self.tb_writer.flush() return None
## END of Class NanoporeTrainer
[docs] def setup_ddp(rank, world_size, gpu_id): """ Sets up Distributed Data Parallel (DDP) for multi-GPU training. Args: rank (int): Rank of the current process. world_size (int): Total number of processes. gpu_id (int): GPU ID to use. Returns: None """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(gpu_id) return None
[docs] def prepare_dataloader(data_path, rank, num_gpu, num_workers, **kwargs): """ Prepares the DataLoader for training and validation datasets. Args: data_path (str): Path to the dataset directory. rank (int): Rank of the current process. num_gpu (int): Number of GPUs to use. num_workers (int): Number of worker processes. **kwargs: Additional keyword arguments. Returns: tuple: A tuple containing the training and validation DataLoaders. """ if rank == 0: log.info(f"Total number of dataloader workers: {num_workers * num_gpu}") train_pos_data_path = f"{data_path}/train/pos" train_neg_data_path = f"{data_path}/train/neg" val_pos_data_path = f"{data_path}/val/pos" val_neg_data_path = f"{data_path}/val/neg" train_loader = load_dataset( pos_data_path=train_pos_data_path, neg_data_path=train_neg_data_path, rank=rank, num_replicas=num_gpu, num_workers=num_workers, shuffle=True, drop_last=True, **kwargs, ) val_loader = load_dataset( pos_data_path=val_pos_data_path, neg_data_path=val_neg_data_path, rank=rank, num_replicas=num_gpu, num_workers=num_workers, shuffle=False, drop_last=False, **kwargs, ) return train_loader, val_loader
[docs] def main_worker(rank, args_dict): """ Main worker function for training the model. Args: rank (int): Rank of the current process. args_dict (dict): Dictionary of command-line arguments. Returns: None """ gpu_id = args_dict["gpu_pool"][rank] setup_ddp(rank, args_dict["num_gpu"], gpu_id) TransformerModel = importlib.import_module(f"deeprm.model.{args_dict['model_type']}").TransformerModel model = TransformerModel( d_model=args_dict["enc_dim"], n_heads=args_dict["head"], d_ff=args_dict["lin_dim"], n_layers=args_dict["enc_layer"], lin_depth=args_dict["lin_layer"], **args_dict, ) if rank == 0: total_params = 0 for name, parameter in model.named_parameters(): params = parameter.numel() total_params += params log.info(f"Total Params: {total_params:,}") model = model.to(gpu_id) if args_dict["load_checkpoint"] is not None: save_dict = torch.load( args_dict["load_checkpoint"], map_location={"cuda:0": f"cuda:{gpu_id}"}, weights_only=False ) model.load_state_dict(state_dict=save_dict["model_state_dict"]) if args_dict["load_weight_only"]: for param in model.parameters(): param.requires_grad = True else: save_dict = {} model = DDP(model, device_ids=[gpu_id], output_device=gpu_id, find_unused_parameters=False) optimizer = torch.optim.AdamW(model.parameters(), lr=args_dict["lr"], weight_decay=args_dict["weight_decay"]) if args_dict["load_checkpoint"] is not None: if not args_dict["load_weight_only"]: optimizer.load_state_dict(save_dict["optimizer_state_dict"]) if args_dict["override_lr"]: for param_group in optimizer.param_groups: param_group["lr"] = args_dict["lr"] if args_dict["rlrop"] is not None: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=args_dict["lr_step"], threshold=args_dict["rlrop"], threshold_mode="rel", cooldown=0, min_lr=1e-6, eps=1e-8, ) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args_dict["lr_step"], eta_min=1e-6) if args_dict["load_checkpoint"] is not None: if not (args_dict["load_weight_only"] or args_dict["override_lr"]): scheduler.load_state_dict(save_dict["scheduler_state_dict"]) if args_dict["load_checkpoint"] is not None: save_dict.clear() del save_dict gc.collect() if args_dict["loss"] == "MSE": loss_func = torch.nn.MSELoss() elif args_dict["loss"] == "BCE": loss_func = torch.nn.BCELoss() elif args_dict["loss"] == "BCEWL": loss_func = torch.nn.BCEWithLogitsLoss() elif args_dict["loss"] == "CE": loss_func = torch.nn.CrossEntropyLoss() else: raise ValueError(f"Loss Function {args_dict['loss']} Not Implemented.") if TORCHMETRICS_AVAILABLE: metric_func_dict = { "acc": cm.BinaryAccuracy().to(gpu_id), "auroc": cm.BinaryAUROC().to(gpu_id), "ap": cm.BinaryAveragePrecision().to(gpu_id), "f-1": cm.BinaryF1Score().to(gpu_id), } else: log.warning("torchmetrics is not available. Some metrics will not be computed.") metric_func_dict = {} args_dict["checkpoint_path"] = args_dict["output"] train_loader, val_loader = prepare_dataloader(rank=rank, gpu_id=gpu_id, **args_dict) trainer = Trainer( rank=rank, gpu_id=gpu_id, model=model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, scheduler=scheduler, loss_func=loss_func, metric_func_dict=metric_func_dict, model_config=args_dict, **args_dict, ) log.info(f"[GPU {gpu_id}] Trainer Setup Complete.") trainer.train(args_dict["epochs"]) log.info(f"[GPU {gpu_id}] Training Loop Complete.") dist.destroy_process_group() return None