class TaskPicoCLVR(Task):
- def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
+ def __init__(self, batch_size,
+ height = 6, width = 8, many_colors = False,
+ device = torch.device('cpu')):
+
self.batch_size = batch_size
self.device = device
nb = args.data_size if args.data_size > 0 else 250000
- descr = picoclvr.generate(nb, height = height, width = width)
+ descr = picoclvr.generate(
+ nb,
+ height = height, width = width,
+ many_colors = many_colors
+ )
+
descr = [ s.strip().split(' ') for s in descr ]
l = max([ len(s) for s in descr ])
descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
def batches(self, split = 'train'):
assert split in { 'train', 'test' }
if split == 'train':
- for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
+ for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
yield batch
else:
- for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
+ for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
yield batch
def vocabulary_size(self):
if args.data_size > 0:
data_iter = itertools.islice(data_iter, args.data_size)
- return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
+ return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
def vocabulary_size(self):
return len(self.vocab)
data_input = data_set.data.view(-1, 28 * 28).long()
if args.data_size >= 0:
data_input = data_input[:args.data_size]
- for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
+ for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
yield batch
def vocabulary_size(self):