Update
[beaver.git] / beaver.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, itertools, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
30
31 parser.add_argument("--log_filename", type=str, default="train.log")
32
33 parser.add_argument("--result_dir", type=str, default="results_default")
34
35 parser.add_argument("--seed", type=int, default=0)
36
37 parser.add_argument("--nb_epochs", type=int, default=25)
38
39 parser.add_argument("--nb_train_samples", type=int, default=200000)
40
41 parser.add_argument("--nb_test_samples", type=int, default=50000)
42
43 parser.add_argument("--batch_size", type=int, default=25)
44
45 parser.add_argument("--optim", type=str, default="adam")
46
47 parser.add_argument("--learning_rate", type=float, default=1e-3)
48
49 parser.add_argument(
50     "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6"
51 )
52
53 parser.add_argument("--dim_model", type=int, default=512)
54
55 parser.add_argument("--dim_keys", type=int, default=64)
56
57 parser.add_argument("--dim_hidden", type=int, default=2048)
58
59 parser.add_argument("--nb_heads", type=int, default=8)
60
61 parser.add_argument("--nb_blocks", type=int, default=12)
62
63 parser.add_argument("--dropout", type=float, default=0.1)
64
65 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
66
67 parser.add_argument("--random_regression_order", action="store_true", default=False)
68
69 parser.add_argument("--no_checkpoint", action="store_true", default=False)
70
71 parser.add_argument("--overwrite_results", action="store_true", default=False)
72
73 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
74
75 ##############################
76 # maze options
77
78 parser.add_argument("--maze_height", type=int, default=13)
79
80 parser.add_argument("--maze_width", type=int, default=21)
81
82 parser.add_argument("--maze_nb_walls", type=int, default=15)
83
84 ##############################
85 # one-shot prediction
86
87 parser.add_argument("--oneshot", action="store_true", default=False)
88
89 parser.add_argument("--oneshot_input", type=str, default="head")
90
91 parser.add_argument("--oneshot_output", type=str, default="trace")
92
93 ######################################################################
94
95 args = parser.parse_args()
96
97 try:
98     os.mkdir(args.result_dir)
99 except FileExistsError:
100     if not args.overwrite_results:
101         print(f"result directory {args.result_dir} already exists")
102         exit(1)
103
104 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
105
106 if args.seed >= 0:
107     # torch.backends.cudnn.deterministic = True
108     # torch.backends.cudnn.benchmark = False
109     # torch.use_deterministic_algorithms(True)
110     torch.manual_seed(args.seed)
111     if torch.cuda.is_available():
112         torch.cuda.manual_seed_all(args.seed)
113
114 ######################################################################
115
116
117 def log_string(s):
118     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
119
120     if log_file is not None:
121         log_file.write(t + s + "\n")
122         log_file.flush()
123
124     print(t + s)
125     sys.stdout.flush()
126
127
128 for n in vars(args):
129     log_string(f"args.{n} {getattr(args, n)}")
130
131 ######################################################################
132
133
134 def generation_order(x, fixed_len):
135     if args.random_regression_order:
136         order = torch.rand(x.size(), device=x.device)
137         order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=x.device)
138         order = order.sort(1).indices
139     else:
140         order = (
141             torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
142         )
143     return order
144
145
146 def reorder(x, order, back=False):  # x is NxTxD1x...xDk, order is NxT'
147     u = x.reshape(x.size()[:2] + (-1,))
148     order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
149     if back:
150         v = u.new(u.size())
151         v.scatter_(1, order, u)
152     else:
153         v = u.gather(1, order)
154     v = v.reshape(v.size()[:2] + x.size()[2:])
155     return v
156
157
158 def shuffle(x, fixed_len):
159     order = generation_order(x, fixed_len)
160     return reorder(x, order), order
161
162
163 ######################################################################
164
165 # ar_mask is a Boolean matrix of same shape as input, with 1s on the
166 # tokens that should be generated
167
168
169 def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
170     for input, ar_mask, order in zip(
171         input.split(batch_size), ar_mask.split(batch_size), order.split(batch_size)
172     ):
173         i = (ar_mask.sum(0) > 0).nonzero()
174         if i.min() > 0:
175             # Needed to initialize the model's cache
176             model(mygpt.BracketedSequence(input, 0, i.min()), order=order)
177         for s in range(i.min(), i.max() + 1):
178             output = model(mygpt.BracketedSequence(input, s, 1), order=order).x
179             logits = output[:, s]
180             if args.deterministic_synthesis:
181                 t_next = logits.argmax(1)
182             else:
183                 dist = torch.distributions.categorical.Categorical(logits=logits)
184                 t_next = dist.sample()
185             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
186
187
188 ######################################################################
189
190
191 def compute_perplexity(model, task, fixed_len, split="train"):
192     with torch.autograd.no_grad():
193         t = model.training
194         model.eval()
195
196         nb_samples, acc_loss = 0, 0.0
197
198         for input in task.batches(split=split):
199             input = input.to(device)
200             x, order = shuffle(input, fixed_len)
201             x = model(mygpt.BracketedSequence(x), order=order).x
202             output = reorder(x, order, back=True)
203             loss = F.cross_entropy(output.transpose(1, 2), input)
204             acc_loss += loss.item() * input.size(0)
205             nb_samples += input.size(0)
206
207         model.train(t)
208
209         return math.exp(min(100, acc_loss / nb_samples))
210
211
212 ######################################################################
213
214
215 def oneshot_policy_loss(mazes, output, policies, height, width):
216     masks = (mazes == maze.v_empty).unsqueeze(-1)
217     targets = policies.permute(0, 2, 1) * masks
218     output = output * masks
219     return -(output.log_softmax(-1) * targets).sum() / masks.sum()
220
221
222 def oneshot_trace_loss(mazes, output, policies, height, width):
223     masks = mazes == maze.v_empty
224     targets = maze.stationary_densities(
225         mazes.view(-1, height, width), policies.view(-1, 4, height, width)
226     ).flatten(-2)
227     targets = targets * masks
228     output = output.squeeze(-1) * masks
229     return (output - targets).abs().sum() / masks.sum()
230
231
232 def oneshot(gpt, task):
233     t = gpt.training
234     gpt.eval()
235
236     if args.oneshot_input == "head":
237         dim_in = args.dim_model
238     elif args.oneshot_input == "deep":
239         dim_in = args.dim_model * args.nb_blocks * 2
240     else:
241         raise ValueError(f"{args.oneshot_input=}")
242
243     if args.oneshot_output == "policy":
244         dim_out = 4
245         compute_loss = oneshot_policy_loss
246     elif args.oneshot_output == "trace":
247         dim_out = 1
248         compute_loss = oneshot_trace_loss
249     else:
250         raise ValueError(f"{args.oneshot_output=}")
251
252     model = nn.Sequential(
253         nn.Linear(dim_in, args.dim_model),
254         nn.ReLU(),
255         nn.Linear(args.dim_model, args.dim_model),
256         nn.ReLU(),
257         nn.Linear(args.dim_model, dim_out),
258     ).to(device)
259
260     for n_epoch in range(args.nb_epochs):
261         learning_rate = learning_rate_schedule[n_epoch]
262         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
263
264         acc_train_loss, nb_train_samples = 0, 0
265         for mazes, policies in task.policy_batches(split="train"):
266             x, order = shuffle(mazes, task.height * task.width)
267             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
268             output_gpt = reorder(x, order, back=True)
269             output = model(output_gpt)
270
271             loss = compute_loss(mazes, output, policies, task.height, task.width)
272             acc_train_loss += loss.item() * mazes.size(0)
273             nb_train_samples += mazes.size(0)
274
275             optimizer.zero_grad()
276             loss.backward()
277             optimizer.step()
278
279         acc_test_loss, nb_test_samples = 0, 0
280         for mazes, policies in task.policy_batches(split="test"):
281             x, order = shuffle(mazes, task.height * task.width)
282             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
283             output_gpt = reorder(x, order, back=True)
284             output = model(output_gpt)
285             loss = compute_loss(mazes, output, policies, task.height, task.width)
286             acc_test_loss += loss.item() * mazes.size(0)
287             nb_test_samples += mazes.size(0)
288
289         log_string(
290             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
291         )
292
293         # -------------------
294         mazes = task.test_input[:32, : task.height * task.width]
295         policies = task.test_policies[:32]
296         x, order = shuffle(mazes, task.height * task.width)
297         x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
298         output_gpt = reorder(x, order, back=True)
299         output = model(output_gpt)
300         if args.oneshot_output == "policy":
301             targets = policies.permute(0, 2, 1)
302             scores = (
303                 (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
304             ).float()
305         elif args.oneshot_output == "trace":
306             targets = maze.stationary_densities(
307                 mazes.view(-1, task.height, task.width),
308                 policies.view(-1, 4, task.height, task.width),
309             ).flatten(-2)
310             scores = output
311         else:
312             raise ValueError(f"{args.oneshot_output=}")
313
314         scores = scores.reshape(-1, task.height, task.width)
315         mazes = mazes.reshape(-1, task.height, task.width)
316         targets = targets.reshape(-1, task.height, task.width)
317         filename = (
318             f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
319         )
320         maze.save_image(
321             os.path.join(args.result_dir, filename),
322             mazes=mazes,
323             score_paths=scores,
324             score_truth=targets,
325         )
326         log_string(f"wrote {filename}")
327
328         # -------------------
329
330     gpt.train(t)
331
332
333 ######################################################################
334
335
336 class Task:
337     def batches(self, split="train", nb_to_use=-1, desc=None):
338         pass
339
340     def vocabulary_size(self):
341         pass
342
343     def produce_results(self, n_epoch, model):
344         pass
345
346
347 ######################################################################
348
349 import maze
350
351
352 class TaskMaze(Task):
353     def map2seq(self, *m):
354         return torch.cat([x.flatten(1) for x in m], 1)
355
356     def seq2map(self, s):
357         s = s.reshape(s.size(0), -1, self.height, self.width)
358         return (s[:, k] for k in range(s.size(1)))
359
360     def __init__(
361         self,
362         nb_train_samples,
363         nb_test_samples,
364         batch_size,
365         height,
366         width,
367         nb_walls,
368         device=torch.device("cpu"),
369     ):
370         self.batch_size = batch_size
371         self.height = height
372         self.width = width
373         self.device = device
374
375         train_mazes, train_paths, train_policies = maze.create_maze_data(
376             nb_train_samples,
377             height=height,
378             width=width,
379             nb_walls=nb_walls,
380             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
381         )
382         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
383         self.train_policies = train_policies.flatten(-2).to(device)
384
385         test_mazes, test_paths, test_policies = maze.create_maze_data(
386             nb_test_samples,
387             height=height,
388             width=width,
389             nb_walls=nb_walls,
390             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
391         )
392         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
393         self.test_policies = test_policies.flatten(-2).to(device)
394
395         self.nb_codes = self.train_input.max() + 1
396
397     def batches(self, split="train", nb_to_use=-1, desc=None):
398         assert split in {"train", "test"}
399         input = self.train_input if split == "train" else self.test_input
400         if nb_to_use > 0:
401             input = input[:nb_to_use]
402         if desc is None:
403             desc = f"epoch-{split}"
404         for batch in tqdm.tqdm(
405             input.split(self.batch_size), dynamic_ncols=True, desc=desc
406         ):
407             yield batch
408
409     def policy_batches(self, split="train", nb_to_use=-1, desc=None):
410         assert split in {"train", "test"}
411         input = self.train_input if split == "train" else self.test_input
412         policies = self.train_policies if split == "train" else self.test_policies
413         input = input[:, : self.height * self.width]
414         policies = policies * (input != maze.v_wall)[:, None]
415
416         if nb_to_use > 0:
417             input = input[:nb_to_use]
418             policies = policies[:nb_to_use]
419
420         if desc is None:
421             desc = f"epoch-{split}"
422         for batch in tqdm.tqdm(
423             zip(input.split(self.batch_size), policies.split(self.batch_size)),
424             dynamic_ncols=True,
425             desc=desc,
426         ):
427             yield batch
428
429     def vocabulary_size(self):
430         return self.nb_codes
431
432     def compute_error(self, model, split="train", nb_to_use=-1):
433         nb_total, nb_correct = 0, 0
434         for input in task.batches(split, nb_to_use):
435             result = input.clone()
436             ar_mask = result.new_zeros(result.size())
437             ar_mask[:, self.height * self.width :] = 1
438             result *= 1 - ar_mask
439             x, order = shuffle(result, self.height * self.width)
440             masked_inplace_autoregression(
441                 model, self.batch_size, x, ar_mask, order=order
442             )
443             result = reorder(x, order, back=True)
444             mazes, paths = self.seq2map(result)
445             nb_correct += maze.path_correctness(mazes, paths).long().sum()
446             nb_total += mazes.size(0)
447
448         return nb_total, nb_correct
449
450     def produce_results(self, n_epoch, model):
451         with torch.autograd.no_grad():
452             t = model.training
453             model.eval()
454
455             train_nb_total, train_nb_correct = self.compute_error(
456                 model, "train", nb_to_use=1000
457             )
458             log_string(
459                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
460             )
461
462             test_nb_total, test_nb_correct = self.compute_error(
463                 model, "test", nb_to_use=1000
464             )
465             log_string(
466                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
467             )
468
469             input = self.test_input[:32]
470             result = input.clone()
471             ar_mask = result.new_zeros(result.size())
472             ar_mask[:, self.height * self.width :] = 1
473             result *= 1 - ar_mask
474             x, order = shuffle(result, self.height * self.width)
475             masked_inplace_autoregression(
476                 model, self.batch_size, x, ar_mask, order=order
477             )
478             result = reorder(x, order, back=True)
479
480             mazes, paths = self.seq2map(input)
481             _, predicted_paths = self.seq2map(result)
482             filename = f"result_{n_epoch:04d}.png"
483             maze.save_image(
484                 os.path.join(args.result_dir, filename),
485                 mazes=mazes,
486                 target_paths=paths,
487                 predicted_paths=predicted_paths,
488                 path_correct=maze.path_correctness(mazes, predicted_paths),
489             )
490             log_string(f"wrote {filename}")
491
492             model.train(t)
493
494
495 ######################################################################
496
497 log_string(f"device {device}")
498
499
500 task = TaskMaze(
501     nb_train_samples=args.nb_train_samples,
502     nb_test_samples=args.nb_test_samples,
503     batch_size=args.batch_size,
504     height=args.maze_height,
505     width=args.maze_width,
506     nb_walls=args.maze_nb_walls,
507     device=device,
508 )
509
510
511 vocabulary_size = task.vocabulary_size()
512
513 log_string(f"vocabulary_size {vocabulary_size}")
514
515 ##############################
516
517 model = mygpt.MyGPT(
518     vocabulary_size=vocabulary_size,
519     dim_model=args.dim_model,
520     dim_keys=args.dim_keys,
521     dim_hidden=args.dim_hidden,
522     nb_heads=args.nb_heads,
523     nb_blocks=args.nb_blocks,
524     causal=True,
525     dropout=args.dropout,
526 )
527
528 model.to(device)
529
530 nb_parameters = sum(p.numel() for p in model.parameters())
531 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
532
533 ######################################################################
534
535 nb_epochs_finished = 0
536
537 if args.no_checkpoint:
538     log_string(f"not trying to load checkpoint.")
539
540 else:
541     try:
542         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
543         checkpoint = torch.load(checkpoint_name)
544         nb_epochs_finished = checkpoint["nb_epochs_finished"]
545         model.load_state_dict(checkpoint["model_state"])
546         torch.set_rng_state(checkpoint["rng_state"])
547         if torch.cuda.is_available():
548             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
549
550         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
551
552     except FileNotFoundError:
553         log_string("starting from scratch.")
554
555     except:
556         log_string("error when loading the checkpoint.")
557         exit(1)
558
559 ######################################################################
560
561 token_count = 0
562 for input in task.batches(split="train"):
563     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
564 token_probas = token_count / token_count.sum()
565 entropy = -torch.xlogy(token_probas, token_probas).sum()
566 train_set_perplexity = math.exp(entropy)
567
568 ##############################
569
570 if args.learning_rate_schedule == "cos":
571     learning_rate_schedule = {}
572     for n_epoch in range(args.nb_epochs):
573         u = n_epoch / args.nb_epochs * math.pi
574         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
575 else:
576     u = {
577         int(k): float(v)
578         for k, v in [
579             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
580         ]
581     }
582
583     learning_rate_schedule = {}
584     learning_rate = args.learning_rate
585     for n_epoch in range(args.nb_epochs):
586         if n_epoch in u:
587             learning_rate = u[n_epoch]
588         learning_rate_schedule[n_epoch] = learning_rate
589
590 log_string(f"learning_rate_schedule {learning_rate_schedule}")
591
592 ##############################
593
594 if nb_epochs_finished >= args.nb_epochs:
595     n_epoch = nb_epochs_finished
596     train_perplexity = compute_perplexity(
597         model, task, fixed_len=task.height * task.width, split="train"
598     )
599     test_perplexity = compute_perplexity(
600         model, task, fixed_len=task.height * task.width, split="test"
601     )
602
603     log_string(
604         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
605     )
606
607     task.produce_results(n_epoch, model)
608
609 ##############################
610
611 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
612     learning_rate = learning_rate_schedule[n_epoch]
613
614     log_string(f"learning_rate {learning_rate}")
615
616     if args.optim == "sgd":
617         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
618     elif args.optim == "adam":
619         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
620     elif args.optim == "adamw":
621         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
622     else:
623         raise ValueError(f"{args.optim=}")
624
625     model.train()
626
627     nb_train_samples, acc_train_loss = 0, 0.0
628
629     for input in task.batches(split="train"):
630         input = input.to(device)
631         x, order = shuffle(input, task.height * task.width)
632         x = model(mygpt.BracketedSequence(x), order=order).x
633         output = reorder(x, order, back=True)
634         loss = F.cross_entropy(output.transpose(1, 2), input)
635         acc_train_loss += loss.item() * input.size(0)
636         nb_train_samples += input.size(0)
637
638         optimizer.zero_grad()
639         loss.backward()
640         optimizer.step()
641
642     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
643     test_perplexity = compute_perplexity(
644         model, task, fixed_len=task.height * task.width, split="test"
645     )
646
647     log_string(
648         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
649     )
650
651     task.produce_results(n_epoch, model)
652
653     checkpoint = {
654         "nb_epochs_finished": n_epoch + 1,
655         "model_state": model.state_dict(),
656         "rng_state": torch.get_rng_state(),
657     }
658
659     if torch.cuda.is_available():
660         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
661
662     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
663     torch.save(checkpoint, checkpoint_name)
664     log_string(f"saved checkpoint {checkpoint_name}")
665
666 ######################################################################
667
668 if args.oneshot:
669     oneshot(model, task)
670
671 ######################################################################