[RFC] PyTorchModel
With current implementation of the PyTorchModel the weights and the architecture need to be provided through checkpoint_path and config in order to use the transformer. This constraint can be alleviate using the TorchScript script feature ref.
TorchScript convert any torch.nn.Module into a persistent executable that can be loaded and used directly without needing to build the architecture first. It basically saves the "code" and the weights into a single file in the same fashion as Tensorflow. Moreover, some optimizations can be turned on during the saving phase such as converting all the ops into constant ops, freezing graph and so on.
Such mechanism can be used in the PyTorchModel base class to greatly simplify how we add new models. An example is provided below:
class TorchScriptModel(TransformerMixin, BaseEstimator):
def __init__(self,
model_path,
preprocessor,
memory_demanding=False,
device=None,
**kwargs):
super().__init__(**kwargs)
self.memory_demanding = memory_demanding
# Model
self.model_path = model_path
self.model = None
self.preprocessor = preprocessor
if device is None:
device = pt.device('cpu')
self.device = device
def _model(self):
if self.model is None:
# For now, we suggest to disable the Jit Autocast Pass,
# As the issue: https://github.com/pytorch/pytorch/issues/75956
pt._C._jit_set_autocast_mode(False)
self.model = pt.jit.load(self.model_path)
self.model.eval()
self.model.to(self.device)
return self.model
def transform(self, X):
X = check_array(X, allow_nd=True).astype(np.float32)
model = self._model()
def _transform(x):
x = pt.from_numpy(x)
with pt.no_grad():
# Preprocess
x = x.to(self.device)
x = self.preprocessor(x)
# Extract embedding
x = model(x)
return x.cpu().numpy()
if self.memory_demanding:
features = []
for x in X:
f = _transform(x[None, ...])
features.append(f)
features = np.asarray(features)
if features.ndim >= 3:
features = np.vstack(features)
return features
else:
return _transform(X)
def __getstate__(self):
# Handling unpicklable objects
d = {}
for key, value in self.__dict__.items():
if key != 'model':
d[key] = value
d['model'] = None
return d
def _more_tags(self):
return {"requires_fit": False}
Moreover, some architecture are defined as a derived class of PyTorchModel such as IResnetXXX whereas others are defined through function returning a PipelineSimple instance. The consistency could be improved for instance by using classmethod to create a certain type of architecture. For instance it could be something like:
class TorchScriptModel(TransformerMixin, BaseEstimator):
@classmethod
def IResNet(cls, version: Enum):
# Mechanic to retrive model from idiap server
return cls(...)
What are your thoughts on this @ydayer, @flavio.tarsetti, @lcolbois ?