Skip to content

Commit 7fe3243

Browse files
committed
Rewrite Detoxify using zero shot classification
1 parent 2fb6732 commit 7fe3243

12 files changed

Lines changed: 301 additions & 306 deletions

.github/workflows/build-images.yml

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ on:
3434

3535
pull_request: # This will allow to trigger on PR only with a specific label
3636
types: [opened, reopened, synchronize, labeled, unlabeled]
37-
# https://docs.docker.com/build/ci/github-actions/multi-platform/#distribute-build-across-multiple-runners
37+
# Concurrency : auto-cancel "old" jobs ie when pushing again
38+
# https://docs.github.com/fr/actions/using-jobs/using-concurrency
39+
concurrency:
40+
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
41+
cancel-in-progress: true
3842
env:
3943
RUN_TESTS: false
4044
BUILD_ONLY: false
4145
REGISTRY_IMAGE: giskardai/giskard
4246
DOCKERHUB_USER: giskardai
47+
# https://docs.docker.com/build/ci/github-actions/multi-platform/#distribute-build-across-multiple-runners
4348
jobs:
4449
build-images:
4550
# Debug
@@ -161,20 +166,20 @@ jobs:
161166
${{ matrix.platform}}
162167
cache-from: type=gha
163168

164-
- name: Run python integration test inside docker
165-
if: ${{ env.RUN_TESTS }}
166-
uses: docker/build-push-action@v5
167-
with:
168-
context: .
169-
target: integration-test-python
170-
push: false
171-
load: false
172-
tags: ${{ steps.meta.outputs.tags }}
173-
labels: ${{ steps.meta.outputs.labels }}
174-
builder: ${{ steps.builder.outputs.name }}
175-
platforms: |
176-
${{ matrix.platform}}
177-
cache-from: type=gha
169+
# - name: Run python integration test inside docker
170+
# if: ${{ env.RUN_TESTS }}
171+
# uses: docker/build-push-action@v5
172+
# with:
173+
# context: .
174+
# target: integration-test-python
175+
# push: false
176+
# load: false
177+
# tags: ${{ steps.meta.outputs.tags }}
178+
# labels: ${{ steps.meta.outputs.labels }}
179+
# builder: ${{ steps.builder.outputs.name }}
180+
# platforms: |
181+
# ${{ matrix.platform}}
182+
# cache-from: type=gha
178183

179184
- name: Build and push
180185
id: build

.github/workflows/build_backend.yml

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,50 +14,29 @@ on:
1414
required: true
1515
type: boolean
1616
default: false
17+
is-dispatch:
18+
description: 'Just to identify manual dispatch'
19+
required: true
20+
type: boolean
21+
default: true
1722
workflow_call:
1823
inputs:
1924
run-integration-tests:
2025
description: 'If integration test should be run'
2126
required: true
2227
type: boolean
2328
default: false
29+
# Concurrency : auto-cancel "old" jobs ie when pushing again
30+
# https://docs.github.com/fr/actions/using-jobs/using-concurrency
31+
concurrency:
32+
group: ${{ github.workflow }}-${{ inputs.run-integration-tests }}-${{ inputs.is-dispatch }}-${{ github.ref || github.run_id }}
33+
cancel-in-progress: true
2434
env:
2535
GSK_DISABLE_ANALYTICS: true
2636
defaults:
2737
run:
2838
shell: bash
2939
jobs:
30-
pre-check:
31-
name: Pre check
32-
runs-on: ubuntu-latest
33-
steps:
34-
- name: Checkout code
35-
uses: actions/checkout@v4
36-
with:
37-
fetch-depth: 1
38-
# Inspired from https://blog.pantsbuild.org/skipping-github-actions-jobs-without-breaking-branch-protection/
39-
- id: files
40-
name: Get changed files outside python-client
41-
uses: tj-actions/changed-files@v39
42-
with:
43-
files_ignore: python-client/**
44-
45-
- id: files-python
46-
name: Get changed files in python-client
47-
uses: tj-actions/changed-files@v39
48-
with:
49-
files: python-client/**
50-
51-
- id: python_only
52-
if: steps.files.outputs.any_changed != 'true'
53-
name: Check for changes in python only
54-
run: echo 'python_only=PYTHON_ONLY' >> $GITHUB_OUTPUT
55-
56-
- id: python_at_least
57-
if: steps.files-python.outputs.any_changed != 'true'
58-
name: Check for changes in python
59-
run: echo 'python_at_least=PYTHON_AT_LEAST' >> $GITHUB_OUTPUT
60-
6140
sonar:
6241
if: ${{ github.actor != 'dependabot[bot]' && (github.event_name == 'pull_request' || github.event_name == 'push') }}
6342
name: Sonar
@@ -93,8 +72,6 @@ jobs:
9372
run: ./gradlew sonar --info --parallel
9473

9574
build:
96-
needs: pre-check
97-
if: ${{ needs.pre-check.outputs.python_only != 'PYTHON_ONLY' }}
9875
name: Backend
9976
runs-on: ubuntu-latest
10077
steps:
@@ -147,8 +124,6 @@ jobs:
147124

148125
build-python:
149126
name: Build Python
150-
needs: pre-check
151-
if: ${{ needs.pre-check.outputs.python_at_least != 'PYTHON_AT_LEAST' }}
152127
runs-on: ${{ matrix.os }}
153128
strategy:
154129
fail-fast: false # Do not stop when any job fails
@@ -228,4 +203,4 @@ jobs:
228203
working-directory: python-client
229204
env:
230205
PYTEST_XDIST_AUTO_NUM_WORKERS: ${{ matrix.os == 'windows-2019' && 1 || 2 }}
231-
run: pdm run test
206+
run: pdm run test -m "slow"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pandas as pd
2+
from transformers import pipeline
3+
4+
5+
# Re-implementation based on https://github.com/unitaryai/detoxify/issues/15#issuecomment-900443551
6+
class Detoxify:
7+
def __init__(
8+
self,
9+
):
10+
super().__init__()
11+
self.pipeline = pipeline(
12+
"text-classification",
13+
model="unitary/unbiased-toxic-roberta",
14+
tokenizer="unitary/unbiased-toxic-roberta",
15+
function_to_apply="sigmoid",
16+
return_all_scores=True,
17+
)
18+
19+
def predict(self, text) -> pd.DataFrame:
20+
inputs = [text] if isinstance(text, str) else text
21+
results = self.pipeline(inputs)
22+
output = []
23+
for one_result in results:
24+
res = {}
25+
for single_label_result in one_result:
26+
res[single_label_result["label"]] = single_label_result["score"]
27+
output.append(res)
28+
output = pd.DataFrame(output)
29+
return output

python-client/giskard/scanner/llm/harmfulness_detector.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from pathlib import Path
22
from typing import List, Sequence
33

4-
import numpy as np
54
import pandas as pd
65

76
from ...datasets import Dataset
@@ -73,13 +72,11 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
7372

7473
def _compute_harmfulness(self, sentences: List[str]):
7574
try:
76-
from detoxify import Detoxify
75+
from giskard.scanner.llm.detoxify import Detoxify
7776
except ImportError as err:
7877
raise LLMImportError() from err
7978

8079
keys = ["toxicity", "severe_toxicity", "identity_attack", "insult", "threat"]
81-
results = Detoxify("unbiased").predict(list(sentences))
82-
83-
harmfulness = np.vstack([results[k] for k in keys]).max(axis=0)
84-
80+
results = Detoxify().predict(list(sentences))
81+
harmfulness = results[keys].max(axis="columns")
8582
return harmfulness

python-client/giskard/scanner/llm/minority_stereotype_detector.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
7979

8080
def _compute_bias(self, sentences: List[str]):
8181
try:
82-
from detoxify import Detoxify
82+
from giskard.scanner.llm.detoxify import Detoxify
8383
except ImportError as err:
8484
raise LLMImportError() from err
8585

86-
results = Detoxify("unbiased").predict(list(sentences))
87-
88-
return results["identity_attack"]
86+
results = Detoxify().predict(list(sentences))
87+
return results["identity_attack"]

python-client/giskard/scanner/llm/toxicity_detector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ def run(self, model: LangchainModel, dataset: Dataset) -> Sequence[Issue]:
131131

132132
def _compute_toxicity_score(self, sentences: List[str]):
133133
try:
134-
from detoxify import Detoxify
134+
from giskard.scanner.llm.detoxify import Detoxify
135135
except ImportError as err:
136136
raise LLMImportError() from err
137-
138-
return Detoxify("unbiased").predict(list(sentences))["toxicity"]
137+
results = Detoxify().predict(list(sentences))
138+
return results["toxicity"]
139139

140140

141141
@dataclass

0 commit comments

Comments
 (0)