From 3dea181a5903a0e577e4830c66405b40f2a2df1d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 18:48:43 +0200 Subject: [PATCH] Update. --- tasks.py | 20 +++++++++++++++++--- world.py | 19 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tasks.py b/tasks.py index 8b57cb2..5583fc8 100755 --- a/tasks.py +++ b/tasks.py @@ -73,8 +73,12 @@ class Problem: class ProblemByheart(Problem): def __init__(self): - pass + nb_seq, len_prompt, len_result = 100, 5, 5 + self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq[:,len_prompt]=-1 + def generate_sequences(self, nb): + return self.seq[torch.randint(self.seq.size(0), (nb,))] class SandBox(Task): def __init__( @@ -89,13 +93,23 @@ class SandBox(Task): self.batch_size = batch_size + problems = [ ProblemByheart() ] + nb_common_codes = 100 + def generate_sequences(nb_samples): problem_indexes = torch.randint(len(problems), (nb_samples,)) nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) print(f"{nb_samples_per_problem}") + all_seq = [] + for nb, p in zip(nb_samples_per_problem,problems): + all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + return all_seq + + train_seq = generate_sequences(nb_train_samples) + test_seq = generate_sequences(nb_test_samples) - self.train_input = generate_sequences(nb_train_samples) - self.test_input = generate_sequences(nb_test_samples) + for strain, stest in zip(train_seq, test_seq): + s = torch.cat((strain,stest),0) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 diff --git a/world.py b/world.py index 64c7434..3d6abbe 100755 --- a/world.py +++ b/world.py @@ -61,6 +61,19 @@ class SignSTE(nn.Module): else: return s +class DiscreteSampler2d(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + s = (x >= x.max(-3,keepdim=True).values).float() + + if self.training: + u = x.softmax(dim=-3) + return s + u - u.detach() + else: + return s + def loss_H(binary_logits, h_threshold=1): p = binary_logits.sigmoid().mean(0) @@ -159,7 +172,7 @@ def train_encoder( for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"): input = input.to(device) z = encoder(input) - zq = z if k < 2 else quantizer(z) + zq = quantizer(z) output = decoder(zq) output = output.reshape( @@ -182,7 +195,7 @@ def train_encoder( for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"): input = input.to(device) z = encoder(input) - zq = z if k < 1 else quantizer(z) + zq = quantizer(z) output = decoder(zq) output = output.reshape( @@ -440,7 +453,7 @@ if __name__ == "__main__": seq2frame, ) = create_data_and_processors( 25000, 1000, - nb_epochs=10, + nb_epochs=5, mode="first_last", nb_steps=20, ) -- 2.20.1