X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=causal-autoregression.py;h=77542655fa5fb958bcdb6d13927bd1add205b21a;hp=c2f6161511b08c56253f224535c31f15b8946fcc;hb=HEAD;hpb=762a2c5e2485e0ebd7c26fe980893a4de2544bb9 diff --git a/causal-autoregression.py b/causal-autoregression.py index c2f6161..0c931fb 100755 --- a/causal-autoregression.py +++ b/causal-autoregression.py @@ -18,46 +18,46 @@ from torch.nn import functional as F ###################################################################### -def save_images(x, filename, nrow = 12): - print(f'Writing {filename}') - torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))), - filename, - nrow = nrow, pad_value=1.0) + +def save_images(x, filename, nrow=12): + print(f"Writing {filename}") + torchvision.utils.save_image( + x.narrow(0, 0, min(48, x.size(0))), filename, nrow=nrow, pad_value=1.0 + ) + ###################################################################### parser = argparse.ArgumentParser( - description = 'An implementation of a causal autoregression model', - formatter_class = argparse.ArgumentDefaultsHelpFormatter + description="An implementation of a causal autoregression model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument('--data', - type = str, default = 'toy1d', - help = 'What data') +parser.add_argument("--data", type=str, default="toy1d", help="What data") -parser.add_argument('--seed', - type = int, default = 0, - help = 'Random seed (default 0, < 0 is no seeding)') +parser.add_argument( + "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)" +) -parser.add_argument('--nb_epochs', - type = int, default = -1, - help = 'How many epochs') +parser.add_argument("--nb_epochs", type=int, default=-1, help="How many epochs") -parser.add_argument('--batch_size', - type = int, default = 100, - help = 'Batch size') +parser.add_argument("--batch_size", type=int, default=100, help="Batch size") -parser.add_argument('--learning_rate', - type = float, default = 1e-3, - help = 'Batch size') +parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size") -parser.add_argument('--positional', - action='store_true', default = False, - help = 'Do we provide a positional encoding as input') +parser.add_argument( + "--positional", + action="store_true", + default=False, + help="Do we provide a positional encoding as input", +) -parser.add_argument('--dilation', - action='store_true', default = False, - help = 'Do we provide a positional encoding as input') +parser.add_argument( + "--dilation", + action="store_true", + default=False, + help="Do we provide a positional encoding as input", +) ###################################################################### @@ -67,32 +67,33 @@ if args.seed >= 0: torch.manual_seed(args.seed) if args.nb_epochs < 0: - if args.data == 'toy1d': + if args.data == "toy1d": args.nb_epochs = 100 - elif args.data == 'mnist': + elif args.data == "mnist": args.nb_epochs = 25 ###################################################################### if torch.cuda.is_available(): - print('Cuda is available') - device = torch.device('cuda') + print("Cuda is available") + device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: - device = torch.device('cpu') + device = torch.device("cpu") ###################################################################### + class NetToy1d(nn.Module): - def __init__(self, nb_classes, ks = 2, nc = 32): - super(NetToy1d, self).__init__() + def __init__(self, nb_classes, ks=2, nc=32): + super().__init__() self.pad = (ks - 1, 0) - self.conv0 = nn.Conv1d(1, nc, kernel_size = 1) - self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks) - self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks) - self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks) - self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks) - self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1) + self.conv0 = nn.Conv1d(1, nc, kernel_size=1) + self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks) + self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks) + self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks) + self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks) + self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1) def forward(self, x): x = F.relu(self.conv0(F.pad(x, (1, -1)))) @@ -103,19 +104,20 @@ class NetToy1d(nn.Module): x = self.conv5(x) return x.permute(0, 2, 1).contiguous() + class NetToy1dWithDilation(nn.Module): - def __init__(self, nb_classes, ks = 2, nc = 32): - super(NetToy1dWithDilation, self).__init__() - self.conv0 = nn.Conv1d(1, nc, kernel_size = 1) - self.pad1 = ((ks-1) * 2, 0) - self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2) - self.pad2 = ((ks-1) * 4, 0) - self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4) - self.pad3 = ((ks-1) * 8, 0) - self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8) - self.pad4 = ((ks-1) * 16, 0) - self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16) - self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1) + def __init__(self, nb_classes, ks=2, nc=32): + super().__init__() + self.conv0 = nn.Conv1d(1, nc, kernel_size=1) + self.pad1 = ((ks - 1) * 2, 0) + self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=2) + self.pad2 = ((ks - 1) * 4, 0) + self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=4) + self.pad3 = ((ks - 1) * 8, 0) + self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=8) + self.pad4 = ((ks - 1) * 16, 0) + self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=16) + self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1) def forward(self, x): x = F.relu(self.conv0(F.pad(x, (1, -1)))) @@ -126,21 +128,23 @@ class NetToy1dWithDilation(nn.Module): x = self.conv5(x) return x.permute(0, 2, 1).contiguous() + ###################################################################### + class PixelCNN(nn.Module): - def __init__(self, nb_classes, in_channels = 1, ks = 5): - super(PixelCNN, self).__init__() + def __init__(self, nb_classes, in_channels=1, ks=5): + super().__init__() - self.hpad = (ks//2, ks//2, ks//2, 0) - self.vpad = (ks//2, 0, 0, 0) + self.hpad = (ks // 2, ks // 2, ks // 2, 0) + self.vpad = (ks // 2, 0, 0, 0) - self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks)) - self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks)) - self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1)) - self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1)) - self.final1 = nn.Conv2d(128, 128, kernel_size = 1) - self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1) + self.conv1h = nn.Conv2d(in_channels, 32, kernel_size=(ks // 2 + 1, ks)) + self.conv2h = nn.Conv2d(32, 64, kernel_size=(ks // 2 + 1, ks)) + self.conv1v = nn.Conv2d(in_channels, 32, kernel_size=(1, ks // 2 + 1)) + self.conv2v = nn.Conv2d(32, 64, kernel_size=(1, ks // 2 + 1)) + self.final1 = nn.Conv2d(128, 128, kernel_size=1) + self.final2 = nn.Conv2d(128, nb_classes, kernel_size=1) def forward(self, x): xh = F.pad(x, (0, 0, 1, -1)) @@ -154,8 +158,10 @@ class PixelCNN(nn.Module): return x.permute(0, 2, 3, 1).contiguous() + ###################################################################### + def positional_tensor(height, width): index_h = torch.arange(height).view(1, -1) m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1) @@ -169,26 +175,30 @@ def positional_tensor(height, width): return torch.cat((i_w, i_h), 1) + ###################################################################### str_experiment = args.data if args.positional: - str_experiment += '-positional' + str_experiment += "-positional" if args.dilation: - str_experiment += '-dilation' + str_experiment += "-dilation" + +log_file = open("causalar-" + str_experiment + "-train.log", "w") -log_file = open('causalar-' + str_experiment + '-train.log', 'w') def log_string(s): - s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s + s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + " " + s print(s) - log_file.write(s + '\n') + log_file.write(s + "\n") log_file.flush() + ###################################################################### + def generate_sequences(nb, len): nb_parts = 2 @@ -196,32 +206,33 @@ def generate_sequences(nb, len): x = torch.empty(nb, nb_parts).uniform_(-1, 1) x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len) - x = x * torch.linspace(0, len-1, len).view(1, -1) + len + x = x * torch.linspace(0, len - 1, len).view(1, -1) + len for n in range(nb): - a = torch.randperm(len - 2)[:nb_parts+1].sort()[0] + a = torch.randperm(len - 2)[: nb_parts + 1].sort()[0] a[0] = 0 a[a.size(0) - 1] = len for k in range(a.size(0) - 1): - r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]] + r[n, a[k] : a[k + 1]] = x[n, k, : a[k + 1] - a[k]] return r.round().long() + ###################################################################### -if args.data == 'toy1d': +if args.data == "toy1d": len = 32 train_input = generate_sequences(50000, len).to(device).unsqueeze(1) if args.dilation: - model = NetToy1dWithDilation(nb_classes = 2 * len).to(device) + model = NetToy1dWithDilation(nb_classes=2 * len).to(device) else: - model = NetToy1d(nb_classes = 2 * len).to(device) + model = NetToy1d(nb_classes=2 * len).to(device) -elif args.data == 'mnist': - train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) +elif args.data == "mnist": + train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True) train_input = train_set.data.view(-1, 1, 28, 28).long().to(device) - model = PixelCNN(nb_classes = 256, in_channels = 1).to(device) + model = PixelCNN(nb_classes=256, in_channels=1).to(device) in_channels = train_input.size(1) if args.positional: @@ -229,40 +240,35 @@ elif args.data == 'mnist': positional_input = positional_tensor(height, width).float().to(device) in_channels += positional_input.size(1) - model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device) + model = PixelCNN(nb_classes=256, in_channels=in_channels).to(device) else: - raise ValueError('Unknown data ' + args.data) + raise ValueError("Unknown data " + args.data) ###################################################################### mean, std = train_input.float().mean(), train_input.float().std() nb_parameters = sum(t.numel() for t in model.parameters()) -log_string(f'nb_parameters {nb_parameters}') +log_string(f"nb_parameters {nb_parameters}") cross_entropy = nn.CrossEntropyLoss().to(device) -optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) +optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) for e in range(args.nb_epochs): - nb_batches, acc_loss = 0, 0.0 for sequences in train_input.split(args.batch_size): - input = (sequences - mean)/std + input = (sequences - mean) / std if args.positional: input = torch.cat( - (input, positional_input.expand(input.size(0), -1, -1, -1)), - 1 + (input, positional_input.expand(input.size(0), -1, -1, -1)), 1 ) output = model(input) - loss = cross_entropy( - output.view(-1, output.size(-1)), - sequences.view(-1) - ) + loss = cross_entropy(output.view(-1, output.size(-1)), sequences.view(-1)) optimizer.zero_grad() loss.backward() @@ -271,7 +277,7 @@ for e in range(args.nb_epochs): nb_batches += 1 acc_loss += loss.item() - log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}') + log_string(f"{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}") sys.stdout.flush() @@ -284,36 +290,36 @@ flat = generated.view(generated.size(0), -1) for t in range(flat.size(1)): input = (generated.float() - mean) / std if args.positional: - input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1) + input = torch.cat( + (input, positional_input.expand(input.size(0), -1, -1, -1)), 1 + ) output = model(input) logits = output.view(flat.size() + (-1,))[:, t] - dist = torch.distributions.categorical.Categorical(logits = logits) + dist = torch.distributions.categorical.Categorical(logits=logits) flat[:, t] = dist.sample() ###################################################################### -if args.data == 'toy1d': - - with open('causalar-' + str_experiment + '-train.dat', 'w') as file: +if args.data == "toy1d": + with open("causalar-" + str_experiment + "-train.dat", "w") as file: for j in range(train_input.size(2)): - file.write(f'{j}') + file.write(f"{j}") for i in range(min(train_input.size(0), 25)): - file.write(f' {train_input[i, 0, j]}') - file.write('\n') + file.write(f" {train_input[i, 0, j]}") + file.write("\n") - with open('causalar-' + str_experiment + '-generated.dat', 'w') as file: + with open("causalar-" + str_experiment + "-generated.dat", "w") as file: for j in range(generated.size(2)): - file.write(f'{j}') + file.write(f"{j}") for i in range(generated.size(0)): - file.write(f' {generated[i, 0, j]}') - file.write('\n') - -elif args.data == 'mnist': + file.write(f" {generated[i, 0, j]}") + file.write("\n") - img_train = 1 - train_input[:generated.size(0)].float() / 255 +elif args.data == "mnist": + img_train = 1 - train_input[: generated.size(0)].float() / 255 img_generated = 1 - generated.float() / 255 - save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12) - save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12) + save_images(img_train, "causalar-" + str_experiment + "-train.png", nrow=12) + save_images(img_generated, "causalar-" + str_experiment + "-generated.png", nrow=12) ######################################################################