Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 9 Jul 2023 11:01:55 +0000 (13:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 9 Jul 2023 11:01:55 +0000 (13:01 +0200)
world.py

index fb5d5c7..e76c07f 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -203,12 +203,12 @@ def patchify(x, factor, invert_size=None):
     if invert_size is None:
         return (
             x.reshape(
-                x.size(0), #0
-                x.size(1), #1
-                factor, #2
-                x.size(2) // factor,#3
-                factor,#4
-                x.size(3) // factor,#5
+                x.size(0),  # 0
+                x.size(1),  # 1
+                factor,  # 2
+                x.size(2) // factor,  # 3
+                factor,  # 4
+                x.size(3) // factor,  # 5
             )
             .permute(0, 2, 4, 1, 3, 5)
             .reshape(-1, x.size(1), x.size(2) // factor, x.size(3) // factor)
@@ -216,18 +216,86 @@ def patchify(x, factor, invert_size=None):
     else:
         return (
             x.reshape(
-                invert_size[0], #0
-                factor, #1
-                factor, #2
-                invert_size[1], #3
-                invert_size[2] // factor, #4
-                invert_size[3] // factor, #5
+                invert_size[0],  # 0
+                factor,  # 1
+                factor,  # 2
+                invert_size[1],  # 3
+                invert_size[2] // factor,  # 4
+                invert_size[3] // factor,  # 5
             )
             .permute(0, 3, 1, 4, 2, 5)
             .reshape(invert_size)
         )
 
 
+def train_encoder(input, device=torch.device("cpu")):
+    class SomeLeNet(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
+            self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
+            self.fc1 = nn.Linear(256, 200)
+            self.fc2 = nn.Linear(200, 10)
+
+        def forward(self, x):
+            x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3))
+            x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
+            x = x.view(x.size(0), -1)
+            x = F.relu(self.fc1(x))
+            x = self.fc2(x)
+            return x
+
+    ######################################################################
+
+    model = SomeLeNet()
+
+    nb_parameters = sum(p.numel() for p in model.parameters())
+
+    print(f"nb_parameters {nb_parameters}")
+
+    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
+    criterion = nn.CrossEntropyLoss()
+
+    model.to(device)
+    criterion.to(device)
+
+    train_input, train_targets = train_input.to(device), train_targets.to(device)
+    test_input, test_targets = test_input.to(device), test_targets.to(device)
+
+    mu, std = train_input.mean(), train_input.std()
+    train_input.sub_(mu).div_(std)
+    test_input.sub_(mu).div_(std)
+
+    start_time = time.perf_counter()
+
+    for k in range(nb_epochs):
+        acc_loss = 0.0
+
+        for input, targets in zip(
+            train_input.split(batch_size), train_targets.split(batch_size)
+        ):
+            output = model(input)
+            loss = criterion(output, targets)
+            acc_loss += loss.item()
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+        nb_test_errors = 0
+        for input, targets in zip(
+            test_input.split(batch_size), test_targets.split(batch_size)
+        ):
+            wta = model(input).argmax(1)
+            nb_test_errors += (wta != targets).long().sum()
+        test_error = nb_test_errors / test_input.size(0)
+        duration = time.perf_counter() - start_time
+
+        print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
+
+
+######################################################################
+
 if __name__ == "__main__":
     import time
 
@@ -241,13 +309,14 @@ if __name__ == "__main__":
     print(f"{nb / (end_time - start_time):.02f} samples per second")
 
     input = torch.cat(all_frames, 0)
-    x = patchify(input, 8)
-    y = x.reshape(x.size(0), -1)
-    print(f"{x.size()=} {y.size()=}")
-    centroids, t = kmeans(y, 4096)
-    results = centroids[t]
-    results = results.reshape(x.size())
-    results = patchify(results, 8, input.size())
+
+    # x = patchify(input, 8)
+    # y = x.reshape(x.size(0), -1)
+    # print(f"{x.size()=} {y.size()=}")
+    # centroids, t = kmeans(y, 4096)
+    # results = centroids[t]
+    # results = results.reshape(x.size())
+    # results = patchify(results, 8, input.size())
 
     print(f"{input.size()=} {results.size()=}")