automatic commit
[folded-ctf.git] / parsing.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include "parsing.h"
22 #include "fusion_sort.h"
23
24 Parsing::Parsing(LabelledImagePool *image_pool,
25                  PoseCellHierarchy *hierarchy,
26                  scalar_t proportion_negative_cells,
27                  int image_index) {
28
29   _image_pool = image_pool;
30   _image_index = image_index;
31
32   PoseCellSet cell_set;
33   LabelledImage *image;
34
35   image = _image_pool->grab_image(_image_index);
36
37   hierarchy->add_root_cells(image, &cell_set);
38
39   int *kept = new int[cell_set.nb_cells()];
40
41   _nb_cells = 0;
42
43   for(int c = 0; c < cell_set.nb_cells(); c++) {
44     int l = image->pose_cell_label(cell_set.get_cell(c));
45     kept[c] = (l > 0) || (l < 0 && drand48() < proportion_negative_cells);
46     if(kept[c]) _nb_cells++;
47   }
48
49   _cells = new PoseCell[_nb_cells];
50   _responses = new scalar_t[_nb_cells];
51   _labels = new int[_nb_cells];
52   _nb_positives = 0;
53   _nb_negatives = 0;
54
55   int d = 0;
56   for(int c = 0; c < cell_set.nb_cells(); c++) {
57     if(kept[c]) {
58       _cells[d] = *(cell_set.get_cell(c));
59       _labels[d] = image->pose_cell_label(&_cells[d]);
60       _responses[d] = 0;
61       if(_labels[d] < 0) {
62         _nb_negatives++;
63       } else if(_labels[d] > 0) {
64         _nb_positives++;
65       }
66       d++;
67     }
68   }
69
70   delete[] kept;
71
72   _image_pool->release_image(_image_index);
73 }
74
75 Parsing::~Parsing() {
76   delete[] _cells;
77   delete[] _responses;
78   delete[] _labels;
79 }
80
81 void Parsing::down_one_level(PoseCellHierarchy *hierarchy,
82                              int level, int *sample_nb_occurences, scalar_t *sample_responses) {
83   PoseCellSet cell_set;
84   LabelledImage *image;
85
86   int new_nb_cells = 0;
87   for(int c = 0; c < _nb_cells; c++) {
88     new_nb_cells += sample_nb_occurences[c];
89   }
90
91   PoseCell *new_cells = new PoseCell[new_nb_cells];
92   scalar_t *new_responses = new scalar_t[new_nb_cells];
93   int *new_labels = new int[new_nb_cells];
94
95   image = _image_pool->grab_image(_image_index);
96   int b = 0;
97
98   for(int c = 0; c < _nb_cells; c++) {
99
100     if(sample_nb_occurences[c] > 0) {
101
102       cell_set.erase_content();
103       hierarchy->add_subcells(level, _cells + c, &cell_set);
104
105       if(_labels[c] > 0) {
106         ASSERT(sample_nb_occurences[c] == 1);
107         int e = -1;
108         for(int d = 0; d < cell_set.nb_cells(); d++) {
109           if(image->pose_cell_label(cell_set.get_cell(d)) > 0) {
110             ASSERT(e < 0);
111             e = d;
112           }
113         }
114         ASSERT(e >= 0);
115         ASSERT(b < new_nb_cells);
116         new_cells[b] = *(cell_set.get_cell(e));
117         new_responses[b] = sample_responses[c];
118         new_labels[b] = 1;
119         b++;
120       }
121
122       else if(_labels[c] < 0) {
123         for(int d = 0; d < sample_nb_occurences[c]; d++) {
124           ASSERT(b < new_nb_cells);
125           new_cells[b] = *(cell_set.get_cell(int(drand48() * cell_set.nb_cells())));
126           new_responses[b] = sample_responses[c];
127           new_labels[b] = -1;
128           b++;
129         }
130       }
131
132       else {
133         cerr << "INCONSISTENCY" << endl;
134         abort();
135       }
136     }
137   }
138
139   ASSERT(b == new_nb_cells);
140
141   _image_pool->release_image(_image_index);
142
143   delete[] _cells;
144   delete[] _labels;
145   delete[] _responses;
146   _nb_cells = new_nb_cells;
147   _cells = new_cells;
148   _labels = new_labels;
149   _responses = new_responses;
150 }
151
152 void Parsing::update_cell_responses(PiFeatureFamily *pi_feature_family,
153                                     Classifier *classifier) {
154   LabelledImage *image;
155
156   image = _image_pool->grab_image(_image_index);
157   image->compute_rich_structure();
158
159   SampleSet *samples = new SampleSet(pi_feature_family->nb_features(), 1);
160
161   for(int c = 0; c < _nb_cells; c++) {
162     samples->set_sample(0, pi_feature_family, image, &_cells[c], 0);
163     _responses[c] += classifier->response(samples, 0);
164     ASSERT(!isnan(_responses[c]));
165   }
166
167   _image_pool->release_image(_image_index);
168   delete samples;
169 }
170
171 void Parsing::collect_samples(SampleSet *samples,
172                               PiFeatureFamily *pi_feature_family,
173                               int s,
174                               int *to_collect) {
175   LabelledImage *image;
176
177   image = _image_pool->grab_image(_image_index);
178   image->compute_rich_structure();
179
180   for(int c = 0; c < _nb_cells; c++) {
181     if(to_collect[c]) {
182       samples->set_sample(s, pi_feature_family, image, &_cells[c], _labels[c]);
183       s++;
184     }
185   }
186
187   _image_pool->release_image(_image_index);
188 }