Update.
[mygptrnn.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
9
10 import torch, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import ffutils
15 import mygpt, tasks, problems
16
17 ######################################################################
18
19
20 def str2bool(x):
21     x = x.lower()
22     if x in {"1", "true", "yes"}:
23         return True
24     elif x in {"0", "false", "no"}:
25         return False
26     else:
27         raise ValueError
28
29
30 parser = argparse.ArgumentParser(
31     description="An implementation of GPT with cache.",
32     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
33 )
34
35 parser.add_argument(
36     "--task",
37     type=str,
38     default="twotargets",
39     help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
40 )
41
42 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
43
44 parser.add_argument("--result_dir", type=str, default=None)
45
46 parser.add_argument("--seed", type=int, default=0)
47
48 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
49
50 parser.add_argument("--force_cpu", type=str2bool, default=False)
51
52 ########################################
53
54 parser.add_argument("--nb_epochs", type=int, default=50)
55
56 parser.add_argument("--batch_size", type=int, default=None)
57
58 parser.add_argument("--nb_train_samples", type=int, default=None)
59
60 parser.add_argument("--nb_test_samples", type=int, default=None)
61
62 parser.add_argument("--optim", type=str, default="adam")
63
64 ########################################
65
66 parser.add_argument("--nb_warmup_iter", type=int, default=100)
67
68 parser.add_argument("--nb_decay_iter", type=int, default=5000)
69
70 parser.add_argument("--learning_rate", type=float, default=6e-4)
71
72 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
73
74 # legacy
75
76 parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
77
78 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
79
80 parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
81
82 parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
83
84 ########################################
85
86 parser.add_argument("--model", type=str, default=None)
87
88 parser.add_argument("--attention", type=str, default=None)
89
90 parser.add_argument("--dim_model", type=int, default=None)
91
92 parser.add_argument("--dim_keys", type=int, default=None)
93
94 parser.add_argument("--dim_hidden", type=int, default=None)
95
96 parser.add_argument("--nb_heads", type=int, default=None)
97
98 parser.add_argument("--nb_lines", type=int, default=None)
99
100 parser.add_argument("--caterpillar_height", type=int, default=None)
101
102 parser.add_argument("--rho", type=float, default=0.0)
103
104 parser.add_argument("--nb_blocks", type=int, default=None)
105
106 parser.add_argument("--dropout", type=float, default=0.1)
107
108 ########################################
109
110 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
111
112 parser.add_argument("--no_checkpoint", action="store_true", default=False)
113
114 parser.add_argument("--continue_training", action="store_true", default=False)
115
116 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
117
118 ##############################
119 # rpl options
120
121 parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
122
123 parser.add_argument("--rpl_max_input", type=int, default=9)
124
125 parser.add_argument("--rpl_prog_len", type=int, default=8)
126
127 parser.add_argument("--rpl_nb_runs", type=int, default=5)
128
129 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
130
131 ##############################
132 # grid options
133
134 parser.add_argument("--grid_size", type=int, default=6)
135
136 ##############################
137 # picoclvr options
138
139 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
140
141 parser.add_argument("--picoclvr_height", type=int, default=12)
142
143 parser.add_argument("--picoclvr_width", type=int, default=16)
144
145 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
146
147 ##############################
148 # Maze options
149
150 parser.add_argument("--maze_height", type=int, default=13)
151
152 parser.add_argument("--maze_width", type=int, default=21)
153
154 parser.add_argument("--maze_nb_walls", type=int, default=15)
155
156 ##############################
157 # Snake options
158
159 parser.add_argument("--snake_height", type=int, default=9)
160
161 parser.add_argument("--snake_width", type=int, default=12)
162
163 parser.add_argument("--snake_nb_colors", type=int, default=5)
164
165 parser.add_argument("--snake_length", type=int, default=200)
166
167 ##############################
168 # Stack options
169
170 parser.add_argument("--stack_nb_steps", type=int, default=100)
171
172 parser.add_argument("--stack_nb_stacks", type=int, default=3)
173
174 parser.add_argument("--stack_nb_digits", type=int, default=3)
175
176 parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
177
178 ##############################
179 # Expr options
180
181 parser.add_argument("--expr_nb_variables", type=int, default=5)
182
183 parser.add_argument("--expr_sequence_length", type=int, default=40)
184
185 parser.add_argument("--expr_operand_max", type=int, default=9)
186
187 parser.add_argument("--expr_result_max", type=int, default=99)
188
189 parser.add_argument("--expr_input_file", type=str, default=None)
190
191 ##############################
192 # Memory
193
194 parser.add_argument("--memory_len_total", type=int, default=32)
195
196 ##############################
197 # Mixing
198
199 parser.add_argument("--mixing_hard", action="store_true", default=False)
200
201 parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
202
203 ######################################################################
204
205 # args = parser.parse_args()
206
207 args, sup_args = parser.parse_known_args()
208
209 sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
210
211 if args.result_dir is None:
212     args.result_dir = f"results_{args.task}_{args.model}"
213
214 ######################################################################
215
216 if not args.force_cpu and torch.cuda.is_available():
217     device = torch.device("cuda")
218     torch.backends.cuda.matmul.allow_tf32 = True
219 else:
220     device = torch.device("cpu")
221
222 ######################################################################
223
224 default_task_args = {
225     "addition": {
226         "model": "352M",
227         "batch_size": 25,
228         "nb_train_samples": 250000,
229         "nb_test_samples": 10000,
230     },
231     "byheart": {
232         "model": "37M",
233         "batch_size": 25,
234         "nb_train_samples": 50000,
235         "nb_test_samples": 10000,
236     },
237     "expr": {
238         "model": "352M",
239         "batch_size": 25,
240         "nb_train_samples": 2500000,
241         "nb_test_samples": 10000,
242     },
243     "grid": {
244         "model": "37M",
245         "batch_size": 25,
246         "nb_train_samples": 250000,
247         "nb_test_samples": 10000,
248     },
249     "qmlp": {
250         "model": "37M",
251         "batch_size": 10,
252         "nb_train_samples": 100000,
253         "nb_test_samples": 1000,
254     },
255     "guessop": {
256         "model": "352M",
257         "batch_size": 25,
258         "nb_train_samples": 1000000,
259         "nb_test_samples": 10000,
260     },
261     "learnop": {
262         "model": "37M",
263         "batch_size": 25,
264         "nb_train_samples": 50000,
265         "nb_test_samples": 10000,
266     },
267     "maze": {
268         "model": "37M",
269         "batch_size": 5,
270         "nb_train_samples": 100000,
271         "nb_test_samples": 10000,
272     },
273     "picoclvr": {
274         "model": "37M",
275         "batch_size": 25,
276         "nb_train_samples": 250000,
277         "nb_test_samples": 10000,
278     },
279     "rpl": {
280         "model": "352M",
281         "batch_size": 5,
282         "nb_train_samples": 2500000,
283         "nb_test_samples": 10000,
284     },
285     "snake": {
286         "model": "37M",
287         "batch_size": 25,
288         "nb_train_samples": 250000,
289         "nb_test_samples": 10000,
290     },
291     "stack": {
292         "model": "37M",
293         "batch_size": 25,
294         "nb_train_samples": 100000,
295         "nb_test_samples": 1000,
296     },
297     "twotargets": {
298         "model": "37M",
299         "batch_size": 25,
300         "nb_train_samples": 50000,
301         "nb_test_samples": 10000,
302     },
303     "memory": {
304         "model": "37M",
305         "batch_size": 25,
306         "nb_train_samples": 25000,
307         "nb_test_samples": 10000,
308     },
309     "mixing": {
310         "model": "37M",
311         "batch_size": 25,
312         "nb_train_samples": 250000,
313         "nb_test_samples": 10000,
314     },
315     "mnist": {
316         "model": "37M",
317         "batch_size": 10,
318         "nb_train_samples": 60000,
319         "nb_test_samples": 10000,
320     },
321 }
322
323 if args.task in default_task_args:
324     for k, v in default_task_args[args.task].items():
325         if getattr(args, k) is None:
326             setattr(args, k, v)
327
328 ######################################################################
329
330 default_model_args = {
331     "17K": {
332         "attention": "mha",
333         "dim_model": 32,
334         "dim_keys": 32,
335         "dim_hidden": 32,
336         "nb_heads": 2,
337         "nb_blocks": 2,
338     },
339     "17K-C": {
340         "attention": "caterpillar",
341         "dim_model": 32,
342         "dim_keys": 32,
343         "dim_hidden": 32,
344         "nb_heads": 2,
345         "nb_lines": 16,
346         "caterpillar_height": 4,
347         "nb_blocks": 2,
348     },
349     "4M": {
350         "attention": "mha",
351         "dim_model": 256,
352         "dim_keys": 32,
353         "dim_hidden": 1024,
354         "nb_heads": 4,
355         "nb_blocks": 6,
356     },
357     "4M-C": {
358         "attention": "caterpillar",
359         "dim_model": 256,
360         "dim_keys": 32,
361         "dim_hidden": 1024,
362         "nb_heads": 4,
363         "nb_lines": 32,
364         "caterpillar_height": 4,
365         "nb_blocks": 6,
366     },
367     "37M": {
368         "attention": "mha",
369         "dim_model": 512,
370         "dim_keys": 64,
371         "dim_hidden": 2048,
372         "nb_heads": 8,
373         "nb_blocks": 12,
374     },
375     "37M-C": {
376         "attention": "caterpillar",
377         "dim_model": 512,
378         "dim_keys": 64,
379         "dim_hidden": 2048,
380         "nb_heads": 8,
381         "nb_lines": 256,
382         "caterpillar_height": 32,
383         "nb_blocks": 12,
384     },
385     "122M": {
386         "attention": "mha",
387         "dim_model": 768,
388         "dim_keys": 64,
389         "dim_hidden": 2048,
390         "nb_heads": 8,
391         "nb_blocks": 24,
392     },
393     "122M-C": {
394         "attention": "caterpillar",
395         "dim_model": 768,
396         "dim_keys": 64,
397         "dim_hidden": 2048,
398         "nb_heads": 8,
399         "nb_lines": 128,
400         "nb_blocks": 24,
401     },
402     "352M": {
403         "attention": "mha",
404         "dim_model": 1024,
405         "dim_keys": 64,
406         "dim_hidden": 2048,
407         "nb_heads": 8,
408         "nb_blocks": 48,
409     },
410     "352M-C": {
411         "attention": "caterpillar",
412         "dim_model": 1024,
413         "dim_keys": 64,
414         "dim_hidden": 2048,
415         "nb_heads": 8,
416         "nb_lines": 128,
417         "nb_blocks": 48,
418     },
419 }
420
421 if args.model in default_model_args:
422     for k, v in default_model_args[args.model].items():
423         if getattr(args, k) is None:
424             setattr(args, k, v)
425 else:
426     raise ValueError(f"Unknown model {args.model}")
427
428 ######################################################################
429
430 try:
431     os.mkdir(args.result_dir)
432 except FileExistsError:
433     if not args.continue_training:
434         print(f"result directory {args.result_dir} already exists")
435         exit(1)
436
437 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
438
439 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
440
441 if args.seed >= 0:
442     # torch.backends.cudnn.deterministic = True
443     # torch.backends.cudnn.benchmark = False
444     # torch.use_deterministic_algorithms(True)
445     torch.manual_seed(args.seed)
446     if torch.cuda.is_available():
447         torch.cuda.manual_seed_all(args.seed)
448
449 ######################################################################
450
451
452 def log_string(s):
453     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
454
455     if log_file is not None:
456         log_file.write(t + s + "\n")
457         log_file.flush()
458
459     print(t + s)
460     sys.stdout.flush()
461
462
463 with os.popen("sha256sum *.py") as f:
464     for l in f:
465         log_string(f"sha256sum {l.strip()}")
466
467 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
468 os.system(f"tar --ignore-failed-read zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
469
470 log_string(f"argv {' '.join(sys.argv)}")
471
472 for n in vars(args):
473     log_string(f"args.{n} {getattr(args, n)}")
474
475 for n in vars(sup_args):
476     log_string(f"sup_args.{n} {getattr(sup_args, n)}")
477
478
479 ######################################################################
480
481
482 def get_lr(n_epoch, it):
483     if args.legacy_lr_schedule:
484         # my crude scheduling to compare to previous baseline, added
485         # warmup though
486
487         if it < args.nb_warmup_iter:
488             return args.legacy_large_lr * it / args.nb_warmup_iter
489         elif n_epoch < args.legacy_nb_epoch_large_lr:
490             return args.legacy_large_lr
491         else:
492             return args.legacy_small_lr
493
494     # from nanoGPT
495
496     # 1) linear warmup for warmup_iter steps
497     if it < args.nb_warmup_iter:
498         return args.learning_rate * it / args.nb_warmup_iter
499     # 2) if it > nb_decay_iter, return min learning rate
500     if it > args.nb_decay_iter:
501         return args.min_learning_rate
502     # 3) in between, use cosine decay down to min learning rate
503     decay_ratio = (it - args.nb_warmup_iter) / (
504         args.nb_decay_iter - args.nb_warmup_iter
505     )
506     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
507     return args.min_learning_rate + coeff * (
508         args.learning_rate - args.min_learning_rate
509     )
510
511
512 ######################################################################
513
514
515 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
516
517
518 def picoclvr_pruner_horizontal_green(p):
519     return not ("green" in p and ("left" in p or "right" in p))
520
521
522 picoclvr_pruner_train = (
523     picoclvr_pruner_horizontal_green
524     if args.picocvlr_prune_properties in {"train+eval"}
525     else None
526 )
527
528 picoclvr_pruner_eval = (
529     (lambda p: not picoclvr_pruner_horizontal_green(p))
530     if args.picocvlr_prune_properties in {"train+eval", "eval"}
531     else None
532 )
533
534 ######################################################################
535
536 device_data = device
537
538 if args.task == "byheart":
539     task = tasks.SandBox(
540         problem=problems.ProblemByHeart(),
541         nb_train_samples=args.nb_train_samples,
542         nb_test_samples=args.nb_test_samples,
543         batch_size=args.batch_size,
544         logger=log_string,
545         device=device_data,
546     )
547     args.max_percents_of_test_in_train = -1
548
549 elif args.task == "learnop":
550     task = tasks.SandBox(
551         problem=problems.ProblemLearnOperator(),
552         nb_train_samples=args.nb_train_samples,
553         nb_test_samples=args.nb_test_samples,
554         batch_size=args.batch_size,
555         logger=log_string,
556         device=device_data,
557     )
558
559
560 elif args.task == "guessop":
561     task = tasks.SandBox(
562         problem=problems.ProblemGuessOperator(),
563         nb_train_samples=args.nb_train_samples,
564         nb_test_samples=args.nb_test_samples,
565         batch_size=args.batch_size,
566         logger=log_string,
567         device=device_data,
568     )
569
570
571 elif args.task == "twotargets":
572     task = tasks.SandBox(
573         problem=problems.ProblemTwoTargets(),
574         nb_train_samples=args.nb_train_samples,
575         nb_test_samples=args.nb_test_samples,
576         batch_size=args.batch_size,
577         logger=log_string,
578         device=device_data,
579     )
580
581 elif args.task == "memory":
582     task = tasks.SandBox(
583         problem=problems.ProblemMemory(len_total=args.memory_len_total),
584         nb_train_samples=args.nb_train_samples,
585         nb_test_samples=args.nb_test_samples,
586         batch_size=args.batch_size,
587         logger=log_string,
588         device=device_data,
589     )
590
591 elif args.task == "mixing":
592     task = tasks.SandBox(
593         problem=problems.ProblemMixing(
594             hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
595         ),
596         nb_train_samples=args.nb_train_samples,
597         nb_test_samples=args.nb_test_samples,
598         batch_size=args.batch_size,
599         logger=log_string,
600         device=device_data,
601     )
602
603 elif args.task == "addition":
604     task = tasks.SandBox(
605         problem=problems.ProblemAddition(),
606         nb_train_samples=args.nb_train_samples,
607         nb_test_samples=args.nb_test_samples,
608         batch_size=args.batch_size,
609         logger=log_string,
610         device=device_data,
611     )
612
613 elif args.task == "picoclvr":
614     task = tasks.PicoCLVR(
615         nb_train_samples=args.nb_train_samples,
616         nb_test_samples=args.nb_test_samples,
617         batch_size=args.batch_size,
618         height=args.picoclvr_height,
619         width=args.picoclvr_width,
620         nb_colors=args.picoclvr_nb_colors,
621         logger=log_string,
622         device=device_data,
623         pruner_train=picoclvr_pruner_train,
624         pruner_eval=picoclvr_pruner_eval,
625     )
626
627 elif args.task == "mnist":
628     task = tasks.MNIST(
629         nb_train_samples=args.nb_train_samples,
630         nb_test_samples=args.nb_test_samples,
631         batch_size=args.batch_size,
632         device=device_data,
633     )
634
635 elif args.task == "maze":
636     task = tasks.Maze(
637         nb_train_samples=args.nb_train_samples,
638         nb_test_samples=args.nb_test_samples,
639         batch_size=args.batch_size,
640         height=args.maze_height,
641         width=args.maze_width,
642         nb_walls=args.maze_nb_walls,
643         device=device_data,
644     )
645
646 elif args.task == "snake":
647     task = tasks.Snake(
648         nb_train_samples=args.nb_train_samples,
649         nb_test_samples=args.nb_test_samples,
650         batch_size=args.batch_size,
651         height=args.snake_height,
652         width=args.snake_width,
653         nb_colors=args.snake_nb_colors,
654         length=args.snake_length,
655         prompt_length=args.snake_length // 2,
656         device=device_data,
657     )
658
659 elif args.task == "stack":
660     task = tasks.Stack(
661         nb_train_samples=args.nb_train_samples,
662         nb_test_samples=args.nb_test_samples,
663         batch_size=args.batch_size,
664         logger=log_string,
665         nb_steps=args.stack_nb_steps,
666         nb_stacks=args.stack_nb_stacks,
667         nb_digits=args.stack_nb_digits,
668         fraction_values_for_train=args.stack_fraction_values_for_train,
669         device=device_data,
670     )
671
672 elif args.task == "expr":
673     task = tasks.Expr(
674         nb_train_samples=args.nb_train_samples,
675         nb_test_samples=args.nb_test_samples,
676         nb_variables=args.expr_nb_variables,
677         sequence_length=args.expr_sequence_length,
678         operand_max=args.expr_operand_max,
679         result_max=args.expr_result_max,
680         batch_size=args.batch_size,
681         device=device_data,
682     )
683
684 elif args.task == "rpl":
685     task = tasks.RPL(
686         nb_train_samples=args.nb_train_samples,
687         nb_test_samples=args.nb_test_samples,
688         batch_size=args.batch_size,
689         nb_starting_values=args.rpl_nb_starting_values,
690         max_input=args.rpl_max_input,
691         prog_len=args.rpl_prog_len,
692         nb_runs=args.rpl_nb_runs,
693         no_prog=args.rpl_no_prog,
694         logger=log_string,
695         device=device_data,
696     )
697
698 elif args.task == "grid":
699     task = tasks.Grid(
700         nb_train_samples=args.nb_train_samples,
701         nb_test_samples=args.nb_test_samples,
702         batch_size=args.batch_size,
703         size=args.grid_size,
704         logger=log_string,
705         device=device_data,
706     )
707
708 elif args.task == "qmlp":
709     task = tasks.QMLP(
710         nb_train_samples=args.nb_train_samples,
711         nb_test_samples=args.nb_test_samples,
712         batch_size=args.batch_size,
713         result_dir=args.result_dir,
714         logger=log_string,
715         device=device_data,
716     )
717
718 else:
719     raise ValueError(f"Unknown task {args.task}")
720
721 ######################################################################
722
723 log_string(f"device {device}")
724
725 vocabulary_size = task.vocabulary_size()
726
727 log_string(f"vocabulary_size {vocabulary_size}")
728
729 ##############################
730
731 model = mygpt.MyGPT(
732     vocabulary_size=vocabulary_size,
733     dim_model=args.dim_model,
734     dim_keys=args.dim_keys,
735     dim_hidden=args.dim_hidden,
736     nb_heads=args.nb_heads,
737     nb_lines=args.nb_lines,
738     caterpillar_height=args.caterpillar_height,
739     nb_blocks=args.nb_blocks,
740     causal=True,
741     dropout=args.dropout,
742     attention_layer=args.attention,
743     logger=log_string,
744     **sup_args,
745 )
746
747 model.to(device)
748
749 nb_parameters = sum(p.numel() for p in model.parameters())
750 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
751
752 ######################################################################
753
754 nb_epochs_finished = 0
755
756 if args.no_checkpoint:
757     log_string(f"not trying to load checkpoint.")
758
759 else:
760     try:
761         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
762         checkpoint = torch.load(checkpoint_name)
763         nb_epochs_finished = checkpoint["nb_epochs_finished"]
764         model.load_state_dict(checkpoint["model_state"])
765         torch.set_rng_state(checkpoint["rng_state"])
766         if torch.cuda.is_available():
767             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
768
769         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
770
771     except FileNotFoundError:
772         log_string("starting from scratch.")
773
774     except:
775         log_string("error when loading the checkpoint.")
776         exit(1)
777
778 ######################################################################
779
780 if args.task == "expr" and args.expr_input_file is not None:
781     task.produce_results(
782         n_epoch=nb_epochs_finished,
783         model=model,
784         result_dir=args.result_dir,
785         logger=log_string,
786         deterministic_synthesis=args.deterministic_synthesis,
787         input_file=args.expr_input_file,
788     )
789
790     exit(0)
791
792 ######################################################################
793
794 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
795
796 # Compute the entropy of the training tokens
797
798 token_count = 0
799 for input in task.batches(split="train"):
800     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
801 token_probas = token_count / token_count.sum()
802 entropy = -torch.xlogy(token_probas, token_probas).sum()
803 train_set_perplexity = math.exp(entropy)
804
805 ######################################################################
806 # A bit of paranoia never hurts
807
808 if args.max_percents_of_test_in_train >= 0:
809
810     def subsets_as_tuples(batches, cs):
811         s = set()
812         for batch in batches:
813             for x in batch:
814                 s.add(tuple([v.item() for v in x]))
815                 if len(s) == cs:
816                     yield s
817                     s = set()
818         yield s
819
820     nb_test, nb_in_train = 0, 0
821     for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
822         in_train = set()
823         for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
824             in_train.update(test_subset.intersection(train_subset))
825         nb_in_train += len(in_train)
826         nb_test += len(test_subset)
827
828     log_string(
829         f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
830     )
831
832     assert (
833         nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
834     ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
835
836 ##############################
837
838 nb_samples_seen = 0
839
840 if nb_epochs_finished >= nb_epochs:
841     task.produce_results(
842         n_epoch=nb_epochs_finished,
843         model=model,
844         result_dir=args.result_dir,
845         logger=log_string,
846         deterministic_synthesis=args.deterministic_synthesis,
847     )
848
849 time_pred_result = datetime.datetime.now()
850
851 it = 0
852
853 n_batch = 0
854
855 for n_epoch in range(nb_epochs_finished, nb_epochs):
856     if args.optim == "sgd":
857         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
858     elif args.optim == "adam":
859         optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
860     elif args.optim == "adamw":
861         optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
862     else:
863         raise ValueError(f"Unknown optimizer {args.optim}.")
864
865     model.train()
866
867     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
868
869     for input in task.batches(split="train"):
870         model.reset_inner_loss()
871         input = input.to(device)
872
873         output = model(mygpt.BracketedSequence(input)).x
874         loss = F.cross_entropy(output.transpose(1, 2), input)
875         inner_loss = model.get_inner_loss()
876
877         acc_train_loss += loss.item() * input.size(0)
878         acc_train_inner_loss += inner_loss.item() * input.size(0)
879
880         nb_train_samples += input.size(0)
881         nb_samples_seen += input.size(0)
882
883         total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
884
885         it += 1
886         lr = get_lr(n_epoch, it)
887         for param_group in optimizer.param_groups:
888             param_group["lr"] = lr
889
890         # log_string(f"learning_rate {lr}")
891
892         optimizer.zero_grad()
893         total_loss.backward()
894         optimizer.step()
895
896         grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
897
898         loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
899
900         n_batch += 1
901
902     with torch.autograd.no_grad():
903         model.eval()
904
905         nb_test_samples, acc_test_loss = 0, 0.0
906
907         for input in task.batches(split="test"):
908             input = input.to(device)
909
910             output = model(mygpt.BracketedSequence(input)).x
911             loss = F.cross_entropy(output.transpose(1, 2), input)
912             acc_test_loss += loss.item() * input.size(0)
913             nb_test_samples += input.size(0)
914
915         log_string(
916             f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
917         )
918
919         task.produce_results(
920             n_epoch=n_epoch,
921             model=model,
922             result_dir=args.result_dir,
923             logger=log_string,
924             deterministic_synthesis=args.deterministic_synthesis,
925         )
926
927         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
928         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
929
930         log_string(
931             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
932         )
933
934         time_current_result = datetime.datetime.now()
935         log_string(
936             f"next_result {time_current_result + (time_current_result - time_pred_result)}"
937         )
938         time_pred_result = time_current_result
939
940     checkpoint = {
941         "nb_epochs_finished": n_epoch + 1,
942         "model_state": model.state_dict(),
943         "rng_state": torch.get_rng_state(),
944     }
945
946     if torch.cuda.is_available():
947         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
948
949     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
950     torch.save(checkpoint, checkpoint_name)
951     log_string(f"saved checkpoint {checkpoint_name}")
952
953 ######################################################################