X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=miniflow.py;h=04b9a23b7e372b041ffbccb97fe2d2ca9bad0436;hb=5ab8211805831629148d7b436b8770590f1987b0;hp=b5c8cb430bceab57c05136b4c26e534b56ab257e;hpb=7080726691be9341436db5a664778679600c5f62;p=pytorch.git diff --git a/miniflow.py b/miniflow.py index b5c8cb4..04b9a23 100755 --- a/miniflow.py +++ b/miniflow.py @@ -33,18 +33,15 @@ def sample_phi(nb): ###################################################################### -# START_LOG_PROBA def LogProba(x, ldj): log_p = ldj - 0.5 * (x**2 + math.log(2*pi)) return log_p -# END_LOG_PROBA ###################################################################### -# START_MODEL class PiecewiseLinear(nn.Module): def __init__(self, nb, xmin, xmax): - super(PiecewiseLinear, self).__init__() + super().__init__() self.xmin = xmin self.xmax = xmax self.nb = nb @@ -59,7 +56,6 @@ class PiecewiseLinear(nn.Module): a = (u - n).clamp(0, 1) x = (1 - a) * y[n] + a * y[n + 1] return x -# END_MODEL def invert(self, y): ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1) @@ -88,7 +84,6 @@ criterion = nn.MSELoss() for k in range(nb_epochs): acc_loss = 0 -# START_OPTIMIZATION for input in train_input.split(batch_size): input.requires_grad_() output = model(input) @@ -103,7 +98,6 @@ for k in range(nb_epochs): optimizer.zero_grad() loss.backward() optimizer.step() -# END_OPTIMIZATION acc_loss += loss.item() if k%10 == 0: print(k, loss.item())