Adaptive Chunking¶
- Adaptive Chunking: The adaptive_chunking function uses NLTK to split the text into sentences and groups them into chunks based on a simple heuristic (e.g., chunking after every two sentences). You can modify this logic to suit your needs.
- Embedding: The embed_and_store function converts the chunks into embeddings using the SentenceTransformer model and stores them in a FAISS index for efficient similarity.
- Retrieval: The retrieve function takes a user query, embeds it, and searches for the most relevant chunks in the FAISS index. If no results are found, it returns a message.
- Visualization: The visualize_results function displays the original chunked data, the user query, and the retrieved data, including their indices and distances, using matplotlib.
- Main Function: The main function orchestrates the chunking, embedding, retrieval, and visualization processes.
In [ ]:
%pip install -q nltk sentence-transformers faiss-cpu matplotlib
In [ ]:
import nltk
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import matplotlib.pyplot as plt
import textwrap # Import textwrap for wrapping text
# Ensure necessary NLTK resources are available
nltk.download('punkt')
# Initialize the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to perform adaptive chunking
def adaptive_chunking(text):
# Tokenize text into sentences
sentences = nltk.sent_tokenize(text)
chunks = []
# Chunking logic: grouping sentences into chunks
current_chunk = []
for sentence in sentences:
current_chunk.append(sentence)
if len(current_chunk) >= 2: # You can adjust the chunk size
chunks.append(' '.join(current_chunk))
current_chunk = [] # Start a new chunk
if current_chunk:
chunks.append(' '.join(current_chunk)) # Add the last chunk
return chunks
# Function to embed chunks and store them in a FAISS index
def embed_and_store(chunks):
embeddings = model.encode(chunks)
# Initialize FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
# Add embeddings to the index
index.add(np.array(embeddings, dtype='float32'))
return index, embeddings
# Function to retrieve relevant chunks based on user query
def retrieve(query, index, chunks):
query_embedding = model.encode([query])
D, I = index.search(np.array(query_embedding, dtype='float32'), k=3) # Retrieve top 5
if I[0][0] == -1: # Check if no results found
return "No results found for user query"
retrieved_chunks = [(chunks[i], I[0][i], D[0][i]) for i in I[0]]
return retrieved_chunks
# Function to visualize results
def visualize_results(original_chunks, retrieved_chunks, user_query):
# Create a figure to visualize results
plt.figure(figsize=(12, 12))
# Original chunked data
plt.subplot(3, 1, 1)
plt.title("Original Chunked Data")
wrapped_chunks = [textwrap.fill(chunk, width=50) for chunk in original_chunks] # Wrap text to fit
plt.bar(range(len(original_chunks)), [1]*len(original_chunks), tick_label=[f'Index {i}:\n {wrapped}' for i, wrapped in enumerate(wrapped_chunks)])
plt.xticks(rotation=0)
plt.ylabel("Chunks")
# Display user query
plt.subplot(3, 1, 2)
#plt.title("User Query")
plt.text(0.5, 0.5, f'User Query : {user_query}', ha='center', va='center', fontsize=12, bbox=dict(facecolor='lightblue', alpha=0.5))
plt.axis('off')
# Retrieved chunks and distances
plt.subplot(3, 1, 3)
plt.title("Retrieved Chunked Data with Index and Distance\n")
if retrieved_chunks != "No results found for user query":
# Prepare table data
table_data = []
for idx, (chunk, index, distance) in enumerate(retrieved_chunks):
wrapped_chunk = textwrap.fill(chunk, width=50) # Wrap retrieved chunk text
table_data.append([wrapped_chunk, index, f"{distance:.4f}"])
# Create a table
column_labels = ['Chunk', 'Index', 'Distance']
table = plt.table(cellText=table_data, colLabels=column_labels, cellLoc='center', loc='center', bbox=[0, 0, 1, 1])
# Set font size and bold headers
table.auto_set_font_size(False) # Disable automatic font size
table.set_fontsize(9) # Set font size for table
for key, cell in table.get_celld().items():
if key[0] == 0: # Header row
cell.set_text_props(fontweight='bold') # Set header font to bold
else:
plt.text(0.5, 0.5, retrieved_chunks, ha='center', va='center', fontsize=12, color='red')
plt.axis('off')
plt.tight_layout()
plt.show()
# Main function to tie it all together
def main(text, user_query):
chunks = adaptive_chunking(text)
index, embeddings = embed_and_store(chunks)
retrieved_chunks = retrieve(user_query, index, chunks)
visualize_results(chunks, retrieved_chunks, user_query)
# Example usage
text = """This is a sample paragraph. It contains several sentences that form a coherent narrative.
The goal is to chunk this paragraph intelligently. Adaptive chunking can improve the retrieval of relevant information based on user queries.
Moreover, it aids in understanding the structure and themes of the text."""
user_query = "What is the goal of chunking"
#text = """Natural Language Processing (NLP) is a subfield of artificial intelligence. It focuses on the interaction between #computers and humans through natural language.
#The goal is to chunk this paragraph intelligently. Adaptive chunking can improve the retrieval of relevant information based on #user queries.
#Moreover, it aids in understanding the structure and themes of the text."""
#text = """Natural Language Processing (NLP) is a subfield of artificial intelligence.
#It focuses on the interaction between computers and humans through natural language.
#The ultimate objective of NLP is to read, decipher, understand, and make sense of human languages in a valuable way. """
#user_query = "What is the goal of chunking"
main(text, user_query)
[nltk_data] Downloading package punkt to /root/nltk_data... [nltk_data] Package punkt is already up-to-date!