Update, seems to work!
[mtp.git] / miniksp.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // START_IP_HEADER                                                       //
4 //                                                                       //
5 // This program is free software: you can redistribute it and/or modify  //
6 // it under the terms of the version 3 of the GNU General Public License //
7 // as published by the Free Software Foundation.                         //
8 //                                                                       //
9 // This program is distributed in the hope that it will be useful, but   //
10 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
12 // General Public License for more details.                              //
13 //                                                                       //
14 // You should have received a copy of the GNU General Public License     //
15 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
16 //                                                                       //
17 // Written by and Copyright (C) Francois Fleuret                         //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 //                                                                       //
20 // END_IP_HEADER                                                         //
21 ///////////////////////////////////////////////////////////////////////////
22
23 // #define VERBOSE
24
25 #include <iostream>
26 #include <fstream>
27 #include <cmath>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <float.h>
31
32 using namespace std;
33
34 typedef float scalar_t;
35
36 #ifdef DEBUG
37 #define ASSERT(x) if(!(x)) { \
38   std::cerr << "ASSERT FAILED IN " << __FILE__ << ":" << __LINE__ << endl; \
39   abort(); \
40 }
41 #else
42 #define ASSERT(x)
43 #endif
44
45 // In all the code:
46 //
47 //  * el[e] is the length of edge e
48 //  * ea[e] is its starting node
49 //  * eb[e] is its destination node.
50
51 // Adds to all the edge length the length of the shortest (which can
52 // be negative)
53 void raise_es(int nb_edges, scalar_t *el) {
54   scalar_t min_es = el[0];
55   for(int e = 1; e < nb_edges; e++) {
56     min_es = min(min_es, el[e]);
57   }
58   for(int e = 0; e < nb_edges; e++) {
59     el[e] -= min_es;
60   }
61 }
62
63 // Adds to every edge length the differential of the psi function on
64 // it
65 void add_dpsi_es(int nb_edges, scalar_t *el, int *ea, int *eb, scalar_t *psi) {
66   for(int e = 0; e < nb_edges; e++) {
67     el[e] += psi[ea[e]] - psi[eb[e]];
68   }
69 }
70
71 // Finds the shortest path in the graph and returns in
72 // result_edge_back, for each vertex, the edge to follow back from it
73 // to reach the source with the shortest path, and in result_dist the
74 // distance to the source. The edge lengths have to be positive.
75 void find_shortest(int nb_vertices,
76                    int nb_edges, scalar_t *el, int *ea, int *eb,
77                    int source, int sink,
78                    int *result_edge_back, scalar_t *result_dist) {
79   for(int v = 0; v < nb_vertices; v++) {
80     result_dist[v] = FLT_MAX;
81     result_edge_back[v] = -1;
82   }
83
84   result_dist[source] = 0;
85
86 #ifdef DEBUG
87   for(int e = 0; e < nb_edges; e++) {
88     if(el[e] < 0) abort();
89   }
90 #endif
91
92   int nb_changes;
93   scalar_t d;
94   do {
95     nb_changes = 0;
96     for(int e = 0; e < nb_edges; e++) {
97       d = result_dist[ea[e]] + el[e];
98       if(d < result_dist[eb[e]]) {
99         nb_changes++;
100         result_dist[eb[e]] = d;
101         result_edge_back[eb[e]] = e;
102       }
103     }
104   } while(nb_changes > 0);
105 }
106
107 // Iterates find_shortest as long as it finds paths of negative
108 // lengths. Returns which edges are occupied by the superposition of
109 // paths in result_edge_occupation.
110 //
111 // **WARNING** this routine changes the values of el, ea, and eb
112 // (i.e. the occupied edges are inverted).
113 void find_best_paths(int nb_vertices,
114                      int nb_edges, scalar_t *el, int *ea, int *eb,
115                      int source, int sink,
116                      int *result_edge_occupation) {
117   scalar_t *dist = new scalar_t[nb_vertices];
118   int *edge_back = new int[nb_vertices];
119   scalar_t *positive_el = new scalar_t[nb_edges];
120   scalar_t s;
121   int v;
122
123   for(int e = 0; e < nb_edges; e++) {
124     positive_el[e] = el[e];
125     result_edge_occupation[e] = 0;
126   }
127
128   raise_es(nb_edges, positive_el);
129
130   do {
131     find_shortest(nb_vertices, nb_edges, positive_el, ea, eb, source, sink, edge_back, dist);
132     add_dpsi_es(nb_edges, positive_el, ea, eb, dist);
133     s = 0.0;
134
135     // If the new path reaches the sink, we will backtrack on it to
136     // compute its score and invert edges
137
138     if(edge_back[sink] >= 0) {
139
140       v = sink;
141       while(v != source) {
142         int e = edge_back[v];
143         ASSERT(eb[e] == v);
144         v = ea[e];
145         s += el[e];
146       }
147
148       // We found a good path (score < 0), we revert the edges along
149       // the path and invert their occupation (since adding a path on
150       // a path already occupied means removing flow on it)
151
152       if(s < 0) {
153         v = sink;
154 #ifdef VERBOSE
155         cout << "FOUND A PATH OF LENGTH " << s << endl;
156 #endif
157         while(v != source) {
158           int e = edge_back[v];
159           ASSERT(eb[e] == v);
160           v = ea[e];
161 #ifdef VERBOSE
162           cout << "INVERTING " << ea[e] << " -> " << eb[e] << " [" << el[e] << "]" << endl;
163 #endif
164           int k = eb[e];
165           eb[e] = ea[e];
166           ea[e] = k;
167           positive_el[e] = - positive_el[e];
168           el[e] = - el[e];
169           result_edge_occupation[e] = 1 - result_edge_occupation[e];
170         }
171       }
172     }
173   } while(s < 0);
174
175   delete[] positive_el;
176   delete[] dist;
177   delete[] edge_back;
178 }
179
180 int main(int argc, char **argv) {
181
182   if(argc < 2) {
183     cerr << argv[0] << " <graph file>" << endl;
184     exit(EXIT_FAILURE);
185   }
186
187   ifstream *file = new ifstream(argv[1]);
188
189   int nb_edges, nb_vertices;
190   int source, sink;
191
192   if(file->good()) {
193
194     (*file) >> nb_vertices >> nb_edges;
195     (*file) >> source >> sink;
196
197     cout << "INPUT nb_edges " << nb_edges << endl;
198     cout << "INPUT nb_vertices " << nb_vertices << endl;
199     cout << "INPUT source " << source << endl;
200     cout << "INPUT sink " << sink << endl;
201
202     scalar_t *el = new scalar_t[nb_edges];
203     int *ea = new int[nb_edges];
204     int *eb = new int[nb_edges];
205     int *edge_occupation = new int[nb_edges];
206
207     for(int e = 0; e < nb_edges; e++) {
208       (*file) >> ea[e] >> eb[e] >> el[e];
209       cout << "INPUT_EDGE " << ea[e] << " " << eb[e] << " " << el[e] << endl;
210     }
211
212     find_best_paths(nb_vertices, nb_edges, el, ea, eb, source, sink,
213                     edge_occupation);
214
215 #ifdef VERBOSE
216     // Sanity check on the overall resulting score (the edge lengths
217     // have been changed, hence should be the opposite of the sum of
218     // the path lengths)
219     scalar_t s = 0;
220     for(int e = 0; e < nb_edges; e++) {
221       if(edge_occupation[e]) s += el[e];
222     }
223     cout << "RESULT_SANITY_CHECK_SCORE " << s << endl;
224 #endif
225
226     for(int e = 0; e < nb_edges; e++) {
227       if(edge_occupation[e]) {
228         cout << "RESULT_OCCUPIED_EDGE " << ea[e] << " " << eb[e] << endl;
229       }
230     }
231
232     delete[] edge_occupation;
233     delete[] el;
234     delete[] ea;
235     delete[] eb;
236
237   } else {
238
239     cerr << "Can not open " << argv[1] << endl;
240
241     delete file;
242     exit(EXIT_FAILURE);
243
244   }
245
246   delete file;
247   exit(EXIT_SUCCESS);
248 }