Skip to content
Snippets Groups Projects

Changes to the biogenerator

Merged Amir MOHAMMADI requested to merge predict into master
1 file
+ 15
1
Compare changes
  • Side-by-side
  • Inline
@@ -153,6 +153,7 @@ def non_existing_files(paths, force=False):
@@ -153,6 +153,7 @@ def non_existing_files(paths, force=False):
def save_predictions(pool, output_dir, key, pred_buffer):
def save_predictions(pool, output_dir, key, pred_buffer):
outpath = make_output_path(output_dir, key)
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
create_directories_safe(os.path.dirname(outpath))
 
logger.debug("Saving predictions for %s", key)
pool.apply_async(save, (np.mean(pred_buffer[key], axis=0), outpath))
pool.apply_async(save, (np.mean(pred_buffer[key], axis=0), outpath))
@@ -193,6 +194,9 @@ def main(argv=None):
@@ -193,6 +194,9 @@ def main(argv=None):
output_dir = get_from_config_or_commandline(
output_dir = get_from_config_or_commandline(
config, 'output_dir', args, defaults, False)
config, 'output_dir', args, defaults, False)
 
assert len(biofiles), "biofiles are empty!"
 
 
logger.info("number_of_parallel_jobs: %d", number_of_parallel_jobs)
if number_of_parallel_jobs > 1:
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
biofiles = biofiles[start:end]
@@ -201,7 +205,12 @@ def main(argv=None):
@@ -201,7 +205,12 @@ def main(argv=None):
paths = (make_output_path(output_dir, f.make_path("", ""))
paths = (make_output_path(output_dir, f.make_path("", ""))
for f in biofiles)
for f in biofiles)
indexes = non_existing_files(paths, force)
indexes = non_existing_files(paths, force)
biofiles = (biofiles[i] for i in indexes)
biofiles = [biofiles[i] for i in indexes]
 
 
if len(biofiles) == 0:
 
logger.warning(
 
"The biofiles are empty after checking for existing files.")
 
return
generator = BioGenerator(
generator = BioGenerator(
database, biofiles, load_data=load_data,
database, biofiles, load_data=load_data,
@@ -210,6 +219,9 @@ def main(argv=None):
@@ -210,6 +219,9 @@ def main(argv=None):
predict_input_fn = bio_predict_input_fn(
predict_input_fn = bio_predict_input_fn(
generator, generator.output_types, generator.output_shapes)
generator, generator.output_types, generator.output_shapes)
 
if checkpoint_path:
 
logger.info("Restoring the model from %s", checkpoint_path)
 
predictions = estimator.predict(
predictions = estimator.predict(
predict_input_fn,
predict_input_fn,
predict_keys=predict_keys,
predict_keys=predict_keys,
@@ -217,6 +229,8 @@ def main(argv=None):
@@ -217,6 +229,8 @@ def main(argv=None):
checkpoint_path=checkpoint_path,
checkpoint_path=checkpoint_path,
)
)
 
logger.info("Saving the predictions in %s", output_dir)
 
pool = Pool()
pool = Pool()
try:
try:
pred_buffer = defaultdict(list)
pred_buffer = defaultdict(list)
Loading