Skip to content

Commit f709d1d

Browse files
Merge branch 'main' into websocket-unit-tests
2 parents 490a72b + 4a1e2da commit f709d1d

7 files changed

Lines changed: 51 additions & 31 deletions

File tree

giskard/commands/cli_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def initialize_hf_token(hf_token, is_server):
136136
def _start_command(is_server, url: AnyHttpUrl, api_key, is_daemon, hf_token=None, nb_workers=None):
137137
from giskard.ml_worker.ml_worker import MLWorker
138138

139+
os.environ["TQDM_DISABLE"] = "1"
139140
start_msg = "Starting ML Worker"
140141
start_msg += " server" if is_server else " client"
141142
if is_daemon:

giskard/ml_worker/ml_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ async def start(self, nb_workers: Optional[int] = None):
161161
# as described in https://github.com/jasonrbriggs/stomp.py/issues/424
162162
# and https://github.com/websocket-client/websocket-client/issues/930
163163
logger.warn(f"WebSocket connection may not be properly closed: {e}")
164+
logger.exception(e)
164165

165166
def stop(self):
166167
if self.ws_conn:

giskard/ml_worker/websocket/listener.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,25 @@ def websocket_log_actor(ml_worker: MLWorkerInfo, req: Dict, *args, **kwargs):
7878
WEBSOCKET_ACTORS = dict((action.name, websocket_log_actor) for action in MLWorkerAction)
7979

8080

81-
def wrapped_handle_result(action: MLWorkerAction, ml_worker: MLWorker, start: float, rep_id: Optional[str]):
81+
def wrapped_handle_result(
82+
action: MLWorkerAction, ml_worker: MLWorker, start: float, rep_id: Optional[str], ignore_timeout: bool
83+
):
8284
def handle_result(future: Union[Future, Callable[..., websocket.WorkerReply]]):
8385
log_pool_stats()
8486

8587
info = None # Needs to be defined in case of cancellation
8688

8789
try:
8890
info: websocket.WorkerReply = future.result() if isinstance(future, Future) else future()
89-
except CancelledError:
90-
info: websocket.WorkerReply = websocket.Empty()
91-
logger.warning("Task for %s has timed out and been cancelled", action.name)
91+
except CancelledError as e:
92+
if ignore_timeout:
93+
info: websocket.WorkerReply = websocket.Empty()
94+
logger.warning("Task for %s has timed out and been cancelled", action.name)
95+
else:
96+
info: websocket.WorkerReply = websocket.ErrorReply(
97+
error_str=str(e), error_type=type(e).__name__, detail=traceback.format_exc()
98+
)
99+
logger.warning(e)
92100
except Exception as e:
93101
info: websocket.WorkerReply = websocket.ErrorReply(
94102
error_str=str(e), error_type=type(e).__name__, detail=traceback.format_exc()
@@ -171,7 +179,7 @@ def parse_and_execute(
171179
)
172180

173181

174-
def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=None):
182+
def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=None, ignore_timeout=False):
175183
# Parse the response ID
176184
rep_id = req["id"] if "id" in req.keys() else None
177185
# Parse the param
@@ -199,7 +207,7 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=N
199207
)
200208
start = time.process_time()
201209

202-
result_handler = wrapped_handle_result(action, ml_worker, start, rep_id)
210+
result_handler = wrapped_handle_result(action, ml_worker, start, rep_id, ignore_timeout=ignore_timeout)
203211
# If execution should be done in a pool
204212
if execute_in_pool:
205213
logger.debug("Submitting for action %s '%s' into the pool", action.name, callback.__name__)
@@ -227,7 +235,9 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=N
227235
)
228236

229237

230-
def websocket_actor(action: MLWorkerAction, execute_in_pool: bool = True, timeout: Optional[float] = None):
238+
def websocket_actor(
239+
action: MLWorkerAction, execute_in_pool: bool = True, timeout: Optional[float] = None, ignore_timeout: bool = False
240+
):
231241
"""
232242
Register a function as an actor to an action from WebSocket connection
233243
"""
@@ -238,7 +248,7 @@ def websocket_actor_callback(callback: callable):
238248
logger.debug(f'Registered "{callback.__name__}" for ML Worker "{action.name}"')
239249

240250
def wrapped_callback(ml_worker: MLWorker, req: dict, *args, **kwargs):
241-
dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout)
251+
dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout, ignore_timeout)
242252

243253
WEBSOCKET_ACTORS[action.name] = wrapped_callback
244254

@@ -664,7 +674,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoMsg:
664674
return params
665675

666676

667-
@websocket_actor(MLWorkerAction.getPush, timeout=30)
677+
@websocket_actor(MLWorkerAction.getPush, timeout=30, ignore_timeout=True)
668678
def get_push(
669679
client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs
670680
) -> websocket.GetPushResponse:

giskard/models/model_explanation.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
from typing import Any, Callable, Dict, List
2+
13
import logging
24
import warnings
3-
from typing import Callable, Dict, List, Any
45

56
import numpy as np
67
import pandas as pd
78

9+
from giskard.core.errors import GiskardImportError
810
from giskard.datasets.base import Dataset
9-
from giskard.models.base import BaseModel
1011
from giskard.ml_worker.utils.logging import timer
11-
from giskard.core.errors import GiskardImportError
12+
from giskard.models.base import BaseModel
1213

1314
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
1415
logger = logging.getLogger(__name__)
@@ -189,6 +190,7 @@ def explain_with_shap(model: BaseModel, dataset: Dataset, only_highest_proba: bo
189190

190191
try:
191192
from shap import Explanation
193+
192194
from giskard.models.shap_result import ShapResult
193195
except ImportError as e:
194196
raise GiskardImportError("shap") from e
@@ -244,8 +246,8 @@ def explain(model: BaseModel, dataset: Dataset, input_data: Dict):
244246
@timer()
245247
def explain_text(model: BaseModel, input_df: pd.DataFrame, text_column: str, text_document: str):
246248
try:
247-
from shap.maskers import Text
248249
from shap import Explainer
250+
from shap.maskers import Text
249251
except ImportError as e:
250252
raise GiskardImportError("shap") from e
251253
try:
@@ -259,7 +261,8 @@ def explain_text(model: BaseModel, input_df: pd.DataFrame, text_column: str, tex
259261
else (shap_values[0].data, shap_values[0].values)
260262
)
261263
except Exception as e:
262-
logger.exception(f"Failed to explain text: {text_document}", e)
264+
logger.error("Failed to explain text %s", text_document)
265+
logger.exception(e)
263266
raise Exception("Failed to create text explanation") from e
264267

265268

@@ -303,9 +306,12 @@ def text_explanation_prediction_wrapper(
303306
) -> Callable:
304307
def text_predict(text_documents: List[str]):
305308
num_documents = len(text_documents)
306-
307-
df_with_text_documents = input_example.append([input_example] * (num_documents - 1), ignore_index=True)
308-
df_with_text_documents[text_column] = pd.DataFrame(text_documents)
309+
df_with_text_documents = (
310+
input_example.copy()
311+
if num_documents == 1
312+
else pd.concat([input_example] * num_documents, ignore_index=True)
313+
)
314+
df_with_text_documents[text_column] = text_documents
309315
return prediction_function(df_with_text_documents)
310316

311317
return text_predict

giskard/scanner/llm/gender_stereotype_detector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Sequence
2+
13
import re
24
from pathlib import Path
3-
from typing import Sequence
45

56
import pandas as pd
67
import scipy.stats as stats
@@ -21,12 +22,11 @@ def __init__(self, threshold: float = 0.05):
2122
def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
2223
# @TODO: add Winogender Schemas
2324
df_job = self._read_job_data()
24-
25-
dataset = Dataset(df=df_job.loc[:, ("job",)], column_types={"job": "text"})
25+
read_dataset = Dataset(df=df_job.loc[:, ("job",)], column_types={"job": "text"})
2626
test_model = model.rewrite_prompt(_prompt_template, input_variables=["job"])
2727

2828
# Get model output and count gender-specific pronouns
29-
output = test_model.predict(dataset).prediction
29+
output = test_model.predict(read_dataset).prediction
3030
detected_genders = [detect_gender(sentence) for sentence in output]
3131

3232
df = df_job.copy()
@@ -64,7 +64,7 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
6464
return [
6565
Issue(
6666
model,
67-
dataset,
67+
read_dataset,
6868
level=IssueLevel.MAJOR,
6969
group=Stereotypes,
7070
description=desc,

giskard/scanner/llm/harmfulness_detector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from pathlib import Path
21
from typing import List, Sequence
32

3+
from pathlib import Path
4+
45
import pandas as pd
56

67
from ...datasets import Dataset
@@ -27,11 +28,11 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
2728
)
2829

2930
# Prepare test model and dataset
30-
dataset = Dataset(df=prompts_df, column_types={"text": "text"})
31+
read_dataset = Dataset(df=prompts_df, column_types={"text": "text"})
3132
test_model = model.rewrite_prompt("{text}", input_variables=["text"])
3233

3334
# Run prediction and evaluate toxicity/harmfulness
34-
output = test_model.predict(dataset).prediction
35+
output = test_model.predict(read_dataset).prediction
3536
harmfulness = self._compute_harmfulness(output)
3637

3738
# Filter the examples based
@@ -58,7 +59,7 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
5859
return [
5960
Issue(
6061
model,
61-
dataset,
62+
read_dataset,
6263
level=IssueLevel.MAJOR,
6364
group=Harmfulness,
6465
description=desc,

giskard/scanner/llm/minority_stereotype_detector.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import List, Optional, Sequence
2+
13
import itertools
24
from pathlib import Path
3-
from typing import List, Optional, Sequence
45

56
import pandas as pd
67

@@ -34,12 +35,12 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
3435
],
3536
columns=["text", "target"],
3637
)
37-
dataset = Dataset(prompt_df.loc[:, ("text",)], column_types={"text": "text"})
38+
read_dataset = Dataset(prompt_df.loc[:, ("text",)], column_types={"text": "text"})
3839

3940
test_model = model.rewrite_prompt("{text}", input_variables=["text"])
4041

4142
# Generate output and predict score
42-
output = test_model.predict(dataset).prediction
43+
output = test_model.predict(read_dataset).prediction
4344
bias_score = self._compute_bias(output)
4445

4546
examples = pd.DataFrame(
@@ -63,7 +64,7 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
6364
issues.append(
6465
Issue(
6566
model,
66-
dataset,
67+
read_dataset,
6768
level=IssueLevel.MAJOR,
6869
group=Stereotypes,
6970
meta={
@@ -84,4 +85,4 @@ def _compute_bias(self, sentences: List[str]):
8485
raise LLMImportError() from err
8586

8687
results = Detoxify().predict(list(sentences))
87-
return results["identity_attack"]
88+
return results["identity_attack"]

0 commit comments

Comments
 (0)