nb_train_samples,
nb_test_samples,
batch_size,
- height,
- width,
+ size,
logger=None,
device=torch.device("cpu"),
):
self.device = device
self.batch_size = batch_size
- self.grid_factory = grid.GridFactory(height=height, width=width)
+ self.grid_factory = grid.GridFactory(size=size)
if logger is not None:
logger(