diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py index 94867b6adf66443ff4adb902573c7482eca9c1da..7736d2cb9b362b6dbc5ba1d920bae9ff084a2a2e 100644 --- a/src/mednet/scripts/train_analysis.py +++ b/src/mednet/scripts/train_analysis.py @@ -66,6 +66,8 @@ def create_figures( import matplotlib.pyplot as plt + from matplotlib.axes import Axes + from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator figures = [] @@ -77,8 +79,8 @@ def create_figures( continue fig, ax = plt.subplots(1, 1) - ax = typing.cast(plt.Axes, ax) - fig = typing.cast(plt.Figure, fig) + ax = typing.cast(Axes, ax) + fig = typing.cast(Figure, fig) if len(curves) == 1: # there is only one curve, just plot it