Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 16:48:43 +0000 (18:48 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 16:48:43 +0000 (18:48 +0200)
tasks.py
world.py

index 8b57cb2..5583fc8 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -73,8 +73,12 @@ class Problem:
 
 class ProblemByheart(Problem):
     def __init__(self):
-        pass
+        nb_seq, len_prompt, len_result = 100, 5, 5
+        self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+        self.seq[:,len_prompt]=-1
 
+    def generate_sequences(self, nb):
+        return self.seq[torch.randint(self.seq.size(0), (nb,))]
 
 class SandBox(Task):
     def __init__(
@@ -89,13 +93,23 @@ class SandBox(Task):
 
         self.batch_size = batch_size
 
+        problems = [ ProblemByheart() ]
+        nb_common_codes = 100
+
         def generate_sequences(nb_samples):
             problem_indexes = torch.randint(len(problems), (nb_samples,))
             nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
             print(f"{nb_samples_per_problem}")
+            all_seq = []
+            for nb, p in zip(nb_samples_per_problem,problems):
+                all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+            return all_seq
+
+        train_seq = generate_sequences(nb_train_samples)
+        test_seq = generate_sequences(nb_test_samples)
 
-        self.train_input = generate_sequences(nb_train_samples)
-        self.test_input = generate_sequences(nb_test_samples)
+        for strain, stest in zip(train_seq, test_seq):
+            s = torch.cat((strain,stest),0)
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
index 64c7434..3d6abbe 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -61,6 +61,19 @@ class SignSTE(nn.Module):
         else:
             return s
 
+class DiscreteSampler2d(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        s = (x >= x.max(-3,keepdim=True).values).float()
+
+        if self.training:
+            u = x.softmax(dim=-3)
+            return s + u - u.detach()
+        else:
+            return s
+
 
 def loss_H(binary_logits, h_threshold=1):
     p = binary_logits.sigmoid().mean(0)
@@ -159,7 +172,7 @@ def train_encoder(
         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
             input = input.to(device)
             z = encoder(input)
-            zq = z if k < 2 else quantizer(z)
+            zq = quantizer(z)
             output = decoder(zq)
 
             output = output.reshape(
@@ -182,7 +195,7 @@ def train_encoder(
         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
             input = input.to(device)
             z = encoder(input)
-            zq = z if k < 1 else quantizer(z)
+            zq = quantizer(z)
             output = decoder(zq)
 
             output = output.reshape(
@@ -440,7 +453,7 @@ if __name__ == "__main__":
         seq2frame,
     ) = create_data_and_processors(
         25000, 1000,
-        nb_epochs=10,
+        nb_epochs=5,
         mode="first_last",
         nb_steps=20,
     )