Update.
[agtree2dot.git] / mlp.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
18 #########################################################################
19
20 from torch import nn
21 from torch.nn import functional as fn
22 from torch import Tensor
23 from torch.autograd import Variable
24 from torch.nn import Module
25
26 import agtree2dot
27
28 class MLP(Module):
29     def __init__(self, input_dim, hidden_dim, output_dim):
30         super(MLP, self).__init__()
31         self.fc1 = nn.Linear(input_dim, hidden_dim)
32         self.fc2 = nn.Linear(hidden_dim, output_dim)
33
34     def forward(self, x):
35         x = self.fc1(x)
36         x = fn.tanh(x)
37         x = self.fc2(x)
38         return x
39
40 mlp = MLP(10, 20, 1)
41 input = Variable(Tensor(100, 10).normal_())
42 target = Variable(Tensor(100).normal_())
43 output = mlp(input)
44 criterion = nn.MSELoss()
45 loss = criterion(output, target)
46
47 agtree2dot.save_dot(loss,
48                     { input: 'input', target: 'target', loss: 'loss' },
49                     open('./mlp.dot', 'w'))
50
51 print('Generated mlp.dot. You can convert it to pdf with')
52 print('> dot mlp.dot -Lg -T pdf -o mlp.pdf')