Heavy fix.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 08:12:13 +0000 (10:12 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 08:12:13 +0000 (10:12 +0200)
cnn-svrt.py
vignette_set.py

index 7bef242..694f035 100755 (executable)
@@ -184,15 +184,19 @@ for problem_number in range(1, 24):
     for p in model.parameters(): nb_parameters += p.numel()
     log_string('nb_parameters {:d}'.format(nb_parameters))
 
     for p in model.parameters(): nb_parameters += p.numel()
     log_string('nb_parameters {:d}'.format(nb_parameters))
 
+    need_to_train = False
     try:
     try:
-
         model.load_state_dict(torch.load(model_filename))
         log_string('loaded_model ' + model_filename)
         model.load_state_dict(torch.load(model_filename))
         log_string('loaded_model ' + model_filename)
-
     except:
     except:
+        need_to_train = True
+
+    if need_to_train:
 
         log_string('training_model ' + model_filename)
 
 
         log_string('training_model ' + model_filename)
 
+        t = time.time()
+
         if args.compress_vignettes:
             train_set = CompressedVignetteSet(problem_number,
                                               args.nb_train_batches, args.batch_size,
         if args.compress_vignettes:
             train_set = CompressedVignetteSet(problem_number,
                                               args.nb_train_batches, args.batch_size,
@@ -208,6 +212,10 @@ for problem_number in range(1, 24):
                                    args.nb_test_batches, args.batch_size,
                                    cuda=torch.cuda.is_available())
 
                                    args.nb_test_batches, args.batch_size,
                                    cuda=torch.cuda.is_available())
 
+        log_string('data_generation {:0.2f} samples / s'.format(
+            (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
+        )
+
         train_model(model, train_set)
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
         train_model(model, train_set)
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
index 72880ba..c46beea 100755 (executable)
@@ -32,11 +32,12 @@ import svrt
 ######################################################################
 
 def generate_one_batch(s):
 ######################################################################
 
 def generate_one_batch(s):
-    svrt.seed(s)
-    target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
+    problem_number, batch_size, cuda, random_seed = s
+    svrt.seed(random_seed)
+    target = torch.LongTensor(batch_size).bernoulli_(0.5)
     input = svrt.generate_vignettes(problem_number, target)
     input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
     input = svrt.generate_vignettes(problem_number, target)
     input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
-    if self.cuda:
+    if cuda:
         input = input.cuda()
         target = target.cuda()
     return [ input, target ]
         input = input.cuda()
         target = target.cuda()
     return [ input, target ]
@@ -50,13 +51,16 @@ class VignetteSet:
         self.nb_batches = nb_batches
         self.nb_samples = self.nb_batches * self.batch_size
 
         self.nb_batches = nb_batches
         self.nb_samples = self.nb_batches * self.batch_size
 
-        seed_list = torch.LongTensor(self.nb_batches).random_().tolist()
+        seeds = torch.LongTensor(self.nb_batches).random_()
+        mp_args = []
+        for b in range(0, self.nb_batches):
+            mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ])
 
         # self.data = []
         # for b in range(0, self.nb_batches):
 
         # self.data = []
         # for b in range(0, self.nb_batches):
-            # self.data.append(generate_one_batch(seed_list[b]))
+            # self.data.append(generate_one_batch(mp_args[b]))
 
 
-        self.data = Pool(cpu_count()).map(generate_one_batch, seed_list)
+        self.data = Pool(cpu_count()).map(generate_one_batch, mp_args)
 
         acc = 0.0
         acc_sq = 0.0
 
         acc = 0.0
         acc_sq = 0.0