X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=9437136ce1a45b066d6884e205540083bfb4d2d6;hb=c3581ba868cd30cb45fbe2f97b80ddbc1fc26bbb;hp=0f2cb61f5645839cc2e205f1cfd3097fe8075697;hpb=232299b8af7e66a02e64bb2e47b525e2f50b099d;p=picoclvr.git diff --git a/main.py b/main.py index 0f2cb61..9437136 100755 --- a/main.py +++ b/main.py @@ -706,7 +706,7 @@ if args.task == "expr" and args.expr_input_file is not None: # Compute the entropy of the training tokens token_count = 0 -for input in task.batches(split="train"): +for input in task.batches(split="train", desc="train-entropy"): token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) token_probas = token_count / token_count.sum() entropy = -torch.xlogy(token_probas, token_probas).sum() @@ -728,9 +728,13 @@ if args.max_percents_of_test_in_train >= 0: yield s nb_test, nb_in_train = 0, 0 - for test_subset in subsets_as_tuples(task.batches(split="test"), 25000): + for test_subset in subsets_as_tuples( + task.batches(split="test", desc="test-check"), 25000 + ): in_train = set() - for train_subset in subsets_as_tuples(task.batches(split="train"), 25000): + for train_subset in subsets_as_tuples( + task.batches(split="train", desc="train-check"), 25000 + ): in_train.update(test_subset.intersection(train_subset)) nb_in_train += len(in_train) nb_test += len(test_subset)