+ np=torch.tensor(np)
+ count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height = self.height, width = self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f'result_picoclvr_{n_epoch:04d}.png'
+ torchvision.utils.save_image(
+ img / 255.,
+ image_name, nrow = nb_per_primer, pad_value = 0.8
+ )
+ log_string(f'wrote {image_name}')
+
+ return count
+
+ def produce_results(self, n_epoch, model):
+ primers_descr = [
+ 'red above green <sep> green top <sep> blue right of red <img>',
+ 'there is red <sep> there is yellow <sep> there is blue <img>',
+ 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
+ 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',