Update.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 10 Jun 2022 09:18:26 +0000 (11:18 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 10 Jun 2022 09:18:26 +0000 (11:18 +0200)
main.py

diff --git a/main.py b/main.py
index a6940f1..a31284e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -136,10 +136,10 @@ class TaskPicoCLVR(Task):
     def batches(self, split = 'train'):
         assert split in { 'train', 'test' }
         if split == 'train':
-            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
+            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
                 yield batch
         else:
-            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
+            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
                 yield batch
 
     def vocabulary_size(self):
@@ -237,7 +237,7 @@ class TaskWiki103(Task):
         if args.data_size > 0:
             data_iter = itertools.islice(data_iter, args.data_size)
 
-        return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
+        return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
 
     def vocabulary_size(self):
         return len(self.vocab)
@@ -296,7 +296,7 @@ class TaskMNIST(Task):
         data_input = data_set.data.view(-1, 28 * 28).long()
         if args.data_size >= 0:
             data_input = data_input[:args.data_size]
-        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
+        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
             yield batch
 
     def vocabulary_size(self):