Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2from torch import nn 

3from hierarchicalsoftmax import HierarchicalSoftmaxLazyLinear, SoftmaxNode 

4 

5 

6class BloodhoundModel(nn.Module): 

7 def __init__( 

8 self, 

9 classification_tree:SoftmaxNode, 

10 features:int=5120, 

11 intermediate_layers:int=0, 

12 growth_factor:float=2.0, 

13 attention_size:int=512, 

14 dropout:float=0.0, 

15 return_attention:bool=False, 

16 ): 

17 super().__init__() 

18 

19 assert growth_factor > 0.0 

20 

21 self.classification_tree = classification_tree 

22 modules = [nn.LazyLinear(out_features=features), nn.PReLU()] 

23 for _ in range(intermediate_layers): 

24 out_features = int(features * growth_factor + 0.5) 

25 modules += [nn.LazyLinear(out_features=out_features), nn.PReLU(), nn.Dropout(dropout)] 

26 features = out_features 

27 

28 self.sequential = nn.Sequential(*modules) 

29 

30 self.attention_layer = nn.Sequential( 

31 nn.Linear(out_features, attention_size), # (batch_size, seq_length, hidden_size) 

32 nn.PReLU(), 

33 nn.Linear(attention_size, 1) # (batch_size, seq_length, 1) 

34 ) 

35 

36 self.classifier = HierarchicalSoftmaxLazyLinear(root=classification_tree) 

37 self.model_dtype = next(self.sequential.parameters()).dtype 

38 self.return_attention = return_attention 

39 

40 def forward(self, x): 

41 if self.model_dtype != x.dtype: 

42 x = x.to(dtype=self.model_dtype) 

43 

44 x = self.sequential(x) 

45 

46 attention_scores = self.attention_layer(x) 

47 attention_weights = torch.softmax(attention_scores, dim=1) 

48 

49 context_vector = torch.sum(attention_weights * x, dim=1) 

50 

51 result = self.classifier(context_vector) 

52 

53 if self.return_attention: 

54 return result, attention_scores 

55 

56 return result