Learn how to augment LLMs with conversational memory using a TileDB-Vector-Search index on the conversation history.
How to run this tutorial
We recommend running this tutorial, as well as the other various tutorials in the Tutorials section, inside TileDB Cloud. This will allow you to quickly experiment avoiding all the installation, deployment, and configuration hassles. Sign up for the free tier, spin up a TileDB Cloud notebook with a Python kernel, and follow the tutorial instructions. If you wish to learn how to run tutorials locally on your machine, read the Tutorials: Running Locally tutorial.
In this tutorial, you will learn how to use TileDB-Vector-Search to store the interaction history of a user with an LLM. This allows an LLM to remember past conversations and user preferences, and answer questions about them appropriately.
Set up
To be able to run this tutorial, you will need an OpenAI API key. In addition, if you wish to use your local machine instead of a TileDB Cloud notebook, you will need to install the following:
Import the necessary libraries, set the URI you will use throughout the tutorial, and clean up any previously generated data.
import numpy as npimport osimport shutilfrom langchain.embeddings import OpenAIEmbeddingsfrom langchain.memory import VectorStoreRetrieverMemoryfrom langchain.chat_models import ChatOpenAIfrom langchain.chains import ConversationChainfrom langchain.vectorstores.tiledb import TileDB# URI to be used throughout the tutorialindex_uri ="memory_index"# Clean up past dataif os.path.exists(index_uri): shutil.rmtree(index_uri)
Set up history
Next, create a vector index that will hold the LLM history, and add some past conversations to it.
# Create a TileDB vector index to store the conversation historyembedding_size =1536# Dimensions of the OpenAIEmbeddingsTileDB.create( index_uri=index_uri, index_type="IVF_FLAT", dimensions=embedding_size, vector_type=np.float32,)vectorstore = TileDB.load( index_uri=index_uri, embedding=OpenAIEmbeddings(), allow_dangerous_deserialization=True,)retriever = vectorstore.as_retriever(search_kwargs=dict(k=2))memory = VectorStoreRetrieverMemory(retriever=retriever)# Add some conversation historymemory.save_context({"input": "My name is Nikos"}, {"output": "Hello Nikos"})memory.save_context( {"input": "My favorite food is pizza"}, {"output": "This is a classic choice"})memory.save_context( {"input": "Blue is the best color"}, {"output": "Green is also nice"})
Chat with memory
Now, initialize ChatGPT 3.5 Turbo passing the vector index as its memory, and ask some questions on the past history you created above. Observe that ChatGPT is able to successfully answer those questions, properly taking into account the history.
llm = ChatOpenAI( model="gpt-3.5-turbo",)qa = ConversationChain(llm=llm, memory=memory)question ="What is my name?"print(f"User: {question}")print(f"AI: {qa.predict(input=question)}\n")question ="Are there any football teams with my favorite color in England?"print(f"User: {question}")print(f"AI: {qa.predict(input=question)}\n")question ="Please suggest a recipe for my favorite food"print(f"User: {question}")print(f"AI: {qa.predict(input=question)}\n")
User: What is my name?
AI: Your name is Nikos.
User: Are there any football teams with my favorite color in England?
AI: Yes, there are football teams in England that have blue as their primary color. Some examples include Chelsea FC, Everton FC, and Manchester City FC.
User: Please suggest a recipe for my favorite food
AI: There are so many delicious pizza recipes out there! One popular option is a classic Margherita pizza with fresh basil, mozzarella cheese, and tomato sauce on a thin crust. Or you could try a BBQ chicken pizza with barbecue sauce, chicken, red onions, and cilantro. Another tasty choice is a veggie supreme pizza with bell peppers, mushrooms, olives, and onions. The possibilities are endless! Let me know if you'd like more specific details on any of these recipes.
Clean up
Clean up all the generated data.
# Clean up past dataif os.path.exists(index_uri): shutil.rmtree(index_uri)