Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from enum import Enum 

2import typer 

3from pathlib import Path 

4import torch 

5from torchapp.cli import method 

6from bloodhound.embedding import Embedding 

7 

8 

9 

10class ESMLayers(Enum): 

11 T6 = "6" 

12 T12 = "12" 

13 T30 = "30" 

14 T33 = "33" 

15 T36 = "36" 

16 T48 = "48" 

17 

18 @classmethod 

19 def from_value(cls, value: int|str) -> "ESMLayers": 

20 for layer in cls: 

21 if layer.value == str(value): 

22 return layer 

23 return None 

24 

25 def __int__(self): 

26 return int(self.value) 

27 

28 def __str__(self): 

29 return str(self.value) 

30 

31 def model_name(self) -> str: 

32 match self: 

33 case ESMLayers.T48: 

34 return "esm2_t48_15B_UR50D" 

35 case ESMLayers.T36: 

36 return "esm2_t36_3B_UR50D" 

37 case ESMLayers.T33: 

38 return "esm2_t33_650M_UR50D" 

39 case ESMLayers.T30: 

40 return "esm2_t30_150M_UR50D" 

41 case ESMLayers.T12: 

42 return "esm2_t12_35M_UR50D" 

43 case ESMLayers.T6: 

44 return "esm2_t6_8M_UR50D" 

45 

46 def get_model_alphabet(self) -> tuple["ESM2", "Alphabet"]: 

47 return torch.hub.load("facebookresearch/esm:main", self.model_name()) 

48 

49 

50class ESMEmbedding(Embedding): 

51 @method 

52 def setup( 

53 self, 

54 layers:ESMLayers=typer.Option(ESMLayers.T6, help="The number of ESM layers to use."), 

55 hub_dir:Path=typer.Option(None, help="The torch hub directory where the ESM models will be cached."), 

56 ): 

57 if layers and not getattr(self, 'layers', None): 

58 self.layers = layers 

59 

60 if isinstance(self.layers, (str,int)): 

61 self.layers = ESMLayers.from_value(self.layers) 

62 

63 assert self.layers is not None, f"Please ensure the number of ESM layers is one of " + ", ".join(ESMLayers.keys()) 

64 assert isinstance(self.layers, ESMLayers) 

65 

66 self.hub_dir = hub_dir 

67 if hub_dir: 

68 torch.hub.set_dir(str(hub_dir)) 

69 self.model = None 

70 self.device = None 

71 self.batch_converter = None 

72 self.alphabet = None 

73 

74 def __getstate__(self): 

75 return dict(max_length=self.max_length, layers=str(self.layers)) 

76 # Return a dictionary of attributes to be pickled 

77 state = self.__dict__.copy() 

78 # Remove the attribute that should not be pickled 

79 if 'model' in state: 

80 del state['model'] 

81 if 'batch_converter' in state: 

82 del state['batch_converter'] 

83 if 'alphabet' in state: 

84 del state['alphabet'] 

85 if 'device' in state: 

86 del state['device'] 

87 return state 

88 

89 def __setstate__(self, state): 

90 self.__init__() 

91 

92 # Restore the object state from the unpickled state 

93 self.__dict__.update(state) 

94 self.model = None 

95 self.device = None 

96 self.batch_converter = None 

97 self.alphabet = None 

98 self.hub_dir = None 

99 

100 def load(self): 

101 self.model, self.alphabet = self.layers.get_model_alphabet() 

102 self.batch_converter = self.alphabet.get_batch_converter() 

103 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

104 self.model = self.model.to(self.device) 

105 

106 def embed(self, seq:str) -> torch.Tensor: 

107 """ Takes a protein sequence as a string and returns an embedding tensor per residue. """ 

108 if isinstance(self.layers, (str,int)): 

109 self.layers = ESMLayers.from_value(self.layers) 

110 

111 layers = int(self.layers.value) 

112 

113 # Handle ambiguous AAs 

114 # https://github.com/facebookresearch/esm/issues/164 

115 seq = seq.replace("J", "X") 

116 

117 if not self.model: 

118 self.load() 

119 

120 _, _, batch_tokens = self.batch_converter([("marker_id", seq)]) 

121 batch_tokens = batch_tokens.to(self.device) 

122 batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1) 

123 

124 # Extract per-residue representations (on CPU) 

125 with torch.no_grad(): 

126 results = self.model(batch_tokens, repr_layers=[layers], return_contacts=True) 

127 token_representations = results["representations"][layers] 

128 

129 assert len(batch_lens) == 1, f"More than one length found" 

130 assert token_representations.size(0) == 1, f"More than one representation found" 

131 

132 # Strip off the beginning-of-sequence and end-of-sequence tokens 

133 embedding_tensor = token_representations[0, 1 : batch_lens[0] - 1] 

134 assert len(seq) == len(embedding_tensor), f"Embedding representation incorrect length. should be {len(seq)} but is {len(embedding_tensor)}" 

135 

136 return embedding_tensor