parser.add_argument("--snake_nb_colors", type=int, default=5)
-parser.add_argument("--snake_length", type=int, default=400)
+parser.add_argument("--snake_length", type=int, default=200)
######################################################################
"batch_size": 25,
},
"snake": {
- "nb_epochs": 25,
- "batch_size": 20,
+ "nb_epochs": 5,
+ "batch_size": 25,
},
}
self.device = device
self.prompt_length = prompt_length
- self.train_input, self.train_prior_visits = snake.generate_sequences(
+ self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
nb_train_samples,
height,
width,
prompt_length,
self.device,
)
- self.test_input, self.test_prior_visits = snake.generate_sequences(
+ self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
nb_test_samples,
height,
width,