RecursiveCharacter Text Splitting¶
Text Splitting : The recursive_text_splitter function uses the RecursiveCharacterTextSplitter from Langchain to split the case law paragraph into smaller chunks based on a specified maximum chunk size.
Embedding and Storage: Each chunk is embedded using sentence-transformers, and the embeddings are stored in a FAISS index for efficient similarity searching.
RAG Retrieval: The retrieve_relevant_chunks function takes a user query, embeds it, and retrieves the top 3 most similar chunks from the FAISS index.
Visualization: The display_chunks function uses Matplotlib to display original text, user query and retrieved chunks.
In [ ]:
%pip install -q langchain sentence-transformers matplotlib faiss-cpu
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import faiss
# Initialize the sentence transformer model for embeddings
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize FAISS index for storing embeddings
dimension = 384 # Dimension of the embeddings from the model
index = faiss.IndexFlatL2(dimension)
# Sample paragraph for Case Law Research
case_law_paragraph = (
"In the case of Smith v. Jones, the court held that the defendant's actions "
"constituted negligence. The plaintiff demonstrated that the defendant failed "
"to act with reasonable care, leading to the plaintiff's injuries. The ruling "
"established a precedent regarding the duty of care owed to individuals in public spaces. "
"Furthermore, the court emphasized the importance of evidence in establishing the "
"causal link between the defendant's actions and the harm suffered by the plaintiff."
)
# Function to recursively split text into chunks
def recursive_text_splitter(text, max_chunk_size):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_chunk_size, chunk_overlap=0)
return text_splitter.split_text(text)
# Generate chunks using recursive splitting
max_chunk_size = 100 # Set desired chunk size
chunks = recursive_text_splitter(case_law_paragraph, max_chunk_size)
# Embed each chunk and store in the FAISS index
chunk_embeddings = []
for chunk in chunks:
embedding = embedding_model.encode(chunk)
chunk_embeddings.append(embedding)
# Convert list to numpy array and add to FAISS index
chunk_embeddings = np.array(chunk_embeddings).astype('float32')
index.add(chunk_embeddings)
# Function to perform retrieval-augmented generation
def retrieve_relevant_chunks(query):
query_embedding = embedding_model.encode(query).reshape(1, -1).astype('float32')
distances, indices = index.search(query_embedding, k=3) # Retrieve top 3 similar chunks
return [(chunks[i], distances[0][j]) for j, i in enumerate(indices[0])]
# Display original chunked data and retrieved chunks one below the other
def display_chunks(original_chunks, retrieved_chunks, user_query):
fig, axs = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [1, 1, 1]})
# Original Chunks
axs[0].barh(range(len(original_chunks)), [len(chunk) for chunk in original_chunks], color='skyblue')
axs[0].set_yticks(range(len(original_chunks)))
axs[0].set_yticklabels(original_chunks, fontsize=9) # Reduce font size for wrapping
axs[0].set_title('Original Chunks')
axs[0].invert_yaxis()
# User Query
axs[1].text(0.5, 0.5, f'User Query: {user_query}', fontsize=11, ha='center', va='center', wrap=True)
axs[1].axis('off') # Hide the axis
# Retrieved Chunks
retrieved_texts, _ = zip(*retrieved_chunks) # Unzip the text and distances
axs[2].barh(range(len(retrieved_texts)), [len(text) for text in retrieved_texts], color='lightgreen')
axs[2].set_yticks(range(len(retrieved_texts)))
axs[2].set_yticklabels(retrieved_texts, fontsize=9) # Reduce font size for wrapping
axs[2].set_title('Retrieved Chunks')
axs[2].invert_yaxis()
plt.tight_layout()
plt.show()
In [ ]:
# User query for RAG
user_query = "What are the key points regarding negligence in Smith v. Jones?"
retrieved_chunks = retrieve_relevant_chunks(user_query)
# Display results
display_chunks(chunks, retrieved_chunks, user_query)