Source code for deeprm.train.extract_block

"""
Extract context blocks from reads.

Key steps to extract context blocks from reads:
    1. Index all k-mers from the read.
    2. Connect the spacers using the k-mer index.
    3. Build a DAG of spacers.
    4. Find the longest path in the DAG.
    5. Extract the sequence from the longest path.
"""

import gc
import glob
import itertools as it
import multiprocessing as mp
import os
from collections import defaultdict

import networkx as nx
import numpy as np
import pandas as pd
import polyleven as pl
import pysam
from tqdm import tqdm

from deeprm.utils.logging import get_logger
from deeprm.utils.memory import start_mem_watchdog
from deeprm.utils.utils import mean_phred

log = get_logger(__name__)


[docs] def get_min_ideal_displacement_dict(cb_per_bb, spacer_size, cb_size): """ Generates a dictionary of minimum ideal displacements for given parameters. Args: cb_per_bb (int): Number of context blocks per base block. spacer_size (int): Size of the spacer. cb_size (int): Size of the context block. Returns: dict: Dictionary with keys as tuples of (from_idx, to_idx) and values as tuples of (displacement, small_steps, big_steps). """ min_ideal_displacement_dict = {} big_step_size = cb_size + spacer_size small_step_size = spacer_size for from_idx in range(cb_per_bb + 1): for to_idx in range(cb_per_bb + 1): if from_idx < to_idx: small_steps = 0 big_steps = to_idx - from_idx displacement = big_step_size * big_steps else: small_steps = 1 big_steps = cb_per_bb - from_idx + to_idx displacement = big_step_size * big_steps + small_step_size min_ideal_displacement_dict[(from_idx, to_idx)] = (displacement, small_steps, big_steps) return min_ideal_displacement_dict
[docs] def get_ideal_displacement( from_spacer_idx, to_spacer_idx, displacement, min_ideal_displacement_dict, cb_per_bb, bb_size ): """ Calculates the ideal displacement and steps between spacers. Args: from_spacer_idx (int): Index of the starting spacer. to_spacer_idx (int): Index of the ending spacer. displacement (int): Actual displacement between spacers. min_ideal_displacement_dict (dict): Dictionary of minimum ideal displacements. cb_per_bb (int): Number of context blocks per base block. bb_size (int): Size of the base block. Returns: tuple: Ideal displacement, small steps, and big steps. """ min_ideal_displacement, min_small_steps, min_big_steps = min_ideal_displacement_dict[ (from_spacer_idx, to_spacer_idx) ] if displacement <= min_ideal_displacement: ideal_displacement = min_ideal_displacement small_steps = min_small_steps big_steps = min_big_steps else: periodicity = round((displacement - min_ideal_displacement) / bb_size) ideal_displacement = min_ideal_displacement + periodicity * bb_size small_steps = min_small_steps + periodicity big_steps = min_big_steps + cb_per_bb * periodicity return ideal_displacement, small_steps, big_steps
[docs] def get_integer_partition(indel_tolerance, cb_size_tolerance): """ Generates a dictionary of integer partitions for indel tolerance. Args: indel_tolerance (int): Indel tolerance. cb_size_tolerance (int): Context block size tolerance. Returns: dict: Dictionary with keys as spacing errors and values as lists of tuples of (front_error, back_error). """ indel_dict = {} for spacing_error in range(-cb_size_tolerance, cb_size_tolerance + 1): indel_list = [] for front_error in range(-indel_tolerance, indel_tolerance + 1): back_error = spacing_error - front_error if np.abs(front_error) + np.abs(back_error) <= indel_tolerance: indel_list.append((front_error, back_error)) indel_list.sort(key=lambda x: np.abs(x[0]) + np.abs(x[1])) indel_dict[spacing_error] = indel_list return indel_dict
[docs] def get_kmer_dict(read, k, bq_cutoff, phred): """ Generates a dictionary of k-mers from a read. Args: read (str): The read sequence. k (int): Length of the k-mer. bq_cutoff (float): Base quality cutoff. phred (list): List of Phred quality scores. Returns: collections.defaultdict: Dictionary with k-mers as keys and positions as values. """ kmer_dict = defaultdict(list) for i in range(len(read) - k + 1): kmer = read[i : i + k] if bq_cutoff: bq = np.mean(phred[i : i + k]) if bq < bq_cutoff: continue kmer_dict[kmer].append(i) return kmer_dict
[docs] def get_ed_kmers(kmer, spacer_mismatch_tolerance): """ Generates a dictionary of k-mers with edit distances. Args: kmer (str): The k-mer sequence. spacer_mismatch_tolerance (int): Tolerance for mismatches in spacers. Returns: collections.defaultdict: Dictionary with edit distances as keys and lists of k-mers as values. """ nucs = "ACGU" possible_nucs = ["".join(x) for x in it.product(nucs, repeat=len(kmer))] kmer_ed_dict = defaultdict(list) for possible_kmer in possible_nucs: ed = pl.levenshtein(kmer, possible_kmer, spacer_mismatch_tolerance) kmer_ed_dict[ed].append(possible_kmer) kmer_ed_dict[spacer_mismatch_tolerance + 1] = [] return kmer_ed_dict
[docs] def validate_anchor( read, from_pos, to_pos, possible_indel_list, spacer_size, cb_pad, single_anchor, indel_penalty, anchor_mismatch_penalty, displacement_error, ): """ Validates the anchor in the read sequence. Args: read (str): The read sequence. from_pos (int): Starting position. to_pos (int): Ending position. possible_indel_list (list): List of possible indels. spacer_size (int): Size of the spacer. cb_pad (int): Context block padding. single_anchor (str): Single anchor sequence. indel_penalty (int): Penalty for indels. anchor_mismatch_penalty (int): Penalty for anchor mismatches. displacement_error (int): Displacement error. Returns: tuple: Missing anchor, anchor position, and total indel. """ query = read[from_pos + spacer_size : to_pos] anchor_candidate_list = [ (displacement_error * indel_penalty + anchor_mismatch_penalty, 1, None, displacement_error) ] ## penalty, missing_anchor, anchor_pos, total_indel for front_indel, back_indel in possible_indel_list: anchor_query = query[cb_pad + front_indel] if anchor_query == single_anchor: anchor_pos = from_pos + spacer_size + cb_pad + front_indel total_indel = np.abs(front_indel) + np.abs(back_indel) anchor_candidate_list.append((total_indel * indel_penalty, 0, anchor_pos, total_indel)) anchor_candidate_list.sort(key=lambda x: (x[0], x[1])) return anchor_candidate_list[0][1:]
[docs] def get_kmer_tuple(spacer_mismatch_tolerance, from_spacer_kmer_ed_dict, to_spacer_kmer_ed_dict): """ Generates a list of k-mer tuples with mismatches. Args: spacer_mismatch_tolerance (int): Tolerance for mismatches in spacers. from_spacer_kmer_ed_dict (dict): Dictionary of k-mers with edit distances for the starting spacer. to_spacer_kmer_ed_dict (dict): Dictionary of k-mers with edit distances for the ending spacer. Returns: list: List of tuples of (from_kmer, to_kmer, total_mismatch). """ kmer_tuple_list = [] for total_mismatch in range(spacer_mismatch_tolerance + 1): for front_mismatch in range(total_mismatch + 1): back_mismatch = total_mismatch - front_mismatch for from_kmer, to_kmer in it.product( from_spacer_kmer_ed_dict[front_mismatch], to_spacer_kmer_ed_dict[back_mismatch] ): kmer_tuple_list.append((from_kmer, to_kmer, total_mismatch)) return kmer_tuple_list
[docs] def find_block_candidates( seq, phred, cb_bq_cutoff, spacer_kmer_ed_dict, skip_size_tolerance, cb_pad, cb_per_bb, indel_penalty, anchor_mismatch_penalty, spacer_mismatch_penalty, spacer_size, spacer_list, indel_dict, min_ideal_displacement_dict, anchor_list, score_converting_func, cb_size_tolerance, spacer_mismatch_tolerance, spacer_size_tolerance, bb_size, ): """ Finds block candidates in the read sequence. Args: seq (str): The read sequence. phred (list): List of Phred quality scores. cb_bq_cutoff (float): Base quality cutoff for context blocks. spacer_kmer_ed_dict (dict): Dictionary of k-mers with edit distances for spacers. skip_size_tolerance (int): Tolerance for skip size. cb_pad (int): Context block padding. cb_per_bb (int): Number of context blocks per base block. indel_penalty (int): Penalty for indels. anchor_mismatch_penalty (int): Penalty for anchor mismatches. spacer_mismatch_penalty (int): Penalty for spacer mismatches. spacer_size (int): Size of the spacer. spacer_list (list): List of spacers. indel_dict (dict): Dictionary of integer partitions for indel tolerance. min_ideal_displacement_dict (dict): Dictionary of minimum ideal displacements. anchor_list (list): List of anchors. score_converting_func (typing.Callable): Function to convert penalty to score. cb_size_tolerance (int): Context block size tolerance. spacer_mismatch_tolerance (int): Tolerance for mismatches in spacers. spacer_size_tolerance (int): Tolerance for spacer size. bb_size (int): Size of the base block. Returns: tuple: Dictionary of context block information, list of DAG edges, and dictionary of DAG edges with scores. """ kmer_pos_dict = get_kmer_dict(seq, spacer_size, cb_bq_cutoff, phred) dag_list = [] ## Format: [from_pos, to_pos, score] dag_dict = {} ## Format: {(from_pos, to_pos): score} cb_info_dict = {} ## Format: {(from_pos, to_pos): [cb_idx,from_pos,to_pos,anchor_pos,score]} for from_spacer_idx in range(len(spacer_list)): if from_spacer_idx == cb_per_bb: single_anchor = None else: single_anchor = anchor_list[from_spacer_idx] from_spacer_kmer_ed_dict = spacer_kmer_ed_dict[from_spacer_idx] for to_spacer_idx in range(len(spacer_list)): to_spacer_kmer_ed_dict = spacer_kmer_ed_dict[to_spacer_idx] kmer_tuple_list = get_kmer_tuple( spacer_mismatch_tolerance, from_spacer_kmer_ed_dict, to_spacer_kmer_ed_dict ) for from_kmer, to_kmer, kmer_mismatch in kmer_tuple_list: for from_pos, to_pos in it.product(kmer_pos_dict[from_kmer], kmer_pos_dict[to_kmer]): displacement = to_pos - from_pos if displacement < 0: continue ideal_displacement, small_steps, big_steps = get_ideal_displacement( from_spacer_idx, to_spacer_idx, displacement, min_ideal_displacement_dict, cb_per_bb, bb_size ) displacement_tolerance_skip = big_steps * skip_size_tolerance displacement_error = displacement - ideal_displacement displacement_error_abs = np.abs(displacement_error) if displacement_error_abs > displacement_tolerance_skip: continue anchor_pos = None is_cb = False missing_anchor = 1 if ( big_steps == 1 and small_steps == 0 and displacement_error_abs <= cb_size_tolerance and single_anchor is not None ): possible_indel_list = indel_dict[displacement_error] missing_anchor, anchor_pos, total_indel = validate_anchor( seq, from_pos, to_pos, possible_indel_list, spacer_size, cb_pad, single_anchor, indel_penalty, anchor_mismatch_penalty, displacement_error_abs, ) if missing_anchor == 0: is_cb = True displacement_error_abs = total_indel elif big_steps == 0 and small_steps == 1 and displacement_error_abs <= spacer_size_tolerance: missing_anchor = 0 penalty = ( spacer_mismatch_penalty * kmer_mismatch + indel_penalty * displacement_error_abs + anchor_mismatch_penalty * missing_anchor ) score = score_converting_func(penalty) from_pos_id = (from_spacer_idx, from_pos) to_pos_id = (to_spacer_idx, to_pos) if is_cb: cb_info_dict[(from_pos_id, to_pos_id)] = [ from_spacer_idx, from_pos, to_pos, anchor_pos, penalty, score, ] dag_list.append((from_pos_id, to_pos_id, score)) dag_dict[(from_pos_id, to_pos_id)] = score return cb_info_dict, dag_list, dag_dict
[docs] def dag_longest_path(edge_list): """ Finds the longest path in a directed acyclic graph (DAG). Args: edge_list (list): List of edges in the DAG. Returns: list: Longest path in the DAG. """ node_list = list(set([x[0] for x in edge_list] + [x[1] for x in edge_list])) dag = nx.DiGraph() dag.add_nodes_from(node_list) dag.add_weighted_edges_from(edge_list) longest_path = nx.dag_longest_path(dag, weight="weight") return longest_path
[docs] def extract_blocks_from_read_list_mp_worker( record_list, indel_penalty, cb_size_tolerance, skip_size_tolerance, anchor_mismatch_penalty, spacer_size_tolerance, spacer_mismatch_tolerance, spacer_mismatch_penalty, cb_pad, cb_per_bb, cb_bq_cutoff, indel_dict, spacer_kmer_ed_dict, anchor_list, spacer_list, spacer_size, bb_size, flush_path, pid, flush_interval, score_converting_func, cb_size, min_ideal_displacement_dict, resume, ): """ Worker function to extract blocks from a list of reads using multiprocessing. Args: record_list (list): List of read records. indel_penalty (int): Penalty for indels. cb_size_tolerance (int): Context block size tolerance. skip_size_tolerance (int): Tolerance for skip size. anchor_mismatch_penalty (int): Penalty for anchor mismatches. spacer_size_tolerance (int): Tolerance for spacer size. spacer_mismatch_tolerance (int): Tolerance for mismatches in spacers. spacer_mismatch_penalty (int): Penalty for spacer mismatches. cb_pad (int): Context block padding. cb_per_bb (int): Number of context blocks per base block. cb_bq_cutoff (float): Base quality cutoff for context blocks. indel_dict (dict): Dictionary of integer partitions for indel tolerance. spacer_kmer_ed_dict (dict): Dictionary of k-mers with edit distances for spacers. anchor_list (list): List of anchors. spacer_list (list): List of spacers. spacer_size (int): Size of the spacer. bb_size (int): Size of the base block. flush_path (str): Path to save intermediate flush files. pid (int): Process ID. flush_interval (int): Interval for flushing data to disk. score_converting_func (typing.Callable): Function to convert penalty to score. cb_size (int): Size of the context block. min_ideal_displacement_dict (dict): Dictionary of minimum ideal displacements. resume (str): Path to resume from previous run. Returns: None """ start_mem_watchdog() len_record = len(record_list) block_df_list = [] flush_file_list = [] last_flush_idx = 0 if resume is not None: ## search for last flush file flush_file_list = glob.glob(os.path.join(resume, f"df_{pid}_*.pkl")) if len(flush_file_list) > 0: flush_idx = [int(x.split("_")[-1].split(".")[0]) for x in flush_file_list] last_flush_idx = max(flush_idx) record_list = record_list[last_flush_idx:] gc.collect() log.info(f"[Process-{pid}] Resuming from {last_flush_idx}th read. {len(record_list)} reads remaining.") else: log.info(f"[Process-{pid}] No flush file found. Starting from the beginning.") for read_idx, record in tqdm(enumerate(record_list), total=len(record_list)): read_idx += last_flush_idx read_id = record[0] seq = record[1].replace("T", "U") phred = record[2] cb_info_dict, dag_list, dag_dict = find_block_candidates( seq, phred, cb_bq_cutoff, spacer_kmer_ed_dict, skip_size_tolerance, cb_pad, cb_per_bb, indel_penalty, anchor_mismatch_penalty, spacer_mismatch_penalty, spacer_size, spacer_list, indel_dict, min_ideal_displacement_dict, anchor_list, score_converting_func, cb_size_tolerance, spacer_mismatch_tolerance, spacer_size_tolerance, bb_size, ) if len(cb_info_dict) > 0: longest_path = dag_longest_path(dag_list) selected_cb = [] total_score = 0 for x, y in zip(longest_path[:-1], longest_path[1:]): if (x, y) in cb_info_dict: selected_cb.append(cb_info_dict[(x, y)]) total_score += dag_dict[(x, y)] selected_cb_df = pd.DataFrame( selected_cb, columns=["cb_idx", "start_pos", "end_pos", "pos_RM", "penalty", "score"] ) selected_cb_df["read_id"] = read_id selected_cb_df["total_score"] = total_score spacer_pos = np.unique(selected_cb_df[["start_pos", "end_pos"]].values.flatten()) spacer_phred = [phred[x : x + spacer_size] for x in spacer_pos] if len(spacer_phred) > 0: spacer_phred = np.concatenate(spacer_phred) mean_spacer_phred = np.mean(spacer_phred) selected_cb_df["mean_spacer_phred"] = mean_spacer_phred selected_cb_df["start_pos"] = selected_cb_df["pos_RM"] - cb_pad selected_cb_df["end_pos"] = selected_cb_df["pos_RM"] + cb_pad + 1 selected_cb_df["motif"] = selected_cb_df.apply(lambda x: seq[x["start_pos"] : x["end_pos"]], axis=1) selected_cb_df["bq"] = selected_cb_df.apply(lambda x: phred[x["start_pos"] : x["end_pos"]], axis=1) selected_cb_df = selected_cb_df[selected_cb_df["end_pos"] <= len(seq)] block_df_list.append(selected_cb_df) ## END IF ## END IF ## Periodic flush to reduce memory usage if (read_idx % flush_interval == 0 and read_idx != 0) or (read_idx == len_record - 1): if len(block_df_list) > 0: block_df_flush = pd.concat(block_df_list, axis=0).reset_index(drop=True) block_df_flush["bq_len"] = block_df_flush["bq"].apply(len) block_df_flush["motif_len"] = block_df_flush["motif"].apply(len) block_df_flush = block_df_flush[ (block_df_flush["start_pos"] >= 0) & (block_df_flush["bq_len"] == cb_size) & (block_df_flush["motif_len"] == cb_size) & (block_df_flush["mean_spacer_phred"] >= cb_bq_cutoff) ] flush_file = f"{flush_path}df_{pid}_{read_idx}.pkl" block_df_flush.to_pickle(flush_file) flush_file_list.append(flush_file) block_df_list = [] del block_df_flush gc.collect() gc.collect() block_df_list = [] if len(flush_file_list) > 0: for flush_file in flush_file_list: block_df = pd.read_pickle(flush_file) block_df_list.append(block_df) gc.collect() block_df = pd.concat(block_df_list, axis=0).reset_index(drop=True) del block_df_list block_df.to_pickle(f"{flush_path}df_{pid}.pkl") gc.collect() return None
[docs] def extract_block( input, output, indel_tolerance, indel_penalty, cb_size_tolerance, skip_size_tolerance, anchor_mismatch_penalty, spacer_size_tolerance, spacer_mismatch_tolerance, max_read_length, spacer_mismatch_penalty, anchor_list, spacer_list, spacer_size, cb_pad, cb_per_bb, read_bq_cutoff, cb_bq_cutoff, flush_path, flush_interval, ncpu, resume, sample, **kwargs, ): """ Extracts context blocks from a list of reads using multiprocessing. Args: input (str): Path to the input BAM file. output (str): Path to save the output pickle file. indel_tolerance (int): Indel tolerance. indel_penalty (int): Penalty for indels. cb_size_tolerance (int): Context block size tolerance. skip_size_tolerance (int): Tolerance for skip size. anchor_mismatch_penalty (int): Penalty for anchor mismatches. spacer_size_tolerance (int): Tolerance for spacer size. spacer_mismatch_tolerance (int): Tolerance for mismatches in spacers. max_read_length (int): Maximum read length. spacer_mismatch_penalty (int): Penalty for spacer mismatches. anchor_list (list): List of anchors. spacer_list (list): List of spacers. spacer_size (int): Size of the spacer. cb_pad (int): Context block padding. cb_per_bb (int): Number of context blocks per base block. read_bq_cutoff (float): Base quality cutoff for reads. cb_bq_cutoff (float): Base quality cutoff for context blocks. flush_path (str): Path to save intermediate flush files. flush_interval (int): Interval for flushing data to disk. ncpu (int): Number of CPU threads to use. resume (str): Path to resume from previous run. sample (int): Number of reads to sample. **kwargs: Additional arguments. Returns: None """ def score_converting_func(x): return 1 - (x / (2 * max_cb_penalty)) spacer_list = [x.replace("T", "U") for x in spacer_list] anchor_list = [x.replace("T", "U") for x in anchor_list] indel_dict = get_integer_partition(indel_tolerance, cb_size_tolerance) spacer_kmer_ed_dict = {i: get_ed_kmers(kmer, spacer_mismatch_tolerance) for i, kmer in enumerate(spacer_list)} assert indel_tolerance >= cb_size_tolerance max_cb_penalty = ( anchor_mismatch_penalty + spacer_mismatch_penalty * spacer_mismatch_tolerance + indel_penalty * indel_tolerance ) cb_size = 2 * cb_pad + 1 bb_size = cb_size * cb_per_bb + spacer_size min_ideal_displacement_dict = get_min_ideal_displacement_dict(cb_per_bb, spacer_size, cb_size) record_list = [] with pysam.AlignmentFile(input, "rb", check_sq=False, threads=ncpu) as input_bam: with tqdm(total=input_bam.mapped) as pbar: for idx, record in enumerate(input_bam): qscore = mean_phred(np.array(record.query_qualities, dtype=int)) if qscore >= read_bq_cutoff: read_length = record.query_length if read_length <= max_read_length and read_length >= kwargs["min_read_length"]: record_tuple = ( str(record.query_name), str(record.query_sequence), np.array(record.query_qualities), int(read_length), ) record_list.append(record_tuple) pbar.update(1) if sample is not None: sample_idx = np.random.choice(len(record_list), sample, replace=False) record_list = [record_list[i] for i in sample_idx] record_list.sort(key=lambda x: x[3], reverse=True) record_cnt = len(record_list) record_split_dict = {i: [] for i in range(ncpu)} for i, fastq in enumerate(record_list): group = int(np.abs((i % (2 * ncpu)) - ncpu + 0.5) - 0.5) record_split_dict[group].append(fastq) del record_list gc.collect() proc_list = [] for pid in range(ncpu): proc = mp.Process( target=extract_blocks_from_read_list_mp_worker, args=( record_split_dict[pid], indel_penalty, cb_size_tolerance, skip_size_tolerance, anchor_mismatch_penalty, spacer_size_tolerance, spacer_mismatch_tolerance, spacer_mismatch_penalty, cb_pad, cb_per_bb, cb_bq_cutoff, indel_dict, spacer_kmer_ed_dict, anchor_list, spacer_list, spacer_size, bb_size, flush_path, pid, flush_interval, score_converting_func, cb_size, min_ideal_displacement_dict, resume, ), ) proc_list.append(proc) proc.start() for proc in proc_list: proc.join() block_df_list = [] for pid in range(ncpu): try: block_df = pd.read_pickle(f"{flush_path}df_{pid}.pkl") block_df_list.append(block_df) except Exception as e: log.warning(f"{e}") log.warning(f"PID {pid} did not return any result.") block_df = pd.concat(block_df_list, axis=0).reset_index(drop=True) del block_df_list gc.collect() block_df.to_pickle(f"{output}/block.pkl") dag_log = [] dag_log.append(f"Total number of passed reads: {record_cnt:,}") dag_log.append(f"Total number of context blocks: {len(block_df):,}") dag_log.append(f"Context blocks per read: {len(block_df) / record_cnt:.2f}") dag_log.append(block_df["score"].describe()) dag_log.append(block_df["penalty"].describe()) log_path = f"{output}/dag_log.txt" with open(log_path, "w") as log_file: for line in dag_log: log_file.write(f"{line}\n") print("=============================================") for line in dag_log: log.info(line) print("=============================================") log.info(f"Saved context blocks to {output}/block.pkl.") return block_df