Update.
[pytorch.git] / speed.py
index f682a16..8363a6c 100755 (executable)
--- a/speed.py
+++ b/speed.py
@@ -1,28 +1,45 @@
 #!/usr/bin/env python
 
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
 import time, torch
 
 if torch.cuda.is_available():
-    device = torch.device('cuda')
+    device = torch.device("cuda")
+    sync = torch.cuda.synchronize
 else:
-    device = torch.device('cpu')
-
-nb_runs = 10000
-d1, d2, d3 = 50000, 256, 512
-
-a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device)
-
-start_time = time.perf_counter()
-for k in range(nb_runs):
-    c = a @ b
-duration = time.perf_counter() - start_time
-
-nb_flop = float(nb_runs * d1 * d2 * d3)
-speed = nb_flop / duration
-
-for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
-    if speed < 1e3: break
-    speed /= 1e3
-
-print(f'{speed:.02f} {u}flops on {device}')
-
+    device = torch.device("cpu")
+    sync = lambda: None
+
+max_duration = 30
+d1, d2, d3 = 2048, 2048, 2048
+
+for t in [torch.float32, torch.float16]:
+    try:
+        a = torch.rand(d1, d2, device=device, dtype=t)
+        b = torch.rand(d2, d3, device=device, dtype=t)
+        nb_runs = 0
+
+        sync()
+        start_time = time.perf_counter()
+        while time.perf_counter() - start_time < max_duration:
+            c = a @ b
+            nb_runs += 1
+        sync()
+        duration = time.perf_counter() - start_time
+
+        nb_flop = float(nb_runs * d1 * d2 * d3 * 2)  # 1 multiply-and-add is 2 ops
+        speed = nb_flop / duration
+
+        for u in ["", "K", "M", "G", "T", "P"]:
+            if speed < 1e3:
+                break
+            speed /= 1e3
+
+        print(f"{speed:.02f} {u}flops with {t} on {device}")
+
+    except:
+        print(f"{t} is not available on {device}")