FastText Embedding¶
- Train the Model:
- Train sentences/words using FastText open-source FastText model and generate embeddings for input sentences/words.
- Deterministic Behavior:
- set_random_seed(seed=42) is used to set a deterministic seed for both Python's random module and NumPy's random number generator, which is crucial for making the training process deterministic.
- The seed is passed to the train_fasttext_model function, which ensures that the model's training is deterministic.
- Deterministic Behavior:
- Store the embedding a vector database (FAISS)
- For each word in the words list, the word vectors are retrieved using the FastText model and stored in both a dictionary (word_to_vector) and a list. The list is then converted into a numpy array (required by FAISS).
- Use FAISS Indexing faiss.IndexFlatL2, which stores the word vectors in the index. It uses Euclidean (L2) distance as the metric for similarity search.
- Train sentences/words using FastText open-source FastText model and generate embeddings for input sentences/words.
- OOV Handling
- For unseen words (oov_word), FastText is still capable of producing a vector. If a word is not found in the FAISS index, the code will display it with its generated FastText vector. This demonstrates how the FAISS index handles known words efficiently, while FastText handles OOV words by generating vectors dynamically.
- Prediction Accuracy: In this experiment we are dealing with FastText, OOV words will be predicted using the model's vector representation. The "prediction accuracy" is simulated graph based on cosine similarity between the known word vector and the OOV word vector, this metric shows the "closeness" or "prediction" performance of the OOV word.
- Prediction Accuracy Determinism: As the cosine similarity calculation does not involve randomness, it remains deterministic as long as the word vectors are deterministic.
In [ ]:
%pip install -q fasttext faiss-cpu numpy pandas matplotlib
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages. ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. petastorm 0.12.1 requires pyspark>=2.1.0, which is not installed. databricks-feature-store 0.14.3 requires pyspark<4,>=3.1.2, which is not installed. ydata-profiling 4.2.0 requires numpy<1.24,>=1.16.0, but you have numpy 2.1.3 which is incompatible. scipy 1.9.1 requires numpy<1.25.0,>=1.18.5, but you have numpy 2.1.3 which is incompatible. numba 0.55.1 requires numpy<1.22,>=1.18, but you have numpy 2.1.3 which is incompatible. mleap 0.20.0 requires scikit-learn<0.23.0,>=0.22.0, but you have scikit-learn 1.1.1 which is incompatible. langchain 0.0.217 requires numpy<2,>=1, but you have numpy 2.1.3 which is incompatible. databricks-feature-store 0.14.3 requires numpy<2,>=1.19.2, but you have numpy 2.1.3 which is incompatible. Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
import fasttext
import numpy as np
import pandas as pd
import faiss # FAISS library for vector search
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import random
def generate_sample_data(file_path, num_sentences=1000, sentence_length=5):
"""
Generate sample training data and save it to a text file.
Args:
- file_path (str): Path to save the generated data.
- num_sentences (int): Number of sentences to generate.
- sentence_length (int): Number of words in each sentence.
"""
words = ['apple', 'banana', 'cherry', 'prune', 'elderberry',
'fig', 'grape', 'honeydew', 'kiwi', 'lemon',
'mango', 'nectarine', 'orange', 'papaya', 'quince']
with open(file_path, 'w') as f:
for _ in range(num_sentences):
sentence = ' '.join(random.choices(words, k=sentence_length))
f.write(sentence + '\n')
return words # Return the list of words used for generating sentences
def set_random_seed(seed=42):
"""
Set the random seed for deterministic behavior.
Args:
- seed (int): The seed value for random number generators.
"""
random.seed(seed)
np.random.seed(seed)
def train_fasttext_model(data_file, model_file='model.bin', dim=300, lr=0.1, epoch=5, min_count=1, seed=42):
"""
Train a FastText model using the provided data file with deterministic behavior.
Args:
- data_file (str): Path to the training data file.
- model_file (str): Path to save the trained model.
- dim (int): Dimension of the word vectors.
- lr (float): Learning rate.
- epoch (int): Number of training epochs.
- min_count (int): Minimum word frequency to consider for training.
- seed (int): The random seed for deterministic training.
Returns:
- model: The trained FastText model.
"""
set_random_seed(seed) # Set seed for reproducibility
# FastText doesn't have an explicit seed argument, but we can control randomness through Python's random and numpy seed
model = fasttext.train_unsupervised(data_file, dim=dim, lr=lr, epoch=epoch, minCount=min_count)
model.save_model(model_file)
return model
def get_word_vector(model, word):
"""
Get the word vector for a given word, handling OOV words.
Args:
- model: The trained FastText model.
- word (str): The word to get the vector for.
Returns:
- np.ndarray: The word vector.
"""
vector = model.get_word_vector(word)
return vector
def build_faiss_index(vectors, dimension):
"""
Build a FAISS index to store and search word vectors.
Args:
- vectors (np.ndarray): Array of word vectors.
- dimension (int): Dimensionality of the word vectors.
Returns:
- index: The FAISS index storing the vectors.
"""
index = faiss.IndexFlatL2(dimension)
index.add(vectors)
return index
def display_vectors(known_word_vector, oov_word_vector, known_word='apple', oov_word='unseenword'):
"""
Display known and OOV word vectors in a pandas DataFrame with table styling.
Also show prediction accuracy visualization (cosine similarity).
Args:
- known_word_vector (np.ndarray): Vector for the known word.
- oov_word_vector (np.ndarray): Vector for the OOV word.
- known_word (str): The known word.
- oov_word (str): The OOV word.
"""
# Create DataFrame for vectors
df = pd.DataFrame({
'Word': [known_word, oov_word],
'Vector': [known_word_vector, oov_word_vector]
})
# Split the vector into separate columns
vector_df = df['Vector'].apply(pd.Series)
df = pd.concat([df.drop('Vector', axis=1), vector_df], axis=1)
# Style the DataFrame for presentation
styled_df = df.style.set_table_styles(
[{'selector': 'thead th', 'props': [('border', '1px solid black')]},
{'selector': 'tbody td', 'props': [('border', '1px solid black')]}]
)
display(styled_df)
# Cosine similarity as "prediction accuracy"
similarity_score = cosine_similarity([known_word_vector], [oov_word_vector])[0][0]
# Plotting prediction accuracy
plt.figure(figsize=(4, 4))
sns.barplot(x=['Prediction Accuracy'], y=[similarity_score], palette='coolwarm')
plt.title(f'Cosine Similarity (Prediction Accuracy) for OOV word: {oov_word}')
plt.ylim(0, 1) # Accuracy is between 0 and 1
plt.ylabel('Cosine Similarity')
plt.show()
def display_word_lists(words_list, title, vector_dimension=None):
"""
Display a list of words in a neat table along with their vector dimension if provided.
Args:
- words_list (list): List of words to display.
- title (str): Title of the table.
- vector_dimension (int, optional): Dimension of the word vectors (if provided).
"""
# Display the word list in a table
df = pd.DataFrame(words_list, columns=['Words'])
# Add vector dimension if specified
if vector_dimension is not None:
df['Vector Dimension'] = vector_dimension
styled_df = df.style.set_table_styles(
[{'selector': 'thead th', 'props': [('border', '1px solid black')]},
{'selector': 'tbody td', 'props': [('border', '1px solid black')]}]
)
print(f"\n{title}")
display(styled_df)
# Example usage
if __name__ == "__main__":
# Set random seed for reproducibility
set_random_seed(42)
# Generate sample training data and capture the words used
data_file = 'sample_data.txt'
words_used = generate_sample_data(data_file)
# Display the list of words used to generate the sample data
display_word_lists(words_used, "Words Used in Sample Data")
# Train the FastText model on the generated data
model = train_fasttext_model(data_file)
# List of words to store in FAISS index
faiss_words = ['apple', 'banana', 'cherry', 'prune', 'elderberry', 'fig', 'grape', 'honeydew', 'kiwi', 'lemon']
# Display the list of words to store in FAISS index with vector dimensions
display_word_lists(faiss_words, "List of Words to Store in FAISS Index", vector_dimension=300)
# Get embeddings for known words and store them in a list
word_vectors = []
word_to_vector = {}
for word in faiss_words:
vector = get_word_vector(model, word)
word_vectors.append(vector)
word_to_vector[word] = vector
# Convert list of word vectors into a numpy array
word_vectors = np.array(word_vectors).astype('float32')
# Build FAISS index for the word vectors
faiss_index = build_faiss_index(word_vectors, dimension=300)
# Test with a known word and an unseen word
known_word = 'apple' # 'prune'
oov_word = 'mandarin' # 'raisins'
# Get the embedding of the known word
known_word_vector = word_to_vector.get(known_word, None)
# For unseen words, get the FastText embedding (handled by FastText)
oov_word_vector = get_word_vector(model, oov_word)
print('Known word : ', known_word)
print('Out-of-Vocabulary (OOV) word: ', oov_word)
# Display the vectors and prediction accuracy
display_vectors(known_word_vector, oov_word_vector, known_word, oov_word)
Words Used in Sample Data
 | Words |
---|---|
0 | apple |
1 | banana |
2 | cherry |
3 | prune |
4 | elderberry |
5 | fig |
6 | grape |
7 | honeydew |
8 | kiwi |
9 | lemon |
10 | mango |
11 | nectarine |
12 | orange |
13 | papaya |
14 | quince |
List of Words to Store in FAISS Index
 | Words | Vector Dimension |
---|---|---|
0 | apple | 300 |
1 | banana | 300 |
2 | cherry | 300 |
3 | prune | 300 |
4 | elderberry | 300 |
5 | fig | 300 |
6 | grape | 300 |
7 | honeydew | 300 |
8 | kiwi | 300 |
9 | lemon | 300 |
Known word : apple Out-of-Vocabulary (OOV) word: mandarin
 | Word | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | apple | -0.000403 | 0.000335 | -0.000500 | -0.000574 | 0.001006 | 0.000436 | -0.000223 | -0.000967 | 0.000934 | 0.000323 | 0.001032 | -0.000096 | 0.001519 | 0.000381 | -0.000928 | -0.000112 | -0.000083 | 0.000194 | -0.000201 | 0.000151 | -0.000662 | -0.000318 | 0.000141 | -0.000754 | 0.000894 | -0.001361 | 0.000836 | 0.000590 | -0.000041 | 0.000544 | -0.000118 | 0.000529 | 0.000125 | -0.000550 | -0.000302 | -0.000410 | -0.000381 | 0.000185 | -0.000563 | 0.000035 | 0.000514 | -0.000721 | -0.000311 | 0.000520 | -0.000056 | 0.000298 | -0.000617 | -0.000132 | 0.000614 | 0.000968 | -0.000403 | -0.000378 | 0.000023 | 0.000470 | -0.000105 | -0.001149 | 0.000602 | -0.000099 | -0.000556 | -0.000771 | -0.000253 | -0.000143 | -0.000320 | 0.001293 | 0.000120 | 0.000302 | -0.000587 | 0.001288 | -0.000147 | 0.000067 | 0.000408 | 0.000496 | 0.000236 | 0.000446 | -0.000771 | -0.000130 | 0.000043 | 0.000727 | 0.000355 | 0.000271 | -0.000179 | 0.001085 | 0.000929 | -0.000950 | -0.000318 | -0.000971 | -0.000488 | -0.000283 | -0.000831 | -0.000271 | -0.000389 | 0.000797 | 0.000448 | -0.000771 | -0.000040 | 0.000196 | 0.000837 | 0.000076 | 0.000869 | 0.000016 | -0.000046 | 0.000795 | 0.000303 | -0.000985 | 0.000418 | 0.000519 | -0.000693 | -0.000468 | 0.000938 | 0.000407 | -0.000181 | -0.000176 | 0.000014 | 0.000169 | -0.000320 | -0.001006 | 0.000161 | 0.000360 | -0.000484 | -0.000167 | -0.000252 | 0.000925 | 0.000409 | 0.000442 | 0.000052 | -0.000906 | -0.000230 | 0.000145 | 0.000336 | -0.001003 | -0.000574 | -0.000769 | 0.000404 | -0.000844 | -0.001075 | 0.000596 | -0.000790 | 0.000078 | -0.000161 | -0.000747 | -0.001080 | -0.000955 | -0.000269 | -0.000206 | 0.000147 | 0.000085 | 0.000315 | 0.000414 | -0.000425 | 0.000496 | 0.000350 | -0.000027 | 0.000366 | -0.000552 | -0.000544 | 0.000019 | 0.000342 | 0.000997 | 0.000625 | 0.000039 | -0.001158 | 0.001477 | 0.000505 | 0.000882 | -0.000134 | -0.000190 | 0.000639 | -0.000131 | -0.000143 | -0.000047 | -0.000008 | 0.000238 | -0.000057 | -0.000406 | -0.000180 | -0.000450 | 0.000458 | 0.000759 | 0.000787 | 0.001022 | 0.000231 | -0.000833 | -0.000109 | 0.000400 | 0.001386 | 0.000259 | 0.000017 | -0.000196 | 0.000140 | 0.000354 | -0.000467 | -0.000177 | 0.000093 | 0.000035 | -0.000683 | 0.000100 | -0.000568 | 0.000296 | -0.000240 | -0.000576 | 0.000687 | -0.000542 | 0.000587 | -0.000623 | 0.000117 | -0.001001 | -0.000292 | -0.001323 | 0.000088 | 0.000113 | 0.000350 | -0.000513 | 0.000646 | -0.000596 | -0.000742 | -0.000798 | 0.000182 | -0.000136 | 0.000246 | -0.000164 | 0.000908 | 0.000285 | 0.000096 | 0.000827 | -0.000315 | -0.000246 | 0.000362 | 0.000421 | 0.000810 | -0.000268 | 0.000009 | -0.000686 | -0.001200 | 0.001173 | -0.000311 | -0.000503 | 0.000150 | 0.000671 | 0.000054 | 0.000245 | -0.000864 | 0.001490 | 0.000175 | 0.000539 | -0.000774 | -0.000351 | 0.000063 | -0.000194 | 0.000602 | 0.000699 | 0.000209 | -0.000298 | -0.000257 | -0.000620 | -0.000467 | -0.000895 | -0.000299 | -0.000738 | 0.000472 | -0.000678 | 0.000204 | -0.000021 | 0.000055 | 0.000030 | 0.000266 | -0.000998 | -0.000085 | -0.000154 | 0.000334 | -0.000510 | -0.000052 | 0.000205 | 0.000451 | 0.000091 | -0.000388 | -0.000405 | 0.000866 | 0.000182 | 0.000409 | 0.000505 | 0.000891 | -0.000381 | -0.000393 | 0.000019 | -0.000058 | -0.001252 | -0.001035 | 0.000279 | -0.000201 | 0.000442 | -0.000297 | 0.000374 | 0.000146 | 0.000011 | -0.000466 | -0.001385 | 0.000350 | 0.000756 | -0.000084 | 0.000015 |
1 | mandarin | -0.000065 | 0.000339 | -0.000552 | -0.000853 | -0.000178 | 0.000443 | -0.000220 | 0.000441 | -0.000413 | 0.000553 | 0.000651 | 0.000662 | -0.000064 | -0.000507 | -0.000780 | -0.000521 | 0.000107 | 0.000295 | 0.000157 | 0.000577 | -0.000810 | 0.000461 | -0.000170 | -0.000909 | -0.000086 | -0.000460 | -0.000107 | -0.000271 | 0.000109 | 0.000440 | -0.000037 | -0.000503 | 0.000107 | -0.000183 | -0.000903 | -0.000010 | 0.000402 | 0.000682 | 0.000411 | 0.000393 | 0.000413 | -0.000082 | -0.000467 | 0.000221 | 0.000070 | 0.000043 | -0.000822 | -0.000345 | 0.000705 | -0.000052 | -0.000398 | 0.000593 | 0.000001 | -0.000846 | 0.000512 | -0.000620 | 0.000012 | 0.000553 | 0.000148 | -0.000054 | 0.000380 | 0.000077 | 0.000665 | -0.000321 | 0.000702 | -0.000036 | -0.000448 | 0.000612 | 0.000173 | 0.000278 | -0.000267 | -0.000184 | 0.000324 | -0.000422 | -0.000638 | 0.000834 | -0.000237 | 0.000504 | 0.000155 | 0.000435 | -0.000273 | 0.000130 | -0.000281 | 0.000596 | 0.000302 | -0.000232 | 0.000058 | 0.000039 | -0.000168 | -0.000315 | 0.000054 | 0.000083 | 0.000484 | -0.000097 | -0.000009 | 0.000174 | -0.000003 | -0.000385 | 0.000195 | 0.000019 | 0.000770 | 0.001118 | -0.000466 | 0.000613 | -0.000012 | 0.000494 | -0.000366 | -0.000100 | 0.000308 | 0.000136 | 0.000319 | -0.000233 | -0.000142 | 0.000205 | -0.000871 | 0.000237 | -0.000340 | 0.000136 | -0.000347 | -0.000634 | -0.000086 | 0.000231 | 0.000499 | -0.000026 | 0.000458 | -0.000112 | -0.000477 | -0.000433 | -0.000907 | -0.000019 | -0.000945 | 0.000164 | 0.000584 | -0.000226 | -0.000879 | 0.000385 | 0.000028 | -0.000646 | 0.000839 | 0.000294 | -0.000769 | -0.000120 | -0.000148 | -0.000339 | 0.000728 | 0.000164 | 0.000264 | 0.000181 | -0.000361 | -0.001074 | -0.000374 | -0.000354 | -0.000065 | -0.000285 | 0.000365 | 0.000432 | -0.000042 | -0.000213 | -0.000326 | -0.000125 | 0.000545 | -0.000107 | 0.000848 | 0.000011 | 0.000031 | 0.000145 | -0.000346 | -0.000981 | -0.000378 | 0.000411 | -0.000565 | 0.000738 | -0.000406 | 0.000122 | -0.000269 | -0.000257 | 0.000363 | 0.000139 | -0.000220 | 0.000584 | 0.000453 | -0.000452 | -0.000546 | 0.000366 | 0.000057 | 0.000255 | -0.000126 | -0.000489 | 0.000186 | 0.000413 | 0.000211 | 0.000663 | 0.000056 | -0.000061 | 0.000119 | -0.000038 | -0.000211 | 0.000237 | 0.000014 | 0.000395 | 0.000213 | 0.000496 | 0.000973 | -0.000701 | -0.000401 | 0.000180 | 0.000266 | 0.000400 | 0.000499 | 0.000033 | -0.000441 | 0.000166 | 0.000383 | 0.000429 | -0.000498 | 0.000226 | -0.000750 | 0.000456 | 0.000206 | -0.000886 | 0.000297 | 0.000113 | -0.000303 | -0.000188 | -0.000404 | -0.000713 | -0.000161 | 0.000397 | -0.000542 | -0.000808 | 0.000284 | -0.000023 | 0.000100 | -0.000067 | 0.000175 | -0.000260 | 0.000718 | 0.000998 | 0.000708 | -0.000578 | 0.000529 | 0.000679 | -0.000352 | 0.000788 | 0.000440 | -0.000047 | -0.000080 | -0.000177 | 0.000590 | -0.000236 | 0.000304 | -0.000561 | -0.000418 | 0.000607 | -0.000281 | -0.000366 | -0.000115 | -0.000762 | 0.000167 | 0.000102 | 0.000194 | 0.000126 | 0.000143 | 0.000153 | -0.000535 | 0.000561 | -0.000272 | 0.000120 | -0.000504 | -0.001007 | 0.000092 | 0.000101 | 0.000273 | -0.000364 | 0.000267 | -0.000284 | -0.000153 | 0.000297 | 0.001012 | 0.000167 | -0.000345 | -0.000048 | -0.000034 | -0.000082 | -0.000268 | -0.000718 | -0.000494 | 0.000595 | -0.000105 | -0.000205 | -0.000229 | 0.000336 | -0.000439 | 0.000039 | -0.000912 | -0.000348 | -0.000060 | -0.000022 | -0.000047 | 0.000090 |