Coverage for bloodhound/data.py : 36.84%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import random
2import os
3import numpy as np
4from collections import defaultdict
5import torch
6from torch.utils.data import DataLoader
7from torch.utils.data import Dataset
8import lightning as L
9from dataclasses import dataclass, field
10from corgi.seqtree import SeqTree
13RANKS = ["phylum", "class", "order", "family", "genus", "species"]
15def read_memmap(path, count, dtype:str="float16") -> np.memmap:
16 file_size = os.path.getsize(path)
17 dtype_size = np.dtype(dtype).itemsize
18 num_elements = file_size // dtype_size
19 embedding_size = num_elements // count
20 shape = (count, embedding_size)
21 return np.memmap(path, dtype=dtype, mode='r', shape=shape)
24def gene_id_from_accession(accession:str):
25 return accession.split("/")[-1]
28def choose_k_from_n(lst, k) -> list[int]:
29 n = len(lst)
30 if n == 0:
31 return []
32 repetitions = k // n
33 remainder = k % n
34 result = lst * repetitions + random.sample(lst, remainder)
35 return result
38@dataclass(kw_only=True)
39class BloodhoundStack():
40 species:str
41 array_indices:np.array
44@dataclass(kw_only=True)
45class BloodhoundPredictionDataset(Dataset):
46 array:np.memmap|np.ndarray
47 accessions: list[str]
48 seq_count:int
49 repeats:int = 2
50 seed:int = 42
51 stacks: list[BloodhoundStack] = field(init=False)
53 def __post_init__(self):
54 species_to_array_indices = defaultdict(set)
55 for index, accession in enumerate(self.accessions):
56 slash_position = accession.rfind("/")
57 assert slash_position != -1
58 species = accession[:slash_position]
59 species_to_array_indices[species].add(index)
61 # Build stacks
62 random.seed(self.seed)
63 self.stacks = []
64 for species, species_array_indices in species_to_array_indices.items():
65 stack_indices = []
66 remainder = []
67 for repeat_index in range(self.repeats + 1):
68 if len(remainder) == 0 and repeat_index >= self.repeats:
69 break
71 # Finish Remainder
72 species_array_indices_set = set(species_array_indices)
73 available = species_array_indices_set - set(remainder)
74 to_add = random.sample(list(available), self.seq_count - len(remainder))
75 to_add_set = set(to_add)
76 assert not set(remainder) & to_add_set, "remainder and to_add should be disjoint"
78 new_stack = sorted(remainder + to_add)
79 stack_indices.append(new_stack)
81 remainder = list(species_array_indices_set - to_add_set)
82 random.shuffle(remainder)
84 # If we have already added each item the required number of times, then stop
85 if repeat_index >= self.repeats:
86 break
88 while len(remainder) >= self.seq_count:
89 stack_indices.append( remainder[:self.seq_count] )
90 remainder = remainder[self.seq_count:]
92 stack = BloodhoundStack(species=species, array_indices=stack_indices)
93 self.stacks.append(stack)
95 def __len__(self):
96 return len(self.stacks)
98 def __getitem__(self, idx):
99 stack = self.stacks[idx]
100 array_indices = stack.array_indices
102 assert len(array_indices) > 0, f"Stack has no array indices"
103 with torch.no_grad():
104 data = np.array(self.array[array_indices, :], copy=False)
105 embeddings = torch.tensor(data, dtype=torch.float16)
106 del data
108 return embeddings
111@dataclass(kw_only=True)
112class BloodhoundTrainingDataset(Dataset):
113 accessions: list[str]
114 seqtree: SeqTree
115 array:np.memmap|np.ndarray
116 gene_id_dict: dict[str, int]
117 accession_to_array_index:dict[str,int]|None=None
118 seq_count:int = 0
120 def __len__(self):
121 return len(self.accessions)
123 def __getitem__(self, idx):
124 accession = self.accessions[idx]
125 array_indices = self.accession_to_array_index[accession] if self.accession_to_array_index else idx
126 if self.seq_count:
127 array_indices = choose_k_from_n(array_indices, self.seq_count)
129 assert len(array_indices) > 0, f"Accession {accession} has no array indices"
130 with torch.no_grad():
131 data = np.array(self.array[array_indices, :], copy=False)
132 embedding = torch.tensor(data, dtype=torch.float16)
133 del data
135 # gene_id = gene_id_from_accession(accession)
136 seq_detail = self.seqtree[accession]
137 node_id = int(seq_detail.node_id)
138 del seq_detail
140 # return embedding, self.gene_id_dict[gene_id], self.seqtree[self.accessions[0]].node_id # hack
141 return embedding, node_id
144# @dataclass(kw_only=True)
145# class BloodhoundPredictionDataset(Dataset):
146# embeddings: list[torch.Tensor]
147# gene_family_ids: list[int]
149# def __post_init__(self):
150# assert len(self.embeddings) == len(self.gene_family_ids)
152# def __len__(self):
153# return len(self.gene_family_ids)
155# def __getitem__(self, idx):
156# return self.embeddings[idx] #, self.gene_family_ids[idx]
159@dataclass
160class BloodhoundDataModule(L.LightningDataModule):
161 seqtree: SeqTree
162 # seqbank: SeqBank
163 array:np.memmap|np.ndarray
164 accession_to_array_index:dict[str,int]
165 gene_id_dict: dict[str,int]
166 max_items: int = 0
167 batch_size: int = 16
168 num_workers: int = 0
169 validation_partition:int = 0
170 test_partition:int = -1
171 train_all:bool = False
173 def __init__(
174 self,
175 seqtree: SeqTree,
176 array:np.memmap|np.ndarray,
177 accession_to_array_index:dict[str,list[int]],
178 gene_id_dict: dict[str,int],
179 max_items: int = 0,
180 batch_size: int = 16,
181 num_workers: int = None,
182 validation_partition:int = 0,
183 test_partition:int=-1,
184 seq_count:int=0,
185 train_all:bool=False,
186 ):
187 super().__init__()
188 self.array = array
189 self.accession_to_array_index = accession_to_array_index
190 self.seqtree = seqtree
191 self.gene_id_dict = gene_id_dict
192 self.max_items = max_items
193 self.batch_size = batch_size
194 self.validation_partition = validation_partition
195 self.test_partition = test_partition
196 self.num_workers = min(os.cpu_count(), 8) if num_workers is None else num_workers
197 self.seq_count = seq_count
198 self.train_all = train_all
200 def setup(self, stage=None):
201 # make assignments here (val/train/test split)
202 # called on every process in DDP
203 self.training = []
204 self.validation = []
206 for accession, details in self.seqtree.items():
207 partition = details.partition
208 if partition == self.test_partition:
209 continue
211 dataset = self.validation if partition == self.validation_partition else self.training
212 dataset.append( accession )
214 if self.max_items and len(self.training) >= self.max_items and len(self.validation) > 0:
215 break
217 if self.train_all:
218 self.training += self.validation
220 self.train_dataset = self.create_dataset(self.training)
221 self.val_dataset = self.create_dataset(self.validation)
223 def create_dataset(self, accessions:list[str]) -> BloodhoundTrainingDataset:
224 return BloodhoundTrainingDataset(
225 accessions=accessions,
226 seqtree=self.seqtree,
227 array=self.array,
228 accession_to_array_index=self.accession_to_array_index,
229 gene_id_dict=self.gene_id_dict,
230 seq_count=self.seq_count,
231 )
233 def train_dataloader(self):
234 return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
236 def val_dataloader(self):
237 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)