Adaptive Chunking¶

  • Adaptive Chunking: The adaptive_chunking function uses NLTK to split the text into sentences and groups them into chunks based on a simple heuristic (e.g., chunking after every two sentences). You can modify this logic to suit your needs.
  • Embedding: The embed_and_store function converts the chunks into embeddings using the SentenceTransformer model and stores them in a FAISS index for efficient similarity.
  • Retrieval: The retrieve function takes a user query, embeds it, and searches for the most relevant chunks in the FAISS index. If no results are found, it returns a message.
  • Visualization: The visualize_results function displays the original chunked data, the user query, and the retrieved data, including their indices and distances, using matplotlib.
  • Main Function: The main function orchestrates the chunking, embedding, retrieval, and visualization processes.
In [ ]:
%pip install -q nltk sentence-transformers faiss-cpu matplotlib
In [ ]:
import nltk
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import matplotlib.pyplot as plt
import textwrap  # Import textwrap for wrapping text

# Ensure necessary NLTK resources are available
nltk.download('punkt')

# Initialize the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Function to perform adaptive chunking
def adaptive_chunking(text):
    # Tokenize text into sentences
    sentences = nltk.sent_tokenize(text)
    chunks = []
    
    # Chunking logic: grouping sentences into chunks
    current_chunk = []
    for sentence in sentences:
        current_chunk.append(sentence)
        if len(current_chunk) >= 2:  # You can adjust the chunk size
            chunks.append(' '.join(current_chunk))
            current_chunk = []  # Start a new chunk
    if current_chunk:
        chunks.append(' '.join(current_chunk))  # Add the last chunk

    return chunks

# Function to embed chunks and store them in a FAISS index
def embed_and_store(chunks):
    embeddings = model.encode(chunks)
    
    # Initialize FAISS index
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    
    # Add embeddings to the index
    index.add(np.array(embeddings, dtype='float32'))
    
    return index, embeddings

# Function to retrieve relevant chunks based on user query
def retrieve(query, index, chunks):
    query_embedding = model.encode([query])
    D, I = index.search(np.array(query_embedding, dtype='float32'), k=3)  # Retrieve top 5
    
    if I[0][0] == -1:  # Check if no results found
        return "No results found for user query"
    
    retrieved_chunks = [(chunks[i], I[0][i], D[0][i]) for i in I[0]]
    return retrieved_chunks

# Function to visualize results
def visualize_results(original_chunks, retrieved_chunks, user_query):
    # Create a figure to visualize results
    plt.figure(figsize=(12, 12))

    # Original chunked data
    plt.subplot(3, 1, 1)
    plt.title("Original Chunked Data")
    wrapped_chunks = [textwrap.fill(chunk, width=50) for chunk in original_chunks]  # Wrap text to fit
    plt.bar(range(len(original_chunks)), [1]*len(original_chunks), tick_label=[f'Index {i}:\n {wrapped}' for i, wrapped in enumerate(wrapped_chunks)])
    plt.xticks(rotation=0)
    plt.ylabel("Chunks")

    # Display user query
    plt.subplot(3, 1, 2)
    #plt.title("User Query")
    plt.text(0.5, 0.5, f'User Query :  {user_query}', ha='center', va='center', fontsize=12, bbox=dict(facecolor='lightblue', alpha=0.5))
    plt.axis('off')

    # Retrieved chunks and distances
    plt.subplot(3, 1, 3)
    plt.title("Retrieved Chunked Data with Index and Distance\n")

    if retrieved_chunks != "No results found for user query":
        # Prepare table data
        table_data = []
        for idx, (chunk, index, distance) in enumerate(retrieved_chunks):
            wrapped_chunk = textwrap.fill(chunk, width=50)  # Wrap retrieved chunk text
            table_data.append([wrapped_chunk, index, f"{distance:.4f}"])

        # Create a table
        column_labels = ['Chunk', 'Index', 'Distance']
        table = plt.table(cellText=table_data, colLabels=column_labels, cellLoc='center', loc='center', bbox=[0, 0, 1, 1])

        # Set font size and bold headers
        table.auto_set_font_size(False)  # Disable automatic font size
        table.set_fontsize(9)  # Set font size for table
        for key, cell in table.get_celld().items():
            if key[0] == 0:  # Header row
                cell.set_text_props(fontweight='bold')  # Set header font to bold

    else:
        plt.text(0.5, 0.5, retrieved_chunks, ha='center', va='center', fontsize=12, color='red')

    plt.axis('off')
    plt.tight_layout()
    plt.show()

# Main function to tie it all together
def main(text, user_query):
    chunks = adaptive_chunking(text)
    index, embeddings = embed_and_store(chunks)
    retrieved_chunks = retrieve(user_query, index, chunks)
    
    visualize_results(chunks, retrieved_chunks, user_query)

# Example usage
text = """This is a sample paragraph. It contains several sentences that form a coherent narrative. 
The goal is to chunk this paragraph intelligently. Adaptive chunking can improve the retrieval of relevant information based on user queries. 
Moreover, it aids in understanding the structure and themes of the text."""
user_query = "What is the goal of chunking"

#text = """Natural Language Processing (NLP) is a subfield of artificial intelligence. It focuses on the interaction between #computers and humans through natural language.
#The goal is to chunk this paragraph intelligently. Adaptive chunking can improve the retrieval of relevant information based on #user queries. 
#Moreover, it aids in understanding the structure and themes of the text."""

#text = """Natural Language Processing (NLP) is a subfield of artificial intelligence.
#It focuses on the interaction between computers and humans through natural language.
#The ultimate objective of NLP is to read, decipher, understand, and make sense of human languages in a valuable way. """
#user_query = "What is the goal of chunking"


main(text, user_query)
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
No description has been provided for this image