svrt.generate_vignettes now takes a 1d label tensor as arguments.
[pysvrt.git] / svrt_generator.cc
index 80cfd12..90f781d 100644 (file)
@@ -153,7 +153,8 @@ struct VignetteSet {
   unsigned char *data;
 };
 
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+                             VignetteSet *result) {
   Vignette tmp;
 
   VignetteGenerator *vg = new_generator(n_problem);
@@ -165,7 +166,7 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *resul
 
   unsigned char *s = result->data;
   for(int i = 0; i < nb_vignettes; i++) {
-    vg->generate(drand48() < 0.5 ? 1 : 0, &tmp);
+    vg->generate(labels[i], &tmp);
     int *r = tmp.content;
     for(int k = 0; k < Vignette::width * Vignette::height; k++) {
       *s++ = *r++;