Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11
12 class Box:
13     def __init__(self, x, y, w, h, r, g, b):
14         self.x = x
15         self.y = y
16         self.w = w
17         self.h = h
18         self.r = r
19         self.g = g
20         self.b = b
21
22     def collision(self, scene):
23         for c in scene:
24             if (
25                 self is not c
26                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
27                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
28             ):
29                 return True
30         return False
31
32
33 def scene2tensor(xh, yh, scene, size=64):
34     width, height = size, size
35     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
36     data = pixel_map.numpy()
37     surface = cairo.ImageSurface.create_for_data(
38         data, cairo.FORMAT_ARGB32, width, height
39     )
40
41     ctx = cairo.Context(surface)
42     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
43
44     for b in scene:
45         ctx.move_to(b.x * size, b.y * size)
46         ctx.rel_line_to(b.w * size, 0)
47         ctx.rel_line_to(0, b.h * size)
48         ctx.rel_line_to(-b.w * size, 0)
49         ctx.close_path()
50         ctx.set_source_rgba(b.r, b.g, b.b, 1.0)
51         ctx.fill()
52
53     hs = size * 0.1
54     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
55     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
56     ctx.rel_line_to(hs, 0)
57     ctx.rel_line_to(0, hs)
58     ctx.rel_line_to(-hs, 0)
59     ctx.close_path()
60     ctx.fill()
61
62     return pixel_map[None, :, :, :3].flip(-1).permute(0, 3, 1, 2).float() / 255
63
64
65 def random_scene():
66     scene = []
67     colors = [
68         (1.00, 0.00, 0.00),
69         (0.00, 1.00, 0.00),
70         (0.00, 0.00, 1.00),
71         (1.00, 1.00, 0.00),
72         (0.75, 0.75, 0.75),
73     ]
74
75     for k in range(10):
76         wh = torch.rand(2) * 0.2 + 0.2
77         xy = torch.rand(2) * (1 - wh)
78         c = colors[torch.randint(len(colors), (1,))]
79         b = Box(
80             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
81         )
82         if not b.collision(scene):
83             scene.append(b)
84
85     return scene
86
87
88 def sequence(nb_steps=10, all_frames=False):
89     delta = 0.1
90     effects = [
91         (False, 0, 0),
92         (False, delta, 0),
93         (False, 0, delta),
94         (False, -delta, 0),
95         (False, 0, -delta),
96         (True, delta, 0),
97         (True, 0, delta),
98         (True, -delta, 0),
99         (True, 0, -delta),
100     ]
101
102     while True:
103
104         frames =[]
105
106         scene = random_scene()
107         xh, yh = tuple(x.item() for x in torch.rand(2))
108
109         frames.append(scene2tensor(xh, yh, scene))
110
111         actions = torch.randint(len(effects), (nb_steps,))
112         change = False
113
114         for a in actions:
115             g, dx, dy = effects[a]
116             if g:
117                 for b in scene:
118                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
119                         x, y = b.x, b.y
120                         b.x += dx
121                         b.y += dy
122                         if (
123                             b.x < 0
124                             or b.y < 0
125                             or b.x + b.w > 1
126                             or b.y + b.h > 1
127                             or b.collision(scene)
128                         ):
129                             b.x, b.y = x, y
130                         else:
131                             xh += dx
132                             yh += dy
133                             change = True
134             else:
135                 x, y = xh, yh
136                 xh += dx
137                 yh += dy
138                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
139                     xh, yh = x, y
140
141             if all_frames:
142                 frames.append(scene2tensor(xh, yh, scene))
143
144         if not all_frames:
145             frames.append(scene2tensor(xh, yh, scene))
146
147         if change:
148             break
149
150     return frames, actions
151
152
153 if __name__ == "__main__":
154     frames, actions = sequence(nb_steps=31,all_frames=True)
155     frames = torch.cat(frames,0)
156     print(f"{frames.size()=}")
157     torchvision.utils.save_image(frames, "seq.png", nrow=8)