Skip to content

Commit ae2f318

Browse files
committed
Fix conflicts
1 parent 87c6194 commit ae2f318

4 files changed

Lines changed: 31 additions & 15 deletions

File tree

python-client/giskard/models/base/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
feature_names: Optional[Iterable] = None,
8282
classification_threshold: Optional[float] = 0.5,
8383
classification_labels: Optional[Iterable] = None,
84+
id: Optional[str] = None,
8485
**kwargs,
8586
) -> None:
8687
"""
@@ -102,7 +103,7 @@ def __init__(
102103
The initialized object contains the following attributes:
103104
- meta: a ModelMeta object containing metadata about the model.
104105
"""
105-
self.id = uuid.UUID(kwargs.get("id", uuid.uuid4().hex))
106+
self.id = uuid.UUID(id) if id is not None else uuid.UUID(kwargs.get("id", uuid.uuid4().hex))
106107
if type(model_type) == str:
107108
try:
108109
model_type = SupportedModelTypes(model_type)
@@ -407,7 +408,7 @@ def download(cls, client: GiskardClient, project_key, model_id):
407408
clazz = cls.determine_model_class(meta, local_dir)
408409

409410
constructor_params = meta.__dict__
410-
constructor_params["id"] = model_id
411+
constructor_params["id"] = str(model_id)
411412

412413
del constructor_params["loader_module"]
413414
del constructor_params["loader_class"]

python-client/giskard/models/base/wrapper.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import pickle
3-
import yaml
43
from abc import ABC, abstractmethod
54
from inspect import isfunction, signature
65
from pathlib import Path
@@ -9,6 +8,7 @@
98
import cloudpickle
109
import numpy as np
1110
import pandas as pd
11+
import yaml
1212

1313
from ...core.core import ModelType
1414
from ...core.validation import configured_validate_arguments
@@ -30,17 +30,18 @@ class WrapperModel(BaseModel, ABC):
3030

3131
@configured_validate_arguments
3232
def __init__(
33-
self,
34-
model: Any,
35-
model_type: ModelType,
36-
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
37-
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
38-
name: Optional[str] = None,
39-
feature_names: Optional[Iterable] = None,
40-
classification_threshold: Optional[float] = 0.5,
41-
classification_labels: Optional[Iterable] = None,
42-
batch_size: Optional[int] = None,
43-
**kwargs,
33+
self,
34+
model: Any,
35+
model_type: ModelType,
36+
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
37+
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
38+
name: Optional[str] = None,
39+
feature_names: Optional[Iterable] = None,
40+
classification_threshold: Optional[float] = 0.5,
41+
classification_labels: Optional[Iterable] = None,
42+
id: Optional[str] = None,
43+
batch_size: Optional[int] = None,
44+
**kwargs,
4445
) -> None:
4546
"""
4647
Parameters
@@ -67,7 +68,16 @@ def __init__(
6768
The batch size to use for inference. Default is ``None``, which
6869
means inference will be done on the full dataframe.
6970
"""
70-
super().__init__(model_type, name, feature_names, classification_threshold, classification_labels, **kwargs)
71+
super().__init__(
72+
model_type=model_type,
73+
name=name,
74+
feature_names=feature_names,
75+
classification_threshold=classification_threshold,
76+
classification_labels=classification_labels,
77+
id=id,
78+
batch_size=batch_size,
79+
**kwargs,
80+
)
7181
self.model = model
7282
self.data_preprocessing_function = data_preprocessing_function
7383
self.model_postprocessing_function = model_postprocessing_function

python-client/giskard/models/huggingface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(
131131
feature_names: Optional[Iterable] = None,
132132
classification_threshold: Optional[float] = 0.5,
133133
classification_labels: Optional[Iterable] = None,
134+
id: Optional[str] = None,
134135
batch_size: Optional[int] = 1,
135136
**kwargs,
136137
) -> None:
@@ -176,6 +177,7 @@ def __init__(
176177
feature_names=feature_names,
177178
classification_threshold=classification_threshold,
178179
classification_labels=classification_labels,
180+
id=id,
179181
batch_size=batch_size,
180182
**kwargs,
181183
)

python-client/giskard/models/sklearn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
feature_names: Optional[Iterable] = None,
2525
classification_threshold: Optional[float] = 0.5,
2626
classification_labels: Optional[Iterable] = None,
27+
id: Optional[str] = None,
2728
batch_size: Optional[int] = None,
2829
**kwargs,
2930
) -> None:
@@ -45,6 +46,8 @@ def __init__(
4546
feature_names=feature_names,
4647
classification_threshold=classification_threshold,
4748
classification_labels=classification_labels,
49+
id=id,
50+
batch_size=batch_size,
4851
**kwargs,
4952
)
5053

0 commit comments

Comments
 (0)