Semantic Chunking¶

  • Chunking the Text: Use the SpacyTextSplitter to break the report into chunks that are semantically meaningful, ensuring the integrity of the information.
  • Embedding the Chunks: Use sentence-transformers to generate dense vector embeddings for the text chunks. These embeddings represent the semantic meaning of the chunks.
  • Storing in FAISS: The embeddings are stored in a FAISS index, which allows us to efficiently retrieve the most similar chunks to a given user query using vector similarity search (L2 distance).
  • Retrieving Chunks for User Query (RAG): When a user submits a query, we embed the query, search the FAISS index for the most similar chunks, and retrieve the closest chunks (based on cosine similarity).
  • Visualization: The script visualizes the chunked data (length of chunks) and the retrieved chunks (length and distance from the query). It uses matplotlib for this graphical representation.
In [ ]:
%pip install -q spacy sentence-transformers faiss-cpu langchain matplotlib
In [ ]:
%python -m spacy download en_core_web_sm
In [ ]:
import spacy
from langchain.text_splitter import SpacyTextSplitter
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import matplotlib.pyplot as plt
from textwrap import wrap

# Initialize Spacy NLP model for chunking
nlp = spacy.load("en_core_web_sm")

# Initialize Sentence Transformer model for embeddings (you can also use OpenAI embeddings)
embedder = SentenceTransformer('all-MiniLM-L6-v2')

# Function to perform semantic chunking using SpacyTextSplitter
def chunk_text(text, chunk_size=500):
    splitter = SpacyTextSplitter(chunk_size=chunk_size, chunk_overlap=50)
    chunks = splitter.split_text(text) 
    return chunks

# Function to generate embeddings for text chunks
def embed_chunks(chunks):
    embeddings = embedder.encode(chunks)
    return embeddings

# Function to create FAISS index and store embedded chunks
def create_faiss_index(embeddings):
    dim = embeddings.shape[1]  # Dimension of embeddings
    index = faiss.IndexFlatL2(dim)  # L2 distance (Euclidean distance)
    index.add(embeddings)  # Add embeddings to index
    return index

# Function to perform RAG and retrieve the closest chunks based on a user query
def rag_query(query, index, chunks, embedder, top_k=1):
    # Generate embedding for the query
    query_embedding = embedder.encode([query])
    
    # Perform search on FAISS index
    distances, indices = index.search(query_embedding, top_k)
    
    if distances[0][0] == np.inf:
        return "Couldn't find accurate match for user query."

    # Retrieve the corresponding chunks
    retrieved_chunks = [chunks[i] for i in indices[0]]
    return retrieved_chunks, indices, distances

# Function to visualize chunked data and RAG results
def visualize_chunks_and_query(chunks, query, retrieved_chunks, indices, distances):
    # Set up the figure and axis for subplots
    fig, axes = plt.subplots(3, 1, figsize=(15, 18))

    # Original Chunk Data (Bar graph)
    axes[0].set_title("Original Chunked Data")
    chunk_lengths = [len(chunk.split()) for chunk in chunks]
    axes[0].barh(range(len(chunks)), chunk_lengths, color='lightblue')
  
    # Wrapping text and displaying chunk index for each chunk
    axes[0].set_yticks(range(len(chunks)))
    ##axes[0].set_yticklabels([f"Chunk {i}\n{wrap(chunk, width=40)}" for i, chunk in enumerate(chunks)])
    axes[0].set_yticklabels([f"Chunk {i}\n{(chunk)}" for i, chunk in enumerate(chunks)], fontsize=9)
    axes[0].set_xlabel("Number of Words in Chunk")
    axes[0].set_ylabel("Chunk Index")
    
    
    # Retrieved Chunks Data (Bar graph)
    axes[1].set_title("Retrieved Chunks for User Query")
    retrieved_chunk_lengths = [len(chunk.split()) for chunk in retrieved_chunks]
    axes[1].barh(range(len(retrieved_chunks)), retrieved_chunk_lengths, color='lightgreen')

    # Wrapping text and showing chunk index and distance
    axes[1].set_yticks(range(len(retrieved_chunks)))
    ##axes[1].set_yticklabels([f"Chunk {idx}\n{wrap(chunk, width=40)}\n(Distance: {dist:.3f})" 
    axes[1].set_yticklabels([f"Chunk {idx}\n{(chunk)}\n(Distance: {dist:.3f})"                          
                             for idx, chunk, dist in zip(indices[0], retrieved_chunks, distances[0])],fontsize=9)
    axes[1].set_xlabel("Number of Words in Chunk")
    axes[1].set_ylabel("Retrieved Chunk Index")

    # User Query (Text)
    axes[2].text(0.5, 0.5, f"User Query: {query}", fontsize=12, ha='center', va='center')
    axes[2].axis('off')  # Hide the axes for the user query

    # Apply tight layout to prevent overlapping labels
    plt.tight_layout()
    plt.show()

    # Print the user query and retrieved results as standard text
    print(f"User Query: {query}\n")
    print("Retrieved Chunks:")
    for idx, chunk, dist in zip(indices[0], retrieved_chunks, distances[0]):
        print(f"Chunk {idx} (Distance: {dist:.3f}):\n{chunk}\n")

# Main function to process marketing report text
def process_marketing_report(text, user_query):
    # Step 1: Chunk the input text
    chunks = chunk_text(text)
    print(f"Number of chunks: {len(chunks)}")

    # Step 2: Generate embeddings for the chunks
    embeddings = embed_chunks(chunks)

    # Step 3: Create FAISS index and add embeddings
    embeddings_np = np.array(embeddings).astype('float32')
    faiss_index = create_faiss_index(embeddings_np)

    # Step 4: Perform RAG for user query
    retrieved_chunks, indices, distances = rag_query(user_query, faiss_index, chunks, embedder)

    # Step 5: Visualize chunk data and RAG results
    visualize_chunks_and_query(chunks, user_query, retrieved_chunks, indices, distances)

    # Step 6: Return final response
    if isinstance(retrieved_chunks, str):  # In case no results are found
        return retrieved_chunks
    else:
        return "\n".join(retrieved_chunks)

# Example usage
if __name__ == "__main__":
    # Sample marketing report text
    marketing_report = """
    The company's revenue increased by 15% in Q3 compared to Q2. This growth was primarily driven by an increase in consumer demand for our new product line.
    In particular, the launch of Product X in the North American market contributed significantly to the overall sales surge.
    Additionally, our marketing team implemented a series of targeted campaigns, which contributed to a 20% increase in conversion rates.
    The customer satisfaction score also improved, indicating a positive reception to the product's quality and design.
    We expect the growth trend to continue in Q4 with the launch of new marketing initiatives aimed at expanding our market share.
    """

    user_query = "What caused the revenue increase in Q3?"

    # Process the report and get RAG response
    result = process_marketing_report(marketing_report, user_query)
    print("\nFinal RAG Response:")
    print(result)
Number of chunks: 2
No description has been provided for this image
User Query: What caused the revenue increase in Q3?

Retrieved Chunks:
Chunk 0 (Distance: 0.622):
The company's revenue increased by 15% in Q3 compared to Q2.

This growth was primarily driven by an increase in consumer demand for our new product line.
    

In particular, the launch of Product X in the North American market contributed significantly to the overall sales surge.
    

Additionally, our marketing team implemented a series of targeted campaigns, which contributed to a 20% increase in conversion rates.


Final RAG Response:
The company's revenue increased by 15% in Q3 compared to Q2.

This growth was primarily driven by an increase in consumer demand for our new product line.
    

In particular, the launch of Product X in the North American market contributed significantly to the overall sales surge.
    

Additionally, our marketing team implemented a series of targeted campaigns, which contributed to a 20% increase in conversion rates.