Update.
[python.git] / covid19.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import os, time
9 import numpy, csv
10 import matplotlib.pyplot as plt
11 import matplotlib.dates as mdates
12 import urllib.request
13
14 ######################################################################
15
16 def gentle_download(url, delay = 86400):
17     filename = url[url.rfind('/') + 1:]
18     if not os.path.isfile(filename) or os.path.getmtime(filename) < time.time() - delay:
19         print(f'Retrieving {url}')
20         urllib.request.urlretrieve(url, filename)
21     return filename
22
23 ######################################################################
24
25 nbcases_filename = gentle_download(
26     'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
27 )
28
29 ######################################################################
30
31 with open(nbcases_filename, newline='') as csvfile:
32     reader = csv.reader(csvfile, delimiter=',')
33     times = []
34     nb_cases = {}
35     time_col = 5
36     for row_nb, row in enumerate(reader):
37         for col_nb, field in enumerate(row):
38             if row_nb == 0 and col_nb >= time_col:
39                 times.append(time.mktime(time.strptime(field, '%m/%d/%y')))
40             if row_nb >= 1:
41                 if col_nb == 1:
42                     country = field
43                     if not country in nb_cases:
44                         nb_cases[country] = numpy.zeros(len(times))
45                 elif col_nb >= time_col:
46                     # if field == '': field = '0'
47                     nb_cases[country][col_nb - time_col] += int(field)
48
49 countries = list(nb_cases.keys())
50 countries.sort()
51 print('Countries: ', countries)
52
53 nb_cases['World'] = sum(nb_cases.values())
54
55 ######################################################################
56
57 fig = plt.figure()
58 ax = fig.add_subplot(1, 1, 1)
59
60 ax.yaxis.grid(color='gray', linestyle='-', linewidth=0.25)
61 ax.set_title('Nb. of COVID-19 cases')
62 ax.set_xlabel('Date', labelpad = 10)
63 ax.set_yscale('log')
64
65 myFmt = mdates.DateFormatter('%b %d')
66
67 ax.xaxis.set_major_formatter(myFmt)
68 dates = mdates.epoch2num(times)
69
70 for key, color, label in [
71         ('World', 'blue', 'World'),
72         ('Switzerland', 'red', 'Switzerland'),
73         ('France', 'lightgreen', 'France'),
74         ('US', 'black', 'USA'),
75         ('Korea, South', 'gray', 'South Korea'),
76         ('Italy', 'purple', 'Italy'),
77         ('China', 'orange', 'China')
78 ]:
79     ax.plot(dates, nb_cases[key],
80             color = color, label = label, linewidth = 2)
81
82 ax.legend(frameon = False)
83
84 plt.show()
85
86 fig.savefig('covid19_nb_cases.png')
87
88 ######################################################################