X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=agtree2dot.py;fp=agtree2dot.py;h=2d89af5cc8130d90e9e8087e549b45f66b69aefd;hb=35a4a9a26e7b35e507755a5d4fe3ea7f4f1ca6e0;hp=f215f94e26a141c377453eddf5848519d39bbb23;hpb=387cadee9eb11bc505f8f3a73524500b8171b859;p=agtree2dot.git
diff --git a/agtree2dot.py b/agtree2dot.py
index f215f94..2d89af5 100755
--- a/agtree2dot.py
+++ b/agtree2dot.py
@@ -1,108 +1,129 @@
-
-#########################################################################
-# This program is free software: you can redistribute it and/or modify #
-# it under the terms of the version 3 of the GNU General Public License #
-# as published by the Free Software Foundation. #
-# #
-# This program is distributed in the hope that it will be useful, but #
-# WITHOUT ANY WARRANTY; without even the implied warranty of #
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
-# General Public License for more details. #
-# #
-# You should have received a copy of the GNU General Public License #
-# along with this program. If not, see . #
-# #
-# Written by and Copyright (C) Francois Fleuret #
-# Contact for comments & bug reports #
-#########################################################################
+#!/usr/bin/env python-for-pytorch
import torch
-import re
-import sys
+import math, sys, re
-import torch.autograd
+from torch import nn
+from torch.nn import functional as fn
-######################################################################
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+from torch.nn import Module
-def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}):
+######################################################################
- if isinstance(x, set):
+class Link:
+ def __init__(self, from_node, from_nb, to_node, to_nb):
+ self.from_node = from_node
+ self.from_nb = from_nb
+ self.to_node = to_node
+ self.to_nb = to_nb
+
+class Node:
+ def __init__(self, id, label):
+ self.id = id
+ self.label = label
+ self.max_in = -1
+ self.max_out = -1
+
+def slot(node_list, n, k, for_input):
+ if for_input:
+ if node_list[n].max_out > 0:
+ return str(node_list[n].id) + ':input' + str(k)
+ else:
+ return str(node_list[n].id)
+ else:
+ if node_list[n].max_in > 0:
+ return str(node_list[n].id) + ':output' + str(k)
+ else:
+ return str(node_list[n].id)
- for y in x:
- save_dot_rec(y, node_labels, out, drawn_node_id)
+def slot_string(k, for_input):
+ result = ''
+ if for_input:
+ label = 'input'
else:
+ label = 'output'
+
+ if k > 0:
+ if not for_input: result = ' |' + result
+ result += ' { <' + label + '0> 0'
+ for j in range(1, k+1):
+ result += " | " + '<' + label + str(j) + '> ' + str(j)
+ result += " } "
+ if for_input: result = result + '| '
- if not x in drawn_node_id:
- drawn_node_id[x] = len(drawn_node_id) + 1
-
- # Draw the node (Variable or Function) if not already
- # drawn
-
- if isinstance(x, torch.autograd.Variable):
- name = ((x in node_labels and node_labels[x]) or 'Variable')
- # Add the tensor size
- name = name + ' ['
- for d in range(0, x.data.dim()):
- if d > 0: name = name + ', '
- name = name + str(x.data.size(d))
- name = name + ']'
-
- out.write(' ' + str(drawn_node_id[x]) +
- ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n')
-
- if hasattr(x, 'creator') and x.creator:
- y = x.creator
- save_dot_rec(y, node_labels, out, drawn_node_id)
- # Edge to the creator
- out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n')
-
- elif isinstance(x, torch.autograd.Function):
- name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \
- re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1)
-
- prefix = ''
- suffix = ''
-
- if hasattr(x, 'num_inputs') and x.num_inputs > 1:
- prefix = '{ '
- for i in range(0, x.num_inputs):
- if i > 0: prefix = prefix + ' | '
- prefix = prefix + ' ' + str(i)
- prefix = prefix + ' } | '
-
- if hasattr(x, 'num_outputs') and x.num_outputs > 1:
- suffix = ' | { '
- for i in range(0, x.num_outputs):
- if i > 0: suffix = suffix + ' | '
- suffix = suffix + '