Update.
[pytorch.git] / attentiontoy1d.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, math, sys, argparse
9
10 from torch import nn
11 from torch.nn import functional as F
12
13 ######################################################################
14
15 parser = argparse.ArgumentParser(description='Toy attention model.')
16
17 parser.add_argument('--nb_epochs',
18                     type = int, default = 250)
19
20 parser.add_argument('--with_attention',
21                     help = 'Use the model with an attention layer',
22                     action='store_true', default=False)
23
24 parser.add_argument('--group_by_locations',
25                     help = 'Use the task where the grouping is location-based',
26                     action='store_true', default=False)
27
28 parser.add_argument('--positional_encoding',
29                     help = 'Provide a positional encoding',
30                     action='store_true', default=False)
31
32 args = parser.parse_args()
33
34 ######################################################################
35
36 label=''
37
38 if args.with_attention: label = 'wa_'
39
40 if args.group_by_locations: label += 'lg_'
41
42 if args.positional_encoding: label += 'pe_'
43
44 log_file = open(f'att1d_{label}train.log', 'w')
45
46 ######################################################################
47
48 def log_string(s):
49     if log_file is not None:
50         log_file.write(s + '\n')
51         log_file.flush()
52     print(s)
53     sys.stdout.flush()
54
55 ######################################################################
56
57 if torch.cuda.is_available():
58     device = torch.device('cuda')
59     torch.backends.cudnn.benchmark = True
60 else:
61     device = torch.device('cpu')
62
63 torch.manual_seed(1)
64
65 ######################################################################
66
67 seq_height_min, seq_height_max = 1.0, 25.0
68 seq_width_min, seq_width_max = 5.0, 11.0
69 seq_length = 100
70
71 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
72     st = torch.arange(seq_length).float()
73     st = st[None, :, None]
74     tr = tr[:, None, :, :]
75     bx = bx[:, None, :, :]
76
77     xtr =            torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
78     xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]
79
80     x = torch.cat((xtr, xbx), 2)
81
82     # u = x.sign()
83     u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
84
85     collisions = (u.sum(2) > 1).max(1).values
86     y = x.max(2).values
87
88     return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
89
90 ######################################################################
91
92 def generate_sequences(nb):
93
94     # Position / height / width
95
96     tr = torch.empty(nb, 2, 3)
97     tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
98     tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
99     tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
100
101     bx = torch.empty(nb, 2, 3)
102     bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
103     bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
104     bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
105
106     if args.group_by_locations:
107         a = torch.cat((tr, bx), 1)
108         v = a[:, :, 0].sort(1).values[:, 2:3]
109         mask_left = (a[:, :, 0] < v).float()
110         h_left = (a[:, :, 1] * mask_left).sum(1) / 2
111         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
112         valid = (h_left - h_right).abs() > 4
113     else:
114         valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)
115
116     input, collisions = positions_to_sequences(tr, bx)
117
118     if args.group_by_locations:
119         a = torch.cat((tr, bx), 1)
120         v = a[:, :, 0].sort(1).values[:, 2:3]
121         mask_left = (a[:, :, 0] < v).float()
122         h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
123         h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
124         a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
125         tr, bx = a.split(2, 1)
126     else:
127         tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
128         bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)
129
130     targets, _ = positions_to_sequences(tr, bx)
131
132     valid = valid & ~collisions
133     tr = tr[valid]
134     bx = bx[valid]
135     input = input[valid][:, None, :]
136     targets = targets[valid][:, None, :]
137
138     if input.size(0) < nb:
139         input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
140         input = torch.cat((input, input2), 0)
141         targets = torch.cat((targets, targets2), 0)
142         tr = torch.cat((tr, tr2), 0)
143         bx = torch.cat((bx, bx2), 0)
144
145     return input, targets, tr, bx
146
147 ######################################################################
148
149 import matplotlib.pyplot as plt
150
151 def save_sequence_images(filename, sequences, tr = None, bx = None):
152     fig = plt.figure()
153     ax = fig.add_subplot(1, 1, 1)
154
155     ax.set_xlim(0, seq_length)
156     ax.set_ylim(-1, seq_height_max + 4)
157
158     for u in sequences:
159         ax.plot(
160             torch.arange(u[0].size(0)) + 0.5, u[0], color = u[1], label = u[2]
161         )
162
163     ax.legend(frameon = False, loc = 'upper left')
164
165     delta = -1.
166     if tr is not None:
167         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
168
169     if bx is not None:
170         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
171
172     fig.savefig(filename, bbox_inches='tight')
173
174     plt.close('all')
175
176 ######################################################################
177
178 class AttentionLayer(nn.Module):
179     def __init__(self, in_channels, out_channels, key_channels):
180         super(AttentionLayer, self).__init__()
181         self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
182         self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
183         self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
184
185     def forward(self, x):
186         Q = self.conv_Q(x)
187         K = self.conv_K(x)
188         V = self.conv_V(x)
189         A = Q.permute(0, 2, 1).matmul(K).softmax(2)
190         x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
191         return x
192
193     def __repr__(self):
194         return self._get_name() + \
195             '(in_channels={}, out_channels={}, key_channels={})'.format(
196                 self.conv_Q.in_channels,
197                 self.conv_V.out_channels,
198                 self.conv_K.out_channels
199             )
200
201     def attention(self, x):
202         Q = self.conv_Q(x)
203         K = self.conv_K(x)
204         return Q.permute(0, 2, 1).matmul(K).softmax(2)
205
206 ######################################################################
207
208 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
209 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
210
211 ######################################################################
212
213 ks = 5
214 nc = 64
215
216 if args.positional_encoding:
217     c = math.ceil(math.log(seq_length) / math.log(2.0))
218     positional_input = (torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2
219     positional_input = positional_input.unsqueeze(0).float()
220 else:
221     positional_input = torch.zeros(1, 0, seq_length)
222
223 in_channels = 1 + positional_input.size(1)
224
225 if args.with_attention:
226
227     model = nn.Sequential(
228         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
229         nn.ReLU(),
230         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
231         nn.ReLU(),
232         AttentionLayer(nc, nc, nc),
233         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
234         nn.ReLU(),
235         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
236     )
237
238 else:
239
240     model = nn.Sequential(
241         nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
242         nn.ReLU(),
243         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
244         nn.ReLU(),
245         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
246         nn.ReLU(),
247         nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
248         nn.ReLU(),
249         nn.Conv1d(nc,  1, kernel_size = ks, padding = ks//2)
250     )
251
252 nb_parameters = sum(p.numel() for p in model.parameters())
253
254 with open(f'att1d_{label}model.log', 'w') as f:
255     f.write(str(model) + '\n\n')
256     f.write(f'nb_parameters {nb_parameters}\n')
257
258 ######################################################################
259
260 batch_size = 100
261
262 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
263 mse_loss = nn.MSELoss()
264
265 model.to(device)
266 mse_loss.to(device)
267 train_input, train_targets = train_input.to(device), train_targets.to(device)
268 test_input, test_targets = test_input.to(device), test_targets.to(device)
269 positional_input = positional_input.to(device)
270
271 mu, std = train_input.mean(), train_input.std()
272
273 for e in range(args.nb_epochs):
274     acc_loss = 0.0
275
276     for input, targets in zip(train_input.split(batch_size),
277                               train_targets.split(batch_size)):
278
279         input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
280
281         output = model((input - mu) / std)
282         loss = mse_loss(output, targets)
283
284         optimizer.zero_grad()
285         loss.backward()
286         optimizer.step()
287
288         acc_loss += loss.item()
289
290     log_string(f'{e+1} {acc_loss}')
291
292 ######################################################################
293
294 train_input = train_input.detach().to('cpu')
295 train_targets = train_targets.detach().to('cpu')
296
297 for k in range(15):
298     save_sequence_images(
299         f'att1d_{label}train_{k:03d}.pdf',
300         [
301             ( train_input[k, 0], 'blue', 'Input' ),
302             ( train_targets[k, 0], 'red', 'Target' ),
303         ],
304     )
305
306 ####################
307
308 test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), -1, -1)), 1)
309 test_outputs = model((test_input - mu) / std).detach()
310
311 if args.with_attention:
312     x = model[0:4]((test_input - mu) / std)
313     test_A = model[4].attention(x)
314     test_A = test_A.detach().to('cpu')
315
316 test_input = test_input.detach().to('cpu')
317 test_outputs = test_outputs.detach().to('cpu')
318 test_targets = test_targets.detach().to('cpu')
319
320 for k in range(15):
321     save_sequence_images(
322         f'att1d_{label}test_Y_{k:03d}.pdf',
323         [
324             ( test_input[k, 0], 'blue', 'Input' ),
325             ( test_outputs[k, 0], 'orange', 'Output' ),
326         ]
327     )
328
329     save_sequence_images(
330         f'att1d_{label}test_Yp_{k:03d}.pdf',
331         [
332             ( test_input[k, 0], 'blue', 'Input' ),
333             ( test_outputs[k, 0], 'orange', 'Output' ),
334         ],
335         test_tr[k],
336         test_bx[k]
337     )
338
339     if args.with_attention:
340         fig = plt.figure()
341         ax = fig.add_subplot(1, 1, 1)
342         ax.set_xlim(0, seq_length)
343         ax.set_ylim(0, seq_length)
344
345         ax.imshow(test_A[k], cmap = 'binary', interpolation='nearest')
346         delta = 0.
347         ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
348         ax.scatter(torch.full((test_bx.size(1),), delta), test_bx[k, :, 0], color = 'black', marker = 's', clip_on=False)
349         ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
350         ax.scatter(torch.full((test_tr.size(1),), delta), test_tr[k, :, 0], color = 'black', marker = '^', clip_on=False)
351
352         fig.savefig(f'att1d_{label}test_A_{k:03d}.pdf', bbox_inches='tight')
353
354     plt.close('all')
355
356 ######################################################################