Build a semantic search tool using FAISS

This post provides an overview of implementing semantic search. Why? Because too often, we notice testers skip testing more complex features like autocomplete. This might be ok in most applications. But in domain specific applications, testing autocomplete capabilities of the product is important. Since testers can benefit from understanding implementation details, in this post, we will look at how autocomplete is usually implemented. We will follow up in a later post on how to write some interesting tests that build upon the insights in this post.

An example to play along:

For this post, let us imagine a text box. You start to type the name of a technology and related technologies pop up.

Aside: we needed to write a similar tool recently and this post borrows heavily from our work on the internal tool. We recognized the need to address a data duplication issue within our internal survey app at Qxf2 Services. Every week, our team completes a survey, listing the technologies used during that week. This process aids us in tracking and evaluating our efforts to remain current with evolving developments in the tech industry. You can read more about it in our blog post ‘Qxf2 Tech Used in 2021‘. However, with several years’ worth of data collected, we noticed a challenge: the same technologies were represented differently, causing redundancy. For instance, SQS was sometimes recorded as just ‘SQS,’ other times as ‘AWS-SQS,’ or ‘Amazon Simple Queue Service.’ This inconsistency was causing database bloat. To tackle this issue, we sought a solution that would enable us to suggest existing words when someone filled in the survey.


About semantic search:

Remember autocomplete capabilities from many years ago? You start typing a word and get suggestions on ways to complete the word. This sort of autocomplete relies on the spelling of the word. It has a fancy name – lexical search. Semantic search is a methodology that enhances precision beyond lexical search by seeking contextually related terms. For instance, when seeking ‘dog,’ the search might also yield ‘German shepherd’ due to their correlation. This is achieved by transforming words like ‘dog’ and ‘German shepherd’ into vectors and positioning them in proximity within a vector space. During a search, the query is similarly transformed into a vector and placed within the same vector space and the closest vectors are identified and retrieved, yielding relevant results.


Tools used to implement semantic search:

During my recent exploration of context-based question answering using LLM, I came across FAISS. We realized that this library could assist us in resolving the data duplication problem. In the initial phase of addressing this issue, I developed a semantic search tool using the FAISS library, leveraging a Stack Overflow dataset. I built my application by referencing the example provided in Tutorial: semantic search using Faiss & MPNet. I am sharing this post with the hope of aiding fellow engineers in their own related tasks.

1. sentence-transformers/all-mpnet-base-v2, to create vector embeddings
2. FAISS,to index the vectors, it also provides APIs to search and retrieve relevant vectors


Environment setup:

The setup process involves the following steps:

# Install torch using the following command
pip install torch==2.0.1
pip install torchvision==0.15.2
 
# Install Transformers module using the following command
pip install transformers==4.30.2
 
# Install FAISS library using the following command
pip install faiss-cpu==1.7.4
 
# Install beautifulsoup using the following command
pip install bs4==0.0.1

Creating a semantic search tool with FAISS:

I have created objects for the following purposes:
1. Read the rows from stackoverflow XML data file
2. Convert words to vectors
3. Create a FAISS index
4. Create a pickle

Read the rows from XML file:

I have used this Stack Overflow dataset to create the tool.

class XMLReader:
    "An XML object to read the values from an XML file"
    @staticmethod
    def read_from_file(xml_file, html_property):
        "Read from XML file"
        html_property = html_property.lower()
        with open(xml_file, 'r') as xmlfile:
            xml = xmlfile.readlines()
        xml = "".join(xml)
 
        soup = BeautifulSoup(xml, "html.parser")
        rows = soup.find_all('row')
        return [ row[html_property] for row in rows]

The XMLReader object will read the XML data file and return all the rows as a list.

Convert words to vectors:
class SemanticEmbedding:
    "A semantic embedding object to get the word embeddings"
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        "object initialization"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
 
    def mean_pooling(self, model_output, attention_mask):
        """
        Mean Pooling - Take attention mask into account for correct averaging
        Although this is very useful to create a vector for a sentence,
        it is useful in our case, where we use a word alone
        """
        #First element of model_output contains all token embeddings
        token_embeddings = model_output[0]
        input_mask_exp = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings*input_mask_exp,1)/torch.clamp(input_mask_exp.sum(1),
                                                                                  min=1e-9)
 
    def get_embedding(self, word):
        "create word embeddings"
        encoded_input = self.tokenizer(word, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        # Perform pooling
        word_embedding = self.mean_pooling(model_output, encoded_input['attention_mask'])
 
        # Normalize embeddings
        word_embedding = torch.nn.functional.normalize(word_embedding, p=2, dim=1)
        return word_embedding.detach().numpy()

The SemanticEmbedding object provides methods to create embeddings for words, the get_embeddings takes text as an input and returns a tensor of size [1, 768]

Create a FAISS index
class FaissIdxObject:
    "A FAISS object to create,add-doc, search and save an index"
    def __init__(self, dim=768):
        "object initialize"
        self.dim = dim
        self.ctr = 0
 
    def create_index(self):
        "Create a new index"
        return faiss.IndexFlatIP(self.dim)
 
    @staticmethod
    def get_index(index_name):
        "Get the index"
        try:
            return faiss.read_index(index_name)
        except FileNotFoundError as err:
            raise f"Unable to find {index_name}, does the file exist? from {err}"
 
    @staticmethod
    def add_doc_to_index(index, embedded_document_text):
        "Add doc to index"
        index.add(embedded_document_text)
 
    @staticmethod
    def search_index(embedded_query, index, doc_map, k=5, return_scores=False):
        "Search through the index"
        D, I = index.search(embedded_query, k)
        if return_scores:
            value = [{doc_map[idx]: str(score)} for idx, score in zip(I[0], D[0]) if idx in doc_map]
        else:
            value = [doc_map[idx] for idx, score in zip(I[0], D[0]) if idx in doc_map]
        return value
 
    @staticmethod
    def save_index(index, index_name):
        "Save the index and dataset pickle file to local"
        try:
            faiss.write_index(index, index_name)
        except Exception as err:
            raise err

The FaissIdxObject object provides methods to create an index and search a vector and return related vectors. The search_index method returns the distance to the nearest neighbours D and their index I.
For my application, I opted for IndexFlatIP index, This choice was driven by its utilization of the inner product as the distance metric, which, for normalized embeddings, equates to cosine similarity

Create a pickle
class PickleObject:
    "A pickle object to save and read the humanreadable dataset"
    def create_dict(self):
        "Create a new dict"
        return {}
 
    @staticmethod
    def get_pickle(pickle_name):
        "Get the local pickle file"
        try:
            with open(pickle_name, 'rb') as pickled_file:
                return pickle.load(pickled_file)
        except FileNotFoundError as err:
            raise f"Unable to find {pickle_name}, does the file exist? from {err}"
 
    @staticmethod
    def add_doc_to_pickle(pickle_dict, counter, doc):
        "Add entry to the pickle"
        pickle_dict[counter] = doc
 
    @staticmethod
    def save_pickle(pickle_file, pickle_name):
        "Save the pickle file to local"
        try:
            with open(pickle_name, 'wb') as pf:
                pickle.dump(pickle_file, pf, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as err:
            raise err

The PickleObject provides methods to save and retrieve human readable data locally. This object saves the index as a key in Python dictionary against the actual text data & helps in retrieving it based on the index.

Putting it all together:
"""
A semantic search tool
"""
import pickle
from pathlib import Path
import faiss
import torch
from bs4 import BeautifulSoup
from transformers import AutoTokenizer, AutoModel
 
class SemanticEmbedding:
    "A semantic embedding object to get the word embeddings"
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        "object initialization"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
 
    def mean_pooling(self, model_output, attention_mask):
        """
        Mean Pooling - Take attention mask into account for correct averaging
        Although this is very useful to create a vector for a sentence,
        it is useful in our case, where we use a word alone
        """
        #First element of model_output contains all token embeddings
        token_embeddings = model_output[0]
        input_mask_exp = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings*input_mask_exp,1)/torch.clamp(input_mask_exp.sum(1),
                                                                                  min=1e-9)
 
    def get_embedding(self, word):
        "create word embeddings"
        encoded_input = self.tokenizer(word, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        # Perform pooling
        word_embedding = self.mean_pooling(model_output, encoded_input['attention_mask'])
 
        # Normalize embeddings
        word_embedding = torch.nn.functional.normalize(word_embedding, p=2, dim=1)
        return word_embedding.detach().numpy()
 
class FaissIdxObject:
    "A FAISS object to create,add-doc, search and save an index"
    def __init__(self, dim=768):
        "object initialize"
        self.dim = dim
        self.ctr = 0
 
    def create_index(self):
        "Create a new index"
        return faiss.IndexFlatIP(self.dim)
 
    @staticmethod
    def get_index(index_name):
        "Get the index"
        try:
            return faiss.read_index(index_name)
        except FileNotFoundError as err:
            raise f"Unable to find {index_name}, does the file exist? from {err}"
 
    @staticmethod
    def add_doc_to_index(index, embedded_document_text):
        "Add doc to index"
        index.add(embedded_document_text)
 
    @staticmethod
    def search_index(embedded_query, index, doc_map, k=5, return_scores=False):
        "Search through the index"
        D, I = index.search(embedded_query, k)
        if return_scores:
            value = [{doc_map[idx]: str(score)} for idx, score in zip(I[0], D[0]) if idx in doc_map]
        else:
            value = [doc_map[idx] for idx, score in zip(I[0], D[0]) if idx in doc_map]
        return value
 
    @staticmethod
    def save_index(index, index_name):
        "Save the index and dataset pickle file to local"
        try:
            faiss.write_index(index, index_name)
        except Exception as err:
            raise err
 
class PickleObject:
    "A pickle object to save and read the humanreadable dataset"
    def create_dict(self):
        "Create a new dict"
        return {}
 
    @staticmethod
    def get_pickle(pickle_name):
        "Get the local pickle file"
        try:
            with open(pickle_name, 'rb') as pickled_file:
                return pickle.load(pickled_file)
        except FileNotFoundError as err:
            raise f"Unable to find {pickle_name}, does the file exist? from {err}"
 
    @staticmethod
    def add_doc_to_pickle(pickle_dict, counter, doc):
        "Add entry to the pickle"
        pickle_dict[counter] = doc
 
    @staticmethod
    def save_pickle(pickle_file, pickle_name):
        "Save the pickle file to local"
        try:
            with open(pickle_name, 'wb') as pf:
                pickle.dump(pickle_file, pf, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as err:
            raise err
 
class XMLReader:
    "An XML object to read the values from an XML file"
    @staticmethod
    def read_from_file(xml_file, html_property):
        "Read from XML file"
        html_property = html_property.lower()
        with open(xml_file, 'r') as xmlfile:
            xml = xmlfile.readlines()
        xml = "".join(xml)
 
        soup = BeautifulSoup(xml, "html.parser")
        rows = soup.find_all('row')
        return [ row[html_property] for row in rows]
 
if __name__ == '__main__':
    embedder = SemanticEmbedding()
 
    if not Path('Tags.index').is_file() or not Path('Tags.pickle').is_file():
        faiss_obj = FaissIdxObject()
        pickle_obj = PickleObject()
        xml_reader = XMLReader()
        faiss_index = faiss_obj.create_index()
        doc_dict = pickle_obj.create_dict()
        input_rows = xml_reader.read_from_file(xml_file='Tags.xml',
                                               html_property='TagName')
        COUNTER = 0
        for row in input_rows:
            embedded_content = embedder.get_embedding(row)
            faiss_obj.add_doc_to_index(index=faiss_index,
                                       embedded_document_text=embedded_content)
            pickle_obj.add_doc_to_pickle(pickle_dict=doc_dict,
                                         counter=COUNTER,
                                         doc=row)
            COUNTER += 1
 
        faiss_obj.save_index(index=faiss_index, index_name='Tags.index')
        pickle_obj.save_pickle(pickle_file=doc_dict,pickle_name='Tags.pickle')
 
    else:
        faiss_index = FaissIdxObject.get_index(index_name='Tags.index')
        doc_dict = PickleObject.get_pickle(pickle_name='Tags.pickle')
 
    while True:
        tech = input("\nEnter a tech: ")
        if tech == "exit":
            break
        if tech.strip() == "":
            continue
        embedded_input = embedder.get_embedding(tech)
        output = FaissIdxObject.search_index(embedded_query=embedded_input,
                             index=faiss_index,
                             doc_map=doc_dict,
                             k=10,
                             return_scores=True)
        print(output)

You can find the snippet here
Note: The script is very rudimentary and yet to go through our stringent code review process.

Output
Enter a tech: python
[{'python': '1.0'}, {'java': '0.62871754'}, {'coding': '0.62170094'}, {'c#': '0.5832066'}, {'unix': '0.575704'}, {'python-3.x': '0.57134223'}, {'programming-languages': '0.5683065'}, {'languages': '0.55639744'}, {'c++': '0.5557853'}, {'javascript': '0.5530219'}]

From the output, it is clear that the tool not just returned python but also a few other entries it considered similar to it.


What next?

Now we have a mental model for what happens under the hood of an autocomplete feature, it leads to the next obvious question – how can we effectively test this feature? Our team engaged in a discussion regarding the testing strategies for this application. In the next post, we will delve into some testing techniques.


Hire technical testers from Qxf2

You will not come across many technical testers that can understand the technical aspects of semantic search, implement FAISS and then apply this knowledge to test products better. But Qxf2 is stacked with such QA engineers. We enjoy the technical aspects of testing. Our approach goes well beyond traditional test automation. If you are working in a highly technical domain and would like good testers on your team – reach out!


Leave a Reply

Your email address will not be published. Required fields are marked *