|
| 1 | +from typing import Tuple |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | + |
| 7 | +from giskard import Dataset, Model |
| 8 | + |
| 9 | +IPCC_REPORT_URL = "https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_LongerReport.pdf" |
| 10 | + |
| 11 | +LLM_NAME = "gpt-3.5-turbo-instruct" |
| 12 | + |
| 13 | +TEXT_COLUMN_NAME = "query" |
| 14 | + |
| 15 | +PROMPT_TEMPLATE = """You are the Climate Assistant, a helpful AI assistant made by Giskard. |
| 16 | +Your task is to answer common questions on climate change. |
| 17 | +You will be given a question and relevant excerpts from the IPCC Climate Change Synthesis Report (2023). |
| 18 | +Please provide short and clear answers based on the provided context. Be polite and helpful. |
| 19 | +
|
| 20 | +Context: |
| 21 | +{context} |
| 22 | +
|
| 23 | +Question: |
| 24 | +{question} |
| 25 | +
|
| 26 | +Your answer: |
| 27 | +""" |
| 28 | + |
| 29 | + |
| 30 | +def ippc_model_and_dataset() -> Tuple[Model, Dataset]: |
| 31 | + from langchain import FAISS, OpenAI, PromptTemplate |
| 32 | + from langchain.chains import RetrievalQA, load_chain |
| 33 | + from langchain.chains.base import Chain |
| 34 | + from langchain.document_loaders import PyPDFLoader |
| 35 | + from langchain.embeddings import OpenAIEmbeddings |
| 36 | + from langchain.text_splitter import RecursiveCharacterTextSplitter |
| 37 | + |
| 38 | + # Define a custom Giskard model wrapper for the serialization. |
| 39 | + class FAISSRAGModel(Model): |
| 40 | + def model_predict(self, df: pd.DataFrame) -> pd.DataFrame: |
| 41 | + return df[TEXT_COLUMN_NAME].apply(lambda x: self.model.run({"query": x})) |
| 42 | + |
| 43 | + def save_model(self, path: str, *args, **kwargs): |
| 44 | + out_dest = Path(path) |
| 45 | + # Save the chain object |
| 46 | + self.model.save(out_dest.joinpath("model.json")) |
| 47 | + |
| 48 | + # Save the FAISS-based retriever |
| 49 | + db = self.model.retriever.vectorstore |
| 50 | + db.save_local(out_dest.joinpath("faiss")) |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def load_model(cls, path: str, *args, **kwargs) -> Chain: |
| 54 | + src = Path(path) |
| 55 | + |
| 56 | + # Load the FAISS-based retriever |
| 57 | + db = FAISS.load_local(src.joinpath("faiss"), OpenAIEmbeddings()) |
| 58 | + |
| 59 | + # Load the chain, passing the retriever |
| 60 | + chain = load_chain(src.joinpath("model.json"), retriever=db.as_retriever()) |
| 61 | + return chain |
| 62 | + |
| 63 | + def get_context_storage() -> FAISS: |
| 64 | + """Initialize a vector storage of embedded IPCC report chunks (context).""" |
| 65 | + text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100, add_start_index=True) |
| 66 | + docs = PyPDFLoader(IPCC_REPORT_URL).load_and_split(text_splitter) |
| 67 | + db = FAISS.from_documents(docs, OpenAIEmbeddings()) |
| 68 | + return db |
| 69 | + |
| 70 | + # Create the chain. |
| 71 | + llm = OpenAI(model=LLM_NAME, temperature=0) |
| 72 | + prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["question", "context"]) |
| 73 | + climate_qa_chain = RetrievalQA.from_llm(llm=llm, retriever=get_context_storage().as_retriever(), prompt=prompt) |
| 74 | + |
| 75 | + # Wrap the QA chain |
| 76 | + giskard_model = FAISSRAGModel( |
| 77 | + model=climate_qa_chain, # A prediction function that encapsulates all the data pre-processing steps and that could be executed with the dataset used by the scan. |
| 78 | + model_type="text_generation", # Either regression, classification or text_generation. |
| 79 | + name="Climate Change Question Answering", # Optional. |
| 80 | + description="This model answers any question about climate change based on IPCC reports", # Is used to generate prompts during the scan. |
| 81 | + feature_names=[TEXT_COLUMN_NAME], # Default: all columns of your dataset. |
| 82 | + ) |
| 83 | + |
| 84 | + # Optional: Wrap a dataframe of sample input prompts to validate the model wrapping and to narrow specific tests' queries. |
| 85 | + giskard_dataset = Dataset( |
| 86 | + pd.DataFrame( |
| 87 | + { |
| 88 | + TEXT_COLUMN_NAME: [ |
| 89 | + "According to the IPCC report, what are key risks in the Europe?", |
| 90 | + "Is sea level rise avoidable? When will it stop?", |
| 91 | + ] |
| 92 | + } |
| 93 | + ), |
| 94 | + target=None, |
| 95 | + ) |
| 96 | + return giskard_model, giskard_dataset |
0 commit comments