1 #!/usr/bin/env python-for-pytorch
5 from torchvision import datasets
7 ######################################################################
9 def sequences_to_image(x):
12 nb_sequences = x.size(0)
13 nb_images_per_sequences = x.size(1)
16 if x.size(2) != nb_channels:
17 print('Can only handle 3 channel tensors.')
23 gap_color = (0, 128, 255)
25 result = torch.ByteTensor(nb_channels,
26 gap + nb_sequences * (height + gap),
27 gap + nb_images_per_sequences * (width + gap))
29 result[0].fill_(gap_color[0])
30 result[1].fill_(gap_color[1])
31 result[2].fill_(gap_color[2])
33 for s in range(0, nb_sequences):
34 for i in range(0, nb_images_per_sequences):
35 result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i])
37 result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
39 return Image.fromarray(result_numpy, 'RGB')
41 ######################################################################
43 from _ext import flatland
45 x = torch.ByteTensor()
47 flatland.generate_sequence(10, x)
49 sequences_to_image(x).save('sequences.png')