Coverage for bloodhound/apps.py : 24.12%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2import numpy as np
3from pathlib import Path
4from torch import nn
5import lightning as L
6from torchmetrics import Metric
7from hierarchicalsoftmax.metrics import RankAccuracyTorchMetric
8from corgi.seqtree import SeqTree
9from hierarchicalsoftmax import HierarchicalSoftmaxLoss, SoftmaxNode
10from torch.utils.data import DataLoader
11from collections.abc import Iterable
12from rich.console import Console
14from bloodhound.markers import extract_single_copy_markers
15from collections import defaultdict
16from rich.progress import track
18import pandas as pd
19from hierarchicalsoftmax.inference import node_probabilities, greedy_predictions, render_probabilities
21from torchapp import Param, method, TorchApp
22from .models import BloodhoundModel
23from .data import read_memmap, RANKS, BloodhoundDataModule, BloodhoundPredictionDataset
24from .embeddings.esm import ESMEmbedding
26console = Console()
30class Bloodhound(TorchApp):
31 @method
32 def setup(
33 self,
34 memmap:str=None,
35 memmap_index:str=None,
36 seqtree:str=None,
37 in_memory:bool=False,
38 tip_alpha:float=None,
39 ) -> None:
40 if not seqtree:
41 raise ValueError("seqtree is required")
42 if not memmap:
43 raise ValueError("memmap is required")
44 if not memmap_index:
45 raise ValueError("memmap_index is required")
47 print(f"Loading seqtree {seqtree}")
48 individual_seqtree = SeqTree.load(seqtree)
49 self.seqtree = SeqTree(classification_tree=individual_seqtree.classification_tree)
51 # Sets the loss weighting for the tips
52 if tip_alpha:
53 for tip in self.seqtree.classification_tree.leaves:
54 tip.parent.alpha = tip_alpha
56 print(f"Loading memmap")
57 self.accession_to_array_index = defaultdict(list)
58 with open(memmap_index) as f:
59 for key_index, key in enumerate(f):
60 key = key.strip()
61 accession = key.strip().split("/")[0]
63 if len(self.accession_to_array_index[accession]) == 0:
64 self.seqtree[accession] = individual_seqtree[key]
66 self.accession_to_array_index[accession].append(key_index)
67 count = key_index + 1
68 self.array = read_memmap(memmap, count)
70 # If there's enough memory, then read into RAM
71 if in_memory:
72 self.array = np.array(self.array)
74 self.classification_tree = self.seqtree.classification_tree
75 assert self.classification_tree is not None
77 # Get list of gene families
78 family_ids = set()
79 for accession in self.seqtree:
80 gene_id = accession.split("/")[-1]
81 family_ids.add(gene_id)
83 self.gene_id_dict = {family_id:index for index, family_id in enumerate(sorted(family_ids))}
85 @method
86 def model(
87 self,
88 features:int=768,
89 intermediate_layers:int=2,
90 growth_factor:float=2.0,
91 attention_size:int=512,
92 ) -> nn.Module:
93 return BloodhoundModel(
94 classification_tree=self.classification_tree,
95 features=features,
96 intermediate_layers=intermediate_layers,
97 growth_factor=growth_factor,
98 attention_size=attention_size,
99 )
101 # @method
102 # def input_count(self) -> int:
103 # return 1
105 @method
106 def loss_function(self):
107 # def dummy_loss(prediction, target):
108 # return prediction[0,0] * 0.0
109 # # return prediction.mean() * 0.0
110 # return dummy_loss # hack
111 return HierarchicalSoftmaxLoss(root=self.classification_tree)
113 @method
114 def metrics(self) -> list[tuple[str,Metric]]:
115 rank_accuracy = RankAccuracyTorchMetric(
116 root=self.classification_tree,
117 ranks={1+i:rank for i, rank in enumerate(RANKS)},
118 )
120 return [('rank_accuracy', rank_accuracy)]
122 @method
123 def data(
124 self,
125 max_items:int=0,
126 num_workers:int=4,
127 validation_partition:int=0,
128 batch_size:int = 4,
129 test_partition:int=-1,
130 seq_count:int=32,
131 train_all:bool = False,
132 ) -> Iterable|L.LightningDataModule:
133 return BloodhoundDataModule(
134 array=self.array,
135 accession_to_array_index=self.accession_to_array_index,
136 seqtree=self.seqtree,
137 gene_id_dict=self.gene_id_dict,
138 max_items=max_items,
139 batch_size=batch_size,
140 num_workers=num_workers,
141 validation_partition=validation_partition,
142 test_partition=test_partition,
143 seq_count=seq_count,
144 train_all=train_all,
145 )
147 @method
148 def extra_hyperparameters(self, embedding_model:str="") -> dict:
149 """ Extra hyperparameters to save with the module. """
150 assert embedding_model, f"Please provide an embedding model."
151 embedding_model = embedding_model.lower()
152 if embedding_model.startswith("esm"):
153 layers = embedding_model[3:].strip()
154 embedding_model = ESMEmbedding()
155 embedding_model.setup(layers=layers)
156 else:
157 raise ValueError(f"Cannot understand embedding model: {embedding_model}")
159 return dict(
160 embedding_model=embedding_model,
161 classification_tree=self.seqtree.classification_tree,
162 gene_id_dict=self.gene_id_dict,
163 )
165 @method
166 def prediction_dataloader(
167 self,
168 module,
169 input:Path=Param(help="A path to a directory of fasta files or a single fasta file."),
170 out_dir:Path=Param(help="A path to the output directory."),
171 hmm_models_dir:Path=Param(help="A path to the HMM models directory containing the Pfam and TIGRFAM HMMs."),
172 torch_hub:Path=Param(help="The path to the Torch Hub directory", envvar="TORCH_HOME"),
173 memmap_array:Path=None, # TODO explain
174 memmap_index:Path=None, # TODO explain
175 extension='fa',
176 prefix:str="gtdbtk",
177 cpus:int=1,
178 batch_size:int = 64,
179 num_workers: int = 0,
180 force_embed:bool=False,
181 repeats:int = Param(2, help="The minimum number of times to use each protein embedding in the prediction."),
182 **kwargs,
183 ) -> Iterable:
184 # Get hyperparameters from checkpoint
185 # esm_layers = ESMLayers.from_value(module.hparams.get('esm_layers', module.hparams.embedding_model.layers))
186 # embedding_model = module.hparams.embedding_model
187 # embedding_model.setup(layers = esm_layers, hub_dir=torch_hub) # HACK
189 seq_count = module.hparams.get('seq_count', 32)
190 self.classification_tree = module.hparams.classification_tree
191 genomes = dict()
192 input = Path(input)
193 if input.is_dir():
194 for path in input.rglob(f"*.{extension}"):
195 genomes[path.stem] = str(path)
196 else:
197 genomes[input.stem] = str(input)
199 self.name = input.name
202 memmap_array_path = memmap_array
203 if memmap_array_path and memmap_array_path.exists() and memmap_index and memmap_index.exists() and not force_embed:
204 print(f"Loading memmap")
205 accessions = memmap_index.read_text().strip().split("\n")
206 embeddings = read_memmap(memmap_array_path, len(accessions))
207 else:
208 # TODO: figure out the best way to set this e.g. use the number of ar53 vs bac120 genes found
209 # or use extract it from the bloodhound model
210 domain = "bac120"
211 # domain = "ar53" if len(module.hparams.gene_id_dict) == 53 else "bac120"
213 ####################
214 # Extract single copy marker genes
215 ####################
216 fastas = extract_single_copy_markers(
217 genomes=genomes,
218 out_dir=str(out_dir),
219 cpus=cpus,
220 force=True,
221 pfam_db=hmm_models_dir / "pfam" / "Pfam-A.hmm",
222 tigr_db=hmm_models_dir / "tigrfam" / "tigrfam.hmm",
223 )
225 #######################
226 # Create Embeddings
227 #######################
228 embeddings = []
229 accessions = []
230 assert len(genomes) == 1 # hack for now
231 genome = list(genomes.keys())[0]
232 fastas = fastas[genome][domain]
233 for fasta in track(fastas, description="[cyan]Embedding... ", total=len(fastas)):
234 # read the fasta file sequence remove the header
235 fasta = Path(fasta)
236 seq = fasta.read_text().split("\n")[1]
237 vector = module.hparams.embedding_model(seq)
238 if vector is not None and not torch.isnan(vector).any():
239 vector = vector.cpu().detach().clone().numpy()
240 embeddings.append(vector)
242 gene_family_id = fasta.stem
243 accession = f"{genome}/{gene_family_id}"
244 accessions.append(accession)
246 del vector
248 embeddings = np.asarray(embeddings).astype(np.float16)
249 if memmap_array_path is not None and memmap_index is not None:
250 memmap_array_path.parent.mkdir(exist_ok=True, parents=True)
251 memmap_array = np.memmap(memmap_array_path, dtype=embeddings.dtype, mode='w+', shape=embeddings.shape)
252 memmap_array[:] = embeddings[:,:]
253 memmap_array.flush()
255 memmap_index.parent.mkdir(exist_ok=True, parents=True)
256 memmap_index.write_text("\n".join(accessions))
258 # Copy memmap for gene family into output memmap
259 self.prediction_dataset = BloodhoundPredictionDataset(
260 array=embeddings,
261 accessions=accessions,
262 seq_count=seq_count,
263 repeats=repeats,
264 seed=42,
265 )
266 dataloader = DataLoader(self.prediction_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
268 return dataloader
270 def node_to_str(self, node:SoftmaxNode) -> str:
271 """
272 Converts the node to a string
273 """
274 return str(node).split(",")[-1].strip()
276 def output_results_to_df(
277 self,
278 names,
279 results,
280 output_csv: Path,
281 output_tips_csv: Path,
282 image_dir: Path,
283 image_threshold:float,
284 prediction_threshold:float,
285 ) -> pd.DataFrame:
286 classification_probabilities = node_probabilities(results, root=self.classification_tree)
288 node_list = self.classification_tree.node_list_softmax
289 category_names = [self.node_to_str(node) for node in node_list if not node.is_root]
291 results_df = pd.DataFrame(classification_probabilities.numpy(), columns=category_names)
293 classification_probabilities = torch.as_tensor(results_df[category_names].to_numpy())
295 # get greedy predictions which can use the raw activation or the softmax probabilities
296 predictions = greedy_predictions(
297 classification_probabilities,
298 root=self.classification_tree,
299 threshold=prediction_threshold,
300 )
302 results_df['greedy_prediction'] = [
303 self.node_to_str(node)
304 for node in predictions
305 ]
307 def get_prediction_probability(row):
308 prediction = row["greedy_prediction"]
309 if prediction in row:
310 return row[prediction]
311 return 1.0
313 results_df['probability'] = results_df.apply(get_prediction_probability, axis=1)
315 results_df["name"] = names
317 # Reorder columns
318 results_df = results_df[["name", "greedy_prediction", "probability" ] + category_names]
320 results_df.set_index('name')
322 if not (image_dir or output_csv or output_tips_csv):
323 print("No output files requested.")
325 # Output images
326 if image_dir:
327 console.print(f"Writing inference probability renders to: {image_dir}")
328 image_dir = Path(image_dir)
329 image_paths = [image_dir/f"{name}.png" for name in results_df["name"]]
330 render_probabilities(
331 root=self.classification_tree,
332 filepaths=image_paths,
333 probabilities=classification_probabilities,
334 predictions=predictions,
335 threshold=image_threshold,
336 )
338 if output_tips_csv:
339 output_tips_csv = Path(output_tips_csv)
340 output_tips_csv.parent.mkdir(exist_ok=True, parents=True)
341 non_tips = [self.node_to_str(node) for node in self.classification_tree.node_list if not node.is_leaf]
342 tips_df = results_df.drop(columns=non_tips)
343 tips_df.to_csv(output_tips_csv, index=False)
345 if output_csv:
346 output_csv = Path(output_csv)
347 output_csv.parent.mkdir(exist_ok=True, parents=True)
348 console.print(f"Writing results for {len(results_df)} sequences to: {output_csv}")
349 results_df.transpose().to_csv(output_csv)
351 return results_df
353 @method
354 def output_results(
355 self,
356 results,
357 output_csv: Path = Param(default=None, help="A path to output the results as a CSV."),
358 output_tips_csv: Path = Param(default=None, help="A path to output the results as a CSV which only stores the probabilities at the tips."),
359 output_averaged_csv: Path = Param(default=None, help="A path to output the results as a CSV."),
360 output_averaged_tips_csv: Path = Param(default=None, help="A path to output the results as a CSV which only stores the probabilities at the tips."),
361 output_gene_csv: Path = Param(default=None, help="A path to output the results for individual genes as a CSV."),
362 output_gene_tips_csv: Path = Param(default=None, help="A path to output the results as a CSV which only stores the probabilities at the tips."),
363 image_dir: Path = Param(default=None, help="A path to output the results as images."),
364 image_threshold:float = 0.005,
365 prediction_threshold:float = Param(default=0.0, help="The threshold value for making hierarchical predictions."),
366 gene_images:bool=False,
367 **kwargs,
368 ):
369 assert self.classification_tree
371 if output_gene_csv or output_gene_tips_csv or gene_images:
372 self.output_results_to_df(
373 self.gene_family_names,
374 results,
375 output_gene_csv,
376 output_gene_tips_csv,
377 image_dir if gene_images else None,
378 image_threshold=image_threshold,
379 prediction_threshold=prediction_threshold,
380 )
382 result = self.output_results_to_df(
383 [self.name],
384 results.mean(axis=0, keepdims=True),
385 output_csv,
386 output_tips_csv,
387 image_dir,
388 image_threshold=image_threshold,
389 prediction_threshold=prediction_threshold,
390 )
392 # if output_averaged_csv or output_averaged_tips_csv:
393 # self.output_results_to_df(
394 # [self.name+"-averaged"],
395 # results.mean(axis=0, keepdims=True),
396 # output_averaged_csv,
397 # output_averaged_tips_csv,
398 # image_dir,
399 # image_threshold=image_threshold,
400 # prediction_threshold=prediction_threshold,
401 # )
403 return result
405 # @method
406 # def extra_callbacks_off(self, **kwargs):
407 # from lightning.pytorch.callbacks import Callback
408 # import tracemalloc
409 # class MemoryLeakCallback(Callback):
410 # def on_train_start(self, trainer, pl_module):
411 # # Start tracing memory allocations at the beginning of the training
412 # tracemalloc.start()
413 # print("tracemalloc started")
415 # def on_train_batch_start(self, trainer, pl_module, *args, **kwargs):
416 # # Take a snapshot before the batch starts
417 # self.snapshot_before = tracemalloc.take_snapshot()
419 # def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
420 # # Take a snapshot after the batch ends
421 # snapshot_after = tracemalloc.take_snapshot()
423 # # Compare the snapshots
424 # stats = snapshot_after.compare_to(self.snapshot_before, 'lineno')
426 # # Log the top memory-consuming lines
427 # print(f"[Batch {trainer.global_step}] Memory differences:")
428 # for stat in stats[:20]:
429 # print(stat)
431 # # Optionally, monitor peak memory usage
432 # current, peak = tracemalloc.get_traced_memory()
433 # print(f"Current memory usage: {current / 1024**2:.2f} MB; Peak: {peak / 1024**2:.2f} MB")
435 # # Clear traces if needed to prevent tracemalloc from consuming too much memory itself
436 # tracemalloc.clear_traces()
438 # return [MemoryLeakCallback()]