Update.
[agtree2dot.git] / mlp.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
18 #########################################################################
19
20 import subprocess
21
22 import torch
23 from torch import nn
24 from torch.nn import Module
25
26 import agtree2dot
27
28 class MLP(Module):
29     def __init__(self, input_dim, hidden_dim, output_dim):
30         super().__init__()
31         self.fc1 = nn.Linear(input_dim, hidden_dim)
32         self.fc2 = nn.Linear(hidden_dim, output_dim)
33
34     def forward(self, x):
35         x = self.fc1(x)
36         x = torch.tanh(x)
37         x = self.fc2(x)
38         return x
39
40 mlp = MLP(10, 20, 1)
41 criterion = nn.MSELoss()
42
43 input = torch.randn(100, 10)
44 target = torch.randn(100, 1)
45
46 output = mlp(input)
47
48 loss = criterion(output, target)
49
50 agtree2dot.save_dot(loss,
51                     {
52                         input: 'input',
53                         target: 'target',
54                         loss: 'loss',
55                         mlp.fc1.weight: 'weight1',
56                         mlp.fc1.bias: 'bias1',
57                         mlp.fc2.weight: 'weight2',
58                         mlp.fc2.bias: 'bias2',
59                     },
60                     open('./mlp.dot', 'w'))
61
62 print('Generated mlp.dot')
63
64 try:
65
66     fontname = 'Computer Modern'
67     fontsize = 12
68     subprocess.check_call(['dot', 'mlp.dot',
69                            '-Lg',
70                            '-T', 'pdf',
71                            '-Efontname=' + fontname, '-Efontsize=' + str(fontsize),
72                            '-Nfontname=' + fontname, '-Nfontsize=' + str(fontsize),
73                            '-o', 'mlp.pdf' ])
74
75 except subprocess.CalledProcessError:
76
77     print('Calling the dot command failed. Is Graphviz installed?')
78     sys.exit(1)
79
80 print('Generated mlp.pdf')