X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=mi_estimator.py;h=68fd51f75e7a226b09da6e0b2667228537a848ca;hp=f8b859dd6fc07dc69f2b8fecfa2c30f4d1eeace3;hb=e916a8624b6a09737696c124f35059030f0f20e4;hpb=236238fdfe7d65612b58fbbb5bb29cff4ec45d54 diff --git a/mi_estimator.py b/mi_estimator.py index f8b859d..68fd51f 100755 --- a/mi_estimator.py +++ b/mi_estimator.py @@ -1,21 +1,9 @@ #!/usr/bin/env python -######################################################################### -# This program is free software: you can redistribute it and/or modify # -# it under the terms of the version 3 of the GNU General Public License # -# as published by the Free Software Foundation. # -# # -# This program is distributed in the hope that it will be useful, but # -# WITHOUT ANY WARRANTY; without even the implied warranty of # -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # -# General Public License for more details. # -# # -# You should have received a copy of the GNU General Public License # -# along with this program. If not, see . # -# # -# Written by and Copyright (C) Francois Fleuret # -# Contact for comments & bug reports # -######################################################################### +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret import argparse, math, sys from copy import deepcopy @@ -38,19 +26,20 @@ else: parser = argparse.ArgumentParser( description = '''An implementation of a Mutual Information estimator with a deep model -Three different toy data-sets are implemented: + Three different toy data-sets are implemented, each consists of + pairs of samples, that may be from different spaces: - (1) Two MNIST images of same class. The "true" MI is the log of the - number of used MNIST classes. + (1) Two MNIST images of same class. The "true" MI is the log of the + number of used MNIST classes. - (2) One MNIST image and a pair of real numbers whose difference is - the class of the image. The "true" MI is the log of the number of - used MNIST classes. + (2) One MNIST image and a pair of real numbers whose difference is + the class of the image. The "true" MI is the log of the number of + used MNIST classes. - (3) Two 1d sequences, the first with a single peak, the second with - two peaks, and the height of the peak in the first is the - difference of timing of the peaks in the second. The "true" MI is - the log of the number of possible peak heights.''', + (3) Two 1d sequences, the first with a single peak, the second with + two peaks, and the height of the peak in the first is the + difference of timing of the peaks in the second. The "true" MI is + the log of the number of possible peak heights.''', formatter_class = argparse.ArgumentDefaultsHelpFormatter ) @@ -197,6 +186,8 @@ def create_image_values_pairs(train = False): ###################################################################### +# + def create_sequences_pairs(train = False): nb, length = 10000, 1024 noise_level = 2e-2 @@ -229,9 +220,6 @@ def create_sequences_pairs(train = False): noise = b.new(b.size()).normal_(0, noise_level) b = b + noise - # a = (a - a.mean()) / a.std() - # b = (b - b.mean()) / b.std() - return a, b, ha ######################################################################