9da2e685478f34b06d773eb265f9b2f155db1fe0
[mygpt.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9
10 import torch
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ##############################
16
17 class WithResidual(nn.Module):
18     def __init__(self, *f):
19         super().__init__()
20         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
21
22     def forward(self, x):
23         return x + self.f(x)
24
25 ##############################
26
27 class PositionalEncoding(nn.Module):
28     def __init__(self, len_max):
29         super().__init__()
30         self.len_max = len_max
31
32     # From Vaswani et al 2018
33     # PE_{t,2i}   = sin(t/(L^{2i/D}))
34     # PE_{t,2i+1} = cos(t/(L^{2i/D}))
35     def forward(self, x):
36         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
37         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
38         k = j%2
39         pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
40         return x + pe
41
42 ##############################
43
44 class QKVAttention(nn.Module):
45     def __init__(self,
46                  dim_in, dim_qk, dim_v,
47                  nb_heads = 1, causal = False, attention_dropout = 0.0):
48         super().__init__()
49
50         def randw(*d):
51             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
52
53         self.causal = causal
54         self.attention_dropout = attention_dropout
55
56         self.w_q = randw(nb_heads, dim_qk, dim_in)
57         self.w_k = randw(nb_heads, dim_qk, dim_in)
58         self.w_v = randw(nb_heads, dim_v, dim_in)
59         self.w_o = randw(dim_v * nb_heads, dim_in)
60
61     def forward(self, x_q, x_kv = None):
62         if x_kv is None: x_kv = x_q
63
64         q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
65         k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
66         v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
67
68         a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
69
70         if self.causal:
71             mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
72                    < torch.arange(a.size(3), device = q.device)[None, None, None, :]
73             a = a.masked_fill(mask, float('-inf'))
74
75         a = a.softmax(dim = 3)
76         a = F.dropout(a, self.attention_dropout, self.training)
77         y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2)
78
79         y = y @ self.w_o
80
81         return y
82
83 ##############################
84
85 class MyGPT(nn.Module):
86     def __init__(self,
87                  vocabulary_size,
88                  dim_model, dim_keys, dim_hidden,
89                  nb_heads, nb_blocks,
90                  dropout = 0.0, len_max = 1e5):
91
92         super().__init__()
93
94         assert dim_model % nb_heads == 0
95
96         self.embedding = nn.Sequential(
97             nn.Embedding(vocabulary_size, dim_model),
98             nn.Dropout(dropout),
99             PositionalEncoding(len_max),
100         )
101
102         trunk_blocks = [ ]
103
104         for _ in range(nb_blocks):
105             trunk_blocks += [
106                 WithResidual(
107                     nn.LayerNorm((dim_model,)),
108                     QKVAttention(
109                         dim_in = dim_model,
110                         dim_qk = dim_keys,
111                         dim_v = dim_model // nb_heads,
112                         nb_heads = nb_heads,
113                         causal = True, attention_dropout = dropout
114                     ),
115                 ),
116                 WithResidual(
117                     nn.LayerNorm((dim_model,)),
118                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
119                     nn.ReLU(),
120                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
121                     nn.Dropout(dropout),
122                 ),
123             ]
124
125         self.trunk = nn.Sequential(*trunk_blocks)
126
127         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
128
129     def forward(self, x):
130         x = F.pad(x, (1, 0))
131         x = self.embedding(x)
132         x = self.trunk(x)
133         x = self.readout(x)
134         x = F.pad(x, (0, 0, 0, -1))
135         return x
136
137 ######################################################################
138
139 if __name__ == '__main__':
140     print('Basic check.')
141
142     vocabulary_size = 10
143     x = torch.randint(vocabulary_size, (25, 100))
144
145     model = MyGPT(
146         vocabulary_size = vocabulary_size,
147         dim_model = 18, dim_keys = 50, dim_hidden = 100,
148         nb_heads = 2, nb_blocks = 3,
149         dropout = 0.1
150     )
151
152     y = model(x)
153
154 ######################################################################