1- from typing import Any , Callable , Iterable , Optional
1+ from pathlib import Path
2+ from typing import Any , Callable , Iterable , Optional , Union , Dict
23
3- import mlflow
44import pandas as pd
55
66from giskard .core .core import SupportedModelTypes
77from giskard .core .validation import configured_validate_arguments
8- from giskard .models .base import MLFlowSerializableModel
8+ from giskard .models .base import WrapperModel
99
1010
11- class LangchainModel (MLFlowSerializableModel ):
11+ class LangchainModel (WrapperModel ):
1212 @configured_validate_arguments
1313 def __init__ (
14- self ,
15- model ,
16- model_type : SupportedModelTypes ,
17- name : Optional [str ] = None ,
18- data_preprocessing_function : Optional [Callable [[pd .DataFrame ], Any ]] = None ,
19- model_postprocessing_function : Optional [Callable [[Any ], Any ]] = None ,
20- feature_names : Optional [Iterable ] = None ,
21- classification_threshold : Optional [float ] = 0.5 ,
22- classification_labels : Optional [Iterable ] = None ,
23- ** kwargs ,
14+ self ,
15+ model ,
16+ model_type : SupportedModelTypes ,
17+ name : Optional [str ] = None ,
18+ data_preprocessing_function : Optional [Callable [[pd .DataFrame ], Any ]] = None ,
19+ model_postprocessing_function : Optional [Callable [[Any ], Any ]] = None ,
20+ feature_names : Optional [Iterable ] = None ,
21+ classification_threshold : Optional [float ] = 0.5 ,
22+ classification_labels : Optional [Iterable ] = None ,
23+ ** kwargs ,
2424 ) -> None :
2525 assert (
26- model_type == SupportedModelTypes .TEXT_GENERATION
26+ model_type == SupportedModelTypes .TEXT_GENERATION
2727 ), "LangchainModel only support text_generation ModelType"
2828
2929 super ().__init__ (
@@ -38,15 +38,49 @@ def __init__(
3838 ** kwargs ,
3939 )
4040
41- def save_model (self , local_path , mlflow_meta ):
42- mlflow .langchain .save_model (self .model , path = local_path , mlflow_model = mlflow_meta )
41+ def save (self , local_path : Union [str , Path ]) -> None :
42+ super ().save (local_path )
43+ self .save_model (local_path )
44+ self .save_artifacts (Path (local_path ) / "artifacts" )
45+
46+ def save_model (self , local_path : Union [str , Path ]) -> None :
47+ path = Path (local_path )
48+ self .model .save (path / "chain.json" )
49+
50+ def save_artifacts (self , artifact_dir ) -> None :
51+ ...
52+
53+ @classmethod
54+ def load (cls , local_dir , ** kwargs ):
55+ constructor_params = cls .load_constructor_params (local_dir , ** kwargs )
56+
57+ artifacts = cls .load_artifacts (Path (local_dir ) / "artifacts" ) or dict ()
58+ constructor_params .update (artifacts )
59+
60+ return cls (model = cls .load_model (local_dir , ** artifacts ), ** constructor_params )
61+
62+ @classmethod
63+ def load_model (cls , local_dir , ** kwargs ):
64+ from langchain .chains import load_chain
65+
66+ path = Path (local_dir )
67+ return load_chain (path / "chain.json" , ** kwargs )
4368
4469 @classmethod
45- def load_model (cls , local_dir ) :
46- return mlflow . langchain . load_model ( local_dir )
70+ def load_artifacts (cls , local_path : Union [ str , Path ]) -> Optional [ Dict [ str , Any ]] :
71+ ...
4772
4873 def model_predict (self , df ):
49- return [self .model .predict (** data ) for data in df .to_dict ("records" )]
74+ generations = [self .model (data ) for data in df .to_dict ("records" )]
75+ output_keys = self .model .output_keys
76+
77+ if len (output_keys ) == 1 :
78+ return [generation [output_keys [0 ]] for generation in generations ]
79+ else :
80+ return [
81+ str ({key : value for key , value in generation .items () if key in output_keys })
82+ for generation in generations
83+ ]
5084
5185 def rewrite_prompt (self , template , input_variables = None , ** kwargs ):
5286 from langchain import LLMChain
@@ -70,6 +104,3 @@ def rewrite_prompt(self, template, input_variables=None, **kwargs):
70104 model_kwargs .update (kwargs )
71105
72106 return self .__class__ (chain , ** model_kwargs )
73-
74- def to_mlflow (self , artifact_path : str = "langchain-model-from-giskard" , ** kwargs ):
75- return mlflow .langchain .log_model (self .model , artifact_path , ** kwargs )
0 commit comments