Skip to content

Commit 5cccfc5

Browse files
Merge pull request #1837 from Giskard-AI/GSK-2887
GSK-2887 Upload tests with project when uploading a suite
2 parents aa505d7 + c88c206 commit 5cccfc5

12 files changed

Lines changed: 110 additions & 348 deletions

File tree

giskard/core/savable.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -53,31 +53,17 @@ def _get_name(cls) -> str:
5353
return f"{cls.__class__.__name__.lower()}s"
5454

5555
@classmethod
56-
def _get_meta_endpoint(cls, uuid: str, project_key: Optional[str]) -> str:
57-
if project_key is None:
58-
return posixpath.join(cls._get_name(), uuid)
59-
else:
60-
return posixpath.join("project", project_key, cls._get_name(), uuid)
56+
def _get_meta_endpoint(cls, uuid: str, project_key: str) -> str:
57+
return posixpath.join("project", project_key, cls._get_name(), uuid)
6158

6259
def _save_meta_locally(self, local_dir):
6360
with open(Path(local_dir) / "meta.yaml", "w") as f:
6461
yaml.dump(self.meta, f)
6562

66-
@classmethod
67-
def _load_meta_locally(cls, local_dir, uuid: str) -> Optional[SMT]:
68-
file = Path(local_dir) / "meta.yaml"
69-
if not file.exists():
70-
return None
71-
72-
with open(file, "r") as f:
73-
# PyYAML prohibits the arbitary execution so our class cannot be loaded safely,
74-
# see: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation
75-
return yaml.load(f, Loader=yaml.UnsafeLoader)
76-
7763
def upload(
7864
self,
7965
client: GiskardClient,
80-
project_key: Optional[str] = None,
66+
project_key: str,
8167
uploaded_dependencies: Optional[Set["Artifact"]] = None,
8268
) -> str:
8369
"""
@@ -114,14 +100,13 @@ def upload(
114100
return self.meta.uuid
115101

116102
@classmethod
117-
def download(cls, uuid: str, client: Optional[GiskardClient], project_key: Optional[str]) -> "Artifact":
103+
def download(cls, uuid: str, client: GiskardClient, project_key: str) -> "Artifact":
118104
"""
119105
Downloads the artifact from the Giskard hub or retrieves it from the local cache.
120106
121107
Args:
122108
uuid (str): The UUID of the artifact to download.
123-
client (Optional[GiskardClient]): The Giskard client instance used for communication with the hub. If None,
124-
the artifact will be retrieved from the local cache if available. Defaults to None.
109+
client (GiskardClient): The Giskard client instance used for communication with the hub.
125110
project_key (Optional[str]): The project key where the artifact is located. If None, the artifact will be
126111
retrieved from the global scope. Defaults to None.
127112
@@ -135,11 +120,7 @@ def download(cls, uuid: str, client: Optional[GiskardClient], project_key: Optio
135120
name = cls._get_name()
136121

137122
local_dir = settings.home_dir / settings.cache_dir / name / uuid
138-
139-
if client is None:
140-
meta = cls._load_meta_locally(local_dir, uuid)
141-
else:
142-
meta = client.load_meta(cls._get_meta_endpoint(uuid, project_key), cls._get_meta_class())
123+
meta = client.load_meta(cls._get_meta_endpoint(uuid, project_key), cls._get_meta_class())
143124

144125
assert meta is not None, "Could not retrieve test meta"
145126

giskard/core/suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuid_status:
431431

432432
return SuiteTestDTO(
433433
id=self.suite_test_id,
434-
testUuid=self.giskard_test.upload(client),
434+
testUuid=self.giskard_test.upload(client, project_key),
435435
functionInputs=params,
436436
displayName=self.display_name,
437437
)
@@ -935,7 +935,7 @@ def download(cls, client: GiskardClient, project_key: str, suite_id: int) -> "Su
935935
suite.project_key = project_key
936936

937937
for test_json in suite_dto.tests:
938-
test = GiskardTest.download(test_json.testUuid, client, None)
938+
test = GiskardTest.download(test_json.testUuid, client, project_key)
939939
test_arguments = parse_function_arguments(client, project_key, test_json.functionInputs.values())
940940
suite.add_test(test(**test_arguments), suite_test_id=test_json.id)
941941

giskard/ml_worker/websocket/listener.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
do_create_sub_dataset,
3737
do_run_adhoc_test,
3838
function_argument_to_ws,
39-
log_artifact_local,
4039
map_dataset_process_function_meta_ws,
4140
map_function_meta_ws,
4241
map_result_to_single_test_result_ws,
@@ -134,12 +133,12 @@ def parse_and_execute(
134133
action: MLWorkerAction,
135134
params,
136135
ml_worker: MLWorkerInfo,
137-
client_params: Optional[Dict[str, str]],
136+
client_params: Dict[str, str],
138137
) -> websocket.WorkerReply:
139138
action_params = parse_action_param(action, params)
140139
return callback(
141140
ml_worker=ml_worker,
142-
client=GiskardClient(**client_params) if client_params is not None else None,
141+
client=GiskardClient(**client_params),
143142
action=action.name,
144143
params=action_params,
145144
)
@@ -314,7 +313,7 @@ def run_other_model(dataset, prediction_results, is_text_generation):
314313

315314

316315
@websocket_actor(MLWorkerAction.runModel)
317-
def run_model(client: Optional[GiskardClient], params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty:
316+
def run_model(client: GiskardClient, params: websocket.RunModelParam, *args, **kwargs) -> websocket.Empty:
318317
try:
319318
model = BaseModel.download(client, params.model.project_key, params.model.id)
320319
dataset = Dataset.download(
@@ -349,35 +348,23 @@ def run_model(client: Optional[GiskardClient], params: websocket.RunModelParam,
349348
tmp_dir = Path(f)
350349
predictions_csv = get_file_name("predictions", "csv", params.dataset.sample)
351350
results.to_csv(index=False, path_or_buf=tmp_dir / predictions_csv)
352-
if client:
353-
client.log_artifact(
354-
tmp_dir / predictions_csv,
355-
f"models/inspections/{params.inspectionId}",
356-
)
357-
else:
358-
log_artifact_local(
359-
tmp_dir / predictions_csv,
360-
f"models/inspections/{params.inspectionId}",
361-
)
351+
client.log_artifact(
352+
tmp_dir / predictions_csv,
353+
f"models/inspections/{params.inspectionId}",
354+
)
362355

363356
calculated_csv = get_file_name("calculated", "csv", params.dataset.sample)
364357
calculated.to_csv(index=False, path_or_buf=tmp_dir / calculated_csv)
365-
if client:
366-
client.log_artifact(
367-
tmp_dir / calculated_csv,
368-
f"models/inspections/{params.inspectionId}",
369-
)
370-
else:
371-
log_artifact_local(
372-
tmp_dir / calculated_csv,
373-
f"models/inspections/{params.inspectionId}",
374-
)
358+
client.log_artifact(
359+
tmp_dir / calculated_csv,
360+
f"models/inspections/{params.inspectionId}",
361+
)
375362
return websocket.Empty()
376363

377364

378365
@websocket_actor(MLWorkerAction.runModelForDataFrame)
379366
def run_model_for_data_frame(
380-
client: Optional[GiskardClient], params: websocket.RunModelForDataFrameParam, *args, **kwargs
367+
client: GiskardClient, params: websocket.RunModelForDataFrameParam, *args, **kwargs
381368
) -> websocket.RunModelForDataFrame:
382369
model = BaseModel.download(client, params.model.project_key, params.model.id)
383370
df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows])
@@ -407,7 +394,7 @@ def run_model_for_data_frame(
407394

408395

409396
@websocket_actor(MLWorkerAction.explain)
410-
def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain:
397+
def explain_ws(client: GiskardClient, params: websocket.ExplainParam, *args, **kwargs) -> websocket.Explain:
411398
model = BaseModel.download(client, params.model.project_key, params.model.id)
412399
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample)
413400
explanations = explain(model, dataset, params.columns)
@@ -419,7 +406,7 @@ def explain_ws(client: Optional[GiskardClient], params: websocket.ExplainParam,
419406

420407
@websocket_actor(MLWorkerAction.explainText)
421408
def explain_text_ws(
422-
client: Optional[GiskardClient], params: websocket.ExplainTextParam, *args, **kwargs
409+
client: GiskardClient, params: websocket.ExplainTextParam, *args, **kwargs
423410
) -> websocket.ExplainText:
424411
model = BaseModel.download(client, params.model.project_key, params.model.id)
425412
text_column = params.feature_name
@@ -460,7 +447,7 @@ def get_catalog(*args, **kwargs) -> websocket.Catalog:
460447

461448
@websocket_actor(MLWorkerAction.datasetProcessing)
462449
def dataset_processing(
463-
client: Optional[GiskardClient], params: websocket.DatasetProcessingParam, *args, **kwargs
450+
client: GiskardClient, params: websocket.DatasetProcessingParam, *args, **kwargs
464451
) -> websocket.DatasetProcessing:
465452
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id, params.dataset.sample)
466453

@@ -500,7 +487,7 @@ def dataset_processing(
500487

501488
@websocket_actor(MLWorkerAction.runAdHocTest)
502489
def run_ad_hoc_test(
503-
client: Optional[GiskardClient], params: websocket.RunAdHocTestParam, *args, **kwargs
490+
client: GiskardClient, params: websocket.RunAdHocTestParam, *args, **kwargs
504491
) -> websocket.RunAdHocTest:
505492
test: GiskardTest = GiskardTest.download(params.testUuid, client, params.projectKey)
506493

@@ -525,9 +512,7 @@ def run_ad_hoc_test(
525512

526513

527514
@websocket_actor(MLWorkerAction.runTestSuite)
528-
def run_test_suite(
529-
client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs
530-
) -> websocket.TestSuite:
515+
def run_test_suite(client: GiskardClient, params: websocket.TestSuiteParam, *args, **kwargs) -> websocket.TestSuite:
531516
loaded_artifacts = defaultdict(dict)
532517

533518
try:
@@ -594,7 +579,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoResponse:
594579

595580

596581
def handle_cta(
597-
client: Optional[GiskardClient],
582+
client: GiskardClient,
598583
params: websocket.GetPushParam,
599584
push: Optional[Push],
600585
push_kind: PushKind,
@@ -635,9 +620,7 @@ def handle_cta(
635620

636621

637622
@websocket_actor(MLWorkerAction.getPush, timeout=30, ignore_timeout=True)
638-
def get_push(
639-
client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs
640-
) -> websocket.GetPushResponse:
623+
def get_push(client: GiskardClient, params: websocket.GetPushParam, *args, **kwargs) -> websocket.GetPushResponse:
641624
# Save cta_kind and push_kind and remove it from params
642625
cta_kind = params.cta_kind
643626
push_kind = params.push_kind
@@ -690,7 +673,7 @@ def push_to_ws(push: Push):
690673
return push.to_ws() if push is not None else None
691674

692675

693-
def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushParam):
676+
def get_push_objects(client: GiskardClient, params: websocket.GetPushParam):
694677
try:
695678
model = BaseModel.download(client, params.model.project_key, params.model.id)
696679
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id)
@@ -735,7 +718,7 @@ def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushP
735718

736719
@websocket_actor(MLWorkerAction.createSubDataset)
737720
def create_sub_dataset(
738-
client: Optional[GiskardClient], params: websocket.CreateSubDatasetParam, *arg, **kwargs
721+
client: GiskardClient, params: websocket.CreateSubDatasetParam, *arg, **kwargs
739722
) -> websocket.CreateSubDataset:
740723
datasets = {
741724
dateset_id: Dataset.download(
@@ -751,7 +734,7 @@ def create_sub_dataset(
751734

752735
@websocket_actor(MLWorkerAction.createDataset)
753736
def create_dataset(
754-
client: Optional[GiskardClient], params: websocket.CreateDatasetParam, *arg, **kwargs
737+
client: GiskardClient, params: websocket.CreateDatasetParam, *arg, **kwargs
755738
) -> websocket.CreateSubDataset:
756739
dataset = do_create_dataset(params.name, params.headers, params.rows)
757740

giskard/ml_worker/websocket/utils.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from typing import Any, Callable, Dict, List, Optional
22

33
import logging
4-
import os
5-
import shutil
64
import uuid
75
from collections import defaultdict
86

97
import pandas as pd
10-
from mlflow.store.artifact.artifact_repo import verify_artifact_path
118

129
from giskard.client.giskard_client import GiskardClient
1310
from giskard.core.suite import DatasetInput, ModelInput, SuiteInput
@@ -34,7 +31,6 @@
3431
)
3532
from giskard.ml_worker.websocket.action import MLWorkerAction
3633
from giskard.models.base import BaseModel
37-
from giskard.path_utils import artifacts_dir
3834
from giskard.registry.registry import tests_registry
3935
from giskard.registry.slicing_function import SlicingFunction
4036
from giskard.registry.transformation_function import TransformationFunction
@@ -126,21 +122,6 @@ def map_function_meta_ws(callable_type):
126122
}
127123

128124

129-
def log_artifact_local(local_file, artifact_path=None):
130-
# Log artifact locally from an internal worker
131-
verify_artifact_path(artifact_path)
132-
133-
file_name = os.path.basename(local_file)
134-
135-
if artifact_path:
136-
artifact_file = artifacts_dir / artifact_path / file_name
137-
else:
138-
artifact_file = artifacts_dir / file_name
139-
artifact_file.parent.mkdir(parents=True, exist_ok=True)
140-
141-
shutil.copy(local_file, artifact_file)
142-
143-
144125
def map_dataset_process_function_meta_ws(callable_type):
145126
return {
146127
test.uuid: websocket.DatasetProcessFunctionMeta(
@@ -182,7 +163,7 @@ def _get_or_load(loaded_artifacts: Dict[str, Dict[str, Any]], type: str, uuid: s
182163

183164

184165
def parse_function_arguments(
185-
client: Optional[GiskardClient],
166+
client: GiskardClient,
186167
request_arguments: List[websocket.FuncArgument],
187168
loaded_artifacts: Optional[Dict[str, Dict[str, Any]]] = None,
188169
):
@@ -245,7 +226,7 @@ def parse_function_arguments(
245226
def map_result_to_single_test_result_ws(
246227
result,
247228
datasets: Dict[uuid.UUID, Dataset],
248-
client: Optional[GiskardClient] = None,
229+
client: GiskardClient,
249230
project_key: Optional[str] = None,
250231
) -> websocket.SingleTestResult:
251232
if isinstance(result, TestResult):
@@ -302,9 +283,6 @@ def _upload_generated_output_df(client, datasets, project_key, result):
302283
)
303284

304285
if result.output_df.original_id not in datasets.keys():
305-
if not client:
306-
raise RuntimeError("Legacy test debugging using `output_df` is not supported internal ML worker")
307-
308286
if not project_key:
309287
raise ValueError("Unable to upload debug dataset due to missing `project_key`")
310288

giskard/models/base/model.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def upload(self, client: GiskardClient, project_key, validate_ds=None, *_args, *
464464
return str(self.id)
465465

466466
@classmethod
467-
def download(cls, client: Optional[GiskardClient], project_key, model_id, *_args, **_kwargs):
467+
def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):
468468
"""
469469
Downloads the specified model from the Giskard hub and loads it into memory.
470470
@@ -480,29 +480,24 @@ def download(cls, client: Optional[GiskardClient], project_key, model_id, *_args
480480
AssertionError: If the local directory where the model should be saved does not exist.
481481
"""
482482
local_dir = settings.home_dir / settings.cache_dir / "models" / model_id
483-
if client is None:
484-
# internal worker case, no token based http client [deprecated, to be removed]
485-
assert local_dir.exists(), f"Cannot find existing model {project_key}.{model_id} in {local_dir}"
486-
meta_response, meta = cls.read_meta_from_local_dir(local_dir)
487-
else:
488-
client.load_artifact(local_dir, posixpath.join("models", model_id))
489-
meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)
490-
# internal worker case, no token based http client
491-
if not local_dir.exists():
492-
raise RuntimeError(f"Cannot find existing model {project_key}.{model_id} in {local_dir}")
493-
with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f:
494-
file_meta = yaml.load(f, Loader=yaml.Loader)
495-
classification_labels = cls.cast_labels(meta_response)
496-
meta = ModelMeta(
497-
name=meta_response.name,
498-
description=meta_response.description,
499-
model_type=SupportedModelTypes[meta_response.modelType],
500-
feature_names=meta_response.featureNames,
501-
classification_labels=classification_labels,
502-
classification_threshold=meta_response.threshold,
503-
loader_module=file_meta["loader_module"],
504-
loader_class=file_meta["loader_class"],
505-
)
483+
client.load_artifact(local_dir, posixpath.join("models", model_id))
484+
meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)
485+
# internal worker case, no token based http client
486+
if not local_dir.exists():
487+
raise RuntimeError(f"Cannot find existing model {project_key}.{model_id} in {local_dir}")
488+
with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f:
489+
file_meta = yaml.load(f, Loader=yaml.Loader)
490+
classification_labels = cls.cast_labels(meta_response)
491+
meta = ModelMeta(
492+
name=meta_response.name,
493+
description=meta_response.description,
494+
model_type=SupportedModelTypes[meta_response.modelType],
495+
feature_names=meta_response.featureNames,
496+
classification_labels=classification_labels,
497+
classification_threshold=meta_response.threshold,
498+
loader_module=file_meta["loader_module"],
499+
loader_class=file_meta["loader_class"],
500+
)
506501

507502
model_py_ver = (
508503
tuple(meta_response.languageVersion.split(".")) if "PYTHON" == meta_response.language.upper() else None

0 commit comments

Comments
 (0)