diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py
index 94867b6adf66443ff4adb902573c7482eca9c1da..7736d2cb9b362b6dbc5ba1d920bae9ff084a2a2e 100644
--- a/src/mednet/scripts/train_analysis.py
+++ b/src/mednet/scripts/train_analysis.py
@@ -66,6 +66,8 @@ def create_figures(
 
     import matplotlib.pyplot as plt
 
+    from matplotlib.axes import Axes
+    from matplotlib.figure import Figure
     from matplotlib.ticker import MaxNLocator
 
     figures = []
@@ -77,8 +79,8 @@ def create_figures(
             continue
 
         fig, ax = plt.subplots(1, 1)
-        ax = typing.cast(plt.Axes, ax)
-        fig = typing.cast(plt.Figure, fig)
+        ax = typing.cast(Axes, ax)
+        fig = typing.cast(Figure, fig)
 
         if len(curves) == 1:
             # there is only one curve, just plot it