Retrieval-Augmented Generation
检索增强生成, 通常称为 RAG (Retrieval-Augmented Generation), 是一种强大的聊天机器人的设计模式. 其中, 检索系统实时获取与查询相关的经过验证的源 / 文档, 并将其输入生成模型 (例如 GPT-4) 以生成响应.
Context is everything when it comes to getting the most out of an AI tool. To improve the relevance and quality of a generative AI output, you need to improve the relevance and quality of the input.

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from google.colab import userdata
# Load document
document_url = "https://arxiv.org/pdf/2312.10997.pdf"
loader = PyPDFLoader(document_url)
pages = loader.load()
# Split document into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400,
chunk_overlap=40,
length_function=len,
is_separator_regex=False,
)
chunks = text_splitter.split_documents(pages)
# Create embeddings from chunks
model_name = "BAAI/bge-small-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
bge_embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)
chunk_texts = list(map(lambda d: d.page_content, chunks))
embeddings = bge_embeddings.embed_documents(chunk_texts)
# Store embeddings
text_embedding_pairs = zip(chunk_texts, embeddings)
db = FAISS.from_embeddings(text_embedding_pairs, bge_embeddings)
# Search database for similar contexts
query = "Which are the drawbacks of Naive RAG?"
contexts = db.similarity_search(query, k=5)
# Chat with model
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an expert at answering questions
based on a context extracted from a document.
The context extracted from the document is: {context}""",
),
("human", "{question}"),
]
)
api_key = userdata.get("ANTHROPIC_API_KEY")
model = ChatAnthropic(model="claude-3-haiku-20240307", api_key=api_key)
chain = prompt | model
response = chain.invoke(
{
"context": "\n\n".join(list(map(lambda c: c.page_content, contexts))),
"question": query,
}
)
print(response.content)