{ "cells": [ { "cell_type": "markdown", "id": "a4e1cad9", "metadata": { "scrolled": true }, "source": [ "Any copyright is dedicated to the Public Domain.\n", "https://creativecommons.org/publicdomain/zero/1.0/\n", "\n", "Written by Francois Fleuret\n", "https://fleuret.org/francois" ] }, { "cell_type": "code", "execution_count": null, "id": "b0f4c709", "metadata": {}, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": null, "id": "2045fb7d", "metadata": { "scrolled": true }, "outputs": [], "source": [ "mappings = [\n", " lambda x: x,\n", " lambda x: torch.sin(x * math.pi),\n", " lambda x: torch.cos(x * math.pi),\n", " lambda x: torch.sigmoid(5 * x) * 2 - 1,\n", " lambda x: 0.25 * x + 0.75 * torch.sign(x),\n", " lambda x: torch.ceil(x * 2) / 2,\n", "]\n", "\n", "mapping_names = [ 'id', 'sin', 'cos', 'sigmoid', 'gap', 'stairs', ]\n", "\n", "def comp(n1, n2, x):\n", " return mappings[n2](mappings[n1](x))\n", "\n", "x = torch.linspace(-1, 1, 250)\n", "\n", "for f, l in zip(mappings, mapping_names):\n", " plt.plot(x, f(x), label = l)\n", "\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "601b926f", "metadata": {}, "outputs": [], "source": [ "def create_set(nb, probas):\n", " probas = probas.view(-1) / probas.sum()\n", " x = torch.rand(nb, device = device) * 2 - 1\n", " y = x.new(x.size(0), len(mappings)**2, device = device)\n", " for k in range(len(mappings)**2):\n", " n1 = k // len(mappings)\n", " n2 = k % len(mappings)\n", " y[:, k] = comp(n1, n2, x)\n", " a = torch.distributions.categorical.Categorical(probas).sample((nb,))\n", " # y[n][m] = y[n, a[n][m]]\n", " y = y.gather(dim = 1, index = a[:, None])\n", " a1 = F.one_hot(a.div(len(mappings), rounding_mode = 'floor'), num_classes = len(mappings))\n", " a2 = F.one_hot(a%len(mappings), num_classes = len(mappings))\n", " x = torch.cat((x[:, None], a1 * 2 - 1, a2 * 2 - 1), 1)\n", " \n", " return x, y\n", "\n", "probas_uniform = torch.full((len(mappings), len(mappings)), 1.0, device = device)\n", "\n", "a = torch.arange(len(mappings), device = device)\n", "\n", "probas_band = ((a[:, None] - a[None, :])%len(mappings) < len(mappings)/2).float()\n", "\n", "probas_blocks = (\n", " a[:, None].div(len(mappings)//2, rounding_mode = 'floor') -\n", " a[None, :].div(len(mappings)//2, rounding_mode = 'floor') == 0\n", ").float()\n", "\n", "probas_checkboard = ((a[:, None] + a[None, :])%2 == 0).float()\n", "\n", "#probas_checkboard = (((a[:, None] + a[None, :])%2 == 0) + (a[:, None] == 0) + (a[None, :] == 0)).float()\n", "\n", "print(probas_uniform)\n", "print(probas_band)\n", "print(probas_blocks)\n", "print(probas_checkboard)" ] }, { "cell_type": "code", "execution_count": null, "id": "16ec81ec", "metadata": {}, "outputs": [], "source": [ "def train_model(probas_train, probas_test, nb_samples = 100000, nb_epochs = 25):\n", "\n", " dim_hidden = 64\n", "\n", " model = nn.Sequential(\n", " nn.Linear(1 + len(mappings) * 2, dim_hidden),\n", " nn.ReLU(),\n", " nn.Linear(dim_hidden, dim_hidden),\n", " nn.ReLU(),\n", " nn.Linear(dim_hidden, 1),\n", " ).to(device)\n", " \n", " batch_size = 100\n", "\n", " train_input, train_targets = create_set(nb_samples, probas_train)\n", " test_input, test_targets = create_set(nb_samples, probas_test)\n", " train_mu, train_std = train_input.mean(), train_input.std()\n", " train_input = (train_input - train_mu) / train_std\n", " test_input = (test_input - train_mu) / train_std\n", "\n", " for k in range(nb_epochs):\n", " optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2 /(k + 1))\n", "\n", " acc_train_loss = 0.0\n", "\n", " for input, targets in zip(train_input.split(batch_size),\n", " train_targets.split(batch_size)):\n", " output = model(input)\n", " loss = F.mse_loss(output, targets)\n", " acc_train_loss += loss.item() * input.size(0)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " acc_test_loss = 0.0\n", "\n", " for input, targets in zip(test_input.split(batch_size),\n", " test_targets.split(batch_size)):\n", " output = model(input)\n", " loss = F.mse_loss(output, targets)\n", " acc_test_loss += loss.item() * input.size(0)\n", "\n", " #print(f'loss {k} {acc_train_loss/train_input.size(0):f} {acc_test_loss/test_input.size(0):f}')\n", "\n", " return train_mu, train_std, model\n", "\n", "def prediction(model, mu, std, n1, n2, x):\n", " h1 = F.one_hot(torch.full((x.size(0),), n1, device = device), num_classes = len(mappings)) * 2 - 1\n", " h2 = F.one_hot(torch.full((x.size(0),), n2, device = device), num_classes = len(mappings)) * 2 - 1\n", " input = torch.cat((x[:, None], h1, h2), dim = 1)\n", " input = (input - mu) / std\n", " return model(input).view(-1).detach()" ] }, { "cell_type": "code", "execution_count": null, "id": "2aad3e36", "metadata": {}, "outputs": [], "source": [ "def plot_result(probas_train):\n", " \n", " train_mu, train_std, model = train_model(\n", " probas_train = probas_train,\n", " probas_test = probas_uniform,\n", " )\n", "\n", " e = torch.empty(len(mappings), len(mappings))\n", "\n", " x = torch.linspace(-1, 1, 250, device = device)\n", "\n", " for n1 in range(len(mappings)):\n", " for n2 in range(len(mappings)):\n", " gt = comp(n1, n2, x)\n", " pr = prediction(model, train_mu, train_std, n1, n2, x)\n", " e[n1, n2] = F.mse_loss(gt, pr)\n", " \n", " plt.matshow(e, cmap = plt.cm.Blues, vmin = 0, vmax = 1)\n", " \n", "plot_result(probas_uniform)\n", "plot_result(probas_band)\n", "plot_result(probas_blocks)\n", "plot_result(probas_checkboard)" ] }, { "cell_type": "code", "execution_count": null, "id": "93234c68", "metadata": {}, "outputs": [], "source": [ "train_mu, train_std, model = train_model(\n", " probas_train = probas_checkboard,\n", " probas_test = probas_uniform,\n", ")\n", "\n", "x = torch.linspace(-1, 1, 250, device = device)\n", "\n", "for n1, n2 in [ (1, 5), (1, 2), (5, 3), (4, 5) ]:\n", " plt.plot(x.to('cpu'), comp(n1, n2, x).to('cpu'), label = 'ground truth')\n", " plt.plot(x.to('cpu'), prediction(model, train_mu, train_std, n1, n2, x).to('cpu'), label = 'prediction')\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "4492c9d6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }