X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=svrt_generator.cc;h=33b98ee905541633632374f08ec3334806f9f3d6;hp=82b7c3b154c009e845b97c81ba4fccd5a3195527;hb=HEAD;hpb=f542d0542b1e51ca7dd12bc6b96f6a299371ae8d diff --git a/svrt_generator.cc b/svrt_generator.cc index 82b7c3b..33b98ee 100644 --- a/svrt_generator.cc +++ b/svrt_generator.cc @@ -18,7 +18,7 @@ * General Public License for more details. * * You should have received a copy of the GNU General Public License - * along with selector. If not, see . + * along with svrt. If not, see . * */ @@ -145,22 +145,46 @@ VignetteGenerator *new_generator(int nb) { extern "C" { - struct VignetteSet { - int n_problem; - int nb_vignettes; - int width; - int height; - unsigned char *data; - }; - - void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) { - VignetteGenerator *vg = new_generator(n_problem); - result->n_problem = n_problem; - result->nb_vignettes = nb_vignettes; - result->width = Vignette::width; - result->height = Vignette::height; - result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height); - delete vg; +struct VignetteSet { + int n_problem; + int nb_vignettes; + int width; + int height; + unsigned char *data; +}; + +void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels, + VignetteSet *result) { + Vignette tmp; + + if(n_problem < 1 || n_problem > NB_PROBLEMS) { + printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem); + exit(1); + } + + VignetteGenerator *vg = new_generator(n_problem); + result->n_problem = n_problem; + result->nb_vignettes = nb_vignettes; + result->width = Vignette::width; + result->height = Vignette::height; + result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height); + + unsigned char *s = result->data; + for(int i = 0; i < nb_vignettes; i++) { + if(labels[i] == 0 || labels[i] == 1) { + vg->generate(labels[i], &tmp); + } else { + printf("Vignette class label has to be 0 or 1. Provided value is %ld.\n", labels[i]); + exit(1); + } + + int *r = tmp.content; + for(int k = 0; k < Vignette::width * Vignette::height; k++) { + *s++ = *r++; + } } + delete vg; +} + }