Update.
[pytorch.git] / attentiontoy1d.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, math, sys, argparse
9
10 from torch import nn, einsum
11 from torch.nn import functional as F
12
13 import matplotlib.pyplot as plt
14
15 ######################################################################
16
17 parser = argparse.ArgumentParser(description="Toy attention model.")
18
19 parser.add_argument("--nb_epochs", type=int, default=250)
20
21 parser.add_argument(
22     "--with_attention",
23     help="Use the model with an attention layer",
24     action="store_true",
25     default=False,
26 )
27
28 parser.add_argument(
29     "--group_by_locations",
30     help="Use the task where the grouping is location-based",
31     action="store_true",
32     default=False,
33 )
34
35 parser.add_argument(
36     "--positional_encoding",
37     help="Provide a positional encoding",
38     action="store_true",
39     default=False,
40 )
41
42 parser.add_argument(
43     "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
44 )
45
46 args = parser.parse_args()
47
48 if args.seed >= 0:
49     torch.manual_seed(args.seed)
50
51 ######################################################################
52
53 label = ""
54
55 if args.with_attention:
56     label = "wa_"
57
58 if args.group_by_locations:
59     label += "lg_"
60
61 if args.positional_encoding:
62     label += "pe_"
63
64 log_file = open(f"att1d_{label}train.log", "w")
65
66 ######################################################################
67
68
69 def log_string(s):
70     if log_file is not None:
71         log_file.write(s + "\n")
72         log_file.flush()
73     print(s)
74     sys.stdout.flush()
75
76
77 ######################################################################
78
79 if torch.cuda.is_available():
80     device = torch.device("cuda")
81     torch.backends.cudnn.benchmark = True
82 else:
83     device = torch.device("cpu")
84
85 ######################################################################
86
87 seq_height_min, seq_height_max = 1.0, 25.0
88 seq_width_min, seq_width_max = 5.0, 11.0
89 seq_length = 100
90
91
92 def positions_to_sequences(tr=None, bx=None, noise_level=0.3):
93     st = torch.arange(seq_length, device=device).float()
94     st = st[None, :, None]
95     tr = tr[:, None, :, :]
96     bx = bx[:, None, :, :]
97
98     xtr = torch.relu(
99         tr[..., 1]
100         - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2]
101     )
102     xbx = (
103         torch.sign(
104             torch.relu(
105                 bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2])
106             )
107         )
108         * bx[..., 1]
109     )
110
111     x = torch.cat((xtr, xbx), 2)
112
113     u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size=2, stride=1).permute(
114         0, 2, 1
115     )
116
117     collisions = (u.sum(2) > 1).max(1).values
118     y = x.max(2).values
119
120     return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
121
122
123 ######################################################################
124
125
126 def generate_sequences(nb):
127     # Position / height / width
128
129     tr = torch.empty(nb, 2, 3, device=device)
130     tr[:, :, 0].uniform_(seq_width_max / 2, seq_length - seq_width_max / 2)
131     tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
132     tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
133
134     bx = torch.empty(nb, 2, 3, device=device)
135     bx[:, :, 0].uniform_(seq_width_max / 2, seq_length - seq_width_max / 2)
136     bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
137     bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
138
139     if args.group_by_locations:
140         a = torch.cat((tr, bx), 1)
141         v = a[:, :, 0].sort(1).values[:, 2:3]
142         mask_left = (a[:, :, 0] < v).float()
143         h_left = (a[:, :, 1] * mask_left).sum(1) / 2
144         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
145         valid = (h_left - h_right).abs() > 4
146     else:
147         valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (
148             torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4
149         )
150
151     input, collisions = positions_to_sequences(tr, bx)
152
153     if args.group_by_locations:
154         a = torch.cat((tr, bx), 1)
155         v = a[:, :, 0].sort(1).values[:, 2:3]
156         mask_left = (a[:, :, 0] < v).float()
157         h_left = (a[:, :, 1] * mask_left).sum(1, keepdim=True) / 2
158         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim=True) / 2
159         a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
160         tr, bx = a.split(2, 1)
161     else:
162         tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim=True)
163         bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim=True)
164
165     targets, _ = positions_to_sequences(tr, bx)
166
167     valid = valid & ~collisions
168     tr = tr[valid]
169     bx = bx[valid]
170     input = input[valid][:, None, :]
171     targets = targets[valid][:, None, :]
172
173     if input.size(0) < nb:
174         input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
175         input = torch.cat((input, input2), 0)
176         targets = torch.cat((targets, targets2), 0)
177         tr = torch.cat((tr, tr2), 0)
178         bx = torch.cat((bx, bx2), 0)
179
180     return input, targets, tr, bx
181
182
183 ######################################################################
184
185
186 def save_sequence_images(filename, sequences, tr=None, bx=None):
187     fig = plt.figure()
188     ax = fig.add_subplot(1, 1, 1)
189
190     ax.set_xlim(0, seq_length)
191     ax.set_ylim(-1, seq_height_max + 4)
192
193     for u in sequences:
194         ax.plot(torch.arange(u[0].size(0)) + 0.5, u[0], color=u[1], label=u[2])
195
196     ax.legend(frameon=False, loc="upper left")
197
198     delta = -1.0
199     if tr is not None:
200         ax.scatter(
201             tr[:, 0].cpu(),
202             torch.full((tr.size(0),), delta),
203             color="black",
204             marker="^",
205             clip_on=False,
206         )
207
208     if bx is not None:
209         ax.scatter(
210             bx[:, 0].cpu(),
211             torch.full((bx.size(0),), delta),
212             color="black",
213             marker="s",
214             clip_on=False,
215         )
216
217     fig.savefig(filename, bbox_inches="tight")
218
219     plt.close("all")
220
221
222 ######################################################################
223
224
225 class AttentionLayer(nn.Module):
226     def __init__(self, in_channels, out_channels, key_channels):
227         super().__init__()
228         self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
229         self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
230         self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
231
232     def forward(self, x):
233         Q = self.conv_Q(x)
234         K = self.conv_K(x)
235         V = self.conv_V(x)
236         A = einsum("nct,ncs->nts", Q, K).softmax(2)
237         y = einsum("nts,ncs->nct", A, V)
238         return y
239
240     def __repr__(self):
241         return (
242             self._get_name()
243             + "(in_channels={}, out_channels={}, key_channels={})".format(
244                 self.conv_Q.in_channels,
245                 self.conv_V.out_channels,
246                 self.conv_K.out_channels,
247             )
248         )
249
250     def attention(self, x):
251         Q = self.conv_Q(x)
252         K = self.conv_K(x)
253         A = einsum("nct,ncs->nts", Q, K).softmax(2)
254         return A
255
256
257 ######################################################################
258
259 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
260 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
261
262 ######################################################################
263
264 ks = 5
265 nc = 64
266
267 if args.positional_encoding:
268     c = math.ceil(math.log(seq_length) / math.log(2.0))
269     positional_input = (
270         torch.arange(seq_length).unsqueeze(0) // 2 ** torch.arange(c).unsqueeze(1)
271     ) % 2
272     positional_input = positional_input.unsqueeze(0).float()
273 else:
274     positional_input = torch.zeros(1, 0, seq_length)
275
276 in_channels = 1 + positional_input.size(1)
277
278 if args.with_attention:
279     model = nn.Sequential(
280         nn.Conv1d(in_channels, nc, kernel_size=ks, padding=ks // 2),
281         nn.ReLU(),
282         nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
283         nn.ReLU(),
284         AttentionLayer(nc, nc, nc),
285         nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
286         nn.ReLU(),
287         nn.Conv1d(nc, 1, kernel_size=ks, padding=ks // 2),
288     )
289
290 else:
291     model = nn.Sequential(
292         nn.Conv1d(in_channels, nc, kernel_size=ks, padding=ks // 2),
293         nn.ReLU(),
294         nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
295         nn.ReLU(),
296         nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
297         nn.ReLU(),
298         nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
299         nn.ReLU(),
300         nn.Conv1d(nc, 1, kernel_size=ks, padding=ks // 2),
301     )
302
303 nb_parameters = sum(p.numel() for p in model.parameters())
304
305 with open(f"att1d_{label}model.log", "w") as f:
306     f.write(str(model) + "\n\n")
307     f.write(f"nb_parameters {nb_parameters}\n")
308
309 ######################################################################
310
311 batch_size = 100
312
313 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
314 mse_loss = nn.MSELoss()
315
316 model.to(device)
317 mse_loss.to(device)
318 train_input, train_targets = train_input.to(device), train_targets.to(device)
319 test_input, test_targets = test_input.to(device), test_targets.to(device)
320 positional_input = positional_input.to(device)
321
322 mu, std = train_input.mean(), train_input.std()
323
324 for e in range(args.nb_epochs):
325     acc_loss = 0.0
326
327     for input, targets in zip(
328         train_input.split(batch_size), train_targets.split(batch_size)
329     ):
330         input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
331
332         output = model((input - mu) / std)
333         loss = mse_loss(output, targets)
334
335         optimizer.zero_grad()
336         loss.backward()
337         optimizer.step()
338
339         acc_loss += loss.item()
340
341     log_string(f"{e+1} {acc_loss}")
342
343 ######################################################################
344
345 train_input = train_input.detach().to("cpu")
346 train_targets = train_targets.detach().to("cpu")
347
348 for k in range(15):
349     save_sequence_images(
350         f"att1d_{label}train_{k:03d}.pdf",
351         [
352             (train_input[k, 0], "blue", "Input"),
353             (train_targets[k, 0], "red", "Target"),
354         ],
355     )
356
357 ####################
358
359 test_input = torch.cat(
360     (test_input, positional_input.expand(test_input.size(0), -1, -1)), 1
361 )
362 test_outputs = model((test_input - mu) / std).detach()
363
364 if args.with_attention:
365     k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
366     x = model[0:k]((test_input - mu) / std)
367     test_A = model[k].attention(x)
368     test_A = test_A.detach().to("cpu")
369
370 test_input = test_input.detach().to("cpu")
371 test_outputs = test_outputs.detach().to("cpu")
372 test_targets = test_targets.detach().to("cpu")
373 test_bx = test_bx.detach().to("cpu")
374 test_tr = test_tr.detach().to("cpu")
375
376 for k in range(15):
377     save_sequence_images(
378         f"att1d_{label}test_Y_{k:03d}.pdf",
379         [
380             (test_input[k, 0], "blue", "Input"),
381             (test_outputs[k, 0], "orange", "Output"),
382         ],
383     )
384
385     save_sequence_images(
386         f"att1d_{label}test_Yp_{k:03d}.pdf",
387         [
388             (test_input[k, 0], "blue", "Input"),
389             (test_outputs[k, 0], "orange", "Output"),
390         ],
391         test_tr[k],
392         test_bx[k],
393     )
394
395     if args.with_attention:
396         fig = plt.figure()
397         ax = fig.add_subplot(1, 1, 1)
398         ax.set_xlim(0, seq_length)
399         ax.set_ylim(0, seq_length)
400
401         ax.imshow(test_A[k], cmap="binary", interpolation="nearest")
402         delta = 0.0
403         ax.scatter(
404             test_bx[k, :, 0],
405             torch.full((test_bx.size(1),), delta),
406             color="black",
407             marker="s",
408             clip_on=False,
409         )
410         ax.scatter(
411             torch.full((test_bx.size(1),), delta),
412             test_bx[k, :, 0],
413             color="black",
414             marker="s",
415             clip_on=False,
416         )
417         ax.scatter(
418             test_tr[k, :, 0],
419             torch.full((test_tr.size(1),), delta),
420             color="black",
421             marker="^",
422             clip_on=False,
423         )
424         ax.scatter(
425             torch.full((test_tr.size(1),), delta),
426             test_tr[k, :, 0],
427             color="black",
428             marker="^",
429             clip_on=False,
430         )
431
432         fig.savefig(f"att1d_{label}test_A_{k:03d}.pdf", bbox_inches="tight")
433
434     plt.close("all")
435
436 ######################################################################