Learn how to use TileDB-Vector-Search to perform a similarity search for a protein dataset.
How to run this tutorial
We recommend running this tutorial, as well as the other various tutorials in the Tutorials section, inside TileDB Cloud. This will allow you to quickly experiment avoiding all the installation, deployment and configuration hassles. Sign up for the free tier, spin up a TileDB Cloud notebook with a Python kernel, and follow the tutorial instructions. If you wish to learn how to run tutorials locally on your machine, read the Tutorials: Running Locally tutorial.
This tutorial shows how you can use TileDB-Vector-Search to search for similar proteins within a protein dataset.
Dataset
You will use the Swiss-Prot dataset from UniProtKB, including 570k manually-annotated proteins with information extracted from literature and curator evaluated computational analysis.
Embeddings
Protein embeddings are a way to encode functional and structural properties of a protein, from its amino-acid sequence. Generating such embeddings is computationally expensive, but once computed they can be leveraged for different tasks, such as sequence similarity search, sequence clustering, and sequence classification.
UniProt is providing raw embeddings (per-protein and per-residue using the ProtT5 model) for the Swiss-Prot dataset.
The embeddings were generated using the bio_embeddings tool, and the specific model used is prottrans_t5_xl_u50. You can check more details here.
The embeddings can also be generated using the publicly available HuggingFace model ProtT5-XL-UniRef50 using the following code snippet:
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
sequence_examples = ["PRTEINO", "SEQWENCE"]
# this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]
# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)
# generate embeddings
with torch.no_grad():
embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
# extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7])
emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024)
print(f"Shape of per-residue embedding of first sequences: {emb_0.shape}")
# do the same for the second ([1,:]) sequence in the batch while taking into account different sequence lengths ([1,:8])
emb_1 = embedding_repr.last_hidden_state[1,:8] # shape (8 x 1024)
# if you want to derive a single representation (per-protein embedding) for the whole protein
emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)
print(f"Shape of per-protein embedding of first sequences: {emb_0_per_protein.shape}")
This tutorial uses the pre-computed per-protein embeddings for the Homo Sapiens part of the Swiss-Prot dataset.
Setup
If you are running this tutorial locally, you will additionally need to install the following:
pip install icn3dpy h5py
Start by importing the necessary libraries and defining URI variables for the different assets:
import osimport shutilimport randomimport numpy as npimport h5pyimport icn3dpyimport tiledbimport tiledb.vector_search as vsinput_dir ="protein-data"swiss_prot_uri ="swiss-prot-data"index_uri ="swiss-prot-index"
Download dataset
Download the Swiss-Prot dataset and the per-protein embeddings of the Homo Sapiens part of the dataset.
You can now parse the FASTA encoded Swiss-Prot dataset and load it in a TileDB array for convenient retrieval of protein data.
# Util function to load a FASTA file to a TileDB arraydef fasta_to_tiledb(fasta_path, tiledb_array_uri): max_size =1000000 sequences = np.empty(max_size, dtype="O") metadata = np.empty(max_size, dtype="O") uniprot_ids = np.empty(max_size, dtype="O") prot_ids = np.empty(max_size, dtype="O") ids = np.zeros(max_size, dtype=np.uint64) prot_id =-1withopen(fasta_path, "r") as fasta_f:for line in fasta_f:if line.startswith(">"): prot_id +=1 prot_metadata = line.replace(">", "").strip() prot_metadata = ( prot_metadata.replace("/", "_").replace(".", "_").split(" ", 1) ) uniprot_id = prot_metadata[0] sequences[prot_id] ="" uniprot_ids[prot_id] = uniprot_id prot_ids[prot_id] = uniprot_id.split("|")[1] ids[prot_id] =abs(hash(uniprot_id.split("|")[1])) metadata[prot_id] = prot_metadata[1]else: sequences[prot_id] += ("".join(line.split()).upper().replace("-", "") ) # drop gaps and cast to upper-case prot_id +=1with tiledb.open(tiledb_array_uri, mode="w") as swiss_prot_array: swiss_prot_array[ids[0:prot_id]] = {"sequence": sequences[0:prot_id],"metadata": metadata[0:prot_id],"uniprot_id": uniprot_ids[0:prot_id],"prot_id": prot_ids[0:prot_id], }# Delete array is existsif os.path.isdir(swiss_prot_uri): shutil.rmtree(swiss_prot_uri)# Create TileDB arraydim = tiledb.Dim(name="id", domain=(0, np.iinfo(np.uint64).max-1), dtype=np.uint64)dom = tiledb.Domain(dim)sequence_attr = tiledb.Attr(name="sequence", dtype=str)metadata_attr = tiledb.Attr(name="metadata", dtype=str)uniprot_id_attr = tiledb.Attr(name="uniprot_id", dtype=str)prot_id_attr = tiledb.Attr(name="prot_id", dtype=str)schema = tiledb.ArraySchema( domain=dom, sparse=True, attrs=[sequence_attr, metadata_attr, uniprot_id_attr, prot_id_attr],)tiledb.Array.create(swiss_prot_uri, schema)# Load FASTA file to TileDB arrayfasta_to_tiledb(f"{input_dir}/uniprot_sprot.fasta", swiss_prot_uri)
Index
First, lets read the protein embeddings from the downloaded H5 file.
with h5py.File(f"{input_dir}/per-protein.h5", "r") asfile: size =len(file.items()) external_ids = np.zeros(size, dtype=np.uint64) embeddings = np.zeros((size, 1024), dtype=np.float32) i =0for sequence_id, embedding infile.items(): external_ids[i] =abs(hash(sequence_id)) embeddings[i] = np.array(embedding) i +=1
Index the protein embeddings using an IVF_FLAT index.
if os.path.isdir(index_uri): shutil.rmtree(index_uri)index = vs.ingest( index_type="IVF_FLAT", index_uri=index_uri, input_vectors=embeddings, external_ids=external_ids,)
Protein similarity search
Pick a query protein
Start by picking a random protein and displaying its 3D structure:
# Open the Swiss-Prot vector indexindex = vs.IVFFlatIndex(uri=index_uri)# Pick a random protein from Swiss-Protrandom_prot_id = random.randrange(size)with tiledb.open(swiss_prot_uri) as swiss_prot_array: random_prot_data = swiss_prot_array[external_ids[random_prot_id]]print(f"Query protein: {random_prot_data['prot_id'][0]} metadata: {random_prot_data['metadata']}")view = icn3dpy.view(q=f"mmdbafid={random_prot_data['prot_id'][0]}")display(view)
Query protein: Q8N1N4 metadata: ['Keratin, type II cytoskeletal 78 OS=Homo sapiens OX=9606 GN=KRT78 PE=1 SV=2']
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the extension: jupyter labextension install jupyterlab_3dmol
<icn3dpy.view at 0x28a872e50>
Similarity search
Now search for similar proteins in the index:
# Search for similar proteins in SwissProtd, i = index.query(np.array([embeddings[random_prot_id]]), k=4, nprobe=30)
Display results
Finally, display the results along with their 3D structure:
# Display the resultswith tiledb.open(swiss_prot_uri) as swiss_prot_array:for similar_prot_id in i[0][1:]: similar_prot_data = swiss_prot_array[similar_prot_id]print(f"Similar protein: {similar_prot_data['prot_id'][0]} metadata: {similar_prot_data['metadata']}" ) view = icn3dpy.view(q=f"mmdbafid={similar_prot_data['prot_id'][0]}") display(view)
Similar protein: Q14CN4 metadata: ['Keratin, type II cytoskeletal 72 OS=Homo sapiens OX=9606 GN=KRT72 PE=1 SV=2']
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the extension: jupyter labextension install jupyterlab_3dmol
<icn3dpy.view at 0x28a872b20>
Similar protein: Q86Y46 metadata: ['Keratin, type II cytoskeletal 73 OS=Homo sapiens OX=9606 GN=KRT73 PE=1 SV=1']
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the extension: jupyter labextension install jupyterlab_3dmol
<icn3dpy.view at 0x28a872b50>
Similar protein: Q3SY84 metadata: ['Keratin, type II cytoskeletal 71 OS=Homo sapiens OX=9606 GN=KRT71 PE=1 SV=3']
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the extension: jupyter labextension install jupyterlab_3dmol
<icn3dpy.view at 0x28a872640>
Clean up
Clean up all the generated data.
# Clean up past dataif os.path.exists(input_dir): shutil.rmtree(input_dir)if os.path.exists(swiss_prot_uri): shutil.rmtree(swiss_prot_uri)if os.path.exists(index_uri): shutil.rmtree(index_uri)