Source code for deeprm.qc.inspect_run

"""
DeepRM QC Module: Inspect Basecalled Run

Open a bam file and get the stats of read, then plot.
1. Read length distribution
2. Quality score distribution
"""

import argparse
import multiprocessing as mp
import os
import pickle
from collections import deque

import numpy as np
import pysam
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm

from deeprm.utils.logging import get_logger
from deeprm.utils.utils import mean_phred

log = get_logger(__name__)

plt.style.use("default")
plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams.update(
    {"font.size": 22, "legend.facecolor": "white", "legend.framealpha": 0.5, "legend.frameon": 1, "lines.linewidth": 2}
)


[docs] def add_arguments(parser: argparse.ArgumentParser): """ Adds command-line arguments. Args: parser (argparse.ArgumentParser): Argument parser to which arguments will be added. Returns: None """ parser.add_argument("--input", "-i", dest="bam_path", type=str, required=True, help="Input bam file") parser.add_argument("--output", "-o", dest="out_path", type=str, required=True, help="Output directory") parser.add_argument( "--process", "-p", dest="process", type=int, default=int(mp.cpu_count() * 0.95 // 4), help="Number of processes" ) parser.add_argument("--threads", "-t", dest="threads", type=int, default=4, help="Number of threads") parser.add_argument("--bq", "-q", dest="bq_thres", type=int, default=7, help="Base quality threshold") parser.add_argument("--bb", "-b", dest="bb_length", type=int, default=87, help="BB length") parser.add_argument("--mrna", "-m", action="store_true", help="mRNA mode") parser.add_argument("--len", "-l", dest="len_cutoff", type=int, default=200, help="Length cutoff") return None
[docs] def main(args: argparse.Namespace): """ Main function to run the script. It reads a BAM file, collects statistics on read lengths, mean quality scores, and poly(A) lengths, and generates plots for these statistics. Args: args (argparse.Namespace): Parsed command-line arguments. Returns: None """ ## Check if BAM is indexed if not os.path.exists(args.bam_path + ".bai"): log.info("BAM file is not indexed. Indexing BAM file...") pysam.index(args.bam_path, nthreads=args.threads * args.process) log.info("BAM file indexed.") load_success = False if os.path.exists(args.out_path): try: with open(f"{args.out_path}/read_len.pkl", "rb") as f: read_len_arr = pickle.load(f) with open(f"{args.out_path}/mean_qual.pkl", "rb") as f: mean_qual_arr = pickle.load(f) with open(f"{args.out_path}/polya_len.pkl", "rb") as f: polya_len_arr = pickle.load(f) load_success = True log.info("Output directory already exists. Pickle files loaded successfully.") except Exception: load_success = False log.warning("Output directory exists but pickle files are not found or corrupted. Recalculating.") if not load_success: os.makedirs(args.out_path, exist_ok=True) manager = mp.Manager() collect_dict = manager.dict() collect_dict["read_len_arr"] = manager.list() collect_dict["qual_arr"] = manager.list() collect_dict["polya_len_arr"] = manager.list() log.info("Reading BAM file") processes = [] for pid in range(args.process): p = mp.Process(target=read_bam_worker, args=(args, pid, collect_dict)) processes.append(p) p.start() for p in processes: p.join() log.info("Collecting results") read_len_arr = np.concatenate(collect_dict["read_len_arr"]) mean_qual_arr = np.concatenate(collect_dict["qual_arr"]) polya_len_arr = np.concatenate(collect_dict["polya_len_arr"]) manager.shutdown() log.info("Saving pickle") ## save pickle with open(f"{args.out_path}/read_len.pkl", "wb") as f: pickle.dump(read_len_arr, f) with open(f"{args.out_path}/mean_qual.pkl", "wb") as f: pickle.dump(mean_qual_arr, f) with open(f"{args.out_path}/polya_len.pkl", "wb") as f: pickle.dump(polya_len_arr, f) log.info("Plotting") if not args.mrna: plot_read_len_oligo(read_len_arr, mean_qual_arr, args.bq_thres, args.out_path, args.bb_length) else: plot_read_len_mrna(read_len_arr, mean_qual_arr, args.bq_thres, args.out_path) plot_qual(mean_qual_arr, args.out_path, bq_thres=args.bq_thres) plot_polya_len(polya_len_arr, mean_qual_arr, args.bq_thres, args.out_path) return None
[docs] def plot_read_len_oligo(read_len_arr, mean_qual_arr, bq_thres, out_path, bb_length): """ Plot read length distribution for oligo data. Args: read_len_arr (numpy.ndarray): Array of read lengths. mean_qual_arr (numpy.ndarray): Array of mean quality scores. bq_thres (int): Base quality threshold. out_path (str): Output directory path. bb_length (int): Length of the barcode. Returns: None """ read_len_arr_passed = read_len_arr[mean_qual_arr >= bq_thres] read_len_arr_failed = read_len_arr[mean_qual_arr < bq_thres] ## plot read length KDE fig, ax = plt.subplots(figsize=(10, 10)) read_len_max = np.percentile(read_len_arr, 99.9) binrange = (0, read_len_max) binwidth = 10 if bb_length is not None: for i in range(1, 1000 // bb_length + 1): ligate_length = bb_length * i ax.axvline(ligate_length, color="grey", linestyle="-", linewidth=2) sns.histplot( read_len_arr_passed, ax=ax, color="royalblue", label=f"Passed (n={len(read_len_arr_passed):,})", binwidth=binwidth, binrange=binrange, fill=False, lw=5, element="step", stat="density", ) sns.histplot( read_len_arr_failed, ax=ax, color="tomato", label=f"Failed (n={len(read_len_arr_failed):,})", binwidth=binwidth, binrange=binrange, fill=False, lw=5, element="step", stat="density", ) ## Median ax.axvline(np.median(read_len_arr_passed), color="royalblue", linestyle="--", linewidth=2) ax.text( np.median(read_len_arr_passed), 0.9 * ax.get_ylim()[1], f"Passed median = {np.median(read_len_arr_passed):.0f}", color="black", ) ax.axvline(np.median(read_len_arr_failed), color="tomato", linestyle="--", linewidth=2) ax.text( np.median(read_len_arr_failed), 0.8 * ax.get_ylim()[1], f"Failed median = {np.median(read_len_arr_failed):.0f}", color="black", ) ax.set_title(f"Read Length Distribution (n={len(read_len_arr):,})") ax.set_xlabel("Read Length") ax.set_ylabel("Count") ax.set_xlim(0, 1000) ax.legend() ## Peak Detection ## Vline at median fig.savefig(f"{out_path}/read_len_hist.png", dpi=300) plt.close(fig) return None
[docs] def plot_read_len_mrna(read_len_arr, mean_qual_arr, bq_thres, out_path): """ Plot read length distribution for mRNA data. Args: read_len_arr (numpy.ndarray): Array of read lengths. mean_qual_arr (numpy.ndarray): Array of mean quality scores. bq_thres (int): Base quality threshold. out_path (str): Output directory path. Returns: None """ read_len_arr_passed = read_len_arr[mean_qual_arr >= bq_thres] read_len_arr_failed = read_len_arr[mean_qual_arr < bq_thres] ## plot read length KDE fig, ax = plt.subplots(figsize=(10, 10)) read_len_max = np.percentile(read_len_arr, 99.9) binrange = (0, read_len_max) binwidth = 10 sns.histplot( read_len_arr_passed, ax=ax, color="royalblue", label=f"Passed (n={len(read_len_arr_passed):,})", binwidth=binwidth, binrange=binrange, fill=False, lw=5, element="step", stat="density", ) sns.histplot( read_len_arr_failed, ax=ax, color="tomato", label=f"Failed (n={len(read_len_arr_failed):,})", binwidth=binwidth, binrange=binrange, fill=False, lw=5, element="step", stat="density", ) ## Median ax.axvline(np.median(read_len_arr_passed), color="royalblue", linestyle="--", linewidth=2) ax.text( np.median(read_len_arr_passed), 0.9 * ax.get_ylim()[1], f"Passed median = {np.median(read_len_arr_passed):.0f}", color="black", ) ax.axvline(np.median(read_len_arr_failed), color="tomato", linestyle="--", linewidth=2) ax.text( np.median(read_len_arr_failed), 0.8 * ax.get_ylim()[1], f"Failed median = {np.median(read_len_arr_failed):.0f}", color="black", ) ax.set_title(f"Read Length Distribution (n={len(read_len_arr):,})") ax.set_xlabel("Read Length") ax.set_ylabel("Count") ax.set_xlim(0, 3000) ax.legend() ## Vline at median fig.savefig(f"{out_path}/read_len_hist.png", dpi=300) plt.close(fig) return None
[docs] def plot_polya_len(read_len_arr, mean_qual_arr, bq_thres, out_path): """ Plot poly(A) length distribution. Args: read_len_arr (numpy.ndarray): Array of read lengths. mean_qual_arr (numpy.ndarray): Array of mean quality scores. bq_thres (int): Base quality threshold. out_path (str): Output directory path. Returns: None """ read_len_arr_passed = read_len_arr[mean_qual_arr >= bq_thres] read_len_arr_failed = read_len_arr[mean_qual_arr < bq_thres] ## plot read length KDE fig, ax = plt.subplots(figsize=(10, 10)) read_len_max = np.percentile(read_len_arr, 99.9) binrange = (0, read_len_max) binwidth = 10 sns.histplot( read_len_arr_passed, ax=ax, color="royalblue", label=f"Passed (n={len(read_len_arr_passed):,})", binwidth=binwidth, binrange=binrange, fill=False, lw=5, element="step", stat="density", ) sns.histplot( read_len_arr_failed, ax=ax, color="tomato", label=f"Failed (n={len(read_len_arr_failed):,})", binwidth=binwidth, binrange=binrange, fill=False, lw=5, element="step", stat="density", ) ## Median ax.axvline(np.median(read_len_arr_passed), color="black", linestyle="--", linewidth=2) ax.text( np.median(read_len_arr_passed), 0.9 * ax.get_ylim()[1], f"Passed median = {np.median(read_len_arr_passed):.0f}", color="black", ) ax.axvline(np.median(read_len_arr_failed), color="black", linestyle="--", linewidth=2) ax.text( np.median(read_len_arr_failed), 0.8 * ax.get_ylim()[1], f"Failed median = {np.median(read_len_arr_failed):.0f}", color="black", ) ax.set_title(f"Poly(A) Length Distribution (n={len(read_len_arr):,})") ax.set_xlabel("Poly(A) Length") ax.set_ylabel("Count") ax.set_xlim(0, 300) ax.legend() ## Vline at median fig.savefig(f"{out_path}/polya_len_hist.png", dpi=300) plt.close(fig) return None
[docs] def plot_qual(mean_qual_arr, out_path, bq_thres=7, max_bq=30): """ Plot mean quality score distribution. Args: mean_qual_arr (numpy.ndarray): Array of mean quality scores. out_path (str): Output directory path. bq_thres (int): Base quality threshold. max_bq (int): Maximum base quality score for plotting. Returns: None """ ## plot mean quality score KDE with histogram fig, ax = plt.subplots(figsize=(10, 10)) pass_arr = mean_qual_arr[mean_qual_arr >= bq_thres] fail_arr = mean_qual_arr[mean_qual_arr < bq_thres] ax.set_title(f"Read Mean Base Quality Distribution (n={len(mean_qual_arr):,})") pass_percent = len(pass_arr) / len(mean_qual_arr) * 100 fail_percent = len(fail_arr) / len(mean_qual_arr) * 100 sns.histplot( data=pass_arr, ax=ax, color="royalblue", label=f"Pass (n={len(pass_arr):,}, {pass_percent:.2f}%)", binwidth=0.1, binrange=(0, max_bq), ) sns.histplot( data=fail_arr, ax=ax, color="tomato", label=f"Fail (n={len(fail_arr):,}, {fail_percent:.2f}%)", binwidth=0.1, binrange=(0, max_bq), ) ## vline at median ax.axvline(np.median(mean_qual_arr), color="black", linestyle="--", linewidth=2) ax.text(np.median(mean_qual_arr), 0.9 * ax.get_ylim()[1], f"Median = {np.median(mean_qual_arr):.2f}", color="black") ## vline at passed median ax.axvline(np.median(pass_arr), color="black", linestyle="--", linewidth=2) ax.text(np.median(pass_arr), 0.8 * ax.get_ylim()[1], f"Passed Median = {np.median(pass_arr):.2f}", color="black") ## vline at failed median ax.axvline(np.median(fail_arr), color="black", linestyle="--", linewidth=2) ax.text(np.median(fail_arr), 0.7 * ax.get_ylim()[1], f"Failed Median = {np.median(fail_arr):.2f}", color="black") ax.legend() ax.set_xlim(0, max_bq) fig.savefig(f"{out_path}/mean_qual_hist.png", dpi=300) plt.close(fig) return None
[docs] def read_bam_worker(args, pid, collect_dict): """ Worker function to read BAM file and collect statistics. Args: args (argparse.Namespace): Parsed command line arguments. pid (int): Process ID. collect_dict (dict): Shared dictionary to collect results. Returns: None """ bam_file = pysam.AlignmentFile(args.bam_path, "rb", check_sq=False, threads=args.threads) total = bam_file.mapped + bam_file.unmapped proc_len = (total // args.process) + 1 read_len_arr = deque(maxlen=proc_len) qual_arr = deque(maxlen=proc_len) polya_len_arr = deque(maxlen=proc_len) for i, read in tqdm(enumerate(bam_file), total=total): if i % args.process == pid: if read.is_secondary: continue if read.has_tag("pi"): continue if args.len_cutoff > 0: if read.query_length < args.len_cutoff: continue bq = np.array(read.query_qualities, dtype=int) rl = read.query_length if read.has_tag("TL") and read.has_tag("TR"): trim_5p = read.get_tag("TL") trim_3p = read.get_tag("TR") if trim_3p - trim_5p <= 0: continue bq = read.query_qualities[trim_5p:trim_3p] rl = trim_3p - trim_5p qual_arr.append(mean_phred(bq)) read_len_arr.append(rl) try: polya_len_arr.append(read.get_tag("pt")) except Exception: polya_len_arr.append(0) collect_dict["read_len_arr"].append(np.array(read_len_arr)) collect_dict["qual_arr"].append(np.array(qual_arr)) collect_dict["polya_len_arr"].append(np.array(polya_len_arr)) return None