projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
main.py
diff --git
a/main.py
b/main.py
index
c51035c
..
969b47f
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-16,14
+16,6
@@
import mygpt, tasks, problems
######################################################################
######################################################################
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
-
-######################################################################
-
def str2bool(x):
x = x.lower()
def str2bool(x):
x = x.lower()
@@
-55,6
+47,8
@@
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
@@
-217,6
+211,14
@@
if args.result_dir is None:
######################################################################
######################################################################
+if not args.force_cpu and torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
default_task_args = {
"addition": {
"model": "352M",
default_task_args = {
"addition": {
"model": "352M",
@@
-832,7
+834,7
@@
if nb_epochs_finished >= nb_epochs:
deterministic_synthesis=args.deterministic_synthesis,
)
deterministic_synthesis=args.deterministic_synthesis,
)
-time_pred_result =
None
+time_pred_result =
datetime.datetime.now()
it = 0
it = 0
@@
-910,10
+912,9
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
)
time_current_result = datetime.datetime.now()
)
time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
- )
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
time_pred_result = time_current_result
checkpoint = {
time_pred_result = time_current_result
checkpoint = {