X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=066f1bbec05fcc0d65365823ea60931010e118cb;hb=26ef53ee3769c3b6b92b85d15b5a43cbd18ede07;hp=ea10d7cbfc758373e8e75f7d419b45bec16d3ad6;hpb=f44ab6863f93ae348e66ffbf52251d96d3b5453c;p=picoclvr.git diff --git a/tasks.py b/tasks.py index ea10d7c..066f1bb 100755 --- a/tasks.py +++ b/tasks.py @@ -1570,39 +1570,28 @@ class QMLP(Task): self.device = device self.batch_size = batch_size + self.nb_samples_per_mlp = 256 if logger is not None: logger( f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" ) - self.train_descr = self.grid_factory.generate_samples( - nb_train_samples, lambda r: tqdm.tqdm(r) - ) - self.test_descr = self.grid_factory.generate_samples( - nb_test_samples, lambda r: tqdm.tqdm(r) + seq, q_test_set = generate_sequence_and_test_set( + nb_mlps=nb_train_samples+nb_test_samples, + nb_samples=self.nb_samples_per_mlp, + device=self.device, + batch_size=64, + nb_epochs=250, + nb_mlps_per_batch=1024 ) - # Build the tokenizer - tokens = set() - for d in [self.train_descr, self.test_descr]: - for s in d: - for t in s.strip().split(" "): - tokens.add(t) - # make this set a sorted list to get the same tensors given - # the same descr - tokens = list(tokens) - tokens.sort() - tokens = ["#"] + tokens - self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) - self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) - self.t_nul = self.token2id["#"] - self.t_true = self.token2id["true"] - self.t_false = self.token2id["false"] + self.train_input = seq[:nb_train_samples] + self.train_q_test_set = q_test_set[:nb_train_samples] + self.test_input = seq[nb_train_samples:] + self.test_q_test_set = q_test_set[nb_train_samples:] - # Tokenize the train and test sets - self.train_input = self.str2tensor(self.train_descr) - self.test_input = self.str2tensor(self.test_descr) + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train"): assert split in {"train", "test"} @@ -1613,14 +1602,14 @@ class QMLP(Task): yield self.trim(batch) def vocabulary_size(self): - return len(self.token2id) + return self.nb_codes def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): correct = self.test_input[:1000] result = correct.clone() - ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long() + ar_mask = torch.arange(result.size(1)) > self.nb_samples_per_mlp * 3 + 1 result *= 1 - ar_mask # paraaaaanoiaaaaaaa logger(f"----------------------------------------------------------") @@ -1644,11 +1633,11 @@ class QMLP(Task): logger(f"----------------------------------------------------------") - nb_total = ar_mask.sum().item() - nb_correct = ((correct == result).long() * ar_mask).sum().item() + q_train_set = result[:, : nb_samples * 3] + q_params = result[:, nb_samples * 3 + 1 :] + error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17) - logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}") - logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}") + logger(f"{error_test=}") ######################################################################