Any copyright is dedicated to the Public Domain.
https://creativecommons.org/publicdomain/zero/1.0/

Written by Francois Fleuret
https://fleuret.org/francois

In [None]:
import math

import torch
import torch.nn.functional as F
from torch import nn

import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
mappings = [
 lambda x: x,
 lambda x: torch.sin(x * math.pi),
 lambda x: torch.cos(x * math.pi),
 lambda x: torch.sigmoid(5 * x) * 2 - 1,
 lambda x: 0.25 * x + 0.75 * torch.sign(x),
 lambda x: torch.ceil(x * 2) / 2,
]

mapping_names = [ 'id', 'sin', 'cos', 'sigmoid', 'gap', 'stairs', ]

def comp(n1, n2, x):
 return mappings[n2](mappings[n1](x))

x = torch.linspace(-1, 1, 250)

for f, l in zip(mappings, mapping_names):
 plt.plot(x, f(x), label = l)

plt.legend()
plt.show()

In [None]:
def create_set(nb, probas):
 probas = probas.view(-1) / probas.sum()
 x = torch.rand(nb, device = device) * 2 - 1
 y = x.new(x.size(0), len(mappings)**2, device = device)
 for k in range(len(mappings)**2):
 n1 = k // len(mappings)
 n2 = k % len(mappings)
 y[:, k] = comp(n1, n2, x)
 a = torch.distributions.categorical.Categorical(probas).sample((nb,))
 # y[n][m] = y[n, a[n][m]]
 y = y.gather(dim = 1, index = a[:, None])
 a1 = F.one_hot(a.div(len(mappings), rounding_mode = 'floor'), num_classes = len(mappings))
 a2 = F.one_hot(a%len(mappings), num_classes = len(mappings))
 x = torch.cat((x[:, None], a1 * 2 - 1, a2 * 2 - 1), 1)
 
 return x, y

probas_uniform = torch.full((len(mappings), len(mappings)), 1.0, device = device)

a = torch.arange(len(mappings), device = device)

probas_band = ((a[:, None] - a[None, :])%len(mappings) < len(mappings)/2).float()

probas_blocks = (
 a[:, None].div(len(mappings)//2, rounding_mode = 'floor') -
 a[None, :].div(len(mappings)//2, rounding_mode = 'floor') == 0
).float()

probas_checkboard = ((a[:, None] + a[None, :])%2 == 0).float()

#probas_checkboard = (((a[:, None] + a[None, :])%2 == 0) + (a[:, None] == 0) + (a[None, :] == 0)).float()

print(probas_uniform)
print(probas_band)
print(probas_blocks)
print(probas_checkboard)

In [None]:
def train_model(probas_train, probas_test, nb_samples = 100000, nb_epochs = 25):

 dim_hidden = 64

 model = nn.Sequential(
 nn.Linear(1 + len(mappings) * 2, dim_hidden),
 nn.ReLU(),
 nn.Linear(dim_hidden, dim_hidden),
 nn.ReLU(),
 nn.Linear(dim_hidden, 1),
 ).to(device)
 
 batch_size = 100

 train_input, train_targets = create_set(nb_samples, probas_train)
 test_input, test_targets = create_set(nb_samples, probas_test)
 train_mu, train_std = train_input.mean(), train_input.std()
 train_input = (train_input - train_mu) / train_std
 test_input = (test_input - train_mu) / train_std

 for k in range(nb_epochs):
 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2 /(k + 1))

 acc_train_loss = 0.0

 for input, targets in zip(train_input.split(batch_size),
 train_targets.split(batch_size)):
 output = model(input)
 loss = F.mse_loss(output, targets)
 acc_train_loss += loss.item() * input.size(0)

 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 acc_test_loss = 0.0

 for input, targets in zip(test_input.split(batch_size),
 test_targets.split(batch_size)):
 output = model(input)
 loss = F.mse_loss(output, targets)
 acc_test_loss += loss.item() * input.size(0)

 #print(f'loss {k} {acc_train_loss/train_input.size(0):f} {acc_test_loss/test_input.size(0):f}')

 return train_mu, train_std, model

def prediction(model, mu, std, n1, n2, x):
 h1 = F.one_hot(torch.full((x.size(0),), n1, device = device), num_classes = len(mappings)) * 2 - 1
 h2 = F.one_hot(torch.full((x.size(0),), n2, device = device), num_classes = len(mappings)) * 2 - 1
 input = torch.cat((x[:, None], h1, h2), dim = 1)
 input = (input - mu) / std
 return model(input).view(-1).detach()

In [None]:
def plot_result(probas_train):
 
 train_mu, train_std, model = train_model(
 probas_train = probas_train,
 probas_test = probas_uniform,
 )

 e = torch.empty(len(mappings), len(mappings))

 x = torch.linspace(-1, 1, 250, device = device)

 for n1 in range(len(mappings)):
 for n2 in range(len(mappings)):
 gt = comp(n1, n2, x)
 pr = prediction(model, train_mu, train_std, n1, n2, x)
 e[n1, n2] = F.mse_loss(gt, pr)
 
 plt.matshow(e, cmap = plt.cm.Blues, vmin = 0, vmax = 1)
 
plot_result(probas_uniform)
plot_result(probas_band)
plot_result(probas_blocks)
plot_result(probas_checkboard)

In [None]:
train_mu, train_std, model = train_model(
 probas_train = probas_checkboard,
 probas_test = probas_uniform,
)

x = torch.linspace(-1, 1, 250, device = device)

for n1, n2 in [ (1, 5), (1, 2), (5, 3), (4, 5) ]:
 plt.plot(x.to('cpu'), comp(n1, n2, x).to('cpu'), label = 'ground truth')
 plt.plot(x.to('cpu'), prediction(model, train_mu, train_std, n1, n2, x).to('cpu'), label = 'prediction')
 plt.legend()
 plt.show()