projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
6f7f454
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Thu, 3 Sep 2020 06:18:03 +0000
(08:18 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 3 Sep 2020 06:18:03 +0000
(08:18 +0200)
speed.py
patch
|
blob
|
history
diff --git
a/speed.py
b/speed.py
index
e03b3b7
..
e5b0e3a
100755
(executable)
--- a/
speed.py
+++ b/
speed.py
@@
-9,30
+9,37
@@
import time, torch
if torch.cuda.is_available():
device = torch.device('cuda')
if torch.cuda.is_available():
device = torch.device('cuda')
- sync =
lambda: torch.cuda.synchronize()
+ sync =
torch.cuda.synchronize
else:
device = torch.device('cpu')
sync = lambda: None
else:
device = torch.device('cpu')
sync = lambda: None
-
nb_runs = 1000
0
+
max_duration = 3
0
d1, d2, d3 = 2048, 2048, 2048
for t in [ torch.float32, torch.float16 ]:
d1, d2, d3 = 2048, 2048, 2048
for t in [ torch.float32, torch.float16 ]:
- a = torch.rand(d1, d2, device = device, dtype = t)
- b = torch.rand(d2, d3, device = device, dtype = t)
+ try:
+ a = torch.rand(d1, d2, device = device, dtype = t)
+ b = torch.rand(d2, d3, device = device, dtype = t)
+ nb_runs = 0
- sync()
- start_time = time.perf_counter()
- for k in range(nb_runs):
- c = a @ b
- sync()
- duration = time.perf_counter() - start_time
+ sync()
+ start_time = time.perf_counter()
+ while time.perf_counter() - start_time < max_duration:
+ c = a @ b
+ nb_runs += 1
+ sync()
+ duration = time.perf_counter() - start_time
- nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
- speed = nb_flop / duration
+
nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
+
speed = nb_flop / duration
- for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
- if speed < 1e3: break
- speed /= 1e3
+
for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
+
if speed < 1e3: break
+
speed /= 1e3
- print(f'{speed:.02f} {u}flops with {t} on {device}')
+ print(f'{speed:.02f} {u}flops with {t} on {device}')
+
+ except:
+
+ print(f'Cannot try with {t}')