Commit dafa65b0 authored by André Anjos's avatar André Anjos

Load only what is annotated (allow flexibility)

parent 9a317867
......@@ -99,7 +99,7 @@ def validate(args):
'--batch': schema.Use(int),
'--iterations': schema.Use(int),
'<database>': lambda n: n in ('fv3d',),
'<protocol>': lambda n: n in ('central',),
'<protocol>': lambda n: n in ('central', 'left', 'right'),
'<group>': lambda n: n in ('dev',),
str: object, #ignores strings we don't care about
}, ignore_extra_keys=True)
......@@ -142,17 +142,20 @@ def main(user_input=None):
from ..preprocessor.utils import poly_to_mask
features = None
target = None
loaded = 0
for k, sample in enumerate(objects):
if args['--samples'] is not None and k >= args['--samples']: break
if args['--samples'] is not None and loaded >= args['--samples']:
break
path = sample.make_path(directory=db.original_directory,
extension=db.original_extension)
logger.info('Loading sample %d/%d (%s)...', k, len(objects), path)
logger.info('Loading sample %d/%d (%s)...', loaded, len(objects), path)
image = sample.load(directory=db.original_directory,
extension=db.original_extension)
if not (hasattr(image, 'metadata') and 'roi' in image.metadata):
logger.info('Skipping sample (no ROI)')
continue
loaded += 1
# copy() required by skimage.util.shape.view_as_windows()
image = image.copy().astype('float64') / 255.
......@@ -169,12 +172,17 @@ def main(user_input=None):
mask = mask[1:-1, 1:-1]
for y in range(windows.shape[0]):
for x in range(windows.shape[1]):
idx = (k*windows.shape[0]*windows.shape[1]) + (y*windows.shape[1]) + x
idx = ((loaded-1)*windows.shape[0]*windows.shape[1]) + \
(y*windows.shape[1]) + x
features[idx,:-2] = windows[y,x].flatten()
features[idx,-2] = y+1
features[idx,-1] = x+1
target[idx] = mask[y,x]
# if number of loaded samples is smaller than expected, clip features array
features = features[:loaded*windows.shape[0]*windows.shape[1]]
target = target[:loaded*windows.shape[0]*windows.shape[1]]
# normalize w.r.t. dimensions
features[:,-2] /= image.shape[0]
features[:,-1] /= image.shape[1]
......@@ -214,7 +222,7 @@ def main(user_input=None):
except KeyboardInterrupt:
print() #avoids the ^C line
logger.info('Gracefully stopping training before limit (%d iterations)',
args['--batch']
args['--batch'])
break
# describe errors
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment