Update.
[agtree2dot.git] / README.md
1 # Introduction #
2
3 This package provides a function that generates a
4 [dot file](https://en.wikipedia.org/wiki/DOT_(graph_description_language))
5 from a [PyTorch](http://pytorch.org) autograd graph.
6
7 # Usage #
8
9 ## Functions ##
10
11 ### agtree2dot.save_dot(variable, variable_labels, result_file) ###
12
13 Saves into `result_file` a dot file corresponding to the autograd
14 graph for the `Variable` `variable`. The dictionary `variable_labels`
15 associates strings to some variables, which will be used in the
16 resulting graph.
17
18 ## Example ##
19
20 A typical use is provided in [mlp.py](https://fleuret.org/git-extract/agtree2dot/mlp.py):
21
22 ```python
23 from torch import nn
24 from torch.nn import functional as fn
25 from torch import Tensor
26 from torch.nn import Module
27
28 import agtree2dot
29
30 class MLP(Module):
31     def __init__(self, input_dim, hidden_dim, output_dim):
32         super(MLP, self).__init__()
33         self.fc1 = nn.Linear(input_dim, hidden_dim)
34         self.fc2 = nn.Linear(hidden_dim, output_dim)
35
36     def forward(self, x):
37         x = self.fc1(x)
38         x = fn.tanh(x)
39         x = self.fc2(x)
40         return x
41
42 mlp = MLP(10, 20, 1)
43 input = Tensor(100, 10).normal_()
44 target = Tensor(100, 1).normal_()
45 output = mlp(input)
46 criterion = nn.MSELoss()
47 loss = criterion(output, target)
48
49 agtree2dot.save_dot(loss,
50                     {
51                         input: 'input',
52                         target: 'target',
53                         loss: 'loss',
54                         mlp.fc1.weight: 'weight1',
55                         mlp.fc1.bias: 'bias1',
56                         mlp.fc2.weight: 'weight2',
57                         mlp.fc2.bias: 'bias2',
58                     },
59                     open('./mlp.dot', 'w'))
60
61 print('Generated mlp.dot')
62
63 try:
64     subprocess.check_call(['dot', 'mlp.dot', '-Lg', '-T', 'pdf', '-o', 'mlp.pdf' ])
65
66 except subprocess.CalledProcessError:
67     print('Calling the dot command failed. Is Graphviz installed?')
68     sys.exit(1)
69
70 print('Generated mlp.pdf')
71 ```
72
73 which would generate a file mlp.dot and try to generate
74 [mlp.pdf](https://fleuret.org/git-extract/agtree2dot/mlp.pdf) from it
75 with [Graphviz tools.](http://www.graphviz.org/)