def batches(self, split = 'train'):
assert split in { 'train', 'test' }
if split == 'train':
def batches(self, split = 'train'):
assert split in { 'train', 'test' }
if split == 'train':
data_input = data_set.data.view(-1, 28 * 28).long()
if args.data_size >= 0:
data_input = data_input[:args.data_size]
data_input = data_set.data.view(-1, 28 * 28).long()
if args.data_size >= 0:
data_input = data_input[:args.data_size]