[RFC] PyTorchModel
With current implementation of the PyTorchModel
the weights and the architecture need to be provided through checkpoint_path
and config
in order to use the transformer. This constraint can be alleviate using the TorchScript
script feature ref.
TorchScript convert any torch.nn.Module
into a persistent executable that can be loaded and used directly without needing to build the architecture first. It basically saves the "code" and the weights into a single file in the same fashion as Tensorflow. Moreover, some optimizations can be turned on during the saving phase such as converting all the ops into constant ops, freezing graph and so on.
Such mechanism can be used in the PyTorchModel
base class to greatly simplify how we add new models. An example is provided below:
class TorchScriptModel(TransformerMixin, BaseEstimator):
def __init__(self,
model_path,
preprocessor,
memory_demanding=False,
device=None,
**kwargs):
super().__init__(**kwargs)
self.memory_demanding = memory_demanding
# Model
self.model_path = model_path
self.model = None
self.preprocessor = preprocessor
if device is None:
device = pt.device('cpu')
self.device = device
def _model(self):
if self.model is None:
# For now, we suggest to disable the Jit Autocast Pass,
# As the issue: https://github.com/pytorch/pytorch/issues/75956
pt._C._jit_set_autocast_mode(False)
self.model = pt.jit.load(self.model_path)
self.model.eval()
self.model.to(self.device)
return self.model
def transform(self, X):
X = check_array(X, allow_nd=True).astype(np.float32)
model = self._model()
def _transform(x):
x = pt.from_numpy(x)
with pt.no_grad():
# Preprocess
x = x.to(self.device)
x = self.preprocessor(x)
# Extract embedding
x = model(x)
return x.cpu().numpy()
if self.memory_demanding:
features = []
for x in X:
f = _transform(x[None, ...])
features.append(f)
features = np.asarray(features)
if features.ndim >= 3:
features = np.vstack(features)
return features
else:
return _transform(X)
def __getstate__(self):
# Handling unpicklable objects
d = {}
for key, value in self.__dict__.items():
if key != 'model':
d[key] = value
d['model'] = None
return d
def _more_tags(self):
return {"requires_fit": False}
Moreover, some architecture are defined as a derived class of PyTorchModel
such as IResnetXXX
whereas others are defined through function returning a PipelineSimple
instance. The consistency could be improved for instance by using classmethod
to create a certain type of architecture. For instance it could be something like:
class TorchScriptModel(TransformerMixin, BaseEstimator):
@classmethod
def IResNet(cls, version: Enum):
# Mechanic to retrive model from idiap server
return cls(...)
What are your thoughts on this @ydayer, @flavio.tarsetti, @lcolbois ?