X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=README.md;h=f4e07af22d33e897e1f457d43d81a78dfd3387c3;hb=de4d7faef08d682d83c075253e532af54fd39c45;hp=9e350b8fc7252c4d0ab66c134c954b5ac91e4af8;hpb=3feef9000c7201dc25b872d9a604a0faf1caca3b;p=pysvrt.git diff --git a/README.md b/README.md index 9e350b8..f4e07af 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,8 @@ # Introduction # -This is the port of the Synthetic Visual Reasoning Test to the pytorch -framework. - -The main function is - -``` -torch.ByteTensor svrt.generate_vignettes(int problem_number, torch.LongTensor labels) -``` - -where - - * `problem_number` indicates which of the 23 problem to use - * `labels` indicates the boolean labels of the vignettes to generate - -The returned ByteTensor has three dimensions: - - * Vignette index - * Pixel row - * Pixel col +This is a port of the Synthetic Visual Reasoning Test problems to the +pytorch framework, with an implementation of two convolutional +networks to solve them. # Installation and test # @@ -35,7 +19,44 @@ Note that the image generation does not take advantage of GPUs or multi-core, and can be as fast as 10,000 vignettes per second and as slow as 40 on a 4GHz i7-6700K. -# Vignette compression # +# Vignette generation and compression # + +## Vignette sets ## + +The svrtset.py implements the classes `VignetteSet` and +`CompressedVignetteSet` with the following constructor + +``` +__init__(problem_number, nb_samples, batch_size, cuda = False, logger = None) +``` + +and the following method to return one batch + +``` +(torch.FloatTensor, torch.LongTensor) get_batch(b) +``` + +as a pair composed of a 4d 'input' Tensor (i.e. single channel 128x128 +images), and a 1d 'target' Tensor (i.e. Boolean labels). + +## Low-level functions ## + +The main function for genering vignettes is + +``` +torch.ByteTensor svrt.generate_vignettes(int problem_number, torch.LongTensor labels) +``` + +where + + * `problem_number` indicates which of the 23 problem to use + * `labels` indicates the boolean labels of the vignettes to generate + +The returned ByteTensor has three dimensions: + + * Vignette index + * Pixel row + * Pixel col The two additional functions @@ -62,15 +83,7 @@ See vignette_set.py for a class CompressedVignetteSet using it. # Testing convolution networks # -The file - -``` -cnn-svrt.py -``` - -provides the implementation of two deep networks, and use the -compressed vignette code to allow the training with several millions -vignettes on a PC with 16Gb and a GPU with 8Gb. - -The networks were designed by Afroze Baqapuri during an internship at -Idiap. +The file `cnn-svrt.py` provides the implementation of two deep +networks designed by Afroze Baqapuri during an internship at Idiap, +and allows to train them with several millions vignettes on a PC with +16Gb and a GPU with 8Gb.