From: Francois Fleuret Date: Thu, 15 Nov 2018 10:21:25 +0000 (+0100) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=7abf09dfdb0059f0a0f4d4fcb5892f030ee75e4e Update. --- diff --git a/mine_mnist.py b/mine_mnist.py index c6dc287..82f6530 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -94,20 +94,21 @@ for e in range(nb_epochs): input_br = input_b[torch.randperm(input_b.size(0))] - mi = 0.0 + acc_mi = 0.0 for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), input_b.split(batch_size), input_br.split(batch_size)): - loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()) - mi -= loss.item() + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + loss = - mi + acc_mi += mi.item() optimizer.zero_grad() loss.backward() optimizer.step() - mi /= (input_a.size(0) // batch_size) + acc_mi /= (input_a.size(0) // batch_size) - print('%d %.04f %.04f'%(e, mi / math.log(2), class_entropy / math.log(2))) + print('%d %.04f %.04f'%(e, acc_mi / math.log(2), class_entropy / math.log(2))) sys.stdout.flush() @@ -122,16 +123,16 @@ for e in range(nb_epochs): input_br = input_b[torch.randperm(input_b.size(0))] - mi = 0.0 + acc_mi = 0.0 for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), input_b.split(batch_size), input_br.split(batch_size)): loss = - (model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()) - mi -= loss.item() + acc_mi -= loss.item() - mi /= (input_a.size(0) // batch_size) + acc_mi /= (input_a.size(0) // batch_size) -print('test %.04f %.04f'%(mi / math.log(2), class_entropy / math.log(2))) +print('test %.04f %.04f'%(acc_mi / math.log(2), class_entropy / math.log(2))) ######################################################################