X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=clueless-kmeans.git;a=blobdiff_plain;f=clusterer.h;fp=clusterer.h;h=b45f8c329b83e2fbf9767151a67e9ee1cee8cd60;hp=88c168a488511fc1f2d5e41f916d5b04641f7004;hb=792d1a58d91f607d47ea5316ec680fb5e3454e5e;hpb=eda8a45be4bdd4c56709cfed4d70059c78a52895 diff --git a/clusterer.h b/clusterer.h index 88c168a..b45f8c3 100644 --- a/clusterer.h +++ b/clusterer.h @@ -29,6 +29,9 @@ class Clusterer { public: + + enum { STANDARD_ASSOCIATION, STANDARD_LP_ASSOCIATION, UNINFORMATIVE_LP_ASSOCIATION }; + const static int max_nb_iterations = 10; const static scalar_t min_iteration_improvement = 0.999; @@ -38,26 +41,38 @@ public: void initialize_clusters(int nb_points, scalar_t **points); + // Does the standard hard k-mean association + scalar_t baseline_cluster_association(int nb_points, scalar_t **points, int nb_classes, int *labels, scalar_t **gamma); + // Does the same with an LP formulation, as a sanity check + scalar_t baseline_lp_cluster_association(int nb_points, scalar_t **points, int nb_classes, int *labels, scalar_t **gamma); + // Does the association under constraints that each cluster gets + // associated clusters with the same class proportion as the overall + // training set + scalar_t uninformative_lp_cluster_association(int nb_points, scalar_t **points, int nb_classes, int *labels, scalar_t **gamma); - void baseline_update_clusters(int nb_points, scalar_t **points, scalar_t **gamma); + void update_clusters(int nb_points, scalar_t **points, scalar_t **gamma); public: Clusterer(); ~Clusterer(); - void train(int nb_clusters, int dim, + + void train(int mode, + int nb_clusters, int dim, int nb_points, scalar_t **points, int nb_classes, int *labels, + // This last array returns for each sample to what + // cluster it was associated. It can be null. int *cluster_associations); int cluster(scalar_t *point);