Update.
[mygptrnn.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
9
10 import torch, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import ffutils
15 import mygpt, tasks, problems
16
17 ######################################################################
18
19
20 def str2bool(x):
21     x = x.lower()
22     if x in {"1", "true", "yes"}:
23         return True
24     elif x in {"0", "false", "no"}:
25         return False
26     else:
27         raise ValueError
28
29
30 parser = argparse.ArgumentParser(
31     description="An implementation of GPT with cache.",
32     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
33 )
34
35 parser.add_argument(
36     "--task",
37     type=str,
38     default="twotargets",
39     help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
40 )
41
42 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
43
44 parser.add_argument("--result_dir", type=str, default=None)
45
46 parser.add_argument("--seed", type=int, default=0)
47
48 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
49
50 parser.add_argument("--force_cpu", type=str2bool, default=False)
51
52 ########################################
53
54 parser.add_argument("--nb_epochs", type=int, default=50)
55
56 parser.add_argument("--batch_size", type=int, default=None)
57
58 parser.add_argument("--nb_train_samples", type=int, default=None)
59
60 parser.add_argument("--nb_test_samples", type=int, default=None)
61
62 parser.add_argument("--optim", type=str, default="adam")
63
64 ########################################
65
66 parser.add_argument("--nb_warmup_iter", type=int, default=100)
67
68 parser.add_argument("--nb_decay_iter", type=int, default=5000)
69
70 parser.add_argument("--learning_rate", type=float, default=6e-4)
71
72 parser.add_argument("--min_learning_rate", type=float, default=6e-5)
73
74 # legacy
75
76 parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
77
78 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
79
80 parser.add_argument("--legacy_small_lr", type=float, default=2e-5)
81
82 parser.add_argument("--legacy_nb_epoch_large_lr", type=float, default=10)
83
84 ########################################
85
86 parser.add_argument("--model", type=str, default=None)
87
88 parser.add_argument("--attention", type=str, default=None)
89
90 parser.add_argument("--memex_proba", type=float, default=0)
91
92 parser.add_argument("--memex_nb_epochs", type=float, default=1)
93
94 parser.add_argument("--dim_model", type=int, default=None)
95
96 parser.add_argument("--dim_keys", type=int, default=None)
97
98 parser.add_argument("--dim_hidden", type=int, default=None)
99
100 parser.add_argument("--nb_heads", type=int, default=None)
101
102 parser.add_argument("--nb_lines", type=int, default=None)
103
104 parser.add_argument("--caterpillar_height", type=int, default=None)
105
106 parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
107
108 parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
109
110 parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
111
112 parser.add_argument("--rho_inner_loss", type=float, default=0.0)
113
114 parser.add_argument("--nb_blocks", type=int, default=None)
115
116 parser.add_argument("--dropout", type=float, default=0.1)
117
118 ########################################
119
120 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
121
122 parser.add_argument("--no_checkpoint", action="store_true", default=False)
123
124 parser.add_argument("--continue_training", action="store_true", default=False)
125
126 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
127
128 ##############################
129 # rpl options
130
131 parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
132
133 parser.add_argument("--rpl_max_input", type=int, default=9)
134
135 parser.add_argument("--rpl_prog_len", type=int, default=8)
136
137 parser.add_argument("--rpl_nb_runs", type=int, default=5)
138
139 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
140
141 ##############################
142 # grid options
143
144 parser.add_argument("--grid_size", type=int, default=6)
145
146 parser.add_argument("--grid_nb_colors", type=int, default=6)
147
148 parser.add_argument("--grid_nb_shapes", type=int, default=6)
149
150 ##############################
151 # picoclvr options
152
153 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
154
155 parser.add_argument("--picoclvr_height", type=int, default=12)
156
157 parser.add_argument("--picoclvr_width", type=int, default=16)
158
159 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
160
161 ##############################
162 # Maze options
163
164 parser.add_argument("--maze_height", type=int, default=13)
165
166 parser.add_argument("--maze_width", type=int, default=21)
167
168 parser.add_argument("--maze_nb_walls", type=int, default=15)
169
170 ##############################
171 # Snake options
172
173 parser.add_argument("--snake_height", type=int, default=9)
174
175 parser.add_argument("--snake_width", type=int, default=12)
176
177 parser.add_argument("--snake_nb_colors", type=int, default=5)
178
179 parser.add_argument("--snake_length", type=int, default=200)
180
181 ##############################
182 # Stack options
183
184 parser.add_argument("--stack_nb_steps", type=int, default=100)
185
186 parser.add_argument("--stack_nb_stacks", type=int, default=3)
187
188 parser.add_argument("--stack_nb_digits", type=int, default=3)
189
190 parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
191
192 ##############################
193 # Expr options
194
195 parser.add_argument("--expr_nb_variables", type=int, default=5)
196
197 parser.add_argument("--expr_sequence_length", type=int, default=40)
198
199 parser.add_argument("--expr_operand_max", type=int, default=9)
200
201 parser.add_argument("--expr_result_max", type=int, default=99)
202
203 parser.add_argument("--expr_input_file", type=str, default=None)
204
205 ##############################
206 # Memory
207
208 parser.add_argument("--memory_len_total", type=int, default=32)
209
210 ##############################
211 # Mixing
212
213 parser.add_argument("--mixing_hard", action="store_true", default=False)
214
215 parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
216
217 ######################################################################
218
219 # args = parser.parse_args()
220
221 args, sup_args = parser.parse_known_args()
222
223 sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
224
225 if args.result_dir is None:
226     args.result_dir = f"results_{args.task}_{args.model}"
227
228 ######################################################################
229
230 if not args.force_cpu and torch.cuda.is_available():
231     device = torch.device("cuda")
232     torch.backends.cuda.matmul.allow_tf32 = True
233 else:
234     device = torch.device("cpu")
235
236 ######################################################################
237
238 default_task_args = {
239     "addition": {
240         "model": "352M",
241         "batch_size": 25,
242         "nb_train_samples": 250000,
243         "nb_test_samples": 10000,
244     },
245     "byheart": {
246         "model": "37M",
247         "batch_size": 25,
248         "nb_train_samples": 50000,
249         "nb_test_samples": 10000,
250     },
251     "expr": {
252         "model": "352M",
253         "batch_size": 25,
254         "nb_train_samples": 2500000,
255         "nb_test_samples": 10000,
256     },
257     "grid": {
258         "model": "37M",
259         "batch_size": 25,
260         "nb_train_samples": 250000,
261         "nb_test_samples": 10000,
262     },
263     "qmlp": {
264         "model": "37M",
265         "batch_size": 10,
266         "nb_train_samples": 100000,
267         "nb_test_samples": 1000,
268     },
269     "guessop": {
270         "model": "352M",
271         "batch_size": 25,
272         "nb_train_samples": 1000000,
273         "nb_test_samples": 10000,
274     },
275     "learnop": {
276         "model": "37M",
277         "batch_size": 25,
278         "nb_train_samples": 50000,
279         "nb_test_samples": 10000,
280     },
281     "maze": {
282         "model": "37M",
283         "batch_size": 5,
284         "nb_train_samples": 100000,
285         "nb_test_samples": 10000,
286     },
287     "picoclvr": {
288         "model": "37M",
289         "batch_size": 25,
290         "nb_train_samples": 250000,
291         "nb_test_samples": 10000,
292     },
293     "rpl": {
294         "model": "352M",
295         "batch_size": 5,
296         "nb_train_samples": 2500000,
297         "nb_test_samples": 10000,
298     },
299     "snake": {
300         "model": "37M",
301         "batch_size": 25,
302         "nb_train_samples": 250000,
303         "nb_test_samples": 10000,
304     },
305     "stack": {
306         "model": "37M",
307         "batch_size": 25,
308         "nb_train_samples": 100000,
309         "nb_test_samples": 1000,
310     },
311     "twotargets": {
312         "model": "37M",
313         "batch_size": 25,
314         "nb_train_samples": 50000,
315         "nb_test_samples": 10000,
316     },
317     "memory": {
318         "model": "37M",
319         "batch_size": 25,
320         "nb_train_samples": 25000,
321         "nb_test_samples": 10000,
322     },
323     "mixing": {
324         "model": "37M",
325         "batch_size": 25,
326         "nb_train_samples": 250000,
327         "nb_test_samples": 10000,
328     },
329     "mnist": {
330         "model": "37M",
331         "batch_size": 10,
332         "nb_train_samples": 60000,
333         "nb_test_samples": 10000,
334     },
335 }
336
337 if args.task in default_task_args:
338     for k, v in default_task_args[args.task].items():
339         if getattr(args, k) is None:
340             setattr(args, k, v)
341
342 ######################################################################
343
344 default_model_args = {
345     "17K": {
346         "attention": "mha",
347         "dim_model": 32,
348         "dim_keys": 32,
349         "dim_hidden": 32,
350         "nb_heads": 2,
351         "nb_blocks": 2,
352     },
353     "17K-C": {
354         "attention": "caterpillar",
355         "dim_model": 32,
356         "dim_keys": 32,
357         "dim_hidden": 32,
358         "nb_heads": 2,
359         "nb_lines": 16,
360         "caterpillar_height": 4,
361         "nb_blocks": 2,
362     },
363     "4M": {
364         "attention": "mha",
365         "dim_model": 256,
366         "dim_keys": 32,
367         "dim_hidden": 1024,
368         "nb_heads": 4,
369         "nb_blocks": 6,
370     },
371     "4M-C": {
372         "attention": "caterpillar",
373         "dim_model": 256,
374         "dim_keys": 32,
375         "dim_hidden": 1024,
376         "nb_heads": 4,
377         "nb_lines": 32,
378         "caterpillar_height": 4,
379         "nb_blocks": 6,
380     },
381     "37M": {
382         "attention": "mha",
383         "dim_model": 512,
384         "dim_keys": 64,
385         "dim_hidden": 2048,
386         "nb_heads": 8,
387         "nb_blocks": 12,
388     },
389     "37M-C": {
390         "attention": "caterpillar",
391         "dim_model": 512,
392         "dim_keys": 64,
393         "dim_hidden": 2048,
394         "nb_heads": 8,
395         "nb_lines": 256,
396         "caterpillar_height": 32,
397         "nb_blocks": 12,
398     },
399     "122M": {
400         "attention": "mha",
401         "dim_model": 768,
402         "dim_keys": 64,
403         "dim_hidden": 2048,
404         "nb_heads": 8,
405         "nb_blocks": 24,
406     },
407     "122M-C": {
408         "attention": "caterpillar",
409         "dim_model": 768,
410         "dim_keys": 64,
411         "dim_hidden": 2048,
412         "nb_heads": 8,
413         "nb_lines": 128,
414         "nb_blocks": 24,
415     },
416     "352M": {
417         "attention": "mha",
418         "dim_model": 1024,
419         "dim_keys": 64,
420         "dim_hidden": 2048,
421         "nb_heads": 8,
422         "nb_blocks": 48,
423     },
424     "352M-C": {
425         "attention": "caterpillar",
426         "dim_model": 1024,
427         "dim_keys": 64,
428         "dim_hidden": 2048,
429         "nb_heads": 8,
430         "nb_lines": 128,
431         "nb_blocks": 48,
432     },
433 }
434
435 if args.model in default_model_args:
436     for k, v in default_model_args[args.model].items():
437         if getattr(args, k) is None:
438             setattr(args, k, v)
439 else:
440     raise ValueError(f"Unknown model {args.model}")
441
442 ######################################################################
443
444 try:
445     os.mkdir(args.result_dir)
446 except FileExistsError:
447     if not args.continue_training:
448         print(f"result directory {args.result_dir} already exists")
449         exit(1)
450
451 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
452
453 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
454
455 if args.seed >= 0:
456     # torch.backends.cudnn.deterministic = True
457     # torch.backends.cudnn.benchmark = False
458     # torch.use_deterministic_algorithms(True)
459     torch.manual_seed(args.seed)
460     if torch.cuda.is_available():
461         torch.cuda.manual_seed_all(args.seed)
462
463 ######################################################################
464
465
466 def log_string(s):
467     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
468
469     if log_file is not None:
470         log_file.write(t + s + "\n")
471         log_file.flush()
472
473     print(t + s)
474     sys.stdout.flush()
475
476
477 with os.popen("sha256sum *.py") as f:
478     for l in f:
479         log_string(f"sha256sum {l.strip()}")
480
481 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
482 os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
483
484 log_string(f"argv {' '.join(sys.argv)}")
485
486 for n in vars(args):
487     log_string(f"args.{n} {getattr(args, n)}")
488
489 for k, v in sup_args.items():
490     log_string(f'sup_args["{k}"] "{v}"')
491
492
493 ######################################################################
494
495
496 def get_lr(n_epoch, it):
497     if args.legacy_lr_schedule:
498         # my crude scheduling to compare to previous baseline, added
499         # warmup though
500
501         if it < args.nb_warmup_iter:
502             return args.legacy_large_lr * it / args.nb_warmup_iter
503         elif n_epoch < args.legacy_nb_epoch_large_lr:
504             return args.legacy_large_lr
505         else:
506             return args.legacy_small_lr
507
508     # from nanoGPT
509
510     # 1) linear warmup for warmup_iter steps
511     if it < args.nb_warmup_iter:
512         return args.learning_rate * it / args.nb_warmup_iter
513     # 2) if it > nb_decay_iter, return min learning rate
514     if it > args.nb_decay_iter:
515         return args.min_learning_rate
516     # 3) in between, use cosine decay down to min learning rate
517     decay_ratio = (it - args.nb_warmup_iter) / (
518         args.nb_decay_iter - args.nb_warmup_iter
519     )
520     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
521     return args.min_learning_rate + coeff * (
522         args.learning_rate - args.min_learning_rate
523     )
524
525
526 ######################################################################
527
528
529 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
530
531
532 def picoclvr_pruner_horizontal_green(p):
533     return not ("green" in p and ("left" in p or "right" in p))
534
535
536 picoclvr_pruner_train = (
537     picoclvr_pruner_horizontal_green
538     if args.picocvlr_prune_properties in {"train+eval"}
539     else None
540 )
541
542 picoclvr_pruner_eval = (
543     (lambda p: not picoclvr_pruner_horizontal_green(p))
544     if args.picocvlr_prune_properties in {"train+eval", "eval"}
545     else None
546 )
547
548 ######################################################################
549
550 device_data = device
551
552 if args.task == "byheart":
553     task = tasks.SandBox(
554         problem=problems.ProblemByHeart(),
555         nb_train_samples=args.nb_train_samples,
556         nb_test_samples=args.nb_test_samples,
557         batch_size=args.batch_size,
558         logger=log_string,
559         device=device_data,
560     )
561     args.max_percents_of_test_in_train = -1
562
563 elif args.task == "learnop":
564     task = tasks.SandBox(
565         problem=problems.ProblemLearnOperator(),
566         nb_train_samples=args.nb_train_samples,
567         nb_test_samples=args.nb_test_samples,
568         batch_size=args.batch_size,
569         logger=log_string,
570         device=device_data,
571     )
572
573
574 elif args.task == "guessop":
575     task = tasks.SandBox(
576         problem=problems.ProblemGuessOperator(),
577         nb_train_samples=args.nb_train_samples,
578         nb_test_samples=args.nb_test_samples,
579         batch_size=args.batch_size,
580         logger=log_string,
581         device=device_data,
582     )
583
584
585 elif args.task == "twotargets":
586     task = tasks.SandBox(
587         problem=problems.ProblemTwoTargets(),
588         nb_train_samples=args.nb_train_samples,
589         nb_test_samples=args.nb_test_samples,
590         batch_size=args.batch_size,
591         logger=log_string,
592         device=device_data,
593     )
594
595 elif args.task == "memory":
596     task = tasks.SandBox(
597         problem=problems.ProblemMemory(len_total=args.memory_len_total),
598         nb_train_samples=args.nb_train_samples,
599         nb_test_samples=args.nb_test_samples,
600         batch_size=args.batch_size,
601         logger=log_string,
602         device=device_data,
603     )
604
605 elif args.task == "mixing":
606     task = tasks.SandBox(
607         problem=problems.ProblemMixing(
608             hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
609         ),
610         nb_train_samples=args.nb_train_samples,
611         nb_test_samples=args.nb_test_samples,
612         batch_size=args.batch_size,
613         logger=log_string,
614         device=device_data,
615     )
616
617 elif args.task == "addition":
618     task = tasks.SandBox(
619         problem=problems.ProblemAddition(),
620         nb_train_samples=args.nb_train_samples,
621         nb_test_samples=args.nb_test_samples,
622         batch_size=args.batch_size,
623         logger=log_string,
624         device=device_data,
625     )
626
627 elif args.task == "picoclvr":
628     task = tasks.PicoCLVR(
629         nb_train_samples=args.nb_train_samples,
630         nb_test_samples=args.nb_test_samples,
631         batch_size=args.batch_size,
632         height=args.picoclvr_height,
633         width=args.picoclvr_width,
634         nb_colors=args.picoclvr_nb_colors,
635         logger=log_string,
636         device=device_data,
637         pruner_train=picoclvr_pruner_train,
638         pruner_eval=picoclvr_pruner_eval,
639     )
640
641 elif args.task == "mnist":
642     task = tasks.MNIST(
643         nb_train_samples=args.nb_train_samples,
644         nb_test_samples=args.nb_test_samples,
645         batch_size=args.batch_size,
646         device=device_data,
647     )
648
649 elif args.task == "maze":
650     task = tasks.Maze(
651         nb_train_samples=args.nb_train_samples,
652         nb_test_samples=args.nb_test_samples,
653         batch_size=args.batch_size,
654         height=args.maze_height,
655         width=args.maze_width,
656         nb_walls=args.maze_nb_walls,
657         device=device_data,
658     )
659
660 elif args.task == "snake":
661     task = tasks.Snake(
662         nb_train_samples=args.nb_train_samples,
663         nb_test_samples=args.nb_test_samples,
664         batch_size=args.batch_size,
665         height=args.snake_height,
666         width=args.snake_width,
667         nb_colors=args.snake_nb_colors,
668         length=args.snake_length,
669         prompt_length=args.snake_length // 2,
670         device=device_data,
671     )
672
673 elif args.task == "stack":
674     task = tasks.Stack(
675         nb_train_samples=args.nb_train_samples,
676         nb_test_samples=args.nb_test_samples,
677         batch_size=args.batch_size,
678         logger=log_string,
679         nb_steps=args.stack_nb_steps,
680         nb_stacks=args.stack_nb_stacks,
681         nb_digits=args.stack_nb_digits,
682         fraction_values_for_train=args.stack_fraction_values_for_train,
683         device=device_data,
684     )
685
686 elif args.task == "expr":
687     task = tasks.Expr(
688         nb_train_samples=args.nb_train_samples,
689         nb_test_samples=args.nb_test_samples,
690         nb_variables=args.expr_nb_variables,
691         sequence_length=args.expr_sequence_length,
692         operand_max=args.expr_operand_max,
693         result_max=args.expr_result_max,
694         batch_size=args.batch_size,
695         device=device_data,
696     )
697
698 elif args.task == "rpl":
699     task = tasks.RPL(
700         nb_train_samples=args.nb_train_samples,
701         nb_test_samples=args.nb_test_samples,
702         batch_size=args.batch_size,
703         nb_starting_values=args.rpl_nb_starting_values,
704         max_input=args.rpl_max_input,
705         prog_len=args.rpl_prog_len,
706         nb_runs=args.rpl_nb_runs,
707         no_prog=args.rpl_no_prog,
708         logger=log_string,
709         device=device_data,
710     )
711
712 elif args.task == "grid":
713     task = tasks.Grid(
714         nb_train_samples=args.nb_train_samples,
715         nb_test_samples=args.nb_test_samples,
716         batch_size=args.batch_size,
717         size=args.grid_size,
718         nb_shapes=args.grid_nb_shapes,
719         nb_colors=args.grid_nb_colors,
720         logger=log_string,
721         device=device_data,
722     )
723
724 elif args.task == "qmlp":
725     task = tasks.QMLP(
726         nb_train_samples=args.nb_train_samples,
727         nb_test_samples=args.nb_test_samples,
728         batch_size=args.batch_size,
729         result_dir=args.result_dir,
730         logger=log_string,
731         device=device_data,
732     )
733
734 else:
735     raise ValueError(f"Unknown task {args.task}")
736
737 ######################################################################
738
739 log_string(f"device {device}")
740
741 vocabulary_size = task.vocabulary_size()
742
743 if args.memex_proba > 0:
744     vocabulary_size += 1
745
746 log_string(f"vocabulary_size {vocabulary_size}")
747
748 ##############################
749
750 model = mygpt.MyGPT(
751     vocabulary_size=vocabulary_size,
752     dim_model=args.dim_model,
753     dim_keys=args.dim_keys,
754     dim_hidden=args.dim_hidden,
755     nb_heads=args.nb_heads,
756     nb_lines=args.nb_lines,
757     caterpillar_height=args.caterpillar_height,
758     nb_blocks=args.nb_blocks,
759     causal=True,
760     dropout=args.dropout,
761     attention_layer=args.attention,
762     logger=log_string,
763     args=args,
764 )
765
766 model.to(device)
767
768 nb_parameters = sum(p.numel() for p in model.parameters())
769 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
770
771 ######################################################################
772
773 nb_epochs_finished = 0
774
775 if args.no_checkpoint:
776     log_string(f"not trying to load checkpoint.")
777
778 else:
779     try:
780         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
781         checkpoint = torch.load(checkpoint_name)
782         nb_epochs_finished = checkpoint["nb_epochs_finished"]
783         model.load_state_dict(checkpoint["model_state"])
784         torch.set_rng_state(checkpoint["rng_state"])
785         if torch.cuda.is_available():
786             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
787
788         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
789
790     except FileNotFoundError:
791         log_string("starting from scratch.")
792
793     except:
794         log_string("error when loading the checkpoint.")
795         exit(1)
796
797 ######################################################################
798
799 if args.task == "expr" and args.expr_input_file is not None:
800     task.produce_results(
801         n_epoch=nb_epochs_finished,
802         model=model,
803         result_dir=args.result_dir,
804         logger=log_string,
805         deterministic_synthesis=args.deterministic_synthesis,
806         input_file=args.expr_input_file,
807     )
808
809     exit(0)
810
811 ######################################################################
812
813 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
814
815 # Compute the entropy of the training tokens
816
817 token_count = 0
818 for input in task.batches(split="train"):
819     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
820 token_probas = token_count / token_count.sum()
821 entropy = -torch.xlogy(token_probas, token_probas).sum()
822 train_set_perplexity = math.exp(entropy)
823
824 ######################################################################
825 # A bit of paranoia never hurts
826
827 if args.max_percents_of_test_in_train >= 0:
828
829     def subsets_as_tuples(batches, cs):
830         s = set()
831         for batch in batches:
832             for x in batch:
833                 s.add(tuple([v.item() for v in x]))
834                 if len(s) == cs:
835                     yield s
836                     s = set()
837         yield s
838
839     nb_test, nb_in_train = 0, 0
840     for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
841         in_train = set()
842         for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
843             in_train.update(test_subset.intersection(train_subset))
844         nb_in_train += len(in_train)
845         nb_test += len(test_subset)
846
847     log_string(
848         f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
849     )
850
851     assert (
852         nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
853     ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
854
855 ##############################
856
857 if "calibrate" in sup_args:
858     for input in task.batches(split="train", desc="calibrate"):
859         input = input.to(device)
860         output = model(mygpt.BracketedSequence(input)).x
861
862     for n, m in model.named_modules():
863         for a in dir(m):
864             x = getattr(m, a)
865             if isinstance(x, mygpt.Calibrator):
866                 print(f"####### ${n} | ${a} ########################")
867                 mean, std = x.moments()
868                 print("mean\n", mean, "\n")
869                 print("std\n", std, "\n")
870                 print(f"############################################\n\n")
871
872     exit(0)
873
874 ##############################
875
876 nb_samples_seen = 0
877
878 if nb_epochs_finished >= nb_epochs:
879     task.produce_results(
880         n_epoch=nb_epochs_finished,
881         model=model,
882         result_dir=args.result_dir,
883         logger=log_string,
884         deterministic_synthesis=args.deterministic_synthesis,
885     )
886
887 time_pred_result = datetime.datetime.now()
888
889 it = 0
890
891 n_batch = 0
892
893 for n_epoch in range(nb_epochs_finished, nb_epochs):
894     if args.optim == "sgd":
895         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
896     elif args.optim == "adam":
897         optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
898     elif args.optim == "adamw":
899         optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
900     else:
901         raise ValueError(f"Unknown optimizer {args.optim}.")
902
903     model.train()
904
905     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
906
907     def add_memex(batches, memex_proba):
908         for input in batches:
909             if torch.rand(1).item() < memex_proba:
910                 sep = torch.full(
911                     (input.size(0), 1), vocabulary_size - 1, device=input.device
912                 )
913
914                 yield torch.cat(
915                     [
916                         input,
917                         sep,
918                         input,
919                     ],
920                     dim=1,
921                 )
922             yield input
923
924     train_batches = add_memex(
925         task.batches(split="train"),
926         args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
927     )
928
929     for input in train_batches:
930         model.reset_inner_loss()
931         input = input.to(device)
932
933         output = model(mygpt.BracketedSequence(input)).x
934         loss = F.cross_entropy(output.transpose(1, 2), input)
935         inner_loss = model.get_inner_loss()
936
937         acc_train_loss += loss.item() * input.size(0)
938         acc_train_inner_loss += inner_loss.item() * input.size(0)
939
940         nb_train_samples += input.size(0)
941         nb_samples_seen += input.size(0)
942
943         total_loss = loss + (
944             args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
945         )
946
947         it += 1
948         lr = get_lr(n_epoch, it)
949         for param_group in optimizer.param_groups:
950             param_group["lr"] = lr
951
952         # log_string(f"learning_rate {lr}")
953
954         optimizer.zero_grad()
955         total_loss.backward()
956         optimizer.step()
957
958         grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
959
960         loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
961
962         n_batch += 1
963
964     with torch.autograd.no_grad():
965         model.eval()
966
967         nb_test_samples, acc_test_loss = 0, 0.0
968
969         for input in task.batches(split="test"):
970             input = input.to(device)
971
972             output = model(mygpt.BracketedSequence(input)).x
973             loss = F.cross_entropy(output.transpose(1, 2), input)
974             acc_test_loss += loss.item() * input.size(0)
975             nb_test_samples += input.size(0)
976
977         log_string(
978             f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
979         )
980
981         task.produce_results(
982             n_epoch=n_epoch,
983             model=model,
984             result_dir=args.result_dir,
985             logger=log_string,
986             deterministic_synthesis=args.deterministic_synthesis,
987         )
988
989         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
990         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
991
992         log_string(
993             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
994         )
995
996         time_current_result = datetime.datetime.now()
997         log_string(
998             f"next_result {time_current_result + (time_current_result - time_pred_result)}"
999         )
1000         time_pred_result = time_current_result
1001
1002     checkpoint = {
1003         "nb_epochs_finished": n_epoch + 1,
1004         "model_state": model.state_dict(),
1005         "rng_state": torch.get_rng_state(),
1006     }
1007
1008     if torch.cuda.is_available():
1009         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1010
1011     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1012     torch.save(checkpoint, checkpoint_name)
1013     log_string(f"saved checkpoint {checkpoint_name}")
1014
1015 ######################################################################