Skip to content

Commit bd93c86

Browse files
committed
Save the batch size
1 parent d6be499 commit bd93c86

2 files changed

Lines changed: 38 additions & 5 deletions

File tree

python-client/giskard/models/base/wrapper.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def save(self, local_path: Union[str, Path]) -> None:
193193
self.save_data_preprocessing_function(local_path)
194194
if self.model_postprocessing_function:
195195
self.save_model_postprocessing_function(local_path)
196+
if self.batch_size:
197+
self.save_batch_size(local_path)
196198

197199
@abstractmethod
198200
def save_model(self, path: Union[str, Path]) -> None:
@@ -213,10 +215,15 @@ def save_model_postprocessing_function(self, local_path: Union[str, Path]):
213215
with open(Path(local_path) / "giskard-model-postprocessing-function.pkl", "wb") as f:
214216
cloudpickle.dump(self.model_postprocessing_function, f, protocol=pickle.DEFAULT_PROTOCOL)
215217

218+
def save_batch_size(self, local_path: Union[str, Path]):
219+
with Path(local_path).joinpath("batch_size").open("w") as f:
220+
f.write(str(self.batch_size))
221+
216222
@classmethod
217223
def load(cls, local_dir, **kwargs):
218224
kwargs["data_preprocessing_function"] = cls.load_data_preprocessing_function(local_dir)
219225
kwargs["model_postprocessing_function"] = cls.load_model_postprocessing_function(local_dir)
226+
kwargs["batch_size"] = cls._load_batch_size(local_dir)
220227
model_id, meta = cls.read_meta_from_local_dir(local_dir)
221228
constructor_params = meta.__dict__
222229
constructor_params["id"] = model_id
@@ -244,8 +251,7 @@ def load_data_preprocessing_function(cls, local_path: Union[str, Path]):
244251
if file_path.exists():
245252
with open(file_path, "rb") as f:
246253
return cloudpickle.load(f)
247-
else:
248-
return None
254+
return None
249255

250256
@classmethod
251257
def load_model_postprocessing_function(cls, local_path: Union[str, Path]):
@@ -254,5 +260,12 @@ def load_model_postprocessing_function(cls, local_path: Union[str, Path]):
254260
if file_path.exists():
255261
with open(file_path, "rb") as f:
256262
return cloudpickle.load(f)
257-
else:
258-
return None
263+
return None
264+
265+
@staticmethod
266+
def _load_batch_size(local_path: Union[str, Path]):
267+
file_path = Path(local_path) / "batch_size"
268+
if file_path.exists():
269+
with file_path.open("r") as f:
270+
return int(f.read())
271+
return None

python-client/tests/models/test_wrapper_model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import tempfile
2+
13
import numpy as np
24
import pandas as pd
5+
from sklearn.linear_model import LogisticRegression
36

4-
from giskard import Dataset
7+
from giskard import Dataset, Model
58
from giskard.models.base.wrapper import WrapperModel
69

710

@@ -36,3 +39,20 @@ def save_model(self, path):
3639
model.expected_batch_size = [20, 1]
3740

3841
model.predict(dataset)
42+
43+
44+
def test_wrapper_model_saves_and_loads_batch_size():
45+
base_model = LogisticRegression()
46+
model = Model(
47+
base_model.predict_proba,
48+
model_type="classification",
49+
feature_names=["one", "two"],
50+
classification_threshold=0.5,
51+
classification_labels=[0, 1],
52+
)
53+
with tempfile.TemporaryDirectory() as tmpdir:
54+
model.batch_size = 127
55+
model.save(tmpdir)
56+
loaded_model = Model.load(tmpdir)
57+
58+
assert loaded_model.batch_size == 127

0 commit comments

Comments
 (0)