"""
DeepRM Train DataLoader
This module provides an IterableDataset implementation for loading
chunked binary classification datasets from NPZ files.
It randomly selects positive and negative samples based on a specified class ratio.
Partially inspired by:
https://discuss.pytorch.org/t/an-iterabledataset-implementation-for-chunked-data/124437
"""
import functools
import gc
import glob
import math
import os
import numpy as np
from deeprm.utils import check_deps
from deeprm.utils.logging import get_logger
log = get_logger(__name__)
check_deps.check_torch_available()
import torch
from torch.utils.data import DataLoader, IterableDataset
[docs]
class BinaryClassDatasetIterator:
"""
Iterator for loading binary classification dataset from NPZ files.
Args:
pos_file_paths (list): List of file paths to positive samples.
neg_file_paths (list): List of file paths to negative samples.
disk_shard_size (int): Size of the disk shard.
shuffle_buffer_size (int): Size of the shuffle buffer.
shuffle (bool): Whether to shuffle the data.
class_ratio (float): Ratio of positive to negative samples.
soft_label (bool): Whether to use soft labels.
yield_period (int): Period for yielding data.
batch_size (int): Batch size for loading data.
"""
def __init__(
self,
pos_file_paths,
neg_file_paths,
disk_shard_size,
shuffle_buffer_size,
shuffle=True,
class_ratio=0.5,
soft_label=False,
yield_period=1,
batch_size=1,
):
self.paths = [neg_file_paths, pos_file_paths]
self.paths_len = [len(neg_file_paths), len(pos_file_paths)]
self.shuffle = shuffle
self.disk_shard_size = disk_shard_size
self.shuffle_buffer_size = shuffle_buffer_size
self.current_class = 0
self.avail_class = [0, 1]
self.total_class = len(self.avail_class)
self.len_iterator = [0 for class_name in range(self.total_class)]
self.current_df_index = [-1 for class_name in range(self.total_class)]
self.current_iterator = [[] for class_name in range(self.total_class)]
self.class_ratio = class_ratio
self.yield_period = yield_period
self.batch_size = batch_size
self.soft_label = soft_label
self.keys = [
"segment_len_arr",
"signal_token",
"kmer_token",
"dwell_motor_token",
"dwell_pore_token",
"bq_token",
]
self.buffer = [[[] for key in self.keys] for class_name in range(self.total_class)]
assert (
self.yield_period <= self.shuffle_buffer_size
), "Shuffle period should be less than or equal to shuffle buffer size"
def __iter__(self):
"""
Returns the iterator object itself.
Returns:
BinaryClassDatasetIterator: The iterator object.
"""
return self
def __next__(self):
"""
Returns the next data from the iterator.
Returns:
tuple: A tuple containing the next data and the current class.
Raises:
StopIteration: If there are no more files to read.
"""
return self._next(), self.current_class
def _read_shuffle_data(self, first_read=False):
"""
Reads and shuffles data from NPZ files.
Args:
first_read (bool): Whether it is the first read.
Returns:
None
"""
current_buffer = self.buffer[self.current_class]
len_buffer = len(current_buffer[0])
if first_read:
assert len_buffer == 0, "Buffer should be empty when first read"
else:
assert len_buffer <= 1, f"Buffer should have at most one item, but has {len_buffer}"
if len_buffer == 1:
len_buffer = len(current_buffer[0][0])
while len_buffer < self.shuffle_buffer_size and (
self.current_df_index[self.current_class] < self.paths_len[self.current_class] - 1
):
self.current_df_index[self.current_class] += 1
try:
with np.load(self.paths[self.current_class][self.current_df_index[self.current_class]]) as npz:
for key_idx, key in enumerate(self.keys):
current_buffer[key_idx].append(npz[key])
len_buffer += len(npz[self.keys[0]])
except Exception as e:
log.warning(
f"Skipping file {self.paths[self.current_class][self.current_df_index[self.current_class]]} due to error: {e}"
)
continue
## Concat and shuffle the data
if self.shuffle:
shuffle_idx = np.random.permutation(len_buffer)
data_to_yield = []
for key_idx, key in enumerate(self.keys):
data = np.concatenate(current_buffer[key_idx])
assert len_buffer == len(data), f"Buffer length {len_buffer} != data length {len(data)}"
if self.shuffle:
data = data[shuffle_idx]
if len_buffer > self.yield_period:
data_to_yield.append(data[: self.yield_period])
current_buffer[key_idx] = [data[self.yield_period :]]
else:
data_to_yield.append(data)
current_buffer[key_idx] = []
self._set_iterator(data_to_yield)
gc.collect()
return None
def _exhaust_buffer(self):
"""
Exhausts the buffer and sets the iterator.
Returns:
None
"""
current_buffer = self.buffer[self.current_class]
assert (
len(current_buffer[0]) == 1
), f"Buffer to be exhausted should have exactly one item, but has {len(current_buffer[0])}"
data_to_yield = [current_buffer[key_idx][0] for key_idx in range(len(current_buffer))]
self._set_iterator(data_to_yield)
for key_idx in range(len(current_buffer)):
current_buffer[key_idx] = []
gc.collect()
return None
def _set_iterator(self, data_to_yield):
"""
Sets the iterator with the given data.
Args:
data_to_yield (list): List of data to yield.
Returns:
None
"""
lengths = [len(data) for data in data_to_yield]
assert len(set(lengths)) == 1, f"Data lengths are not equal: {lengths}"
iterator = zip(*data_to_yield)
self.current_iterator[self.current_class] = iterator
self.len_iterator[self.current_class] = lengths[0]
gc.collect()
return None
def _get_rand_class(self):
"""
Randomly selects a class based on the class ratio.
Returns:
int: The selected class.
"""
rand = torch.rand(1).item()
if rand < self.class_ratio:
return 0
else:
return 1
def _next(self):
"""
Returns the next data from the iterator.
Returns:
tuple: A tuple containing the next data and the current class.
Raises:
StopIteration: If there are no more files to read.
"""
## Randomly decide between positive and negative data
if len(self.avail_class) == 0: ## No more data to read in both classes
raise StopIteration
elif len(self.avail_class) == 1: ## Only one class available
self.current_class = self.avail_class[0]
else:
self.current_class = self._get_rand_class()
if self.current_df_index[self.current_class] == -1: ## First time reading data
self._read_shuffle_data(first_read=True)
elif self.len_iterator[self.current_class] == 0: ## Current iterator ran out of data
if (
self.current_df_index[self.current_class] < self.paths_len[self.current_class] - 1
): ## Still have data to read from disk
self._read_shuffle_data(first_read=False)
elif len(self.buffer[self.current_class][0]) > 0: ## No data to read from disk, but buffer has data
self._exhaust_buffer()
else: ## No data to read from disk, and buffer is empty.
pass
else: ## Current iterator has data to process
pass
try: ## Check if the current iterator has data to process
result = next(self.current_iterator[self.current_class])
self.len_iterator[self.current_class] -= 1
except StopIteration: ## Current iterator ran out of data, and no more data to read.
self.avail_class.remove(self.current_class)
result = self._next()
return result
## END of BinaryClassDatasetIterator
[docs]
class NanoporeDataset(IterableDataset):
"""
Iterable dataset for loading Nanopore data from NPZ files.
Args:
pos_data_path (list): Paths to the directory containing positive samples.
neg_data_path (list): Paths to the directory containing negative samples.
batch_size (int): Batch size for loading data.
disk_shard_size (int): Size of the disk shard.
rank (int): Rank of the current process.
num_replicas (int): Number of replicas.
shuffle_buffer_size (int): Size of the shuffle buffer.
yield_period (int): Period for yielding data.
seed (int): Random seed.
shuffle (bool): Whether to shuffle the data.
drop_last (bool): Whether to drop the last incomplete batch.
class_ratio (float): Ratio of positive to negative samples.
soft_label (bool): Whether to use soft labels.
"""
def __init__(
self,
pos_data_path,
neg_data_path,
batch_size,
disk_shard_size,
rank,
num_replicas,
shuffle_buffer_size,
yield_period=None,
seed=0,
shuffle=True,
drop_last=True,
class_ratio=1,
soft_label=False,
):
super(NanoporeDataset).__init__()
self.pos_file_paths = pos_data_path
self.neg_file_paths = neg_data_path
self.batch_size = batch_size
self.disk_shard_size = disk_shard_size
self.rank = rank
self.num_replicas = num_replicas
self.rank_gpu = rank
self.num_gpu = num_replicas
self.shuffle = shuffle
self.drop_last = drop_last
self.epoch = 0
self.seed = seed
self.shuffle_buffer_size = shuffle_buffer_size
self.class_ratio = class_ratio
self.soft_label = soft_label
self.yield_period = yield_period
if self.drop_last:
self.pos_num_shard = math.floor(len(self.pos_file_paths) / num_replicas)
self.neg_num_shard = math.floor(len(self.neg_file_paths) / num_replicas)
self.pos_num_shard = min(self.pos_num_shard, int(self.neg_num_shard / self.class_ratio))
self.neg_num_shard = int(self.pos_num_shard * self.class_ratio)
self.pos_total_num_shard = self.pos_num_shard * num_replicas
self.pos_dataset_size = self.pos_total_num_shard * disk_shard_size
self.neg_total_num_shard = self.neg_num_shard * num_replicas
self.neg_dataset_size = self.neg_total_num_shard * disk_shard_size
else:
self.pos_num_shard = math.ceil(len(self.pos_file_paths) / num_replicas)
self.neg_num_shard = math.ceil(len(self.neg_file_paths) / num_replicas)
self.pos_num_shard = min(self.pos_num_shard, int(self.neg_num_shard / self.class_ratio))
self.neg_num_shard = int(self.pos_num_shard * self.class_ratio)
self.pos_total_num_shard = self.pos_num_shard * num_replicas
self.pos_dataset_size = self.pos_total_num_shard * self.disk_shard_size
self.neg_total_num_shard = self.neg_num_shard * num_replicas
self.neg_dataset_size = self.neg_total_num_shard * self.disk_shard_size
self.dataset_size = self.pos_dataset_size + self.neg_dataset_size
[docs]
def reinit(self):
"""
Reinitializes the dataset using worker information.
Returns:
None
"""
## After replicating the dataset, reinitialize the dataset using worker_info
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.rank = worker_info.id + self.rank_gpu * worker_info.num_workers
self.num_replicas = worker_info.num_workers * self.num_gpu
if self.drop_last:
self.pos_num_shard = math.floor(len(self.pos_file_paths) / self.num_replicas)
self.neg_num_shard = math.floor(len(self.neg_file_paths) / self.num_replicas)
self.pos_num_shard = min(self.pos_num_shard, int(self.neg_num_shard / self.class_ratio))
self.neg_num_shard = int(self.pos_num_shard * self.class_ratio)
self.pos_total_num_shard = self.pos_num_shard * self.num_replicas
self.pos_dataset_size = self.pos_total_num_shard * self.disk_shard_size
self.neg_total_num_shard = self.neg_num_shard * self.num_replicas
self.neg_dataset_size = self.neg_total_num_shard * self.disk_shard_size
else:
self.pos_num_shard = math.ceil(len(self.pos_file_paths) / self.num_replicas)
self.neg_num_shard = math.ceil(len(self.neg_file_paths) / self.num_replicas)
self.pos_num_shard = min(self.pos_num_shard, int(self.neg_num_shard / self.class_ratio))
self.neg_num_shard = int(self.pos_num_shard * self.class_ratio)
self.pos_total_num_shard = self.pos_num_shard * self.num_replicas
self.pos_dataset_size = self.pos_total_num_shard * self.disk_shard_size
self.neg_total_num_shard = self.neg_num_shard * self.num_replicas
self.neg_dataset_size = self.neg_total_num_shard * self.disk_shard_size
self.dataset_size = self.pos_dataset_size + self.neg_dataset_size
def __len__(self):
"""
Returns the length of the dataset.
Returns:
int: Number of samples in the dataset.
"""
return self.dataset_size
def __iter__(self):
"""
Returns an iterator for the dataset.
Returns:
BinaryClassDatasetIterator: Iterator for the dataset.
"""
self.reinit()
pos_file_paths = self._deterministic_shuffle_and_sample(
self.pos_file_paths, self.pos_num_shard, self.pos_total_num_shard
)
neg_file_paths = self._deterministic_shuffle_and_sample(
self.neg_file_paths, self.neg_num_shard, self.neg_total_num_shard
)
class_ratio_iterator = self.class_ratio / (1 + self.class_ratio)
return BinaryClassDatasetIterator(
pos_file_paths,
neg_file_paths,
disk_shard_size=self.disk_shard_size,
shuffle_buffer_size=self.shuffle_buffer_size,
shuffle=self.shuffle,
class_ratio=class_ratio_iterator,
soft_label=self.soft_label,
yield_period=self.yield_period,
batch_size=self.batch_size,
)
[docs]
def set_epoch(self, epoch: int) -> None:
"""
Sets the epoch for the dataset.
Args:
epoch (int): The epoch number.
Returns:
None
"""
self.epoch = epoch
return None
def _deterministic_shuffle_and_sample(self, data_path_list, num_shard, total_num_shard):
"""
Deterministically shuffles and samples the data paths.
Args:
data_path_list (list): List of data paths.
num_shard (int): Number of shards.
total_num_shard (int): Total number of shards.
Returns:
list: List of shuffled and sampled data paths.
"""
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(data_path_list), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(data_path_list))) # type: ignore[arg-type]
if len(indices) < total_num_shard:
# add extra samples to make it evenly divisible
padding_size = total_num_shard - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
elif len(indices) > total_num_shard:
# remove tail of data to make it evenly divisible.
indices = indices[:total_num_shard]
assert len(indices) == total_num_shard, f"{len(indices)} != {total_num_shard}"
indices = indices[self.rank : total_num_shard : self.num_replicas][:num_shard]
subsampled = [data_path_list[i] for i in indices]
return subsampled
## END of NanoporeDataset
[docs]
class NanoporeDataLoader(DataLoader):
"""
DataLoader for loading Nanopore data.
Args:
dataset (NanoporeDataset): The dataset to load data from.
batch_size (int): Batch size for loading data.
num_workers (int): Number of worker processes.
pin_memory (bool): Whether to pin memory.
drop_last (bool): Whether to drop the last incomplete batch.
collate_fn (typing.Callable): Function to collate data into batches.
prefetch_factor (int): Number of batches to prefetch.
"""
def __init__(
self, dataset: NanoporeDataset, batch_size, num_workers, pin_memory, drop_last, collate_fn, prefetch_factor
):
shuffle = False
sampler = None
self.dataset = dataset
super().__init__(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
shuffle=shuffle,
sampler=sampler,
collate_fn=collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=True,
)
def __len__(self):
"""
Returns the length of the DataLoader.
Returns:
int: Number of batches in the DataLoader.
"""
dataset = self.dataset
# Total replicas = GPUs (num_gpu) * DataLoader workers
num_replicas = dataset.num_gpu * self.num_workers
if dataset.drop_last:
pos_num_shard = math.floor(len(dataset.pos_file_paths) / num_replicas)
neg_num_shard = math.floor(len(dataset.neg_file_paths) / num_replicas)
else:
pos_num_shard = math.ceil(len(dataset.pos_file_paths) / num_replicas)
neg_num_shard = math.ceil(len(dataset.neg_file_paths) / num_replicas)
pos_num_shard = min(pos_num_shard, int(neg_num_shard / dataset.class_ratio))
neg_num_shard = int(pos_num_shard * dataset.class_ratio)
per_rank_samples = (pos_num_shard + neg_num_shard) * dataset.disk_shard_size
return per_rank_samples // self.batch_size * self.num_workersl
[docs]
def set_epoch(self, epoch: int) -> None:
"""
Sets the epoch for the DataLoader.
Args:
epoch (int): The epoch number.
Returns:
None
"""
self.dataset.set_epoch(epoch)
return None
## END of NanoporeDataLoader
[docs]
def load_dataset(
pos_data_path,
neg_data_path,
batch_size,
disk_shard_size,
rank,
num_replicas,
shuffle_buffer_size,
yield_period,
seed=0,
shuffle=True,
drop_last=True,
pad_to=200,
class_ratio=1,
prefetch_factor=512,
pin_memory=True,
soft_label=False,
num_workers=4,
signal_stride=6,
kmer_size=5,
**kwargs,
):
"""
Loads the Nanopore dataset using DataLoader.
Args:
pos_data_path (str): Path to the directory containing positive samples.
neg_data_path (str): Path to the directory containing negative samples.
batch_size (int): Batch size for loading data.
disk_shard_size (int): Size of the disk shard.
rank (int): Rank of the current process.
num_replicas (int): Number of replicas.
shuffle_buffer_size (int): Size of the shuffle buffer.
yield_period (int): Period for yielding data.
seed (int): Random seed. Defaults to 0. (optional)
shuffle (bool): Whether to shuffle the data. Defaults to True. (optional)
drop_last (bool): Whether to drop the last incomplete batch. Defaults to True. (optional)
pad_to (int): Padding length for sequences. Defaults to 200. (optional)
class_ratio (float): Ratio of positive to negative samples. Defaults to 1. (optional)
prefetch_factor (int): Number of batches to prefetch. Defaults to 512. (optional)
pin_memory (bool): Whether to pin memory. Defaults to True. (optional)
soft_label (bool): Whether to use soft labels. Defaults to False. (optional)
num_workers (int): Number of worker processes. Defaults to 4. (optional)
signal_stride (int): Signal stride. Defaults to 6. (optional)
kmer_size (int): K-mer size. Defaults to 5. (optional)
**kwargs: Additional keyword arguments. (optional)
Returns:
NanoporeDataLoader: DataLoader for loading the dataset.
"""
pad_collate_func = functools.partial(pad_collate, pad_to=pad_to, signal_stride=signal_stride, kmer_size=kmer_size)
## Use DataLoader to load the dataset
pos_data_paths = glob.glob(os.path.join(pos_data_path, "*.npz"))
neg_data_paths = glob.glob(os.path.join(neg_data_path, "*.npz"))
if class_ratio is None:
class_ratio = len(neg_data_paths) / len(pos_data_paths)
if rank == 0:
log.info(f"Neg:Pos = {class_ratio:.3f}:1")
if yield_period is None:
yield_period = batch_size * 25
min_len = min(len(pos_data_paths), len(neg_data_paths))
if min_len < num_replicas:
log.error(f"The number of datapoints is smaller than the number of GPUs: {min_len} < {num_replicas}")
raise ValueError(f"The number of datapoints is smaller than the number of GPUs: {min_len} < {num_replicas}")
if min_len < num_workers * num_replicas:
new_max_workers = min_len // num_replicas
log.warning(
f"The number of datapoints is smaller than the number of workers: {min_len} < {num_workers * num_replicas}"
)
log.warning(f"Setting number of workers to {new_max_workers * num_replicas}")
num_workers = new_max_workers
dataset = NanoporeDataset(
pos_data_paths,
neg_data_paths,
batch_size,
disk_shard_size,
rank,
num_replicas,
shuffle_buffer_size,
yield_period,
seed,
shuffle,
drop_last,
class_ratio=class_ratio,
soft_label=soft_label,
)
dataloader = NanoporeDataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=pad_collate_func,
prefetch_factor=prefetch_factor,
)
return dataloader
[docs]
def pad_collate(batch, pad_to, signal_stride, kmer_size, trim=2):
"""
Collate function for DataLoader.
Args:
batch (list): List of samples in the batch.
pad_to (int): Padding length for sequences.
signal_stride (int): Signal stride.
kmer_size (int): K-mer size.
trim (int): Trim length. Defaults to 2. (optional)
Returns:
tuple: A tuple containing the source and target tensors.
"""
# Transform into Batch First
# ORDER: ["segment_len_arr", "signal_token", "kmer_token", "dwell_motor_token", "dwell_pore_token", "bq_token"]
label_list = []
kmer_token_list = []
signal_token_list = []
dwell_motor_token_list = []
dwell_pore_token_list = []
bq_token_list = []
segment_len_list = []
for source, target in batch:
label_list.append(target)
segment_len_list.append(source[0])
signal_token_list.append(source[1])
kmer_token_list.append(source[2])
dwell_motor_token_list.append(source[3])
dwell_pore_token_list.append(source[4])
bq_token_list.append(source[5])
target = torch.tensor(label_list, dtype=torch.float32)
src_kmer = torch.tensor(np.stack(kmer_token_list), dtype=torch.int32)
src_seg_len = torch.tensor(np.stack(segment_len_list), dtype=torch.int32)
src_dwell_motor = np.stack(dwell_motor_token_list)
src_dwell_pore = np.stack(dwell_pore_token_list)
src_bq = np.stack(bq_token_list)
src_signal = torch.tensor(np.stack(signal_token_list), dtype=torch.float32)
src_dwell_bq = torch.tensor(np.stack([src_dwell_motor, src_dwell_pore, src_bq], axis=-1), dtype=torch.float32)
source = {}
source["kmer_token"] = src_kmer
source["segment_len"] = src_seg_len
source["signal_token"] = src_signal
source["dwell_bq_token"] = src_dwell_bq
return source, target