Minor update.
[pysvrt.git] / svrt.c
diff --git a/svrt.c b/svrt.c
index fdee66f..4969d5a 100644 (file)
--- a/svrt.c
+++ b/svrt.c
@@ -19,7 +19,7 @@
  *  General Public License for more details.
  *
  *  You should have received a copy of the GNU General Public License
- *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
+ *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
  *
  */
 
 
 #include "svrt_generator.h"
 
-THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
+THByteStorage *compress(THByteStorage *x) {
+  long k, g, n;
+
+  k = 0; n = 0;
+  while(k < x->size) {
+    g = 0;
+    while(k < x->size && x->data[k] == 255 && g < 255) { g++; k++; }
+    n++;
+    if(k < x->size && g < 255) { k++; }
+  }
+
+  if(x->data[k-1] == 0) {
+    n++;
+  }
+
+  THByteStorage *result = THByteStorage_newWithSize(n);
+
+  k = 0; n = 0;
+  while(k < x->size) {
+    g = 0;
+    while(k < x->size && x->data[k] == 255 && g < 255) { g++; k++; }
+    result->data[n++] = g;
+    if(k < x->size && g < 255) { k++; }
+  }
+  if(x->data[k-1] == 0) {
+    result->data[n++] = 0;
+  }
+
+  return result;
+}
+
+THByteStorage *uncompress(THByteStorage *x) {
+  long k, g, n;
+
+  k = 0;
+  for(n = 0; n < x->size - 1; n++) {
+    k = k + x->data[n];
+    if(x->data[n] < 255) { k++; }
+  }
+  k = k + x->data[n];
+
+  THByteStorage *result = THByteStorage_newWithSize(k);
+
+  k = 0;
+  for(n = 0; n < x->size - 1; n++) {
+    for(g = 0; g < x->data[n]; g++) {
+      result->data[k++] = 255;
+    }
+    if(x->data[n] < 255) {
+      result->data[k++] = 0;
+    }
+  }
+  for(g = 0; g < x->data[n]; g++) {
+    result->data[k++] = 255;
+  }
+
+  return result;
+}
+
+void seed(long s) {
+  srand48(s);
+}
+
+THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) {
   struct VignetteSet vs;
+  long nb_vignettes;
   long st0, st1, st2;
   long v, i, j;
+  long *m, *l;
   unsigned char *a, *b;
 
-  svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
-  printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
+  if(THLongTensor_nDimension(labels) != 1) {
+    printf("Label tensor has to be of dimension 1.\n");
+    exit(1);
+  }
+
+  nb_vignettes = THLongTensor_size(labels, 0);
+  m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels);
+  st0 = THLongTensor_stride(labels, 0);
+  l = (long *) malloc(sizeof(long) * nb_vignettes);
+  for(v = 0; v < nb_vignettes; v++) {
+    l[v] = *m;
+    m += st0;
+  }
+
+  svrt_generate_vignettes(n_problem, nb_vignettes, l, &vs);
+  free(l);
 
   THLongStorage *size = THLongStorage_newWithSize(3);
   size->data[0] = vs.nb_vignettes;
@@ -61,5 +140,7 @@ THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
     }
   }
 
+  free(vs.data);
+
   return result;
 }