Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import random 

2import os 

3import numpy as np 

4from collections import defaultdict 

5import torch 

6from torch.utils.data import DataLoader 

7from torch.utils.data import Dataset 

8import lightning as L 

9from dataclasses import dataclass, field 

10from corgi.seqtree import SeqTree 

11 

12 

13RANKS = ["phylum", "class", "order", "family", "genus", "species"] 

14 

15def read_memmap(path, count, dtype:str="float16") -> np.memmap: 

16 file_size = os.path.getsize(path) 

17 dtype_size = np.dtype(dtype).itemsize 

18 num_elements = file_size // dtype_size 

19 embedding_size = num_elements // count 

20 shape = (count, embedding_size) 

21 return np.memmap(path, dtype=dtype, mode='r', shape=shape) 

22 

23 

24def gene_id_from_accession(accession:str): 

25 return accession.split("/")[-1] 

26 

27 

28def choose_k_from_n(lst, k) -> list[int]: 

29 n = len(lst) 

30 if n == 0: 

31 return [] 

32 repetitions = k // n 

33 remainder = k % n 

34 result = lst * repetitions + random.sample(lst, remainder) 

35 return result 

36 

37 

38@dataclass(kw_only=True) 

39class BloodhoundStack(): 

40 species:str 

41 array_indices:np.array 

42 

43 

44@dataclass(kw_only=True) 

45class BloodhoundPredictionDataset(Dataset): 

46 array:np.memmap|np.ndarray 

47 accessions: list[str] 

48 seq_count:int 

49 repeats:int = 2 

50 seed:int = 42 

51 stacks: list[BloodhoundStack] = field(init=False) 

52 

53 def __post_init__(self): 

54 species_to_array_indices = defaultdict(set) 

55 for index, accession in enumerate(self.accessions): 

56 slash_position = accession.rfind("/") 

57 assert slash_position != -1 

58 species = accession[:slash_position] 

59 species_to_array_indices[species].add(index) 

60 

61 # Build stacks 

62 random.seed(self.seed) 

63 self.stacks = [] 

64 for species, species_array_indices in species_to_array_indices.items(): 

65 stack_indices = [] 

66 remainder = [] 

67 for repeat_index in range(self.repeats + 1): 

68 if len(remainder) == 0 and repeat_index >= self.repeats: 

69 break 

70 

71 # Finish Remainder 

72 species_array_indices_set = set(species_array_indices) 

73 available = species_array_indices_set - set(remainder) 

74 to_add = random.sample(list(available), self.seq_count - len(remainder)) 

75 to_add_set = set(to_add) 

76 assert not set(remainder) & to_add_set, "remainder and to_add should be disjoint" 

77 

78 new_stack = sorted(remainder + to_add) 

79 stack_indices.append(new_stack) 

80 

81 remainder = list(species_array_indices_set - to_add_set) 

82 random.shuffle(remainder) 

83 

84 # If we have already added each item the required number of times, then stop 

85 if repeat_index >= self.repeats: 

86 break 

87 

88 while len(remainder) >= self.seq_count: 

89 stack_indices.append( remainder[:self.seq_count] ) 

90 remainder = remainder[self.seq_count:] 

91 

92 stack = BloodhoundStack(species=species, array_indices=stack_indices) 

93 self.stacks.append(stack) 

94 

95 def __len__(self): 

96 return len(self.stacks) 

97 

98 def __getitem__(self, idx): 

99 stack = self.stacks[idx] 

100 array_indices = stack.array_indices 

101 

102 assert len(array_indices) > 0, f"Stack has no array indices" 

103 with torch.no_grad(): 

104 data = np.array(self.array[array_indices, :], copy=False) 

105 embeddings = torch.tensor(data, dtype=torch.float16) 

106 del data 

107 

108 return embeddings 

109 

110 

111@dataclass(kw_only=True) 

112class BloodhoundTrainingDataset(Dataset): 

113 accessions: list[str] 

114 seqtree: SeqTree 

115 array:np.memmap|np.ndarray 

116 gene_id_dict: dict[str, int] 

117 accession_to_array_index:dict[str,int]|None=None 

118 seq_count:int = 0 

119 

120 def __len__(self): 

121 return len(self.accessions) 

122 

123 def __getitem__(self, idx): 

124 accession = self.accessions[idx] 

125 array_indices = self.accession_to_array_index[accession] if self.accession_to_array_index else idx 

126 if self.seq_count: 

127 array_indices = choose_k_from_n(array_indices, self.seq_count) 

128 

129 assert len(array_indices) > 0, f"Accession {accession} has no array indices" 

130 with torch.no_grad(): 

131 data = np.array(self.array[array_indices, :], copy=False) 

132 embedding = torch.tensor(data, dtype=torch.float16) 

133 del data 

134 

135 # gene_id = gene_id_from_accession(accession) 

136 seq_detail = self.seqtree[accession] 

137 node_id = int(seq_detail.node_id) 

138 del seq_detail 

139 

140 # return embedding, self.gene_id_dict[gene_id], self.seqtree[self.accessions[0]].node_id # hack 

141 return embedding, node_id 

142 

143 

144# @dataclass(kw_only=True) 

145# class BloodhoundPredictionDataset(Dataset): 

146# embeddings: list[torch.Tensor] 

147# gene_family_ids: list[int] 

148 

149# def __post_init__(self): 

150# assert len(self.embeddings) == len(self.gene_family_ids) 

151 

152# def __len__(self): 

153# return len(self.gene_family_ids) 

154 

155# def __getitem__(self, idx): 

156# return self.embeddings[idx] #, self.gene_family_ids[idx] 

157 

158 

159@dataclass 

160class BloodhoundDataModule(L.LightningDataModule): 

161 seqtree: SeqTree 

162 # seqbank: SeqBank 

163 array:np.memmap|np.ndarray 

164 accession_to_array_index:dict[str,int] 

165 gene_id_dict: dict[str,int] 

166 max_items: int = 0 

167 batch_size: int = 16 

168 num_workers: int = 0 

169 validation_partition:int = 0 

170 test_partition:int = -1 

171 train_all:bool = False 

172 

173 def __init__( 

174 self, 

175 seqtree: SeqTree, 

176 array:np.memmap|np.ndarray, 

177 accession_to_array_index:dict[str,list[int]], 

178 gene_id_dict: dict[str,int], 

179 max_items: int = 0, 

180 batch_size: int = 16, 

181 num_workers: int = None, 

182 validation_partition:int = 0, 

183 test_partition:int=-1, 

184 seq_count:int=0, 

185 train_all:bool=False, 

186 ): 

187 super().__init__() 

188 self.array = array 

189 self.accession_to_array_index = accession_to_array_index 

190 self.seqtree = seqtree 

191 self.gene_id_dict = gene_id_dict 

192 self.max_items = max_items 

193 self.batch_size = batch_size 

194 self.validation_partition = validation_partition 

195 self.test_partition = test_partition 

196 self.num_workers = min(os.cpu_count(), 8) if num_workers is None else num_workers 

197 self.seq_count = seq_count 

198 self.train_all = train_all 

199 

200 def setup(self, stage=None): 

201 # make assignments here (val/train/test split) 

202 # called on every process in DDP 

203 self.training = [] 

204 self.validation = [] 

205 

206 for accession, details in self.seqtree.items(): 

207 partition = details.partition 

208 if partition == self.test_partition: 

209 continue 

210 

211 dataset = self.validation if partition == self.validation_partition else self.training 

212 dataset.append( accession ) 

213 

214 if self.max_items and len(self.training) >= self.max_items and len(self.validation) > 0: 

215 break 

216 

217 if self.train_all: 

218 self.training += self.validation 

219 

220 self.train_dataset = self.create_dataset(self.training) 

221 self.val_dataset = self.create_dataset(self.validation) 

222 

223 def create_dataset(self, accessions:list[str]) -> BloodhoundTrainingDataset: 

224 return BloodhoundTrainingDataset( 

225 accessions=accessions, 

226 seqtree=self.seqtree, 

227 array=self.array, 

228 accession_to_array_index=self.accession_to_array_index, 

229 gene_id_dict=self.gene_id_dict, 

230 seq_count=self.seq_count, 

231 ) 

232 

233 def train_dataloader(self): 

234 return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) 

235 

236 def val_dataloader(self): 

237 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 

238