Cleaning up.
[pysvrt.git] / cnn-svrt.py
index 79d3ff4..ad73f0c 100755 (executable)
@@ -65,7 +65,7 @@ args = parser.parse_args()
 
 log_file = open(args.log_file, 'w')
 
-print('Logging into ' + args.log_file)
+print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 def log_string(s):
     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
@@ -112,7 +112,7 @@ def train_model(train_input, train_target):
         model.cuda()
         criterion.cuda()
 
-    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
+    optimizer, bs = optim.Adam(model.parameters(), lr = 1e-1), 100
 
     for k in range(0, args.nb_epochs):
         acc_loss = 0.0
@@ -144,9 +144,7 @@ def nb_errors(model, data_input, data_target, bs = 100):
 
 ######################################################################
 
-# for problem_number in range(1, 24):
-
-for problem_number in [ 3 ]:
+for problem_number in range(1, 24):
     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
     test_input, test_target = generate_set(problem_number, args.nb_test_samples)