Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2import numpy as np 

3from pathlib import Path 

4from torch import nn 

5import lightning as L 

6from torchmetrics import Metric 

7from hierarchicalsoftmax.metrics import RankAccuracyTorchMetric 

8from corgi.seqtree import SeqTree 

9from hierarchicalsoftmax import HierarchicalSoftmaxLoss, SoftmaxNode 

10from torch.utils.data import DataLoader 

11from collections.abc import Iterable 

12from rich.console import Console 

13 

14from bloodhound.markers import extract_single_copy_markers 

15from collections import defaultdict 

16from rich.progress import track 

17 

18import pandas as pd 

19from hierarchicalsoftmax.inference import node_probabilities, greedy_predictions, render_probabilities 

20 

21from torchapp import Param, method, TorchApp 

22from .models import BloodhoundModel 

23from .data import read_memmap, RANKS, BloodhoundDataModule, BloodhoundPredictionDataset 

24from .embeddings.esm import ESMEmbedding 

25 

26console = Console() 

27 

28 

29 

30class Bloodhound(TorchApp): 

31 @method 

32 def setup( 

33 self, 

34 memmap:str=None, 

35 memmap_index:str=None, 

36 seqtree:str=None, 

37 in_memory:bool=False, 

38 tip_alpha:float=None, 

39 ) -> None: 

40 if not seqtree: 

41 raise ValueError("seqtree is required") 

42 if not memmap: 

43 raise ValueError("memmap is required") 

44 if not memmap_index: 

45 raise ValueError("memmap_index is required") 

46 

47 print(f"Loading seqtree {seqtree}") 

48 individual_seqtree = SeqTree.load(seqtree) 

49 self.seqtree = SeqTree(classification_tree=individual_seqtree.classification_tree) 

50 

51 # Sets the loss weighting for the tips 

52 if tip_alpha: 

53 for tip in self.seqtree.classification_tree.leaves: 

54 tip.parent.alpha = tip_alpha 

55 

56 print(f"Loading memmap") 

57 self.accession_to_array_index = defaultdict(list) 

58 with open(memmap_index) as f: 

59 for key_index, key in enumerate(f): 

60 key = key.strip() 

61 accession = key.strip().split("/")[0] 

62 

63 if len(self.accession_to_array_index[accession]) == 0: 

64 self.seqtree[accession] = individual_seqtree[key] 

65 

66 self.accession_to_array_index[accession].append(key_index) 

67 count = key_index + 1 

68 self.array = read_memmap(memmap, count) 

69 

70 # If there's enough memory, then read into RAM 

71 if in_memory: 

72 self.array = np.array(self.array) 

73 

74 self.classification_tree = self.seqtree.classification_tree 

75 assert self.classification_tree is not None 

76 

77 # Get list of gene families 

78 family_ids = set() 

79 for accession in self.seqtree: 

80 gene_id = accession.split("/")[-1] 

81 family_ids.add(gene_id) 

82 

83 self.gene_id_dict = {family_id:index for index, family_id in enumerate(sorted(family_ids))} 

84 

85 @method 

86 def model( 

87 self, 

88 features:int=768, 

89 intermediate_layers:int=2, 

90 growth_factor:float=2.0, 

91 attention_size:int=512, 

92 ) -> nn.Module: 

93 return BloodhoundModel( 

94 classification_tree=self.classification_tree, 

95 features=features, 

96 intermediate_layers=intermediate_layers, 

97 growth_factor=growth_factor, 

98 attention_size=attention_size, 

99 ) 

100 

101 # @method 

102 # def input_count(self) -> int: 

103 # return 1 

104 

105 @method 

106 def loss_function(self): 

107 # def dummy_loss(prediction, target): 

108 # return prediction[0,0] * 0.0 

109 # # return prediction.mean() * 0.0 

110 # return dummy_loss # hack 

111 return HierarchicalSoftmaxLoss(root=self.classification_tree) 

112 

113 @method 

114 def metrics(self) -> list[tuple[str,Metric]]: 

115 rank_accuracy = RankAccuracyTorchMetric( 

116 root=self.classification_tree, 

117 ranks={1+i:rank for i, rank in enumerate(RANKS)}, 

118 ) 

119 

120 return [('rank_accuracy', rank_accuracy)] 

121 

122 @method 

123 def data( 

124 self, 

125 max_items:int=0, 

126 num_workers:int=4, 

127 validation_partition:int=0, 

128 batch_size:int = 4, 

129 test_partition:int=-1, 

130 seq_count:int=32, 

131 train_all:bool = False, 

132 ) -> Iterable|L.LightningDataModule: 

133 return BloodhoundDataModule( 

134 array=self.array, 

135 accession_to_array_index=self.accession_to_array_index, 

136 seqtree=self.seqtree, 

137 gene_id_dict=self.gene_id_dict, 

138 max_items=max_items, 

139 batch_size=batch_size, 

140 num_workers=num_workers, 

141 validation_partition=validation_partition, 

142 test_partition=test_partition, 

143 seq_count=seq_count, 

144 train_all=train_all, 

145 ) 

146 

147 @method 

148 def extra_hyperparameters(self, embedding_model:str="") -> dict: 

149 """ Extra hyperparameters to save with the module. """ 

150 assert embedding_model, f"Please provide an embedding model." 

151 embedding_model = embedding_model.lower() 

152 if embedding_model.startswith("esm"): 

153 layers = embedding_model[3:].strip() 

154 embedding_model = ESMEmbedding() 

155 embedding_model.setup(layers=layers) 

156 else: 

157 raise ValueError(f"Cannot understand embedding model: {embedding_model}") 

158 

159 return dict( 

160 embedding_model=embedding_model, 

161 classification_tree=self.seqtree.classification_tree, 

162 gene_id_dict=self.gene_id_dict, 

163 ) 

164 

165 @method 

166 def prediction_dataloader( 

167 self, 

168 module, 

169 input:Path=Param(help="A path to a directory of fasta files or a single fasta file."), 

170 out_dir:Path=Param(help="A path to the output directory."), 

171 hmm_models_dir:Path=Param(help="A path to the HMM models directory containing the Pfam and TIGRFAM HMMs."), 

172 torch_hub:Path=Param(help="The path to the Torch Hub directory", envvar="TORCH_HOME"), 

173 memmap_array:Path=None, # TODO explain 

174 memmap_index:Path=None, # TODO explain 

175 extension='fa', 

176 prefix:str="gtdbtk", 

177 cpus:int=1, 

178 batch_size:int = 64, 

179 num_workers: int = 0, 

180 force_embed:bool=False, 

181 repeats:int = Param(2, help="The minimum number of times to use each protein embedding in the prediction."), 

182 **kwargs, 

183 ) -> Iterable: 

184 # Get hyperparameters from checkpoint 

185 # esm_layers = ESMLayers.from_value(module.hparams.get('esm_layers', module.hparams.embedding_model.layers)) 

186 # embedding_model = module.hparams.embedding_model 

187 # embedding_model.setup(layers = esm_layers, hub_dir=torch_hub) # HACK 

188 

189 seq_count = module.hparams.get('seq_count', 32) 

190 self.classification_tree = module.hparams.classification_tree 

191 genomes = dict() 

192 input = Path(input) 

193 if input.is_dir(): 

194 for path in input.rglob(f"*.{extension}"): 

195 genomes[path.stem] = str(path) 

196 else: 

197 genomes[input.stem] = str(input) 

198 

199 self.name = input.name 

200 

201 

202 memmap_array_path = memmap_array 

203 if memmap_array_path and memmap_array_path.exists() and memmap_index and memmap_index.exists() and not force_embed: 

204 print(f"Loading memmap") 

205 accessions = memmap_index.read_text().strip().split("\n") 

206 embeddings = read_memmap(memmap_array_path, len(accessions)) 

207 else: 

208 # TODO: figure out the best way to set this e.g. use the number of ar53 vs bac120 genes found  

209 # or use extract it from the bloodhound model 

210 domain = "bac120" 

211 # domain = "ar53" if len(module.hparams.gene_id_dict) == 53 else "bac120" 

212 

213 #################### 

214 # Extract single copy marker genes 

215 #################### 

216 fastas = extract_single_copy_markers( 

217 genomes=genomes, 

218 out_dir=str(out_dir), 

219 cpus=cpus, 

220 force=True, 

221 pfam_db=hmm_models_dir / "pfam" / "Pfam-A.hmm", 

222 tigr_db=hmm_models_dir / "tigrfam" / "tigrfam.hmm", 

223 ) 

224 

225 ####################### 

226 # Create Embeddings 

227 ####################### 

228 embeddings = [] 

229 accessions = [] 

230 assert len(genomes) == 1 # hack for now 

231 genome = list(genomes.keys())[0] 

232 fastas = fastas[genome][domain] 

233 for fasta in track(fastas, description="[cyan]Embedding... ", total=len(fastas)): 

234 # read the fasta file sequence remove the header 

235 fasta = Path(fasta) 

236 seq = fasta.read_text().split("\n")[1] 

237 vector = module.hparams.embedding_model(seq) 

238 if vector is not None and not torch.isnan(vector).any(): 

239 vector = vector.cpu().detach().clone().numpy() 

240 embeddings.append(vector) 

241 

242 gene_family_id = fasta.stem 

243 accession = f"{genome}/{gene_family_id}" 

244 accessions.append(accession) 

245 

246 del vector 

247 

248 embeddings = np.asarray(embeddings).astype(np.float16) 

249 if memmap_array_path is not None and memmap_index is not None: 

250 memmap_array_path.parent.mkdir(exist_ok=True, parents=True) 

251 memmap_array = np.memmap(memmap_array_path, dtype=embeddings.dtype, mode='w+', shape=embeddings.shape) 

252 memmap_array[:] = embeddings[:,:] 

253 memmap_array.flush() 

254 

255 memmap_index.parent.mkdir(exist_ok=True, parents=True) 

256 memmap_index.write_text("\n".join(accessions)) 

257 

258 # Copy memmap for gene family into output memmap 

259 self.prediction_dataset = BloodhoundPredictionDataset( 

260 array=embeddings, 

261 accessions=accessions, 

262 seq_count=seq_count, 

263 repeats=repeats, 

264 seed=42, 

265 ) 

266 dataloader = DataLoader(self.prediction_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) 

267 

268 return dataloader 

269 

270 def node_to_str(self, node:SoftmaxNode) -> str: 

271 """  

272 Converts the node to a string 

273 """ 

274 return str(node).split(",")[-1].strip() 

275 

276 def output_results_to_df( 

277 self, 

278 names, 

279 results, 

280 output_csv: Path, 

281 output_tips_csv: Path, 

282 image_dir: Path, 

283 image_threshold:float, 

284 prediction_threshold:float, 

285 ) -> pd.DataFrame: 

286 classification_probabilities = node_probabilities(results, root=self.classification_tree) 

287 

288 node_list = self.classification_tree.node_list_softmax 

289 category_names = [self.node_to_str(node) for node in node_list if not node.is_root] 

290 

291 results_df = pd.DataFrame(classification_probabilities.numpy(), columns=category_names) 

292 

293 classification_probabilities = torch.as_tensor(results_df[category_names].to_numpy()) 

294 

295 # get greedy predictions which can use the raw activation or the softmax probabilities 

296 predictions = greedy_predictions( 

297 classification_probabilities, 

298 root=self.classification_tree, 

299 threshold=prediction_threshold, 

300 ) 

301 

302 results_df['greedy_prediction'] = [ 

303 self.node_to_str(node) 

304 for node in predictions 

305 ] 

306 

307 def get_prediction_probability(row): 

308 prediction = row["greedy_prediction"] 

309 if prediction in row: 

310 return row[prediction] 

311 return 1.0 

312 

313 results_df['probability'] = results_df.apply(get_prediction_probability, axis=1) 

314 

315 results_df["name"] = names 

316 

317 # Reorder columns 

318 results_df = results_df[["name", "greedy_prediction", "probability" ] + category_names] 

319 

320 results_df.set_index('name') 

321 

322 if not (image_dir or output_csv or output_tips_csv): 

323 print("No output files requested.") 

324 

325 # Output images 

326 if image_dir: 

327 console.print(f"Writing inference probability renders to: {image_dir}") 

328 image_dir = Path(image_dir) 

329 image_paths = [image_dir/f"{name}.png" for name in results_df["name"]] 

330 render_probabilities( 

331 root=self.classification_tree, 

332 filepaths=image_paths, 

333 probabilities=classification_probabilities, 

334 predictions=predictions, 

335 threshold=image_threshold, 

336 ) 

337 

338 if output_tips_csv: 

339 output_tips_csv = Path(output_tips_csv) 

340 output_tips_csv.parent.mkdir(exist_ok=True, parents=True) 

341 non_tips = [self.node_to_str(node) for node in self.classification_tree.node_list if not node.is_leaf] 

342 tips_df = results_df.drop(columns=non_tips) 

343 tips_df.to_csv(output_tips_csv, index=False) 

344 

345 if output_csv: 

346 output_csv = Path(output_csv) 

347 output_csv.parent.mkdir(exist_ok=True, parents=True) 

348 console.print(f"Writing results for {len(results_df)} sequences to: {output_csv}") 

349 results_df.transpose().to_csv(output_csv) 

350 

351 return results_df 

352 

353 @method 

354 def output_results( 

355 self, 

356 results, 

357 output_csv: Path = Param(default=None, help="A path to output the results as a CSV."), 

358 output_tips_csv: Path = Param(default=None, help="A path to output the results as a CSV which only stores the probabilities at the tips."), 

359 output_averaged_csv: Path = Param(default=None, help="A path to output the results as a CSV."), 

360 output_averaged_tips_csv: Path = Param(default=None, help="A path to output the results as a CSV which only stores the probabilities at the tips."), 

361 output_gene_csv: Path = Param(default=None, help="A path to output the results for individual genes as a CSV."), 

362 output_gene_tips_csv: Path = Param(default=None, help="A path to output the results as a CSV which only stores the probabilities at the tips."), 

363 image_dir: Path = Param(default=None, help="A path to output the results as images."), 

364 image_threshold:float = 0.005, 

365 prediction_threshold:float = Param(default=0.0, help="The threshold value for making hierarchical predictions."), 

366 gene_images:bool=False, 

367 **kwargs, 

368 ): 

369 assert self.classification_tree 

370 

371 if output_gene_csv or output_gene_tips_csv or gene_images: 

372 self.output_results_to_df( 

373 self.gene_family_names, 

374 results, 

375 output_gene_csv, 

376 output_gene_tips_csv, 

377 image_dir if gene_images else None, 

378 image_threshold=image_threshold, 

379 prediction_threshold=prediction_threshold, 

380 ) 

381 

382 result = self.output_results_to_df( 

383 [self.name], 

384 results.mean(axis=0, keepdims=True), 

385 output_csv, 

386 output_tips_csv, 

387 image_dir, 

388 image_threshold=image_threshold, 

389 prediction_threshold=prediction_threshold, 

390 ) 

391 

392 # if output_averaged_csv or output_averaged_tips_csv: 

393 # self.output_results_to_df( 

394 # [self.name+"-averaged"], 

395 # results.mean(axis=0, keepdims=True), 

396 # output_averaged_csv, 

397 # output_averaged_tips_csv, 

398 # image_dir, 

399 # image_threshold=image_threshold, 

400 # prediction_threshold=prediction_threshold, 

401 # ) 

402 

403 return result 

404 

405 # @method 

406 # def extra_callbacks_off(self, **kwargs): 

407 # from lightning.pytorch.callbacks import Callback 

408 # import tracemalloc 

409 # class MemoryLeakCallback(Callback): 

410 # def on_train_start(self, trainer, pl_module): 

411 # # Start tracing memory allocations at the beginning of the training 

412 # tracemalloc.start() 

413 # print("tracemalloc started") 

414 

415 # def on_train_batch_start(self, trainer, pl_module, *args, **kwargs): 

416 # # Take a snapshot before the batch starts 

417 # self.snapshot_before = tracemalloc.take_snapshot() 

418 

419 # def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): 

420 # # Take a snapshot after the batch ends 

421 # snapshot_after = tracemalloc.take_snapshot() 

422 

423 # # Compare the snapshots 

424 # stats = snapshot_after.compare_to(self.snapshot_before, 'lineno') 

425 

426 # # Log the top memory-consuming lines 

427 # print(f"[Batch {trainer.global_step}] Memory differences:") 

428 # for stat in stats[:20]: 

429 # print(stat) 

430 

431 # # Optionally, monitor peak memory usage 

432 # current, peak = tracemalloc.get_traced_memory() 

433 # print(f"Current memory usage: {current / 1024**2:.2f} MB; Peak: {peak / 1024**2:.2f} MB") 

434 

435 # # Clear traces if needed to prevent tracemalloc from consuming too much memory itself 

436 # tracemalloc.clear_traces() 

437 

438 # return [MemoryLeakCallback()] 

439 

440