projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
minidiffusion.py
diff --git
a/minidiffusion.py
b/minidiffusion.py
index
2c54d19
..
6fd8564
100755
(executable)
--- a/
minidiffusion.py
+++ b/
minidiffusion.py
@@
-18,10
+18,14
@@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample_gaussian_mixture(nb):
p, std = 0.3, 0.2
def sample_gaussian_mixture(nb):
p, std = 0.3, 0.2
- result = torch.
empty(nb, 1).normal_(0, std)
+ result = torch.
randn(nb, 1) * std
result = result + torch.sign(torch.rand(result.size()) - p) / 2
return result
result = result + torch.sign(torch.rand(result.size()) - p) / 2
return result
+def sample_ramp(nb):
+ result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
+ return result
+
def sample_two_discs(nb):
a = torch.rand(nb) * math.pi * 2
b = torch.rand(nb).sqrt()
def sample_two_discs(nb):
a = torch.rand(nb) * math.pi * 2
b = torch.rand(nb).sqrt()
@@
-35,8
+39,9
@@
def sample_two_discs(nb):
def sample_disc_grid(nb):
a = torch.rand(nb) * math.pi * 2
b = torch.rand(nb).sqrt()
def sample_disc_grid(nb):
a = torch.rand(nb) * math.pi * 2
b = torch.rand(nb).sqrt()
- q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
- r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
+ N = 4
+ q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
+ r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
b = b * 0.1
result = torch.empty(nb, 2)
result[:, 0] = a.cos() * b + q
b = b * 0.1
result = torch.empty(nb, 2)
result[:, 0] = a.cos() * b + q
@@
-59,6
+64,7
@@
def sample_mnist(nb):
samplers = {
'gaussian_mixture': sample_gaussian_mixture,
samplers = {
'gaussian_mixture': sample_gaussian_mixture,
+ 'ramp': sample_ramp,
'two_discs': sample_two_discs,
'disc_grid': sample_disc_grid,
'spiral': sample_spiral,
'two_discs': sample_two_discs,
'disc_grid': sample_disc_grid,
'spiral': sample_spiral,
@@
-179,7
+185,7
@@
train_mean, train_std = train_input.mean(), train_input.std()
# Model
if train_input.dim() == 2:
# Model
if train_input.dim() == 2:
- nh =
64
+ nh =
256
model = nn.Sequential(
nn.Linear(train_input.size(1) + 1, nh),
model = nn.Sequential(
nn.Linear(train_input.size(1) + 1, nh),
@@
-197,6
+203,8
@@
elif train_input.dim() == 4:
model.to(device)
model.to(device)
+print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
+
######################################################################
# Train
######################################################################
# Train
@@
-228,7
+236,7
@@
for k in range(args.nb_epochs):
ema.step()
ema.step()
-
if k%10 == 0:
print(f'{k} {acc_loss / train_input.size(0)}')
+ print(f'{k} {acc_loss / train_input.size(0)}')
ema.copy()
ema.copy()
@@
-281,18
+289,20
@@
if train_input.dim() == 2:
x = generate((1000, 2), model)
x = generate((1000, 2), model)
- ax.set_xlim(-1.
25, 1.2
5)
- ax.set_ylim(-1.
25, 1.2
5)
+ ax.set_xlim(-1.
5, 1.
5)
+ ax.set_ylim(-1.
5, 1.
5)
ax.set(aspect = 1)
ax.set(aspect = 1)
-
- d = train_input[:x.size(0)].detach().to('cpu').numpy()
- ax.scatter(d[:, 0], d[:, 1],
- color = 'lightblue', label = 'Train')
+ ax.spines.right.set_visible(False)
+ ax.spines.top.set_visible(False)
d = x.detach().to('cpu').numpy()
ax.scatter(d[:, 0], d[:, 1],
facecolors = 'none', color = 'red', label = 'Synthesis')
d = x.detach().to('cpu').numpy()
ax.scatter(d[:, 0], d[:, 1],
facecolors = 'none', color = 'red', label = 'Synthesis')
+ d = train_input[:x.size(0)].detach().to('cpu').numpy()
+ ax.scatter(d[:, 0], d[:, 1],
+ s = 1.0, color = 'blue', label = 'Train')
+
ax.legend(frameon = False, loc = 2)
filename = f'diffusion_{args.data}.pdf'
ax.legend(frameon = False, loc = 2)
filename = f'diffusion_{args.data}.pdf'