c26afedeb58e60a76e14bc6cb88d66528b4f16b0
[path.git] / path.py
1 #!/usr/bin/env python
2
3 import sys, math, time, argparse
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12 parser = argparse.ArgumentParser(
13     description='Path-planning as denoising.',
14     formatter_class = argparse.ArgumentDefaultsHelpFormatter
15 )
16
17 parser.add_argument('--nb_epochs',
18                     type = int, default = 25)
19
20 parser.add_argument('--batch_size',
21                     type = int, default = 100)
22
23 parser.add_argument('--nb_residual_blocks',
24                     type = int, default = 16)
25
26 parser.add_argument('--nb_channels',
27                     type = int, default = 128)
28
29 parser.add_argument('--kernel_size',
30                     type = int, default = 3)
31
32 parser.add_argument('--nb_for_train',
33                     type = int, default = 100000)
34
35 parser.add_argument('--nb_for_test',
36                     type = int, default = 10000)
37
38 parser.add_argument('--world_height',
39                     type = int, default = 23)
40
41 parser.add_argument('--world_width',
42                     type = int, default = 31)
43
44 parser.add_argument('--world_nb_walls',
45                     type = int, default = 15)
46
47 parser.add_argument('--seed',
48                     type = int, default = 0,
49                     help = 'Random seed (default 0, < 0 is no seeding)')
50
51 ######################################################################
52
53 args = parser.parse_args()
54
55 if args.seed >= 0:
56     torch.manual_seed(args.seed)
57
58 ######################################################################
59
60 label=''
61
62 log_file = open(f'path_{label}train.log', 'w')
63
64 ######################################################################
65
66 def log_string(s):
67     t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime())
68     s = t + ' - ' + s
69     if log_file is not None:
70         log_file.write(s + '\n')
71         log_file.flush()
72
73     print(s)
74     sys.stdout.flush()
75
76 ######################################################################
77
78 class ETA:
79     def __init__(self, n):
80         self.n = n
81         self.t0 = time.time()
82
83     def eta(self, k):
84         if k > 0:
85             t = time.time()
86             u = self.t0 + ((t - self.t0) * self.n) // k
87             return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u))
88         else:
89             return "n.a."
90
91 ######################################################################
92
93 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
95 log_string(f'device {device}')
96
97 ######################################################################
98
99 def create_maze(h = 11, w = 15, nb_walls = 10):
100     a, k = 0, 0
101
102     while k < nb_walls:
103         while True:
104             if a == 0:
105                 m = torch.zeros(h, w, dtype = torch.int64)
106                 m[ 0,  :] = 1
107                 m[-1,  :] = 1
108                 m[ :,  0] = 1
109                 m[ :, -1] = 1
110
111             r = torch.rand(4)
112
113             if r[0] <= 0.5:
114                 i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item())
115                 i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2
116                 i1, i2 = min(i1, i2), max(i1, i2)
117                 if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1:
118                     m[i1:i2+1, j] = 1
119                     break
120             else:
121                 i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item())
122                 i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2
123                 j1, j2 = min(j1, j2), max(j1, j2)
124                 if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1:
125                     m[i, j1:j2+1] = 1
126                     break
127             a += 1
128
129             if a > 10 * nb_walls: a, k = 0, 0
130
131         k += 1
132
133     return m
134
135 ######################################################################
136
137 def random_free_position(walls):
138     p = torch.randperm(walls.numel())
139     k = p[walls.view(-1)[p] == 0][0].item()
140     return k//walls.size(1), k%walls.size(1)
141
142 def create_transitions(walls, nb):
143     trans = walls.new_zeros((9,) + walls.size())
144     t = torch.randint(4, (nb,))
145     i, j = random_free_position(walls)
146
147     for k in range(t.size(0)):
148         di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]]
149         ip, jp = i + di, j + dj
150         if ip < 0 or ip >= walls.size(0) or \
151            jp < 0 or jp >= walls.size(1) or \
152            walls[ip, jp] > 0:
153             trans[t[k] + 4, i, j] += 1
154         else:
155             trans[t[k], i, j] += 1
156             i, j = ip, jp
157
158     n = trans[0:8].sum(dim = 0, keepdim = True)
159     trans[8:9] = n
160     trans[0:8] = trans[0:8] / (n + (n == 0).long())
161
162     return trans
163
164 ######################################################################
165
166 def compute_distance(walls, i, j):
167     max_length = walls.numel()
168     dist = torch.full_like(walls, max_length)
169
170     dist[i, j] = 0
171     pred_dist = torch.empty_like(dist)
172
173     while True:
174         pred_dist.copy_(dist)
175         d = torch.cat(
176             (
177                 dist[None, 1:-1, 0:-2],
178                 dist[None, 2:, 1:-1],
179                 dist[None, 1:-1, 2:],
180                 dist[None, 0:-2, 1:-1]
181             ),
182         0).min(dim = 0)[0] + 1
183
184         dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
185         dist = walls * max_length + (1 - walls) * dist
186
187         if dist.equal(pred_dist): return dist * (1 - walls)
188
189 ######################################################################
190
191 def compute_policy(walls, i, j):
192     distance = compute_distance(walls, i, j)
193     distance = distance + walls.numel() * walls
194
195     value = distance.new_full((4,) + distance.size(), walls.numel())
196     value[0,  :  , 1:  ] = distance[ :  ,  :-1]
197     value[1,  :  ,  :-1] = distance[ :  , 1:  ]
198     value[2, 1:  ,  :  ] = distance[ :-1,  :  ]
199     value[3,  :-1,  :  ] = distance[1:  ,  :  ]
200
201     proba = (value.min(dim = 0)[0][None] == value).float()
202     proba = proba / proba.sum(dim = 0)[None]
203     proba = proba * (1 - walls)
204
205     return proba
206
207 ######################################################################
208
209 def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
210     input = torch.empty(nb, 10, h, w)
211     targets = torch.empty(nb, 2, h, w)
212
213     if type(traj_length) == tuple:
214         l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long()
215     else:
216         l = torch.full((nb,), traj_length).long()
217
218     eta = ETA(nb)
219
220     for n in range(nb):
221         if n%(max(10, nb//1000)) == 0:
222             log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}')
223
224         walls = create_maze(h, w, nb_walls)
225         trans = create_transitions(walls, l[n])
226
227         i, j = random_free_position(walls)
228         start = walls.new_zeros(walls.size())
229         start[i, j] = 1
230         dist = compute_distance(walls, i, j)
231
232         input[n] = torch.cat((trans, start[None]), 0)
233         targets[n] = torch.cat((walls[None], dist[None]), 0)
234
235     return input, targets
236
237 ######################################################################
238
239 def save_image(name, input, targets, output = None):
240     input, targets = input.cpu(), targets.cpu()
241
242     weight = torch.tensor(
243         [
244             [ 1.0, 0.0, 0.0 ],
245             [ 1.0, 1.0, 0.0 ],
246             [ 0.0, 1.0, 0.0 ],
247             [ 0.0, 0.0, 1.0 ],
248         ] ).t()[:, :, None, None]
249
250     # img_trans = F.conv2d(input[:, 0:5], weight)
251     # img_trans = img_trans / img_trans.max()
252
253     img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1)
254     img_trans = 1 - img_trans / img_trans.max()
255
256     img_start = input[:, 9:10].expand(-1, 3, -1, -1)
257     img_start = 1 - img_start / img_start.max()
258
259     img_walls = targets[:, 0:1].expand(-1, 3, -1, -1)
260     img_walls = 1 - img_walls / img_walls.max()
261
262     # img_pi = F.conv2d(targets[:, 2:6], weight)
263     # img_pi = img_pi / img_pi.max()
264
265     img_dist = targets[:, 1:2].expand(-1, 3, -1, -1)
266     img_dist = img_dist / img_dist.max()
267
268     img = (
269         img_start[:, None],
270         img_trans[:, None],
271         img_walls[:, None],
272         # img_pi[:, None],
273         img_dist[:, None],
274         )
275
276     if output is not None:
277         output = output.cpu()
278         img_walls = output[:, 0:1].expand(-1, 3, -1, -1)
279         img_walls = 1 - img_walls / img_walls.max()
280
281         # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight)
282         # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1]
283
284         img_dist = output[:, 1:2].expand(-1, 3, -1, -1)
285         img_dist = img_dist / img_dist.max()
286
287         img += (
288             img_walls[:, None],
289             img_dist[:, None],
290             # img_pi[:, None],
291         )
292
293     img_all = torch.cat(img, 1)
294
295     img_all = img_all.view(
296         img_all.size(0) * img_all.size(1),
297         img_all.size(2),
298         img_all.size(3),
299         img_all.size(4),
300     )
301
302     torchvision.utils.save_image(
303         img_all,
304         name,
305         padding = 1, pad_value = 0.5, nrow = len(img)
306     )
307
308     log_string(f'Wrote {name}')
309
310 ######################################################################
311
312 class Net(nn.Module):
313     def __init__(self):
314         super().__init__()
315         nh = 128
316         self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2)
317         self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
318         self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
319         self.conv4 = nn.Conv2d(nh,  2, kernel_size = 5, padding = 2)
320
321     def forward(self, x):
322         x = F.relu(self.conv1(x))
323         x = F.relu(self.conv2(x))
324         x = F.relu(self.conv3(x))
325         x = self.conv4(x)
326         return x
327
328 ######################################################################
329
330 class ResNetBlock(nn.Module):
331     def __init__(self, nb_channels, kernel_size):
332         super().__init__()
333
334         self.conv1 = nn.Conv2d(nb_channels, nb_channels,
335                                kernel_size = kernel_size,
336                                padding = (kernel_size - 1) // 2)
337
338         self.bn1 = nn.BatchNorm2d(nb_channels)
339
340         self.conv2 = nn.Conv2d(nb_channels, nb_channels,
341                                kernel_size = kernel_size,
342                                padding = (kernel_size - 1) // 2)
343
344         self.bn2 = nn.BatchNorm2d(nb_channels)
345
346     def forward(self, x):
347         y = F.relu(self.bn1(self.conv1(x)))
348         y = F.relu(x + self.bn2(self.conv2(y)))
349         return y
350
351 class ResNet(nn.Module):
352
353     def __init__(self,
354                  in_channels, out_channels,
355                  nb_residual_blocks, nb_channels, kernel_size):
356         super().__init__()
357
358         self.pre_process = nn.Sequential(
359             nn.Conv2d(in_channels, nb_channels,
360                       kernel_size = kernel_size,
361                       padding = (kernel_size - 1) // 2),
362             nn.BatchNorm2d(nb_channels),
363             nn.ReLU(inplace = True),
364         )
365
366         blocks = []
367         for k in range(nb_residual_blocks):
368             blocks.append(ResNetBlock(nb_channels, kernel_size))
369
370         self.resnet_blocks = nn.Sequential(*blocks)
371
372         self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1)
373
374     def forward(self, x):
375         x = self.pre_process(x)
376         x = self.resnet_blocks(x)
377         x = self.post_process(x)
378         return x
379
380 ######################################################################
381
382 data_filename = 'path.dat'
383
384 try:
385     input, targets = torch.load(data_filename)
386     log_string('Data loaded.')
387     assert input.size(0) == args.nb_for_train + args.nb_for_test and \
388            input.size(1) == 10 and \
389            input.size(2) == args.world_height and \
390            input.size(3) == args.world_width and \
391            \
392            targets.size(0) == args.nb_for_train + args.nb_for_test and \
393            targets.size(1) == 2 and \
394            targets.size(2) == args.world_height and \
395            targets.size(3) == args.world_width
396
397 except FileNotFoundError:
398     log_string('Generating data.')
399
400     input, targets = create_maze_data(
401         nb = args.nb_for_train + args.nb_for_test,
402         h = args.world_height, w = args.world_width,
403         nb_walls = args.world_nb_walls,
404         traj_length = (100, 10000)
405     )
406
407     torch.save((input, targets), data_filename)
408
409 except:
410     log_string('Error when loading data.')
411     exit(1)
412
413 ######################################################################
414
415 for n in vars(args):
416     log_string(f'args.{n} {getattr(args, n)}')
417
418 model = ResNet(
419     in_channels = 10, out_channels = 2,
420     nb_residual_blocks = args.nb_residual_blocks,
421     nb_channels = args.nb_channels,
422     kernel_size = args.kernel_size
423 )
424
425 criterion = nn.MSELoss()
426
427 model.to(device)
428 criterion.to(device)
429
430 input, targets = input.to(device), targets.to(device)
431
432 train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train]
433 test_input, test_targets   = input[args.nb_for_train:], targets[args.nb_for_train:]
434
435 mu, std = train_input.mean(), train_input.std()
436 train_input.sub_(mu).div_(std)
437 test_input.sub_(mu).div_(std)
438
439 ######################################################################
440
441 eta = ETA(args.nb_epochs)
442
443 for e in range(args.nb_epochs):
444
445     if e < args.nb_epochs // 2:
446         lr = 1e-2
447     else:
448         lr = 1e-3
449
450     optimizer = torch.optim.Adam(model.parameters(), lr = lr)
451
452     acc_train_loss = 0.0
453
454     for input, targets in zip(train_input.split(args.batch_size),
455                               train_targets.split(args.batch_size)):
456         output = model(input)
457
458         loss = criterion(output, targets)
459         acc_train_loss += loss.item()
460
461         optimizer.zero_grad()
462         loss.backward()
463         optimizer.step()
464
465     test_loss = 0.0
466
467     for input, targets in zip(test_input.split(args.batch_size),
468                               test_targets.split(args.batch_size)):
469         output = model(input)
470         loss = criterion(output, targets)
471         test_loss += loss.item()
472
473     log_string(
474         f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}'
475     )
476
477     # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
478     # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
479
480     save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2])
481     save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2])
482
483 ######################################################################