Skip to content
Snippets Groups Projects
Commit 1e3a8031 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'scatter-plot' into 'master'

Make the model optional in the scatter plot

See merge request !12
parents 5dd16f8e 57e2dc76
Branches
Tags
1 merge request!12Make the model optional in the scatter plot
Pipeline #40887 passed
...@@ -51,11 +51,14 @@ def plot_boundary_decision(algorithm, scores, score_labels, threshold, ...@@ -51,11 +51,14 @@ def plot_boundary_decision(algorithm, scores, score_labels, threshold,
xx, yy = np.meshgrid( xx, yy = np.meshgrid(
np.linspace(x_min, x_max, resolution), np.linspace(x_min, x_max, resolution),
np.linspace(y_min, y_max, resolution)) np.linspace(y_min, y_max, resolution))
temp = np.c_[xx.ravel(), yy.ravel()]
temp = algorithm.preprocess(temp)
Z = (algorithm.fuse(temp) > threshold).reshape(xx.shape)
contourf = plt.contour(xx, yy, Z, 1, alpha=1, cmap=plt.cm.gray) contourf = None
if algorithm is not None:
temp = np.c_[xx.ravel(), yy.ravel()]
temp = algorithm.preprocess(temp)
Z = (algorithm.fuse(temp) > threshold).reshape(xx.shape)
contourf = plt.contour(xx, yy, Z, 1, alpha=1, cmap=plt.cm.gray)
if do_grouping: if do_grouping:
gen = grouping(X[Y == 0, :], **kwargs) gen = grouping(X[Y == 0, :], **kwargs)
...@@ -94,9 +97,9 @@ $ bob fusion boundary -vvv {sys1,sys2}/scores-eval -m /path/to/Model.pkl ...@@ -94,9 +97,9 @@ $ bob fusion boundary -vvv {sys1,sys2}/scores-eval -m /path/to/Model.pkl
""") """)
@click.argument('scores', nargs=-1, required=True, @click.argument('scores', nargs=-1, required=True,
type=click.Path(exists=True)) type=click.Path(exists=True))
@click.option('-m', '--model-file', required=True, @click.option('-m', '--model-file', required=False,
help='The path to where the algorithm will be loaded from.') help='The path to where the algorithm will be loaded from.')
@click.option('-t', '--threshold', type=click.FLOAT, required=True, @click.option('-t', '--threshold', type=click.FLOAT, required=False,
help='The threshold to classify scores after fusion. Usually ' help='The threshold to classify scores after fusion. Usually '
'calculated from fused development set.') 'calculated from fused development set.')
@click.option('-g', '--group', type=click.INT, default=0, show_default=True, @click.option('-g', '--group', type=click.INT, default=0, show_default=True,
...@@ -127,8 +130,10 @@ def boundary(scores, model_file, threshold, group, grouping, output, x_label, ...@@ -127,8 +130,10 @@ def boundary(scores, model_file, threshold, group, grouping, output, x_label,
plotted on the x-axis. plotted on the x-axis.
""" """
# load the algorithm # load the algorithm
algorithm = Algorithm() algorithm = None
algorithm = algorithm.load(model_file) if model_file:
algorithm = Algorithm().load(model_file)
assert threshold is not None, "threshold must be provided with the model"
# load the scores # load the scores
score_lines_list_eval = [load_score(path) for path in scores] score_lines_list_eval = [load_score(path) for path in scores]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment