automatic commit
[folded-ctf.git] / parsing.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "parsing.h"
20 #include "fusion_sort.h"
21
22 Parsing::Parsing(LabelledImagePool *image_pool,
23                  PoseCellHierarchy *hierarchy,
24                  scalar_t proportion_negative_cells,
25                  int image_index) {
26
27   _image_pool = image_pool;
28   _image_index = image_index;
29
30   PoseCellSet cell_set;
31   LabelledImage *image;
32
33   image = _image_pool->grab_image(_image_index);
34
35   hierarchy->add_root_cells(image, &cell_set);
36
37   int *kept = new int[cell_set.nb_cells()];
38
39   _nb_cells = 0;
40
41   for(int c = 0; c < cell_set.nb_cells(); c++) {
42     int l = image->pose_cell_label(cell_set.get_cell(c));
43     kept[c] = (l > 0) || (l < 0 && drand48() < proportion_negative_cells);
44     if(kept[c]) _nb_cells++;
45   }
46
47   _cells = new PoseCell[_nb_cells];
48   _responses = new scalar_t[_nb_cells];
49   _labels = new int[_nb_cells];
50   _nb_positives = 0;
51   _nb_negatives = 0;
52
53   int d = 0;
54   for(int c = 0; c < cell_set.nb_cells(); c++) {
55     if(kept[c]) {
56       _cells[d] = *(cell_set.get_cell(c));
57       _labels[d] = image->pose_cell_label(&_cells[d]);
58       _responses[d] = 0;
59       if(_labels[d] < 0) {
60         _nb_negatives++;
61       } else if(_labels[d] > 0) {
62         _nb_positives++;
63       }
64       d++;
65     }
66   }
67
68   delete[] kept;
69
70   _image_pool->release_image(_image_index);
71 }
72
73 Parsing::~Parsing() {
74   delete[] _cells;
75   delete[] _responses;
76   delete[] _labels;
77 }
78
79 void Parsing::down_one_level(PoseCellHierarchy *hierarchy,
80                              int level, int *sample_nb_occurences, scalar_t *sample_responses) {
81   PoseCellSet cell_set;
82   LabelledImage *image;
83
84   int new_nb_cells = 0;
85   for(int c = 0; c < _nb_cells; c++) {
86     new_nb_cells += sample_nb_occurences[c];
87   }
88
89   PoseCell *new_cells = new PoseCell[new_nb_cells];
90   scalar_t *new_responses = new scalar_t[new_nb_cells];
91   int *new_labels = new int[new_nb_cells];
92
93   image = _image_pool->grab_image(_image_index);
94   int b = 0;
95
96   for(int c = 0; c < _nb_cells; c++) {
97
98     if(sample_nb_occurences[c] > 0) {
99
100       cell_set.erase_content();
101       hierarchy->add_subcells(level, _cells + c, &cell_set);
102
103       if(_labels[c] > 0) {
104         ASSERT(sample_nb_occurences[c] == 1);
105         int e = -1;
106         for(int d = 0; d < cell_set.nb_cells(); d++) {
107           if(image->pose_cell_label(cell_set.get_cell(d)) > 0) {
108             ASSERT(e < 0);
109             e = d;
110           }
111         }
112         ASSERT(e >= 0);
113         ASSERT(b < new_nb_cells);
114         new_cells[b] = *(cell_set.get_cell(e));
115         new_responses[b] = sample_responses[c];
116         new_labels[b] = 1;
117         b++;
118       }
119
120       else if(_labels[c] < 0) {
121         for(int d = 0; d < sample_nb_occurences[c]; d++) {
122           ASSERT(b < new_nb_cells);
123           new_cells[b] = *(cell_set.get_cell(int(drand48() * cell_set.nb_cells())));
124           new_responses[b] = sample_responses[c];
125           new_labels[b] = -1;
126           b++;
127         }
128       }
129
130       else {
131         cerr << "INCONSISTENCY" << endl;
132         abort();
133       }
134     }
135   }
136
137   ASSERT(b == new_nb_cells);
138
139   _image_pool->release_image(_image_index);
140
141   delete[] _cells;
142   delete[] _labels;
143   delete[] _responses;
144   _nb_cells = new_nb_cells;
145   _cells = new_cells;
146   _labels = new_labels;
147   _responses = new_responses;
148 }
149
150 void Parsing::update_cell_responses(PiFeatureFamily *pi_feature_family,
151                                     Classifier *classifier) {
152   LabelledImage *image;
153
154   image = _image_pool->grab_image(_image_index);
155   image->compute_rich_structure();
156
157   SampleSet *samples = new SampleSet(pi_feature_family->nb_features(), 1);
158
159   for(int c = 0; c < _nb_cells; c++) {
160     samples->set_sample(0, pi_feature_family, image, &_cells[c], 0);
161     _responses[c] += classifier->response(samples, 0);
162     ASSERT(!isnan(_responses[c]));
163   }
164
165   _image_pool->release_image(_image_index);
166   delete samples;
167 }
168
169 void Parsing::collect_samples(SampleSet *samples,
170                               PiFeatureFamily *pi_feature_family,
171                               int s,
172                               int *to_collect) {
173   LabelledImage *image;
174
175   image = _image_pool->grab_image(_image_index);
176   image->compute_rich_structure();
177
178   for(int c = 0; c < _nb_cells; c++) {
179     if(to_collect[c]) {
180       samples->set_sample(s, pi_feature_family, image, &_cells[c], _labels[c]);
181       s++;
182     }
183   }
184
185   _image_pool->release_image(_image_index);
186 }