Update.
[pytorch.git] / speed.py
index e5b0e3a..8363a6c 100755 (executable)
--- a/speed.py
+++ b/speed.py
@@ -8,19 +8,19 @@
 import time, torch
 
 if torch.cuda.is_available():
-    device = torch.device('cuda')
+    device = torch.device("cuda")
     sync = torch.cuda.synchronize
 else:
-    device = torch.device('cpu')
+    device = torch.device("cpu")
     sync = lambda: None
 
 max_duration = 30
 d1, d2, d3 = 2048, 2048, 2048
 
-for t in [ torch.float32, torch.float16 ]:
+for t in [torch.float32, torch.float16]:
     try:
-        a = torch.rand(d1, d2, device = device, dtype = t)
-        b = torch.rand(d2, d3, device = device, dtype = t)
+        a = torch.rand(d1, d2, device=device, dtype=t)
+        b = torch.rand(d2, d3, device=device, dtype=t)
         nb_runs = 0
 
         sync()
@@ -31,15 +31,15 @@ for t in [ torch.float32, torch.float16 ]:
         sync()
         duration = time.perf_counter() - start_time
 
-        nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
+        nb_flop = float(nb_runs * d1 * d2 * d3 * 2)  # 1 multiply-and-add is 2 ops
         speed = nb_flop / duration
 
-        for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
-            if speed < 1e3: break
+        for u in ["", "K", "M", "G", "T", "P"]:
+            if speed < 1e3:
+                break
             speed /= 1e3
 
-        print(f'{speed:.02f} {u}flops with {t} on {device}')
+        print(f"{speed:.02f} {u}flops with {t} on {device}")
 
     except:
-
-        print(f'Cannot try with {t}')
+        print(f"{t} is not available on {device}")