11import mlflow
22import pytest
3+ from pathlib import Path
34
5+ from tempfile import TemporaryDirectory
46from giskard .core .core import SupportedModelTypes
57
68mlflow_model_types = {
1012}
1113
1214
13- def _evaluate (dataset , model , evaluator_config , request ):
14- mlflow .end_run ()
15- mlflow .start_run ()
16- model_info = model .to_mlflow ()
17-
18- mlflow .evaluate (
19- model = model_info .model_uri ,
20- model_type = mlflow_model_types [model .meta .model_type ],
21- data = dataset .df ,
22- targets = dataset .target ,
23- evaluators = "giskard" ,
24- evaluator_config = evaluator_config ,
25- )
26- mlflow .end_run ()
15+ def _evaluate (dataset , model , evaluator_config ):
16+ import platform
17+ import os
18+
19+ with TemporaryDirectory () as f :
20+ if platform .system () == "Windows" :
21+ f = f .replace (os .sep , "/" )
22+ f = "file://" + f
23+ mlflow .set_tracking_uri (Path (f ))
24+ experiment_id = mlflow .create_experiment ("test" , artifact_location = f )
25+ with mlflow .start_run (experiment_id = experiment_id ):
26+ model_info = model .to_mlflow ()
27+ mlflow .evaluate (
28+ model = model_info .model_uri ,
29+ model_type = mlflow_model_types [model .meta .model_type ],
30+ data = dataset .df ,
31+ targets = dataset .target ,
32+ evaluators = "giskard" ,
33+ evaluator_config = evaluator_config ,
34+ )
2735
2836
2937@pytest .mark .parametrize (
@@ -59,7 +67,7 @@ def _run_test(dataset_name, model_name, request):
5967 dataset = request .getfixturevalue (dataset_name )
6068 model = request .getfixturevalue (model_name )
6169 evaluator_config = {"model_config" : {"classification_labels" : model .meta .classification_labels }}
62- _evaluate (dataset , model , evaluator_config , request )
70+ _evaluate (dataset , model , evaluator_config )
6371
6472
6573@pytest .mark .parametrize ("dataset_name,model_name" , [("german_credit_data" , "german_credit_model" )])
@@ -74,23 +82,23 @@ def test_errors(dataset_name, model_name, request):
7482 dataset_copy .target = [1 ]
7583
7684 with pytest .raises (Exception ) as e :
77- _evaluate (dataset_copy , model , evaluator_config , request )
85+ _evaluate (dataset_copy , model , evaluator_config )
7886 assert e .match (r"Only pd.DataFrame are currently supported by the giskard evaluator." )
7987
8088 # dataset wrapping error
8189 dataset_copy = dataset .copy ()
8290 dataset_copy .df .savings [0 ] = ["wrong_entry" ]
8391
8492 with pytest .raises (Exception ) as e :
85- _evaluate (dataset_copy , model , evaluator_config , request )
93+ _evaluate (dataset_copy , model , evaluator_config )
8694 assert e .match (r"An error occurred while wrapping the dataset.*" )
8795
8896 # model wrapping error
8997 dataset_copy = dataset .copy ()
9098 evaluator_config = {"model_config" : {"classification_labels" : None }}
9199
92100 with pytest .raises (Exception ) as e :
93- _evaluate (dataset_copy , model , evaluator_config , request )
101+ _evaluate (dataset_copy , model , evaluator_config )
94102 assert e .match (r"An error occurred while wrapping the model.*" )
95103
96104 # scan error
@@ -100,5 +108,5 @@ def test_errors(dataset_name, model_name, request):
100108 evaluator_config = {"model_config" : {"classification_labels" : cl }}
101109
102110 with pytest .raises (Exception ) as e :
103- _evaluate (dataset_copy , model , evaluator_config , request )
111+ _evaluate (dataset_copy , model , evaluator_config )
104112 assert e .match (r"An error occurred while scanning the model for vulnerabilities.*" )
0 commit comments