Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 13 Oct 2023 11:51:34 +0000 (13:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 13 Oct 2023 11:51:34 +0000 (13:51 +0200)
qmlp.py [moved from qmlp.y with 54% similarity]
tasks.py

diff --git a/qmlp.y b/qmlp.py
similarity index 54%
rename from qmlp.y
rename to qmlp.py
index 7b97edb..e12f0e1 100755 (executable)
--- a/qmlp.y
+++ b/qmlp.py
@@ -39,40 +39,24 @@ def dequantize(q, xmin, xmax):
 ######################################################################
 
 
-def create_model():
-    hidden_dim = 32
-
-    model = nn.Sequential(
-        nn.Linear(2, hidden_dim),
-        nn.ReLU(),
-        nn.Linear(hidden_dim, hidden_dim),
-        nn.ReLU(),
-        nn.Linear(hidden_dim, 2),
-    )
-
-    return model
-
-
-######################################################################
 
 
 def generate_sets_and_params(
-    nb_mlps,
+    batch_nb_mlps,
     nb_samples,
     batch_size,
     nb_epochs,
     device=torch.device("cpu"),
     print_log=False,
 ):
-    data_input = torch.zeros(nb_mlps, 2 * nb_samples, 2, device=device)
+    data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
     data_targets = torch.zeros(
-        nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
+        batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
     )
 
     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
         nb = i.sum()
-        print(f"{nb=}")
 
         nb_rec = 2
         support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1
@@ -108,10 +92,10 @@ def generate_sets_and_params(
     test_targets = test_targets
 
     hidden_dim = 32
-    w1 = torch.randn(nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
-    b1 = torch.zeros(nb_mlps, hidden_dim, device=device)
-    w2 = torch.randn(nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
-    b2 = torch.zeros(nb_mlps, 2, device=device)
+    w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
+    b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
+    w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
+    b2 = torch.zeros(batch_nb_mlps, 2, device=device)
 
     w1.requires_grad_()
     b1.requires_grad_()
@@ -158,13 +142,13 @@ def generate_sets_and_params(
         # print(f"{k=} {acc_train_loss=} {train_error=}")
 
     q_params = torch.cat(
-        [quantize(p.view(nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
+        [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
     )
     q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
-        nb_mlps, -1
+        batch_nb_mlps, -1
     )
     q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
-        nb_mlps, -1
+        batch_nb_mlps, -1
     )
 
     return q_train_set, q_test_set, q_params
@@ -173,51 +157,59 @@ def generate_sets_and_params(
 ######################################################################
 
 
-def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu")):
+def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024):
+
+    errors = []
     nb_mlps = q_params.size(0)
-    hidden_dim = 32
-    w1 = torch.empty(nb_mlps, hidden_dim, 2, device=device)
-    b1 = torch.empty(nb_mlps, hidden_dim, device=device)
-    w2 = torch.empty(nb_mlps, 2, hidden_dim, device=device)
-    b2 = torch.empty(nb_mlps, 2, device=device)
-
-    with torch.no_grad():
-        k = 0
-        for p in [w1, b1, w2, b2]:
-            print(f"{p.size()=}")
-            x = dequantize(q_params[:, k : k + p.numel() // nb_mlps], -2, 2).view(
-                p.size()
-            )
-            p.copy_(x)
-            k += p.numel() // nb_mlps
 
-    q_set = q_set.view(nb_mlps, -1, 3)
-    data_input = dequantize(q_set[:, :, :2], -1, 1).to(device)
-    data_targets = q_set[:, :, 2].to(device)
+    for n in range(0,nb_mlps,nb_mlps_per_batch):
+        batch_nb_mlps = min(nb_mlps_per_batch,nb_mlps-n)
+        batch_q_params = q_params[n:n+batch_nb_mlps]
+        batch_q_set = q_set[n:n+batch_nb_mlps]
+        hidden_dim = 32
+        w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
+        b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
+        w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
+        b2 = torch.empty(batch_nb_mlps, 2, device=device)
 
-    print(f"{data_input.size()=} {data_targets.size()=}")
+        with torch.no_grad():
+            k = 0
+            for p in [w1, b1, w2, b2]:
+                print(f"{p.size()=}")
+                x = dequantize(batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2).view(
+                    p.size()
+                )
+                p.copy_(x)
+                k += p.numel() // batch_nb_mlps
 
-    criterion = nn.CrossEntropyLoss()
-    criterion.to(device)
+        batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
+        data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
+        data_targets = batch_q_set[:, :, 2].to(device)
+
+        print(f"{data_input.size()=} {data_targets.size()=}")
+
+        criterion = nn.CrossEntropyLoss()
+        criterion.to(device)
+
+        acc_loss = 0.0
+        nb_errors = 0
 
-    acc_loss = 0.0
-    nb_errors = 0
+        for input, targets in zip(
+            data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
+        ):
+            h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+            h = F.relu(h)
+            output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+            loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+            acc_loss += loss.item() * input.size(0)
+            wta = output.argmax(-1)
+            nb_errors += (wta != targets).long().sum(-1)
 
-    for input, targets in zip(
-        data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
-    ):
-        h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
-        h = F.relu(h)
-        output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
-        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
-        acc_loss += loss.item() * input.size(0)
-        wta = output.argmax(-1)
-        nb_errors += (wta != targets).long().sum(-1)
+        errors.append(nb_errors / data_input.size(1))
+        acc_loss = acc_loss / data_input.size(1)
 
-    error = nb_errors / data_input.size(1)
-    acc_loss = acc_loss / data_input.size(1)
 
-    return error
+    return torch.cat(errors)
 
 
 ######################################################################
@@ -229,40 +221,41 @@ def generate_sequence_and_test_set(
     batch_size,
     nb_epochs,
     device,
+    nb_mlps_per_batch=1024,
 ):
-    q_train_set, q_test_set, q_params = generate_sets_and_params(
-        nb_mlps,
-        nb_samples,
-        batch_size,
-        nb_epochs,
-        device=device,
-    )
 
-    input = torch.cat(
-        [
-            q_train_set,
-            q_train_set.new_full(
-                (
-                    q_train_set.size(0),
-                    1,
+    inputs, q_test_sets = [],[]
+
+    for n in range(0,nb_mlps,nb_mlps_per_batch):
+        q_train_set, q_test_set, q_params = generate_sets_and_params(
+            batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n),
+            nb_samples=nb_samples,
+            batch_size=batch_size,
+            nb_epochs=nb_epochs,
+            device=device,
+        )
+
+        inputs.append(torch.cat(
+            [
+                q_train_set,
+                q_train_set.new_full(
+                    (
+                        q_train_set.size(0),
+                        1,
+                    ),
+                    nb_quantization_levels,
                 ),
-                nb_quantization_levels,
-            ),
-            q_params,
-        ],
-        dim=-1,
-    )
+                q_params,
+            ],
+            dim=-1,
+        ))
 
-    print(f"SANITY #1 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
+        q_test_sets.append(q_test_set)
 
-    ar_mask = (
-        (torch.arange(input.size(0), device=input.device) > q_train_set.size(0) + 1)
-        .long()
-        .view(1, -1)
-        .reshape(nb_mlps, -1)
-    )
+    input = torch.cat(inputs)
+    q_test_set = torch.cat(q_test_sets)
 
-    return input, ar_mask, q_test_set
+    return input, q_test_set
 
 
 ######################################################################
@@ -270,7 +263,7 @@ def generate_sequence_and_test_set(
 if __name__ == "__main__":
     import time
 
-    nb_mlps, nb_samples = 128, 200
+    batch_nb_mlps, nb_samples = 128, 500
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
@@ -278,26 +271,22 @@ if __name__ == "__main__":
 
     data = []
 
-    for n in range(2):
-        data.append(
-            generate_sequence_and_test_set(
-                nb_mlps=nb_mlps,
-                nb_samples=nb_samples,
-                device=device,
-                batch_size=25,
-                nb_epochs=250,
-            )
-        )
+    input, q_test_set = generate_sequence_and_test_set(
+        nb_mlps=batch_nb_mlps,
+        nb_samples=nb_samples,
+        device=device,
+        batch_size=25,
+        nb_epochs=250,
+        nb_mlps_per_batch=17
+    )
 
     end_time = time.perf_counter()
-    nb = sum([i.size(0) for i, _, _ in data])
-    print(f"{nb / (end_time - start_time):.02f} samples per second")
-
-    for input, ar_mask, q_test_set in data:
-        q_train_set = input[:, : nb_samples * 3]
-        q_params = input[:, nb_samples * 3 + 1 :]
-        print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
-        error_train = evaluate_q_params(q_params, q_train_set)
-        print(f"train {error_train*100}%")
-        error_test = evaluate_q_params(q_params, q_test_set)
-        print(f"test {error_test*100}%")
+    print(f"{input.size(0) / (end_time - start_time):.02f} samples per second")
+
+    q_train_set = input[:, : nb_samples * 3]
+    q_params = input[:, nb_samples * 3 + 1 :]
+    print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
+    error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
+    print(f"train {error_train*100}%")
+    error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
+    print(f"test {error_test*100}%")
index 183c3cf..ea10d7c 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1550,3 +1550,105 @@ class Grid(Task):
 
 
 ######################################################################
+
+import qmlp
+
+
+class QMLP(Task):
+
+    ######################
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.device = device
+        self.batch_size = batch_size
+
+        if logger is not None:
+            logger(
+                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+            )
+
+        self.train_descr = self.grid_factory.generate_samples(
+            nb_train_samples, lambda r: tqdm.tqdm(r)
+        )
+        self.test_descr = self.grid_factory.generate_samples(
+            nb_test_samples, lambda r: tqdm.tqdm(r)
+        )
+
+        # Build the tokenizer
+        tokens = set()
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        # make this set a sorted list to get the same tensors given
+        # the same descr
+        tokens = list(tokens)
+        tokens.sort()
+        tokens = ["#"] + tokens
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+        self.t_nul = self.token2id["#"]
+        self.t_true = self.token2id["true"]
+        self.t_false = self.token2id["false"]
+
+        # Tokenize the train and test sets
+        self.train_input = self.str2tensor(self.train_descr)
+        self.test_input = self.str2tensor(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield self.trim(batch)
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        correct = self.test_input[:1000]
+        result = correct.clone()
+        ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
+
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
+            logger(f"test_before {e}")
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
+            logger(f"test_after  {e}")
+
+        logger(f"----------------------------------------------------------")
+
+        nb_total = ar_mask.sum().item()
+        nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
+######################################################################