Added cuda support.
[ipynb.git] / understanding.ipynb
1 {
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "a4e1cad9",
6    "metadata": {
7     "scrolled": true
8    },
9    "source": [
10     "Any copyright is dedicated to the Public Domain.\n",
11     "https://creativecommons.org/publicdomain/zero/1.0/\n",
12     "\n",
13     "Written by Francois Fleuret\n",
14     "https://fleuret.org/francois"
15    ]
16   },
17   {
18    "cell_type": "code",
19    "execution_count": null,
20    "id": "b0f4c709",
21    "metadata": {},
22    "outputs": [],
23    "source": [
24     "import math\n",
25     "\n",
26     "import torch\n",
27     "import torch.nn.functional as F\n",
28     "from torch import nn\n",
29     "\n",
30     "import matplotlib.pyplot as plt\n",
31     "\n",
32     "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
33    ]
34   },
35   {
36    "cell_type": "code",
37    "execution_count": null,
38    "id": "2045fb7d",
39    "metadata": {
40     "scrolled": true
41    },
42    "outputs": [],
43    "source": [
44     "mappings = [\n",
45     "    lambda x: x,\n",
46     "    lambda x: torch.sin(x * math.pi),\n",
47     "    lambda x: torch.cos(x * math.pi),\n",
48     "    lambda x: torch.sigmoid(5 * x) * 2 - 1,\n",
49     "    lambda x: 0.25 * x + 0.75 * torch.sign(x),\n",
50     "    lambda x: torch.ceil(x * 2) / 2,\n",
51     "]\n",
52     "\n",
53     "mapping_names = [ 'id', 'sin', 'cos', 'sigmoid', 'gap', 'stairs', ]\n",
54     "\n",
55     "def comp(n1, n2, x):\n",
56     "    return mappings[n2](mappings[n1](x))\n",
57     "\n",
58     "x = torch.linspace(-1, 1, 250)\n",
59     "\n",
60     "for f, l in zip(mappings, mapping_names):\n",
61     "    plt.plot(x, f(x), label = l)\n",
62     "\n",
63     "plt.legend()\n",
64     "plt.show()"
65    ]
66   },
67   {
68    "cell_type": "code",
69    "execution_count": null,
70    "id": "601b926f",
71    "metadata": {},
72    "outputs": [],
73    "source": [
74     "def create_set(nb, probas):\n",
75     "    probas = probas.view(-1) / probas.sum()\n",
76     "    x = torch.rand(nb, device = device) * 2 - 1\n",
77     "    y = x.new(x.size(0), len(mappings)**2, device = device)\n",
78     "    for k in range(len(mappings)**2):\n",
79     "        n1 = k // len(mappings)\n",
80     "        n2 = k % len(mappings)\n",
81     "        y[:, k] = comp(n1, n2, x)\n",
82     "    a = torch.distributions.categorical.Categorical(probas).sample((nb,))\n",
83     "    # y[n][m] = y[n, a[n][m]]\n",
84     "    y = y.gather(dim = 1, index = a[:, None])\n",
85     "    a1 = F.one_hot(a.div(len(mappings), rounding_mode = 'floor'), num_classes = len(mappings))\n",
86     "    a2 = F.one_hot(a%len(mappings), num_classes = len(mappings))\n",
87     "    x = torch.cat((x[:, None], a1 * 2 - 1, a2 * 2 - 1), 1)\n",
88     "    \n",
89     "    return x, y\n",
90     "\n",
91     "probas_uniform = torch.full((len(mappings), len(mappings)), 1.0, device = device)\n",
92     "\n",
93     "a = torch.arange(len(mappings), device = device)\n",
94     "\n",
95     "probas_band = ((a[:, None] - a[None, :])%len(mappings) < len(mappings)/2).float()\n",
96     "\n",
97     "probas_blocks = (\n",
98     "     a[:, None].div(len(mappings)//2, rounding_mode = 'floor') -\n",
99     "     a[None, :].div(len(mappings)//2, rounding_mode = 'floor') == 0\n",
100     ").float()\n",
101     "\n",
102     "probas_checkboard = ((a[:, None] + a[None, :])%2 == 0).float()\n",
103     "\n",
104     "#probas_checkboard = (((a[:, None] + a[None, :])%2 == 0) + (a[:, None] == 0) + (a[None, :] == 0)).float()\n",
105     "\n",
106     "print(probas_uniform)\n",
107     "print(probas_band)\n",
108     "print(probas_blocks)\n",
109     "print(probas_checkboard)"
110    ]
111   },
112   {
113    "cell_type": "code",
114    "execution_count": null,
115    "id": "16ec81ec",
116    "metadata": {},
117    "outputs": [],
118    "source": [
119     "def train_model(probas_train, probas_test, nb_samples = 100000, nb_epochs = 25):\n",
120     "\n",
121     "    dim_hidden = 64\n",
122     "\n",
123     "    model = nn.Sequential(\n",
124     "        nn.Linear(1 + len(mappings) * 2, dim_hidden),\n",
125     "        nn.ReLU(),\n",
126     "        nn.Linear(dim_hidden, dim_hidden),\n",
127     "        nn.ReLU(),\n",
128     "        nn.Linear(dim_hidden, 1),\n",
129     "    ).to(device)\n",
130     "    \n",
131     "    batch_size = 100\n",
132     "\n",
133     "    train_input, train_targets = create_set(nb_samples, probas_train)\n",
134     "    test_input, test_targets = create_set(nb_samples, probas_test)\n",
135     "    train_mu, train_std = train_input.mean(), train_input.std()\n",
136     "    train_input = (train_input - train_mu) / train_std\n",
137     "    test_input = (test_input - train_mu) / train_std\n",
138     "\n",
139     "    for k in range(nb_epochs):\n",
140     "        optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2 /(k + 1))\n",
141     "\n",
142     "        acc_train_loss = 0.0\n",
143     "\n",
144     "        for input, targets in zip(train_input.split(batch_size),\n",
145     "                                  train_targets.split(batch_size)):\n",
146     "            output = model(input)\n",
147     "            loss = F.mse_loss(output, targets)\n",
148     "            acc_train_loss += loss.item() * input.size(0)\n",
149     "\n",
150     "            optimizer.zero_grad()\n",
151     "            loss.backward()\n",
152     "            optimizer.step()\n",
153     "        \n",
154     "        acc_test_loss = 0.0\n",
155     "\n",
156     "        for input, targets in zip(test_input.split(batch_size),\n",
157     "                                  test_targets.split(batch_size)):\n",
158     "            output = model(input)\n",
159     "            loss = F.mse_loss(output, targets)\n",
160     "            acc_test_loss += loss.item() * input.size(0)\n",
161     "\n",
162     "        #print(f'loss {k} {acc_train_loss/train_input.size(0):f} {acc_test_loss/test_input.size(0):f}')\n",
163     "\n",
164     "    return train_mu, train_std, model\n",
165     "\n",
166     "def prediction(model, mu, std, n1, n2, x):\n",
167     "    h1 = F.one_hot(torch.full((x.size(0),), n1, device = device), num_classes = len(mappings)) * 2 - 1\n",
168     "    h2 = F.one_hot(torch.full((x.size(0),), n2, device = device), num_classes = len(mappings)) * 2 - 1\n",
169     "    input = torch.cat((x[:, None], h1, h2), dim = 1)\n",
170     "    input = (input - mu) / std\n",
171     "    return model(input).view(-1).detach()"
172    ]
173   },
174   {
175    "cell_type": "code",
176    "execution_count": null,
177    "id": "2aad3e36",
178    "metadata": {},
179    "outputs": [],
180    "source": [
181     "def plot_result(probas_train):\n",
182     "    \n",
183     "    train_mu, train_std, model = train_model(\n",
184     "        probas_train = probas_train,\n",
185     "        probas_test = probas_uniform,\n",
186     "    )\n",
187     "\n",
188     "    e = torch.empty(len(mappings), len(mappings))\n",
189     "\n",
190     "    x = torch.linspace(-1, 1, 250, device = device)\n",
191     "\n",
192     "    for n1 in range(len(mappings)):\n",
193     "        for n2 in range(len(mappings)):\n",
194     "            gt = comp(n1, n2, x)\n",
195     "            pr = prediction(model, train_mu, train_std, n1, n2, x)\n",
196     "            e[n1, n2] = F.mse_loss(gt, pr)\n",
197     "        \n",
198     "    plt.matshow(e, cmap = plt.cm.Blues, vmin = 0, vmax = 1)\n",
199     "    \n",
200     "plot_result(probas_uniform)\n",
201     "plot_result(probas_band)\n",
202     "plot_result(probas_blocks)\n",
203     "plot_result(probas_checkboard)"
204    ]
205   },
206   {
207    "cell_type": "code",
208    "execution_count": null,
209    "id": "93234c68",
210    "metadata": {},
211    "outputs": [],
212    "source": [
213     "train_mu, train_std, model = train_model(\n",
214     "    probas_train = probas_checkboard,\n",
215     "    probas_test = probas_uniform,\n",
216     ")\n",
217     "\n",
218     "x = torch.linspace(-1, 1, 250, device = device)\n",
219     "\n",
220     "for n1, n2 in [ (1, 5), (1, 2), (5, 3), (4, 5) ]:\n",
221     "    plt.plot(x.to('cpu'), comp(n1, n2, x).to('cpu'), label = 'ground truth')\n",
222     "    plt.plot(x.to('cpu'), prediction(model, train_mu, train_std, n1, n2, x).to('cpu'), label = 'prediction')\n",
223     "    plt.legend()\n",
224     "    plt.show()"
225    ]
226   },
227   {
228    "cell_type": "code",
229    "execution_count": null,
230    "id": "4492c9d6",
231    "metadata": {},
232    "outputs": [],
233    "source": []
234   }
235  ],
236  "metadata": {
237   "kernelspec": {
238    "display_name": "Python 3 (ipykernel)",
239    "language": "python",
240    "name": "python3"
241   },
242   "language_info": {
243    "codemirror_mode": {
244     "name": "ipython",
245     "version": 3
246    },
247    "file_extension": ".py",
248    "mimetype": "text/x-python",
249    "name": "python",
250    "nbconvert_exporter": "python",
251    "pygments_lexer": "ipython3",
252    "version": "3.9.12"
253   }
254  },
255  "nbformat": 4,
256  "nbformat_minor": 5
257 }