Update.
[pytorch.git] / attentiontoy1d.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, math, sys, argparse
9
10 from torch import nn, einsum
11 from torch.nn import functional as F
12
13 import matplotlib.pyplot as plt
14
15 ######################################################################
16
17 parser = argparse.ArgumentParser(description='Toy attention model.')
18
19 parser.add_argument('--nb_epochs',
20                     type = int, default = 250)
21
22 parser.add_argument('--with_attention',
23                     help = 'Use the model with an attention layer',
24                     action='store_true', default=False)
25
26 parser.add_argument('--group_by_locations',
27                     help = 'Use the task where the grouping is location-based',
28                     action='store_true', default=False)
29
30 parser.add_argument('--positional_encoding',
31                     help = 'Provide a positional encoding',
32                     action='store_true', default=False)
33
34 parser.add_argument('--seed',
35                     type = int, default = 0,
36                     help = 'Random seed (default 0, < 0 is no seeding)')
37
38 args = parser.parse_args()
39
40 if args.seed >= 0:
41     torch.manual_seed(args.seed)
42
43 ######################################################################
44
45 label=''
46
47 if args.with_attention: label = 'wa_'
48
49 if args.group_by_locations: label += 'lg_'
50
51 if args.positional_encoding: label += 'pe_'
52
53 log_file = open(f'att1d_{label}train.log', 'w')
54
55 ######################################################################
56
57 def log_string(s):
58     if log_file is not None:
59         log_file.write(s + '\n')
60         log_file.flush()
61     print(s)
62     sys.stdout.flush()
63
64 ######################################################################
65
66 if torch.cuda.is_available():
67     device = torch.device('cuda')
68     torch.backends.cudnn.benchmark = True
69 else:
70     device = torch.device('cpu')
71
72 ######################################################################
73
74 seq_height_min, seq_height_max = 1.0, 25.0
75 seq_width_min, seq_width_max = 5.0, 11.0
76 seq_length = 100
77
78 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
79     st = torch.arange(seq_length, device = device).float()
80     st = st[None, :, None]
81     tr = tr[:, None, :, :]
82     bx = bx[:, None, :, :]
83
84     xtr =            torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
85     xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]
86
87     x = torch.cat((xtr, xbx), 2)
88
89     u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
90
91     collisions = (u.sum(2) > 1).max(1).values
92     y = x.max(2).values
93
94     return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
95
96 ######################################################################
97
98 def generate_sequences(nb):
99
100     # Position / height / width
101
102     tr = torch.empty(nb, 2, 3, device = device)
103     tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
104     tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
105     tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
106
107     bx = torch.empty(nb, 2, 3, device = device)
108     bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
109     bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
110     bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
111
112     if args.group_by_locations:
113         a = torch.cat((tr, bx), 1)
114         v = a[:, :, 0].sort(1).values[:, 2:3]
115         mask_left = (a[:, :, 0] < v).float()
116         h_left = (a[:, :, 1] * mask_left).sum(1) / 2
117         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
118         valid = (h_left - h_right).abs() > 4
119     else:
120         valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)
121
122     input, collisions = positions_to_sequences(tr, bx)
123
124     if args.group_by_locations:
125         a = torch.cat((tr, bx), 1)
126         v = a[:, :, 0].sort(1).values[:, 2:3]
127         mask_left = (a[:, :, 0] < v).float()
128         h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
129         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
130         a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
131         tr, bx = a.split(2, 1)
132     else:
133         tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
134         bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)
135
136     targets, _ = positions_to_sequences(tr, bx)
137
138     valid = valid & ~collisions
139     tr = tr[valid]
140     bx = bx[valid]
141     input = input[valid][:, None, :]
142     targets = targets[valid][:, None, :]
143
144     if input.size(0) < nb:
145         input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
146         input = torch.cat((input, input2), 0)
147         targets = torch.cat((targets, targets2), 0)
148         tr = torch.cat((tr, tr2), 0)
149         bx = torch.cat((bx, bx2), 0)
150
151     return input, targets, tr, bx
152
153 ######################################################################
154
155 def save_sequence_images(filename, sequences, tr = None, bx = None):
156     fig = plt.figure()
157     ax = fig.add_subplot(1, 1, 1)
158
159     ax.set_xlim(0, seq_length)
160     ax.set_ylim(-1, seq_height_max + 4)
161
162     for u in sequences:
163         ax.plot(
164             torch.arange(u[0].size(0)) + 0.5, u[0], color = u[1], label = u[2]
165         )
166
167     ax.legend(frameon = False, loc = 'upper left')
168
169     delta = -1.
170     if tr is not None:
171         ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False)
172
173     if bx is not None:
174         ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False)
175
176     fig.savefig(filename, bbox_inches='tight')
177
178     plt.close('all')
179
180 ######################################################################
181
182 class AttentionLayer(nn.Module):
183     def __init__(self, in_channels, out_channels, key_channels):
184         super().__init__()
185         self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
186         self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
187         self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
188
189     def forward(self, x):
190         Q = self.conv_Q(x)
191         K = self.conv_K(x)
192         V = self.conv_V(x)
193         A = einsum('nct,ncs->nts', Q, K).softmax(2)
194         y = einsum('nts,ncs->nct', A, V)
195         return y
196
197     def __repr__(self):
198         return self._get_name() + \
199             '(in_channels={}, out_channels={}, key_channels={})'.format(
200                 self.conv_Q.in_channels,
201                 self.conv_V.out_channels,
202                 self.conv_K.out_channels
203             )
204
205     def attention(self, x):
206         Q = self.conv_Q(x)
207         K = self.conv_K(x)
208         A = einsum('nct,ncs->nts', Q, K).softmax(2)
209         return A
210
211 ######################################################################
212
213 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
214 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
215
216 ######################################################################
217
218 ks = 5
219 nc = 64
220
221 if args.positional_encoding:
222     c = math.ceil(math.log(seq_length) / math.log(2.0))
223     positional_input = (torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2
224     positional_input = positional_input.unsqueeze(0).float()
225 else:
226     positional_input = torch.zeros(1, 0, seq_length)
227
228 in_channels = 1 + positional_input.size(1)
229
230 if args.with_attention:
231
232     model = nn.Sequential(
233         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
234         nn.ReLU(),
235         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
236         nn.ReLU(),
237         AttentionLayer(nc, nc, nc),
238         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
239         nn.ReLU(),
240         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
241     )
242
243 else:
244
245     model = nn.Sequential(
246         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
247         nn.ReLU(),
248         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
249         nn.ReLU(),
250         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
251         nn.ReLU(),
252         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
253         nn.ReLU(),
254         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
255     )
256
257 nb_parameters = sum(p.numel() for p in model.parameters())
258
259 with open(f'att1d_{label}model.log', 'w') as f:
260     f.write(str(model) + '\n\n')
261     f.write(f'nb_parameters {nb_parameters}\n')
262
263 ######################################################################
264
265 batch_size = 100
266
267 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
268 mse_loss = nn.MSELoss()
269
270 model.to(device)
271 mse_loss.to(device)
272 train_input, train_targets = train_input.to(device), train_targets.to(device)
273 test_input, test_targets = test_input.to(device), test_targets.to(device)
274 positional_input = positional_input.to(device)
275
276 mu, std = train_input.mean(), train_input.std()
277
278 for e in range(args.nb_epochs):
279     acc_loss = 0.0
280
281     for input, targets in zip(train_input.split(batch_size),
282                               train_targets.split(batch_size)):
283
284         input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
285
286         output = model((input - mu) / std)
287         loss = mse_loss(output, targets)
288
289         optimizer.zero_grad()
290         loss.backward()
291         optimizer.step()
292
293         acc_loss += loss.item()
294
295     log_string(f'{e+1} {acc_loss}')
296
297 ######################################################################
298
299 train_input = train_input.detach().to('cpu')
300 train_targets = train_targets.detach().to('cpu')
301
302 for k in range(15):
303     save_sequence_images(
304         f'att1d_{label}train_{k:03d}.pdf',
305         [
306             ( train_input[k, 0], 'blue', 'Input' ),
307             ( train_targets[k, 0], 'red', 'Target' ),
308         ],
309     )
310
311 ####################
312
313 test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), -1, -1)), 1)
314 test_outputs = model((test_input - mu) / std).detach()
315
316 if args.with_attention:
317     k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
318     x = model[0:k]((test_input - mu) / std)
319     test_A = model[k].attention(x)
320     test_A = test_A.detach().to('cpu')
321
322 test_input = test_input.detach().to('cpu')
323 test_outputs = test_outputs.detach().to('cpu')
324 test_targets = test_targets.detach().to('cpu')
325 test_bx = test_bx.detach().to('cpu')
326 test_tr = test_tr.detach().to('cpu')
327
328 for k in range(15):
329     save_sequence_images(
330         f'att1d_{label}test_Y_{k:03d}.pdf',
331         [
332             ( test_input[k, 0], 'blue', 'Input' ),
333             ( test_outputs[k, 0], 'orange', 'Output' ),
334         ]
335     )
336
337     save_sequence_images(
338         f'att1d_{label}test_Yp_{k:03d}.pdf',
339         [
340             ( test_input[k, 0], 'blue', 'Input' ),
341             ( test_outputs[k, 0], 'orange', 'Output' ),
342         ],
343         test_tr[k],
344         test_bx[k]
345     )
346
347     if args.with_attention:
348         fig = plt.figure()
349         ax = fig.add_subplot(1, 1, 1)
350         ax.set_xlim(0, seq_length)
351         ax.set_ylim(0, seq_length)
352
353         ax.imshow(test_A[k], cmap = 'binary', interpolation='nearest')
354         delta = 0.
355         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
356         ax.scatter(torch.full((test_bx.size(1),), delta), test_bx[k, :, 0], color = 'black', marker = 's', clip_on=False)
357         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
358         ax.scatter(torch.full((test_tr.size(1),), delta), test_tr[k, :, 0], color = 'black', marker = '^', clip_on=False)
359
360         fig.savefig(f'att1d_{label}test_A_{k:03d}.pdf', bbox_inches='tight')
361
362     plt.close('all')
363
364 ######################################################################