Skip to content

Commit e4b7d5d

Browse files
committed
Adapt code to make it compatible with pydantic 2.0
1 parent 0438629 commit e4b7d5d

14 files changed

Lines changed: 45 additions & 32 deletions

File tree

.github/workflows/build_backend.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ jobs:
221221
- name: Check Pydantic installed version
222222
working-directory: python-client
223223
run: |
224-
pdm run pip freeze | tail -n +1 | grep '^pydantic'
225-
pdm run pip freeze | tail -n +1 | grep -q '^pydantic==${{ matrix.pydantic_v1 && '1' || '2' }}\.'
224+
pdm run pip freeze | grep '^pydantic'
225+
pdm run pip freeze | grep -q '^pydantic==${{ matrix.pydantic_v1 && '1' || '2' }}\.'
226226
227227
- name: Test code
228228
working-directory: python-client

python-client/giskard/core/validation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@ def configured_validate_arguments(func):
2222
# If you check https://docs.pydantic.dev/latest/usage/validation_decorator/#coercion-and-strictness,
2323
# this explains it will try to convert/coerce type to the type hinting
2424
# So a string will be "coerced" to an enum element, and so on
25-
to_return = functools.wraps(func)(validate_arguments(config={"arbitrary_types_allowed":True})(func))
26-
return to_return
25+
26+
# Add validation wrapper
27+
validated_func = validate_arguments(func, config={"arbitrary_types_allowed":True})
28+
# Call wraps, to update name, docs, ...
29+
validated_func = functools.wraps(func)(validated_func)
30+
return func
2731

2832

2933
def validate_is_pandasdataframe(df):

python-client/giskard/llm/talk/talk.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@
4444
Thought: I should look at the features that are required to see what I should gather to predict in the available tools
4545
{agent_scratchpad}"""
4646

47+
class LenientBaseToolkit(BaseToolkit):
48+
"""Extended class to allow arbitrary_types_allowed for pydantic compatibility"""
49+
class Config:
50+
arbitrary_types_allowed = True
51+
4752

4853
class ModelSpec(BaseModel):
4954
"""Base class for model spec."""
50-
5155
model: GiskardBaseModel
52-
dataset: Optional[Dataset]
53-
scan_report: Optional[ScanReport]
56+
dataset: Optional[Dataset] = None
57+
scan_report: Optional[ScanReport] = None
5458

5559
def _parse_json_inputs(self, json_inputs: str) -> Dict[str, Any]:
5660
features = json.loads(json_inputs)
@@ -242,7 +246,7 @@ async def _arun(
242246
return self._run(tool_input)
243247

244248

245-
class ModelToolkit(BaseToolkit):
249+
class ModelToolkit(LenientBaseToolkit):
246250
"""Toolkit for interacting with an ML model."""
247251

248252
spec: ModelSpec

python-client/giskard/ml_worker/testing/registry/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Callable, Optional, List, Union, Type, TypeVar
55

66
from giskard.core.core import TestFunctionMeta
7-
from giskard.core.validation import configured_validate_arguments
87
from giskard.ml_worker.testing.registry.decorators_utils import make_all_optional_or_suite_input, set_return_type
98
from giskard.ml_worker.testing.registry.giskard_test import GiskardTestMethod, GiskardTest
109

@@ -45,4 +44,5 @@ def _wrap_test_method(original):
4544
make_all_optional_or_suite_input(giskard_test_method)
4645
set_return_type(giskard_test_method, GiskardTestMethod)
4746

48-
return configured_validate_arguments(giskard_test_method)()
47+
## why a copy ?
48+
return giskard_test_method #configured_validate_arguments(giskard_test_method)#()

python-client/giskard/ml_worker/testing/registry/giskard_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import cloudpickle
1010

1111
from giskard.core.core import TestFunctionMeta, SMT
12+
from giskard.core.validation import configured_validate_arguments
1213
from giskard.ml_worker.core.savable import Artifact
1314
from giskard.ml_worker.testing.registry.registry import tests_registry, get_object_uuid
1415
from giskard.ml_worker.testing.test_result import TestResult
@@ -104,7 +105,7 @@ class GiskardTestMethod(GiskardTest):
104105
def __init__(self, test_fn: Function) -> None:
105106
self.params = {}
106107
self.is_initialized = False
107-
self.test_fn = test_fn
108+
self.test_fn = configured_validate_arguments(test_fn)
108109
test_uuid = get_object_uuid(test_fn)
109110
meta = tests_registry.get_test(test_uuid)
110111
if meta is None:

python-client/giskard/ml_worker/testing/registry/slicing_function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, func: Optional[SlicingFunctionType], row_level=True, cell_lev
6060
"""
6161
self.is_initialized = False
6262
self.params = {}
63-
self.func = func
63+
self.func = configured_validate_arguments(func)
6464
self.row_level = row_level
6565
self.cell_level = cell_level
6666

@@ -147,7 +147,7 @@ def inner(func: Union[SlicingFunctionType, Type[SlicingFunction]]) -> SlicingFun
147147
if inspect.isclass(func) and issubclass(func, SlicingFunction):
148148
return func
149149

150-
return _wrap_slicing_function(func, row_level, cell_level)()
150+
return _wrap_slicing_function(func, row_level, cell_level)#()
151151

152152
if callable(_fn):
153153
return functools.wraps(_fn)(inner(_fn))
@@ -166,4 +166,4 @@ def _wrap_slicing_function(original: Callable, row_level: bool, cell_level: bool
166166
make_all_optional_or_suite_input(slicing_fn)
167167
set_return_type(slicing_fn, SlicingFunction)
168168

169-
return configured_validate_arguments(slicing_fn)
169+
return slicing_fn #configured_validate_arguments(slicing_fn)

python-client/giskard/ml_worker/testing/registry/transformation_function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _get_name(cls) -> str:
3232
return "transformations"
3333

3434
def __init__(self, func: Optional[TransformationFunctionType], row_level=True, cell_level=False):
35-
self.func = func
35+
self.func = configured_validate_arguments(func)
3636
self.row_level = row_level
3737
self.cell_level = cell_level
3838

@@ -136,5 +136,5 @@ def _wrap_transformation_function(original: Callable, row_level: bool, cell_leve
136136

137137
make_all_optional_or_suite_input(transformation_fn)
138138
set_return_type(transformation_fn, TransformationFunction)
139-
140-
return configured_validate_arguments(transformation_fn)
139+
# Copy or not copy ?
140+
return transformation_fn #configured_validate_arguments(transformation_fn) #()

python-client/giskard/models/automodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pandas as pd
77

8-
from ..core.core import ModelType
8+
from ..core.core import ModelType, SupportedModelTypes
99
from .base.serialization import CloudpickleSerializableModel
1010
from .function import PredictionFunctionModel
1111

@@ -159,11 +159,11 @@ def __new__(
159159

160160
obj = output_cls(
161161
model=model,
162-
model_type=model_type,
162+
model_type=SupportedModelTypes(model_type) if isinstance(model_type, str) else model_type,
163163
data_preprocessing_function=data_preprocessing_function,
164164
model_postprocessing_function=model_postprocessing_function,
165165
name=name,
166-
feature_names=feature_names,
166+
feature_names=list(feature_names) if feature_names is not None else None,
167167
classification_threshold=classification_threshold,
168168
classification_labels=classification_labels,
169169
**kwargs,

python-client/giskard/models/base/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
self.meta = ModelMeta(
129129
name=name if name is not None else self.__class__.__name__,
130130
model_type=model_type,
131-
feature_names=list(feature_names) if feature_names else None,
131+
feature_names=list(feature_names) if feature_names is not None else None,
132132
classification_labels=np_types_to_native(classification_labels),
133133
loader_class=self.__class__.__name__,
134134
loader_module=self.__module__,

python-client/giskard/models/base/model_prediction.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Any, Optional
22

3-
import pydantic
3+
from pydantic import BaseModel, Field
44

55

66
# @TODO: Define the fields of this class more rigorously.
7-
class ModelPredictionResults(pydantic.BaseModel):
7+
class ModelPredictionResults(BaseModel):
88
"""Data structure for model predictions.
99
1010
For regression models, the `prediction` field of the returned `ModelPredictionResults` object will contain the same
@@ -29,8 +29,8 @@ class ModelPredictionResults(pydantic.BaseModel):
2929
The predicted probabilities for all class labels for each example in the input dataset.
3030
"""
3131

32-
raw: Any = []
33-
prediction: Any = []
34-
raw_prediction: Any = []
35-
probabilities: Optional[Any]
36-
all_predictions: Optional[Any]
32+
raw: Any = Field(default_factory=list)
33+
prediction: Any = Field(default_factory=list)
34+
raw_prediction: Any = Field(default_factory=list)
35+
probabilities: Optional[Any] = None
36+
all_predictions: Optional[Any] = None

0 commit comments

Comments
 (0)