Content-aware Chunking¶

  • Chunking: The chunk_text function uses NLTK's RegexpParser to identify noun phrases, verb phrases, prepositional phrases, and clauses based on a defined grammar.
  • Embedding: The embed_chunks function utilizes the SentenceTransformer model to create embeddings for each chunked piece of text.
  • Vector Database: The create_vector_database function builds a FAISS index to store and search the embeddings efficiently.
  • RAG Functionality: The rag_search function retrieves the top-k closest chunks from the vector database based on the user's query.
  • Extracting Matched Information: The code captures the matched index, matched chunk, and distance from the best match in retrieved_data.
  • Output: The results are displayed graphically using Matplotlib, showing the original chunks, the user query, and the retrieved chunks.
Additonal enchantments can be made as below:¶
  • Adjust the embeddings model if needed based on the specific domain or performance requirements.
  • You can further enhance the user query matching mechanism by experimenting with different distance thresholds or fine-tuning the embedding model.
In [ ]:
%pip install -q nltk sentence-transformers faiss-cpu matplotlib
In [ ]:
import nltk
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt

# Ensure NLTK resources are available
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('maxent_ne_chunker')
nltk.download('words')

# Define the grammar for chunking
grammar = r"""
  NP: {<DT>?<JJ>*<NN.*>+}   # Noun Phrase
  VP: {<VB.*><NP|PP|CLAUSE>+$}  # Verb Phrase
  PP: {<IN><NP>}             # Prepositional Phrase
  CLAUSE: {<NP><VP>}         # Clause
"""

# Function to perform content-aware chunking and return chunks with indices
def chunk_text(text):
    sentences = nltk.sent_tokenize(text)
    chunked_data = []
    indices = []

    for idx, sentence in enumerate(sentences):
        tokens = nltk.word_tokenize(sentence)
        pos_tags = nltk.pos_tag(tokens)
        chunk_parser = nltk.RegexpParser(grammar)
        tree = chunk_parser.parse(pos_tags)
        chunks = [' '.join(word for word, tag in subtree.leaves()) for subtree in tree if isinstance(subtree, nltk.Tree)]
        
        # Append the chunks along with their indices
        chunked_data.extend(chunks)
        indices.extend([(idx, chunk) for chunk in chunks])  # Store (sentence index, chunk)

    return chunked_data, indices

# Function to embed text using SentenceTransformer
def embed_chunks(chunks):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(chunks)
    return embeddings

# Function to create a vector database
def create_vector_database(embeddings):
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)  # L2 distance
    index.add(np.array(embeddings).astype('float32'))  # Add embeddings to the index
    return index

# Function for retrieval-augmented generation
def rag_search(index, query, chunks, k=5):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    query_embedding = model.encode([query]).astype('float32')
    D, I = index.search(query_embedding, k)  # Search the top k results
    return [(chunks[i], D[0][j], i) for j, i in enumerate(I[0]) if i < len(chunks) and D[0][j] < 1e10]  # Include indices

# Main function to encapsulate the complete functionality
def content_aware_chunking_and_rag(text, user_query):
    # Step 1: Chunk the text
    chunked_data, original_indices = chunk_text(text)
    
    # Step 2: Embed the chunks
    embeddings = embed_chunks(chunked_data)
    
    # Step 3: Create a vector database
    index = create_vector_database(embeddings)

    # Step 4: Embed the user query
    query_embedding = embed_chunks([user_query])[0].reshape(1, -1)  # Reshape for FAISS
    index.add(query_embedding.astype('float32'))  # Add the query embedding to the index
    query_id = embeddings.shape[0]  # The query index will be the next available index

    # Step 5: RAG search for user query
    retrieved_data = rag_search(index, user_query, chunked_data)

    # Step 6: Display results
    if not retrieved_data:
        print("No results found for user query.")
        return
    
    # Prepare output for display
    original_chunked_data = '\n'.join(f"{i}: {chunk}" for i, (idx, chunk) in enumerate(original_indices))
    retrieved_chunks = '\n'.join([f"{chunk} (distance: {dist:.4f}, index: {index})" for chunk, dist, index in retrieved_data])

    # Graphical display of results
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.axis('off')

    # Create and position text boxes for each section
    plt.text(0.5, 1.05, 'Original Chunks with Indices', ha='center', va='center', fontsize=12, weight='bold')
    plt.text(0.5, 0.85, original_chunked_data, ha='center', va='center', fontsize=12, wrap=True,
             bbox=dict(facecolor='lightblue', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    
    plt.text(0.5, 0.55, 'User Query', ha='center', va='center', fontsize=12, weight='bold')
    plt.text(0.5, 0.5, user_query, ha='center', va='center', fontsize=12, wrap=True,
             bbox=dict(facecolor='lightyellow', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    
    plt.text(0.5, 0.30, 'Retrieved Chunks\n', ha='center', va='center', fontsize=12, weight='bold')
    plt.text(0.5, 0.2, retrieved_chunks, ha='center', va='center', fontsize=12, wrap=True,
             bbox=dict(facecolor='lightgreen', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))

    plt.show()

    # New graph for matched index, chunk, and distance in a single graph
    best_match = retrieved_data[0]  # Assuming the first result is the best
    matched_chunk, matched_distance, matched_index = best_match

    # Create a new simple graph for matched index, chunk, and distance
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.axis('off')
    plt.title('Matched Index, Chunk, and Distance', fontsize=12, weight='bold')

    # Display all information in a single text box
    combined_text = (
        f"Matched Index: {matched_index}\n"
        f"Matched Chunk: {matched_chunk}\n"
        f"Distance: {matched_distance:.4f}"
    )
    
    plt.text(0.5, 0.5, combined_text, ha='center', va='center', fontsize=12,
             bbox=dict(facecolor='lightgreen', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))

    plt.show()

# Example usage
if __name__ == "__main__":
    text = """To ensure optimal performance and longevity of hydroelectric pumps, regular maintenance is essential.
               Begin by inspecting the pump and associated piping for any signs of wear, leaks, or corrosion."""
    user_query = "How to maintain the hydroelectric pump?"

    content_aware_chunking_and_rag(text, user_query)
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package maxent_ne_chunker to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package maxent_ne_chunker is already up-to-date!
[nltk_data] Downloading package words to /root/nltk_data...
[nltk_data]   Package words is already up-to-date!
No description has been provided for this image
No description has been provided for this image