X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=speed.py;h=075b07ee3bb54deda3dc34b02cc05074d4d39381;hp=f682a169076aca29f3f7b4b695bb55fc179b305d;hb=HEAD;hpb=31b3e7141d2acb0cb1ff298d2b7fca3911889f25 diff --git a/speed.py b/speed.py index f682a16..8363a6c 100755 --- a/speed.py +++ b/speed.py @@ -1,28 +1,45 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import time, torch if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") + sync = torch.cuda.synchronize else: - device = torch.device('cpu') - -nb_runs = 10000 -d1, d2, d3 = 50000, 256, 512 - -a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device) - -start_time = time.perf_counter() -for k in range(nb_runs): - c = a @ b -duration = time.perf_counter() - start_time - -nb_flop = float(nb_runs * d1 * d2 * d3) -speed = nb_flop / duration - -for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: - if speed < 1e3: break - speed /= 1e3 - -print(f'{speed:.02f} {u}flops on {device}') - + device = torch.device("cpu") + sync = lambda: None + +max_duration = 30 +d1, d2, d3 = 2048, 2048, 2048 + +for t in [torch.float32, torch.float16]: + try: + a = torch.rand(d1, d2, device=device, dtype=t) + b = torch.rand(d2, d3, device=device, dtype=t) + nb_runs = 0 + + sync() + start_time = time.perf_counter() + while time.perf_counter() - start_time < max_duration: + c = a @ b + nb_runs += 1 + sync() + duration = time.perf_counter() - start_time + + nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops + speed = nb_flop / duration + + for u in ["", "K", "M", "G", "T", "P"]: + if speed < 1e3: + break + speed /= 1e3 + + print(f"{speed:.02f} {u}flops with {t} on {device}") + + except: + print(f"{t} is not available on {device}")