From a92a5ca00f4277f7a133fa6cfaada2bc1981f524 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 12:57:36 +0200 Subject: [PATCH] Update. --- main.py | 23 ++++++++++++++--- tasks.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ world.py | 43 +++++++++++++++++++++++-------- 3 files changed, 128 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 305bd3c..69ee58f 100755 --- a/main.py +++ b/main.py @@ -34,8 +34,8 @@ parser = argparse.ArgumentParser( parser.add_argument( "--task", type=str, - default="picoclvr", - help="picoclvr, mnist, maze, snake, stack, expr, world", + default="sandbox", + help="sandbox, picoclvr, mnist, maze, snake, stack, expr, world", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -150,6 +150,12 @@ if args.result_dir is None: ###################################################################### default_args = { + "sandbox": { + "nb_epochs": 10, + "batch_size": 25, + "nb_train_samples": 25000, + "nb_test_samples": 10000, + }, "picoclvr": { "nb_epochs": 25, "batch_size": 25, @@ -189,7 +195,7 @@ default_args = { "world": { "nb_epochs": 10, "batch_size": 25, - "nb_train_samples": 125000, + "nb_train_samples": 25000, "nb_test_samples": 1000, }, } @@ -257,7 +263,16 @@ picoclvr_pruner_eval = ( ###################################################################### -if args.task == "picoclvr": +if args.task == "sandbox": + task = tasks.SandBox( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + +elif args.task == "picoclvr": task = tasks.PicoCLVR( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, diff --git a/tasks.py b/tasks.py index f8fb9b9..8b57cb2 100755 --- a/tasks.py +++ b/tasks.py @@ -60,6 +60,69 @@ class Task: pass +###################################################################### + + +class Problem: + def generate(nb): + pass + + def perf(seq, logger): + pass + + +class ProblemByheart(Problem): + def __init__(self): + pass + + +class SandBox(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + + def generate_sequences(nb_samples): + problem_indexes = torch.randint(len(problems), (nb_samples,)) + nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + print(f"{nb_samples_per_problem}") + + self.train_input = generate_sequences(nb_train_samples) + self.test_input = generate_sequences(nb_test_samples) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + # logger( + # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + # ) + pass + + ###################################################################### import picoclvr @@ -108,6 +171,8 @@ class PicoCLVR(Task): pruner_train=None, pruner_eval=None, ): + super().__init__() + def generate_descr(nb, cache_suffix, pruner): return picoclvr.generate( nb, @@ -296,6 +361,8 @@ class MNIST(Task): def __init__( self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") ): + super().__init__() + self.nb_train_samples = (nb_train_samples,) self.nb_test_samples = (nb_test_samples,) self.batch_size = batch_size @@ -366,6 +433,8 @@ class Maze(Task): nb_walls, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.height = height self.width = width @@ -537,6 +606,8 @@ class Snake(Task): prompt_length, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.height = height self.width = width @@ -635,6 +706,8 @@ class Stack(Task): fraction_values_for_train=None, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.nb_steps = nb_steps self.nb_stacks = nb_stacks @@ -782,6 +855,8 @@ class Expr(Task): batch_size, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.device = device @@ -961,6 +1036,8 @@ class World(Task): device=torch.device("cpu"), device_storage=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.device = device diff --git a/world.py b/world.py index fa305cf..da7de75 100755 --- a/world.py +++ b/world.py @@ -62,12 +62,20 @@ class SignSTE(nn.Module): return s +def loss_H(binary_logits, h_threshold=1): + p = binary_logits.sigmoid().mean(0) + h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2) + h.clamp_(max=h_threshold) + return h_threshold - h.mean() + + def train_encoder( train_input, test_input, depth=2, dim_hidden=48, nb_bits_per_token=8, + lambda_entropy=0.0, lr_start=1e-3, lr_end=1e-4, nb_epochs=10, @@ -160,6 +168,9 @@ def train_encoder( train_loss = F.cross_entropy(output, input) + if lambda_entropy > 0: + loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5) + acc_train_loss += train_loss.item() * input.size(0) optimizer.zero_grad() @@ -238,7 +249,7 @@ def scene2tensor(xh, yh, scene, size): ) -def random_scene(): +def random_scene(nb_insert_attempts=3): scene = [] colors = [ ((Box.nb_rgb_levels - 1), 0, 0), @@ -252,7 +263,7 @@ def random_scene(): ), ] - for k in range(10): + for k in range(nb_insert_attempts): wh = torch.rand(2) * 0.2 + 0.2 xy = torch.rand(2) * (1 - wh) c = colors[torch.randint(len(colors), (1,))] @@ -286,14 +297,15 @@ def generate_episode(steps, size=64): xh, yh = tuple(x.item() for x in torch.rand(2)) actions = torch.randint(len(effects), (len(steps),)) - change = False + nb_changes = 0 for s, a in zip(steps, actions): if s: frames.append(scene2tensor(xh, yh, scene, size=size)) - g, dx, dy = effects[a] - if g: + grasp, dx, dy = effects[a] + + if grasp: for b in scene: if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh: x, y = b.x, b.y @@ -310,7 +322,7 @@ def generate_episode(steps, size=64): else: xh += dx yh += dy - change = True + nb_changes += 1 else: x, y = xh, yh xh += dx @@ -318,7 +330,7 @@ def generate_episode(steps, size=64): if xh < 0 or xh > 1 or yh < 0 or yh > 1: xh, yh = x, y - if change: + if nb_changes > len(steps) // 3: break return frames, actions @@ -352,12 +364,21 @@ def create_data_and_processors( steps = [True] + [False] * (nb_steps + 1) + [True] train_input, train_actions = generate_episodes(nb_train_samples, steps) - train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage) + train_input, train_actions = train_input.to(device_storage), train_actions.to( + device_storage + ) test_input, test_actions = generate_episodes(nb_test_samples, steps) - test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage) + test_input, test_actions = test_input.to(device_storage), test_actions.to( + device_storage + ) encoder, quantizer, decoder = train_encoder( - train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device + train_input, + test_input, + lambda_entropy=1.0, + nb_epochs=nb_epochs, + logger=logger, + device=device, ) encoder.train(False) quantizer.train(False) @@ -371,7 +392,7 @@ def create_data_and_processors( seq = [] p = pow2.to(device) for x in input.split(batch_size): - x=x.to(device) + x = x.to(device) z = encoder(x) ze_bool = (quantizer(z) >= 0).long() output = ( -- 2.20.1