[RFC] PyTorchModel

With current implementation of the PyTorchModel the weights and the architecture need to be provided through checkpoint_path and config in order to use the transformer. This constraint can be alleviate using the TorchScript script feature ref.

TorchScript convert any torch.nn.Module into a persistent executable that can be loaded and used directly without needing to build the architecture first. It basically saves the "code" and the weights into a single file in the same fashion as Tensorflow. Moreover, some optimizations can be turned on during the saving phase such as converting all the ops into constant ops, freezing graph and so on.

Such mechanism can be used in the PyTorchModel base class to greatly simplify how we add new models. An example is provided below:


class TorchScriptModel(TransformerMixin, BaseEstimator):

    def __init__(self,
                 model_path,
                 preprocessor,
                 memory_demanding=False,
                 device=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.memory_demanding = memory_demanding
        # Model
        self.model_path = model_path
        self.model = None
        self.preprocessor = preprocessor
        if device is None:
            device = pt.device('cpu')
        self.device = device

    def _model(self):
        if self.model is None:
            # For now, we suggest to disable the Jit Autocast Pass,
            # As the issue: https://github.com/pytorch/pytorch/issues/75956
            pt._C._jit_set_autocast_mode(False)
            self.model = pt.jit.load(self.model_path)
            self.model.eval()
            self.model.to(self.device)
        return self.model

    def transform(self, X):
        X = check_array(X, allow_nd=True).astype(np.float32)
        model = self._model()

        def _transform(x):
            x = pt.from_numpy(x)
            with pt.no_grad():
                # Preprocess
                x = x.to(self.device)
                x = self.preprocessor(x)
                # Extract embedding
                x = model(x)
                return x.cpu().numpy()

        if self.memory_demanding:
            features = []
            for x in X:
                f = _transform(x[None, ...])
                features.append(f)
            features = np.asarray(features)
            if features.ndim >= 3:
                features = np.vstack(features)
            return features
        else:
            return _transform(X)

    def __getstate__(self):
        # Handling unpicklable objects
        d = {}
        for key, value in self.__dict__.items():
            if key != 'model':
                d[key] = value
        d['model'] = None
        return d

    def _more_tags(self):
        return {"requires_fit": False}

Moreover, some architecture are defined as a derived class of PyTorchModel such as IResnetXXX whereas others are defined through function returning a PipelineSimple instance. The consistency could be improved for instance by using classmethod to create a certain type of architecture. For instance it could be something like:


class TorchScriptModel(TransformerMixin, BaseEstimator):

    @classmethod
    def IResNet(cls, version: Enum):
        # Mechanic to retrive model from idiap server
        return cls(...)

What are your thoughts on this @ydayer, @flavio.tarsetti, @lcolbois ?