X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=08aa8caf997a2b54ca3cea8fb29dd784c18820e8;hb=128d372813e99d8474bb6e967d5c7e7f085c819d;hp=a53d213a8baa21715e68c5f3028b865cdacb3534;hpb=ac3d9ba45d72a7f3e399de4e3614698ac5e0ce39;p=picoclvr.git diff --git a/tasks.py b/tasks.py index a53d213..08aa8ca 100755 --- a/tasks.py +++ b/tasks.py @@ -1475,6 +1475,7 @@ class Grid(Task): nb_test_samples, batch_size, size, + fraction_play=0.0, logger=None, device=torch.device("cpu"), ): @@ -1490,10 +1491,12 @@ class Grid(Task): ) self.train_descr = self.grid_factory.generate_samples( - nb_train_samples, lambda r: tqdm.tqdm(r) + nb=nb_train_samples, + fraction_play=fraction_play, + progress_bar=lambda r: tqdm.tqdm(r), ) self.test_descr = self.grid_factory.generate_samples( - nb_test_samples, lambda r: tqdm.tqdm(r) + nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r) ) # Build the tokenizer