Update.
[mygptrnn.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
9
10 import torch, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import ffutils
15 import mygpt, tasks, problems
16
17 ######################################################################
18
19
20 def str2bool(x):
21     x = x.lower()
22     if x in {"1", "true", "yes"}:
23         return True
24     elif x in {"0", "false", "no"}:
25         return False
26     else:
27         raise ValueError
28
29
30 parser = argparse.ArgumentParser(
31     description="An implementation of GPT with cache.",
32     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
33 )
34
35 parser.add_argument(
36     "--task",
37     type=str,
38     default="twotargets",
39     help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
40 )
41
42 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
43
44 parser.add_argument("--result_dir", type=str, default=None)
45
46 parser.add_argument("--seed", type=int, default=0)
47
48 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
49
50 parser.add_argument("--force_cpu", type=str2bool, default=False)
51
52 ########################################
53
54 parser.add_argument("--nb_epochs", type=int, default=50)
55
56 parser.add_argument("--batch_size", type=int, default=None)
57
58 parser.add_argument("--nb_train_samples", type=int, default=None)
59
60 parser.add_argument("--nb_test_samples", type=int, default=None)
61
62 parser.add_argument("--optim", type=str, default="adam")
63
64 ########################################
65
66 parser.add_argument("--nb_warmup_iter", type=int, default=100)
67
68 parser.add_argument("--nb_decay_iter", type=int, default=5000)
69
70 parser.add_argument("--learning_rate", type=float, default=6e-4)
71
72 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
73
74 # legacy
75
76 parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
77
78 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
79
80 parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
81
82 parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
83
84 ########################################
85
86 parser.add_argument("--model", type=str, default=None)
87
88 parser.add_argument("--attention", type=str, default=None)
89
90 parser.add_argument("--dim_model", type=int, default=None)
91
92 parser.add_argument("--dim_keys", type=int, default=None)
93
94 parser.add_argument("--dim_hidden", type=int, default=None)
95
96 parser.add_argument("--nb_heads", type=int, default=None)
97
98 parser.add_argument("--nb_lines", type=int, default=None)
99
100 parser.add_argument("--caterpillar_height", type=int, default=None)
101
102 parser.add_argument("--rho", type=float, default=0.0)
103
104 parser.add_argument("--nb_blocks", type=int, default=None)
105
106 parser.add_argument("--dropout", type=float, default=0.1)
107
108 ########################################
109
110 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
111
112 parser.add_argument("--no_checkpoint", action="store_true", default=False)
113
114 parser.add_argument("--continue_training", action="store_true", default=False)
115
116 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
117
118 ##############################
119 # rpl options
120
121 parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
122
123 parser.add_argument("--rpl_max_input", type=int, default=9)
124
125 parser.add_argument("--rpl_prog_len", type=int, default=8)
126
127 parser.add_argument("--rpl_nb_runs", type=int, default=5)
128
129 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
130
131 ##############################
132 # grid options
133
134 parser.add_argument("--grid_size", type=int, default=6)
135
136 parser.add_argument("--grid_nb_colors", type=int, default=6)
137
138 parser.add_argument("--grid_nb_shapes", type=int, default=6)
139
140 ##############################
141 # picoclvr options
142
143 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
144
145 parser.add_argument("--picoclvr_height", type=int, default=12)
146
147 parser.add_argument("--picoclvr_width", type=int, default=16)
148
149 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
150
151 ##############################
152 # Maze options
153
154 parser.add_argument("--maze_height", type=int, default=13)
155
156 parser.add_argument("--maze_width", type=int, default=21)
157
158 parser.add_argument("--maze_nb_walls", type=int, default=15)
159
160 ##############################
161 # Snake options
162
163 parser.add_argument("--snake_height", type=int, default=9)
164
165 parser.add_argument("--snake_width", type=int, default=12)
166
167 parser.add_argument("--snake_nb_colors", type=int, default=5)
168
169 parser.add_argument("--snake_length", type=int, default=200)
170
171 ##############################
172 # Stack options
173
174 parser.add_argument("--stack_nb_steps", type=int, default=100)
175
176 parser.add_argument("--stack_nb_stacks", type=int, default=3)
177
178 parser.add_argument("--stack_nb_digits", type=int, default=3)
179
180 parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
181
182 ##############################
183 # Expr options
184
185 parser.add_argument("--expr_nb_variables", type=int, default=5)
186
187 parser.add_argument("--expr_sequence_length", type=int, default=40)
188
189 parser.add_argument("--expr_operand_max", type=int, default=9)
190
191 parser.add_argument("--expr_result_max", type=int, default=99)
192
193 parser.add_argument("--expr_input_file", type=str, default=None)
194
195 ##############################
196 # Memory
197
198 parser.add_argument("--memory_len_total", type=int, default=32)
199
200 ##############################
201 # Mixing
202
203 parser.add_argument("--mixing_hard", action="store_true", default=False)
204
205 parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
206
207 ######################################################################
208
209 # args = parser.parse_args()
210
211 args, sup_args = parser.parse_known_args()
212
213 sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
214
215 if args.result_dir is None:
216     args.result_dir = f"results_{args.task}_{args.model}"
217
218 ######################################################################
219
220 if not args.force_cpu and torch.cuda.is_available():
221     device = torch.device("cuda")
222     torch.backends.cuda.matmul.allow_tf32 = True
223 else:
224     device = torch.device("cpu")
225
226 ######################################################################
227
228 default_task_args = {
229     "addition": {
230         "model": "352M",
231         "batch_size": 25,
232         "nb_train_samples": 250000,
233         "nb_test_samples": 10000,
234     },
235     "byheart": {
236         "model": "37M",
237         "batch_size": 25,
238         "nb_train_samples": 50000,
239         "nb_test_samples": 10000,
240     },
241     "expr": {
242         "model": "352M",
243         "batch_size": 25,
244         "nb_train_samples": 2500000,
245         "nb_test_samples": 10000,
246     },
247     "grid": {
248         "model": "37M",
249         "batch_size": 25,
250         "nb_train_samples": 250000,
251         "nb_test_samples": 10000,
252     },
253     "qmlp": {
254         "model": "37M",
255         "batch_size": 10,
256         "nb_train_samples": 100000,
257         "nb_test_samples": 1000,
258     },
259     "guessop": {
260         "model": "352M",
261         "batch_size": 25,
262         "nb_train_samples": 1000000,
263         "nb_test_samples": 10000,
264     },
265     "learnop": {
266         "model": "37M",
267         "batch_size": 25,
268         "nb_train_samples": 50000,
269         "nb_test_samples": 10000,
270     },
271     "maze": {
272         "model": "37M",
273         "batch_size": 5,
274         "nb_train_samples": 100000,
275         "nb_test_samples": 10000,
276     },
277     "picoclvr": {
278         "model": "37M",
279         "batch_size": 25,
280         "nb_train_samples": 250000,
281         "nb_test_samples": 10000,
282     },
283     "rpl": {
284         "model": "352M",
285         "batch_size": 5,
286         "nb_train_samples": 2500000,
287         "nb_test_samples": 10000,
288     },
289     "snake": {
290         "model": "37M",
291         "batch_size": 25,
292         "nb_train_samples": 250000,
293         "nb_test_samples": 10000,
294     },
295     "stack": {
296         "model": "37M",
297         "batch_size": 25,
298         "nb_train_samples": 100000,
299         "nb_test_samples": 1000,
300     },
301     "twotargets": {
302         "model": "37M",
303         "batch_size": 25,
304         "nb_train_samples": 50000,
305         "nb_test_samples": 10000,
306     },
307     "memory": {
308         "model": "37M",
309         "batch_size": 25,
310         "nb_train_samples": 25000,
311         "nb_test_samples": 10000,
312     },
313     "mixing": {
314         "model": "37M",
315         "batch_size": 25,
316         "nb_train_samples": 250000,
317         "nb_test_samples": 10000,
318     },
319     "mnist": {
320         "model": "37M",
321         "batch_size": 10,
322         "nb_train_samples": 60000,
323         "nb_test_samples": 10000,
324     },
325 }
326
327 if args.task in default_task_args:
328     for k, v in default_task_args[args.task].items():
329         if getattr(args, k) is None:
330             setattr(args, k, v)
331
332 ######################################################################
333
334 default_model_args = {
335     "17K": {
336         "attention": "mha",
337         "dim_model": 32,
338         "dim_keys": 32,
339         "dim_hidden": 32,
340         "nb_heads": 2,
341         "nb_blocks": 2,
342     },
343     "17K-C": {
344         "attention": "caterpillar",
345         "dim_model": 32,
346         "dim_keys": 32,
347         "dim_hidden": 32,
348         "nb_heads": 2,
349         "nb_lines": 16,
350         "caterpillar_height": 4,
351         "nb_blocks": 2,
352     },
353     "4M": {
354         "attention": "mha",
355         "dim_model": 256,
356         "dim_keys": 32,
357         "dim_hidden": 1024,
358         "nb_heads": 4,
359         "nb_blocks": 6,
360     },
361     "4M-C": {
362         "attention": "caterpillar",
363         "dim_model": 256,
364         "dim_keys": 32,
365         "dim_hidden": 1024,
366         "nb_heads": 4,
367         "nb_lines": 32,
368         "caterpillar_height": 4,
369         "nb_blocks": 6,
370     },
371     "37M": {
372         "attention": "mha",
373         "dim_model": 512,
374         "dim_keys": 64,
375         "dim_hidden": 2048,
376         "nb_heads": 8,
377         "nb_blocks": 12,
378     },
379     "37M-C": {
380         "attention": "caterpillar",
381         "dim_model": 512,
382         "dim_keys": 64,
383         "dim_hidden": 2048,
384         "nb_heads": 8,
385         "nb_lines": 256,
386         "caterpillar_height": 32,
387         "nb_blocks": 12,
388     },
389     "122M": {
390         "attention": "mha",
391         "dim_model": 768,
392         "dim_keys": 64,
393         "dim_hidden": 2048,
394         "nb_heads": 8,
395         "nb_blocks": 24,
396     },
397     "122M-C": {
398         "attention": "caterpillar",
399         "dim_model": 768,
400         "dim_keys": 64,
401         "dim_hidden": 2048,
402         "nb_heads": 8,
403         "nb_lines": 128,
404         "nb_blocks": 24,
405     },
406     "352M": {
407         "attention": "mha",
408         "dim_model": 1024,
409         "dim_keys": 64,
410         "dim_hidden": 2048,
411         "nb_heads": 8,
412         "nb_blocks": 48,
413     },
414     "352M-C": {
415         "attention": "caterpillar",
416         "dim_model": 1024,
417         "dim_keys": 64,
418         "dim_hidden": 2048,
419         "nb_heads": 8,
420         "nb_lines": 128,
421         "nb_blocks": 48,
422     },
423 }
424
425 if args.model in default_model_args:
426     for k, v in default_model_args[args.model].items():
427         if getattr(args, k) is None:
428             setattr(args, k, v)
429 else:
430     raise ValueError(f"Unknown model {args.model}")
431
432 ######################################################################
433
434 try:
435     os.mkdir(args.result_dir)
436 except FileExistsError:
437     if not args.continue_training:
438         print(f"result directory {args.result_dir} already exists")
439         exit(1)
440
441 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
442
443 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
444
445 if args.seed >= 0:
446     # torch.backends.cudnn.deterministic = True
447     # torch.backends.cudnn.benchmark = False
448     # torch.use_deterministic_algorithms(True)
449     torch.manual_seed(args.seed)
450     if torch.cuda.is_available():
451         torch.cuda.manual_seed_all(args.seed)
452
453 ######################################################################
454
455
456 def log_string(s):
457     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
458
459     if log_file is not None:
460         log_file.write(t + s + "\n")
461         log_file.flush()
462
463     print(t + s)
464     sys.stdout.flush()
465
466
467 with os.popen("sha256sum *.py") as f:
468     for l in f:
469         log_string(f"sha256sum {l.strip()}")
470
471 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
472 os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
473
474 log_string(f"argv {' '.join(sys.argv)}")
475
476 for n in vars(args):
477     log_string(f"args.{n} {getattr(args, n)}")
478
479 for k, v in sup_args.items():
480     log_string(f'sup_args["{k}"] "{v}"')
481
482
483 ######################################################################
484
485
486 def get_lr(n_epoch, it):
487     if args.legacy_lr_schedule:
488         # my crude scheduling to compare to previous baseline, added
489         # warmup though
490
491         if it < args.nb_warmup_iter:
492             return args.legacy_large_lr * it / args.nb_warmup_iter
493         elif n_epoch < args.legacy_nb_epoch_large_lr:
494             return args.legacy_large_lr
495         else:
496             return args.legacy_small_lr
497
498     # from nanoGPT
499
500     # 1) linear warmup for warmup_iter steps
501     if it < args.nb_warmup_iter:
502         return args.learning_rate * it / args.nb_warmup_iter
503     # 2) if it > nb_decay_iter, return min learning rate
504     if it > args.nb_decay_iter:
505         return args.min_learning_rate
506     # 3) in between, use cosine decay down to min learning rate
507     decay_ratio = (it - args.nb_warmup_iter) / (
508         args.nb_decay_iter - args.nb_warmup_iter
509     )
510     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
511     return args.min_learning_rate + coeff * (
512         args.learning_rate - args.min_learning_rate
513     )
514
515
516 ######################################################################
517
518
519 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
520
521
522 def picoclvr_pruner_horizontal_green(p):
523     return not ("green" in p and ("left" in p or "right" in p))
524
525
526 picoclvr_pruner_train = (
527     picoclvr_pruner_horizontal_green
528     if args.picocvlr_prune_properties in {"train+eval"}
529     else None
530 )
531
532 picoclvr_pruner_eval = (
533     (lambda p: not picoclvr_pruner_horizontal_green(p))
534     if args.picocvlr_prune_properties in {"train+eval", "eval"}
535     else None
536 )
537
538 ######################################################################
539
540 device_data = device
541
542 if args.task == "byheart":
543     task = tasks.SandBox(
544         problem=problems.ProblemByHeart(),
545         nb_train_samples=args.nb_train_samples,
546         nb_test_samples=args.nb_test_samples,
547         batch_size=args.batch_size,
548         logger=log_string,
549         device=device_data,
550     )
551     args.max_percents_of_test_in_train = -1
552
553 elif args.task == "learnop":
554     task = tasks.SandBox(
555         problem=problems.ProblemLearnOperator(),
556         nb_train_samples=args.nb_train_samples,
557         nb_test_samples=args.nb_test_samples,
558         batch_size=args.batch_size,
559         logger=log_string,
560         device=device_data,
561     )
562
563
564 elif args.task == "guessop":
565     task = tasks.SandBox(
566         problem=problems.ProblemGuessOperator(),
567         nb_train_samples=args.nb_train_samples,
568         nb_test_samples=args.nb_test_samples,
569         batch_size=args.batch_size,
570         logger=log_string,
571         device=device_data,
572     )
573
574
575 elif args.task == "twotargets":
576     task = tasks.SandBox(
577         problem=problems.ProblemTwoTargets(),
578         nb_train_samples=args.nb_train_samples,
579         nb_test_samples=args.nb_test_samples,
580         batch_size=args.batch_size,
581         logger=log_string,
582         device=device_data,
583     )
584
585 elif args.task == "memory":
586     task = tasks.SandBox(
587         problem=problems.ProblemMemory(len_total=args.memory_len_total),
588         nb_train_samples=args.nb_train_samples,
589         nb_test_samples=args.nb_test_samples,
590         batch_size=args.batch_size,
591         logger=log_string,
592         device=device_data,
593     )
594
595 elif args.task == "mixing":
596     task = tasks.SandBox(
597         problem=problems.ProblemMixing(
598             hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
599         ),
600         nb_train_samples=args.nb_train_samples,
601         nb_test_samples=args.nb_test_samples,
602         batch_size=args.batch_size,
603         logger=log_string,
604         device=device_data,
605     )
606
607 elif args.task == "addition":
608     task = tasks.SandBox(
609         problem=problems.ProblemAddition(),
610         nb_train_samples=args.nb_train_samples,
611         nb_test_samples=args.nb_test_samples,
612         batch_size=args.batch_size,
613         logger=log_string,
614         device=device_data,
615     )
616
617 elif args.task == "picoclvr":
618     task = tasks.PicoCLVR(
619         nb_train_samples=args.nb_train_samples,
620         nb_test_samples=args.nb_test_samples,
621         batch_size=args.batch_size,
622         height=args.picoclvr_height,
623         width=args.picoclvr_width,
624         nb_colors=args.picoclvr_nb_colors,
625         logger=log_string,
626         device=device_data,
627         pruner_train=picoclvr_pruner_train,
628         pruner_eval=picoclvr_pruner_eval,
629     )
630
631 elif args.task == "mnist":
632     task = tasks.MNIST(
633         nb_train_samples=args.nb_train_samples,
634         nb_test_samples=args.nb_test_samples,
635         batch_size=args.batch_size,
636         device=device_data,
637     )
638
639 elif args.task == "maze":
640     task = tasks.Maze(
641         nb_train_samples=args.nb_train_samples,
642         nb_test_samples=args.nb_test_samples,
643         batch_size=args.batch_size,
644         height=args.maze_height,
645         width=args.maze_width,
646         nb_walls=args.maze_nb_walls,
647         device=device_data,
648     )
649
650 elif args.task == "snake":
651     task = tasks.Snake(
652         nb_train_samples=args.nb_train_samples,
653         nb_test_samples=args.nb_test_samples,
654         batch_size=args.batch_size,
655         height=args.snake_height,
656         width=args.snake_width,
657         nb_colors=args.snake_nb_colors,
658         length=args.snake_length,
659         prompt_length=args.snake_length // 2,
660         device=device_data,
661     )
662
663 elif args.task == "stack":
664     task = tasks.Stack(
665         nb_train_samples=args.nb_train_samples,
666         nb_test_samples=args.nb_test_samples,
667         batch_size=args.batch_size,
668         logger=log_string,
669         nb_steps=args.stack_nb_steps,
670         nb_stacks=args.stack_nb_stacks,
671         nb_digits=args.stack_nb_digits,
672         fraction_values_for_train=args.stack_fraction_values_for_train,
673         device=device_data,
674     )
675
676 elif args.task == "expr":
677     task = tasks.Expr(
678         nb_train_samples=args.nb_train_samples,
679         nb_test_samples=args.nb_test_samples,
680         nb_variables=args.expr_nb_variables,
681         sequence_length=args.expr_sequence_length,
682         operand_max=args.expr_operand_max,
683         result_max=args.expr_result_max,
684         batch_size=args.batch_size,
685         device=device_data,
686     )
687
688 elif args.task == "rpl":
689     task = tasks.RPL(
690         nb_train_samples=args.nb_train_samples,
691         nb_test_samples=args.nb_test_samples,
692         batch_size=args.batch_size,
693         nb_starting_values=args.rpl_nb_starting_values,
694         max_input=args.rpl_max_input,
695         prog_len=args.rpl_prog_len,
696         nb_runs=args.rpl_nb_runs,
697         no_prog=args.rpl_no_prog,
698         logger=log_string,
699         device=device_data,
700     )
701
702 elif args.task == "grid":
703     task = tasks.Grid(
704         nb_train_samples=args.nb_train_samples,
705         nb_test_samples=args.nb_test_samples,
706         batch_size=args.batch_size,
707         size=args.grid_size,
708         nb_shapes=args.grid_nb_shapes,
709         nb_colors=args.grid_nb_colors,
710         logger=log_string,
711         device=device_data,
712     )
713
714 elif args.task == "qmlp":
715     task = tasks.QMLP(
716         nb_train_samples=args.nb_train_samples,
717         nb_test_samples=args.nb_test_samples,
718         batch_size=args.batch_size,
719         result_dir=args.result_dir,
720         logger=log_string,
721         device=device_data,
722     )
723
724 else:
725     raise ValueError(f"Unknown task {args.task}")
726
727 ######################################################################
728
729 log_string(f"device {device}")
730
731 vocabulary_size = task.vocabulary_size()
732
733 log_string(f"vocabulary_size {vocabulary_size}")
734
735 ##############################
736
737 model = mygpt.MyGPT(
738     vocabulary_size=vocabulary_size,
739     dim_model=args.dim_model,
740     dim_keys=args.dim_keys,
741     dim_hidden=args.dim_hidden,
742     nb_heads=args.nb_heads,
743     nb_lines=args.nb_lines,
744     caterpillar_height=args.caterpillar_height,
745     nb_blocks=args.nb_blocks,
746     causal=True,
747     dropout=args.dropout,
748     attention_layer=args.attention,
749     logger=log_string,
750     **sup_args,
751 )
752
753 model.to(device)
754
755 nb_parameters = sum(p.numel() for p in model.parameters())
756 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
757
758 ######################################################################
759
760 nb_epochs_finished = 0
761
762 if args.no_checkpoint:
763     log_string(f"not trying to load checkpoint.")
764
765 else:
766     try:
767         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
768         checkpoint = torch.load(checkpoint_name)
769         nb_epochs_finished = checkpoint["nb_epochs_finished"]
770         model.load_state_dict(checkpoint["model_state"])
771         torch.set_rng_state(checkpoint["rng_state"])
772         if torch.cuda.is_available():
773             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
774
775         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
776
777     except FileNotFoundError:
778         log_string("starting from scratch.")
779
780     except:
781         log_string("error when loading the checkpoint.")
782         exit(1)
783
784 ######################################################################
785
786 if args.task == "expr" and args.expr_input_file is not None:
787     task.produce_results(
788         n_epoch=nb_epochs_finished,
789         model=model,
790         result_dir=args.result_dir,
791         logger=log_string,
792         deterministic_synthesis=args.deterministic_synthesis,
793         input_file=args.expr_input_file,
794     )
795
796     exit(0)
797
798 ######################################################################
799
800 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
801
802 # Compute the entropy of the training tokens
803
804 token_count = 0
805 for input in task.batches(split="train"):
806     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
807 token_probas = token_count / token_count.sum()
808 entropy = -torch.xlogy(token_probas, token_probas).sum()
809 train_set_perplexity = math.exp(entropy)
810
811 ######################################################################
812 # A bit of paranoia never hurts
813
814 if args.max_percents_of_test_in_train >= 0:
815
816     def subsets_as_tuples(batches, cs):
817         s = set()
818         for batch in batches:
819             for x in batch:
820                 s.add(tuple([v.item() for v in x]))
821                 if len(s) == cs:
822                     yield s
823                     s = set()
824         yield s
825
826     nb_test, nb_in_train = 0, 0
827     for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
828         in_train = set()
829         for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
830             in_train.update(test_subset.intersection(train_subset))
831         nb_in_train += len(in_train)
832         nb_test += len(test_subset)
833
834     log_string(
835         f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
836     )
837
838     assert (
839         nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
840     ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
841
842 ##############################
843
844 if "calibrate" in sup_args:
845     for input in task.batches(split="train", desc="calibrate"):
846         input = input.to(device)
847         output = model(mygpt.BracketedSequence(input)).x
848
849     for n, m in model.named_modules():
850         for a in dir(m):
851             x = getattr(m, a)
852             if isinstance(x, mygpt.Calibrator):
853                 print(f"####### ${n} | ${a} ########################")
854                 mean, std = x.moments()
855                 print("mean\n", mean, "\n")
856                 print("std\n", std, "\n")
857                 print(f"############################################\n\n")
858
859     exit(0)
860
861 ##############################
862
863 nb_samples_seen = 0
864
865 if nb_epochs_finished >= nb_epochs:
866     task.produce_results(
867         n_epoch=nb_epochs_finished,
868         model=model,
869         result_dir=args.result_dir,
870         logger=log_string,
871         deterministic_synthesis=args.deterministic_synthesis,
872     )
873
874 time_pred_result = datetime.datetime.now()
875
876 it = 0
877
878 n_batch = 0
879
880 for n_epoch in range(nb_epochs_finished, nb_epochs):
881     if args.optim == "sgd":
882         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
883     elif args.optim == "adam":
884         optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
885     elif args.optim == "adamw":
886         optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
887     else:
888         raise ValueError(f"Unknown optimizer {args.optim}.")
889
890     model.train()
891
892     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
893
894     for input in task.batches(split="train"):
895         model.reset_inner_loss()
896         input = input.to(device)
897
898         output = model(mygpt.BracketedSequence(input)).x
899         loss = F.cross_entropy(output.transpose(1, 2), input)
900         inner_loss = model.get_inner_loss()
901
902         acc_train_loss += loss.item() * input.size(0)
903         acc_train_inner_loss += inner_loss.item() * input.size(0)
904
905         nb_train_samples += input.size(0)
906         nb_samples_seen += input.size(0)
907
908         total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
909
910         it += 1
911         lr = get_lr(n_epoch, it)
912         for param_group in optimizer.param_groups:
913             param_group["lr"] = lr
914
915         # log_string(f"learning_rate {lr}")
916
917         optimizer.zero_grad()
918         total_loss.backward()
919         optimizer.step()
920
921         grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
922
923         loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
924
925         n_batch += 1
926
927     with torch.autograd.no_grad():
928         model.eval()
929
930         nb_test_samples, acc_test_loss = 0, 0.0
931
932         for input in task.batches(split="test"):
933             input = input.to(device)
934
935             output = model(mygpt.BracketedSequence(input)).x
936             loss = F.cross_entropy(output.transpose(1, 2), input)
937             acc_test_loss += loss.item() * input.size(0)
938             nb_test_samples += input.size(0)
939
940         log_string(
941             f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
942         )
943
944         task.produce_results(
945             n_epoch=n_epoch,
946             model=model,
947             result_dir=args.result_dir,
948             logger=log_string,
949             deterministic_synthesis=args.deterministic_synthesis,
950         )
951
952         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
953         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
954
955         log_string(
956             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
957         )
958
959         time_current_result = datetime.datetime.now()
960         log_string(
961             f"next_result {time_current_result + (time_current_result - time_pred_result)}"
962         )
963         time_pred_result = time_current_result
964
965     checkpoint = {
966         "nb_epochs_finished": n_epoch + 1,
967         "model_state": model.state_dict(),
968         "rng_state": torch.get_rng_state(),
969     }
970
971     if torch.cuda.is_available():
972         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
973
974     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
975     torch.save(checkpoint, checkpoint_name)
976     log_string(f"saved checkpoint {checkpoint_name}")
977
978 ######################################################################