Skip to content

Commit 74ec47e

Browse files
[GSK-2378] Only persist/read cache when explicitly asked (#1680)
* Only persist/read cache when explicitly asked * Only load artifact once * Fixed cache logic * Typo * Code improvement * Added test to ensure same model instance is used in test suite
1 parent b63c3d1 commit 74ec47e

6 files changed

Lines changed: 165 additions & 18 deletions

File tree

giskard/ml_worker/websocket/listener.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tempfile
77
import time
88
import traceback
9+
from collections import defaultdict
910
from concurrent.futures import CancelledError, Future
1011
from copy import copy
1112
from dataclasses import dataclass
@@ -520,17 +521,20 @@ def run_test_suite(
520521
client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs
521522
) -> websocket.TestSuite:
522523
log_listener = LogListener()
524+
525+
loaded_artifacts = defaultdict(dict)
526+
523527
try:
524528
tests = [
525529
{
526530
"test": GiskardTest.download(t.testUuid, client, None),
527-
"arguments": parse_function_arguments(client, t.arguments),
531+
"arguments": parse_function_arguments(client, t.arguments, loaded_artifacts),
528532
"id": t.id,
529533
}
530534
for t in params.tests
531535
]
532536

533-
global_arguments = parse_function_arguments(client, params.globalArguments)
537+
global_arguments = parse_function_arguments(client, params.globalArguments, loaded_artifacts)
534538

535539
datasets = {arg.original_id: arg for arg in global_arguments.values() if isinstance(arg, Dataset)}
536540
for test in tests:

giskard/ml_worker/websocket/utils.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import os
33
import shutil
44
import uuid
5-
from typing import Any, Dict, List, Optional
5+
from collections import defaultdict
66

77
import pandas as pd
88
from mlflow.store.artifact.artifact_repo import verify_artifact_path
9+
from typing import Any, Dict, List, Optional, Callable
910

1011
from giskard.client.giskard_client import GiskardClient
1112
from giskard.core.suite import DatasetInput, ModelInput, SuiteInput
@@ -158,7 +159,21 @@ def map_dataset_process_function_meta_ws(callable_type):
158159
}
159160

160161

161-
def parse_function_arguments(client: Optional[GiskardClient], request_arguments: List[websocket.FuncArgument]):
162+
def _get_or_load(loaded_artifacts: Dict[str, Dict[str, Any]], type: str, uuid: str, load_fn: Callable[[], Any]) -> Any:
163+
if uuid not in loaded_artifacts[type]:
164+
loaded_artifacts[type][uuid] = load_fn()
165+
166+
return loaded_artifacts[type][uuid]
167+
168+
169+
def parse_function_arguments(
170+
client: Optional[GiskardClient],
171+
request_arguments: List[websocket.FuncArgument],
172+
loaded_artifacts: Optional[Dict[str, Dict[str, Any]]] = None,
173+
):
174+
if loaded_artifacts is None:
175+
loaded_artifacts = defaultdict(dict)
176+
162177
arguments = dict()
163178

164179
# Processing empty list
@@ -169,22 +184,32 @@ def parse_function_arguments(client: Optional[GiskardClient], request_arguments:
169184
if arg.is_none:
170185
continue
171186
if arg.dataset is not None:
172-
arguments[arg.name] = Dataset.download(
173-
client,
174-
arg.dataset.project_key,
187+
arguments[arg.name] = _get_or_load(
188+
loaded_artifacts,
189+
"Dataset",
175190
arg.dataset.id,
176-
arg.dataset.sample,
191+
lambda: Dataset.download(
192+
client,
193+
arg.dataset.project_key,
194+
arg.dataset.id,
195+
arg.dataset.sample,
196+
),
177197
)
178198
elif arg.model is not None:
179-
arguments[arg.name] = BaseModel.download(client, arg.model.project_key, arg.model.id)
199+
arguments[arg.name] = _get_or_load(
200+
loaded_artifacts,
201+
"BaseModel",
202+
arg.model.id,
203+
lambda: BaseModel.download(client, arg.model.project_key, arg.model.id),
204+
)
180205
elif arg.slicingFunction is not None:
181206
arguments[arg.name] = SlicingFunction.download(
182207
arg.slicingFunction.id, client, arg.slicingFunction.project_key
183-
)(**parse_function_arguments(client, arg.args))
208+
)(**parse_function_arguments(client, arg.args, loaded_artifacts))
184209
elif arg.transformationFunction is not None:
185210
arguments[arg.name] = TransformationFunction.download(
186211
arg.transformationFunction.id, client, arg.transformationFunction.project_key
187-
)(**parse_function_arguments(client, arg.args))
212+
)(**parse_function_arguments(client, arg.args, loaded_artifacts))
188213
elif arg.float_arg is not None:
189214
arguments[arg.name] = float(arg.float_arg)
190215
elif arg.int_arg is not None:

giskard/models/base/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,12 @@ def __init__(
143143
if len(classification_labels) != len(set(classification_labels)):
144144
raise ValueError("Duplicates are found in 'classification_labels', please only provide unique values.")
145145

146-
self._cache = ModelCache(model_type, str(self.id), cache_dir=kwargs.get("prediction_cache_dir"))
146+
self._cache = ModelCache(
147+
model_type,
148+
str(self.id),
149+
persist_cache=kwargs.get("persist_cache", False),
150+
cache_dir=kwargs.get("prediction_cache_dir"),
151+
)
147152

148153
# sklearn and catboost will fill classification_labels before this check
149154
if model_type == SupportedModelTypes.CLASSIFICATION and not classification_labels:

giskard/models/cache/cache.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import csv
22
from pathlib import Path
3-
from typing import Any, Iterable, List, Optional
43

54
import numpy as np
65
import pandas as pd
6+
from typing import Any, Iterable, List, Optional
77

88
from ...client.python_utils import warning
99
from ...core.core import SupportedModelTypes
@@ -26,14 +26,23 @@ def flatten(xs):
2626
class ModelCache:
2727
_default_cache_dir_prefix = Path(settings.home_dir / settings.cache_dir / "global" / "prediction_cache")
2828

29-
def __init__(self, model_type: SupportedModelTypes, id: Optional[str] = None, cache_dir: Optional[Path] = None):
29+
def __init__(
30+
self,
31+
model_type: SupportedModelTypes,
32+
id: Optional[str] = None,
33+
persist_cache: bool = False,
34+
cache_dir: Optional[Path] = None,
35+
):
3036
self.id = id
3137
self.prediction_cache = dict()
3238

33-
if cache_dir is None and self.id:
34-
cache_dir = self._default_cache_dir_prefix.joinpath(self.id)
39+
if persist_cache:
40+
if cache_dir is None and self.id:
41+
cache_dir = self._default_cache_dir_prefix / self.id
3542

36-
self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None
43+
self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None
44+
else:
45+
self.cache_file = None
3746

3847
self.vectorized_get_cache_or_na = np.vectorize(self.get_cache_or_na, otypes=[object])
3948
self.model_type = model_type

tests/communications/test_websocket_actor_tests.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
import random
12
import uuid
23

34
import pandas as pd
45
import pytest
56

7+
import giskard
68
from giskard import test
79
from giskard.datasets.base import Dataset
810
from giskard.ml_worker import websocket
911
from giskard.ml_worker.testing.test_result import TestResult as GiskardTestResult, TestMessage, TestMessageLevel
1012
from giskard.ml_worker.websocket import listener
13+
from giskard.models.base import BaseModel
1114
from giskard.testing.tests import debug_prefix
1215
from tests import utils
1316

@@ -62,6 +65,13 @@ def my_simple_test_legacy_debug(dataset: Dataset, debug: bool = False):
6265
return GiskardTestResult(passed=False, output_df=output_ds)
6366

6467

68+
@giskard.test()
69+
def same_prediction(left: BaseModel, right: BaseModel, ds: giskard.Dataset):
70+
left_pred = left.predict(ds)
71+
right_pred = right.predict(ds)
72+
return giskard.TestResult(passed=list(left_pred.raw_prediction) == list(right_pred.raw_prediction))
73+
74+
6575
def test_websocket_actor_run_ad_hoc_test_legacy_debug(enron_data: Dataset):
6676
project_key = str(uuid.uuid4())
6777

@@ -343,6 +353,97 @@ def test_websocket_actor_run_test_suite():
343353
assert not reply.results[2].result.passed
344354

345355

356+
def test_websocket_actor_run_test_suite_share_models_and_dataset_instance():
357+
def random_prediction(df):
358+
return [random.randint(0, 9) for _ in df.index]
359+
360+
# Use random model to ensure model prediction cache is shared (same instance loaded)
361+
random_model = giskard.Model(random_prediction, "regression", feature_names=["feature"])
362+
mock_dataset = giskard.Dataset(pd.DataFrame({"feature": range(100)}))
363+
364+
with utils.MockedClient(mock_all=False) as (client, mr):
365+
params = websocket.TestSuiteParam(
366+
projectKey=str(uuid.uuid4()),
367+
tests=[
368+
websocket.SuiteTestArgument(
369+
id=0,
370+
testUuid=same_prediction.meta.uuid,
371+
arguments=[
372+
websocket.FuncArgument(
373+
name="left",
374+
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
375+
none=False,
376+
),
377+
websocket.FuncArgument(
378+
name="right",
379+
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
380+
none=False,
381+
),
382+
websocket.FuncArgument(
383+
name="ds",
384+
dataset=websocket.ArtifactRef(project_key="project_key", id=str(mock_dataset.id)),
385+
none=False,
386+
),
387+
],
388+
),
389+
websocket.SuiteTestArgument(
390+
id=1,
391+
testUuid=same_prediction.meta.uuid,
392+
arguments=[
393+
websocket.FuncArgument(
394+
name="left",
395+
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
396+
none=False,
397+
)
398+
],
399+
),
400+
websocket.SuiteTestArgument(
401+
id=2,
402+
testUuid=same_prediction.meta.uuid,
403+
arguments=[
404+
websocket.FuncArgument(
405+
name="right",
406+
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
407+
none=False,
408+
)
409+
],
410+
),
411+
websocket.SuiteTestArgument(id=2, testUuid=same_prediction.meta.uuid, arguments=[]),
412+
],
413+
globalArguments=[
414+
websocket.FuncArgument(
415+
name="left",
416+
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
417+
none=False,
418+
),
419+
websocket.FuncArgument(
420+
name="right",
421+
model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)),
422+
none=False,
423+
),
424+
websocket.FuncArgument(
425+
name="ds",
426+
dataset=websocket.ArtifactRef(project_key="project_key", id=str(mock_dataset.id)),
427+
none=False,
428+
),
429+
],
430+
)
431+
utils.register_uri_for_artifact_meta_info(mr, same_prediction, None)
432+
433+
utils.register_uri_for_model_meta_info(mr, random_model, "project_key")
434+
utils.register_uri_for_model_artifact_info(mr, random_model, "project_key", register_file_contents=True)
435+
436+
utils.register_uri_for_dataset_meta_info(mr, mock_dataset, "project_key")
437+
utils.register_uri_for_dataset_artifact_info(mr, mock_dataset, "project_key", register_file_contents=True)
438+
439+
reply = listener.run_test_suite(client, params)
440+
441+
assert isinstance(reply, websocket.TestSuite)
442+
assert not reply.is_error
443+
assert reply.is_pass
444+
assert 4 == len(reply.results)
445+
446+
346447
def test_websocket_actor_run_test_suite_raise_error():
347448
with utils.MockedClient(mock_all=False) as (client, mr):
348449
params = websocket.TestSuiteParam(

tests/models/test_model_cache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
import pandas as pd
88
import pytest
99
import xxhash
10-
from langchain import LLMChain, PromptTemplate
1110
from langchain.llms.fake import FakeListLLM
1211

1312
import giskard
1413
from giskard import Dataset, Model
1514
from giskard.core.core import SupportedModelTypes
1615
from giskard.models.cache import ModelCache
16+
from langchain import LLMChain, PromptTemplate
17+
1718

1819
# https://symbl.cc/fr/unicode/blocks/
1920

@@ -31,6 +32,7 @@ def test_unicode_prediction(keys, values):
3132
with TemporaryDirectory() as temp_cache_dir:
3233
cache = ModelCache(
3334
model_type=SupportedModelTypes.TEXT_GENERATION,
35+
persist_cache=True,
3436
cache_dir=Path(temp_cache_dir),
3537
)
3638
key_series = pd.Series(keys)
@@ -43,6 +45,7 @@ def test_unicode_prediction(keys, values):
4345
warmed_up_cache = ModelCache(
4446
id="warmed_up",
4547
model_type=SupportedModelTypes.TEXT_GENERATION,
48+
persist_cache=True,
4649
cache_dir=Path(temp_cache_dir),
4750
)
4851
# Ensure warm up works fine

0 commit comments

Comments
 (0)