+
+input_a, input_b, count = create_pair_set(used_classes, test_input, test_target)
+
+for e in range(nb_epochs):
+ class_proba = count.float()
+ class_proba /= class_proba.sum()
+ class_entropy = - (class_proba.log() * class_proba).sum().item()
+
+ input_br = input_b[torch.randperm(input_b.size(0))]
+
+ mi = 0.0
+
+ for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
+ input_b.split(batch_size),
+ input_br.split(batch_size)):
+ loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log())
+ mi -= loss.item()
+
+ mi /= (input_a.size(0) // batch_size)
+
+print('test %.04f %.04f'%(mi / math.log(2), class_entropy / math.log(2)))
+
+######################################################################