Update.
[mygptrnn.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
9
10 import torch, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 # torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
15
16 import ffutils
17 import mygpt, tasks, problems
18
19 ######################################################################
20
21
22 def str2bool(x):
23     x = x.lower()
24     if x in {"1", "true", "yes"}:
25         return True
26     elif x in {"0", "false", "no"}:
27         return False
28     else:
29         raise ValueError
30
31
32 parser = argparse.ArgumentParser(
33     description="An implementation of GPT with cache.",
34     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
35 )
36
37 parser.add_argument(
38     "--task",
39     type=str,
40     default="twotargets",
41     help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
42 )
43
44 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
45
46 parser.add_argument("--result_dir", type=str, default=None)
47
48 parser.add_argument("--seed", type=int, default=0)
49
50 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
51
52 parser.add_argument("--force_cpu", type=str2bool, default=False)
53
54 ########################################
55
56 parser.add_argument("--nb_epochs", type=int, default=25)
57
58 parser.add_argument("--physical_batch_size", type=int, default=None)
59
60 parser.add_argument("--batch_size", type=int, default=25)
61
62 parser.add_argument("--nb_train_samples", type=int, default=None)
63
64 parser.add_argument("--nb_test_samples", type=int, default=None)
65
66 parser.add_argument("--optim", type=str, default="adam")
67
68 ########################################
69
70 parser.add_argument("--nb_warmup_iter", type=int, default=100)
71
72 parser.add_argument("--nb_decay_iter", type=int, default=5000)
73
74 parser.add_argument("--learning_rate", type=float, default=6e-4)
75
76 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
77
78 # legacy
79
80 parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
81
82 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
83
84 parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
85
86 parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
87
88 ########################################
89
90 parser.add_argument("--model", type=str, default=None)
91
92 parser.add_argument("--attention", type=str, default=None)
93
94 parser.add_argument("--memex_proba", type=float, default=0)
95
96 parser.add_argument("--memex_nb_epochs", type=float, default=None)
97
98 parser.add_argument("--dim_model", type=int, default=None)
99
100 parser.add_argument("--dim_keys", type=int, default=None)
101
102 parser.add_argument("--dim_hidden", type=int, default=None)
103
104 parser.add_argument("--nb_heads", type=int, default=None)
105
106 parser.add_argument("--nb_lines", type=int, default=None)
107
108 parser.add_argument("--caterpillar_height", type=int, default=None)
109
110 parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
111
112 parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
113
114 parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
115
116 parser.add_argument("--rho_inner_loss", type=float, default=0.0)
117
118 parser.add_argument("--nb_blocks", type=int, default=None)
119
120 parser.add_argument("--dropout", type=float, default=0.1)
121
122 ########################################
123
124 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
125
126 parser.add_argument("--no_checkpoint", action="store_true", default=False)
127
128 parser.add_argument("--continue_training", action="store_true", default=False)
129
130 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
131
132 ##############################
133 # rpl options
134
135 parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
136
137 parser.add_argument("--rpl_max_input", type=int, default=9)
138
139 parser.add_argument("--rpl_prog_len", type=int, default=8)
140
141 parser.add_argument("--rpl_nb_runs", type=int, default=5)
142
143 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
144
145 ##############################
146 # grid options
147
148 parser.add_argument("--grid_size", type=int, default=6)
149
150 parser.add_argument("--grid_nb_colors", type=int, default=6)
151
152 parser.add_argument("--grid_nb_shapes", type=int, default=6)
153
154 ##############################
155 # picoclvr options
156
157 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
158
159 parser.add_argument("--picoclvr_height", type=int, default=12)
160
161 parser.add_argument("--picoclvr_width", type=int, default=16)
162
163 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
164
165 ##############################
166 # Maze options
167
168 parser.add_argument("--maze_height", type=int, default=13)
169
170 parser.add_argument("--maze_width", type=int, default=21)
171
172 parser.add_argument("--maze_nb_walls", type=int, default=15)
173
174 ##############################
175 # Snake options
176
177 parser.add_argument("--snake_height", type=int, default=9)
178
179 parser.add_argument("--snake_width", type=int, default=12)
180
181 parser.add_argument("--snake_nb_colors", type=int, default=5)
182
183 parser.add_argument("--snake_length", type=int, default=200)
184
185 ##############################
186 # Stack options
187
188 parser.add_argument("--stack_nb_steps", type=int, default=100)
189
190 parser.add_argument("--stack_nb_stacks", type=int, default=3)
191
192 parser.add_argument("--stack_nb_digits", type=int, default=3)
193
194 parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
195
196 ##############################
197 # Expr options
198
199 parser.add_argument("--expr_nb_variables", type=int, default=5)
200
201 parser.add_argument("--expr_sequence_length", type=int, default=40)
202
203 parser.add_argument("--expr_operand_max", type=int, default=9)
204
205 parser.add_argument("--expr_result_max", type=int, default=99)
206
207 parser.add_argument("--expr_input_file", type=str, default=None)
208
209 ##############################
210 # Memory
211
212 parser.add_argument("--memory_len_total", type=int, default=32)
213
214 ##############################
215 # Mixing
216
217 parser.add_argument("--mixing_hard", action="store_true", default=False)
218
219 parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
220
221 ######################################################################
222
223 # args = parser.parse_args()
224
225 args, sup_args = parser.parse_known_args()
226
227 sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
228
229 if args.result_dir is None:
230     args.result_dir = f"results_{args.task}_{args.model}"
231
232 ######################################################################
233
234 if not args.force_cpu and torch.cuda.is_available():
235     device = torch.device("cuda")
236     torch.backends.cuda.matmul.allow_tf32 = True
237 else:
238     device = torch.device("cpu")
239
240 ######################################################################
241
242 default_task_args = {
243     "addition": {
244         "model": "352M",
245         "physical_batch_size": 25,
246         "nb_train_samples": 250000,
247         "nb_test_samples": 10000,
248     },
249     "byheart": {
250         "model": "37M",
251         "physical_batch_size": 25,
252         "nb_train_samples": 50000,
253         "nb_test_samples": 10000,
254     },
255     "expr": {
256         "model": "352M",
257         "physical_batch_size": 25,
258         "nb_train_samples": 2500000,
259         "nb_test_samples": 10000,
260     },
261     "grid": {
262         "model": "37M",
263         "physical_batch_size": 25,
264         "nb_train_samples": 250000,
265         "nb_test_samples": 10000,
266     },
267     "qmlp": {
268         "model": "37M",
269         "physical_batch_size": 10,
270         "nb_train_samples": 100000,
271         "nb_test_samples": 1000,
272     },
273     "guessop": {
274         "model": "352M",
275         "physical_batch_size": 25,
276         "nb_train_samples": 1000000,
277         "nb_test_samples": 10000,
278     },
279     "learnop": {
280         "model": "37M",
281         "physical_batch_size": 25,
282         "nb_train_samples": 50000,
283         "nb_test_samples": 10000,
284     },
285     "maze": {
286         "model": "37M",
287         "physical_batch_size": 5,
288         "nb_train_samples": 100000,
289         "nb_test_samples": 10000,
290     },
291     "picoclvr": {
292         "model": "37M",
293         "physical_batch_size": 25,
294         "nb_train_samples": 250000,
295         "nb_test_samples": 10000,
296     },
297     "rpl": {
298         "model": "352M",
299         "physical_batch_size": 5,
300         "nb_train_samples": 2500000,
301         "nb_test_samples": 10000,
302     },
303     "snake": {
304         "model": "37M",
305         "physical_batch_size": 25,
306         "nb_train_samples": 250000,
307         "nb_test_samples": 10000,
308     },
309     "stack": {
310         "model": "37M",
311         "physical_batch_size": 25,
312         "nb_train_samples": 100000,
313         "nb_test_samples": 1000,
314     },
315     "twotargets": {
316         "model": "37M",
317         "physical_batch_size": 25,
318         "nb_train_samples": 50000,
319         "nb_test_samples": 10000,
320     },
321     "memory": {
322         "model": "37M",
323         "physical_batch_size": 25,
324         "nb_train_samples": 25000,
325         "nb_test_samples": 10000,
326     },
327     "mixing": {
328         "model": "37M",
329         "physical_batch_size": 25,
330         "nb_train_samples": 250000,
331         "nb_test_samples": 10000,
332     },
333     "mnist": {
334         "model": "37M",
335         "physical_batch_size": 5,
336         "nb_train_samples": 60000,
337         "nb_test_samples": 10000,
338     },
339 }
340
341 if args.task in default_task_args:
342     for k, v in default_task_args[args.task].items():
343         if getattr(args, k) is None:
344             setattr(args, k, v)
345
346 ######################################################################
347
348 default_model_args = {
349     "17K": {
350         "attention": "mha",
351         "dim_model": 32,
352         "dim_keys": 32,
353         "dim_hidden": 32,
354         "nb_heads": 2,
355         "nb_blocks": 2,
356     },
357     "17K-C": {
358         "attention": "caterpillar",
359         "dim_model": 32,
360         "dim_keys": 32,
361         "dim_hidden": 32,
362         "nb_heads": 2,
363         "nb_lines": 16,
364         "caterpillar_height": 4,
365         "nb_blocks": 2,
366     },
367     "4M": {
368         "attention": "mha",
369         "dim_model": 256,
370         "dim_keys": 32,
371         "dim_hidden": 1024,
372         "nb_heads": 4,
373         "nb_blocks": 6,
374     },
375     "4M-C": {
376         "attention": "caterpillar",
377         "dim_model": 256,
378         "dim_keys": 32,
379         "dim_hidden": 1024,
380         "nb_heads": 4,
381         "nb_lines": 32,
382         "caterpillar_height": 4,
383         "nb_blocks": 6,
384     },
385     "37M": {
386         "attention": "mha",
387         "dim_model": 512,
388         "dim_keys": 64,
389         "dim_hidden": 2048,
390         "nb_heads": 8,
391         "nb_blocks": 12,
392     },
393     "37M-C": {
394         "attention": "caterpillar",
395         "dim_model": 512,
396         "dim_keys": 64,
397         "dim_hidden": 2048,
398         "nb_heads": 8,
399         "nb_lines": 256,
400         "caterpillar_height": 32,
401         "nb_blocks": 12,
402     },
403     "122M": {
404         "attention": "mha",
405         "dim_model": 768,
406         "dim_keys": 64,
407         "dim_hidden": 2048,
408         "nb_heads": 8,
409         "nb_blocks": 24,
410     },
411     "122M-C": {
412         "attention": "caterpillar",
413         "dim_model": 768,
414         "dim_keys": 64,
415         "dim_hidden": 2048,
416         "nb_heads": 8,
417         "nb_lines": 128,
418         "nb_blocks": 24,
419     },
420     "352M": {
421         "attention": "mha",
422         "dim_model": 1024,
423         "dim_keys": 64,
424         "dim_hidden": 2048,
425         "nb_heads": 8,
426         "nb_blocks": 48,
427     },
428     "352M-C": {
429         "attention": "caterpillar",
430         "dim_model": 1024,
431         "dim_keys": 64,
432         "dim_hidden": 2048,
433         "nb_heads": 8,
434         "nb_lines": 128,
435         "nb_blocks": 48,
436     },
437 }
438
439 if args.model in default_model_args:
440     for k, v in default_model_args[args.model].items():
441         if getattr(args, k) is None:
442             setattr(args, k, v)
443 else:
444     raise ValueError(f"Unknown model {args.model}")
445
446 ######################################################################
447
448 try:
449     os.mkdir(args.result_dir)
450 except FileExistsError:
451     if not args.continue_training:
452         print(f"result directory {args.result_dir} already exists")
453         exit(1)
454
455 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
456
457 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
458
459 if args.seed >= 0:
460     # torch.backends.cudnn.deterministic = True
461     # torch.backends.cudnn.benchmark = False
462     # torch.use_deterministic_algorithms(True)
463     torch.manual_seed(args.seed)
464     if torch.cuda.is_available():
465         torch.cuda.manual_seed_all(args.seed)
466
467 ######################################################################
468
469
470 def log_string(s):
471     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
472
473     if log_file is not None:
474         log_file.write(t + s + "\n")
475         log_file.flush()
476
477     print(t + s)
478     sys.stdout.flush()
479
480
481 with os.popen("sha256sum *.py") as f:
482     for l in f:
483         log_string(f"sha256sum {l.strip()}")
484
485 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
486 os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
487
488 log_string(f"argv {' '.join(sys.argv)}")
489
490 for n in vars(args):
491     log_string(f"args.{n} {getattr(args, n)}")
492
493 for k, v in sup_args.items():
494     log_string(f'sup_args["{k}"] "{v}"')
495
496
497 ######################################################################
498
499
500 def get_lr(n_epoch, it):
501     if args.legacy_lr_schedule:
502         # my crude scheduling to compare to previous baseline, added
503         # warmup though
504
505         if it < args.nb_warmup_iter:
506             return args.legacy_large_lr * it / args.nb_warmup_iter
507         elif n_epoch < args.legacy_nb_epoch_large_lr:
508             return args.legacy_large_lr
509         else:
510             return args.legacy_small_lr
511
512     # from nanoGPT
513
514     # 1) linear warmup for warmup_iter steps
515     if it < args.nb_warmup_iter:
516         return args.learning_rate * it / args.nb_warmup_iter
517     # 2) if it > nb_decay_iter, return min learning rate
518     if it > args.nb_decay_iter:
519         return args.min_learning_rate
520     # 3) in between, use cosine decay down to min learning rate
521     decay_ratio = (it - args.nb_warmup_iter) / (
522         args.nb_decay_iter - args.nb_warmup_iter
523     )
524     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
525     return args.min_learning_rate + coeff * (
526         args.learning_rate - args.min_learning_rate
527     )
528
529
530 ######################################################################
531
532
533 def add_memex_v2(batches, memex_proba, marker_token):
534     for input in batches:
535         if torch.rand(1).item() < memex_proba:
536             t = (
537                 torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
538                 .expand(input.size(0), -1)
539                 .clone()
540             )
541
542             u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
543             caterpillar_length = args.nb_lines // args.caterpillar_height
544             u1 = (
545                 u0
546                 + torch.randint(
547                     caterpillar_length, (input.size(0), 1), device=input.device
548                 )
549                 + 1
550             )
551
552             m0 = (t < u0).long()
553             m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
554
555             t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
556             m = (t < 0).long()
557             n = torch.arange(input.size(0), device=input.device)[:, None].expand(
558                 -1, t.size(1)
559             )
560
561             new_input = input[n, t.clamp(min=0)]
562             new_input = (1 - m) * new_input + m * (marker_token)
563
564             yield new_input
565
566         yield input
567
568
569 def add_memex_v3(batches, memex_proba, marker_token):
570     for input in batches:
571         if torch.rand(1).item() < memex_proba:
572             t = (
573                 torch.arange(2 * input.size(1), device=input.device)[None, :]
574                 .expand(input.size(0), -1)
575                 .clone()
576             )
577
578             u = torch.rand(t.size(), device=t.device)
579             u[:, : input.size(1)] = 1.0
580             memex_v3_proba_fragment = 1 / 20
581             u = (u < memex_v3_proba_fragment).long()
582             v = u * torch.randint(input.size(1), u.size())
583             u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
584                 :, : input.size(1) - 1
585             ] * input.size(1)
586             u = u.cumsum().clamp(min=0)
587
588             u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
589             caterpillar_length = args.nb_lines // args.caterpillar_height
590             u1 = (
591                 u0
592                 + torch.randint(
593                     caterpillar_length, (input.size(0), 1), device=input.device
594                 )
595                 + 1
596             )
597
598             m0 = (t < u0).long()
599             m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
600
601             t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
602             m = (t < 0).long()
603             n = torch.arange(input.size(0), device=input.device)[:, None].expand(
604                 -1, t.size(1)
605             )
606
607             new_input = input[n, t.clamp(min=0)]
608             new_input = (1 - m) * new_input + m * (marker_token)
609
610             yield new_input
611
612         yield input
613
614
615 ######################################################################
616
617 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
618
619
620 def picoclvr_pruner_horizontal_green(p):
621     return not ("green" in p and ("left" in p or "right" in p))
622
623
624 picoclvr_pruner_train = (
625     picoclvr_pruner_horizontal_green
626     if args.picocvlr_prune_properties in {"train+eval"}
627     else None
628 )
629
630 picoclvr_pruner_eval = (
631     (lambda p: not picoclvr_pruner_horizontal_green(p))
632     if args.picocvlr_prune_properties in {"train+eval", "eval"}
633     else None
634 )
635
636 ######################################################################
637
638 device_data = device
639
640 if args.task == "byheart":
641     task = tasks.SandBox(
642         problem=problems.ProblemByHeart(),
643         nb_train_samples=args.nb_train_samples,
644         nb_test_samples=args.nb_test_samples,
645         batch_size=args.physical_batch_size,
646         logger=log_string,
647         device=device_data,
648     )
649     args.max_percents_of_test_in_train = -1
650
651 elif args.task == "learnop":
652     task = tasks.SandBox(
653         problem=problems.ProblemLearnOperator(),
654         nb_train_samples=args.nb_train_samples,
655         nb_test_samples=args.nb_test_samples,
656         batch_size=args.physical_batch_size,
657         logger=log_string,
658         device=device_data,
659     )
660
661
662 elif args.task == "guessop":
663     task = tasks.SandBox(
664         problem=problems.ProblemGuessOperator(),
665         nb_train_samples=args.nb_train_samples,
666         nb_test_samples=args.nb_test_samples,
667         batch_size=args.physical_batch_size,
668         logger=log_string,
669         device=device_data,
670     )
671
672
673 elif args.task == "twotargets":
674     task = tasks.SandBox(
675         problem=problems.ProblemTwoTargets(),
676         nb_train_samples=args.nb_train_samples,
677         nb_test_samples=args.nb_test_samples,
678         batch_size=args.physical_batch_size,
679         logger=log_string,
680         device=device_data,
681     )
682
683 elif args.task == "memory":
684     task = tasks.SandBox(
685         problem=problems.ProblemMemory(len_total=args.memory_len_total),
686         nb_train_samples=args.nb_train_samples,
687         nb_test_samples=args.nb_test_samples,
688         batch_size=args.physical_batch_size,
689         logger=log_string,
690         device=device_data,
691     )
692
693 elif args.task == "mixing":
694     task = tasks.SandBox(
695         problem=problems.ProblemMixing(
696             hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
697         ),
698         nb_train_samples=args.nb_train_samples,
699         nb_test_samples=args.nb_test_samples,
700         batch_size=args.physical_batch_size,
701         logger=log_string,
702         device=device_data,
703     )
704
705 elif args.task == "addition":
706     task = tasks.SandBox(
707         problem=problems.ProblemAddition(),
708         nb_train_samples=args.nb_train_samples,
709         nb_test_samples=args.nb_test_samples,
710         batch_size=args.physical_batch_size,
711         logger=log_string,
712         device=device_data,
713     )
714
715 elif args.task == "picoclvr":
716     task = tasks.PicoCLVR(
717         nb_train_samples=args.nb_train_samples,
718         nb_test_samples=args.nb_test_samples,
719         batch_size=args.physical_batch_size,
720         height=args.picoclvr_height,
721         width=args.picoclvr_width,
722         nb_colors=args.picoclvr_nb_colors,
723         logger=log_string,
724         device=device_data,
725         pruner_train=picoclvr_pruner_train,
726         pruner_eval=picoclvr_pruner_eval,
727     )
728
729 elif args.task == "mnist":
730     task = tasks.MNIST(
731         nb_train_samples=args.nb_train_samples,
732         nb_test_samples=args.nb_test_samples,
733         batch_size=args.physical_batch_size,
734         device=device_data,
735     )
736
737 elif args.task == "maze":
738     task = tasks.Maze(
739         nb_train_samples=args.nb_train_samples,
740         nb_test_samples=args.nb_test_samples,
741         batch_size=args.physical_batch_size,
742         height=args.maze_height,
743         width=args.maze_width,
744         nb_walls=args.maze_nb_walls,
745         device=device_data,
746     )
747
748 elif args.task == "snake":
749     task = tasks.Snake(
750         nb_train_samples=args.nb_train_samples,
751         nb_test_samples=args.nb_test_samples,
752         batch_size=args.physical_batch_size,
753         height=args.snake_height,
754         width=args.snake_width,
755         nb_colors=args.snake_nb_colors,
756         length=args.snake_length,
757         prompt_length=args.snake_length // 2,
758         device=device_data,
759     )
760
761 elif args.task == "stack":
762     task = tasks.Stack(
763         nb_train_samples=args.nb_train_samples,
764         nb_test_samples=args.nb_test_samples,
765         batch_size=args.physical_batch_size,
766         logger=log_string,
767         nb_steps=args.stack_nb_steps,
768         nb_stacks=args.stack_nb_stacks,
769         nb_digits=args.stack_nb_digits,
770         fraction_values_for_train=args.stack_fraction_values_for_train,
771         device=device_data,
772     )
773
774 elif args.task == "expr":
775     task = tasks.Expr(
776         nb_train_samples=args.nb_train_samples,
777         nb_test_samples=args.nb_test_samples,
778         nb_variables=args.expr_nb_variables,
779         sequence_length=args.expr_sequence_length,
780         operand_max=args.expr_operand_max,
781         result_max=args.expr_result_max,
782         batch_size=args.physical_batch_size,
783         device=device_data,
784     )
785
786 elif args.task == "rpl":
787     task = tasks.RPL(
788         nb_train_samples=args.nb_train_samples,
789         nb_test_samples=args.nb_test_samples,
790         batch_size=args.physical_batch_size,
791         nb_starting_values=args.rpl_nb_starting_values,
792         max_input=args.rpl_max_input,
793         prog_len=args.rpl_prog_len,
794         nb_runs=args.rpl_nb_runs,
795         no_prog=args.rpl_no_prog,
796         logger=log_string,
797         device=device_data,
798     )
799
800 elif args.task == "grid":
801     task = tasks.Grid(
802         nb_train_samples=args.nb_train_samples,
803         nb_test_samples=args.nb_test_samples,
804         batch_size=args.physical_batch_size,
805         size=args.grid_size,
806         nb_shapes=args.grid_nb_shapes,
807         nb_colors=args.grid_nb_colors,
808         logger=log_string,
809         device=device_data,
810     )
811
812 elif args.task == "qmlp":
813     task = tasks.QMLP(
814         nb_train_samples=args.nb_train_samples,
815         nb_test_samples=args.nb_test_samples,
816         batch_size=args.physical_batch_size,
817         result_dir=args.result_dir,
818         logger=log_string,
819         device=device_data,
820     )
821
822 else:
823     raise ValueError(f"Unknown task {args.task}")
824
825 ######################################################################
826
827 log_string(f"device {device}")
828
829 vocabulary_size = task.vocabulary_size()
830
831 if args.memex_proba > 0:
832     vocabulary_size += 1
833
834 log_string(f"vocabulary_size {vocabulary_size}")
835
836 ##############################
837
838 model = mygpt.MyGPT(
839     vocabulary_size=vocabulary_size,
840     dim_model=args.dim_model,
841     dim_keys=args.dim_keys,
842     dim_hidden=args.dim_hidden,
843     nb_heads=args.nb_heads,
844     nb_lines=args.nb_lines,
845     caterpillar_height=args.caterpillar_height,
846     nb_blocks=args.nb_blocks,
847     causal=True,
848     dropout=args.dropout,
849     attention_layer=args.attention,
850     logger=log_string,
851     args=args,
852 )
853
854 model.to(device)
855
856 nb_parameters = sum(p.numel() for p in model.parameters())
857 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
858
859 ######################################################################
860
861 nb_epochs_finished = 0
862
863 if args.no_checkpoint:
864     log_string(f"not trying to load checkpoint.")
865
866 else:
867     try:
868         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
869         checkpoint = torch.load(checkpoint_name)
870         nb_epochs_finished = checkpoint["nb_epochs_finished"]
871         model.load_state_dict(checkpoint["model_state"])
872         torch.set_rng_state(checkpoint["rng_state"])
873         if torch.cuda.is_available():
874             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
875
876         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
877
878     except FileNotFoundError:
879         log_string("starting from scratch.")
880
881     except:
882         log_string("error when loading the checkpoint.")
883         exit(1)
884
885 ######################################################################
886
887 if args.task == "expr" and args.expr_input_file is not None:
888     task.produce_results(
889         n_epoch=nb_epochs_finished,
890         model=model,
891         result_dir=args.result_dir,
892         logger=log_string,
893         deterministic_synthesis=args.deterministic_synthesis,
894         input_file=args.expr_input_file,
895     )
896
897     exit(0)
898
899 ######################################################################
900
901 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
902
903 # Compute the entropy of the training tokens
904
905 token_count = 0
906 for input in task.batches(split="train"):
907     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
908 token_probas = token_count / token_count.sum()
909 entropy = -torch.xlogy(token_probas, token_probas).sum()
910 train_set_perplexity = math.exp(entropy)
911
912 ######################################################################
913 # A bit of paranoia never hurts
914
915 if args.max_percents_of_test_in_train >= 0:
916
917     def subsets_as_tuples(batches, cs):
918         s = set()
919         for batch in batches:
920             for x in batch:
921                 s.add(tuple([v.item() for v in x]))
922                 if len(s) == cs:
923                     yield s
924                     s = set()
925         yield s
926
927     nb_test, nb_in_train = 0, 0
928     for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
929         in_train = set()
930         for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
931             in_train.update(test_subset.intersection(train_subset))
932         nb_in_train += len(in_train)
933         nb_test += len(test_subset)
934
935     log_string(
936         f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
937     )
938
939     assert (
940         nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
941     ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
942
943 ##############################
944
945 if "calibrate" in sup_args:
946     for input in task.batches(split="train", desc="calibrate"):
947         input = input.to(device)
948         output = model(mygpt.BracketedSequence(input)).x
949
950     for n, m in model.named_modules():
951         for a in dir(m):
952             x = getattr(m, a)
953             if isinstance(x, mygpt.Calibrator):
954                 print(f"####### ${n} | ${a} ########################")
955                 mean, std = x.moments()
956                 print("mean\n", mean, "\n")
957                 print("std\n", std, "\n")
958                 print(f"############################################\n\n")
959
960     exit(0)
961
962 ##############################
963
964 nb_samples_seen = 0
965
966 if nb_epochs_finished >= nb_epochs:
967     task.produce_results(
968         n_epoch=nb_epochs_finished,
969         model=model,
970         result_dir=args.result_dir,
971         logger=log_string,
972         deterministic_synthesis=args.deterministic_synthesis,
973     )
974
975 time_pred_result = datetime.datetime.now()
976
977 it = 0
978
979 n_batch = 0
980
981 for n_epoch in range(nb_epochs_finished, nb_epochs):
982     if args.optim == "sgd":
983         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
984     elif args.optim == "adam":
985         optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
986     elif args.optim == "adamw":
987         optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
988     else:
989         raise ValueError(f"Unknown optimizer {args.optim}.")
990
991     model.train()
992
993     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
994
995     memex_proba = (
996         args.memex_proba
997         if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
998         else 0.0
999     )
1000
1001     log_string(f"memex_proba {memex_proba}")
1002
1003     train_batches = add_memex_v2(
1004         batches=task.batches(split="train"),
1005         memex_proba=memex_proba,
1006         marker_token=vocabulary_size - 1,
1007     )
1008
1009     def add_none(it):
1010         for x in it:
1011             yield x
1012         yield None
1013
1014     nb_acc_samples = 0
1015
1016     for input in add_none(train_batches):
1017         if input is not None:
1018             model.reset_inner_loss()
1019             input = input.to(device)
1020
1021             output = model(mygpt.BracketedSequence(input)).x
1022             loss = F.cross_entropy(output.transpose(1, 2), input)
1023             inner_loss = model.get_inner_loss()
1024
1025             acc_train_loss += loss.item() * input.size(0)
1026             acc_train_inner_loss += inner_loss.item() * input.size(0)
1027
1028             nb_train_samples += input.size(0)
1029             nb_samples_seen += input.size(0)
1030
1031             total_loss = loss + (
1032                 args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
1033             )
1034
1035             it += 1
1036             lr = get_lr(n_epoch, it)
1037             for param_group in optimizer.param_groups:
1038                 param_group["lr"] = lr
1039
1040                 # log_string(f"learning_rate {lr}")
1041
1042             total_loss.backward()
1043             nb_acc_samples += input.size(0)
1044
1045         if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
1046             assert nb_acc_samples <= args.batch_size
1047             optimizer.step()
1048             grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
1049             loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
1050             optimizer.zero_grad()
1051             nb_acc_samples = 0
1052
1053         n_batch += 1
1054
1055     with torch.autograd.no_grad():
1056         model.eval()
1057
1058         nb_test_samples, acc_test_loss = 0, 0.0
1059
1060         for input in task.batches(split="test"):
1061             input = input.to(device)
1062
1063             output = model(mygpt.BracketedSequence(input)).x
1064             loss = F.cross_entropy(output.transpose(1, 2), input)
1065             acc_test_loss += loss.item() * input.size(0)
1066             nb_test_samples += input.size(0)
1067
1068         log_string(
1069             f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
1070         )
1071
1072         task.produce_results(
1073             n_epoch=n_epoch,
1074             model=model,
1075             result_dir=args.result_dir,
1076             logger=log_string,
1077             deterministic_synthesis=args.deterministic_synthesis,
1078         )
1079
1080         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1081         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1082
1083         log_string(
1084             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1085         )
1086
1087         time_current_result = datetime.datetime.now()
1088         log_string(
1089             f"next_result {time_current_result + (time_current_result - time_pred_result)}"
1090         )
1091         time_pred_result = time_current_result
1092
1093     checkpoint = {
1094         "nb_epochs_finished": n_epoch + 1,
1095         "model_state": model.state_dict(),
1096         "rng_state": torch.get_rng_state(),
1097     }
1098
1099     if torch.cuda.is_available():
1100         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1101
1102     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1103     torch.save(checkpoint, checkpoint_name)
1104     log_string(f"saved checkpoint {checkpoint_name}")
1105
1106 ######################################################################