Update.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index 3bf7587..11cf0a3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -25,7 +25,7 @@ parser.add_argument('--log_filename',
                     type = str, default = 'train.log')
 
 parser.add_argument('--download',
-                    type = bool, default = False)
+                    action='store_true', default = False)
 
 parser.add_argument('--seed',
                     type = int, default = 0)
@@ -67,7 +67,13 @@ parser.add_argument('--dropout',
                     type = float, default = 0.1)
 
 parser.add_argument('--synthesis_sampling',
-                    type = bool, default = True)
+                    action='store_true', default = True)
+
+parser.add_argument('--checkpoint_name',
+                    type = str, default = 'checkpoint.pth')
+
+parser.add_argument('--picoclvr_many_colors',
+                    action='store_true', default = False)
 
 ######################################################################
 
@@ -350,7 +356,7 @@ if args.data == 'wiki103':
 elif args.data == 'mnist':
     task = TaskMNIST(batch_size = args.batch_size, device = device)
 elif args.data == 'picoclvr':
-    task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
+    task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
 else:
     raise ValueError(f'Unknown dataset {args.data}.')
 
@@ -366,11 +372,11 @@ model = mygpt.MyGPT(
     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
 )
 
+model.to(device)
+
 nb_parameters = sum(p.numel() for p in model.parameters())
 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
 
-model.to(device)
-
 ######################################################################
 
 if args.optim == 'sgd':
@@ -382,7 +388,27 @@ elif args.optim == 'adamw':
 else:
     raise ValueError(f'Unknown optimizer {args.optim}.')
 
-for k in range(args.nb_epochs):
+######################################################################
+
+nb_epochs_finished = 0
+
+try:
+    checkpoint = torch.load(args.checkpoint_name, map_location = device)
+    nb_epochs_finished = checkpoint['nb_epochs_finished']
+    model.load_state_dict(checkpoint['model_state'])
+    optimizer.load_state_dict(checkpoint['optimizer_state'])
+    print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+
+except FileNotFoundError:
+    print('Starting from scratch.')
+
+except:
+    print('Error when loading the checkpoint.')
+    exit(1)
+
+######################################################################
+
+for k in range(nb_epochs_finished, args.nb_epochs):
 
     model.train()
 
@@ -419,4 +445,12 @@ for k in range(args.nb_epochs):
 
         task.produce_results(k, model)
 
+    checkpoint = {
+        'nb_epochs_finished': k + 1,
+        'model_state': model.state_dict(),
+        'optimizer_state': optimizer.state_dict()
+    }
+
+    torch.save(checkpoint, args.checkpoint_name)
+
 ######################################################################