automatic commit
[folded-ctf.git] / folding.cc
diff --git a/folding.cc b/folding.cc
new file mode 100644 (file)
index 0000000..34c785c
--- /dev/null
@@ -0,0 +1,339 @@
+
+///////////////////////////////////////////////////////////////////////////
+// This program is free software: you can redistribute it and/or modify  //
+// it under the terms of the version 3 of the GNU General Public License //
+// as published by the Free Software Foundation.                         //
+//                                                                       //
+// This program is distributed in the hope that it will be useful, but   //
+// WITHOUT ANY WARRANTY; without even the implied warranty of            //
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
+// General Public License for more details.                              //
+//                                                                       //
+// You should have received a copy of the GNU General Public License     //
+// along with this program. If not, see <http://www.gnu.org/licenses/>.  //
+//                                                                       //
+// Written by Francois Fleuret, (C) IDIAP                                //
+// Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
+///////////////////////////////////////////////////////////////////////////
+
+#include <iostream>
+#include <fstream>
+#include <cmath>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+using namespace std;
+
+#include "misc.h"
+#include "param_parser.h"
+#include "global.h"
+#include "labelled_image_pool_file.h"
+#include "labelled_image_pool_subset.h"
+#include "tools.h"
+#include "detector.h"
+#include "pose_cell_hierarchy.h"
+#include "error_rates.h"
+#include "materials.h"
+
+//////////////////////////////////////////////////////////////////////
+
+void check(bool condition, const char *message) {
+  if(!condition) {
+    cerr << message << endl;
+    exit(1);
+  }
+}
+
+//////////////////////////////////////////////////////////////////////
+
+int main(int argc, char **argv) {
+  char *new_argv[argc];
+  int new_argc = 0;
+
+#ifdef DEBUG
+  cout << endl;
+  cout << "**********************************************************************" << endl;
+  cout << "**                     COMPILED IN DEBUG MODE                       **" << endl;
+  cout << "**********************************************************************" << endl;
+  cout << endl;
+#endif
+
+  cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
+  for(int i = 0; i < argc; i++)
+    cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
+         << endl;
+
+  {
+    ParamParser parser;
+    global.init_parser(&parser);
+    parser.parse_options(argc, argv, false, &new_argc, new_argv);
+    global.read_parser(&parser);
+    (*global.log_stream)
+      << "-- PARAMETERS --------------------------------------------------------"
+      << endl;
+    parser.print_all(global.log_stream);
+  }
+
+  nice(global.niceness);
+
+  (*global.log_stream) << "INFO RANDOM_SEED " << global.random_seed << endl;
+  srand48(global.random_seed);
+
+  LabelledImagePool *main_pool = 0;
+  LabelledImagePool *train_pool = 0, *validation_pool = 0, *hierarchy_pool = 0;
+  LabelledImagePool *test_pool = 0;
+  Detector *detector = 0;
+
+  {
+    char buffer[buffer_size];
+    gethostname(buffer, buffer_size);
+    (*global.log_stream) << "INFO HOSTNAME " << buffer << endl;
+  }
+
+  for(int c = 1; c < new_argc; c++) {
+
+    if(strcmp(new_argv[c], "open-pool") == 0) {
+      cout
+        << "-- OPENING POOL ------------------------------------------------------"
+        << endl;
+
+      check(!main_pool, "Pool already opened.");
+      check(global.pool_name[0], "No pool file.");
+
+      main_pool = new LabelledImagePoolFile(global.pool_name);
+
+      bool for_test[main_pool->nb_images()];
+      bool for_train[main_pool->nb_images()];
+      bool for_validation[main_pool->nb_images()];
+      bool for_hierarchy[main_pool->nb_images()];
+
+      for(int n = 0; n < main_pool->nb_images(); n++) {
+        for_test[n] = false;
+        for_train[n] = false;
+        for_validation[n] = false;
+        scalar_t r = drand48();
+        if(r < global.proportion_for_train)
+          for_train[n] = true;
+        else if(r < global.proportion_for_train + global.proportion_for_validation)
+          for_validation[n] = true;
+        else if(global.proportion_for_test < 0 ||
+                r < global.proportion_for_train +
+                global.proportion_for_validation +
+                global.proportion_for_test)
+          for_test[n] = true;
+        for_hierarchy[n] = for_train[n] || for_validation[n];
+      }
+
+      train_pool = new LabelledImagePoolSubset(main_pool, for_train);
+      validation_pool = new LabelledImagePoolSubset(main_pool, for_validation);
+      hierarchy_pool = new LabelledImagePoolSubset(main_pool, for_hierarchy);
+
+      if(global.test_pool_name[0]) {
+        test_pool = new LabelledImagePoolFile(global.test_pool_name);
+      } else {
+        test_pool = new LabelledImagePoolSubset(main_pool, for_test);
+      }
+
+      cout << "Using "
+           << train_pool->nb_images() << " images for train, "
+           << validation_pool->nb_images() << " images for validation, "
+           << hierarchy_pool->nb_images() << " images for the hierarchy and "
+           << test_pool->nb_images() << " images for test."
+           << endl;
+
+    }
+
+    else if(strcmp(new_argv[c], "write-target-poses") == 0) {
+      check(main_pool, "No pool available.");
+      LabelledImage *image;
+      for(int p = 0; p < main_pool->nb_images(); p++) {
+        image = main_pool->grab_image(p);
+        for(int t = 0; t < image->nb_targets(); t++) {
+          cout << "IMAGE " << p << " TARGET " << t << endl;
+          image->get_target_pose(t)->print(&cout);
+        }
+        main_pool->release_image(p);
+      }
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "train-detector") == 0) {
+      cout << "-- TRAIN DETECTOR ----------------------------------------------------" << endl;
+      check(train_pool, "No train pool available.");
+      check(validation_pool, "No validation pool available.");
+      check(hierarchy_pool, "No hierarchy pool available.");
+      check(!detector, "Existing detector, can not train another one.");
+      detector = new Detector();
+      detector->train(train_pool, validation_pool, hierarchy_pool);
+    }
+
+    else if(strcmp(new_argv[c], "compute-thresholds") == 0) {
+      cout << "-- COMPUTE THRESHOLDS ------------------------------------------------" << endl;
+      check(validation_pool, "No validation pool available.");
+      check(detector, "No detector.");
+      detector->compute_thresholds(validation_pool, global.wanted_true_positive_rate);
+    }
+
+    else if(strcmp(new_argv[c], "check-hierarchy") == 0) {
+      cout << "-- CHECK HIERARCHY ---------------------------------------------------" << endl;
+      PoseCellHierarchy *h = new PoseCellHierarchy(hierarchy_pool);
+      cout << "Train incompatible poses " << h->nb_incompatible_poses(train_pool) << endl;
+      cout << "Validation incompatible poses " << h->nb_incompatible_poses(validation_pool) << endl;
+      delete h;
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "validate-detector") == 0) {
+      cout << "-- VALIDATE DETECTOR -------------------------------------------------" << endl;
+
+      check(validation_pool, "No validation pool available.");
+      check(detector, "No detector.");
+
+      print_decimated_error_rate(global.nb_levels - 1, validation_pool, detector);
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "test-detector") == 0) {
+      cout << "-- TEST DETECTOR -----------------------------------------------------" << endl;
+
+      check(test_pool, "No test pool available.");
+      check(detector, "No detector.");
+
+      if(test_pool->nb_images() > 0) {
+        print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
+      } else {
+        cout << "No test image." << endl;
+      }
+    }
+
+    else if(strcmp(new_argv[c], "parse-images") == 0) {
+      cout << "-- PARSING IMAGES -----------------------------------------------------" << endl;
+      check(detector, "No detector.");
+      while(!cin.eof()) {
+        char image_name[buffer_size];
+        cin.getline(image_name, buffer_size);
+        if(strlen(image_name) > 0) {
+          parse_scene(detector, image_name);
+        }
+      }
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "sequence-test-detector") == 0) {
+      cout << "-- SEQUENCE TEST DETECTOR --------------------------------------------" << endl;
+
+      check(test_pool, "No test pool available.");
+      check(detector, "No detector.");
+
+      if(test_pool->nb_images() > 0) {
+
+        for(int n = 0; n < global.nb_wanted_true_positive_rates; n++) {
+          scalar_t r = global.wanted_true_positive_rate *
+            scalar_t(n + 1) / scalar_t(global.nb_wanted_true_positive_rates);
+          cout << "Testint at tp " << r
+               << " (" << n + 1 << "/" << global.nb_wanted_true_positive_rates << ")"
+               << endl;
+          (*global.log_stream) << "INFO THRESHOLD_FOR_TP " << r << endl;
+          detector->compute_thresholds(validation_pool, r);
+          print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
+        }
+      } else {
+        cout << "No test image." << endl;
+      }
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "write-detector") == 0) {
+      cout << "-- WRITE DETECTOR ----------------------------------------------------" << endl;
+      ofstream out(global.detector_name);
+      if(out.fail()) {
+        cerr << "Can not write to " << global.detector_name << endl;
+        exit(1);
+      }
+      check(detector, "No detector available.");
+      detector->write(&out);
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "read-detector") == 0) {
+      cout << "-- READ DETECTOR -----------------------------------------------------" << endl;
+
+      check(!detector, "Existing detector, can not load another one.");
+
+      ifstream in(global.detector_name);
+      if(in.fail()) {
+        cerr << "Can not read from " << global.detector_name << endl;
+        exit(1);
+      }
+
+      detector = new Detector();
+      detector->read(&in);
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else if(strcmp(new_argv[c], "write-pool-images") == 0) {
+      cout << "-- WRITING POOL IMAGES -----------------------------------------------" << endl;
+      check(global.nb_images > 0, "You must set nb_images to a positive value.");
+      check(train_pool, "No train pool available.");
+      write_pool_images_with_poses_and_referentials(train_pool, detector);
+    }
+
+    else if(strcmp(new_argv[c], "produce-materials") == 0) {
+      cout << "-- PRODUCING MATERIALS -----------------------------------------------" << endl;
+
+      check(hierarchy_pool, "No hierarchy pool available.");
+      check(test_pool, "No test pool available.");
+
+      PoseCellHierarchy *hierarchy;
+
+      cout << "Creating hierarchy" << endl;
+
+      hierarchy = new PoseCellHierarchy(hierarchy_pool);
+
+      LabelledImage *image;
+      for(int p = 0; p < test_pool->nb_images(); p++) {
+        image = test_pool->grab_image(p);
+        if(image->width() == 640 && image->height() == 480) {
+          PoseCellSet pcs;
+          hierarchy->add_root_cells(image, &pcs);
+          cout << "WE HAVE " << pcs.nb_cells() << " CELLS" << endl;
+          exit(0);
+          test_pool->release_image(p);
+        }
+      }
+
+      delete hierarchy;
+
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+    else {
+      cerr << "Unknown action " << new_argv[c] << endl;
+      exit(1);
+    }
+
+    //////////////////////////////////////////////////////////////////////
+
+  }
+
+  delete detector;
+
+  delete train_pool;
+  delete validation_pool;
+  delete hierarchy_pool;
+  delete test_pool;
+
+  delete main_pool;
+
+  cout << "-- FINISHED ----------------------------------------------------------" << endl;
+
+}