Removed the definition of basename, which confuses an existing system one.
[folded-ctf.git] / parsing.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  folded-ctf is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "parsing.h"
26 #include "fusion_sort.h"
27
28 Parsing::Parsing(LabelledImagePool *image_pool,
29                  PoseCellHierarchy *hierarchy,
30                  scalar_t proportion_negative_cells,
31                  int image_index) {
32
33   _image_pool = image_pool;
34   _image_index = image_index;
35
36   PoseCellSet cell_set;
37   LabelledImage *image;
38
39   image = _image_pool->grab_image(_image_index);
40
41   hierarchy->add_root_cells(image, &cell_set);
42
43   int *kept = new int[cell_set.nb_cells()];
44
45   _nb_cells = 0;
46
47   for(int c = 0; c < cell_set.nb_cells(); c++) {
48     int l = image->pose_cell_label(cell_set.get_cell(c));
49     kept[c] = (l > 0) || (l < 0 && drand48() < proportion_negative_cells);
50     if(kept[c]) _nb_cells++;
51   }
52
53   _cells = new PoseCell[_nb_cells];
54   _responses = new scalar_t[_nb_cells];
55   _labels = new int[_nb_cells];
56   _nb_positives = 0;
57   _nb_negatives = 0;
58
59   int d = 0;
60   for(int c = 0; c < cell_set.nb_cells(); c++) {
61     if(kept[c]) {
62       _cells[d] = *(cell_set.get_cell(c));
63       _labels[d] = image->pose_cell_label(&_cells[d]);
64       _responses[d] = 0;
65       if(_labels[d] < 0) {
66         _nb_negatives++;
67       } else if(_labels[d] > 0) {
68         _nb_positives++;
69       }
70       d++;
71     }
72   }
73
74   delete[] kept;
75
76   _image_pool->release_image(_image_index);
77 }
78
79 Parsing::~Parsing() {
80   delete[] _cells;
81   delete[] _responses;
82   delete[] _labels;
83 }
84
85 void Parsing::down_one_level(PoseCellHierarchy *hierarchy,
86                              int level, int *sample_nb_occurences, scalar_t *sample_responses) {
87   PoseCellSet cell_set;
88   LabelledImage *image;
89
90   int new_nb_cells = 0;
91   for(int c = 0; c < _nb_cells; c++) {
92     new_nb_cells += sample_nb_occurences[c];
93   }
94
95   PoseCell *new_cells = new PoseCell[new_nb_cells];
96   scalar_t *new_responses = new scalar_t[new_nb_cells];
97   int *new_labels = new int[new_nb_cells];
98
99   image = _image_pool->grab_image(_image_index);
100   int b = 0;
101
102   for(int c = 0; c < _nb_cells; c++) {
103
104     if(sample_nb_occurences[c] > 0) {
105
106       cell_set.erase_content();
107       hierarchy->add_subcells(level, _cells + c, &cell_set);
108
109       if(_labels[c] > 0) {
110         ASSERT(sample_nb_occurences[c] == 1);
111         int e = -1;
112         for(int d = 0; d < cell_set.nb_cells(); d++) {
113           if(image->pose_cell_label(cell_set.get_cell(d)) > 0) {
114             ASSERT(e < 0);
115             e = d;
116           }
117         }
118         ASSERT(e >= 0);
119         ASSERT(b < new_nb_cells);
120         new_cells[b] = *(cell_set.get_cell(e));
121         new_responses[b] = sample_responses[c];
122         new_labels[b] = 1;
123         b++;
124       }
125
126       else if(_labels[c] < 0) {
127         for(int d = 0; d < sample_nb_occurences[c]; d++) {
128           ASSERT(b < new_nb_cells);
129           new_cells[b] = *(cell_set.get_cell(int(drand48() * cell_set.nb_cells())));
130           new_responses[b] = sample_responses[c];
131           new_labels[b] = -1;
132           b++;
133         }
134       }
135
136       else {
137         cerr << "INCONSISTENCY" << endl;
138         abort();
139       }
140     }
141   }
142
143   ASSERT(b == new_nb_cells);
144
145   _image_pool->release_image(_image_index);
146
147   delete[] _cells;
148   delete[] _labels;
149   delete[] _responses;
150   _nb_cells = new_nb_cells;
151   _cells = new_cells;
152   _labels = new_labels;
153   _responses = new_responses;
154 }
155
156 void Parsing::update_cell_responses(PiFeatureFamily *pi_feature_family,
157                                     Classifier *classifier) {
158   LabelledImage *image;
159
160   image = _image_pool->grab_image(_image_index);
161   image->compute_rich_structure();
162
163   SampleSet *samples = new SampleSet(pi_feature_family->nb_features(), 1);
164
165   for(int c = 0; c < _nb_cells; c++) {
166     samples->set_sample(0, pi_feature_family, image, &_cells[c], 0);
167     _responses[c] += classifier->response(samples, 0);
168     ASSERT(!isnan(_responses[c]));
169   }
170
171   _image_pool->release_image(_image_index);
172   delete samples;
173 }
174
175 void Parsing::collect_samples(SampleSet *samples,
176                               PiFeatureFamily *pi_feature_family,
177                               int s,
178                               int *to_collect) {
179   LabelledImage *image;
180
181   image = _image_pool->grab_image(_image_index);
182   image->compute_rich_structure();
183
184   for(int c = 0; c < _nb_cells; c++) {
185     if(to_collect[c]) {
186       samples->set_sample(s, pi_feature_family, image, &_cells[c], _labels[c]);
187       s++;
188     }
189   }
190
191   _image_pool->release_image(_image_index);
192 }