Skip to content

Commit 980ce50

Browse files
authored
Merge pull request #1178 from Giskard-AI/task/GSK-1078
[GSK-1078] Correlation detector
2 parents e0de872 + d4ae7ae commit 980ce50

17 files changed

Lines changed: 420 additions & 71 deletions

python-client/giskard/core/model_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _do_validate_model(model: BaseModel, validate_ds: Optional[Dataset] = None):
7272
else: # Classification with target = None
7373
validate_model_execution(model, validate_ds)
7474

75-
if model.meta.model_type == SupportedModelTypes.CLASSIFICATION:
75+
if model.meta.model_type == SupportedModelTypes.CLASSIFICATION and validate_ds.target is not None:
7676
validate_order_classifcation_labels(model, validate_ds)
7777

7878

python-client/giskard/scanner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
".stochasticity.stochasticity_detector",
1313
".calibration.overconfidence_detector",
1414
".calibration.underconfidence_detector",
15+
".correlation.spurious_correlation_detector",
1516
".llm.toxicity_detector",
1617
]
1718

python-client/giskard/scanner/calibration/underconfidence_detector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
@detector(name="underconfidence", tags=["underconfidence", "classification"])
1616
class UnderconfidenceDetector(LossBasedDetector):
17+
_needs_target = False
18+
1719
def __init__(self, threshold=0.1, p_threshold=0.95, method="tree"):
1820
self.threshold = threshold
1921
self.p_threshold = p_threshold

python-client/giskard/scanner/common/examples.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def get_examples_dataframe(self, n=3, with_prediction: Union[int, bool] = 1):
2222
examples = dataset.df.copy()
2323

2424
# Keep only interesting columns
25-
cols_to_show = issue.features + [issue.dataset.target]
25+
cols_to_show = issue.features
26+
if issue.dataset.target is not None:
27+
cols_to_show += [issue.dataset.target]
2628
examples = examples.loc[:, cols_to_show]
2729

2830
# If metadata slice, add the metadata column
@@ -54,7 +56,10 @@ def get_examples_dataframe(self, n=3, with_prediction: Union[int, bool] = 1):
5456
else:
5557
pred_examples = model_pred.prediction
5658

57-
examples[f"Predicted `{issue.dataset.target}`"] = pred_examples
59+
predicted_label = "Predicted"
60+
if issue.dataset.target is not None:
61+
predicted_label += f" `{issue.dataset.target}`"
62+
examples[predicted_label] = pred_examples
5863

5964
n = min(len(examples), n)
6065
if n > 0:

python-client/giskard/scanner/common/loss_based_detector.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
from typing import Sequence
55
from abc import abstractmethod
66

7+
from ...slicing.slice_finder import SliceFinder
8+
79
from ..registry import Detector
810

911
from ...models.base import BaseModel
1012
from ...datasets.base import Dataset
11-
from ...slicing.utils import get_slicer
12-
from ...slicing.text_slicer import TextSlicer
13-
from ...slicing.category_slicer import CategorySlicer
1413
from ...ml_worker.testing.registry.slicing_function import SlicingFunction
1514
from ..logger import logger
1615
from ..issues import Issue
@@ -21,7 +20,13 @@ class LossBasedDetector(Detector):
2120
MAX_DATASET_SIZE = 10_000_000
2221
LOSS_COLUMN_NAME = "__gsk__loss"
2322

23+
_needs_target = True
24+
2425
def run(self, model: BaseModel, dataset: Dataset):
26+
if self._needs_target and dataset.target is None:
27+
logger.info(f"{self.__class__.__name__}: Skipping detection because the dataset has no target column.")
28+
return []
29+
2530
logger.info(f"{self.__class__.__name__}: Running")
2631

2732
# Check if we have enough data to run the scan
@@ -48,8 +53,7 @@ def run(self, model: BaseModel, dataset: Dataset):
4853
# Find slices
4954
logger.info(f"{self.__class__.__name__}: Finding data slices")
5055
start = perf_counter()
51-
dataset_to_slice = dataset.select_columns(model.meta.feature_names) if model.meta.feature_names else dataset
52-
slices = self._find_slices(dataset_to_slice, meta)
56+
slices = self._find_slices(model, dataset, meta)
5357
elapsed = perf_counter() - start
5458
logger.info(
5559
f"{self.__class__.__name__}: {len(slices)} slices found (took {datetime.timedelta(seconds=elapsed)})"
@@ -70,7 +74,9 @@ def run(self, model: BaseModel, dataset: Dataset):
7074
def _numerical_slicer_method(self):
7175
return "tree"
7276

73-
def _find_slices(self, dataset: Dataset, meta: pd.DataFrame):
77+
def _find_slices(self, model: BaseModel, dataset: Dataset, meta: pd.DataFrame):
78+
features = model.meta.feature_names or dataset.columns.drop(dataset.target, errors="ignore")
79+
7480
df_with_meta = dataset.df.join(meta, how="right")
7581

7682
column_types = dataset.column_types.copy()
@@ -85,28 +91,10 @@ def _find_slices(self, dataset: Dataset, meta: pd.DataFrame):
8591
# For performance
8692
dataset_with_meta.load_metadata_from_instance(dataset.column_meta)
8793

88-
# Columns by type
89-
cols_by_type = {
90-
type_val: [col for col, col_type in dataset.column_types.items() if col_type == type_val]
91-
for type_val in ["numeric", "category", "text"]
92-
}
93-
94-
# Numerical features
95-
slicer = get_slicer(self._numerical_slicer_method, dataset_with_meta, self.LOSS_COLUMN_NAME)
96-
97-
slices = []
98-
for col in cols_by_type["numeric"]:
99-
slices.extend(slicer.find_slices([col]))
100-
101-
# Categorical features
102-
slicer = CategorySlicer(dataset_with_meta, target=self.LOSS_COLUMN_NAME)
103-
for col in cols_by_type["category"]:
104-
slices.extend(slicer.find_slices([col]))
105-
106-
# Text features
107-
slicer = TextSlicer(dataset_with_meta, target=self.LOSS_COLUMN_NAME, slicer=self._numerical_slicer_method)
108-
for col in cols_by_type["text"]:
109-
slices.extend(slicer.find_slices([col]))
94+
# Find slices
95+
sf = SliceFinder(numerical_slicer=self._numerical_slicer_method)
96+
sliced = sf.run(dataset_with_meta, features, target=self.LOSS_COLUMN_NAME)
97+
slices = sum(sliced.values(), start=[])
11098

11199
# Keep only slices of size at least 5% of the dataset or 20 samples (whatever is larger)
112100
slices = [s for s in slices if max(0.05 * len(dataset), 20) <= len(dataset_with_meta.slice(s))]
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from dataclasses import dataclass
2+
import pandas as pd
3+
from sklearn.metrics import adjusted_mutual_info_score, mutual_info_score
4+
from scipy import stats
5+
6+
from ..common.examples import ExampleExtractor
7+
from ...ml_worker.testing.registry.slicing_function import SlicingFunction
8+
from ..issues import Issue
9+
from ...slicing.slice_finder import SliceFinder
10+
from ..logger import logger
11+
from ...datasets.base import Dataset
12+
from ...models.base import BaseModel
13+
from ..registry import Detector
14+
from ..decorators import detector
15+
16+
17+
@detector(name="spurious_correlation", tags=["spurious_correlation", "classification"])
18+
class SpuriousCorrelationDetector(Detector):
19+
def __init__(self, method="theil", threshold=0.5) -> None:
20+
self.threshold = threshold
21+
self.method = method
22+
23+
def run(self, model: BaseModel, dataset: Dataset):
24+
logger.info(f"{self.__class__.__name__}: Running")
25+
26+
# Dataset prediction
27+
ds_predictions = pd.Series(model.predict(dataset).prediction, dataset.df.index)
28+
29+
# Keep only interesting features
30+
features = model.meta.feature_names or dataset.columns.drop(dataset.target, errors="ignore")
31+
32+
# Warm up text metadata
33+
for f in features:
34+
if dataset.column_types[f] == "text":
35+
dataset.column_meta[f, "text"]
36+
37+
# Prepare dataset for slicing
38+
df = dataset.df.copy()
39+
if dataset.target is not None:
40+
df.drop(columns=dataset.target, inplace=True)
41+
df["__gsk__target"] = pd.Categorical(ds_predictions)
42+
wdata = Dataset(df, target="__gsk__target", column_types=dataset.column_types)
43+
wdata.load_metadata_from_instance(dataset.column_meta)
44+
45+
# Find slices
46+
sliced_cols = SliceFinder("tree").run(wdata, features, target=wdata.target)
47+
48+
measure_fn, measure_name = self._get_measure_fn()
49+
issues = []
50+
for col, slices in sliced_cols.items():
51+
if not slices:
52+
continue
53+
54+
for slice_fn in slices:
55+
data_slice = dataset.slice(slice_fn)
56+
57+
# Skip small slices
58+
if len(data_slice) < 20 or len(data_slice) < 0.05 * len(dataset):
59+
continue
60+
61+
dx = pd.DataFrame(
62+
{
63+
"feature": dataset.df.index.isin(data_slice.df.index).astype(int),
64+
"prediction": ds_predictions,
65+
},
66+
index=dataset.df.index,
67+
)
68+
dx.dropna(inplace=True)
69+
70+
metric_value = measure_fn(dx.feature, dx.prediction)
71+
logger.info(f"{self.__class__.__name__}: {slice_fn}\tAssociation = {metric_value:.3f}")
72+
73+
if metric_value > self.threshold:
74+
predictions = dx[dx.feature > 0].prediction.value_counts(normalize=True)
75+
info = SpuriousCorrelationInfo(col, slice_fn, metric_value, measure_name, predictions)
76+
issues.append(SpuriousCorrelationIssue(model, dataset, "info", info))
77+
78+
return issues
79+
80+
def _get_measure_fn(self):
81+
if self.method == "theil":
82+
return _theil_u, "Theil's U"
83+
if self.method == "mutual_information" or self.method == "mi":
84+
return _mutual_information, "Mutual information"
85+
if self.method == "cramer":
86+
return _cramer_v, "Cramer's V"
87+
raise ValueError(f"Unknown method `{self.method}`")
88+
89+
90+
def _cramer_v(x, y):
91+
ct = pd.crosstab(x, y)
92+
return stats.contingency.association(ct, method="cramer")
93+
94+
95+
def _mutual_information(x, y):
96+
return adjusted_mutual_info_score(x, y)
97+
98+
99+
def _theil_u(x, y):
100+
return mutual_info_score(x, y) / stats.entropy(pd.Series(y).value_counts(normalize=True))
101+
102+
103+
@dataclass
104+
class SpuriousCorrelationInfo:
105+
feature: str
106+
slice_fn: SlicingFunction
107+
metric_value: float
108+
metric_name: str
109+
predictions: pd.DataFrame
110+
111+
112+
class SpuriousCorrelationIssue(Issue):
113+
group = "Spurious correlation"
114+
115+
@property
116+
def features(self):
117+
return [self.info.feature]
118+
119+
@property
120+
def domain(self) -> str:
121+
return str(self.info.slice_fn)
122+
123+
@property
124+
def metric(self) -> str:
125+
return f"Nominal association ({self.info.metric_name})"
126+
127+
@property
128+
def deviation(self) -> str:
129+
plabel, p = self.info.predictions.index[0], self.info.predictions.iloc[0]
130+
131+
return f"Prediction {self.dataset.target} = `{plabel}` for {p * 100:.2f}% of samples in the slice"
132+
133+
@property
134+
def slicing_fn(self):
135+
return self.info.slice_fn
136+
137+
@property
138+
def description(self) -> str:
139+
pred = self.model.predict(self.dataset.slice(self.info.slice_fn)).prediction
140+
classes = pd.Series(pred).value_counts(normalize=True)
141+
plabel, p = classes.index[0], classes.iloc[0]
142+
return f"Data slice {self.info.slice_fn} seems to be highly associated to prediction {self.dataset.target} = `{plabel}` ({p * 100:.2f}% of predictions in the data slice)."
143+
144+
# @lru_cache
145+
def examples(self, n=3):
146+
extractor = ExampleExtractor(self)
147+
return extractor.get_examples_dataframe(n, with_prediction=1)
148+
149+
@property
150+
def importance(self) -> float:
151+
return self.info.metric_value

python-client/giskard/scanner/result.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ def _repr_html_(self):
4646
issues=self.issues,
4747
issues_by_group=issues_by_group,
4848
num_major_issues={
49-
group: len([i for i in issues if i.is_major]) for group, issues in issues_by_group.items()
49+
group: len([i for i in issues if i.level == "major"]) for group, issues in issues_by_group.items()
5050
},
5151
num_medium_issues={
52-
group: len([i for i in issues if not i.is_major]) for group, issues in issues_by_group.items()
52+
group: len([i for i in issues if i.level == "medium"]) for group, issues in issues_by_group.items()
53+
},
54+
num_info_issues={
55+
group: len([i for i in issues if i.level == "info"]) for group, issues in issues_by_group.items()
5356
},
5457
)
5558

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<tr class="gsk-issue text-sm group peer text-left cursor-pointer hover:bg-zinc-700">
2+
<td class="p-3">
3+
<code class="mono text-blue-300">
4+
{{issue.domain|replace("&", "<br>")|safe}}
5+
</code>
6+
</td>
7+
<td class="p-3">
8+
{{ issue.metric }} = {{ issue.info.metric_value|format_metric }}
9+
</td>
10+
<td class="p-3 text-amber-200">
11+
{{ issue.deviation }}
12+
</td>
13+
<td class="p-3">
14+
<span class="text-gray-400">
15+
<!-- {{ issue.description }} -->
16+
</span>
17+
</td>
18+
<td class="p-3 text-xs text-right space-x-1">
19+
<a href="#"
20+
class="gsk-issue-detail-btn inline-block group-[.open]:hidden border border-zinc-100/50 text-zinc-100/90 hover:bg-zinc-500 hover:border-zinc-500 hover:text-white px-2 py-0.5 rounded-sm">Show details</a>
21+
<a href="#"
22+
class="hidden group-[.open]:inline-block gsk-issue-detail-btn border border-zinc-500 text-zinc-100/90 bg-zinc-500 hover:bg-zinc-400 hover:text-white px-2 py-0.5 rounded-sm">Hide details</a>
23+
</td>
24+
</tr>
25+
<tr class="gsk-issue-detail text-left collapse peer-[.open]:visible border-b border-zinc-500 bg-zinc-700">
26+
<td colspan="5" class="p-3">
27+
<h4 class="font-bold text-sm">Description</h4>
28+
{{ issue.description | safe }}
29+
30+
{% if issue.examples(3)|length %}
31+
<h4 class="font-bold text-sm mt-4">Examples</h4>
32+
<div class="text-white max-w-xl text-sm overflow-scroll" style="max-width: 920px">
33+
{{ issue.examples(3).to_html(notebook=True) | replace("\\n", "<br>") | safe }}
34+
</div>
35+
{% endif %}
36+
</td>
37+
</tr>

python-client/giskard/scanner/templates/_issues_table.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
{% include "_issues/stochasticity.html" %}
1717
{% elif issue.__class__.__name__ == 'LLMToxicityIssue' %}
1818
{% include "_issues/llm_toxicity.html" %}
19+
{% elif issue.__class__.__name__ == 'SpuriousCorrelationIssue' %}
20+
{% include "_issues/spurious_correlation.html" %}
1921
{% else %}
2022
{% include "_issues/default.html" %}
2123
{% endif %}

python-client/giskard/scanner/templates/_main_content.html

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@
8585
Your model seems to exhibit offensive behaviour when we use adversarial “Do Anything Now” (DAN)
8686
prompts.
8787
</p>
88+
{% elif issues[0].__class__.__name__ == "SpuriousCorrelationIssue" %}
89+
<p>
90+
We found potential spurious correlations between your data and the model predictions. Spurious
91+
correlations may occur when the model overfits on relations that are not causal. We recommend that you
92+
verify the causal relationship between the detected data slices and the target variable.
93+
</p>
8894
{% else %}
8995
<p>Found issues for {{ issues[0].group }}</p>
9096
{% endif %}
@@ -101,6 +107,10 @@ <h2 class="uppercase my-4 mr-2 font-medium">Issues</h2>
101107
<span class="text-xs border rounded px-1 uppercase text-amber-200 border-amber-200">{{num_medium_issues[group]}}
102108
medium</span>
103109
{% endif %}
110+
{% if num_info_issues[group] > 0 %}
111+
<span class="text-xs border rounded px-1 uppercase text-blue-200 border-blue-200">{{num_info_issues[group]}}
112+
info</span>
113+
{% endif %}
104114
</div>
105115

106116
{% include "_issues_table.html" %}

0 commit comments

Comments
 (0)