Replaced --synthesis_sampling with --deterministic_synthesis.
[mygpt.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9
10 import torch
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ##############################
16
17 class WithResidual(nn.Module):
18     def __init__(self, *f):
19         super().__init__()
20         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
21
22     def forward(self, x):
23         return x + self.f(x)
24
25 ##############################
26
27 class AddPositionalEncoding(nn.Module):
28     def __init__(self, len_max):
29         super().__init__()
30         self.len_max = len_max
31
32     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
33     def forward(self, x):
34         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
35         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
36         k = j%2
37         pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
38         return x + pe
39
40 ##############################
41
42 class QKVAttention(nn.Module):
43     def __init__(self,
44                  dim_in, dim_qk, dim_v,
45                  nb_heads = 1, causal = False, attention_dropout = 0.0):
46         super().__init__()
47
48         def randw(*d):
49             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
50
51         self.causal = causal
52         self.attention_dropout = attention_dropout
53
54         self.w_q = randw(nb_heads, dim_qk, dim_in)
55         self.w_k = randw(nb_heads, dim_qk, dim_in)
56         self.w_v = randw(nb_heads, dim_v, dim_in)
57         self.w_o = randw(dim_v * nb_heads, dim_in)
58
59     def forward(self, x_q, x_kv = None):
60         if x_kv is None: x_kv = x_q
61
62         q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
63         k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
64         v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
65
66         a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
67
68         if self.causal:
69             mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
70                    < torch.arange(a.size(3), device = q.device)[None, None, None, :]
71             a = a.masked_fill(mask, float('-inf'))
72
73         a = a.softmax(dim = 3)
74         a = F.dropout(a, self.attention_dropout, self.training)
75         y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2)
76
77         y = y @ self.w_o
78
79         return y
80
81 ##############################
82
83 class MyGPT(nn.Module):
84     def __init__(self,
85                  vocabulary_size,
86                  dim_model, dim_keys, dim_hidden,
87                  nb_heads, nb_blocks,
88                  dropout = 0.0, len_max = 1e5):
89
90         super().__init__()
91
92         assert dim_model % nb_heads == 0
93
94         self.embedding = nn.Sequential(
95             nn.Embedding(vocabulary_size, dim_model),
96             nn.Dropout(dropout),
97             AddPositionalEncoding(len_max),
98         )
99
100         trunk_blocks = [ ]
101
102         for _ in range(nb_blocks):
103             trunk_blocks += [
104                 WithResidual(
105                     nn.LayerNorm((dim_model,)),
106                     QKVAttention(
107                         dim_in = dim_model,
108                         dim_qk = dim_keys,
109                         dim_v = dim_model // nb_heads,
110                         nb_heads = nb_heads,
111                         causal = True, attention_dropout = dropout
112                     ),
113                 ),
114                 WithResidual(
115                     nn.LayerNorm((dim_model,)),
116                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
117                     nn.ReLU(),
118                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
119                     nn.Dropout(dropout),
120                 ),
121             ]
122
123         self.trunk = nn.Sequential(*trunk_blocks)
124
125         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
126
127         with torch.no_grad():
128             for m in self.modules():
129                 if isinstance(m, nn.Embedding):
130                     m.weight.normal_(mean = 0, std = 2e-2)
131                 elif isinstance(m, nn.LayerNorm):
132                     m.bias.zero_()
133                     m.weight.fill_(1.0)
134
135     def forward(self, x):
136         x = F.pad(x, (1, -1))
137         x = self.embedding(x)
138         x = self.trunk(x)
139         x = self.readout(x)
140         return x
141
142 ######################################################################
143
144 if __name__ == '__main__':
145     print('Basic check.')
146
147     vocabulary_size = 10
148     x = torch.randint(vocabulary_size, (25, 100))
149
150     model = MyGPT(
151         vocabulary_size = vocabulary_size,
152         dim_model = 18, dim_keys = 50, dim_hidden = 100,
153         nb_heads = 2, nb_blocks = 3,
154         dropout = 0.1
155     )
156
157     y = model(x)
158
159 ######################################################################