projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
a4bc783
)
OCD update.
author
Francois Fleuret
<francois@fleuret.org>
Mon, 22 Jun 2020 07:59:39 +0000
(09:59 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Mon, 22 Jun 2020 07:59:39 +0000
(09:59 +0200)
ddpol.py
patch
|
blob
|
history
diff --git
a/ddpol.py
b/ddpol.py
index
51e7636
..
35d98a0
100755
(executable)
--- a/
ddpol.py
+++ b/
ddpol.py
@@
-83,12
+83,6
@@
def compute_mse(nb_train_samples):
return mse_train.median(0).values, mse_test.median(0).values
return mse_train.median(0).values, mse_test.median(0).values
-######################################################################
-
-torch.manual_seed(0)
-
-mse_train, mse_test = compute_mse(args.nb_train_samples)
-
######################################################################
# Plot the MSE vs. degree curves
######################################################################
# Plot the MSE vs. degree curves
@@
-100,7
+94,14
@@
ax.set_ylim(1e-5, 1)
ax.set_xlabel('Polynomial degree', labelpad = 10)
ax.set_ylabel('MSE', labelpad = 10)
ax.set_xlabel('Polynomial degree', labelpad = 10)
ax.set_ylabel('MSE', labelpad = 10)
-ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5)
+ax.axvline(x = args.nb_train_samples - 1,
+ color = 'gray', linewidth = 0.5, linestyle = '--')
+ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
+ fontsize = 10, color = 'gray',
+ rotation = 90, rotation_mode='anchor')
+
+mse_train, mse_test = compute_mse(args.nb_train_samples)
+
ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')