Test now saves an example image.
[pysvrt.git] / svrt.c
1
2 /*
3  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
4  *  generator for evaluating classification performance of machine
5  *  learning systems, humans and primates.
6  *
7  *  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9  *
10  *  This file is part of svrt.
11  *
12  *  svrt is free software: you can redistribute it and/or modify it
13  *  under the terms of the GNU General Public License version 3 as
14  *  published by the Free Software Foundation.
15  *
16  *  svrt is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 #include <TH/TH.h>
27
28 #include "svrt_generator.h"
29
30 THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
31   struct VignetteSet vs;
32   long st0, st1, st2;
33   long v, i, j;
34   unsigned char *a, *b;
35
36   svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
37   printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
38
39   THLongStorage *size = THLongStorage_newWithSize(3);
40   size->data[0] = vs.nb_vignettes;
41   size->data[1] = vs.height;
42   size->data[2] = vs.width;
43
44   THByteTensor *result = THByteTensor_newWithSize(size, NULL);
45   THLongStorage_free(size);
46
47   st0 = THByteTensor_stride(result, 0);
48   st1 = THByteTensor_stride(result, 1);
49   st2 = THByteTensor_stride(result, 2);
50
51   unsigned char *r = vs.data;
52   for(v = 0; v < vs.nb_vignettes; v++) {
53     a = THByteTensor_storage(result)->data + THByteTensor_storageOffset(result) + v * st0;
54     for(i = 0; i < vs.height; i++) {
55       b = a + i * st1;
56       for(j = 0; j < vs.width; j++) {
57         *b = (unsigned char) (*r);
58         r++;
59         b += st2;
60       }
61     }
62   }
63
64   return result;
65 }