Update.
[mygpt.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9
10 import torch
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ##############################
16
17 class Residual(nn.Module):
18     def __init__(self, *f):
19         super().__init__()
20         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
21
22     def forward(self, x):
23         return x + self.f(x)
24
25 ##############################
26
27 class PositionalEncoding(nn.Module):
28     def __init__(self, len_max):
29         super().__init__()
30         self.len_max = len_max
31
32     # From Vaswani et al 2018
33     # PE_{t,2i}   = sin(t/(L^{2i/D}))
34     # PE_{t,2i+1} = cos(t/(L^{2i/D}))
35     def forward(self, x):
36         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
37         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
38         k = j%2
39         return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
40
41 ##############################
42
43 class QKVAttention(nn.Module):
44     def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
45         super().__init__()
46
47         def randw(*d):
48             return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
49
50         self.w_q = randw(nb_heads, dim_qk, dim_in)
51         self.w_k = randw(nb_heads, dim_qk, dim_in)
52         self.w_v = randw(nb_heads, dim_v, dim_in)
53         self.causal = causal
54         self.attention_dropout = attention_dropout
55
56     def forward(self, x):
57         q = torch.einsum('ntc,hdc->nhtd', x, self.w_q)
58         k = torch.einsum('ntc,hdc->nhtd', x, self.w_k)
59         v = torch.einsum('ntc,hdc->nhtd', x, self.w_v)
60         r = math.sqrt(q.size(3))
61         a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
62         if self.causal:
63             mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
64             a = a.masked_fill(mask, float('-inf'))
65         a = a.softmax(dim = 3)
66         a = F.dropout(a, self.attention_dropout, self.training)
67         y = torch.einsum('nhts,nhsd->nhtd', a, v)
68         return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
69
70 ##############################
71
72 class MyGPT(nn.Module):
73     def __init__(self,
74                  vocabulary_size,
75                  dim_model, dim_keys, dim_hidden,
76                  nb_heads, nb_blocks, dropout = 0.):
77
78         super().__init__()
79
80         assert dim_model % nb_heads == 0
81
82         self.embedding = nn.Sequential(
83             nn.Embedding(vocabulary_size, dim_model),
84             nn.Dropout(dropout),
85             PositionalEncoding(len_max = 1e5),
86         )
87
88         trunk_blocks = [ ]
89
90         for _ in range(nb_blocks):
91             trunk_blocks += [
92                 Residual(
93                     nn.LayerNorm(dim_model),
94                     QKVAttention(
95                         dim_in = dim_model,
96                         dim_qk = dim_keys, dim_v = dim_model // nb_heads,
97                         nb_heads = nb_heads,
98                         causal = True, attention_dropout = dropout
99                     ),
100                     nn.Linear(in_features = dim_model, out_features = dim_model),
101                 ),
102                 Residual(
103                     nn.LayerNorm(dim_model),
104                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
105                     nn.ReLU(),
106                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
107                     nn.Dropout(dropout),
108                 ),
109             ]
110
111         self.trunk = nn.Sequential(*trunk_blocks)
112
113         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
114
115     def forward(self, x):
116         x = self.embedding(x)
117         x = self.trunk(x)
118         x = self.readout(x)
119         return x
120
121 ######################################################################
122
123 if __name__ == '__main__':
124     vocabulary_size = 10
125     x = torch.randint(vocabulary_size, (25, 100))
126
127     model = MyGPT(
128         vocabulary_size = vocabulary_size,
129         dim_model = 16, dim_keys = 50, dim_hidden = 100,
130         nb_heads = 2, nb_blocks = 3,
131         dropout = 0.1
132     )
133
134     y = model(x)
135
136 ######################################################################