|
1 | | -import spacy |
| 1 | +import nltk |
| 2 | +from nltk.tokenize import word_tokenize |
| 3 | +from nltk.corpus import stopwords |
2 | 4 | from instructor.retry import InstructorRetryException |
3 | 5 |
|
4 | 6 |
|
@@ -37,16 +39,9 @@ def __init__(self, strategy, llm_handler, history, schemas, response_handler): |
37 | 39 | self.endpoint_found_methods = {} |
38 | 40 | model_name = "en_core_web_sm" |
39 | 41 |
|
40 | | - # Check if the model is already installed |
41 | | - from spacy.util import is_package |
42 | | - if not is_package(model_name): |
43 | | - print(f"Model '{model_name}' is not installed. Installing now...") |
44 | | - spacy.cli.download(model_name) |
45 | | - |
46 | | - # Load the model |
47 | | - self.nlp = spacy.load(model_name) |
48 | | - |
49 | | - self.nlp = spacy.load("en_core_web_sm") |
| 42 | + # Check if the models are already installed |
| 43 | + nltk.download('punkt') |
| 44 | + nltk.download('stopwords') |
50 | 45 | self._prompt_history = history |
51 | 46 | self.prompt = self._prompt_history |
52 | 47 | self.previous_prompt = self._prompt_history[self.round]["content"] |
@@ -199,20 +194,19 @@ def chain_of_thought(self, doc=False, hint=""): |
199 | 194 |
|
200 | 195 | def token_count(self, text): |
201 | 196 | """ |
202 | | - Counts the number of word tokens in the provided text using spaCy's tokenizer. |
203 | | -
|
204 | | - Args: |
205 | | - text (str): The input text to tokenize and count. |
206 | | -
|
207 | | - Returns: |
208 | | - int: The number of tokens in the input text. |
209 | | - """ |
210 | | - # Process the text through spaCy's pipeline |
211 | | - doc = self.nlp(text) |
212 | | - # Count tokens that aren't punctuation marks |
213 | | - tokens = [token for token in doc if not token.is_punct] |
214 | | - print(f'TOKENS: {len(tokens)}') |
215 | | - return len(tokens) |
| 197 | + Counts the number of word tokens in the provided text using NLTK's tokenizer. |
| 198 | +
|
| 199 | + Args: |
| 200 | + text (str): The input text to tokenize and count. |
| 201 | +
|
| 202 | + Returns: |
| 203 | + int: The number of tokens in the input text. |
| 204 | + """ |
| 205 | + # Tokenize the text using NLTK |
| 206 | + tokens = word_tokenize(text) |
| 207 | + # Filter out punctuation marks |
| 208 | + words = [token for token in tokens if token.isalnum()] |
| 209 | + return len(words) |
216 | 210 |
|
217 | 211 |
|
218 | 212 | def check_prompt(self, previous_prompt, chain_of_thought_steps, max_tokens=900): |
|
0 commit comments