Coverage for bloodhound/embedding.py : 35.17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import gzip
2import os
3from pathlib import Path
4import os
5from abc import ABC, abstractmethod
6from Bio import SeqIO
7import random
8import numpy as np
9from rich.progress import track
10from hierarchicalsoftmax import SoftmaxNode
11from corgi.seqtree import SeqTree
12import tarfile
13import torch
14from io import StringIO
15from torchapp.cli import CLIApp, tool, method
16import typer
17from dataclasses import dataclass
19from .data import read_memmap, RANKS
22def _open(path, mode='rt', **kwargs):
23 """
24 Open a file normally, or with gzip if it ends in .gz.
26 Args:
27 path (str or Path): The path to the file.
28 mode (str): The mode to open the file with (default 'rt' for reading text).
29 **kwargs: Additional arguments passed to open or gzip.open.
31 Returns:
32 A file object.
33 """
34 path = Path(path)
35 if path.suffix == '.gz':
36 return gzip.open(path, mode, **kwargs)
37 return open(path, mode, **kwargs)
40def set_validation_rank_to_seqtree(
41 seqtree:SeqTree,
42 validation_rank:str="species",
43 partitions:int=5,
44) -> SeqTree:
45 # find the taxonomic rank to use for the validation partition
46 validation_rank = validation_rank.lower()
47 assert validation_rank in RANKS
48 validation_rank_index = RANKS.index(validation_rank)
50 partitions_dict = {}
51 for key in seqtree:
52 node = seqtree.node(key)
53 # Assign validation partition at set rank
54 partition_node = node.ancestors[validation_rank_index]
55 if partition_node not in partitions_dict:
56 partitions_dict[partition_node] = random.randint(0,partitions-1)
58 seqtree[key].partition = partitions_dict[partition_node]
60 return seqtree
63def get_key(accession:str, gene:str) -> str:
64 """ Returns the standard format of a key """
65 key = f"{accession}/{gene}"
66 return key
69def get_node(lineage:str, lineage_to_node:dict[str,SoftmaxNode]) -> SoftmaxNode:
70 if lineage in lineage_to_node:
71 return lineage_to_node[lineage]
73 assert ";" in lineage, f"Semi-colon ';' not found in lineage '{lineage}'"
74 split_point = lineage.rfind(";")
75 parent_lineage = lineage[:split_point]
76 name = lineage[split_point+1:]
77 parent = get_node(parent_lineage, lineage_to_node)
78 node = SoftmaxNode(name=name, parent=parent)
79 lineage_to_node[lineage] = node
80 return node
83def generate_overlapping_intervals(total: int, interval_size: int, min_overlap: int, check:bool=True, variable_size:bool=False) -> list[tuple[int, int]]:
84 """
85 Creates a list of overlapping intervals within a specified range, adjusting the interval size to ensure
86 that the overlap is approximately the same across all intervals.
88 Args:
89 total (int): The total range within which intervals are to be created.
90 max_interval_size (int): The maximum size of each interval.
91 min_overlap (int): The minimum number of units by which consecutive intervals overlap.
92 check (bool): If True, checks are performed to ensure that the intervals meet the specified conditions.
94 Returns:
95 list[tuple[int, int]]: A list of tuples where each tuple represents the start (inclusive)
96 and end (exclusive) of an interval.
98 Example:
99 >>> generate_overlapping_intervals(20, 5, 2)
100 [(0, 5), (3, 8), (6, 11), (9, 14), (12, 17), (15, 20)]
101 """
102 intervals = []
103 start = 0
105 if total == 0:
106 return intervals
108 max_interval_size = interval_size
109 assert interval_size
110 assert min_overlap is not None
111 assert interval_size > min_overlap, f"Max interval size of {interval_size} must be greater than min overlap of {min_overlap}"
113 # Calculate the number of intervals needed to cover the range
114 num_intervals, remainder = divmod(total - min_overlap, interval_size - min_overlap)
115 if remainder > 0:
116 num_intervals += 1
118 # Calculate the exact interval size to ensure consistent overlap
119 overlap = min_overlap
120 if variable_size:
121 if num_intervals > 1:
122 interval_size, remainder = divmod(total + (num_intervals - 1) * overlap, num_intervals)
123 if remainder > 0:
124 interval_size += 1
125 else:
126 # If the size is fixed, then vary the overlap to keep it even
127 if num_intervals > 1:
128 overlap, remainder = divmod( num_intervals * interval_size - total, num_intervals - 1)
129 if overlap < min_overlap:
130 overlap = min_overlap
132 while True:
133 end = start + interval_size
134 if end > total:
135 end = total
136 start = max(end - interval_size,0)
137 intervals.append((start, end))
138 start += interval_size - overlap
139 if end >= total:
140 break
142 if check:
143 assert intervals[0][0] == 0
144 assert intervals[-1][1] == total
145 assert len(intervals) == num_intervals, f"Expected {num_intervals} intervals, got {len(intervals)}"
147 assert interval_size <= max_interval_size, f"Interval size of {interval_size} exceeds max interval size of {max_interval_size}"
148 for interval in intervals:
149 assert interval[1] - interval[0] == interval_size, f"Interval size of {interval[1] - interval[0]} is not the expected size {interval_size}"
151 for i in range(1, len(intervals)):
152 overlap = intervals[i - 1][1] - intervals[i][0]
153 assert overlap >= min_overlap, f"Min overlap condition of {min_overlap} not met for intervals {intervals[i - 1]} and {intervals[i]} (overlap {overlap})"
155 return intervals
158@dataclass
159class Embedding(CLIApp, ABC):
160 """ A class for embedding protein sequences. """
161 max_length:int|None=None
162 overlap:int=64
164 def __post_init__(self):
165 super().__init__()
167 @abstractmethod
168 def embed(self, seq:str) -> torch.Tensor:
169 """ Takes a protein sequence as a string and returns an embedding vector. """
170 raise NotImplementedError
172 def reduce(self, tensor:torch.Tensor) -> torch.Tensor:
173 if tensor.ndim == 2:
174 tensor = tensor.mean(dim=0)
175 assert tensor.ndim == 1
176 return tensor
178 def __call__(self, seq:str) -> torch.Tensor:
179 """ Takes a protein sequence as a string and returns an embedding vector. """
180 if not self.max_length or len(seq) <= self.max_length:
181 tensor = self.embed(seq)
182 return self.reduce(tensor)
184 epsilon = 0.1
185 intervals = generate_overlapping_intervals(len(seq), self.max_length, self.overlap)
186 weights = torch.zeros( (len(seq),), device="cpu" )
187 tensor = None
188 for start,end in intervals:
189 result = self.embed(seq[start:end]).cpu()
191 assert result.shape[0] == end-start
192 embedding_size = result.shape[1]
194 if tensor is None:
195 tensor = torch.zeros( (len(seq), embedding_size ), device="cpu")
197 assert tensor.shape[-1] == embedding_size
199 interval_indexes = torch.arange(end-start)
200 distance_from_ends = torch.min( interval_indexes-start, end-interval_indexes-1 )
202 weight = epsilon + torch.minimum(distance_from_ends, torch.tensor(self.overlap))
204 tensor[start:end] += result * weight.unsqueeze(1)
205 weights[start:end] += weight
207 tensor = tensor/weights.unsqueeze(1)
209 return self.reduce(tensor)
211 @method
212 def setup(self, **kwargs):
213 pass
215 def build_seqtree(self, taxonomy:Path) -> tuple[SeqTree,dict[str,SoftmaxNode]]:
216 # Create root of tree
217 lineage_to_node = {}
218 root = None
220 # Fill out tree with taxonomy
221 accession_to_node = {}
222 with _open(taxonomy) as f:
223 for line in f:
224 accesssion, lineage = line.split("\t")
226 if not root:
227 root_name = lineage.split(";")[0]
228 root = SoftmaxNode(root_name)
229 lineage_to_node[root_name] = root
231 node = get_node(lineage, lineage_to_node)
232 accession_to_node[accesssion] = node
234 seqtree = SeqTree(classification_tree=root)
235 return seqtree, accession_to_node
237 @tool("setup")
238 def test_lengths(
239 self,
240 end:int=5_000,
241 start:int=1000,
242 retries:int=5,
243 **kwargs,
244 ):
245 def random_amino_acid_sequence(k):
246 amino_acids = "ACDEFGHIKLMNPQRSTVWY" # standard 20 amino acids
247 return ''.join(random.choice(amino_acids) for _ in range(k))
249 self.max_length = None
250 self.setup(**kwargs)
251 for ii in track(range(start,end)):
252 for _ in range(retries):
253 seq = random_amino_acid_sequence(ii)
254 try:
255 self(seq)
256 except Exception as err:
257 print(f"{ii}: {err}")
258 return
261 @tool("setup")
262 def build_gene_array(
263 self,
264 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."),
265 family_index:int=typer.Option(default=..., help="The index for the gene family to use. E.g. if there are 120 gene families then this should be a number from 0 to 119."),
266 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."),
267 flush_every:int=typer.Option(default=5_000, help="An interval to flush the memmap array as it is generated."),
268 max_length:int=None,
269 **kwargs,
270 ):
271 self.max_length = max_length
272 self.setup(**kwargs)
274 assert marker_genes is not None
275 assert family_index is not None
276 assert output_dir is not None
278 dtype = 'float16'
280 memmap_wip_array = None
281 output_dir.mkdir(parents=True, exist_ok=True)
282 memmap_wip_path = output_dir / f"{family_index}-wip.npy"
283 error = output_dir / f"{family_index}-errors.txt"
284 accessions_wip = output_dir / f"{family_index}-accessions-wip.txt"
286 accessions = []
288 print(f"Loading {marker_genes} file.")
289 with tarfile.open(marker_genes, "r:gz") as tar, open(error, "w") as error_file, open(accessions_wip, "w") as accessions_wip_file:
290 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")]
291 prefix_length = len(os.path.commonprefix([Path(member.name).with_suffix("").name for member in members]))
293 member = members[family_index]
294 print(f"Processing file {family_index} in {marker_genes}")
296 f = tar.extractfile(member)
297 marker_id = Path(member.name).with_suffix("").name[prefix_length:]
299 fasta_io = StringIO(f.read().decode('ascii'))
301 total = sum(1 for _ in SeqIO.parse(fasta_io, "fasta"))
302 fasta_io.seek(0)
303 print(marker_id, total)
305 for record in track(SeqIO.parse(fasta_io, "fasta"), total=total):
306 # for record in SeqIO.parse(fasta_io, "fasta"):
307 species_accession = record.id
309 key = get_key(species_accession, marker_id)
311 seq = str(record.seq).replace("-","").replace("*","")
312 try:
313 vector = self(seq)
314 except Exception as err:
315 print(f"{key} ({len(seq)}): {err}", file=error_file)
316 print(f"{key} ({len(seq)}): {err}")
317 continue
319 if vector is None:
320 print(f"{key} ({len(seq)}): Embedding is None", file=error_file)
321 print(f"{key} ({len(seq)}): Embedding is None")
322 continue
324 if torch.isnan(vector).any():
325 print(f"{key} ({len(seq)}): Embedding contains NaN", file=error_file)
326 print(f"{key} ({len(seq)}): Embedding contains NaN")
327 continue
329 if memmap_wip_array is None:
330 size = len(vector)
331 shape = (total,size)
332 memmap_wip_array = np.memmap(memmap_wip_path, dtype=dtype, mode='w+', shape=shape)
334 index = len(accessions)
335 memmap_wip_array[index,:] = vector.cpu().half().numpy()
336 if index % flush_every == 0:
337 memmap_wip_array.flush()
339 accessions.append(key)
340 print(key, file=accessions_wip_file)
342 memmap_wip_array.flush()
344 accessions_path = output_dir / f"{family_index}.txt"
345 with open(accessions_path, "w") as f:
346 for accession in accessions:
347 print(accession, file=f)
349 # Save final memmap array now that we now the final size
350 memmap_path = output_dir / f"{family_index}.npy"
351 shape = (len(accessions),size)
352 print(f"Writing final memmap array of shape {shape}: {memmap_path}")
353 memmap_array = np.memmap(memmap_path, dtype=dtype, mode='w+', shape=shape)
354 memmap_array[:len(accessions),:] = memmap_wip_array[:len(accessions),:]
355 memmap_array.flush()
357 # Clean up
358 memmap_array._mmap.close()
359 memmap_array._mmap = None
360 memmap_array = None
361 memmap_wip_path.unlink()
362 accessions_wip.unlink()
364 @tool
365 def set_validation_rank(
366 self,
367 seqtree:Path=typer.Option(default=..., help="The path to the seqtree file."),
368 output:Path=typer.Option(default=..., help="The path to save the adapted seqtree file."),
369 validation_rank:str=typer.Option(default="species", help="The rank to hold out for cross-validation."),
370 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."),
371 ) -> SeqTree:
372 seqtree = SeqTree.load(seqtree)
373 set_validation_rank_to_seqtree(seqtree, validation_rank=validation_rank, partitions=partitions)
374 seqtree.save(output)
375 return seqtree
377 @tool
378 def preprocess(
379 self,
380 taxonomy:Path=typer.Option(default=..., help="The path to the TSV taxonomy file (e.g. bac120_taxonomy_r220.tsv)."),
381 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."),
382 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."),
383 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."),
384 seed:int=typer.Option(default=42, help="The random seed."),
385 ):
386 seqtree, accession_to_node = self.build_seqtree(taxonomy)
388 dtype = 'float16'
390 random.seed(seed)
392 print(f"Loading {marker_genes} file.")
393 with tarfile.open(marker_genes, "r:gz") as tar:
394 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")]
395 family_count = len(members)
396 print(f"{family_count} gene families found.")
398 # Read and collect accessions
399 print(f"Building seqtree")
400 keys = []
401 counts = []
402 node_to_partition_dict = dict()
403 for family_index in track(range(family_count)):
404 keys_path = output_dir / f"{family_index}.txt"
406 if not keys_path.exists():
407 counts.append(0)
408 continue
410 with open(keys_path) as f:
411 family_index_keys = [line.strip() for line in f]
412 keys += family_index_keys
413 counts.append(len(family_index_keys))
415 for key in family_index_keys:
416 species_accession = key.split("/")[0]
417 node = accession_to_node[species_accession]
418 partition = node_to_partition_dict.setdefault(key, random.randint(0, partitions - 1))
420 # Add to seqtree
421 seqtree.add(key, node, partition)
423 assert len(counts) == family_count
425 # Save seqtree
426 seqtree_path = output_dir / f"{output_dir.name}.st"
427 print(f"Saving seqtree to {seqtree_path}")
428 seqtree.save(seqtree_path)
430 # Concatenate numpy memmap arrays
431 memmap_array = None
432 memmap_array_path = output_dir / f"{output_dir.name}.npy"
433 print(f"Saving memmap to {memmap_array_path}")
434 current_index = 0
435 for family_index, family_count in track(enumerate(counts), total=len(counts)):
436 my_memmap_path = output_dir / f"{family_index}.npy"
438 # Build memmap for gene family if it doesn't exist
439 if not my_memmap_path.exists():
440 continue
441 # print("Building", my_memmap_path)
442 # self.build_gene_array(marker_genes=marker_genes, family_index=family_index, output_dir=output_dir)
443 # assert my_memmap_path.exists()
445 my_memmap = read_memmap(my_memmap_path, family_count)
447 # Build memmap for output if it doesn't exist
448 if memmap_array is None:
449 size = my_memmap.shape[1]
450 shape = (len(keys),size)
451 memmap_array = np.memmap(memmap_array_path, dtype=dtype, mode='w+', shape=shape)
453 # Copy memmap for gene family into output memmap
454 memmap_array[current_index:current_index+family_count,:] = my_memmap[:,:]
456 current_index += family_count
458 assert len(keys) == current_index
460 memmap_array.flush()
462 # Save keys
463 keys_path = output_dir / f"{output_dir.name}.txt"
464 print(f"Saving keys to {keys_path}")
465 with open(keys_path, "w") as f:
466 for key in keys:
467 print(key, file=f)