OCD update.
[pytorch.git] / attentiontoy1d.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, math, sys, argparse
9
10 from torch import nn
11 from torch.nn import functional as F
12
13 import matplotlib.pyplot as plt
14
15 ######################################################################
16
17 parser = argparse.ArgumentParser(description='Toy attention model.')
18
19 parser.add_argument('--nb_epochs',
20                     type = int, default = 250)
21
22 parser.add_argument('--with_attention',
23                     help = 'Use the model with an attention layer',
24                     action='store_true', default=False)
25
26 parser.add_argument('--group_by_locations',
27                     help = 'Use the task where the grouping is location-based',
28                     action='store_true', default=False)
29
30 parser.add_argument('--positional_encoding',
31                     help = 'Provide a positional encoding',
32                     action='store_true', default=False)
33
34 parser.add_argument('--seed',
35                     type = int, default = 0,
36                     help = 'Random seed (default 0, < 0 is no seeding)')
37
38 args = parser.parse_args()
39
40 if args.seed >= 0:
41     torch.manual_seed(args.seed)
42
43 ######################################################################
44
45 label=''
46
47 if args.with_attention: label = 'wa_'
48
49 if args.group_by_locations: label += 'lg_'
50
51 if args.positional_encoding: label += 'pe_'
52
53 log_file = open(f'att1d_{label}train.log', 'w')
54
55 ######################################################################
56
57 def log_string(s):
58     if log_file is not None:
59         log_file.write(s + '\n')
60         log_file.flush()
61     print(s)
62     sys.stdout.flush()
63
64 ######################################################################
65
66 if torch.cuda.is_available():
67     device = torch.device('cuda')
68     torch.backends.cudnn.benchmark = True
69 else:
70     device = torch.device('cpu')
71
72 ######################################################################
73
74 seq_height_min, seq_height_max = 1.0, 25.0
75 seq_width_min, seq_width_max = 5.0, 11.0
76 seq_length = 100
77
78 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
79     st = torch.arange(seq_length).float()
80     st = st[None, :, None]
81     tr = tr[:, None, :, :]
82     bx = bx[:, None, :, :]
83
84     xtr =            torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
85     xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]
86
87     x = torch.cat((xtr, xbx), 2)
88
89     # u = x.sign()
90     u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
91
92     collisions = (u.sum(2) > 1).max(1).values
93     y = x.max(2).values
94
95     return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
96
97 ######################################################################
98
99 def generate_sequences(nb):
100
101     # Position / height / width
102
103     tr = torch.empty(nb, 2, 3)
104     tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
105     tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
106     tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
107
108     bx = torch.empty(nb, 2, 3)
109     bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
110     bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
111     bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
112
113     if args.group_by_locations:
114         a = torch.cat((tr, bx), 1)
115         v = a[:, :, 0].sort(1).values[:, 2:3]
116         mask_left = (a[:, :, 0] < v).float()
117         h_left = (a[:, :, 1] * mask_left).sum(1) / 2
118         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
119         valid = (h_left - h_right).abs() > 4
120     else:
121         valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)
122
123     input, collisions = positions_to_sequences(tr, bx)
124
125     if args.group_by_locations:
126         a = torch.cat((tr, bx), 1)
127         v = a[:, :, 0].sort(1).values[:, 2:3]
128         mask_left = (a[:, :, 0] < v).float()
129         h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
130         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
131         a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
132         tr, bx = a.split(2, 1)
133     else:
134         tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
135         bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)
136
137     targets, _ = positions_to_sequences(tr, bx)
138
139     valid = valid & ~collisions
140     tr = tr[valid]
141     bx = bx[valid]
142     input = input[valid][:, None, :]
143     targets = targets[valid][:, None, :]
144
145     if input.size(0) < nb:
146         input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
147         input = torch.cat((input, input2), 0)
148         targets = torch.cat((targets, targets2), 0)
149         tr = torch.cat((tr, tr2), 0)
150         bx = torch.cat((bx, bx2), 0)
151
152     return input, targets, tr, bx
153
154 ######################################################################
155
156 def save_sequence_images(filename, sequences, tr = None, bx = None):
157     fig = plt.figure()
158     ax = fig.add_subplot(1, 1, 1)
159
160     ax.set_xlim(0, seq_length)
161     ax.set_ylim(-1, seq_height_max + 4)
162
163     for u in sequences:
164         ax.plot(
165             torch.arange(u[0].size(0)) + 0.5, u[0], color = u[1], label = u[2]
166         )
167
168     ax.legend(frameon = False, loc = 'upper left')
169
170     delta = -1.
171     if tr is not None:
172         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
173
174     if bx is not None:
175         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
176
177     fig.savefig(filename, bbox_inches='tight')
178
179     plt.close('all')
180
181 ######################################################################
182
183 class AttentionLayer(nn.Module):
184     def __init__(self, in_channels, out_channels, key_channels):
185         super(AttentionLayer, self).__init__()
186         self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
187         self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
188         self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
189
190     def forward(self, x):
191         Q = self.conv_Q(x)
192         K = self.conv_K(x)
193         V = self.conv_V(x)
194         A = Q.permute(0, 2, 1).matmul(K).softmax(2)
195         x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
196         return x
197
198     def __repr__(self):
199         return self._get_name() + \
200             '(in_channels={}, out_channels={}, key_channels={})'.format(
201                 self.conv_Q.in_channels,
202                 self.conv_V.out_channels,
203                 self.conv_K.out_channels
204             )
205
206     def attention(self, x):
207         Q = self.conv_Q(x)
208         K = self.conv_K(x)
209         return Q.permute(0, 2, 1).matmul(K).softmax(2)
210
211 ######################################################################
212
213 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
214 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
215
216 ######################################################################
217
218 ks = 5
219 nc = 64
220
221 if args.positional_encoding:
222     c = math.ceil(math.log(seq_length) / math.log(2.0))
223     positional_input = (torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2
224     positional_input = positional_input.unsqueeze(0).float()
225 else:
226     positional_input = torch.zeros(1, 0, seq_length)
227
228 in_channels = 1 + positional_input.size(1)
229
230 if args.with_attention:
231
232     model = nn.Sequential(
233         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
234         nn.ReLU(),
235         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
236         nn.ReLU(),
237         AttentionLayer(nc, nc, nc),
238         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
239         nn.ReLU(),
240         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
241     )
242
243 else:
244
245     model = nn.Sequential(
246         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
247         nn.ReLU(),
248         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
249         nn.ReLU(),
250         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
251         nn.ReLU(),
252         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
253         nn.ReLU(),
254         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
255     )
256
257 nb_parameters = sum(p.numel() for p in model.parameters())
258
259 with open(f'att1d_{label}model.log', 'w') as f:
260     f.write(str(model) + '\n\n')
261     f.write(f'nb_parameters {nb_parameters}\n')
262
263 ######################################################################
264
265 batch_size = 100
266
267 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
268 mse_loss = nn.MSELoss()
269
270 model.to(device)
271 mse_loss.to(device)
272 train_input, train_targets = train_input.to(device), train_targets.to(device)
273 test_input, test_targets = test_input.to(device), test_targets.to(device)
274 positional_input = positional_input.to(device)
275
276 mu, std = train_input.mean(), train_input.std()
277
278 for e in range(args.nb_epochs):
279     acc_loss = 0.0
280
281     for input, targets in zip(train_input.split(batch_size),
282                               train_targets.split(batch_size)):
283
284         input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
285
286         output = model((input - mu) / std)
287         loss = mse_loss(output, targets)
288
289         optimizer.zero_grad()
290         loss.backward()
291         optimizer.step()
292
293         acc_loss += loss.item()
294
295     log_string(f'{e+1} {acc_loss}')
296
297 ######################################################################
298
299 train_input = train_input.detach().to('cpu')
300 train_targets = train_targets.detach().to('cpu')
301
302 for k in range(15):
303     save_sequence_images(
304         f'att1d_{label}train_{k:03d}.pdf',
305         [
306             ( train_input[k, 0], 'blue', 'Input' ),
307             ( train_targets[k, 0], 'red', 'Target' ),
308         ],
309     )
310
311 ####################
312
313 test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), -1, -1)), 1)
314 test_outputs = model((test_input - mu) / std).detach()
315
316 if args.with_attention:
317     k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
318     x = model[0:k]((test_input - mu) / std)
319     test_A = model[k].attention(x)
320     test_A = test_A.detach().to('cpu')
321
322 test_input = test_input.detach().to('cpu')
323 test_outputs = test_outputs.detach().to('cpu')
324 test_targets = test_targets.detach().to('cpu')
325
326 for k in range(15):
327     save_sequence_images(
328         f'att1d_{label}test_Y_{k:03d}.pdf',
329         [
330             ( test_input[k, 0], 'blue', 'Input' ),
331             ( test_outputs[k, 0], 'orange', 'Output' ),
332         ]
333     )
334
335     save_sequence_images(
336         f'att1d_{label}test_Yp_{k:03d}.pdf',
337         [
338             ( test_input[k, 0], 'blue', 'Input' ),
339             ( test_outputs[k, 0], 'orange', 'Output' ),
340         ],
341         test_tr[k],
342         test_bx[k]
343     )
344
345     if args.with_attention:
346         fig = plt.figure()
347         ax = fig.add_subplot(1, 1, 1)
348         ax.set_xlim(0, seq_length)
349         ax.set_ylim(0, seq_length)
350
351         ax.imshow(test_A[k], cmap = 'binary', interpolation='nearest')
352         delta = 0.
353         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
354         ax.scatter(torch.full((test_bx.size(1),), delta), test_bx[k, :, 0], color = 'black', marker = 's', clip_on=False)
355         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
356         ax.scatter(torch.full((test_tr.size(1),), delta), test_tr[k, :, 0], color = 'black', marker = '^', clip_on=False)
357
358         fig.savefig(f'att1d_{label}test_A_{k:03d}.pdf', bbox_inches='tight')
359
360     plt.close('all')
361
362 ######################################################################