From 62533ba50393866c15b322074cad836684dd69e7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 3 Dec 2022 14:06:44 -0600 Subject: [PATCH] Update. --- main.py | 67 +++++++++++++++++++++++++++++++++++++++-------------- picoclvr.py | 19 ++++++++++----- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index b6eb6fe..aa1b517 100755 --- a/main.py +++ b/main.py @@ -216,17 +216,11 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) - def produce_results(self, n_epoch, model): + def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False): nb_tokens_to_generate = self.height * self.width + 3 result_descr = [ ] - nb_per_primer = 8 - for primer_descr in [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', - ]: + for primer_descr in primers_descr: results = autoregression( model, @@ -249,18 +243,57 @@ class TaskPicoCLVR(Task): log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') - img = [ - picoclvr.descr2img(d, height = self.height, width = self.width) - for d in result_descr + np=torch.tensor(np) + count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64) + for i in range(count.size(0)): + for j in range(count.size(1)): + count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum() + + if generate_images: + img = [ + picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr + ] + + img = torch.cat(img, 0) + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 + ) + log_string(f'wrote {image_name}') + + return count + + def produce_results(self, n_epoch, model): + primers_descr = [ + 'red above green green top blue right of red ', + 'there is red there is yellow there is blue ', + 'red below yellow yellow below green green below blue red right yellow left green right blue left ', + 'green bottom yellow bottom green left of blue yellow right of blue blue top ', ] - img = torch.cat(img, 0) - image_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image( - img / 255., - image_name, nrow = nb_per_primer, pad_value = 0.8 + self.test_model( + n_epoch, model, + primers_descr, + nb_per_primer=8, generate_images=True ) - log_string(f'wrote {image_name}') + + # FAR TOO SLOW!!! + + # test_primers_descr=[ s.split('')[0] for s in self.test_descr ] + + # count=self.test_model( + # n_epoch, model, + # test_primers_descr, + # nb_per_primer=1, generate_images=False + # ) + + # with open(f'perf_{n_epoch:04d}.txt', 'w') as f: + # for i in range(count.size(0)): + # for j in range(count.size(1)): + # f.write(f'{count[i,j]}') + # f.write(" " if j= max_nb_squares and nb_colors <= len(color_tokens) - 1 @@ -117,6 +118,9 @@ def generate(nb, height, width, s = all_properties(height, width, nb_squares, square_i, square_j, square_c) + if pruning_criterion is not None: + s = list(filter(pruning_criterion,s)) + # pick at most max_nb_properties at random nb_properties = torch.randint(max_nb_properties, (1,)) + 1 @@ -206,23 +210,26 @@ def nb_properties(descr, height, width): ###################################################################### if __name__ == '__main__': - descr = generate(nb = 5) + descr = generate( + nb = 5, height = 12, width = 16, + pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s)) + ) - #print(descr2properties(descr)) - print(nb_properties(descr)) + print(descr2properties(descr, height = 12, width = 16)) + print(nb_properties(descr, height = 12, width = 16)) with open('picoclvr_example.txt', 'w') as f: for d in descr: f.write(f'{d}\n\n') - img = descr2img(descr) + img = descr2img(descr, height = 12, width = 16) torchvision.utils.save_image(img / 255., 'picoclvr_example.png', nrow = 16, pad_value = 0.8) import time start_time = time.perf_counter() - descr = generate(nb = 1000) + descr = generate(nb = 1000, height = 12, width = 16) end_time = time.perf_counter() print(f'{len(descr) / (end_time - start_time):.02f} samples per second') -- 2.20.1