diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index f15f4c7d89ae7d6a54da74fc8594320d06e491c9..3a469a8c3a59458befbd9d4f64afda7b429e3ab6 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -99,10 +99,11 @@ def experiment( logger.info("Started train analysis") from mednet.libs.common.scripts.train_analysis import train_analysis + logdir = train_output_folder / "logs" ctx.invoke( train_analysis, - logdir=train_output_folder / "logs", - output_folder=output_folder / "trainlog.pdf", + logdir=logdir, + output_folder=train_output_folder, ) logger.info("Ended train analysis") diff --git a/src/mednet/libs/classification/tests/test_cli_classification.py b/src/mednet/libs/classification/tests/test_cli_classification.py index 8eb8a93c69488d1d4e37bf0650e21e1c84e12ecf..7e81d045561de23cfd8e6a4c11a52ace8e48c8ff 100644 --- a/src/mednet/libs/classification/tests/test_cli_classification.py +++ b/src/mednet/libs/classification/tests/test_cli_classification.py @@ -456,7 +456,7 @@ def test_experiment(temporary_basedir): ) == 1 ) - assert (output_folder / "trainlog.pdf").exists() + assert (output_folder / "model" / "trainlog.pdf").exists() assert ( len( list(