Coverage for bloodhound/embeddings/esm.py : 29.55%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from enum import Enum
2import typer
3from pathlib import Path
4import torch
5from torchapp.cli import method
6from bloodhound.embedding import Embedding
10class ESMLayers(Enum):
11 T6 = "6"
12 T12 = "12"
13 T30 = "30"
14 T33 = "33"
15 T36 = "36"
16 T48 = "48"
18 @classmethod
19 def from_value(cls, value: int|str) -> "ESMLayers":
20 for layer in cls:
21 if layer.value == str(value):
22 return layer
23 return None
25 def __int__(self):
26 return int(self.value)
28 def __str__(self):
29 return str(self.value)
31 def model_name(self) -> str:
32 match self:
33 case ESMLayers.T48:
34 return "esm2_t48_15B_UR50D"
35 case ESMLayers.T36:
36 return "esm2_t36_3B_UR50D"
37 case ESMLayers.T33:
38 return "esm2_t33_650M_UR50D"
39 case ESMLayers.T30:
40 return "esm2_t30_150M_UR50D"
41 case ESMLayers.T12:
42 return "esm2_t12_35M_UR50D"
43 case ESMLayers.T6:
44 return "esm2_t6_8M_UR50D"
46 def get_model_alphabet(self) -> tuple["ESM2", "Alphabet"]:
47 return torch.hub.load("facebookresearch/esm:main", self.model_name())
50class ESMEmbedding(Embedding):
51 @method
52 def setup(
53 self,
54 layers:ESMLayers=typer.Option(ESMLayers.T6, help="The number of ESM layers to use."),
55 hub_dir:Path=typer.Option(None, help="The torch hub directory where the ESM models will be cached."),
56 ):
57 if layers and not getattr(self, 'layers', None):
58 self.layers = layers
60 if isinstance(self.layers, (str,int)):
61 self.layers = ESMLayers.from_value(self.layers)
63 assert self.layers is not None, f"Please ensure the number of ESM layers is one of " + ", ".join(ESMLayers.keys())
64 assert isinstance(self.layers, ESMLayers)
66 self.hub_dir = hub_dir
67 if hub_dir:
68 torch.hub.set_dir(str(hub_dir))
69 self.model = None
70 self.device = None
71 self.batch_converter = None
72 self.alphabet = None
74 def __getstate__(self):
75 return dict(max_length=self.max_length, layers=str(self.layers))
76 # Return a dictionary of attributes to be pickled
77 state = self.__dict__.copy()
78 # Remove the attribute that should not be pickled
79 if 'model' in state:
80 del state['model']
81 if 'batch_converter' in state:
82 del state['batch_converter']
83 if 'alphabet' in state:
84 del state['alphabet']
85 if 'device' in state:
86 del state['device']
87 return state
89 def __setstate__(self, state):
90 self.__init__()
92 # Restore the object state from the unpickled state
93 self.__dict__.update(state)
94 self.model = None
95 self.device = None
96 self.batch_converter = None
97 self.alphabet = None
98 self.hub_dir = None
100 def load(self):
101 self.model, self.alphabet = self.layers.get_model_alphabet()
102 self.batch_converter = self.alphabet.get_batch_converter()
103 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
104 self.model = self.model.to(self.device)
106 def embed(self, seq:str) -> torch.Tensor:
107 """ Takes a protein sequence as a string and returns an embedding tensor per residue. """
108 if isinstance(self.layers, (str,int)):
109 self.layers = ESMLayers.from_value(self.layers)
111 layers = int(self.layers.value)
113 # Handle ambiguous AAs
114 # https://github.com/facebookresearch/esm/issues/164
115 seq = seq.replace("J", "X")
117 if not self.model:
118 self.load()
120 _, _, batch_tokens = self.batch_converter([("marker_id", seq)])
121 batch_tokens = batch_tokens.to(self.device)
122 batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)
124 # Extract per-residue representations (on CPU)
125 with torch.no_grad():
126 results = self.model(batch_tokens, repr_layers=[layers], return_contacts=True)
127 token_representations = results["representations"][layers]
129 assert len(batch_lens) == 1, f"More than one length found"
130 assert token_representations.size(0) == 1, f"More than one representation found"
132 # Strip off the beginning-of-sequence and end-of-sequence tokens
133 embedding_tensor = token_representations[0, 1 : batch_lens[0] - 1]
134 assert len(seq) == len(embedding_tensor), f"Embedding representation incorrect length. should be {len(seq)} but is {len(embedding_tensor)}"
136 return embedding_tensor