Skip to content

Commit 64a1f66

Browse files
committed
Updated TextNumberToWordTransformation to use row by row language instead of dataset language
1 parent 1cd39ee commit 64a1f66

1 file changed

Lines changed: 17 additions & 17 deletions

File tree

giskard/scanner/robustness/text_transformations.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import random
44
import re
55
from pathlib import Path
6-
from num2words import num2words
76

87
import numpy as np
98
import pandas as pd
9+
from num2words import num2words
1010

1111
from ...core.core import DatasetProcessFunctionMeta
1212
from ...datasets import Dataset
@@ -148,22 +148,6 @@ def make_perturbation(self, text):
148148
return "".join(pieces)
149149

150150

151-
class TextNumberToWordTransformation(TextTransformation):
152-
name = "Transform numbers to words"
153-
154-
def __init__(self, column, lang="en"):
155-
super().__init__(column)
156-
# Target language
157-
self.lang = lang
158-
159-
# Regex to match numbers in text
160-
self._regex = re.compile(r"(?<!\d/)(?<!\d\.)\b\d+(?:\.\d+)?\b(?!(?:\.\d+)?@|\d?/?\d)")
161-
162-
def make_perturbation(self, text):
163-
# Replace numbers with words
164-
return self._regex.sub(lambda x: num2words(x.group(), lang=self.lang), text)
165-
166-
167151
class TextLanguageBasedTransformation(TextTransformation):
168152
needs_dataset = True
169153

@@ -226,6 +210,22 @@ def _switch(self, word, language):
226210
return None
227211

228212

213+
class TextNumberToWordTransformation(TextLanguageBasedTransformation):
214+
name = "Transform numbers to words"
215+
216+
def __init__(self, column, lang="en"):
217+
super().__init__(column)
218+
# Target language
219+
self.lang = lang
220+
221+
# Regex to match numbers in text
222+
self._regex = re.compile(r"(?<!\d/)(?<!\d\.)\b\d+(?:\.\d+)?\b(?!(?:\.\d+)?@|\d?/?\d)")
223+
224+
def make_perturbation(self, row):
225+
# Replace numbers with words
226+
return self._regex.sub(lambda x: num2words(x.group(), lang=row["language__gsk__meta"]), row[self.column])
227+
228+
229229
class TextReligionTransformation(TextLanguageBasedTransformation):
230230
name = "Switch Religion"
231231

0 commit comments

Comments
 (0)