From 9e62722596c40655041a0a812512115f1036c6fc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 9 Jul 2023 13:01:55 +0200 Subject: [PATCH] Update. --- world.py | 107 +++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 19 deletions(-) diff --git a/world.py b/world.py index fb5d5c7..e76c07f 100755 --- a/world.py +++ b/world.py @@ -203,12 +203,12 @@ def patchify(x, factor, invert_size=None): if invert_size is None: return ( x.reshape( - x.size(0), #0 - x.size(1), #1 - factor, #2 - x.size(2) // factor,#3 - factor,#4 - x.size(3) // factor,#5 + x.size(0), # 0 + x.size(1), # 1 + factor, # 2 + x.size(2) // factor, # 3 + factor, # 4 + x.size(3) // factor, # 5 ) .permute(0, 2, 4, 1, 3, 5) .reshape(-1, x.size(1), x.size(2) // factor, x.size(3) // factor) @@ -216,18 +216,86 @@ def patchify(x, factor, invert_size=None): else: return ( x.reshape( - invert_size[0], #0 - factor, #1 - factor, #2 - invert_size[1], #3 - invert_size[2] // factor, #4 - invert_size[3] // factor, #5 + invert_size[0], # 0 + factor, # 1 + factor, # 2 + invert_size[1], # 3 + invert_size[2] // factor, # 4 + invert_size[3] // factor, # 5 ) .permute(0, 3, 1, 4, 2, 5) .reshape(invert_size) ) +def train_encoder(input, device=torch.device("cpu")): + class SomeLeNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) + self.fc1 = nn.Linear(256, 200) + self.fc2 = nn.Linear(200, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3)) + x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2)) + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + ###################################################################### + + model = SomeLeNet() + + nb_parameters = sum(p.numel() for p in model.parameters()) + + print(f"nb_parameters {nb_parameters}") + + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + train_input, train_targets = train_input.to(device), train_targets.to(device) + test_input, test_targets = test_input.to(device), test_targets.to(device) + + mu, std = train_input.mean(), train_input.std() + train_input.sub_(mu).div_(std) + test_input.sub_(mu).div_(std) + + start_time = time.perf_counter() + + for k in range(nb_epochs): + acc_loss = 0.0 + + for input, targets in zip( + train_input.split(batch_size), train_targets.split(batch_size) + ): + output = model(input) + loss = criterion(output, targets) + acc_loss += loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + nb_test_errors = 0 + for input, targets in zip( + test_input.split(batch_size), test_targets.split(batch_size) + ): + wta = model(input).argmax(1) + nb_test_errors += (wta != targets).long().sum() + test_error = nb_test_errors / test_input.size(0) + duration = time.perf_counter() - start_time + + print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%") + + +###################################################################### + if __name__ == "__main__": import time @@ -241,13 +309,14 @@ if __name__ == "__main__": print(f"{nb / (end_time - start_time):.02f} samples per second") input = torch.cat(all_frames, 0) - x = patchify(input, 8) - y = x.reshape(x.size(0), -1) - print(f"{x.size()=} {y.size()=}") - centroids, t = kmeans(y, 4096) - results = centroids[t] - results = results.reshape(x.size()) - results = patchify(results, 8, input.size()) + + # x = patchify(input, 8) + # y = x.reshape(x.size(0), -1) + # print(f"{x.size()=} {y.size()=}") + # centroids, t = kmeans(y, 4096) + # results = centroids[t] + # results = results.reshape(x.size()) + # results = patchify(results, 8, input.size()) print(f"{input.size()=} {results.size()=}") -- 2.20.1