- def forward(self, z):
- output = self.model(z.view(z.size(0), -1, 1, 1))
+ def encode(self, x):
+ output = self.encoder(x).view(x.size(0), 2, -1)
+ mu, log_var = output[:, 0], output[:, 1]
+ return mu, log_var
+
+ def decode(self, z):
+ # return self.decoder(z.view(z.size(0), -1, 1, 1)).permute(0, 2, 3, 1)
+ output = self.decoder(z.view(z.size(0), -1, 1, 1))