Update.
[mygptrnn.git] / fridge
1
2 ######################################################################
3
4 2024 Jan 07 21:37:48 (from mygpt.py)
5
6
7 # This is one order of magnitude more complicated than I expected, not
8 # elegant, slow, hopefully not buggy
9
10
11 def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device):
12     # starting flash backs
13     fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long()
14     fb_start[:, :, -CL:] = 0
15     fb_start[:, :, :CL] = 0
16
17     # Remove series longer than CL
18     fb_body = fb_start.clone()
19     fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)]
20     fb_body = fb_body.cumsum(dim=2)
21     fb_start = fb_start * (fb_body == 1)
22
23     # Set a origin source time (starting time of the chunck to copy
24     # here) We set it as the current time minus a multiple of CL to be
25     # consistent with the "rolling" caterpillar
26     t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :]
27     src_time = fb_start * (
28         t
29         - CL
30         * (
31             1
32             + (
33                 torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1)
34             ).long()
35         )
36     )
37     src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL]
38     src_time = src_time.cumsum(dim=2)
39
40     src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device)
41     src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL]
42     src_head = src_head.cumsum(dim=2)
43
44     # combine
45     src_delta = fb_start.clone()
46     src_delta[:, :, CL:] -= fb_start[:, :, :-CL]
47     src_delta = src_delta.cumsum(dim=2)
48     src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL]
49     src_time += src_delta.cumsum(dim=2) - 1
50
51     return src_time, src_head
52
53
54 def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
55     N, H, CH = V.size(0), V.size(1), rec_V.size(1)
56
57     fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device)
58
59     fbt_V = fbt[:, :, :, None]
60     fbh_V = fbh[:, :, :, None]
61     t = fbt_V.clamp(min=0)
62     n = torch.arange(V.size(0), device=V.device)[:, None, None, None]
63     d = torch.arange(V.size(3), device=V.device)[None, None, None, :]
64     q = V[:, :, t0:t1][n, fbh_V, t, d]
65     rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0)
66
67     fbt_K = fbt[:, :, :, None]
68     fbh_K = fbh[:, :, :, None]
69     t = fbt_K.clamp(min=0)
70     n = torch.arange(K.size(0), device=K.device)[:, None, None, None]
71     d = torch.arange(K.size(3), device=K.device)[None, None, None, :]
72     q = K[:, :, t0:t1][n, fbh_K, t, d]
73     rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0)
74
75
76 ######################################################################
77
78 ######################################################################
79
80 2024 Jan 07 21:38:11 (from mygpt.py)
81
82             # insert_flash_back(self.rec_V,V,self.rec_K,K,t0,t1,CL,proba=self.proba_flashback / CL,)
83
84
85 ######################################################################
86
87 2024 Jan 09 14:24:42 (from mygpt.py)
88
89             # This piece of code makes the assumption that there is
90             # nothing informative before t0, otherwise we'd have to
91             # implement a cache for V and K too. This should not be
92             # too much of a problem since this is used only during
93             # train, where full sequence are available
94
95             # n = torch.arange(N, device=X.device)[:, None, None, None]
96             # t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
97             # dv = torch.arange(DV, device=X.device)[None, None, None, :]
98             # dk = torch.arange(DK, device=X.device)[None, None, None, :]
99
100             # u = (
101                 # torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
102             # ) * CL
103
104             # src_time = t - u - t0
105             # src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
106
107             # mask = (
108                 # torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
109             # ).long()
110
111             # self.rec_V[:, :, t0:t1] = (
112                 # mask * V[n, src_head, src_time, dv]
113                 # + (1 - mask) * self.rec_V[:, :, t0:t1]
114             # )
115
116             # self.rec_K[:, :, t0:t1] = (
117                 # mask * K[n, src_head, src_time, dk]
118                 # + (1 - mask) * self.rec_K[:, :, t0:t1]
119             # )
120
121 ######################################################################
122
123 2024 Jan 10 08:10:39 (from mygpt.py)
124
125         # That was a bad idea
126         # G = F.dropout(G, self.attention_dropout, self.training)
127
128
129 ######################################################################
130
131 2024 Jan 10 08:46:13 (from mygpt.py)
132
133         #################################################################
134         # Flashbacks. This version sucks, about to replace it 
135         if self.training and self.proba_flashback > 0.0:
136             warnings.warn("flash back", RuntimeWarning)
137             # This piece of code makes the assumption that there is
138             # nothing informative before t0, otherwise we'd have to
139             # implement a cache for V and K too. This should not be
140             # too much of a problem since this is used only during
141             # train, where full sequence are available
142
143             n = torch.arange(N, device=X.device)[:, None, None, None]
144             t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
145             dv = torch.arange(DV, device=X.device)[None, None, None, :]
146             dk = torch.arange(DK, device=X.device)[None, None, None, :]
147
148             u = (
149                 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
150             ) * CL
151
152             src_time = t - u - t0
153             src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
154
155             mask = (
156                 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
157             ).long()
158
159             self.rec_V[:, :, t0:t1] = (
160                 mask * V[n, src_head, src_time, dv]
161                 + (1 - mask) * self.rec_V[:, :, t0:t1]
162             )
163
164             self.rec_K[:, :, t0:t1] = (
165                 mask * K[n, src_head, src_time, dk]
166                 + (1 - mask) * self.rec_K[:, :, t0:t1]
167             )
168
169
170 ######################################################################
171
172 2024 Jan 13 13:38:31 (from mygpt.py)
173
174         g= F.sigmoid(self.b_G)
175         a=1-g
176
177         print(f"\n\nSANITY {a**T}\n")
178         exit(0)
179
180
181 ######################################################################
182
183 2024 Jan 14 13:39:37 (from mygpt.py)
184
185             epsilon = 0.5
186
187             dropout_head = (
188                 (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
189                 .expand_as(G)
190                 .float()
191             )
192
193             dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
194
195             dropout_active = (
196                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
197             ).long()
198
199             dropout_head *= dropout_active
200             dropout_tail *= dropout_active
201
202             G = (
203                 G
204                 + dropout_head * (1 - epsilon - G.detach())
205                 - dropout_tail * G.detach()
206             )
207
208 ######################################################################
209
210 2024 Jan 18 07:39:29 (from mygpt.py)
211
212 class Calibrator:
213     def __init__(self, w=None, b=None):
214         self.w = w
215         self.b = b
216         self.s, self.s_sq, self.n = 0, 0, 0
217         self.mean, self.std = 0, 0
218
219     def update(self, X):
220         X = X.detach()
221         self.s += X.sum(dim=0)
222         self.s_sq += X.pow(2).sum(dim=0)
223         self.n += X.size(0)
224
225     def moments(self):
226         mean = self.s / self.n
227         std = (self.s_sq / self.n - mean * mean).sqrt()
228         return mean, std
229
230     def normalize(self):
231         mean, std = self.moments()
232         if self.b is not None:
233             self.b.sub_(mean)
234         if self.w is not None:
235             self.w.div_(std)
236         result = mean - self.mean, std - self.std
237         self.mean, self.std = mean, std
238         self.s, self.s_sq, self.n = 0, 0, 0
239         return result
240
241
242
243 ######################################################################
244
245 2024 Jan 18 07:39:34 (from mygpt.py)
246
247         # self.calibrator_G = Calibrator()
248         # self.calibrator_rec_V = Calibrator()
249         # self.calibrator_rec_K = Calibrator()
250
251
252 ######################################################################
253
254 2024 Jan 18 07:39:37 (from mygpt.py)
255
256         # self.calibrator_G.update(G.reshape(-1, G.size(-1)))
257
258
259 ######################################################################
260
261 2024 Jan 18 07:39:42 (from mygpt.py)
262
263         # self.calibrator_rec_V.update(
264         # next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2))
265         # )
266         # self.calibrator_rec_K.update(
267         # next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2))
268         # )
269
270
271 ######################################################################
272
273 2024 Jan 18 07:47:12 (from mygpt.py)
274
275         ######################################################################
276         # Roll the gating indexes
277
278         # warnings.warn("rotating barrel", RuntimeWarning)
279
280         # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
281         # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
282         # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
283         # G = G.gather(dim=2, index=r_barrel.expand_as(G))
284
285
286 ######################################################################
287
288 2024 Jan 18 07:47:25 (from mygpt.py)
289
290         # warnings.warn("harmonic recurrence", RuntimeWarning)
291         # har = torch.arange(t0, t1, device = G.device).float() + 1
292         # A = har / (har + 1)
293         # G = G / har
294
295
296 ######################################################################
297
298 2024 Jan 18 08:46:18 (from mygpt.py)
299
300         # warnings.warn("softmax gating", RuntimeWarning)
301
302         # G = (
303         # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
304         # ).softmax(dim=2)
305
306 ######################################################################
307
308 2024 Jan 21 16:55:24 (from main.py)
309
310         with open("test.dat", "a") as f:
311             for m filter(lambda m: isinstance(m,mygpt.Catenn.Linear),model.modules()):
312                 for p in m.parameters() ]
313
314
315         for m in model.modules():
316             if isinstance(m, mygpt.Caterpillar):
317                 
318
319
320 ######################################################################
321
322 2024 Feb 13 22:53:52 (from mygpt.py)
323
324         ######################################################################
325         # Prepare the keys
326
327         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
328
329         warnings.warn("rotating key barrel", RuntimeWarning)
330         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
331         t_barrel = torch.arange(t0, t1, device=k_star.device)
332         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
333         l_barrel = (
334             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
335         ) % k_star.size(0)
336         k_star = k_star[l_barrel, t_barrel]
337
338
339 ######################################################################
340
341 2024 Feb 15 23:10:50 (from main.py)
342
343
344 def add_memex_v4(batches, memex_proba, marker_token):
345     for input in batches:
346         if torch.rand(1).item() < memex_proba:
347             t = (
348                 torch.arange(2 * input.size(1), device=input.device)[None, :]
349                 .expand(input.size(0), -1)
350                 .clone()
351             )
352
353             u = torch.rand(t.size(), device=t.device)
354             u[:, : input.size(1)] = 1.0
355             memex_v3_proba_fragment = 1 / 20
356             u = (u < memex_v3_proba_fragment).long()
357             v = u * torch.randint(input.size(1), u.size())
358             u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
359                 :, : input.size(1) - 1
360             ] * input.size(1)
361             u = u.cumsum().clamp(min=0)
362
363             u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
364             caterpillar_length = args.nb_lines // args.caterpillar_height
365             u1 = (
366                 u0
367                 + torch.randint(
368                     caterpillar_length, (input.size(0), 1), device=input.device
369                 )
370                 + 1
371             )
372
373             m0 = (t < u0).long()
374             m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
375
376             t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
377             m = (t < 0).long()
378             n = torch.arange(input.size(0), device=input.device)[:, None].expand(
379                 -1, t.size(1)
380             )
381
382             new_input = input[n, t.clamp(min=0)]
383             new_input = (1 - m) * new_input + m * (marker_token)
384
385             yield new_input
386
387         yield input
388
389
390
391 ######################################################################
392
393 2024 Feb 16 17:07:48 (from main.py)
394
395                 # ||gn + lambda * gm|| = max(||gn||,||gm||)
396                 # ||gn||^2 + lambda<gn,gm> + lambda^2||gm||^2 = max(||gn||^2,||gm||^2)
397                 # A = ||gm||^2 B = <gn,gm> C = ||gn||^2 - max(||gn||^2, ||gm||^2)
398
399 ######################################################################
400
401 2024 Feb 16 17:07:51 (from main.py)
402
403                 # A,B,C = gmgm, gngm, gngn - max(gngn,gmgm)
404                 # Delta = B*B - 4*A*C
405                 # if(delta >= 0):
406                     # l = ( -B - sqrt(Delta))/(2*A)
407                 # ||gn||+l*rho*||gm|| = max(||gn||,rho*||gm||)