Skip to content

Commit 21fba7d

Browse files
Merge pull request #1490 from Giskard-AI/feature/gsk-1921-save-langchain-chain-without-ml_flow-to-reduce-compatibility
1921 & GSK-1851 & GSK-1901 - save langchain chain without ml flow to reduce compatibility
2 parents 0a8c342 + 868b4c0 commit 21fba7d

12 files changed

Lines changed: 1000 additions & 84 deletions

File tree

giskard/models/automodel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import pandas as pd
77

8-
from ..core.core import ModelType, SupportedModelTypes
98
from .base.serialization import CloudpickleSerializableModel
109
from .function import PredictionFunctionModel
10+
from ..core.core import ModelType, SupportedModelTypes
1111

1212
logger = logging.getLogger(__name__)
1313

@@ -128,8 +128,6 @@ def __new__(
128128
giskard_cls = CloudpickleSerializableModel
129129
# if save_model and load_model are overriden, replace them, if not, these equalities will be identities.
130130
possibly_overriden_cls = cls
131-
possibly_overriden_cls.save_model = giskard_cls.save_model
132-
possibly_overriden_cls.load_model = giskard_cls.load_model
133131
possibly_overriden_cls.should_save_model_class = True
134132
elif giskard_cls:
135133
input_type = "'prediction_function'" if giskard_cls == PredictionFunctionModel else "'model'"

giskard/models/base/serialization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .wrapper import WrapperModel
99

10+
1011
# @TODO: decouple the serialization logic from models. These abstract classes
1112
# could be implemented as mixins and then used in the models that need them.
1213
# The logic of saving the model should be moved to the serialization classes.

giskard/models/base/wrapper.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import pandas as pd
1212
import yaml
1313

14+
from .model import BaseModel
15+
from ..utils import warn_once
1416
from ...core.core import ModelType
1517
from ...core.validation import configured_validate_arguments
16-
from ..utils import warn_once
17-
from .model import BaseModel
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -242,6 +242,12 @@ def save_wrapper_meta(self, local_path):
242242

243243
@classmethod
244244
def load(cls, local_dir, **kwargs):
245+
constructor_params = cls.load_constructor_params(local_dir, **kwargs)
246+
247+
return cls(model=cls.load_model(local_dir), **constructor_params)
248+
249+
@classmethod
250+
def load_constructor_params(cls, local_dir, **kwargs):
245251
params = cls.load_wrapper_meta(local_dir)
246252
params["data_preprocessing_function"] = cls.load_data_preprocessing_function(local_dir)
247253
params["model_postprocessing_function"] = cls.load_model_postprocessing_function(local_dir)
@@ -253,7 +259,7 @@ def load(cls, local_dir, **kwargs):
253259
constructor_params = constructor_params.copy()
254260
constructor_params.update(params)
255261

256-
return cls(model=cls.load_model(local_dir), **constructor_params)
262+
return constructor_params
257263

258264
@classmethod
259265
@abstractmethod

giskard/models/huggingface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ class explicitly using :class:`giskard.models.huggingface.HuggingFaceModel`.
103103
from giskard.core.core import ModelType
104104
from giskard.core.validation import configured_validate_arguments
105105
from giskard.models.base import WrapperModel
106-
107106
from ..client.python_utils import warning
108107

109108
try:

giskard/models/langchain.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
from typing import Any, Callable, Iterable, Optional
1+
from pathlib import Path
2+
from typing import Any, Callable, Iterable, Optional, Union, Dict
23

3-
import mlflow
44
import pandas as pd
55

66
from giskard.core.core import SupportedModelTypes
77
from giskard.core.validation import configured_validate_arguments
8-
from giskard.models.base import MLFlowSerializableModel
8+
from giskard.models.base import WrapperModel
99

1010

11-
class LangchainModel(MLFlowSerializableModel):
11+
class LangchainModel(WrapperModel):
1212
@configured_validate_arguments
1313
def __init__(
14-
self,
15-
model,
16-
model_type: SupportedModelTypes,
17-
name: Optional[str] = None,
18-
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
19-
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
20-
feature_names: Optional[Iterable] = None,
21-
classification_threshold: Optional[float] = 0.5,
22-
classification_labels: Optional[Iterable] = None,
23-
**kwargs,
14+
self,
15+
model,
16+
model_type: SupportedModelTypes,
17+
name: Optional[str] = None,
18+
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
19+
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
20+
feature_names: Optional[Iterable] = None,
21+
classification_threshold: Optional[float] = 0.5,
22+
classification_labels: Optional[Iterable] = None,
23+
**kwargs,
2424
) -> None:
2525
assert (
26-
model_type == SupportedModelTypes.TEXT_GENERATION
26+
model_type == SupportedModelTypes.TEXT_GENERATION
2727
), "LangchainModel only support text_generation ModelType"
2828

2929
super().__init__(
@@ -38,15 +38,49 @@ def __init__(
3838
**kwargs,
3939
)
4040

41-
def save_model(self, local_path, mlflow_meta):
42-
mlflow.langchain.save_model(self.model, path=local_path, mlflow_model=mlflow_meta)
41+
def save(self, local_path: Union[str, Path]) -> None:
42+
super().save(local_path)
43+
self.save_model(local_path)
44+
self.save_artifacts(Path(local_path) / "artifacts")
45+
46+
def save_model(self, local_path: Union[str, Path]) -> None:
47+
path = Path(local_path)
48+
self.model.save(path / "chain.json")
49+
50+
def save_artifacts(self, artifact_dir) -> None:
51+
...
52+
53+
@classmethod
54+
def load(cls, local_dir, **kwargs):
55+
constructor_params = cls.load_constructor_params(local_dir, **kwargs)
56+
57+
artifacts = cls.load_artifacts(Path(local_dir) / "artifacts") or dict()
58+
constructor_params.update(artifacts)
59+
60+
return cls(model=cls.load_model(local_dir, **artifacts), **constructor_params)
61+
62+
@classmethod
63+
def load_model(cls, local_dir, **kwargs):
64+
from langchain.chains import load_chain
65+
66+
path = Path(local_dir)
67+
return load_chain(path / "chain.json", **kwargs)
4368

4469
@classmethod
45-
def load_model(cls, local_dir):
46-
return mlflow.langchain.load_model(local_dir)
70+
def load_artifacts(cls, local_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
71+
...
4772

4873
def model_predict(self, df):
49-
return [self.model.predict(**data) for data in df.to_dict("records")]
74+
generations = [self.model(data) for data in df.to_dict("records")]
75+
output_keys = self.model.output_keys
76+
77+
if len(output_keys) == 1:
78+
return [generation[output_keys[0]] for generation in generations]
79+
else:
80+
return [
81+
str({key: value for key, value in generation.items() if key in output_keys})
82+
for generation in generations
83+
]
5084

5185
def rewrite_prompt(self, template, input_variables=None, **kwargs):
5286
from langchain import LLMChain
@@ -70,6 +104,3 @@ def rewrite_prompt(self, template, input_variables=None, **kwargs):
70104
model_kwargs.update(kwargs)
71105

72106
return self.__class__(chain, **model_kwargs)
73-
74-
def to_mlflow(self, artifact_path: str = "langchain-model-from-giskard", **kwargs):
75-
return mlflow.langchain.log_model(self.model, artifact_path, **kwargs)

giskard/models/pytorch.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from torch.utils.data import DataLoader
1111
from torch.utils.data import Dataset as torch_dataset
1212

13-
from ..client.python_utils import warning
14-
from ..core.core import ModelType
1513
from .base.serialization import MLFlowSerializableModel
1614
from .utils import map_to_tuples
15+
from ..client.python_utils import warning
16+
from ..core.core import ModelType
1717

1818
TorchDType = Literal[
1919
"float32",
@@ -60,21 +60,21 @@ def __getitem__(self, idx):
6060

6161
class PyTorchModel(MLFlowSerializableModel):
6262
def __init__(
63-
self,
64-
model,
65-
model_type: ModelType,
66-
torch_dtype: TorchDType = "float32",
67-
device="cpu",
68-
name: Optional[str] = None,
69-
data_preprocessing_function=None,
70-
model_postprocessing_function=None,
71-
feature_names=None,
72-
classification_threshold=0.5,
73-
classification_labels=None,
74-
iterate_dataset: bool = True,
75-
id: Optional[str] = None,
76-
batch_size: Optional[int] = None,
77-
**kwargs,
63+
self,
64+
model,
65+
model_type: ModelType,
66+
torch_dtype: TorchDType = "float32",
67+
device="cpu",
68+
name: Optional[str] = None,
69+
data_preprocessing_function=None,
70+
model_postprocessing_function=None,
71+
feature_names=None,
72+
classification_threshold=0.5,
73+
classification_labels=None,
74+
iterate_dataset: bool = True,
75+
id: Optional[str] = None,
76+
batch_size: Optional[int] = None,
77+
**kwargs,
7878
) -> None:
7979
"""Automatically wraps a PyTorch model.
8080

giskard/models/sklearn.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import mlflow
44
import pandas as pd
55

6+
from .base.serialization import MLFlowSerializableModel
67
from ..core.core import ModelType, SupportedModelTypes
78
from ..core.validation import configured_validate_arguments
8-
from .base.serialization import MLFlowSerializableModel
99

1010

1111
class SKLearnModel(MLFlowSerializableModel):
@@ -15,18 +15,18 @@ class SKLearnModel(MLFlowSerializableModel):
1515

1616
@configured_validate_arguments
1717
def __init__(
18-
self,
19-
model,
20-
model_type: ModelType,
21-
name: Optional[str] = None,
22-
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
23-
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
24-
feature_names: Optional[Iterable] = None,
25-
classification_threshold: Optional[float] = 0.5,
26-
classification_labels: Optional[Iterable] = None,
27-
id: Optional[str] = None,
28-
batch_size: Optional[int] = None,
29-
**kwargs,
18+
self,
19+
model,
20+
model_type: ModelType,
21+
name: Optional[str] = None,
22+
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
23+
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
24+
feature_names: Optional[Iterable] = None,
25+
classification_threshold: Optional[float] = 0.5,
26+
classification_labels: Optional[Iterable] = None,
27+
id: Optional[str] = None,
28+
batch_size: Optional[int] = None,
29+
**kwargs,
3030
) -> None:
3131
model_type = SupportedModelTypes(model_type) if isinstance(model_type, str) else model_type
3232
if model_type == SupportedModelTypes.CLASSIFICATION:
@@ -75,9 +75,7 @@ def model_predict(self, df):
7575
else:
7676
return self.model.predict_proba(df)
7777

78-
def to_mlflow(self,
79-
artifact_path="sklearn-model-from-giskard",
80-
**kwargs):
81-
return mlflow.sklearn.log_model(sk_model=self.model, artifact_path=artifact_path,
82-
pyfunc_predict_fn=self._get_pyfunc_predict_fn(),
83-
**kwargs)
78+
def to_mlflow(self, artifact_path="sklearn-model-from-giskard", **kwargs):
79+
return mlflow.sklearn.log_model(
80+
sk_model=self.model, artifact_path=artifact_path, pyfunc_predict_fn=self._get_pyfunc_predict_fn(), **kwargs
81+
)

giskard/models/tensorflow.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,27 @@
44
import mlflow
55
import pandas as pd
66

7+
from .base import MLFlowSerializableModel
78
from ..core.core import ModelType
89
from ..core.validation import configured_validate_arguments
9-
from .base import MLFlowSerializableModel
1010

1111
logger = logging.getLogger(__name__)
1212

1313

1414
class TensorFlowModel(MLFlowSerializableModel):
1515
@configured_validate_arguments
1616
def __init__(
17-
self,
18-
model,
19-
model_type: ModelType,
20-
name: Optional[str] = None,
21-
data_preprocessing_function: Callable[[pd.DataFrame], Any] = None,
22-
model_postprocessing_function: Callable[[Any], Any] = None,
23-
feature_names: Optional[Iterable] = None,
24-
classification_threshold: Optional[float] = 0.5,
25-
classification_labels: Optional[Iterable] = None,
26-
id: Optional[str] = None,
27-
**kwargs,
17+
self,
18+
model,
19+
model_type: ModelType,
20+
name: Optional[str] = None,
21+
data_preprocessing_function: Callable[[pd.DataFrame], Any] = None,
22+
model_postprocessing_function: Callable[[Any], Any] = None,
23+
feature_names: Optional[Iterable] = None,
24+
classification_threshold: Optional[float] = 0.5,
25+
classification_labels: Optional[Iterable] = None,
26+
id: Optional[str] = None,
27+
**kwargs,
2828
):
2929
super().__init__(
3030
model=model,

0 commit comments

Comments
 (0)