Fixed a typo !
[clueless-kmeans.git] / clueless-kmeans.cc
1 /*
2  *  clueless-kmeans is a variant of k-means which enforces balanced
3  *  distribution of classes in every cluster
4  *
5  *  Copyright (c) 2013 Idiap Research Institute, http://www.idiap.ch/
6  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
7  *
8  *  This file is part of clueless-kmeans.
9  *
10  *  clueless-kmeans is free software: you can redistribute it and/or
11  *  modify it under the terms of the GNU General Public License
12  *  version 3 as published by the Free Software Foundation.
13  *
14  *  clueless-kmeans is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *  General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
21  *
22  */
23
24 #include <iostream>
25 #include <fstream>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <float.h>
29 #include <glpk.h>
30
31 using namespace std;
32
33 #include "misc.h"
34 #include "arrays.h"
35 #include "sample_set.h"
36 #include "clusterer.h"
37
38 void generate_toy_problem(SampleSet *sample_set) {
39   int dim = 2;
40   int nb_points = 1000;
41
42   sample_set->resize(dim, nb_points);
43   sample_set->nb_classes = 2;
44
45   for(int n = 0; n < nb_points; n++) {
46     sample_set->labels[n] = int(drand48() * 2);
47     if(sample_set->labels[n] == 0) {
48       sample_set->points[n][0] = (2 * drand48()  - 1) * 0.8;
49       sample_set->points[n][1] = - 0.6 + (2 * drand48()  - 1) * 0.4;
50     } else {
51       sample_set->points[n][0] = (2 * drand48()  - 1) * 0.4;
52       sample_set->points[n][1] =   0.6 + (2 * drand48()  - 1) * 0.4;
53     }
54   }
55 }
56
57 int main(int argc, char **argv) {
58   SampleSet sample_set;
59   Clusterer clusterer;
60   int nb_clusters = 3;
61
62   generate_toy_problem(&sample_set);
63
64   {
65     ofstream out("points.dat");
66     for(int n = 0; n < sample_set.nb_points; n++) {
67       out << sample_set.labels[n];
68       for(int d = 0; d < sample_set.dim; d++) {
69         out << " " << sample_set.points[n][d];
70       }
71       out << endl;
72     }
73   }
74
75   int *associated_clusters = new int[sample_set.nb_points];
76
77   glp_term_out(0);
78
79   int mode;
80
81   if(argc == 2) {
82     if(strcmp(argv[1], "standard") == 0) {
83       mode = Clusterer::STANDARD_LP_ASSOCIATION;
84     } else if(strcmp(argv[1], "clueless") == 0) {
85       mode = Clusterer::UNINFORMATIVE_LP_ASSOCIATION;
86     } else {
87       cerr << "Unknown association mode " << argv[1] << endl;
88       exit(EXIT_FAILURE);
89     }
90   } else {
91     cerr << "Usage: " << argv[0] << " standard|clueless" << endl;
92     exit(EXIT_FAILURE);
93   }
94
95   clusterer.train(mode,
96                   nb_clusters,
97                   sample_set.dim,
98                   sample_set.nb_points, sample_set.points,
99                   sample_set.nb_classes, sample_set.labels,
100                   associated_clusters);
101
102   {
103     ofstream out("associated_clusters.dat");
104     for(int n = 0; n < sample_set.nb_points; n++) {
105       out << associated_clusters[n];
106       for(int d = 0; d < sample_set.dim; d++) {
107         out << " " << sample_set.points[n][d];
108       }
109       out << endl;
110     }
111   }
112
113   {
114     ofstream out("clusters.dat");
115     for(int k = 0 ; k < clusterer._nb_clusters; k++) {
116       out << k;
117       for(int d = 0; d < sample_set.dim; d++) {
118         out << " " << clusterer._cluster_means[k][d];
119       }
120       for(int d = 0; d < sample_set.dim; d++) {
121         out << " " << 2 * sqrt(clusterer._cluster_var[k][d]);
122       }
123       out << endl;
124     }
125   }
126
127   delete[] associated_clusters;
128
129   glp_free_env(); // I do not want valgrind to complain
130 }