|
| 1 | +import random |
1 | 2 | import uuid |
2 | 3 |
|
3 | 4 | import pandas as pd |
4 | 5 | import pytest |
5 | 6 |
|
| 7 | +import giskard |
6 | 8 | from giskard import test |
7 | 9 | from giskard.datasets.base import Dataset |
8 | 10 | from giskard.ml_worker import websocket |
9 | 11 | from giskard.ml_worker.testing.test_result import TestResult as GiskardTestResult, TestMessage, TestMessageLevel |
10 | 12 | from giskard.ml_worker.websocket import listener |
| 13 | +from giskard.models.base import BaseModel |
11 | 14 | from giskard.testing.tests import debug_prefix |
12 | 15 | from tests import utils |
13 | 16 |
|
@@ -62,6 +65,13 @@ def my_simple_test_legacy_debug(dataset: Dataset, debug: bool = False): |
62 | 65 | return GiskardTestResult(passed=False, output_df=output_ds) |
63 | 66 |
|
64 | 67 |
|
| 68 | +@giskard.test() |
| 69 | +def same_prediction(left: BaseModel, right: BaseModel, ds: giskard.Dataset): |
| 70 | + left_pred = left.predict(ds) |
| 71 | + right_pred = right.predict(ds) |
| 72 | + return giskard.TestResult(passed=list(left_pred.raw_prediction) == list(right_pred.raw_prediction)) |
| 73 | + |
| 74 | + |
65 | 75 | def test_websocket_actor_run_ad_hoc_test_legacy_debug(enron_data: Dataset): |
66 | 76 | project_key = str(uuid.uuid4()) |
67 | 77 |
|
@@ -343,6 +353,97 @@ def test_websocket_actor_run_test_suite(): |
343 | 353 | assert not reply.results[2].result.passed |
344 | 354 |
|
345 | 355 |
|
| 356 | +def test_websocket_actor_run_test_suite_share_models_and_dataset_instance(): |
| 357 | + def random_prediction(df): |
| 358 | + return [random.randint(0, 9) for _ in df.index] |
| 359 | + |
| 360 | + # Use random model to ensure model prediction cache is shared (same instance loaded) |
| 361 | + random_model = giskard.Model(random_prediction, "regression", feature_names=["feature"]) |
| 362 | + mock_dataset = giskard.Dataset(pd.DataFrame({"feature": range(100)})) |
| 363 | + |
| 364 | + with utils.MockedClient(mock_all=False) as (client, mr): |
| 365 | + params = websocket.TestSuiteParam( |
| 366 | + projectKey=str(uuid.uuid4()), |
| 367 | + tests=[ |
| 368 | + websocket.SuiteTestArgument( |
| 369 | + id=0, |
| 370 | + testUuid=same_prediction.meta.uuid, |
| 371 | + arguments=[ |
| 372 | + websocket.FuncArgument( |
| 373 | + name="left", |
| 374 | + model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)), |
| 375 | + none=False, |
| 376 | + ), |
| 377 | + websocket.FuncArgument( |
| 378 | + name="right", |
| 379 | + model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)), |
| 380 | + none=False, |
| 381 | + ), |
| 382 | + websocket.FuncArgument( |
| 383 | + name="ds", |
| 384 | + dataset=websocket.ArtifactRef(project_key="project_key", id=str(mock_dataset.id)), |
| 385 | + none=False, |
| 386 | + ), |
| 387 | + ], |
| 388 | + ), |
| 389 | + websocket.SuiteTestArgument( |
| 390 | + id=1, |
| 391 | + testUuid=same_prediction.meta.uuid, |
| 392 | + arguments=[ |
| 393 | + websocket.FuncArgument( |
| 394 | + name="left", |
| 395 | + model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)), |
| 396 | + none=False, |
| 397 | + ) |
| 398 | + ], |
| 399 | + ), |
| 400 | + websocket.SuiteTestArgument( |
| 401 | + id=2, |
| 402 | + testUuid=same_prediction.meta.uuid, |
| 403 | + arguments=[ |
| 404 | + websocket.FuncArgument( |
| 405 | + name="right", |
| 406 | + model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)), |
| 407 | + none=False, |
| 408 | + ) |
| 409 | + ], |
| 410 | + ), |
| 411 | + websocket.SuiteTestArgument(id=2, testUuid=same_prediction.meta.uuid, arguments=[]), |
| 412 | + ], |
| 413 | + globalArguments=[ |
| 414 | + websocket.FuncArgument( |
| 415 | + name="left", |
| 416 | + model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)), |
| 417 | + none=False, |
| 418 | + ), |
| 419 | + websocket.FuncArgument( |
| 420 | + name="right", |
| 421 | + model=websocket.ArtifactRef(project_key="project_key", id=str(random_model.id)), |
| 422 | + none=False, |
| 423 | + ), |
| 424 | + websocket.FuncArgument( |
| 425 | + name="ds", |
| 426 | + dataset=websocket.ArtifactRef(project_key="project_key", id=str(mock_dataset.id)), |
| 427 | + none=False, |
| 428 | + ), |
| 429 | + ], |
| 430 | + ) |
| 431 | + utils.register_uri_for_artifact_meta_info(mr, same_prediction, None) |
| 432 | + |
| 433 | + utils.register_uri_for_model_meta_info(mr, random_model, "project_key") |
| 434 | + utils.register_uri_for_model_artifact_info(mr, random_model, "project_key", register_file_contents=True) |
| 435 | + |
| 436 | + utils.register_uri_for_dataset_meta_info(mr, mock_dataset, "project_key") |
| 437 | + utils.register_uri_for_dataset_artifact_info(mr, mock_dataset, "project_key", register_file_contents=True) |
| 438 | + |
| 439 | + reply = listener.run_test_suite(client, params) |
| 440 | + |
| 441 | + assert isinstance(reply, websocket.TestSuite) |
| 442 | + assert not reply.is_error |
| 443 | + assert reply.is_pass |
| 444 | + assert 4 == len(reply.results) |
| 445 | + |
| 446 | + |
346 | 447 | def test_websocket_actor_run_test_suite_raise_error(): |
347 | 448 | with utils.MockedClient(mock_all=False) as (client, mr): |
348 | 449 | params = websocket.TestSuiteParam( |
|
0 commit comments