Update.
[pytorch.git] / lazy_linear.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import nn, Tensor
9
10 ######################################################################
11
12 class LazyLinear(nn.Module):
13
14     def __init__(self, out_dim, bias = True):
15         super().__init__()
16         self.out_dim = out_dim
17         self.bias = bias
18         self.core = None
19
20     def forward(self, x):
21         x = x.view(x.size(0), -1)
22
23         if self.core is None:
24             if self.training:
25                 self.core = nn.Linear(x.size(1), self.out_dim, self.bias)
26             else:
27                 raise RuntimeError('Undefined LazyLinear core in inference mode.')
28
29         return self.core(x)
30
31     def named_parameters(self, memo=None, prefix=''):
32         assert self.core is not None, 'Parameters not yet defined'
33         return super().named_parameters(memo, prefix)
34
35 ######################################################################
36
37 if __name__ == "__main__":
38     model = nn.Sequential(nn.Conv2d(3, 8, kernel_size = 5),
39                           nn.ReLU(inplace = True),
40                           LazyLinear(128),
41                           nn.ReLU(inplace = True),
42                           nn.Linear(128, 10))
43
44     # model.eval()
45
46     input = Tensor(100, 3, 32, 32).normal_()
47
48     output = model(input)
49
50     for n, x in model.named_parameters():
51         print(n, x.size())
52