Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 0f2cb61..9437136 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -706,7 +706,7 @@ if args.task == "expr" and args.expr_input_file is not None:
 # Compute the entropy of the training tokens
 
 token_count = 0
-for input in task.batches(split="train"):
+for input in task.batches(split="train", desc="train-entropy"):
     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
 token_probas = token_count / token_count.sum()
 entropy = -torch.xlogy(token_probas, token_probas).sum()
@@ -728,9 +728,13 @@ if args.max_percents_of_test_in_train >= 0:
         yield s
 
     nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
+    for test_subset in subsets_as_tuples(
+        task.batches(split="test", desc="test-check"), 25000
+    ):
         in_train = set()
-        for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
+        for train_subset in subsets_as_tuples(
+            task.batches(split="train", desc="train-check"), 25000
+        ):
             in_train.update(test_subset.intersection(train_subset))
         nb_in_train += len(in_train)
         nb_test += len(test_subset)