X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=tinymnist.py;h=896477e9527adec5a362a03bdb212da39fd6a413;hp=8642b22082f844864af3fecd4521d04d09c1df67;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/tinymnist.py b/tinymnist.py index 8642b22..896477e 100755 --- a/tinymnist.py +++ b/tinymnist.py @@ -12,47 +12,49 @@ from torch.nn import functional as F lr, nb_epochs, batch_size = 1e-1, 10, 100 -data_dir = os.environ.get('PYTORCH_DATA_DIR') or './data/' +data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/" -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -train_set = torchvision.datasets.MNIST(root = data_dir, train = True, download = True) +train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True) train_input = train_set.data.view(-1, 1, 28, 28).float() train_targets = train_set.targets -test_set = torchvision.datasets.MNIST(root = data_dir, train = False, download = True) +test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True) test_input = test_set.data.view(-1, 1, 28, 28).float() test_targets = test_set.targets ###################################################################### + class SomeLeNet(nn.Module): def __init__(self): super().__init__() - self.conv1 = nn.Conv2d(1, 32, kernel_size = 5) - self.conv2 = nn.Conv2d(32, 64, kernel_size = 5) + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) self.fc1 = nn.Linear(256, 200) self.fc2 = nn.Linear(200, 10) def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3)) - x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2)) + x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3)) + x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2)) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x + ###################################################################### model = SomeLeNet() nb_parameters = sum(p.numel() for p in model.parameters()) -print(f'nb_parameters {nb_parameters}') +print(f"nb_parameters {nb_parameters}") -optimizer = torch.optim.SGD(model.parameters(), lr = lr) +optimizer = torch.optim.SGD(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() model.to(device) @@ -68,10 +70,11 @@ test_input.sub_(mu).div_(std) start_time = time.perf_counter() for k in range(nb_epochs): - acc_loss = 0. + acc_loss = 0.0 - for input, targets in zip(train_input.split(batch_size), - train_targets.split(batch_size)): + for input, targets in zip( + train_input.split(batch_size), train_targets.split(batch_size) + ): output = model(input) loss = criterion(output, targets) acc_loss += loss.item() @@ -81,13 +84,14 @@ for k in range(nb_epochs): optimizer.step() nb_test_errors = 0 - for input, targets in zip(test_input.split(batch_size), - test_targets.split(batch_size)): + for input, targets in zip( + test_input.split(batch_size), test_targets.split(batch_size) + ): wta = model(input).argmax(1) nb_test_errors += (wta != targets).long().sum() test_error = nb_test_errors / test_input.size(0) duration = time.perf_counter() - start_time - print(f'loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%') + print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%") ######################################################################