Update for recent libpng
[svrt.git] / vision_test.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include <iostream>
26 #include <fstream>
27 #include <cmath>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <unistd.h>
31
32 using namespace std;
33
34 #include "rgb_image.h"
35 #include "param_parser.h"
36 #include "global.h"
37
38 #include "vignette.h"
39 #include "shape.h"
40 #include "classifier.h"
41 #include "classifier_reader.h"
42 #include "naive_bayesian_classifier.h"
43 #include "boosted_classifier.h"
44 #include "error_rates.h"
45
46 #include "vision_problem_1.h"
47 #include "vision_problem_2.h"
48 #include "vision_problem_3.h"
49 #include "vision_problem_4.h"
50 #include "vision_problem_5.h"
51 #include "vision_problem_6.h"
52 #include "vision_problem_7.h"
53 #include "vision_problem_8.h"
54 #include "vision_problem_9.h"
55 #include "vision_problem_10.h"
56 #include "vision_problem_11.h"
57 #include "vision_problem_12.h"
58 #include "vision_problem_13.h"
59 #include "vision_problem_14.h"
60 #include "vision_problem_15.h"
61 #include "vision_problem_16.h"
62 #include "vision_problem_17.h"
63 #include "vision_problem_18.h"
64 #include "vision_problem_19.h"
65 #include "vision_problem_20.h"
66 #include "vision_problem_21.h"
67 #include "vision_problem_22.h"
68 #include "vision_problem_23.h"
69
70 //////////////////////////////////////////////////////////////////////
71
72 void check(bool condition, const char *message) {
73   if(!condition) {
74     cerr << message << endl;
75     exit(1);
76   }
77 }
78
79 int main(int argc, char **argv) {
80
81   char buffer[buffer_size];
82   char *new_argv[argc];
83   int new_argc = 0;
84
85   cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
86
87   for(int i = 0; i < argc; i++)
88     cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
89          << endl;
90
91   cout << "-- PARAMETERS --------------------------------------------------------" << endl;
92
93   {
94     ParamParser parser;
95     global.init_parser(&parser);
96     parser.parse_options(argc, argv, false, &new_argc, new_argv);
97     global.read_parser(&parser);
98     parser.print_all(&cout);
99   }
100
101   nice(global.niceness);
102   srand48(global.random_seed);
103
104   VignetteGenerator *generator;
105
106   switch(global.problem_number) {
107   case 1:
108     generator = new VisionProblem_1();
109     break;
110   case 2:
111     generator = new VisionProblem_2();
112     break;
113   case 3:
114     generator = new VisionProblem_3();
115     break;
116   case 4:
117     generator = new VisionProblem_4();
118     break;
119   case 5:
120     generator = new VisionProblem_5();
121     break;
122   case 6:
123     generator = new VisionProblem_6();
124     break;
125   case 7:
126     generator = new VisionProblem_7();
127     break;
128   case 8:
129     generator = new VisionProblem_8();
130     break;
131   case 9:
132     generator = new VisionProblem_9();
133     break;
134   case 10:
135     generator = new VisionProblem_10();
136     break;
137   case 11:
138     generator = new VisionProblem_11();
139     break;
140   case 12:
141     generator = new VisionProblem_12();
142     break;
143   case 13:
144     generator = new VisionProblem_13();
145     break;
146   case 14:
147     generator = new VisionProblem_14();
148     break;
149   case 15:
150     generator = new VisionProblem_15();
151     break;
152   case 16:
153     generator = new VisionProblem_16();
154     break;
155   case 17:
156     generator = new VisionProblem_17();
157     break;
158   case 18:
159     generator = new VisionProblem_18();
160     break;
161   case 19:
162     generator = new VisionProblem_19();
163     break;
164   case 20:
165     generator = new VisionProblem_20();
166     break;
167   case 21:
168     generator = new VisionProblem_21();
169     break;
170   case 22:
171     generator = new VisionProblem_22();
172     break;
173   case 23:
174     generator = new VisionProblem_23();
175     break;
176   default:
177     cerr << "Can not find problem "
178          << global.problem_number
179          << endl;
180     exit(1);
181   }
182
183   generator->precompute();
184
185   //////////////////////////////////////////////////////////////////////
186
187   Vignette *train_samples;
188   int *train_labels;
189
190   train_samples = new Vignette[global.nb_train_samples];
191   train_labels = new int[global.nb_train_samples];
192
193   //////////////////////////////////////////////////////////////////////
194
195   Classifier *classifier = 0;
196
197   cout << "-- COMPUTATIONS ------------------------------------------------------" << endl;
198
199   for(int c = 1; c < new_argc; c++) {
200
201     if(strcmp(new_argv[c], "randomize-train") == 0) {
202       cout << "Generating the training set." << endl;
203       for(int n = 0; n < global.nb_train_samples; n++) {
204         train_labels[n] = int(drand48() * 2);
205         generator->generate(train_labels[n], &train_samples[n]);
206       }
207     }
208
209     else if(strcmp(new_argv[c], "adaboost") == 0) {
210       delete classifier;
211       cout << "Building and training adaboost classifier." << endl;
212       classifier = new BoostedClassifier(global.nb_weak_learners);
213       classifier->train(global.nb_train_samples, train_samples, train_labels);
214     }
215
216     else if(strcmp(new_argv[c], "naive-bayesian") == 0) {
217       delete classifier;
218       cout << "Building and training naive bayesian classifier." << endl;
219       classifier = new NaiveBayesianClassifier();
220       classifier->train(global.nb_train_samples, train_samples, train_labels);
221     }
222
223     else if(strcmp(new_argv[c], "read-classifier") == 0) {
224       delete classifier;
225       sprintf(buffer, "%s", global.classifier_name);
226       cout << "Reading classifier from " << buffer << "." << endl;
227       ifstream in(buffer);
228       if(in.fail()) {
229         cerr << "Can not open " << buffer << " for reading." << endl;
230         exit(1);
231       }
232       classifier = read_classifier(&in);
233     }
234
235     else if(strcmp(new_argv[c], "write-classifier") == 0) {
236       check(classifier, "No classifier.");
237       sprintf(buffer, "%s/%s", global.result_path, global.classifier_name);
238       cout << "Writing classifier to " << buffer << "." << endl;
239       ofstream out(buffer);
240       if(out.fail()) {
241         cerr << "Can not open " << buffer << " for writing." << endl;
242         exit(1);
243       }
244       classifier->write(&out);
245     }
246
247     else if(strcmp(new_argv[c], "compute-errors-vs-nb-samples") == 0) {
248       for(int t = global.nb_train_samples; t >= 100; t /= 10) {
249         for(int n = 0; n < t; n++) {
250           train_labels[n] = int(drand48() * 2);
251           generator->generate(train_labels[n], &train_samples[n]);
252         }
253         Classifier *classifier = 0;
254         cout << "Building and training adaboost classifier with " << t << " samples." << endl;
255         classifier = new BoostedClassifier(global.nb_weak_learners);
256         classifier->train(t, train_samples, train_labels);
257         cout << "ERROR_RATES_VS_NB_SAMPLES "
258              << t
259              << " TRAIN_ERROR "
260              << error_rate(classifier, t, train_samples, train_labels)
261              << " TEST_ERROR "
262              << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
263         delete classifier;
264       }
265     }
266
267     else if(strcmp(new_argv[c], "compute-train-error") == 0) {
268       check(classifier, "No classifier.");
269       cout << "TRAIN_ERROR_RATE "
270            << classifier->name()
271            << " "
272            << error_rate(classifier, global.nb_train_samples, train_samples, train_labels)
273            << endl;
274     }
275
276     else if(strcmp(new_argv[c], "compute-test-error") == 0) {
277       check(classifier, "No classifier.");
278       cout << "TEST_ERROR_RATE "
279            << classifier->name()
280            << " "
281            << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
282     }
283
284     else if(strcmp(new_argv[c], "write-samples") == 0) {
285       Vignette vignette;
286       for(int k = 0; k < global.nb_train_samples; k++) {
287         for(int l = 0; l < 2; l++) {
288           generator->generate(l, &vignette);
289           sprintf(buffer, "%s/sample_%01d_%04d.png", global.result_path, l, k);
290           vignette.write_png(buffer, 1);
291           cout << "Wrote " << buffer << endl;
292         }
293       }
294     }
295
296     //////////////////////////////////////////////////////////////////////
297
298     //////////////////////////////////////////////////////////////////////
299
300     else {
301       cerr << "Unknown action " << new_argv[c] << endl;
302       exit(1);
303     }
304
305   }
306
307   cout << "-- FINISHED ----------------------------------------------------------" << endl;
308
309   delete classifier;
310   delete[] train_labels;
311   delete[] train_samples;
312   delete generator;
313 }