Minor update.
[pysvrt.git] / test-svrt.py
index 92fc554..ad3677a 100755 (executable)
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
-#  along with selector.  If not, see <http://www.gnu.org/licenses/>.
+#  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
 
 import time
 
 import torch
 
 import time
 
 import torch
+import torchvision
 
 from torch import optim
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 
 from torch import optim
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 from torchvision import datasets, transforms, utils
 
-from _ext import svrt
+import svrt
+
+labels = torch.LongTensor(12).zero_()
+labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+
+x = svrt.generate_vignettes(4, labels)
+
+print('compression factor {:f}'.format(x.storage().size() / svrt.compress(x.storage()).size()))
+
+x = x.view(x.size(0), 1, x.size(1), x.size(2))
+
+x.div_(255)
 
 
-train_set = svrt.generate_vignettes(12, 1234)
+torchvision.utils.save_image(x, 'example.png')
 
 
-print(str(type(train_set)), train_set.size())
+print('Wrote example.png')