Update.
[pytorch.git] / causal-autoregression.py
index c2f6161..0c931fb 100755 (executable)
@@ -18,46 +18,46 @@ from torch.nn import functional as F
 
 ######################################################################
 
-def save_images(x, filename, nrow = 12):
-    print(f'Writing {filename}')
-    torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))),
-                                 filename,
-                                 nrow = nrow, pad_value=1.0)
+
+def save_images(x, filename, nrow=12):
+    print(f"Writing {filename}")
+    torchvision.utils.save_image(
+        x.narrow(0, 0, min(48, x.size(0))), filename, nrow=nrow, pad_value=1.0
+    )
+
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description = 'An implementation of a causal autoregression model',
-    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+    description="An implementation of a causal autoregression model",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument('--data',
-                    type = str, default = 'toy1d',
-                    help = 'What data')
+parser.add_argument("--data", type=str, default="toy1d", help="What data")
 
-parser.add_argument('--seed',
-                    type = int, default = 0,
-                    help = 'Random seed (default 0, < 0 is no seeding)')
+parser.add_argument(
+    "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
+)
 
-parser.add_argument('--nb_epochs',
-                    type = int, default = -1,
-                    help = 'How many epochs')
+parser.add_argument("--nb_epochs", type=int, default=-1, help="How many epochs")
 
-parser.add_argument('--batch_size',
-                    type = int, default = 100,
-                    help = 'Batch size')
+parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
 
-parser.add_argument('--learning_rate',
-                    type = float, default = 1e-3,
-                    help = 'Batch size')
+parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
 
-parser.add_argument('--positional',
-                    action='store_true', default = False,
-                    help = 'Do we provide a positional encoding as input')
+parser.add_argument(
+    "--positional",
+    action="store_true",
+    default=False,
+    help="Do we provide a positional encoding as input",
+)
 
-parser.add_argument('--dilation',
-                    action='store_true', default = False,
-                    help = 'Do we provide a positional encoding as input')
+parser.add_argument(
+    "--dilation",
+    action="store_true",
+    default=False,
+    help="Do we provide a positional encoding as input",
+)
 
 ######################################################################
 
@@ -67,32 +67,33 @@ if args.seed >= 0:
     torch.manual_seed(args.seed)
 
 if args.nb_epochs < 0:
-    if args.data == 'toy1d':
+    if args.data == "toy1d":
         args.nb_epochs = 100
-    elif args.data == 'mnist':
+    elif args.data == "mnist":
         args.nb_epochs = 25
 
 ######################################################################
 
 if torch.cuda.is_available():
-    print('Cuda is available')
-    device = torch.device('cuda')
+    print("Cuda is available")
+    device = torch.device("cuda")
     torch.backends.cudnn.benchmark = True
 else:
-    device = torch.device('cpu')
+    device = torch.device("cpu")
 
 ######################################################################
 
+
 class NetToy1d(nn.Module):
-    def __init__(self, nb_classes, ks = 2, nc = 32):
-        super(NetToy1d, self).__init__()
+    def __init__(self, nb_classes, ks=2, nc=32):
+        super().__init__()
         self.pad = (ks - 1, 0)
-        self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
-        self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks)
-        self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks)
-        self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks)
-        self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks)
-        self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+        self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
+        self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks)
+        self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks)
+        self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks)
+        self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks)
+        self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
 
     def forward(self, x):
         x = F.relu(self.conv0(F.pad(x, (1, -1))))
@@ -103,19 +104,20 @@ class NetToy1d(nn.Module):
         x = self.conv5(x)
         return x.permute(0, 2, 1).contiguous()
 
+
 class NetToy1dWithDilation(nn.Module):
-    def __init__(self, nb_classes, ks = 2, nc = 32):
-        super(NetToy1dWithDilation, self).__init__()
-        self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
-        self.pad1 = ((ks-1) * 2, 0)
-        self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2)
-        self.pad2 = ((ks-1) * 4, 0)
-        self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4)
-        self.pad3 = ((ks-1) * 8, 0)
-        self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8)
-        self.pad4 = ((ks-1) * 16, 0)
-        self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16)
-        self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+    def __init__(self, nb_classes, ks=2, nc=32):
+        super().__init__()
+        self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
+        self.pad1 = ((ks - 1) * 2, 0)
+        self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=2)
+        self.pad2 = ((ks - 1) * 4, 0)
+        self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=4)
+        self.pad3 = ((ks - 1) * 8, 0)
+        self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=8)
+        self.pad4 = ((ks - 1) * 16, 0)
+        self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=16)
+        self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
 
     def forward(self, x):
         x = F.relu(self.conv0(F.pad(x, (1, -1))))
@@ -126,21 +128,23 @@ class NetToy1dWithDilation(nn.Module):
         x = self.conv5(x)
         return x.permute(0, 2, 1).contiguous()
 
+
 ######################################################################
 
+
 class PixelCNN(nn.Module):
-    def __init__(self, nb_classes, in_channels = 1, ks = 5):
-        super(PixelCNN, self).__init__()
+    def __init__(self, nb_classes, in_channels=1, ks=5):
+        super().__init__()
 
-        self.hpad = (ks//2, ks//2, ks//2, 0)
-        self.vpad = (ks//2,     0,     0, 0)
+        self.hpad = (ks // 2, ks // 2, ks // 2, 0)
+        self.vpad = (ks // 2, 0, 0, 0)
 
-        self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks))
-        self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks))
-        self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1))
-        self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1))
-        self.final1 = nn.Conv2d(128, 128, kernel_size = 1)
-        self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1)
+        self.conv1h = nn.Conv2d(in_channels, 32, kernel_size=(ks // 2 + 1, ks))
+        self.conv2h = nn.Conv2d(32, 64, kernel_size=(ks // 2 + 1, ks))
+        self.conv1v = nn.Conv2d(in_channels, 32, kernel_size=(1, ks // 2 + 1))
+        self.conv2v = nn.Conv2d(32, 64, kernel_size=(1, ks // 2 + 1))
+        self.final1 = nn.Conv2d(128, 128, kernel_size=1)
+        self.final2 = nn.Conv2d(128, nb_classes, kernel_size=1)
 
     def forward(self, x):
         xh = F.pad(x, (0, 0, 1, -1))
@@ -154,8 +158,10 @@ class PixelCNN(nn.Module):
 
         return x.permute(0, 2, 3, 1).contiguous()
 
+
 ######################################################################
 
+
 def positional_tensor(height, width):
     index_h = torch.arange(height).view(1, -1)
     m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
@@ -169,26 +175,30 @@ def positional_tensor(height, width):
 
     return torch.cat((i_w, i_h), 1)
 
+
 ######################################################################
 
 str_experiment = args.data
 
 if args.positional:
-    str_experiment += '-positional'
+    str_experiment += "-positional"
 
 if args.dilation:
-    str_experiment += '-dilation'
+    str_experiment += "-dilation"
+
+log_file = open("causalar-" + str_experiment + "-train.log", "w")
 
-log_file = open('causalar-' + str_experiment + '-train.log', 'w')
 
 def log_string(s):
-    s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s
+    s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + " " + s
     print(s)
-    log_file.write(s + '\n')
+    log_file.write(s + "\n")
     log_file.flush()
 
+
 ######################################################################
 
+
 def generate_sequences(nb, len):
     nb_parts = 2
 
@@ -196,32 +206,33 @@ def generate_sequences(nb, len):
 
     x = torch.empty(nb, nb_parts).uniform_(-1, 1)
     x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
-    x = x * torch.linspace(0, len-1, len).view(1, -1) + len
+    x = x * torch.linspace(0, len - 1, len).view(1, -1) + len
 
     for n in range(nb):
-        a = torch.randperm(len - 2)[:nb_parts+1].sort()[0]
+        a = torch.randperm(len - 2)[: nb_parts + 1].sort()[0]
         a[0] = 0
         a[a.size(0) - 1] = len
         for k in range(a.size(0) - 1):
-            r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]]
+            r[n, a[k] : a[k + 1]] = x[n, k, : a[k + 1] - a[k]]
 
     return r.round().long()
 
+
 ######################################################################
 
-if args.data == 'toy1d':
+if args.data == "toy1d":
     len = 32
     train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
     if args.dilation:
-        model = NetToy1dWithDilation(nb_classes = 2 * len).to(device)
+        model = NetToy1dWithDilation(nb_classes=2 * len).to(device)
     else:
-        model = NetToy1d(nb_classes = 2 * len).to(device)
+        model = NetToy1d(nb_classes=2 * len).to(device)
 
-elif args.data == 'mnist':
-    train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+elif args.data == "mnist":
+    train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
     train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
 
-    model = PixelCNN(nb_classes = 256, in_channels = 1).to(device)
+    model = PixelCNN(nb_classes=256, in_channels=1).to(device)
     in_channels = train_input.size(1)
 
     if args.positional:
@@ -229,40 +240,35 @@ elif args.data == 'mnist':
         positional_input = positional_tensor(height, width).float().to(device)
         in_channels += positional_input.size(1)
 
-    model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device)
+    model = PixelCNN(nb_classes=256, in_channels=in_channels).to(device)
 
 else:
-    raise ValueError('Unknown data ' + args.data)
+    raise ValueError("Unknown data " + args.data)
 
 ######################################################################
 
 mean, std = train_input.float().mean(), train_input.float().std()
 
 nb_parameters = sum(t.numel() for t in model.parameters())
-log_string(f'nb_parameters {nb_parameters}')
+log_string(f"nb_parameters {nb_parameters}")
 
 cross_entropy = nn.CrossEntropyLoss().to(device)
-optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
 for e in range(args.nb_epochs):
-
     nb_batches, acc_loss = 0, 0.0
 
     for sequences in train_input.split(args.batch_size):
-        input = (sequences - mean)/std
+        input = (sequences - mean) / std
 
         if args.positional:
             input = torch.cat(
-                (input, positional_input.expand(input.size(0), -1, -1, -1)),
-                1
+                (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
             )
 
         output = model(input)
 
-        loss = cross_entropy(
-            output.view(-1, output.size(-1)),
-            sequences.view(-1)
-        )
+        loss = cross_entropy(output.view(-1, output.size(-1)), sequences.view(-1))
 
         optimizer.zero_grad()
         loss.backward()
@@ -271,7 +277,7 @@ for e in range(args.nb_epochs):
         nb_batches += 1
         acc_loss += loss.item()
 
-    log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}')
+    log_string(f"{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}")
 
     sys.stdout.flush()
 
@@ -284,36 +290,36 @@ flat = generated.view(generated.size(0), -1)
 for t in range(flat.size(1)):
     input = (generated.float() - mean) / std
     if args.positional:
-        input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1)
+        input = torch.cat(
+            (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
+        )
     output = model(input)
     logits = output.view(flat.size() + (-1,))[:, t]
-    dist = torch.distributions.categorical.Categorical(logits = logits)
+    dist = torch.distributions.categorical.Categorical(logits=logits)
     flat[:, t] = dist.sample()
 
 ######################################################################
 
-if args.data == 'toy1d':
-
-    with open('causalar-' + str_experiment + '-train.dat', 'w') as file:
+if args.data == "toy1d":
+    with open("causalar-" + str_experiment + "-train.dat", "w") as file:
         for j in range(train_input.size(2)):
-            file.write(f'{j}')
+            file.write(f"{j}")
             for i in range(min(train_input.size(0), 25)):
-                file.write(f' {train_input[i, 0, j]}')
-            file.write('\n')
+                file.write(f" {train_input[i, 0, j]}")
+            file.write("\n")
 
-    with open('causalar-' + str_experiment + '-generated.dat', 'w') as file:
+    with open("causalar-" + str_experiment + "-generated.dat", "w") as file:
         for j in range(generated.size(2)):
-            file.write(f'{j}')
+            file.write(f"{j}")
             for i in range(generated.size(0)):
-                file.write(f' {generated[i, 0, j]}')
-            file.write('\n')
-
-elif args.data == 'mnist':
+                file.write(f" {generated[i, 0, j]}")
+            file.write("\n")
 
-    img_train = 1 - train_input[:generated.size(0)].float() / 255
+elif args.data == "mnist":
+    img_train = 1 - train_input[: generated.size(0)].float() / 255
     img_generated = 1 - generated.float() / 255
 
-    save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12)
-    save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12)
+    save_images(img_train, "causalar-" + str_experiment + "-train.png", nrow=12)
+    save_images(img_generated, "causalar-" + str_experiment + "-generated.png", nrow=12)
 
 ######################################################################