3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, math, sys, argparse
10 from torch import nn, einsum
11 from torch.nn import functional as F
13 import matplotlib.pyplot as plt
15 ######################################################################
17 parser = argparse.ArgumentParser(description='Toy attention model.')
19 parser.add_argument('--nb_epochs',
20 type = int, default = 250)
22 parser.add_argument('--with_attention',
23 help = 'Use the model with an attention layer',
24 action='store_true', default=False)
26 parser.add_argument('--group_by_locations',
27 help = 'Use the task where the grouping is location-based',
28 action='store_true', default=False)
30 parser.add_argument('--positional_encoding',
31 help = 'Provide a positional encoding',
32 action='store_true', default=False)
34 parser.add_argument('--seed',
35 type = int, default = 0,
36 help = 'Random seed (default 0, < 0 is no seeding)')
38 args = parser.parse_args()
41 torch.manual_seed(args.seed)
43 ######################################################################
47 if args.with_attention: label = 'wa_'
49 if args.group_by_locations: label += 'lg_'
51 if args.positional_encoding: label += 'pe_'
53 log_file = open(f'att1d_{label}train.log', 'w')
55 ######################################################################
58 if log_file is not None:
59 log_file.write(s + '\n')
64 ######################################################################
66 if torch.cuda.is_available():
67 device = torch.device('cuda')
68 torch.backends.cudnn.benchmark = True
70 device = torch.device('cpu')
72 ######################################################################
74 seq_height_min, seq_height_max = 1.0, 25.0
75 seq_width_min, seq_width_max = 5.0, 11.0
78 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
79 st = torch.arange(seq_length, device = device).float()
80 st = st[None, :, None]
81 tr = tr[:, None, :, :]
82 bx = bx[:, None, :, :]
84 xtr = torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
85 xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]
87 x = torch.cat((xtr, xbx), 2)
89 u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
91 collisions = (u.sum(2) > 1).max(1).values
94 return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
96 ######################################################################
98 def generate_sequences(nb):
100 # Position / height / width
102 tr = torch.empty(nb, 2, 3, device = device)
103 tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
104 tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
105 tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
107 bx = torch.empty(nb, 2, 3, device = device)
108 bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
109 bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
110 bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
112 if args.group_by_locations:
113 a = torch.cat((tr, bx), 1)
114 v = a[:, :, 0].sort(1).values[:, 2:3]
115 mask_left = (a[:, :, 0] < v).float()
116 h_left = (a[:, :, 1] * mask_left).sum(1) / 2
117 h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
118 valid = (h_left - h_right).abs() > 4
120 valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)
122 input, collisions = positions_to_sequences(tr, bx)
124 if args.group_by_locations:
125 a = torch.cat((tr, bx), 1)
126 v = a[:, :, 0].sort(1).values[:, 2:3]
127 mask_left = (a[:, :, 0] < v).float()
128 h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
129 h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
130 a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
131 tr, bx = a.split(2, 1)
133 tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
134 bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)
136 targets, _ = positions_to_sequences(tr, bx)
138 valid = valid & ~collisions
141 input = input[valid][:, None, :]
142 targets = targets[valid][:, None, :]
144 if input.size(0) < nb:
145 input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
146 input = torch.cat((input, input2), 0)
147 targets = torch.cat((targets, targets2), 0)
148 tr = torch.cat((tr, tr2), 0)
149 bx = torch.cat((bx, bx2), 0)
151 return input, targets, tr, bx
153 ######################################################################
155 def save_sequence_images(filename, sequences, tr = None, bx = None):
157 ax = fig.add_subplot(1, 1, 1)
159 ax.set_xlim(0, seq_length)
160 ax.set_ylim(-1, seq_height_max + 4)
164 torch.arange(u[0].size(0)) + 0.5, u[0], color = u[1], label = u[2]
167 ax.legend(frameon = False, loc = 'upper left')
171 ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False)
174 ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False)
176 fig.savefig(filename, bbox_inches='tight')
180 ######################################################################
182 class AttentionLayer(nn.Module):
183 def __init__(self, in_channels, out_channels, key_channels):
185 self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
186 self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
187 self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
189 def forward(self, x):
193 A = einsum('nct,ncs->nts', Q, K).softmax(2)
194 y = einsum('nts,ncs->nct', A, V)
198 return self._get_name() + \
199 '(in_channels={}, out_channels={}, key_channels={})'.format(
200 self.conv_Q.in_channels,
201 self.conv_V.out_channels,
202 self.conv_K.out_channels
205 def attention(self, x):
208 A = einsum('nct,ncs->nts', Q, K).softmax(2)
211 ######################################################################
213 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
214 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
216 ######################################################################
221 if args.positional_encoding:
222 c = math.ceil(math.log(seq_length) / math.log(2.0))
223 positional_input = (torch.arange(seq_length).unsqueeze(0) // 2**torch.arange(c).unsqueeze(1))%2
224 positional_input = positional_input.unsqueeze(0).float()
226 positional_input = torch.zeros(1, 0, seq_length)
228 in_channels = 1 + positional_input.size(1)
230 if args.with_attention:
232 model = nn.Sequential(
233 nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
235 nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
237 AttentionLayer(nc, nc, nc),
238 nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
240 nn.Conv1d(nc, 1, kernel_size = ks, padding = ks//2)
245 model = nn.Sequential(
246 nn.Conv1d(in_channels, nc, kernel_size = ks, padding = ks//2),
248 nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
250 nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
252 nn.Conv1d(nc, nc, kernel_size = ks, padding = ks//2),
254 nn.Conv1d(nc, 1, kernel_size = ks, padding = ks//2)
257 nb_parameters = sum(p.numel() for p in model.parameters())
259 with open(f'att1d_{label}model.log', 'w') as f:
260 f.write(str(model) + '\n\n')
261 f.write(f'nb_parameters {nb_parameters}\n')
263 ######################################################################
267 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
268 mse_loss = nn.MSELoss()
272 train_input, train_targets = train_input.to(device), train_targets.to(device)
273 test_input, test_targets = test_input.to(device), test_targets.to(device)
274 positional_input = positional_input.to(device)
276 mu, std = train_input.mean(), train_input.std()
278 for e in range(args.nb_epochs):
281 for input, targets in zip(train_input.split(batch_size),
282 train_targets.split(batch_size)):
284 input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
286 output = model((input - mu) / std)
287 loss = mse_loss(output, targets)
289 optimizer.zero_grad()
293 acc_loss += loss.item()
295 log_string(f'{e+1} {acc_loss}')
297 ######################################################################
299 train_input = train_input.detach().to('cpu')
300 train_targets = train_targets.detach().to('cpu')
303 save_sequence_images(
304 f'att1d_{label}train_{k:03d}.pdf',
306 ( train_input[k, 0], 'blue', 'Input' ),
307 ( train_targets[k, 0], 'red', 'Target' ),
313 test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), -1, -1)), 1)
314 test_outputs = model((test_input - mu) / std).detach()
316 if args.with_attention:
317 k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
318 x = model[0:k]((test_input - mu) / std)
319 test_A = model[k].attention(x)
320 test_A = test_A.detach().to('cpu')
322 test_input = test_input.detach().to('cpu')
323 test_outputs = test_outputs.detach().to('cpu')
324 test_targets = test_targets.detach().to('cpu')
325 test_bx = test_bx.detach().to('cpu')
326 test_tr = test_tr.detach().to('cpu')
329 save_sequence_images(
330 f'att1d_{label}test_Y_{k:03d}.pdf',
332 ( test_input[k, 0], 'blue', 'Input' ),
333 ( test_outputs[k, 0], 'orange', 'Output' ),
337 save_sequence_images(
338 f'att1d_{label}test_Yp_{k:03d}.pdf',
340 ( test_input[k, 0], 'blue', 'Input' ),
341 ( test_outputs[k, 0], 'orange', 'Output' ),
347 if args.with_attention:
349 ax = fig.add_subplot(1, 1, 1)
350 ax.set_xlim(0, seq_length)
351 ax.set_ylim(0, seq_length)
353 ax.imshow(test_A[k], cmap = 'binary', interpolation='nearest')
355 ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
356 ax.scatter(torch.full((test_bx.size(1),), delta), test_bx[k, :, 0], color = 'black', marker = 's', clip_on=False)
357 ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
358 ax.scatter(torch.full((test_tr.size(1),), delta), test_tr[k, :, 0], color = 'black', marker = '^', clip_on=False)
360 fig.savefig(f'att1d_{label}test_A_{k:03d}.pdf', bbox_inches='tight')
364 ######################################################################