Update.
[mygpt.git] / mygpt.py
index 970ee7b..13fbe8e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -252,6 +252,7 @@ class TaskPicoCLVR(Task):
     def produce_results(self, n_epoch, model, nb_tokens = 50):
         img = [ ]
         nb_per_primer = 8
+
         for primer in [
                 'red above green <sep> green top <sep> blue right of red <img>',
                 'there is red <sep> there is yellow <sep> there is blue <img>',
@@ -507,7 +508,10 @@ for k in range(args.nb_epochs):
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
 
-        log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}')
+        train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
+        test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
+
+        log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
 
         task.produce_results(k, model)