Skip to content

Commit a70b89a

Browse files
committed
standardization of bd93c86
1 parent bd93c86 commit a70b89a

2 files changed

Lines changed: 58 additions & 43 deletions

File tree

python-client/giskard/models/base/wrapper.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import pickle
3+
import yaml
34
from abc import ABC, abstractmethod
45
from inspect import isfunction, signature
56
from pathlib import Path
@@ -29,17 +30,17 @@ class WrapperModel(BaseModel, ABC):
2930

3031
@configured_validate_arguments
3132
def __init__(
32-
self,
33-
model: Any,
34-
model_type: ModelType,
35-
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
36-
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
37-
name: Optional[str] = None,
38-
feature_names: Optional[Iterable] = None,
39-
classification_threshold: Optional[float] = 0.5,
40-
classification_labels: Optional[Iterable] = None,
41-
batch_size: Optional[int] = None,
42-
**kwargs,
33+
self,
34+
model: Any,
35+
model_type: ModelType,
36+
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
37+
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
38+
name: Optional[str] = None,
39+
feature_names: Optional[Iterable] = None,
40+
classification_threshold: Optional[float] = 0.5,
41+
classification_labels: Optional[Iterable] = None,
42+
batch_size: Optional[int] = None,
43+
**kwargs,
4344
) -> None:
4445
"""
4546
Parameters
@@ -188,13 +189,11 @@ def model_predict(self, data):
188189

189190
def save(self, local_path: Union[str, Path]) -> None:
190191
super().save(local_path)
191-
192+
self.save_wrapper_meta(local_path)
192193
if self.data_preprocessing_function:
193194
self.save_data_preprocessing_function(local_path)
194195
if self.model_postprocessing_function:
195196
self.save_model_postprocessing_function(local_path)
196-
if self.batch_size:
197-
self.save_batch_size(local_path)
198197

199198
@abstractmethod
200199
def save_model(self, path: Union[str, Path]) -> None:
@@ -215,15 +214,21 @@ def save_model_postprocessing_function(self, local_path: Union[str, Path]):
215214
with open(Path(local_path) / "giskard-model-postprocessing-function.pkl", "wb") as f:
216215
cloudpickle.dump(self.model_postprocessing_function, f, protocol=pickle.DEFAULT_PROTOCOL)
217216

218-
def save_batch_size(self, local_path: Union[str, Path]):
219-
with Path(local_path).joinpath("batch_size").open("w") as f:
220-
f.write(str(self.batch_size))
217+
def save_wrapper_meta(self, local_path):
218+
with open(Path(local_path) / "giskard-model-wrapper-meta.yaml", "w") as f:
219+
yaml.dump(
220+
{
221+
"batch_size": self.batch_size,
222+
},
223+
f,
224+
default_flow_style=False,
225+
)
221226

222227
@classmethod
223228
def load(cls, local_dir, **kwargs):
224229
kwargs["data_preprocessing_function"] = cls.load_data_preprocessing_function(local_dir)
225230
kwargs["model_postprocessing_function"] = cls.load_model_postprocessing_function(local_dir)
226-
kwargs["batch_size"] = cls._load_batch_size(local_dir)
231+
kwargs.update(cls.load_wrapper_meta(local_dir))
227232
model_id, meta = cls.read_meta_from_local_dir(local_dir)
228233
constructor_params = meta.__dict__
229234
constructor_params["id"] = model_id
@@ -262,10 +267,15 @@ def load_model_postprocessing_function(cls, local_path: Union[str, Path]):
262267
return cloudpickle.load(f)
263268
return None
264269

265-
@staticmethod
266-
def _load_batch_size(local_path: Union[str, Path]):
267-
file_path = Path(local_path) / "batch_size"
268-
if file_path.exists():
269-
with file_path.open("r") as f:
270-
return int(f.read())
271-
return None
270+
@classmethod
271+
def load_wrapper_meta(cls, local_dir):
272+
wrapper_meta_file = Path(local_dir) / "giskard-model-wrapper-meta.yaml"
273+
if wrapper_meta_file.exists():
274+
with open(wrapper_meta_file) as f:
275+
wrapper_meta = yaml.load(f, Loader=yaml.Loader)
276+
wrapper_meta["batch_size"] = int(wrapper_meta["batch_size"]) if wrapper_meta["batch_size"] else None
277+
return wrapper_meta
278+
else:
279+
raise ValueError(
280+
f"Cannot load model ({cls.__module__}.{cls.__name__}), " f"{wrapper_meta_file} file not found"
281+
)

python-client/giskard/models/pytorch.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,20 @@ def __getitem__(self, idx):
6060

6161
class PyTorchModel(MLFlowSerializableModel):
6262
def __init__(
63-
self,
64-
model,
65-
model_type: ModelType,
66-
torch_dtype: TorchDType = "float32",
67-
device="cpu",
68-
name: Optional[str] = None,
69-
data_preprocessing_function=None,
70-
model_postprocessing_function=None,
71-
feature_names=None,
72-
classification_threshold=0.5,
73-
classification_labels=None,
74-
iterate_dataset: bool = True,
75-
batch_size: Optional[int] = None,
76-
**kwargs,
63+
self,
64+
model,
65+
model_type: ModelType,
66+
torch_dtype: TorchDType = "float32",
67+
device="cpu",
68+
name: Optional[str] = None,
69+
data_preprocessing_function=None,
70+
model_postprocessing_function=None,
71+
feature_names=None,
72+
classification_threshold=0.5,
73+
classification_labels=None,
74+
iterate_dataset: bool = True,
75+
batch_size: Optional[int] = None,
76+
**kwargs,
7777
) -> None:
7878
"""Automatically wraps a PyTorch model.
7979
@@ -214,14 +214,19 @@ def save(self, local_path: Union[str, Path]) -> None:
214214

215215
@classmethod
216216
def load(cls, local_dir, **kwargs):
217+
kwargs.update(cls.load_pytorch_meta(local_dir))
218+
return super().load(local_dir, **kwargs)
219+
220+
@classmethod
221+
def load_pytorch_meta(cls, local_dir):
217222
pytorch_meta_file = Path(local_dir) / "giskard-model-pytorch-meta.yaml"
218223
if pytorch_meta_file.exists():
219224
with open(pytorch_meta_file) as f:
220225
pytorch_meta = yaml.load(f, Loader=yaml.Loader)
221-
kwargs["device"] = pytorch_meta["device"]
222-
kwargs["torch_dtype"] = pytorch_meta["torch_dtype"]
223-
kwargs["iterate_dataset"] = pytorch_meta.get("iterate_dataset")
224-
return super().load(local_dir, **kwargs)
226+
pytorch_meta["device"] = pytorch_meta.get("device")
227+
pytorch_meta["torch_dtype"] = pytorch_meta.get("torch_dtype")
228+
pytorch_meta["iterate_dataset"] = pytorch_meta.get("iterate_dataset")
229+
return pytorch_meta
225230
else:
226231
raise ValueError(
227232
f"Cannot load model ({cls.__module__}.{cls.__name__}), " f"{pytorch_meta_file} file not found"

0 commit comments

Comments
 (0)