automatic commit
[mlp.git] / neural.h
diff --git a/neural.h b/neural.h
new file mode 100644 (file)
index 0000000..6843212
--- /dev/null
+++ b/neural.h
@@ -0,0 +1,118 @@
+/*
+ *  mlp-mnist is an implementation of a multi-layer neural network.
+ *
+ *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
+ *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
+ *
+ *  This file is part of mlp-mnist.
+ *
+ *  mlp-mnist is free software: you can redistribute it and/or modify
+ *  it under the terms of the GNU General Public License version 3 as
+ *  published by the Free Software Foundation.
+ *
+ *  mlp-mnist is distributed in the hope that it will be useful, but
+ *  WITHOUT ANY WARRANTY; without even the implied warranty of
+ *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ *  General Public License for more details.
+ *
+ *  You should have received a copy of the GNU General Public License
+ *  along with mlp-mnist.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+#ifndef NEURAL_H
+#define NEURAL_H
+
+#include <cmath>
+#include <stdlib.h>
+
+#include "misc.h"
+#include "images.h"
+
+inline scalar_t normal_sample() {
+  scalar_t a = drand48();
+  scalar_t b = drand48();
+  return cos(2 * M_PI * a) * sqrt(-2 * log(b));
+}
+
+class MultiLayerPerceptron {
+protected:
+  static const scalar_t output_amplitude;
+
+  int _nb_layers;
+  int *_layer_sizes;
+  int _nb_activations, _nb_weights;
+
+  // We can 'freeze' certain layers and let the learning only change
+  // the others
+  bool *_frozen_layers;
+
+  // Tell us where the layers begin
+  int *_weights_index, *_activations_index;
+
+  scalar_t *_activations, *_pre_sigma_activations;
+  scalar_t *_weights;
+
+public:
+  MultiLayerPerceptron(const MultiLayerPerceptron &mlp);
+  MultiLayerPerceptron(int nb_layers, int *layer_sizes);
+  MultiLayerPerceptron(istream &is);
+  ~MultiLayerPerceptron();
+
+  void save(ostream &os);
+
+  void save_data();
+
+  inline int nb_layers() { return _nb_layers; }
+  inline int layer_size(int l) { return _layer_sizes[l]; }
+  inline int nb_weights() { return _nb_weights; }
+  inline void freeze(int l, bool f) { _frozen_layers[l] = f; }
+  scalar_t sigma(scalar_t x) { return 2 / (1 + exp(- x)) - 1; }
+  scalar_t dsigma(scalar_t x) { scalar_t e = exp(- x); return 2 * e / sq(1 + e); }
+
+  // Init all the weights with a normal distribution of given standard
+  // deviation
+  void init_random_weights(scalar_t stdd);
+
+  // Compute the gradient based on one single sample
+  void compute_gradient_1s(ImageSet *is, int p, scalar_t *gradient_1s);
+  // Compute the gradient based on all samples from the set
+  void compute_gradient(ImageSet *is, scalar_t *gradient);
+
+  // Compute the same gradient numerically (to check the one above)
+  void compute_numerical_gradient(ImageSet *is, scalar_t *gradient);
+
+  // Print the gradient
+  void print_gradient(ostream &os, scalar_t *gradient);
+
+  // Move all weights to origin + lambda * gradient
+  void move_on_line(scalar_t *origin, scalar_t *gradient, scalar_t lambda);
+
+  // The 'basic' gradient just goes through all samples and add dt
+  // time the gradient on each one
+  void one_step_basic_gradient(ImageSet *is, scalar_t dt);
+
+  // The global gradient uses a conjugate gradient to minmize the
+  // global quadratic error
+  void one_step_global_gradient(ImageSet *is, scalar_t *xi, scalar_t *g, scalar_t *h);
+
+  // Performs gradient descent until the test error as increased
+  // during 5 steps
+  void train(ImageSet *training_set, ImageSet *validation_set);
+
+  // Compute the activation of the network from one sample. The input
+  // layer has to be as large as the number of pixels in the images.
+  void compute_activations_1s(ImageSet *is, int p);
+
+  // Compute the activation of the network on all samples. The
+  // responses array has to be as large as the number of samples in is
+  // time the dimension of the output layer
+  void test(ImageSet *is, scalar_t *responses);
+
+  // Compute the quadratic error
+  scalar_t error(ImageSet *is);
+  // Compute the classification error
+  scalar_t classification_error(ImageSet *is);
+};
+
+#endif