|
| 1 | +from dataclasses import dataclass |
| 2 | +import pandas as pd |
| 3 | +from sklearn.metrics import adjusted_mutual_info_score, mutual_info_score |
| 4 | +from scipy import stats |
| 5 | + |
| 6 | +from ..common.examples import ExampleExtractor |
| 7 | +from ...ml_worker.testing.registry.slicing_function import SlicingFunction |
| 8 | +from ..issues import Issue |
| 9 | +from ...slicing.slice_finder import SliceFinder |
| 10 | +from ..logger import logger |
| 11 | +from ...datasets.base import Dataset |
| 12 | +from ...models.base import BaseModel |
| 13 | +from ..registry import Detector |
| 14 | +from ..decorators import detector |
| 15 | + |
| 16 | + |
| 17 | +@detector(name="spurious_correlation", tags=["spurious_correlation", "classification"]) |
| 18 | +class SpuriousCorrelationDetector(Detector): |
| 19 | + def __init__(self, method="theil", threshold=0.5) -> None: |
| 20 | + self.threshold = threshold |
| 21 | + self.method = method |
| 22 | + |
| 23 | + def run(self, model: BaseModel, dataset: Dataset): |
| 24 | + logger.info(f"{self.__class__.__name__}: Running") |
| 25 | + |
| 26 | + # Dataset prediction |
| 27 | + ds_predictions = pd.Series(model.predict(dataset).prediction, dataset.df.index) |
| 28 | + |
| 29 | + # Keep only interesting features |
| 30 | + features = model.meta.feature_names or dataset.columns.drop(dataset.target, errors="ignore") |
| 31 | + |
| 32 | + # Warm up text metadata |
| 33 | + for f in features: |
| 34 | + if dataset.column_types[f] == "text": |
| 35 | + dataset.column_meta[f, "text"] |
| 36 | + |
| 37 | + # Prepare dataset for slicing |
| 38 | + df = dataset.df.copy() |
| 39 | + if dataset.target is not None: |
| 40 | + df.drop(columns=dataset.target, inplace=True) |
| 41 | + df["__gsk__target"] = pd.Categorical(ds_predictions) |
| 42 | + wdata = Dataset(df, target="__gsk__target", column_types=dataset.column_types) |
| 43 | + wdata.load_metadata_from_instance(dataset.column_meta) |
| 44 | + |
| 45 | + # Find slices |
| 46 | + sliced_cols = SliceFinder("tree").run(wdata, features, target=wdata.target) |
| 47 | + |
| 48 | + measure_fn, measure_name = self._get_measure_fn() |
| 49 | + issues = [] |
| 50 | + for col, slices in sliced_cols.items(): |
| 51 | + if not slices: |
| 52 | + continue |
| 53 | + |
| 54 | + for slice_fn in slices: |
| 55 | + data_slice = dataset.slice(slice_fn) |
| 56 | + |
| 57 | + # Skip small slices |
| 58 | + if len(data_slice) < 20 or len(data_slice) < 0.05 * len(dataset): |
| 59 | + continue |
| 60 | + |
| 61 | + dx = pd.DataFrame( |
| 62 | + { |
| 63 | + "feature": dataset.df.index.isin(data_slice.df.index).astype(int), |
| 64 | + "prediction": ds_predictions, |
| 65 | + }, |
| 66 | + index=dataset.df.index, |
| 67 | + ) |
| 68 | + dx.dropna(inplace=True) |
| 69 | + |
| 70 | + metric_value = measure_fn(dx.feature, dx.prediction) |
| 71 | + logger.info(f"{self.__class__.__name__}: {slice_fn}\tAssociation = {metric_value:.3f}") |
| 72 | + |
| 73 | + if metric_value > self.threshold: |
| 74 | + predictions = dx[dx.feature > 0].prediction.value_counts(normalize=True) |
| 75 | + info = SpuriousCorrelationInfo(col, slice_fn, metric_value, measure_name, predictions) |
| 76 | + issues.append(SpuriousCorrelationIssue(model, dataset, "info", info)) |
| 77 | + |
| 78 | + return issues |
| 79 | + |
| 80 | + def _get_measure_fn(self): |
| 81 | + if self.method == "theil": |
| 82 | + return _theil_u, "Theil's U" |
| 83 | + if self.method == "mutual_information" or self.method == "mi": |
| 84 | + return _mutual_information, "Mutual information" |
| 85 | + if self.method == "cramer": |
| 86 | + return _cramer_v, "Cramer's V" |
| 87 | + raise ValueError(f"Unknown method `{self.method}`") |
| 88 | + |
| 89 | + |
| 90 | +def _cramer_v(x, y): |
| 91 | + ct = pd.crosstab(x, y) |
| 92 | + return stats.contingency.association(ct, method="cramer") |
| 93 | + |
| 94 | + |
| 95 | +def _mutual_information(x, y): |
| 96 | + return adjusted_mutual_info_score(x, y) |
| 97 | + |
| 98 | + |
| 99 | +def _theil_u(x, y): |
| 100 | + return mutual_info_score(x, y) / stats.entropy(pd.Series(y).value_counts(normalize=True)) |
| 101 | + |
| 102 | + |
| 103 | +@dataclass |
| 104 | +class SpuriousCorrelationInfo: |
| 105 | + feature: str |
| 106 | + slice_fn: SlicingFunction |
| 107 | + metric_value: float |
| 108 | + metric_name: str |
| 109 | + predictions: pd.DataFrame |
| 110 | + |
| 111 | + |
| 112 | +class SpuriousCorrelationIssue(Issue): |
| 113 | + group = "Spurious correlation" |
| 114 | + |
| 115 | + @property |
| 116 | + def features(self): |
| 117 | + return [self.info.feature] |
| 118 | + |
| 119 | + @property |
| 120 | + def domain(self) -> str: |
| 121 | + return str(self.info.slice_fn) |
| 122 | + |
| 123 | + @property |
| 124 | + def metric(self) -> str: |
| 125 | + return f"Nominal association ({self.info.metric_name})" |
| 126 | + |
| 127 | + @property |
| 128 | + def deviation(self) -> str: |
| 129 | + plabel, p = self.info.predictions.index[0], self.info.predictions.iloc[0] |
| 130 | + |
| 131 | + return f"Prediction {self.dataset.target} = `{plabel}` for {p * 100:.2f}% of samples in the slice" |
| 132 | + |
| 133 | + @property |
| 134 | + def slicing_fn(self): |
| 135 | + return self.info.slice_fn |
| 136 | + |
| 137 | + @property |
| 138 | + def description(self) -> str: |
| 139 | + pred = self.model.predict(self.dataset.slice(self.info.slice_fn)).prediction |
| 140 | + classes = pd.Series(pred).value_counts(normalize=True) |
| 141 | + plabel, p = classes.index[0], classes.iloc[0] |
| 142 | + return f"Data slice {self.info.slice_fn} seems to be highly associated to prediction {self.dataset.target} = `{plabel}` ({p * 100:.2f}% of predictions in the data slice)." |
| 143 | + |
| 144 | + # @lru_cache |
| 145 | + def examples(self, n=3): |
| 146 | + extractor = ExampleExtractor(self) |
| 147 | + return extractor.get_examples_dataframe(n, with_prediction=1) |
| 148 | + |
| 149 | + @property |
| 150 | + def importance(self) -> float: |
| 151 | + return self.info.metric_value |
0 commit comments