nb_test_samples,
batch_size,
size,
+ fraction_play=0.0,
logger=None,
device=torch.device("cpu"),
):
)
self.train_descr = self.grid_factory.generate_samples(
- nb_train_samples, lambda r: tqdm.tqdm(r)
+ nb=nb_train_samples,
+ fraction_play=fraction_play,
+ progress_bar=lambda r: tqdm.tqdm(r),
)
self.test_descr = self.grid_factory.generate_samples(
- nb_test_samples, lambda r: tqdm.tqdm(r)
+ nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
)
# Build the tokenizer