Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public MLWorkerIllegalReplyException(String errorString) {
}

public MLWorkerIllegalReplyException(MLWorkerWSErrorDTO error) {
super(error.getErrorStr());
super(error.getErrorStr() + "\n" + error.getDetail());
this.errorType = error.getErrorType() + ": " + error.getErrorStr();
this.errorString = error.getDetail();
}
Expand Down
2 changes: 1 addition & 1 deletion python-client/giskard/ml_worker/websocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ExplainText(WorkerReply):
class ExplainTextParam(BaseModel):
model: ArtifactRef
feature_name: str
columns: Dict[str, str]
columns: Dict[str, Optional[str]]
column_types: Dict[str, str]


Expand Down
26 changes: 14 additions & 12 deletions python-client/giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def run_model(ml_worker: MLWorker, params: websocket.RunModelParam, *args, **kwa

@websocket_actor(MLWorkerAction.runModelForDataFrame)
def run_model_for_data_frame(
ml_worker: MLWorker, params: websocket.RunModelForDataFrameParam, *args, **kwargs
ml_worker: MLWorker, params: websocket.RunModelForDataFrameParam, *args, **kwargs
) -> websocket.RunModelForDataFrame:
model = BaseModel.download(ml_worker.client, params.model.project_key, params.model.id)
df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows])
Expand Down Expand Up @@ -440,19 +440,23 @@ def get_catalog(*args, **kwargs) -> websocket.Catalog:

@websocket_actor(MLWorkerAction.datasetProcessing)
def dataset_processing(
ml_worker: MLWorker, params: websocket.DatasetProcessingParam, *args, **kwargs
ml_worker: MLWorker, params: websocket.DatasetProcessingParam, *args, **kwargs
) -> websocket.DatasetProcessing:
dataset = Dataset.download(ml_worker.client, params.dataset.project_key, params.dataset.id, params.dataset.sample)

for function in params.functions:
arguments = parse_function_arguments(ml_worker, function.arguments)
if function.slicingFunction:
dataset.add_slicing_function(
SlicingFunction.download(function.slicingFunction.id, ml_worker.client, None)(**arguments)
SlicingFunction.download(
function.slicingFunction.id, ml_worker.client, function.slicingFunction.project_key
)(**arguments)
)
else:
dataset.add_transformation_function(
TransformationFunction.download(function.transformationFunction.id, ml_worker.client, None)(**arguments)
TransformationFunction.download(
function.transformationFunction.id, ml_worker.client, function.slicingFunction.project_key
)(**arguments)
)

result = dataset.process()
Expand Down Expand Up @@ -480,7 +484,7 @@ def dataset_processing(

@websocket_actor(MLWorkerAction.runAdHocTest)
def run_ad_hoc_test(
ml_worker: MLWorker, params: websocket.RunAdHocTestParam, *args, **kwargs
ml_worker: MLWorker, params: websocket.RunAdHocTestParam, *args, **kwargs
) -> websocket.RunAdHocTest:
test: GiskardTest = GiskardTest.download(params.testUuid, ml_worker.client, None)

Expand Down Expand Up @@ -553,7 +557,7 @@ def run_test_suite(ml_worker: MLWorker, params: websocket.TestSuiteParam, *args,

@websocket_actor(MLWorkerAction.generateTestSuite)
def generate_test_suite(
ml_worker: MLWorker, params: websocket.GenerateTestSuiteParam, *args, **kwargs
ml_worker: MLWorker, params: websocket.GenerateTestSuiteParam, *args, **kwargs
) -> websocket.GenerateTestSuite:
inputs = [map_suite_input_ws(i) for i in params.inputs]

Expand All @@ -579,9 +583,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoMsg:


@websocket_actor(MLWorkerAction.getPush)
def get_push(
ml_worker: MLWorker, params: websocket.GetPushParam, *args, **kwargs
) -> websocket.GetPushResponse:
def get_push(ml_worker: MLWorker, params: websocket.GetPushParam, *args, **kwargs) -> websocket.GetPushResponse:
object_uuid = ""
object_params = {}
project_key = params.model.project_key
Expand Down Expand Up @@ -648,8 +650,8 @@ def get_push(

# Upload related object depending on CTA type
if (
params.cta_kind == CallToActionKind.CREATE_SLICE
or params.cta_kind == CallToActionKind.CREATE_SLICE_OPEN_DEBUGGER
params.cta_kind == CallToActionKind.CREATE_SLICE
or params.cta_kind == CallToActionKind.CREATE_SLICE_OPEN_DEBUGGER
):
push.slicing_function.meta.tags.append("generated")
object_uuid = push.slicing_function.upload(ml_worker.client)
Expand All @@ -676,7 +678,7 @@ def get_push(
if object_uuid != "":
logger.info(f"Uploaded object for CTA with uuid: {object_uuid}")

if object_uuid != '':
if object_uuid != "":
return websocket.GetPushResponse(
contribution=contrib_ws,
perturbation=perturb_ws,
Expand Down
4 changes: 2 additions & 2 deletions python-client/giskard/ml_worker/websocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def map_function_meta_ws(callable_type):
default=str(a.default),
argOrder=a.argOrder,
)
for a in test.args.values()
for a in (test.args.values() if test.args else []) # args could be None
],
)
for test in tests_registry.get_all().values()
Expand Down Expand Up @@ -145,7 +145,7 @@ def map_dataset_process_function_meta_ws(callable_type):
default=str(a.default),
argOrder=a.argOrder,
)
for a in test.args.values()
for a in (test.args.values() if test.args else []) # args could be None
],
cellLevel=test.cell_level,
columnType=test.column_type,
Expand Down