Test now saves an example image.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 16:06:35 +0000 (18:06 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 16:06:35 +0000 (18:06 +0200)
svrt.c
svrt_generator.cc
test-svrt.py

diff --git a/svrt.c b/svrt.c
index 0f53642..fdee66f 100644 (file)
--- a/svrt.c
+++ b/svrt.c
 
 THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
   struct VignetteSet vs;
 
 THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
   struct VignetteSet vs;
+  long st0, st1, st2;
+  long v, i, j;
+  unsigned char *a, *b;
 
   svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
   printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
 
   THLongStorage *size = THLongStorage_newWithSize(3);
 
   svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
   printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
 
   THLongStorage *size = THLongStorage_newWithSize(3);
-  size->data[0] = nb_vignettes;
+  size->data[0] = vs.nb_vignettes;
   size->data[1] = vs.height;
   size->data[2] = vs.width;
 
   THByteTensor *result = THByteTensor_newWithSize(size, NULL);
   THLongStorage_free(size);
 
   size->data[1] = vs.height;
   size->data[2] = vs.width;
 
   THByteTensor *result = THByteTensor_newWithSize(size, NULL);
   THLongStorage_free(size);
 
-  /* st0 = THByteTensor_stride(result, 0); */
-  /* st1 = THByteTensor_stride(result, 1); */
-  /* st2 = THByteTensor_stride(result, 2); */
+  st0 = THByteTensor_stride(result, 0);
+  st1 = THByteTensor_stride(result, 1);
+  st2 = THByteTensor_stride(result, 2);
+
+  unsigned char *r = vs.data;
+  for(v = 0; v < vs.nb_vignettes; v++) {
+    a = THByteTensor_storage(result)->data + THByteTensor_storageOffset(result) + v * st0;
+    for(i = 0; i < vs.height; i++) {
+      b = a + i * st1;
+      for(j = 0; j < vs.width; j++) {
+        *b = (unsigned char) (*r);
+        r++;
+        b += st2;
+      }
+    }
+  }
 
   return result;
 }
 
   return result;
 }
index 82b7c3b..80cfd12 100644 (file)
@@ -145,22 +145,34 @@ VignetteGenerator *new_generator(int nb) {
 
 extern "C" {
 
 
 extern "C" {
 
-  struct VignetteSet {
-    int n_problem;
-    int nb_vignettes;
-    int width;
-    int height;
-    unsigned char *data;
-  };
-
-  void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
-    VignetteGenerator *vg = new_generator(n_problem);
-    result->n_problem = n_problem;
-    result->nb_vignettes = nb_vignettes;
-    result->width = Vignette::width;
-    result->height = Vignette::height;
-    result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
-    delete vg;
+struct VignetteSet {
+  int n_problem;
+  int nb_vignettes;
+  int width;
+  int height;
+  unsigned char *data;
+};
+
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+  Vignette tmp;
+
+  VignetteGenerator *vg = new_generator(n_problem);
+  result->n_problem = n_problem;
+  result->nb_vignettes = nb_vignettes;
+  result->width = Vignette::width;
+  result->height = Vignette::height;
+  result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
+
+  unsigned char *s = result->data;
+  for(int i = 0; i < nb_vignettes; i++) {
+    vg->generate(drand48() < 0.5 ? 1 : 0, &tmp);
+    int *r = tmp.content;
+    for(int k = 0; k < Vignette::width * Vignette::height; k++) {
+      *s++ = *r++;
+    }
   }
 
   }
 
+  delete vg;
+}
+
 }
 }
index 92fc554..6b5f826 100755 (executable)
 import time
 
 import torch
 import time
 
 import torch
+import torchvision
 
 from torch import optim
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 
 from torch import optim
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 from _ext import svrt
 
 from torchvision import datasets, transforms, utils
 
 from _ext import svrt
 
-train_set = svrt.generate_vignettes(12, 1234)
+train_set = svrt.generate_vignettes(12, 64)
 
 print(str(type(train_set)), train_set.size())
 
 print(str(type(train_set)), train_set.size())
+
+train_set.div_(255)
+
+torchvision.utils.save_image(train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)), 'example.png')