11import logging
22import pickle
3+ import yaml
34from abc import ABC , abstractmethod
45from inspect import isfunction , signature
56from pathlib import Path
@@ -29,17 +30,17 @@ class WrapperModel(BaseModel, ABC):
2930
3031 @configured_validate_arguments
3132 def __init__ (
32- self ,
33- model : Any ,
34- model_type : ModelType ,
35- data_preprocessing_function : Optional [Callable [[pd .DataFrame ], Any ]] = None ,
36- model_postprocessing_function : Optional [Callable [[Any ], Any ]] = None ,
37- name : Optional [str ] = None ,
38- feature_names : Optional [Iterable ] = None ,
39- classification_threshold : Optional [float ] = 0.5 ,
40- classification_labels : Optional [Iterable ] = None ,
41- batch_size : Optional [int ] = None ,
42- ** kwargs ,
33+ self ,
34+ model : Any ,
35+ model_type : ModelType ,
36+ data_preprocessing_function : Optional [Callable [[pd .DataFrame ], Any ]] = None ,
37+ model_postprocessing_function : Optional [Callable [[Any ], Any ]] = None ,
38+ name : Optional [str ] = None ,
39+ feature_names : Optional [Iterable ] = None ,
40+ classification_threshold : Optional [float ] = 0.5 ,
41+ classification_labels : Optional [Iterable ] = None ,
42+ batch_size : Optional [int ] = None ,
43+ ** kwargs ,
4344 ) -> None :
4445 """
4546 Parameters
@@ -188,13 +189,11 @@ def model_predict(self, data):
188189
189190 def save (self , local_path : Union [str , Path ]) -> None :
190191 super ().save (local_path )
191-
192+ self . save_wrapper_meta ( local_path )
192193 if self .data_preprocessing_function :
193194 self .save_data_preprocessing_function (local_path )
194195 if self .model_postprocessing_function :
195196 self .save_model_postprocessing_function (local_path )
196- if self .batch_size :
197- self .save_batch_size (local_path )
198197
199198 @abstractmethod
200199 def save_model (self , path : Union [str , Path ]) -> None :
@@ -215,15 +214,21 @@ def save_model_postprocessing_function(self, local_path: Union[str, Path]):
215214 with open (Path (local_path ) / "giskard-model-postprocessing-function.pkl" , "wb" ) as f :
216215 cloudpickle .dump (self .model_postprocessing_function , f , protocol = pickle .DEFAULT_PROTOCOL )
217216
218- def save_batch_size (self , local_path : Union [str , Path ]):
219- with Path (local_path ).joinpath ("batch_size" ).open ("w" ) as f :
220- f .write (str (self .batch_size ))
217+ def save_wrapper_meta (self , local_path ):
218+ with open (Path (local_path ) / "giskard-model-wrapper-meta.yaml" , "w" ) as f :
219+ yaml .dump (
220+ {
221+ "batch_size" : self .batch_size ,
222+ },
223+ f ,
224+ default_flow_style = False ,
225+ )
221226
222227 @classmethod
223228 def load (cls , local_dir , ** kwargs ):
224229 kwargs ["data_preprocessing_function" ] = cls .load_data_preprocessing_function (local_dir )
225230 kwargs ["model_postprocessing_function" ] = cls .load_model_postprocessing_function (local_dir )
226- kwargs [ "batch_size" ] = cls ._load_batch_size (local_dir )
231+ kwargs . update ( cls .load_wrapper_meta (local_dir ) )
227232 model_id , meta = cls .read_meta_from_local_dir (local_dir )
228233 constructor_params = meta .__dict__
229234 constructor_params ["id" ] = model_id
@@ -262,10 +267,15 @@ def load_model_postprocessing_function(cls, local_path: Union[str, Path]):
262267 return cloudpickle .load (f )
263268 return None
264269
265- @staticmethod
266- def _load_batch_size (local_path : Union [str , Path ]):
267- file_path = Path (local_path ) / "batch_size"
268- if file_path .exists ():
269- with file_path .open ("r" ) as f :
270- return int (f .read ())
271- return None
270+ @classmethod
271+ def load_wrapper_meta (cls , local_dir ):
272+ wrapper_meta_file = Path (local_dir ) / "giskard-model-wrapper-meta.yaml"
273+ if wrapper_meta_file .exists ():
274+ with open (wrapper_meta_file ) as f :
275+ wrapper_meta = yaml .load (f , Loader = yaml .Loader )
276+ wrapper_meta ["batch_size" ] = int (wrapper_meta ["batch_size" ]) if wrapper_meta ["batch_size" ] else None
277+ return wrapper_meta
278+ else :
279+ raise ValueError (
280+ f"Cannot load model ({ cls .__module__ } .{ cls .__name__ } ), " f"{ wrapper_meta_file } file not found"
281+ )
0 commit comments