Skip to content
Snippets Groups Projects
Commit a43b31fd authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'fix-167' into 'master'

Creating a setter for AlgorithmTransformer that sets projector file

Closes #167

See merge request !269
parents 00e9b63b ebdf02a1
Branches
Tags
1 merge request!269Creating a setter for AlgorithmTransformer that sets projector file
Pipeline #56921 passed
...@@ -358,6 +358,22 @@ def checkpoint_vanilla_biometrics( ...@@ -358,6 +358,22 @@ def checkpoint_vanilla_biometrics(
if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy):
pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir
# Here we need to check if the LAST transformer is
# 1. is instance of CheckpointWrapper
# 2. Its estimator is instance of AlgorithmTransformer
if (
isinstance(pipeline.transformer[-1], CheckpointWrapper)
and hasattr(pipeline.transformer[-1].estimator, "estimator")
and isinstance(
pipeline.transformer[-1].estimator.estimator, AlgorithmTransformer
)
):
pipeline.transformer[
-1
].estimator.estimator.projector_file = bio_ref_scores_dir
else: else:
pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper( pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir, hash_fn=hash_fn pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir, hash_fn=hash_fn
......
...@@ -51,7 +51,7 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator): ...@@ -51,7 +51,7 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
raise ValueError(f"{instance} needs to be picklable") raise ValueError(f"{instance} needs to be picklable")
self.instance = instance self.instance = instance
self.projector_file = projector_file self._projector_file = projector_file
super().__init__(**kwargs) super().__init__(**kwargs)
def fit(self, X, y=None): def fit(self, X, y=None):
...@@ -74,6 +74,16 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator): ...@@ -74,6 +74,16 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
for data, metadata in zip(X, metadata) for data, metadata in zip(X, metadata)
] ]
@property
def projector_file(self):
return self._projector_file
@projector_file.setter
def projector_file(self, v):
base_dir = os.path.dirname(v)
filename = os.path.basename(self.projector_file)
self._projector_file = os.path.join(base_dir, filename)
def _more_tags(self): def _more_tags(self):
return { return {
"stateless": not self.instance.requires_projector_training, "stateless": not self.instance.requires_projector_training,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment