Skip to content

MTCNN is not serializable

With the current implementation, MTCNN is not serializable via pickle. When used locally, it is not an issue, however, when running on the cluster, we are getting lovely messages into workers' logs such as:

distributed.protocol.pickle - INFO - Failed to serialize CheckpointWrapper(estimator=SampleWrapper(estimator=BoundingBoxAnnotatorCrop(annotator=MTCNN(thresholds=(0.1,
                                                                                                         0.2,
                                                                                                         0.2)),
                                                                             eyes_cropper=FaceEyesNorm(final_image_size=(112,
                                                                                                                         112),
                                                                                                       reference_eyes_location={'bottomright': (112,
                                                                                                                                                112),
                                                                                                                                'leye': (55,
                                                                                                                                         72),
                                                                                                                                'reye': (55,
                                                                                                                                         40),
                                                                                                                                'topleft': (0,
                                                                                                                                            0)})),
                                          fit_extra_arguments=(),
                                          input_attribute='data',
                                          output_attribute='data',
                                          transform_...',
                                                                      'annotations'),)),
                  extension='.h5',
                  features_dir='/idiap/temp/cecabert/experiments/bob101/results-sge/ijbc/cropper',
                  hash_fn=<function hash_string at 0x7fa50d891000>,
                  load_func=<function load at 0x7fa507f5ed40>,
                  model_path='/idiap/temp/cecabert/experiments/bob101/results-sge/ijbc/cropper.pkl',
                  sample_attribute='data',
                  save_func=<function save at 0x7fa507f5ef80>). Exception: cannot pickle '_thread.RLock' object
distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/remote/idiap.svm/temp.biometric03/cecabert/mambaforge/envs/bob_deps/lib/python3.10/site-packages/distributed/protocol/core.py", line 111, in loads
    return msgpack.loads(
  File "msgpack/_unpacker.pyx", line 194, in msgpack._cmsgpack.unpackb
  File "/remote/idiap.svm/temp.biometric03/cecabert/mambaforge/envs/bob_deps/lib/python3.10/site-packages/distributed/protocol/core.py", line 103, in _decode_default
    return merge_and_deserialize(
  File "/remote/idiap.svm/temp.biometric03/cecabert/mambaforge/envs/bob_deps/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 488, in merge_and_deserialize
    return deserialize(header, merged_frames, deserializers=deserializers)
  File "/remote/idiap.svm/temp.biometric03/cecabert/mambaforge/envs/bob_deps/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 417, in deserialize
    return loads(header, frames)
  File "/remote/idiap.svm/temp.biometric03/cecabert/mambaforge/envs/bob_deps/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 180, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type CheckpointWrapper.
Traceback (most recent call last):
  File "/remote/idiap.svm/temp.biometric03/cecabert/mambaforge/envs/bob_deps/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 40, in dumps
    result = pickle.dumps(x, **dump_kwargs)
AttributeError: Can't pickle local object 'WeakSet.__init__.<locals>._remove'

The issue it that in some case, it tries to serialize the already loaded underlying Tensorflow Graph. This can be solved with the same mechanism used in PyTorchModel class by overriding the __getstate__ method as follow:

def __getstate__(self):
    # Handling unpicklable objects
    state = {}
    for key, value in super().__getstate__().items():
        if key != '_fun':
           state[key] = value
    state['_fun'] = None
    return state

With this change, the serialization now works properly and can be tested with:

mtcnn = MTCNN()
mtcnn.mtcnn_fun    # Force instantiation of TF graph
other = pickle.loads(pickle.dumps(mtcnn))  
# No AttributeError: Can't pickle local object 'WeakSet.__init__.<locals>._remove'
# TF graph will be lazily initialized if needed