@@ -193,6 +193,8 @@ def save(self, local_path: Union[str, Path]) -> None:
193193 self .save_data_preprocessing_function (local_path )
194194 if self .model_postprocessing_function :
195195 self .save_model_postprocessing_function (local_path )
196+ if self .batch_size :
197+ self .save_batch_size (local_path )
196198
197199 @abstractmethod
198200 def save_model (self , path : Union [str , Path ]) -> None :
@@ -213,10 +215,15 @@ def save_model_postprocessing_function(self, local_path: Union[str, Path]):
213215 with open (Path (local_path ) / "giskard-model-postprocessing-function.pkl" , "wb" ) as f :
214216 cloudpickle .dump (self .model_postprocessing_function , f , protocol = pickle .DEFAULT_PROTOCOL )
215217
218+ def save_batch_size (self , local_path : Union [str , Path ]):
219+ with Path (local_path ).joinpath ("batch_size" ).open ("w" ) as f :
220+ f .write (str (self .batch_size ))
221+
216222 @classmethod
217223 def load (cls , local_dir , ** kwargs ):
218224 kwargs ["data_preprocessing_function" ] = cls .load_data_preprocessing_function (local_dir )
219225 kwargs ["model_postprocessing_function" ] = cls .load_model_postprocessing_function (local_dir )
226+ kwargs ["batch_size" ] = cls ._load_batch_size (local_dir )
220227 model_id , meta = cls .read_meta_from_local_dir (local_dir )
221228 constructor_params = meta .__dict__
222229 constructor_params ["id" ] = model_id
@@ -244,8 +251,7 @@ def load_data_preprocessing_function(cls, local_path: Union[str, Path]):
244251 if file_path .exists ():
245252 with open (file_path , "rb" ) as f :
246253 return cloudpickle .load (f )
247- else :
248- return None
254+ return None
249255
250256 @classmethod
251257 def load_model_postprocessing_function (cls , local_path : Union [str , Path ]):
@@ -254,5 +260,12 @@ def load_model_postprocessing_function(cls, local_path: Union[str, Path]):
254260 if file_path .exists ():
255261 with open (file_path , "rb" ) as f :
256262 return cloudpickle .load (f )
257- else :
258- return None
263+ return None
264+
265+ @staticmethod
266+ def _load_batch_size (local_path : Union [str , Path ]):
267+ file_path = Path (local_path ) / "batch_size"
268+ if file_path .exists ():
269+ with file_path .open ("r" ) as f :
270+ return int (f .read ())
271+ return None
0 commit comments