Update.
[pytorch.git] / attentiontoy1d.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, math, sys, argparse
9
10 from torch import nn
11 from torch.nn import functional as F
12
13 ######################################################################
14
15 parser = argparse.ArgumentParser(description='Toy RNN.')
16
17 parser.add_argument('--nb_epochs',
18                     type = int, default = 250)
19
20 parser.add_argument('--with_attention',
21                     help = 'Use the model with an attention layer',
22                     action='store_true', default=False)
23
24 parser.add_argument('--group_by_locations',
25                     help = 'Use the task where the grouping is location-based',
26                     action='store_true', default=False)
27
28 parser.add_argument('--positional_encoding',
29                     help = 'Provide a positional encoding',
30                     action='store_true', default=False)
31
32 args = parser.parse_args()
33
34 ######################################################################
35
36 label=''
37
38 if args.with_attention: label = 'wa_'
39
40 if args.group_by_locations: label += 'lg_'
41
42 if args.positional_encoding: label += 'pe_'
43
44 log_file = open(f'att1d_{label}train.log', 'w')
45
46 ######################################################################
47
48 def log_string(s):
49     if log_file is not None:
50         log_file.write(s + '\n')
51         log_file.flush()
52     print(s)
53     sys.stdout.flush()
54
55 ######################################################################
56
57 if torch.cuda.is_available():
58     device = torch.device('cuda')
59     torch.backends.cudnn.benchmark = True
60 else:
61     device = torch.device('cpu')
62
63 torch.manual_seed(1)
64
65 ######################################################################
66
67 seq_height_min, seq_height_max = 1.0, 25.0
68 seq_width_min, seq_width_max = 5.0, 11.0
69 seq_length = 100
70
71 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
72     st = torch.arange(seq_length).float()
73     st = st[None, :, None]
74     tr = tr[:, None, :, :]
75     bx = bx[:, None, :, :]
76
77     xtr =            torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
78     xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]
79
80     x = torch.cat((xtr, xbx), 2)
81
82     # u = x.sign()
83     u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
84
85     collisions = (u.sum(2) > 1).max(1).values
86     y = x.max(2).values
87
88     return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
89
90 ######################################################################
91
92 def generate_sequences(nb):
93
94     # Position / height / width
95
96     tr = torch.empty(nb, 2, 3)
97     tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
98     tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
99     tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
100
101     bx = torch.empty(nb, 2, 3)
102     bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
103     bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
104     bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
105
106     if args.group_by_locations:
107         a = torch.cat((tr, bx), 1)
108         v = a[:, :, 0].sort(1).values[:, 2:3]
109         mask_left = (a[:, :, 0] < v).float()
110         h_left = (a[:, :, 1] * mask_left).sum(1) / 2
111         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
112         valid = (h_left - h_right).abs() > 4
113     else:
114         valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)
115
116     input, collisions = positions_to_sequences(tr, bx)
117
118     if args.group_by_locations:
119         a = torch.cat((tr, bx), 1)
120         v = a[:, :, 0].sort(1).values[:, 2:3]
121         mask_left = (a[:, :, 0] < v).float()
122         h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
123         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
124         a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
125         tr, bx = a.split(2, 1)
126     else:
127         tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
128         bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)
129
130     targets, _ = positions_to_sequences(tr, bx)
131
132     valid = valid & ~collisions
133     tr = tr[valid]
134     bx = bx[valid]
135     input = input[valid][:, None, :]
136     targets = targets[valid][:, None, :]
137
138     if input.size(0) < nb:
139         input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
140         input = torch.cat((input, input2), 0)
141         targets = torch.cat((targets, targets2), 0)
142         tr = torch.cat((tr, tr2), 0)
143         bx = torch.cat((bx, bx2), 0)
144
145     return input, targets, tr, bx
146
147 ######################################################################
148
149 import matplotlib.pyplot as plt
150 import matplotlib.collections as mc
151
152 def save_sequence_images(filename, sequences, tr = None, bx = None):
153     fig = plt.figure()
154     ax = fig.add_subplot(1, 1, 1)
155
156     ax.set_xlim(0, seq_length)
157     ax.set_ylim(-1, seq_height_max + 4)
158
159     for u in sequences:
160         ax.plot(
161             torch.arange(u[0].size(0)) + 0.5, u[0], color = u[1], label = u[2]
162         )
163
164     ax.legend(frameon = False, loc = 'upper left')
165
166     delta = -1.
167     if tr is not None:
168         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
169
170     if bx is not None:
171         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
172
173     fig.savefig(filename, bbox_inches='tight')
174
175     plt.close('all')
176
177 ######################################################################
178
179 class AttentionLayer(nn.Module):
180     def __init__(self, in_channels, out_channels, key_channels):
181         super(AttentionLayer, self).__init__()
182         self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
183         self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
184         self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
185
186     def forward(self, x):
187         Q = self.conv_Q(x)
188         K = self.conv_K(x)
189         V = self.conv_V(x)
190         A = Q.permute(0, 2, 1).matmul(K).softmax(2)
191         x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
192         return x
193
194     def __repr__(self):
195         return self._get_name() + \
196             '(in_channels={}, out_channels={}, key_channels={})'.format(
197                 self.conv_Q.in_channels,
198                 self.conv_V.out_channels,
199                 self.conv_K.out_channels
200             )
201
202     def attention(self, x):
203         Q = self.conv_Q(x)
204         K = self.conv_K(x)
205         return Q.permute(0, 2, 1).matmul(K).softmax(2)
206
207 ######################################################################
208
209 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
210 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
211
212 ######################################################################
213
214 ks = 5
215 nc = 64
216
217 if args.positional_encoding:
218     c = math.ceil(math.log(seq_length) / math.log(2.0))
219     positional_input = (torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2
220     positional_input = positional_input.unsqueeze(0).float()
221 else:
222     positional_input = torch.zeros(1, 0, seq_length)
223
224 in_channels = 1 + positional_input.size(1)
225
226 if args.with_attention:
227
228     model = nn.Sequential(
229         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
230         nn.ReLU(),
231         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
232         nn.ReLU(),
233         AttentionLayer(nc, nc, nc),
234         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
235         nn.ReLU(),
236         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
237     )
238
239 else:
240
241     model = nn.Sequential(
242         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
243         nn.ReLU(),
244         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
245         nn.ReLU(),
246         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
247         nn.ReLU(),
248         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
249         nn.ReLU(),
250         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
251     )
252
253 nb_parameters = sum(p.numel() for p in model.parameters())
254
255 with open(f'att1d_{label}model.log', 'w') as f:
256     f.write(str(model) + '\n\n')
257     f.write(f'nb_parameters {nb_parameters}\n')
258
259 ######################################################################
260
261 batch_size = 100
262
263 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
264 mse_loss = nn.MSELoss()
265
266 model.to(device)
267 mse_loss.to(device)
268 train_input, train_targets = train_input.to(device), train_targets.to(device)
269 test_input, test_targets = test_input.to(device), test_targets.to(device)
270 positional_input = positional_input.to(device)
271
272 mu, std = train_input.mean(), train_input.std()
273
274 for e in range(args.nb_epochs):
275     acc_loss = 0.0
276
277     for input, targets in zip(train_input.split(batch_size),
278                               train_targets.split(batch_size)):
279
280         input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
281
282         output = model((input - mu) / std)
283         loss = mse_loss(output, targets)
284
285         optimizer.zero_grad()
286         loss.backward()
287         optimizer.step()
288
289         acc_loss += loss.item()
290
291     log_string(f'{e+1} {acc_loss}')
292
293 ######################################################################
294
295 train_input = train_input.detach().to('cpu')
296 train_targets = train_targets.detach().to('cpu')
297
298 for k in range(15):
299     save_sequence_images(
300         f'att1d_{label}train_{k:03d}.pdf',
301         [
302             ( train_input[k, 0], 'blue', 'Input' ),
303             ( train_targets[k, 0], 'red', 'Target' ),
304         ],
305     )
306
307 ####################
308
309 test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), -1, -1)), 1)
310 test_outputs = model((test_input - mu) / std).detach()
311
312 if args.with_attention:
313     x = model[0:4]((test_input - mu) / std)
314     test_A = model[4].attention(x)
315     test_A = test_A.detach().to('cpu')
316
317 test_input = test_input.detach().to('cpu')
318 test_outputs = test_outputs.detach().to('cpu')
319 test_targets = test_targets.detach().to('cpu')
320
321 for k in range(15):
322     save_sequence_images(
323         f'att1d_{label}test_Y_{k:03d}.pdf',
324         [
325             ( test_input[k, 0], 'blue', 'Input' ),
326             ( test_outputs[k, 0], 'orange', 'Output' ),
327         ]
328     )
329
330     save_sequence_images(
331         f'att1d_{label}test_Yp_{k:03d}.pdf',
332         [
333             ( test_input[k, 0], 'blue', 'Input' ),
334             ( test_outputs[k, 0], 'orange', 'Output' ),
335         ],
336         test_tr[k],
337         test_bx[k]
338     )
339
340     if args.with_attention:
341         fig = plt.figure()
342         ax = fig.add_subplot(1, 1, 1)
343         ax.set_xlim(0, seq_length)
344         ax.set_ylim(0, seq_length)
345
346         ax.imshow(test_A[k], cmap = 'binary', interpolation='nearest')
347         delta = 0.
348         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
349         ax.scatter(torch.full((test_bx.size(1),), delta), test_bx[k, :, 0], color = 'black', marker = 's', clip_on=False)
350         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
351         ax.scatter(torch.full((test_tr.size(1),), delta), test_tr[k, :, 0], color = 'black', marker = '^', clip_on=False)
352
353         fig.savefig(f'att1d_{label}test_A_{k:03d}.pdf', bbox_inches='tight')
354
355     plt.close('all')
356
357 ######################################################################