Update.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index b6eb6fe..aa1b517 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -216,17 +216,11 @@ class TaskPicoCLVR(Task):
     def vocabulary_size(self):
         return len(self.token2id)
 
-    def produce_results(self, n_epoch, model):
+    def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
         nb_tokens_to_generate = self.height * self.width + 3
         result_descr = [ ]
-        nb_per_primer = 8
 
-        for primer_descr in [
-                'red above green <sep> green top <sep> blue right of red <img>',
-                'there is red <sep> there is yellow <sep> there is blue <img>',
-                'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
-                'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
-        ]:
+        for primer_descr in primers_descr:
 
             results = autoregression(
                 model,
@@ -249,18 +243,57 @@ class TaskPicoCLVR(Task):
 
         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
 
-        img = [
-            picoclvr.descr2img(d, height = self.height, width = self.width)
-            for d in result_descr
+        np=torch.tensor(np)
+        count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
+        for i in range(count.size(0)):
+            for j in range(count.size(1)):
+                count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
+
+        if generate_images:
+            img = [
+                picoclvr.descr2img(d, height = self.height, width = self.width)
+                for d in result_descr
+            ]
+
+            img = torch.cat(img, 0)
+            image_name = f'result_picoclvr_{n_epoch:04d}.png'
+            torchvision.utils.save_image(
+                img / 255.,
+                image_name, nrow = nb_per_primer, pad_value = 0.8
+            )
+            log_string(f'wrote {image_name}')
+
+        return count
+
+    def produce_results(self, n_epoch, model):
+        primers_descr = [
+            'red above green <sep> green top <sep> blue right of red <img>',
+            'there is red <sep> there is yellow <sep> there is blue <img>',
+            'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
+            'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
         ]
 
-        img = torch.cat(img, 0)
-        image_name = f'result_picoclvr_{n_epoch:04d}.png'
-        torchvision.utils.save_image(
-            img / 255.,
-            image_name, nrow = nb_per_primer, pad_value = 0.8
+        self.test_model(
+            n_epoch, model,
+            primers_descr,
+            nb_per_primer=8, generate_images=True
         )
-        log_string(f'wrote {image_name}')
+
+        # FAR TOO SLOW!!!
+
+        # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
+
+        # count=self.test_model(
+            # n_epoch, model,
+            # test_primers_descr,
+            # nb_per_primer=1, generate_images=False
+        # )
+
+        # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
+            # for i in range(count.size(0)):
+                # for j in range(count.size(1)):
+                    # f.write(f'{count[i,j]}')
+                    # f.write(" " if j<count.size(1)-1 else "\n")
 
 ######################################################################