X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=tinyae.py;h=70484f151c866d57e1ee8cd1b25c2b3d5b87d483;hp=c608c9ccb5e5c5e45bf786ae40242abc31fe2961;hb=HEAD;hpb=dcb8e93a2f882abf1a30326fe419a592484deb18 diff --git a/tinyae.py b/tinyae.py index c608c9c..b4f3aba 100755 --- a/tinyae.py +++ b/tinyae.py @@ -14,77 +14,75 @@ from torch.nn import functional as F ###################################################################### -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -parser = argparse.ArgumentParser(description = 'Tiny LeNet-like auto-encoder.') +parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") -parser.add_argument('--nb_epochs', - type = int, default = 25) +parser.add_argument("--nb_epochs", type=int, default=25) -parser.add_argument('--batch_size', - type = int, default = 100) +parser.add_argument("--batch_size", type=int, default=100) -parser.add_argument('--data_dir', - type = str, default = './data/') +parser.add_argument("--data_dir", type=str, default="./data/") -parser.add_argument('--log_filename', - type = str, default = 'train.log') +parser.add_argument("--log_filename", type=str, default="train.log") -parser.add_argument('--embedding_dim', - type = int, default = 8) +parser.add_argument("--embedding_dim", type=int, default=8) -parser.add_argument('--nb_channels', - type = int, default = 32) - -parser.add_argument('--force_train', - type = bool, default = False) +parser.add_argument("--nb_channels", type=int, default=32) args = parser.parse_args() -log_file = open(args.log_filename, 'w') +log_file = open(args.log_filename, "w") ###################################################################### + def log_string(s): t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) if log_file is not None: - log_file.write(t + s + '\n') + log_file.write(t + s + "\n") log_file.flush() print(t + s) sys.stdout.flush() + ###################################################################### + class AutoEncoder(nn.Module): def __init__(self, nb_channels, embedding_dim): - super(AutoEncoder, self).__init__() + super().__init__() self.encoder = nn.Sequential( - nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4) + nn.Conv2d(1, nb_channels, kernel_size=5), # to 24x24 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size=4), ) self.decoder = nn.Sequential( - nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4), - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24 + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=3, stride=2 + ), # from 4x4 + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=4, stride=2 + ), # from 9x9 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, 1, kernel_size=5), # from 24x24 ) def encode(self, x): @@ -98,20 +96,23 @@ class AutoEncoder(nn.Module): x = self.decoder(x) return x + ###################################################################### -train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', - train = True, download = True) +train_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=True, download=True +) train_input = train_set.data.view(-1, 1, 28, 28).float() -test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', - train = False, download = True) +test_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=False, download=True +) test_input = test_set.data.view(-1, 1, 28, 28).float() ###################################################################### model = AutoEncoder(args.nb_channels, args.embedding_dim) -optimizer = optim.Adam(model.parameters(), lr = 1e-3) +optimizer = optim.Adam(model.parameters(), lr=1e-3) model.to(device) @@ -124,7 +125,6 @@ test_input.sub_(mu).div_(std) ###################################################################### for epoch in range(args.nb_epochs): - acc_loss = 0 for input in train_input.split(args.batch_size): @@ -137,7 +137,7 @@ for epoch in range(args.nb_epochs): acc_loss += loss.item() - log_string('acc_loss {:d} {:f}.'.format(epoch, acc_loss)) + log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss)) ###################################################################### @@ -148,8 +148,8 @@ input = test_input[:256] z = model.encode(input) output = model.decode(z) -torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8) -torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8) +torchvision.utils.save_image(1 - input, "ae-input.png", nrow=16, pad_value=0.8) +torchvision.utils.save_image(1 - output, "ae-output.png", nrow=16, pad_value=0.8) # Dumb synthesis @@ -158,6 +158,6 @@ mu, std = z.mean(0), z.std(0) z = z.normal_() * std + mu output = model.decode(z) -torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8) +torchvision.utils.save_image(1 - output, "ae-synth.png", nrow=16, pad_value=0.8) ######################################################################