Continuous Bag of Words (CBOW)¶
  • Create CBOW Embeddings:
    • The function create_cbow_embeddings uses Gensim’s Word2Vec with sg=0 to create the CBOW embeddings. CBOW tries to predict the target word based on context words around it (instead of the Skip-Gram model).
  • Store Embeddings in FAISS:
    • The embeddings are stored in a FAISS index to enable fast similarity search. The function create_faiss_index adds the word embeddings to the FAISS index, which will later be used to retrieve the top-k most similar words.
  • Retrieval-Augmented Generation (RAG):
    • The rag_query function takes the user query, tokenizes it, and retrieves the top-k (k=3) words that are most similar based on cosine distance in the embedding space, using the FAISS index for efficient retrieval.
  • Display Results:
    • PrettyTable for Word Embeddings: The function display_original_embeddings formats and displays the CBOW word embeddings in a tabular format with the index, word, and its embedding.
    • PrettyTable for Retrieved Words: The function display_retrieved_data shows the retrieved words, their embeddings, indices, and distances in a table.
    • Heatmap for Cosine Similarity: The display_retrieval_heatmap function creates a heatmap that visualizes the cosine similarity between the query words and their retrieved counterparts.
In [ ]:
%pip install -q gensim faiss-cpu pandas prettytable matplotlib
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
petastorm 0.12.1 requires pyspark>=2.1.0, which is not installed.
databricks-feature-store 0.14.3 requires pyspark<4,>=3.1.2, which is not installed.
ydata-profiling 4.2.0 requires numpy<1.24,>=1.16.0, but you have numpy 1.26.4 which is incompatible.
ydata-profiling 4.2.0 requires scipy<1.11,>=1.4.1, but you have scipy 1.13.1 which is incompatible.
numba 0.55.1 requires numpy<1.22,>=1.18, but you have numpy 1.26.4 which is incompatible.
mleap 0.20.0 requires scikit-learn<0.23.0,>=0.22.0, but you have scikit-learn 1.1.1 which is incompatible.
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
import gensim
import faiss
import numpy as np
import pandas as pd
from prettytable import PrettyTable
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

# Step 1: Create CBOW Embedding using Gensim
def create_cbow_embeddings(sentences):
    # Tokenize the sentences and create a CBOW model
    tokenized_sentences = [sentence.lower().split() for sentence in sentences]
    model = gensim.models.Word2Vec(sentences=tokenized_sentences, vector_size=100, window=5, sg=0, min_count=1, workers=4)

    # Extract word embeddings from the model
    word_embeddings = {word: model.wv[word] for word in model.wv.index_to_key}
    return model, word_embeddings

# Step 2: Store the embeddings in a FAISS index
def create_faiss_index(word_embeddings):
    embedding_matrix = np.array([embedding for embedding in word_embeddings.values()])
    
    # Initialize a FAISS index for L2 distance
    dimension = embedding_matrix.shape[1]  # Dimensionality of the embeddings
    index = faiss.IndexFlatL2(dimension)
    
    # Add the embeddings to the FAISS index
    index.add(embedding_matrix)
    
    return index

# Step 3: Implement Retrieval-Augmented Generation (RAG) to query the database and get top k embeddings
def rag_query(user_query, model, index, k=2):
    query_tokens = user_query.lower().split()  # Tokenize the query
    results = []
    
    # For each word in the query, retrieve the closest k words from FAISS
    for token in query_tokens:
        if token in model.wv:
            query_embedding = model.wv[token].reshape(1, -1)  # Reshape for FAISS query
            distances, indices = index.search(query_embedding, k)  # Search the FAISS index for top k results
            results.append((token, indices[0], distances[0]))  # Collect word, indices, and distances
    
    return results

# Step 4: Display results

# 4.1: Display the original CBOW embeddings in a neat table format using PrettyTable
def display_original_embeddings(word_embeddings):
    # Create a PrettyTable for Original Word Embeddings with Index
    table = PrettyTable()
    table.field_names = ["Index", "Word", "Embedding"]

    for idx, word in enumerate(word_embeddings.keys()):
        table.add_row([idx, word, str(word_embeddings[word])[:50]])  # Shorten embedding for display

    return table

# 4.2: Display the retrieved data in a table format using PrettyTable
def display_retrieved_data(results, model):
    # Create a PrettyTable for Retrieved Data
    table = PrettyTable()
    table.field_names = ["Query Word", "Predicted Word", "Index", "Distance", "Embedding"]

    for word, indices, distances in results:
        for idx, dist in zip(indices, distances):
            retrieved_word = model.wv.index_to_key[idx]  # Get the word from FAISS index
            table.add_row([word, retrieved_word, idx, dist, str(model.wv[retrieved_word])[:50]])

    return table

# 4.3: Display the heatmap of the retrieved words based on cosine similarity
def display_retrieval_heatmap(user_query, results, model):
    # Extract the embeddings of the query word and its top-k retrieved words
    query_words = [result[0] for result in results]  # Words from user query
    retrieved_words = []
    embeddings = []
    
    for word, indices, _ in results:
        for idx in indices:
            retrieved_word = model.wv.index_to_key[idx]
            retrieved_words.append(retrieved_word)
            embeddings.append(model.wv[retrieved_word])
    
    # Query word embeddings
    query_embeddings = np.array([model.wv[word] for word in query_words])
    
    # Calculate cosine similarity between the query and retrieved words
    similarity_matrix = cosine_similarity(query_embeddings, embeddings)

    # Create a DataFrame for better visualization
    similarity_df = pd.DataFrame(similarity_matrix, index=query_words, columns=retrieved_words)

    # Display the heatmap
    plt.figure(figsize=(8, 3.5))
    sns.heatmap(similarity_df, annot=True, cmap="YlGnBu", cbar=True, linewidths=0.5)
    plt.title(f"Cosine Similarity Heatmap for Query '{user_query}' and Retrieved Words")
    plt.show()

# Sample sentences for CBOW model (Real-world paragraph/sentences)
sentences = [
    "Artificial intelligence is transforming the world in multiple ways",
    "Natural language processing enables computers to understand human language",
    "Machine learning models are being used in various industries",
    "The future of technology is driven by advancements in AI and data science",
    "Computer vision allows machines to interpret and understand visual data"
]

# Create CBOW embeddings using Gensim
model, word_embeddings = create_cbow_embeddings(sentences)

# Create FAISS index with the embeddings
faiss_index = create_faiss_index(word_embeddings)

# Example user query
user_query = "driven by advancements"
results = rag_query(user_query, model, faiss_index, k=2)

# Step 4: Display results

# 4.1: Display the original word embeddings table
print("Original CBOW Word Embeddings (with Index):")
original_embeddings_table = display_original_embeddings(word_embeddings)
print(original_embeddings_table)

# Display original user query as text
print("\nOriginal User Query:", f"'{user_query}'")

# 4.2: Display the retrieved words and their embeddings
print("\nRetrieved Data for User Query:", f"'{user_query}'")
retrieved_data_table = display_retrieved_data(results, model)
print(retrieved_data_table)

# 4.3: Display the heatmap for the retrieved words based on cosine similarity
display_retrieval_heatmap(user_query, results, model)
Original CBOW Word Embeddings (with Index):
+-------+--------------+----------------------------------------------------+
| Index |     Word     |                     Embedding                      |
+-------+--------------+----------------------------------------------------+
|   0   |      in      | [-5.4488698e-04  2.3148068e-04  5.1030992e-03  9.0 |
|   1   |   language   | [-8.6322334e-03  3.6704775e-03  5.1830662e-03  5.7 |
|   2   |      is      | [ 8.8889763e-05  3.0749419e-03 -6.8128402e-03 -1.3 |
|   3   |  understand  | [-8.2551921e-03  9.2970291e-03 -2.0232004e-04 -1.9 |
|   4   |     the      | [-0.00713902  0.00124103 -0.00717672 -0.00224462   |
|   5   |      to      | [-8.7441904e-03  2.1391853e-03 -8.7682781e-04 -9.3 |
|   6   |     data     | [ 8.1304898e-03 -4.4578831e-03 -1.0723007e-03  1.0 |
|   7   |     and      | [ 8.1668133e-03 -4.4421479e-03  8.9852260e-03  8.2 |
|   8   |    visual    | [-9.58045293e-03  8.93678144e-03  4.16930532e-03   |
|   9   |    human     | [-0.00516869 -0.00666477 -0.00777491  0.00830756 - |
|   10  |  computers   | [ 7.0799328e-03 -1.5651411e-03  7.9463385e-03 -9.4 |
|   11  |   enables    | [ 9.7702928e-03  8.1651136e-03  1.2809705e-03  5.0 |
|   12  |  processing  | [-1.9522113e-03 -5.2662930e-03  9.4475150e-03 -9.2 |
|   13  |   multiple   | [-0.00950223  0.00956123 -0.00777185 -0.00264515 - |
|   14  |   natural    | [ 7.6966453e-03  9.1206403e-03  1.1355019e-03 -8.3 |
|   15  |     ways     | [-7.1909428e-03  4.2328904e-03  2.1633934e-03  7.4 |
|   16  |   learning   | [ 1.2987166e-03 -9.8028453e-03  4.5865774e-03 -5.3 |
|   17  |    world     | [ 0.00180023  0.00704609  0.0029447  -0.00698085   |
|   18  | transforming | [ 0.00973343 -0.00978134 -0.00649983  0.00278333   |
|   19  | intelligence | [ 5.6267120e-03  5.4973699e-03  1.8291187e-03  5.7 |
|   20  |   machine    | [ 0.00256773  0.0008465  -0.00253994  0.00935823   |
|   21  |     are      | [ 1.3269740e-03  6.5380698e-03  9.9846479e-03  9.0 |
|   22  |    models    | [-2.4468597e-04  4.2178901e-03  2.1142603e-03  9.9 |
|   23  |      by      | [-0.00250877 -0.00590266  0.00748334 -0.00725973 - |
|   24  |   machines   | [-4.9735666e-03 -1.2833046e-03  3.2806373e-03 -6.4 |
|   25  |    allows    | [ 0.00964853  0.00732483  0.00126166 -0.00340524 - |
|   26  |    vision    | [-6.9688787e-03 -2.4623817e-03 -8.0255056e-03  7.4 |
|   27  |   computer   | [ 0.00211351  0.00573515 -0.00211641  0.0031723    |
|   28  |   science    | [ 8.3416523e-03 -5.7128520e-04 -9.4387336e-03  4.7 |
|   29  |      ai      | [-4.2858059e-03 -9.3295584e-03 -1.8737190e-03 -3.7 |
|   30  | advancements | [ 1.9187220e-04  2.1426824e-03  1.0673234e-03  7.5 |
|   31  |    driven    | [-0.00220363 -0.00971407  0.00929654  0.00204508 - |
|   32  |  interpret   | [ 6.4095901e-03 -8.9552216e-03 -7.3440163e-03 -1.7 |
|   33  |  technology  | [ 0.00479019 -0.00362633 -0.00426102  0.00122164 - |
|   34  |      of      | [-1.5110135e-03 -4.0345048e-03 -4.3988540e-03 -4.6 |
|   35  |    future    | [ 0.00794268 -0.00645123  0.00579293 -0.00021677 - |
|   36  |  industries  | [ 5.3819208e-03  9.8090153e-03 -7.0540742e-03 -5.8 |
|   37  |   various    | [ 2.2229629e-03 -7.5374460e-03  5.6277504e-03 -5.0 |
|   38  |     used     | [-0.00778307 -0.00675631 -0.00315467  0.00661474 - |
|   39  |    being     | [-0.00172406  0.00640525 -0.00927498 -0.00914481 - |
|   40  |  artificial  | [-4.7935294e-03 -4.3203579e-03 -4.7943974e-03 -9.7 |
+-------+--------------+----------------------------------------------------+

Original User Query: 'driven by advancements'

Retrieved Data for User Query: 'driven by advancements'
+--------------+----------------+-------+--------------+----------------------------------------------------+
|  Query Word  | Predicted Word | Index |   Distance   |                     Embedding                      |
+--------------+----------------+-------+--------------+----------------------------------------------------+
|    driven    |     driven     |   31  |     0.0      | [-0.00220363 -0.00971407  0.00929654  0.00204508 - |
|    driven    |      and       |   7   | 0.005862892  | [ 8.1668133e-03 -4.4421479e-03  8.9852260e-03  8.2 |
|      by      |       by       |   23  |     0.0      | [-0.00250877 -0.00590266  0.00748334 -0.00725973 - |
|      by      |   artificial   |   40  | 0.005529711  | [-4.7935294e-03 -4.3203579e-03 -4.7943974e-03 -9.7 |
| advancements |  advancements  |   30  |     0.0      | [ 1.9187220e-04  2.1426824e-03  1.0673234e-03  7.5 |
| advancements |    machine     |   20  | 0.0047652363 | [ 0.00256773  0.0008465  -0.00253994  0.00935823   |
+--------------+----------------+-------+--------------+----------------------------------------------------+
No description has been provided for this image