Coverage for bloodhound/models.py : 20.69%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2from torch import nn
3from hierarchicalsoftmax import HierarchicalSoftmaxLazyLinear, SoftmaxNode
6class BloodhoundModel(nn.Module):
7 def __init__(
8 self,
9 classification_tree:SoftmaxNode,
10 features:int=5120,
11 intermediate_layers:int=0,
12 growth_factor:float=2.0,
13 attention_size:int=512,
14 dropout:float=0.0,
15 return_attention:bool=False,
16 ):
17 super().__init__()
19 assert growth_factor > 0.0
21 self.classification_tree = classification_tree
22 modules = [nn.LazyLinear(out_features=features), nn.PReLU()]
23 for _ in range(intermediate_layers):
24 out_features = int(features * growth_factor + 0.5)
25 modules += [nn.LazyLinear(out_features=out_features), nn.PReLU(), nn.Dropout(dropout)]
26 features = out_features
28 self.sequential = nn.Sequential(*modules)
30 self.attention_layer = nn.Sequential(
31 nn.Linear(out_features, attention_size), # (batch_size, seq_length, hidden_size)
32 nn.PReLU(),
33 nn.Linear(attention_size, 1) # (batch_size, seq_length, 1)
34 )
36 self.classifier = HierarchicalSoftmaxLazyLinear(root=classification_tree)
37 self.model_dtype = next(self.sequential.parameters()).dtype
38 self.return_attention = return_attention
40 def forward(self, x):
41 if self.model_dtype != x.dtype:
42 x = x.to(dtype=self.model_dtype)
44 x = self.sequential(x)
46 attention_scores = self.attention_layer(x)
47 attention_weights = torch.softmax(attention_scores, dim=1)
49 context_vector = torch.sum(attention_weights * x, dim=1)
51 result = self.classifier(context_vector)
53 if self.return_attention:
54 return result, attention_scores
56 return result