Added comments in the main method.
[clueless-kmeans.git] / clueless-kmean.cc
1 /*
2  *  clueless-kmean is a variant of k-mean which enforces balanced
3  *  distribution of classes in every cluster
4  *
5  *  Copyright (c) 2013 Idiap Research Institute, http://www.idiap.ch/
6  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
7  *
8  *  This file is part of clueless-kmean.
9  *
10  *  clueless-kmean is free software: you can redistribute it and/or
11  *  modify it under the terms of the GNU General Public License
12  *  version 3 as published by the Free Software Foundation.
13  *
14  *  clueless-kmean is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *  General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
21  *
22  */
23
24 #include <iostream>
25 #include <fstream>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <float.h>
29 #include <glpk.h>
30
31 using namespace std;
32
33 #include "misc.h"
34 #include "arrays.h"
35 #include "sample_set.h"
36 #include "clusterer.h"
37
38 void generate_toy_problem(SampleSet *sample_set) {
39   int dim = 2;
40   int nb_points = 1000;
41
42   sample_set->resize(dim, nb_points);
43   sample_set->nb_classes = 2;
44
45   for(int n = 0; n < nb_points; n++) {
46     sample_set->labels[n] = int(drand48() * 2);
47     if(sample_set->labels[n] == 0) {
48       sample_set->points[n][0] = (2 * drand48()  - 1) * 0.8;
49       sample_set->points[n][1] = - 0.6 + (2 * drand48()  - 1) * 0.4;
50     } else {
51       sample_set->points[n][0] = (2 * drand48()  - 1) * 0.4;
52       sample_set->points[n][1] =   0.6 + (2 * drand48()  - 1) * 0.4;
53     }
54   }
55 }
56
57 int main(int argc, char **argv) {
58   SampleSet sample_set;
59   Clusterer clusterer;
60   int nb_clusters = 3;
61
62   generate_toy_problem(&sample_set);
63
64   {
65     ofstream out("points.dat");
66     for(int n = 0; n < sample_set.nb_points; n++) {
67       out << sample_set.labels[n];
68       for(int d = 0; d < sample_set.dim; d++) {
69         out << " " << sample_set.points[n][d];
70       }
71       out << endl;
72     }
73   }
74
75   int *associated_clusters = new int[sample_set.nb_points];
76
77   glp_term_out(0);
78
79   clusterer.train(Clusterer::UNINFORMATIVE_LP_ASSOCIATION,
80                   nb_clusters,
81                   sample_set.dim,
82                   sample_set.nb_points, sample_set.points,
83                   sample_set.nb_classes, sample_set.labels,
84                   associated_clusters);
85
86   {
87     ofstream out("associated_clusters.dat");
88     for(int n = 0; n < sample_set.nb_points; n++) {
89       out << associated_clusters[n];
90       for(int d = 0; d < sample_set.dim; d++) {
91         out << " " << sample_set.points[n][d];
92       }
93       out << endl;
94     }
95   }
96
97   {
98     ofstream out("clusters.dat");
99     for(int k = 0 ; k < clusterer._nb_clusters; k++) {
100       out << k;
101       for(int d = 0; d < sample_set.dim; d++) {
102         out << " " << clusterer._cluster_means[k][d];
103       }
104       for(int d = 0; d < sample_set.dim; d++) {
105         out << " " << 2 * sqrt(clusterer._cluster_var[k][d]);
106       }
107       out << endl;
108     }
109   }
110
111   delete[] associated_clusters;
112 }