Initial commit.
[mtp.git] / miniksp.cc
1
2 ////////////////////////////////////////////////////////////////////
3 // START_IP_HEADER                                                //
4 //                                                                //
5 // Written by Francois Fleuret                                    //
6 // Contact <francois.fleuret@idiap.ch> for comments & bug reports //
7 //                                                                //
8 // END_IP_HEADER                                                  //
9 ////////////////////////////////////////////////////////////////////
10
11 #include <iostream>
12 #include <fstream>
13 #include <cmath>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <float.h>
17
18 using namespace std;
19
20 typedef float scalar_t;
21
22 #ifdef DEBUG
23 #define ASSERT(x) if(!(x)) { \
24   std::cerr << "ASSERT FAILED IN " << __FILE__ << ":" << __LINE__ << endl; \
25   abort(); \
26 }
27 #else
28 #define ASSERT(x)
29 #endif
30
31 void raise_es(int nb_edges, scalar_t *es) {
32   scalar_t min_es = es[0];
33   for(int e = 1; e < nb_edges; e++) {
34     min_es = min(min_es, es[e]);
35   }
36   for(int e = 0; e < nb_edges; e++) {
37     es[e] -= min_es;
38   }
39 }
40
41 void add_dpsi_es(int nb_edges, scalar_t *es, int *ea, int *eb, scalar_t *psi) {
42   for(int e = 0; e < nb_edges; e++) {
43     es[e] += psi[ea[e]] - psi[eb[e]];
44   }
45 }
46
47 void find_shortest(int nb_vertices,
48                    int nb_edges, scalar_t *es, int *ea, int *eb,
49                    int source, int sink,
50                    int *pred, scalar_t *dist) {
51   for(int v = 0; v < nb_vertices; v++) {
52     dist[v] = FLT_MAX;
53   }
54
55   dist[source] = 0;
56
57   for(int e = 0; e < nb_edges; e++) {
58     pred[e] = -1;
59     ASSERT(se[e] >= 0);
60   }
61
62   int nb_changes;
63   scalar_t d;
64   do {
65     nb_changes = 0;
66     for(int e = 0; e < nb_edges; e++) {
67       d = dist[ea[e]] + es[e];
68       if(d < dist[eb[e]]) {
69         nb_changes++;
70         dist[eb[e]] = d;
71         pred[eb[e]] = ea[e];
72       }
73     }
74   } while(nb_changes > 0);
75
76   ASSERT(pred[sink] >= 0);
77 }
78
79 void find_best_paths(int nb_vertices,
80                      int nb_edges, scalar_t *es, int *ea, int *eb,
81                      int source, int sink) {
82   scalar_t *dist = new scalar_t[nb_vertices];
83   int *pred = new int[nb_vertices];
84
85   raise_es(nb_edges, es);
86
87   scalar_t s;
88   do {
89     find_shortest(nb_vertices, nb_edges, es, ea, eb, source, sink, pred, dist);
90     add_dpsi_es(nb_edges, es, ea, eb, dist);
91     s = 0.0;
92     for(int e = 0; e < nb_edges; e++) {
93       if(pred[eb[e]] == ea[e]) {
94         s += es[e];
95         int k = eb[e];
96         eb[e] = ea[e];
97         ea[e] = k;
98         es[e] = - es[e];
99       }
100     }
101   } while(s < 0);
102
103   delete[] dist;
104   delete[] pred;
105 }
106
107 int main(int argc, char **argv) {
108   int nb_time_steps = 4;
109   int nb_locations = 5;
110   // Add the source and sink
111   int nb_vertices = nb_time_steps * nb_locations + 2;
112   int nb_edges = nb_locations + (nb_time_steps - 1) * nb_locations * nb_locations + nb_locations;
113   int source = 0;
114   int sink = nb_locations - 1;
115   scalar_t *es = new scalar_t[nb_edges];
116   int *ea = new int[nb_edges];
117   int *eb = new int[nb_edges];
118
119   delete[] es;
120   delete[] ea;
121   delete[] eb;
122 }