X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=speed.py;h=075b07ee3bb54deda3dc34b02cc05074d4d39381;hp=e5b0e3a6d13c777afbe6a58bc9cacd46e35293ff;hb=HEAD;hpb=20285925e51c7adc6e7bb64bb9d2a5cab92c6aac diff --git a/speed.py b/speed.py index e5b0e3a..8363a6c 100755 --- a/speed.py +++ b/speed.py @@ -8,19 +8,19 @@ import time, torch if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") sync = torch.cuda.synchronize else: - device = torch.device('cpu') + device = torch.device("cpu") sync = lambda: None max_duration = 30 d1, d2, d3 = 2048, 2048, 2048 -for t in [ torch.float32, torch.float16 ]: +for t in [torch.float32, torch.float16]: try: - a = torch.rand(d1, d2, device = device, dtype = t) - b = torch.rand(d2, d3, device = device, dtype = t) + a = torch.rand(d1, d2, device=device, dtype=t) + b = torch.rand(d2, d3, device=device, dtype=t) nb_runs = 0 sync() @@ -31,15 +31,15 @@ for t in [ torch.float32, torch.float16 ]: sync() duration = time.perf_counter() - start_time - nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops + nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops speed = nb_flop / duration - for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: - if speed < 1e3: break + for u in ["", "K", "M", "G", "T", "P"]: + if speed < 1e3: + break speed /= 1e3 - print(f'{speed:.02f} {u}flops with {t} on {device}') + print(f"{speed:.02f} {u}flops with {t} on {device}") except: - - print(f'Cannot try with {t}') + print(f"{t} is not available on {device}")