From 38c69cc69cffd1a54b92bfe993f52aa649afb7d4 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 29 Jul 2022 06:13:14 +0200 Subject: [PATCH] Update. --- main.py | 4 ++-- mygpt.py | 6 +++--- picoclvr.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index bcba9ee..83227bb 100755 --- a/main.py +++ b/main.py @@ -160,8 +160,8 @@ class TaskPicoCLVR(Task): def tensorize(self, descr): token_descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in token_descr ]) - token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] - #token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] + #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] + token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] return torch.tensor(id_descr, device = self.device) diff --git a/mygpt.py b/mygpt.py index d6879dc..9da2e68 100755 --- a/mygpt.py +++ b/mygpt.py @@ -14,7 +14,7 @@ from torch.nn import functional as F ############################## -class Residual(nn.Module): +class WithResidual(nn.Module): def __init__(self, *f): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) @@ -103,7 +103,7 @@ class MyGPT(nn.Module): for _ in range(nb_blocks): trunk_blocks += [ - Residual( + WithResidual( nn.LayerNorm((dim_model,)), QKVAttention( dim_in = dim_model, @@ -113,7 +113,7 @@ class MyGPT(nn.Module): causal = True, attention_dropout = dropout ), ), - Residual( + WithResidual( nn.LayerNorm((dim_model,)), nn.Linear(in_features = dim_model, out_features = dim_hidden), nn.ReLU(), diff --git a/picoclvr.py b/picoclvr.py index f097eb0..8201f5d 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -172,7 +172,7 @@ def descr2properties(descr, height, width): return [] seen[x] = (color_id[x], k // width, k % width) - square_infos = zip(*seen.values()) + square_infos = tuple(zip(*seen.values())) square_c = torch.tensor(square_infos[0]) square_i = torch.tensor(square_infos[1]) square_j = torch.tensor(square_infos[2]) -- 2.20.1