)
parser.add_argument(
- "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack, expr"
+ "--task",
+ type=str,
+ default="picoclvr",
+ help="picoclvr, mnist, maze, snake, stack, expr",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
progress_bar_desc="autoregression",
device=torch.device("cpu"),
):
-
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
train_sequences = expr.generate_sequences(nb_train_samples)
test_sequences = expr.generate_sequences(nb_test_samples)
- self.char2id = dict([ (c,n) for n,c in enumerate(set("".join(train_sequences + test_sequences))) ])
- self.id2char = dict([ (n,c) for n,c in self.char2id.items() ])
+ self.char2id = dict(
+ [
+ (c, n)
+ for n, c in enumerate(set("".join(train_sequences + test_sequences)))
+ ]
+ )
+ self.id2char = dict([(n, c) for n, c in self.char2id.items()])
len_max = max([len(x) for x in train_sequences + test_sequences])
- self.train_input = torch.cat([torch.tensor([char2id(c) for c in s + " "*(len_max-len(s))] for s in train_sequences)], 0)
- self.test_input = torch.cat([torch.tensor([char2id(c) for c in s + " "*(len_max-len(s))] for s in test_sequences)], 0)
+ self.train_input = torch.cat(
+ [
+ torch.tensor(
+ [char2id(c) for c in s + " " * (len_max - len(s))]
+ for s in train_sequences
+ )
+ ],
+ 0,
+ )
+ self.test_input = torch.cat(
+ [
+ torch.tensor(
+ [char2id(c) for c in s + " " * (len_max - len(s))]
+ for s in test_sequences
+ )
+ ],
+ 0,
+ )
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def batches(self, split="train", nb_to_use=-1, desc=None):