Starts to look decent.
[mtp.git] / mtp_graph.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by and Copyright (C) Francois Fleuret                         //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "mtp_graph.h"
20
21 #include <iostream>
22 #include <float.h>
23
24 using namespace std;
25
26 class Edge {
27 public:
28   int id, occupied;
29   scalar_t length, work_length;
30   Vertex *origin_vertex, *terminal_vertex;
31   Edge *next, *pred;
32
33   inline void revert();
34 };
35
36 class Vertex {
37 public:
38   int id, iteration;
39   Edge *root_edge;
40   scalar_t distance_from_source;
41   Edge *pred_edge;
42
43   Vertex();
44   inline void add_edge(Edge *e);
45   inline void del_edge(Edge *e);
46 };
47
48 //////////////////////////////////////////////////////////////////////
49
50 void Edge::revert() {
51   length = - length;
52   work_length = 0;
53   origin_vertex->del_edge(this);
54   terminal_vertex->add_edge(this);
55   Vertex *t = terminal_vertex;
56   terminal_vertex = origin_vertex;
57   origin_vertex = t;
58 }
59
60 //////////////////////////////////////////////////////////////////////
61
62 Vertex::Vertex() {
63   root_edge = 0;
64 }
65
66 void Vertex::add_edge(Edge *e) {
67   e->next = root_edge;
68   e->pred = 0;
69   if(root_edge) { root_edge->pred = e; }
70   root_edge = e;
71 }
72
73 void Vertex::del_edge(Edge *e) {
74   if(e == root_edge) { root_edge = e->next; }
75   if(e->pred) { e->pred->next = e->next; }
76   if(e->next) { e->next->pred = e->pred; }
77 }
78
79 //////////////////////////////////////////////////////////////////////
80
81 void MTPGraph::print() {
82   for(int k = 0; k < _nb_edges; k++) {
83     Edge *e = edges + k;
84     cout << e->origin_vertex->id
85          << " -> "
86          << e->terminal_vertex->id
87          << " "
88          << e->length;
89     if(e->occupied) {
90       cout << " *";
91     }
92     cout << endl;
93   }
94 }
95
96 void MTPGraph::print_dot() {
97   cout << "digraph {" << endl;
98   cout << "  node[shape=circle];" << endl;
99   for(int k = 0; k < _nb_edges; k++) {
100     Edge *e = edges + k;
101     if(e->occupied) {
102       cout << "  " << e->origin_vertex->id << " -> " << e->terminal_vertex->id
103            << " [style=bold,color=black,label=\"" << -e->length << "\"];" << endl;
104     } else {
105       cout << "  " << e->origin_vertex->id << " -> " << e->terminal_vertex->id
106            << " [color=gray,label=\"" << e->length << "\"];" << endl;
107     }
108   }
109   cout << "}" << endl;
110 }
111
112 MTPGraph::MTPGraph(int nb_vertices, int nb_edges,
113                    int *from, int *to,
114                    int src, int snk) {
115   _nb_vertices = nb_vertices;
116   _nb_edges = nb_edges;
117
118   edges = new Edge[_nb_edges];
119   vertices = new Vertex[_nb_vertices];
120   _front = new Vertex *[_nb_vertices];
121   _new_front = new Vertex *[_nb_vertices];
122
123   _source = &vertices[src];
124   _sink = &vertices[snk];
125
126   for(int v = 0; v < _nb_vertices; v++) {
127     vertices[v].id = v;
128   }
129
130   for(int e = 0; e < nb_edges; e++) {
131     vertices[from[e]].add_edge(&edges[e]);
132     edges[e].occupied = 0;
133     edges[e].id = e;
134     edges[e].origin_vertex = &vertices[from[e]];
135     edges[e].terminal_vertex = &vertices[to[e]];
136   }
137
138 }
139
140 MTPGraph::~MTPGraph() {
141   delete[] vertices;
142   delete[] edges;
143   delete[] _front;
144   delete[] _new_front;
145 }
146
147 void MTPGraph::initialize_work_lengths() {
148   scalar_t length_min = 0;
149   for(int n = 0; n < _nb_vertices; n++) {
150     for(Edge *e = vertices[n].root_edge; e; e = e->next) {
151       length_min = min(e->length, length_min);
152     }
153   }
154   for(int n = 0; n < _nb_vertices; n++) {
155     for(Edge *e = vertices[n].root_edge; e; e = e->next) {
156       e->work_length = e->length - length_min;
157     }
158   }
159 }
160
161 void MTPGraph::update_work_lengths() {
162   for(int k = 0; k < _nb_edges; k++) {
163     Edge *e = edges + k;
164     e->work_length += e->terminal_vertex->distance_from_source - e->terminal_vertex->distance_from_source;
165   }
166 }
167
168 void MTPGraph::force_positive_work_lengths() {
169 #ifdef VERBOSE
170   scalar_t residual_error = 0.0;
171 #endif
172   for(int n = 0; n < _nb_vertices; n++) {
173     for(Edge *e = vertices[n].root_edge; e; e = e->next) {
174       if(e->work_length < 0) {
175 #ifdef VERBOSE
176         residual_error -= e->work_length;
177 #endif
178         e->work_length = 0.0;
179       }
180     }
181   }
182 #ifdef VERBOSE
183   cerr << "residual_error " << residual_error << endl;
184 #endif
185 }
186
187 void MTPGraph::find_shortest_path(Vertex **_front, Vertex **_new_front) {
188   Vertex **tmp_front;
189   int tmp_front_size;
190   Vertex *v, *tv;
191   scalar_t d;
192
193   for(int v = 0; v < _nb_vertices; v++) {
194     vertices[v].distance_from_source = FLT_MAX;
195     vertices[v].pred_edge = 0;
196     vertices[v].iteration = 0;
197   }
198
199   int iteration = 0;
200
201   int _front_size = 0, _new_front_size;
202   _front[_front_size++] = _source;
203   _source->distance_from_source = 0;
204
205   do {
206     _new_front_size = 0;
207     iteration++;
208     for(int f = 0; f < _front_size; f++) {
209       v = _front[f];
210       for(Edge *e = v->root_edge; e; e = e->next) {
211         d = v->distance_from_source + e->work_length;
212         tv = e->terminal_vertex;
213         if(d < tv->distance_from_source) {
214           tv->distance_from_source = d;
215           tv->pred_edge = e;
216           if(tv->iteration < iteration) {
217             _new_front[_new_front_size++] = tv;
218             tv->iteration = iteration;
219           }
220         }
221       }
222     }
223
224     tmp_front = _new_front;
225     _new_front = _front;
226     _front = tmp_front;
227
228     tmp_front_size = _new_front_size;
229     _new_front_size = _front_size;
230     _front_size = tmp_front_size;
231   } while(_front_size > 0);
232 }
233
234 void MTPGraph::find_best_paths(scalar_t *lengths, int *result_edge_occupation) {
235   scalar_t total_length;
236   Vertex *v;
237   Edge *e;
238
239   for(int e = 0; e < _nb_edges; e++) {
240     edges[e].length = lengths[e];
241     edges[e].work_length = edges[e].length;
242   }
243
244   find_shortest_path(_front, _new_front);
245   update_work_lengths();
246
247   // #warning
248   // initialize_work_lengths();
249
250   do {
251     force_positive_work_lengths();
252     find_shortest_path(_front, _new_front);
253     update_work_lengths();
254
255     total_length = 0.0;
256
257     // Do we reach the _sink?
258     if(_sink->pred_edge) {
259       // If yes, compute the length of the best path
260       v = _sink;
261       while(v->pred_edge) {
262         total_length += v->pred_edge->length;
263         v = v->pred_edge->origin_vertex;
264       }
265       // If that length is negative
266       if(total_length < 0.0) {
267 #ifdef VERBOSE
268         cout << "Found a path of length " << total_length << endl;
269 #endif
270         // Invert all the edges along the best path
271         v = _sink;
272         while(v->pred_edge) {
273           e = v->pred_edge;
274           v = e->origin_vertex;
275           e->revert();
276           e->occupied = 1 - e->occupied;
277         }
278       }
279     }
280
281   } while(total_length < 0.0);
282
283   for(int k = 0; k < _nb_edges; k++) {
284     Edge *e = edges + k;
285     if(e->occupied) { e->revert(); }
286   }
287
288   for(int n = 0; n < _nb_vertices; n++) {
289     Vertex *v = &vertices[n];
290     for(Edge *e = v->root_edge; e; e = e->next) {
291       result_edge_occupation[e->id] = e->occupied;
292     }
293   }
294 }