projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
speed.py
diff --git
a/speed.py
b/speed.py
index
e5b0e3a
..
8363a6c
100755
(executable)
--- a/
speed.py
+++ b/
speed.py
@@
-8,19
+8,19
@@
import time, torch
if torch.cuda.is_available():
import time, torch
if torch.cuda.is_available():
- device = torch.device(
'cuda'
)
+ device = torch.device(
"cuda"
)
sync = torch.cuda.synchronize
else:
sync = torch.cuda.synchronize
else:
- device = torch.device(
'cpu'
)
+ device = torch.device(
"cpu"
)
sync = lambda: None
max_duration = 30
d1, d2, d3 = 2048, 2048, 2048
sync = lambda: None
max_duration = 30
d1, d2, d3 = 2048, 2048, 2048
-for t in [
torch.float32, torch.float16
]:
+for t in [
torch.float32, torch.float16
]:
try:
try:
- a = torch.rand(d1, d2, device
= device, dtype =
t)
- b = torch.rand(d2, d3, device
= device, dtype =
t)
+ a = torch.rand(d1, d2, device
=device, dtype=
t)
+ b = torch.rand(d2, d3, device
=device, dtype=
t)
nb_runs = 0
sync()
nb_runs = 0
sync()
@@
-31,15
+31,15
@@
for t in [ torch.float32, torch.float16 ]:
sync()
duration = time.perf_counter() - start_time
sync()
duration = time.perf_counter() - start_time
- nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
+ nb_flop = float(nb_runs * d1 * d2 * d3 * 2)
# 1 multiply-and-add is 2 ops
speed = nb_flop / duration
speed = nb_flop / duration
- for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
- if speed < 1e3: break
+ for u in ["", "K", "M", "G", "T", "P"]:
+ if speed < 1e3:
+ break
speed /= 1e3
speed /= 1e3
- print(f
'{speed:.02f} {u}flops with {t} on {device}'
)
+ print(f
"{speed:.02f} {u}flops with {t} on {device}"
)
except:
except:
-
- print(f'Cannot try with {t}')
+ print(f"{t} is not available on {device}")