Skip to content

Commit 801ef04

Browse files
authored
Merge pull request #1610 from Giskard-AI/GSK-2153
[GSK-2153] removing mlruns artifacts file produced by pytest
2 parents 6f72f9d + 83c977f commit 801ef04

1 file changed

Lines changed: 27 additions & 19 deletions

File tree

tests/integrations/test_mlflow.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import mlflow
22
import pytest
3+
from pathlib import Path
34

5+
from tempfile import TemporaryDirectory
46
from giskard.core.core import SupportedModelTypes
57

68
mlflow_model_types = {
@@ -10,20 +12,26 @@
1012
}
1113

1214

13-
def _evaluate(dataset, model, evaluator_config, request):
14-
mlflow.end_run()
15-
mlflow.start_run()
16-
model_info = model.to_mlflow()
17-
18-
mlflow.evaluate(
19-
model=model_info.model_uri,
20-
model_type=mlflow_model_types[model.meta.model_type],
21-
data=dataset.df,
22-
targets=dataset.target,
23-
evaluators="giskard",
24-
evaluator_config=evaluator_config,
25-
)
26-
mlflow.end_run()
15+
def _evaluate(dataset, model, evaluator_config):
16+
import platform
17+
import os
18+
19+
with TemporaryDirectory() as f:
20+
if platform.system() == "Windows":
21+
f = f.replace(os.sep, "/")
22+
f = "file://" + f
23+
mlflow.set_tracking_uri(Path(f))
24+
experiment_id = mlflow.create_experiment("test", artifact_location=f)
25+
with mlflow.start_run(experiment_id=experiment_id):
26+
model_info = model.to_mlflow()
27+
mlflow.evaluate(
28+
model=model_info.model_uri,
29+
model_type=mlflow_model_types[model.meta.model_type],
30+
data=dataset.df,
31+
targets=dataset.target,
32+
evaluators="giskard",
33+
evaluator_config=evaluator_config,
34+
)
2735

2836

2937
@pytest.mark.parametrize(
@@ -59,7 +67,7 @@ def _run_test(dataset_name, model_name, request):
5967
dataset = request.getfixturevalue(dataset_name)
6068
model = request.getfixturevalue(model_name)
6169
evaluator_config = {"model_config": {"classification_labels": model.meta.classification_labels}}
62-
_evaluate(dataset, model, evaluator_config, request)
70+
_evaluate(dataset, model, evaluator_config)
6371

6472

6573
@pytest.mark.parametrize("dataset_name,model_name", [("german_credit_data", "german_credit_model")])
@@ -74,23 +82,23 @@ def test_errors(dataset_name, model_name, request):
7482
dataset_copy.target = [1]
7583

7684
with pytest.raises(Exception) as e:
77-
_evaluate(dataset_copy, model, evaluator_config, request)
85+
_evaluate(dataset_copy, model, evaluator_config)
7886
assert e.match(r"Only pd.DataFrame are currently supported by the giskard evaluator.")
7987

8088
# dataset wrapping error
8189
dataset_copy = dataset.copy()
8290
dataset_copy.df.savings[0] = ["wrong_entry"]
8391

8492
with pytest.raises(Exception) as e:
85-
_evaluate(dataset_copy, model, evaluator_config, request)
93+
_evaluate(dataset_copy, model, evaluator_config)
8694
assert e.match(r"An error occurred while wrapping the dataset.*")
8795

8896
# model wrapping error
8997
dataset_copy = dataset.copy()
9098
evaluator_config = {"model_config": {"classification_labels": None}}
9199

92100
with pytest.raises(Exception) as e:
93-
_evaluate(dataset_copy, model, evaluator_config, request)
101+
_evaluate(dataset_copy, model, evaluator_config)
94102
assert e.match(r"An error occurred while wrapping the model.*")
95103

96104
# scan error
@@ -100,5 +108,5 @@ def test_errors(dataset_name, model_name, request):
100108
evaluator_config = {"model_config": {"classification_labels": cl}}
101109

102110
with pytest.raises(Exception) as e:
103-
_evaluate(dataset_copy, model, evaluator_config, request)
111+
_evaluate(dataset_copy, model, evaluator_config)
104112
assert e.match(r"An error occurred while scanning the model for vulnerabilities.*")

0 commit comments

Comments
 (0)