Skip to content
Snippets Groups Projects
Commit dafa65b0 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Load only what is annotated (allow flexibility)

parent 9a317867
No related branches found
No related tags found
1 merge request!353DFV and multiple fixes
...@@ -99,7 +99,7 @@ def validate(args): ...@@ -99,7 +99,7 @@ def validate(args):
'--batch': schema.Use(int), '--batch': schema.Use(int),
'--iterations': schema.Use(int), '--iterations': schema.Use(int),
'<database>': lambda n: n in ('fv3d',), '<database>': lambda n: n in ('fv3d',),
'<protocol>': lambda n: n in ('central',), '<protocol>': lambda n: n in ('central', 'left', 'right'),
'<group>': lambda n: n in ('dev',), '<group>': lambda n: n in ('dev',),
str: object, #ignores strings we don't care about str: object, #ignores strings we don't care about
}, ignore_extra_keys=True) }, ignore_extra_keys=True)
...@@ -142,17 +142,20 @@ def main(user_input=None): ...@@ -142,17 +142,20 @@ def main(user_input=None):
from ..preprocessor.utils import poly_to_mask from ..preprocessor.utils import poly_to_mask
features = None features = None
target = None target = None
loaded = 0
for k, sample in enumerate(objects): for k, sample in enumerate(objects):
if args['--samples'] is not None and k >= args['--samples']: break if args['--samples'] is not None and loaded >= args['--samples']:
break
path = sample.make_path(directory=db.original_directory, path = sample.make_path(directory=db.original_directory,
extension=db.original_extension) extension=db.original_extension)
logger.info('Loading sample %d/%d (%s)...', k, len(objects), path) logger.info('Loading sample %d/%d (%s)...', loaded, len(objects), path)
image = sample.load(directory=db.original_directory, image = sample.load(directory=db.original_directory,
extension=db.original_extension) extension=db.original_extension)
if not (hasattr(image, 'metadata') and 'roi' in image.metadata): if not (hasattr(image, 'metadata') and 'roi' in image.metadata):
logger.info('Skipping sample (no ROI)') logger.info('Skipping sample (no ROI)')
continue continue
loaded += 1
# copy() required by skimage.util.shape.view_as_windows() # copy() required by skimage.util.shape.view_as_windows()
image = image.copy().astype('float64') / 255. image = image.copy().astype('float64') / 255.
...@@ -169,12 +172,17 @@ def main(user_input=None): ...@@ -169,12 +172,17 @@ def main(user_input=None):
mask = mask[1:-1, 1:-1] mask = mask[1:-1, 1:-1]
for y in range(windows.shape[0]): for y in range(windows.shape[0]):
for x in range(windows.shape[1]): for x in range(windows.shape[1]):
idx = (k*windows.shape[0]*windows.shape[1]) + (y*windows.shape[1]) + x idx = ((loaded-1)*windows.shape[0]*windows.shape[1]) + \
(y*windows.shape[1]) + x
features[idx,:-2] = windows[y,x].flatten() features[idx,:-2] = windows[y,x].flatten()
features[idx,-2] = y+1 features[idx,-2] = y+1
features[idx,-1] = x+1 features[idx,-1] = x+1
target[idx] = mask[y,x] target[idx] = mask[y,x]
# if number of loaded samples is smaller than expected, clip features array
features = features[:loaded*windows.shape[0]*windows.shape[1]]
target = target[:loaded*windows.shape[0]*windows.shape[1]]
# normalize w.r.t. dimensions # normalize w.r.t. dimensions
features[:,-2] /= image.shape[0] features[:,-2] /= image.shape[0]
features[:,-1] /= image.shape[1] features[:,-1] /= image.shape[1]
...@@ -214,7 +222,7 @@ def main(user_input=None): ...@@ -214,7 +222,7 @@ def main(user_input=None):
except KeyboardInterrupt: except KeyboardInterrupt:
print() #avoids the ^C line print() #avoids the ^C line
logger.info('Gracefully stopping training before limit (%d iterations)', logger.info('Gracefully stopping training before limit (%d iterations)',
args['--batch'] args['--batch'])
break break
# describe errors # describe errors
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment