Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import gzip 

2import os 

3from pathlib import Path 

4import os 

5from abc import ABC, abstractmethod 

6from Bio import SeqIO 

7import random 

8import numpy as np 

9from rich.progress import track 

10from hierarchicalsoftmax import SoftmaxNode 

11from corgi.seqtree import SeqTree 

12import tarfile 

13import torch 

14from io import StringIO 

15from torchapp.cli import CLIApp, tool, method 

16import typer 

17from dataclasses import dataclass 

18 

19from .data import read_memmap, RANKS 

20 

21 

22def _open(path, mode='rt', **kwargs): 

23 """ 

24 Open a file normally, or with gzip if it ends in .gz. 

25  

26 Args: 

27 path (str or Path): The path to the file. 

28 mode (str): The mode to open the file with (default 'rt' for reading text). 

29 **kwargs: Additional arguments passed to open or gzip.open. 

30 

31 Returns: 

32 A file object. 

33 """ 

34 path = Path(path) 

35 if path.suffix == '.gz': 

36 return gzip.open(path, mode, **kwargs) 

37 return open(path, mode, **kwargs) 

38 

39 

40def set_validation_rank_to_seqtree( 

41 seqtree:SeqTree, 

42 validation_rank:str="species", 

43 partitions:int=5, 

44) -> SeqTree: 

45 # find the taxonomic rank to use for the validation partition 

46 validation_rank = validation_rank.lower() 

47 assert validation_rank in RANKS 

48 validation_rank_index = RANKS.index(validation_rank) 

49 

50 partitions_dict = {} 

51 for key in seqtree: 

52 node = seqtree.node(key) 

53 # Assign validation partition at set rank 

54 partition_node = node.ancestors[validation_rank_index] 

55 if partition_node not in partitions_dict: 

56 partitions_dict[partition_node] = random.randint(0,partitions-1) 

57 

58 seqtree[key].partition = partitions_dict[partition_node] 

59 

60 return seqtree 

61 

62 

63def get_key(accession:str, gene:str) -> str: 

64 """ Returns the standard format of a key """ 

65 key = f"{accession}/{gene}" 

66 return key 

67 

68 

69def get_node(lineage:str, lineage_to_node:dict[str,SoftmaxNode]) -> SoftmaxNode: 

70 if lineage in lineage_to_node: 

71 return lineage_to_node[lineage] 

72 

73 assert ";" in lineage, f"Semi-colon ';' not found in lineage '{lineage}'" 

74 split_point = lineage.rfind(";") 

75 parent_lineage = lineage[:split_point] 

76 name = lineage[split_point+1:] 

77 parent = get_node(parent_lineage, lineage_to_node) 

78 node = SoftmaxNode(name=name, parent=parent) 

79 lineage_to_node[lineage] = node 

80 return node 

81 

82 

83def generate_overlapping_intervals(total: int, interval_size: int, min_overlap: int, check:bool=True, variable_size:bool=False) -> list[tuple[int, int]]: 

84 """ 

85 Creates a list of overlapping intervals within a specified range, adjusting the interval size to ensure 

86 that the overlap is approximately the same across all intervals. 

87 

88 Args: 

89 total (int): The total range within which intervals are to be created. 

90 max_interval_size (int): The maximum size of each interval. 

91 min_overlap (int): The minimum number of units by which consecutive intervals overlap. 

92 check (bool): If True, checks are performed to ensure that the intervals meet the specified conditions. 

93 

94 Returns: 

95 list[tuple[int, int]]: A list of tuples where each tuple represents the start (inclusive)  

96 and end (exclusive) of an interval. 

97 

98 Example: 

99 >>> generate_overlapping_intervals(20, 5, 2) 

100 [(0, 5), (3, 8), (6, 11), (9, 14), (12, 17), (15, 20)] 

101 """ 

102 intervals = [] 

103 start = 0 

104 

105 if total == 0: 

106 return intervals 

107 

108 max_interval_size = interval_size 

109 assert interval_size 

110 assert min_overlap is not None 

111 assert interval_size > min_overlap, f"Max interval size of {interval_size} must be greater than min overlap of {min_overlap}" 

112 

113 # Calculate the number of intervals needed to cover the range 

114 num_intervals, remainder = divmod(total - min_overlap, interval_size - min_overlap) 

115 if remainder > 0: 

116 num_intervals += 1 

117 

118 # Calculate the exact interval size to ensure consistent overlap 

119 overlap = min_overlap 

120 if variable_size: 

121 if num_intervals > 1: 

122 interval_size, remainder = divmod(total + (num_intervals - 1) * overlap, num_intervals) 

123 if remainder > 0: 

124 interval_size += 1 

125 else: 

126 # If the size is fixed, then vary the overlap to keep it even 

127 if num_intervals > 1: 

128 overlap, remainder = divmod( num_intervals * interval_size - total, num_intervals - 1) 

129 if overlap < min_overlap: 

130 overlap = min_overlap 

131 

132 while True: 

133 end = start + interval_size 

134 if end > total: 

135 end = total 

136 start = max(end - interval_size,0) 

137 intervals.append((start, end)) 

138 start += interval_size - overlap 

139 if end >= total: 

140 break 

141 

142 if check: 

143 assert intervals[0][0] == 0 

144 assert intervals[-1][1] == total 

145 assert len(intervals) == num_intervals, f"Expected {num_intervals} intervals, got {len(intervals)}" 

146 

147 assert interval_size <= max_interval_size, f"Interval size of {interval_size} exceeds max interval size of {max_interval_size}" 

148 for interval in intervals: 

149 assert interval[1] - interval[0] == interval_size, f"Interval size of {interval[1] - interval[0]} is not the expected size {interval_size}" 

150 

151 for i in range(1, len(intervals)): 

152 overlap = intervals[i - 1][1] - intervals[i][0] 

153 assert overlap >= min_overlap, f"Min overlap condition of {min_overlap} not met for intervals {intervals[i - 1]} and {intervals[i]} (overlap {overlap})" 

154 

155 return intervals 

156 

157 

158@dataclass 

159class Embedding(CLIApp, ABC): 

160 """ A class for embedding protein sequences. """ 

161 max_length:int|None=None 

162 overlap:int=64 

163 

164 def __post_init__(self): 

165 super().__init__() 

166 

167 @abstractmethod 

168 def embed(self, seq:str) -> torch.Tensor: 

169 """ Takes a protein sequence as a string and returns an embedding vector. """ 

170 raise NotImplementedError 

171 

172 def reduce(self, tensor:torch.Tensor) -> torch.Tensor: 

173 if tensor.ndim == 2: 

174 tensor = tensor.mean(dim=0) 

175 assert tensor.ndim == 1 

176 return tensor 

177 

178 def __call__(self, seq:str) -> torch.Tensor: 

179 """ Takes a protein sequence as a string and returns an embedding vector. """ 

180 if not self.max_length or len(seq) <= self.max_length: 

181 tensor = self.embed(seq) 

182 return self.reduce(tensor) 

183 

184 epsilon = 0.1 

185 intervals = generate_overlapping_intervals(len(seq), self.max_length, self.overlap) 

186 weights = torch.zeros( (len(seq),), device="cpu" ) 

187 tensor = None 

188 for start,end in intervals: 

189 result = self.embed(seq[start:end]).cpu() 

190 

191 assert result.shape[0] == end-start 

192 embedding_size = result.shape[1] 

193 

194 if tensor is None: 

195 tensor = torch.zeros( (len(seq), embedding_size ), device="cpu") 

196 

197 assert tensor.shape[-1] == embedding_size 

198 

199 interval_indexes = torch.arange(end-start) 

200 distance_from_ends = torch.min( interval_indexes-start, end-interval_indexes-1 ) 

201 

202 weight = epsilon + torch.minimum(distance_from_ends, torch.tensor(self.overlap)) 

203 

204 tensor[start:end] += result * weight.unsqueeze(1) 

205 weights[start:end] += weight 

206 

207 tensor = tensor/weights.unsqueeze(1) 

208 

209 return self.reduce(tensor) 

210 

211 @method 

212 def setup(self, **kwargs): 

213 pass 

214 

215 def build_seqtree(self, taxonomy:Path) -> tuple[SeqTree,dict[str,SoftmaxNode]]: 

216 # Create root of tree 

217 lineage_to_node = {} 

218 root = None 

219 

220 # Fill out tree with taxonomy 

221 accession_to_node = {} 

222 with _open(taxonomy) as f: 

223 for line in f: 

224 accesssion, lineage = line.split("\t") 

225 

226 if not root: 

227 root_name = lineage.split(";")[0] 

228 root = SoftmaxNode(root_name) 

229 lineage_to_node[root_name] = root 

230 

231 node = get_node(lineage, lineage_to_node) 

232 accession_to_node[accesssion] = node 

233 

234 seqtree = SeqTree(classification_tree=root) 

235 return seqtree, accession_to_node 

236 

237 @tool("setup") 

238 def test_lengths( 

239 self, 

240 end:int=5_000, 

241 start:int=1000, 

242 retries:int=5, 

243 **kwargs, 

244 ): 

245 def random_amino_acid_sequence(k): 

246 amino_acids = "ACDEFGHIKLMNPQRSTVWY" # standard 20 amino acids 

247 return ''.join(random.choice(amino_acids) for _ in range(k)) 

248 

249 self.max_length = None 

250 self.setup(**kwargs) 

251 for ii in track(range(start,end)): 

252 for _ in range(retries): 

253 seq = random_amino_acid_sequence(ii) 

254 try: 

255 self(seq) 

256 except Exception as err: 

257 print(f"{ii}: {err}") 

258 return 

259 

260 

261 @tool("setup") 

262 def build_gene_array( 

263 self, 

264 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."), 

265 family_index:int=typer.Option(default=..., help="The index for the gene family to use. E.g. if there are 120 gene families then this should be a number from 0 to 119."), 

266 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."), 

267 flush_every:int=typer.Option(default=5_000, help="An interval to flush the memmap array as it is generated."), 

268 max_length:int=None, 

269 **kwargs, 

270 ): 

271 self.max_length = max_length 

272 self.setup(**kwargs) 

273 

274 assert marker_genes is not None 

275 assert family_index is not None 

276 assert output_dir is not None 

277 

278 dtype = 'float16' 

279 

280 memmap_wip_array = None 

281 output_dir.mkdir(parents=True, exist_ok=True) 

282 memmap_wip_path = output_dir / f"{family_index}-wip.npy" 

283 error = output_dir / f"{family_index}-errors.txt" 

284 accessions_wip = output_dir / f"{family_index}-accessions-wip.txt" 

285 

286 accessions = [] 

287 

288 print(f"Loading {marker_genes} file.") 

289 with tarfile.open(marker_genes, "r:gz") as tar, open(error, "w") as error_file, open(accessions_wip, "w") as accessions_wip_file: 

290 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")] 

291 prefix_length = len(os.path.commonprefix([Path(member.name).with_suffix("").name for member in members])) 

292 

293 member = members[family_index] 

294 print(f"Processing file {family_index} in {marker_genes}") 

295 

296 f = tar.extractfile(member) 

297 marker_id = Path(member.name).with_suffix("").name[prefix_length:] 

298 

299 fasta_io = StringIO(f.read().decode('ascii')) 

300 

301 total = sum(1 for _ in SeqIO.parse(fasta_io, "fasta")) 

302 fasta_io.seek(0) 

303 print(marker_id, total) 

304 

305 for record in track(SeqIO.parse(fasta_io, "fasta"), total=total): 

306 # for record in SeqIO.parse(fasta_io, "fasta"): 

307 species_accession = record.id 

308 

309 key = get_key(species_accession, marker_id) 

310 

311 seq = str(record.seq).replace("-","").replace("*","") 

312 try: 

313 vector = self(seq) 

314 except Exception as err: 

315 print(f"{key} ({len(seq)}): {err}", file=error_file) 

316 print(f"{key} ({len(seq)}): {err}") 

317 continue 

318 

319 if vector is None: 

320 print(f"{key} ({len(seq)}): Embedding is None", file=error_file) 

321 print(f"{key} ({len(seq)}): Embedding is None") 

322 continue 

323 

324 if torch.isnan(vector).any(): 

325 print(f"{key} ({len(seq)}): Embedding contains NaN", file=error_file) 

326 print(f"{key} ({len(seq)}): Embedding contains NaN") 

327 continue 

328 

329 if memmap_wip_array is None: 

330 size = len(vector) 

331 shape = (total,size) 

332 memmap_wip_array = np.memmap(memmap_wip_path, dtype=dtype, mode='w+', shape=shape) 

333 

334 index = len(accessions) 

335 memmap_wip_array[index,:] = vector.cpu().half().numpy() 

336 if index % flush_every == 0: 

337 memmap_wip_array.flush() 

338 

339 accessions.append(key) 

340 print(key, file=accessions_wip_file) 

341 

342 memmap_wip_array.flush() 

343 

344 accessions_path = output_dir / f"{family_index}.txt" 

345 with open(accessions_path, "w") as f: 

346 for accession in accessions: 

347 print(accession, file=f) 

348 

349 # Save final memmap array now that we now the final size 

350 memmap_path = output_dir / f"{family_index}.npy" 

351 shape = (len(accessions),size) 

352 print(f"Writing final memmap array of shape {shape}: {memmap_path}") 

353 memmap_array = np.memmap(memmap_path, dtype=dtype, mode='w+', shape=shape) 

354 memmap_array[:len(accessions),:] = memmap_wip_array[:len(accessions),:] 

355 memmap_array.flush() 

356 

357 # Clean up 

358 memmap_array._mmap.close() 

359 memmap_array._mmap = None 

360 memmap_array = None 

361 memmap_wip_path.unlink() 

362 accessions_wip.unlink() 

363 

364 @tool 

365 def set_validation_rank( 

366 self, 

367 seqtree:Path=typer.Option(default=..., help="The path to the seqtree file."), 

368 output:Path=typer.Option(default=..., help="The path to save the adapted seqtree file."), 

369 validation_rank:str=typer.Option(default="species", help="The rank to hold out for cross-validation."), 

370 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."), 

371 ) -> SeqTree: 

372 seqtree = SeqTree.load(seqtree) 

373 set_validation_rank_to_seqtree(seqtree, validation_rank=validation_rank, partitions=partitions) 

374 seqtree.save(output) 

375 return seqtree 

376 

377 @tool 

378 def preprocess( 

379 self, 

380 taxonomy:Path=typer.Option(default=..., help="The path to the TSV taxonomy file (e.g. bac120_taxonomy_r220.tsv)."), 

381 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."), 

382 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."), 

383 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."), 

384 seed:int=typer.Option(default=42, help="The random seed."), 

385 ): 

386 seqtree, accession_to_node = self.build_seqtree(taxonomy) 

387 

388 dtype = 'float16' 

389 

390 random.seed(seed) 

391 

392 print(f"Loading {marker_genes} file.") 

393 with tarfile.open(marker_genes, "r:gz") as tar: 

394 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")] 

395 family_count = len(members) 

396 print(f"{family_count} gene families found.") 

397 

398 # Read and collect accessions 

399 print(f"Building seqtree") 

400 keys = [] 

401 counts = [] 

402 node_to_partition_dict = dict() 

403 for family_index in track(range(family_count)): 

404 keys_path = output_dir / f"{family_index}.txt" 

405 

406 if not keys_path.exists(): 

407 counts.append(0) 

408 continue 

409 

410 with open(keys_path) as f: 

411 family_index_keys = [line.strip() for line in f] 

412 keys += family_index_keys 

413 counts.append(len(family_index_keys)) 

414 

415 for key in family_index_keys: 

416 species_accession = key.split("/")[0] 

417 node = accession_to_node[species_accession] 

418 partition = node_to_partition_dict.setdefault(key, random.randint(0, partitions - 1)) 

419 

420 # Add to seqtree 

421 seqtree.add(key, node, partition) 

422 

423 assert len(counts) == family_count 

424 

425 # Save seqtree 

426 seqtree_path = output_dir / f"{output_dir.name}.st" 

427 print(f"Saving seqtree to {seqtree_path}") 

428 seqtree.save(seqtree_path) 

429 

430 # Concatenate numpy memmap arrays 

431 memmap_array = None 

432 memmap_array_path = output_dir / f"{output_dir.name}.npy" 

433 print(f"Saving memmap to {memmap_array_path}") 

434 current_index = 0 

435 for family_index, family_count in track(enumerate(counts), total=len(counts)): 

436 my_memmap_path = output_dir / f"{family_index}.npy" 

437 

438 # Build memmap for gene family if it doesn't exist 

439 if not my_memmap_path.exists(): 

440 continue 

441 # print("Building", my_memmap_path) 

442 # self.build_gene_array(marker_genes=marker_genes, family_index=family_index, output_dir=output_dir) 

443 # assert my_memmap_path.exists() 

444 

445 my_memmap = read_memmap(my_memmap_path, family_count) 

446 

447 # Build memmap for output if it doesn't exist 

448 if memmap_array is None: 

449 size = my_memmap.shape[1] 

450 shape = (len(keys),size) 

451 memmap_array = np.memmap(memmap_array_path, dtype=dtype, mode='w+', shape=shape) 

452 

453 # Copy memmap for gene family into output memmap 

454 memmap_array[current_index:current_index+family_count,:] = my_memmap[:,:] 

455 

456 current_index += family_count 

457 

458 assert len(keys) == current_index 

459 

460 memmap_array.flush() 

461 

462 # Save keys 

463 keys_path = output_dir / f"{output_dir.name}.txt" 

464 print(f"Saving keys to {keys_path}") 

465 with open(keys_path, "w") as f: 

466 for key in keys: 

467 print(key, file=f) 

468