Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 06:44:21 +0000 (08:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 06:44:21 +0000 (08:44 +0200)
main.py
tasks.py
world.py

diff --git a/main.py b/main.py
index e18887b..3be3d55 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -266,6 +266,7 @@ picoclvr_pruner_eval = (
 
 if args.task == "sandbox":
     task = tasks.SandBox(
+        tasks.ProblemByheart(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
index 9cd06ae..eef84af 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -64,10 +64,10 @@ class Task:
 
 
 class Problem:
-    def generate(nb):
+    def generate_sequences(self, nb):
         pass
 
-    def perf(seq, logger):
+    def log_performance(self, sequences, logger):
         pass
 
 
@@ -75,15 +75,33 @@ class ProblemByheart(Problem):
     def __init__(self):
         nb_seq, len_prompt, len_result = 100, 5, 5
         self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
-        self.seq[:, len_prompt] = -1
+        self.seq[:, len_prompt] = 10
 
     def generate_sequences(self, nb):
-        return self.seq[torch.randint(self.seq.size(0), (nb,))]
-
+        sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+        ar_mask = (sequences==10).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+        # problems = [ProblemByheart()]
+        # nb_common_codes = 100
+
+        # def generate_sequences(nb_samples):
+            # problem_indexes = torch.randint(len(problems), (nb_samples,))
+            # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+            # print(f"{nb_samples_per_problem}")
+            # all_seq = []
+            # for nb, p in zip(nb_samples_per_problem, problems):
+                # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+            # return all_seq
+
+        # for strain, stest in zip(train_seq, test_seq):
+            # s = torch.cat((strain, stest), 0)
 
 class SandBox(Task):
     def __init__(
         self,
+        problem,
         nb_train_samples,
         nb_test_samples,
         batch_size,
@@ -93,24 +111,10 @@ class SandBox(Task):
         super().__init__()
 
         self.batch_size = batch_size
+        self.device = device
 
-        problems = [ProblemByheart()]
-        nb_common_codes = 100
-
-        def generate_sequences(nb_samples):
-            problem_indexes = torch.randint(len(problems), (nb_samples,))
-            nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
-            print(f"{nb_samples_per_problem}")
-            all_seq = []
-            for nb, p in zip(nb_samples_per_problem, problems):
-                all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
-            return all_seq
-
-        train_seq = generate_sequences(nb_train_samples)
-        test_seq = generate_sequences(nb_test_samples)
-
-        for strain, stest in zip(train_seq, test_seq):
-            s = torch.cat((strain, stest), 0)
+        self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples)
+        self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
@@ -132,11 +136,35 @@ class SandBox(Task):
     def produce_results(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis
     ):
-        # logger(
-        # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        # )
-        pass
 
+        def compute_accuracy(input, ar_mask):
+            result = input.clone() * (1-ar_mask)
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            nb_total = ar_mask.sum().item()
+            nb_correct = ((result==input).long() * ar_mask).sum().item()
+
+            return nb_total, nb_correct
+
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask)
+
+        logger(
+            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+        )
+
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask)
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
 
 ######################################################################
 
index b35a08e..12c6553 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -96,8 +96,6 @@ def train_encoder(
     logger=None,
     device=torch.device("cpu"),
 ):
-    if logger is None:
-        logger = lambda s: print(s)
 
     mu, std = train_input.float().mean(), train_input.float().std()
 
@@ -157,7 +155,7 @@ def train_encoder(
 
     nb_parameters = sum(p.numel() for p in model.parameters())
 
-    logger(f"nb_parameters {nb_parameters}")
+    logger(f"vqae nb_parameters {nb_parameters}")
 
     model.to(device)
 
@@ -209,7 +207,7 @@ def train_encoder(
         train_loss = acc_train_loss / train_input.size(0)
         test_loss = acc_test_loss / test_input.size(0)
 
-        logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+        logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
         sys.stdout.flush()
 
     return encoder, quantizer, decoder
@@ -378,6 +376,9 @@ def create_data_and_processors(
     if mode == "first_last":
         steps = [True] + [False] * (nb_steps + 1) + [True]
 
+    if logger is None:
+        logger = lambda s: print(s)
+
     train_input, train_actions = generate_episodes(nb_train_samples, steps)
     train_input, train_actions = train_input.to(device_storage), train_actions.to(
         device_storage
@@ -405,6 +406,8 @@ def create_data_and_processors(
     pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
     z_h, z_w = z.size(2), z.size(3)
 
+    logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
+
     def frame2seq(input, batch_size=25):
         seq = []
         p = pow2.to(device)