Update.
[mygptrnn.git] / snake.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9 import torch.nn.functional as F
10
11
12 def generate_sequences(
13     nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
14 ):
15     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
16     world_prior_visits = torch.zeros(nb, height, width, device=device)
17
18     # nb x 2
19     snake_position = torch.cat(
20         (
21             torch.randint(height, (nb, 1), device=device),
22             torch.randint(width, (nb, 1), device=device),
23         ),
24         1,
25     )
26     snake_direction = torch.randint(4, (nb,), device=device)
27     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
28     sequences_prior_visits = torch.zeros(
29         nb, 2 * length, device=device, dtype=torch.int64
30     )
31     i = torch.arange(nb, device=device)  # [:,None]
32
33     for l in range(length):
34         # nb x 3
35         snake_next_direction = torch.cat(
36             (
37                 (snake_direction[:, None] - 1) % 4,
38                 snake_direction[:, None],
39                 (snake_direction[:, None] + 1) % 4,
40             ),
41             1,
42         )
43
44         # nb x 3
45         vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
46         vw = snake_next_direction % 2 * (snake_next_direction - 2)
47
48         # nb x 3 x 2
49         snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
50         snake_next_position = snake_position[:, None, :] + snake_next_speed
51
52         # nb x 3
53         val = torch.logical_and(
54             torch.logical_and(
55                 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
56             ),
57             torch.logical_and(
58                 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
59             ),
60         ).float()
61         val = (
62             # The multiplicative factors bias toward moving forward
63             torch.rand_like(val)
64             * val
65             * torch.tensor([[1.0, 2.0, 1.0]], device=device)
66         )
67
68         # nb
69         j = val.argmax(1)
70         snake_direction = snake_next_direction[i, j]
71
72         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
73         sequences_prior_visits[:, 2 * l] = world_prior_visits[
74             i, snake_position[:, 0], snake_position[:, 1]
75         ]
76         if l < prompt_length:
77             world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
78         sequences[:, 2 * l + 1] = snake_direction
79
80         # nb x 2
81         snake_position = snake_next_position[i, j]
82
83     return sequences, sequences_prior_visits, worlds, world_prior_visits
84
85
86 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
87 # exit(0)
88
89
90 def solver(input, ar_mask):
91     for n in range(input.size(0)):
92         i, j, memory = 0, 0, {}
93         # print(input[n])
94         # print(ar_mask[n])
95         for l in range(input.size(1) // 2):
96             if ar_mask[n, 2 * l] == 1:
97                 if memory.get((i, j)) is None:
98                     input[n, 2 * l] = -1
99                 else:
100                     input[n, 2 * l] = memory[(i, j)]
101             else:
102                 # print(f'@3 {memory=}')
103                 if memory.get((i, j)) is None:
104                     memory[(i, j)] = input[n, 2 * l]
105                 else:
106                     assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
107             # print(f'@1 {i=} {j=}')
108             d = input[n, 2 * l + 1].item()
109             i += (d + 1) % 2 * (d - 1)
110             j += d % 2 * (d - 2)
111             # print(f'@2 {i=} {j=}')
112
113
114 def seq2str(seq):
115     return "".join(["NESW123456789"[i] for i in seq])
116
117
118 ######################################################################
119
120 if __name__ == "__main__":
121     train_input, train_prior_visits, _, _ = generate_sequences(
122         nb=20,
123         height=9,
124         width=12,
125         nb_colors=5,
126         length=50,
127         prompt_length=100,
128     )
129
130     print([seq2str(s) for s in train_input])
131
132 ######################################################################