From c8d0cf6842db19f84a78c1b3a4d2666b323a5d4a Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 15 Jul 2022 17:07:47 +0200 Subject: [PATCH] Update. --- main.py | 9 ++++++--- picoclvr.py | 35 ++++++++++++++++++++--------------- result_picoclvr_0007.png | Bin 630 -> 639 bytes 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 85cf4cf..11cf0a3 100755 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ parser.add_argument('--log_filename', type = str, default = 'train.log') parser.add_argument('--download', - type = bool, default = False) + action='store_true', default = False) parser.add_argument('--seed', type = int, default = 0) @@ -67,11 +67,14 @@ parser.add_argument('--dropout', type = float, default = 0.1) parser.add_argument('--synthesis_sampling', - type = bool, default = True) + action='store_true', default = True) parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') +parser.add_argument('--picoclvr_many_colors', + action='store_true', default = False) + ###################################################################### args = parser.parse_args() @@ -353,7 +356,7 @@ if args.data == 'wiki103': elif args.data == 'mnist': task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) + task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') diff --git a/picoclvr.py b/picoclvr.py index 601bdf7..6dd8114 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -71,6 +71,25 @@ color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) ###################################################################### +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [ ] + + for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: + s += [ f'there is {c}' ] + + if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] + if square_i[r] < height//3: s += [ f'{c} top' ] + if square_j[r] >= width - width//3: s += [ f'{c} right' ] + if square_j[r] < width//3: s += [ f'{c} left' ] + + for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: + if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] + if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] + if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] + if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] + + return s + def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False): @@ -93,21 +112,7 @@ def generate(nb, height = 6, width = 8, # generates all the true relations - s = [ ] - - for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - s += [ f'there is {c}' ] - - if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] - if square_i[r] < height//3: s += [ f'{c} top' ] - if square_j[r] >= width - width//3: s += [ f'{c} right' ] - if square_j[r] < width//3: s += [ f'{c} left' ] - - for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] - if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] - if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] - if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] + s = all_properties(height, width, nb_squares, square_i, square_j, square_c) # pick at most max_nb_statements at random diff --git a/result_picoclvr_0007.png b/result_picoclvr_0007.png index e36efb698edc19ebcaf1a4a7f444287f955dba06..569bfc3729c752a8e9835c2ea1ed2bd69895da22 100644 GIT binary patch delta 614 zcmV-s0-62x1pfq(B!5OpL_t(|ob6dXcEd0T7JYrq*sh&3X0%V!%jD%UIYEyhj~P36 z&e+#M9i+-g0)!kV9&DD_20jq_0THm)n)T7s(he*`&;FSj%#a8HQox0vP95 z?$4Lm?coIO=m;*d=s`wq0UNpfXCKj6y9`&HZEO!9KiotB0u*~$ogpuC;t>nx^ol$yw8-Pc>^wkmrF)wgHahy*Wv%+NLQQ(KCcCL{ zEpMZIDE$AVwLv-BmVuMPE0vU^kP2^kd$oJ}4_&K5a?;M1& delta 605 zcmV-j0;2u@1oi}wB!4|gL_t(|ob6dzwu2xHO+0T$y8ji}HRXL+Fo8^1JDj8Ur$}bN zBvM@06|RTr+YWpOf7{QzEI`zgACCtB%<~KYfY|T|DhDDEIaCB%SB+Xsr0gyW0`*?+ zvYlY1mfhB3MnEwjf-)0;2sO^CN3*uO(aIKXMBx?JqnZOEWPj6~M?d5nH^EbD9RN_J z`d}i#q-I`c+kza@%-5Im=wX|?5mR%=!G&-A6j-~xd5`ijgZI092CUuQyeIiMfPsbj zAnN|%UAHfUQn&l+3II-J9-%hn69EbOs*XG_4B8E3xdXCpZ&WN?r=rOWMzQhJhwoC5 zdUgj&Xp^8)!+&)EHTjRqy%bm$CA6V;o3}j1cu)S3hRz``a6G!22BxAeqmQ>d#yGE` zp!STG8f34gV@mH&doBQ|VUJd3Yn4*ArEu7>hQ95sbi8?`y}#@Bh48G~%QCA+$*R2# zRaN~{)at*e+u^JIrmvz*COw+HWd@P{PmC+8-sc?hQh!}dv_0me$WW_Vzh;Os>MrAR z5_#*Q`Pxo7`Z+1HWR`3VJDd77jQFiAie!#Y3mW6VG}m0iR+rB9K=n6p@qfx_u>#eQ!6_JM(25%BDI3#O+&!*WSEWz2`So1`Odq2;41c9bccMe4<~Pk_xI44EE@B+bcgPc%l~~0?X9Aa0Ytg( r8@{m5B=Mxr3nk1Y?QY|)+ZVz=