1. Structure
  2. AI & ML
  3. ML Models
  4. Quickstart
  • Home
  • What is TileDB?
  • Get Started
  • Explore Content
  • Accounts
    • Individual Accounts
      • Apply for the Free Tier
      • Profile
        • Overview
        • Cloud Credentials
        • Storage Paths
        • REST API Tokens
        • Credits
    • Organization Admins
      • Create an Organization
      • Profile
        • Overview
        • Members
        • Cloud Credentials
        • Storage Paths
        • Billing
      • API Tokens
    • Organization Members
      • Organization Invitations
      • Profile
        • Overview
        • Members
        • Cloud Credentials
        • Storage Paths
        • Billing
      • API Tokens
  • Catalog
    • Introduction
    • Data
      • Arrays
      • Tables
      • Single-Cell (SOMA)
      • Genomics (VCF)
      • Biomedical Imaging
      • Vector Search
      • Files
    • Code
      • Notebooks
      • Dashboards
      • User-Defined Functions
      • Task Graphs
      • ML Models
    • Groups
    • Marketplace
    • Search
  • Collaborate
    • Introduction
    • Organizations
    • Access Control
      • Introduction
      • Share Assets
      • Asset Permissions
      • Public Assets
    • Logging
    • Marketplace
  • Analyze
    • Introduction
    • Slice Data
    • Multi-Region Redirection
    • Notebooks
      • Launch a Notebook
      • Usage
      • Widgets
      • Notebook Image Dependencies
    • Dashboards
      • Dashboards
      • Streamlit
    • Preview
    • User-Defined Functions
    • Task Graphs
    • Serverless SQL
    • Monitor
      • Task Log
      • Task Graph Log
  • Scale
    • Introduction
    • Task Graphs
    • API Usage
  • Structure
    • Why Structure Is Important
    • Arrays
      • Introduction
      • Quickstart
      • Foundation
        • Array Data Model
        • Key Concepts
          • Storage
            • Arrays
            • Dimensions
            • Attributes
            • Cells
            • Domain
            • Tiles
            • Data Layout
            • Compression
            • Encryption
            • Tile Filters
            • Array Schema
            • Schema Evolution
            • Fragments
            • Fragment Metadata
            • Commits
            • Indexing
            • Array Metadata
            • Datetimes
            • Groups
            • Object Stores
          • Compute
            • Writes
            • Deletions
            • Consolidation
            • Vacuuming
            • Time Traveling
            • Reads
            • Query Conditions
            • Aggregates
            • User-Defined Functions
            • Distributed Compute
            • Concurrency
            • Parallelism
        • Storage Format Spec
      • Tutorials
        • Basics
          • Basic Dense Array
          • Basic Sparse Array
          • Array Metadata
          • Compression
          • Encryption
          • Data Layout
          • Tile Filters
          • Datetimes
          • Multiple Attributes
          • Variable-Length Attributes
          • String Dimensions
          • Nullable Attributes
          • Multi-Range Reads
          • Query Conditions
          • Aggregates
          • Deletions
          • Catching Errors
          • Configuration
          • Basic S3 Example
          • Basic TileDB Cloud
          • fromDataFrame
          • Palmer Penguins
        • Advanced
          • Schema Evolution
          • Advanced Writes
            • Write at a Timestamp
            • Get Fragment Info
            • Consolidation
              • Fragments
              • Fragment List
              • Consolidation Plan
              • Commits
              • Fragment Metadata
              • Array Metadata
            • Vacuuming
              • Fragments
              • Commits
              • Fragment Metadata
              • Array Metadata
          • Advanced Reads
            • Get Fragment Info
            • Time Traveling
              • Introduction
              • Fragments
              • Array Metadata
              • Schema Evolution
          • Array Upgrade
          • Backends
            • Amazon S3
            • Azure Blob Storage
            • Google Cloud Storage
            • MinIO
            • Lustre
          • Virtual Filesystem
          • User-Defined Functions
          • Distributed Compute
          • Result Estimation
          • Incomplete Queries
        • Management
          • Array Schema
          • Groups
          • Object Management
        • Performance
          • Summary of Factors
          • Dense vs. Sparse
          • Dimensions vs. Attributes
          • Compression
          • Tiling and Data Layout
          • Tuning Writes
          • Tuning Reads
      • API Reference
    • Tables
      • Introduction
      • Quickstart
      • Foundation
        • Data Model
        • Key Concepts
          • Indexes
          • Columnar Storage
          • Compression
          • Data Manipulation
          • Optimize Tables
          • ACID
          • Serverless SQL
          • SQL Connectors
          • Dataframes
          • CSV Ingestion
      • Tutorials
        • Basics
          • Ingestion with SQL
          • CSV Ingestion
          • Basic S3 Example
          • Running Locally
        • Advanced
          • Scalable Ingestion
          • Scalable Queries
      • API Reference
    • AI & ML
      • Vector Search
        • Introduction
        • Quickstart
        • Foundation
          • Data Model
          • Key Concepts
            • Vector Search
            • Vector Databases
            • Algorithms
            • Distance Metrics
            • Updates
            • Deployment Methods
            • Architecture
            • Distributed Compute
          • Storage Format Spec
        • Tutorials
          • Basics
            • Ingestion & Querying
            • Updates
            • Deletions
            • Basic S3 Example
            • Running Locally
          • Advanced
            • Versioning
            • Time Traveling
            • Consolidation
            • Distributed Compute
            • RAG LLM
            • LLM Memory
            • File Search
            • Image Search
            • Protein Search
          • Performance
        • API Reference
      • ML Models
        • Introduction
        • Quickstart
        • Foundation
          • Basics
          • Storage
          • Cloud Execution
          • Why TileDB for Machine Learning
        • Tutorials
          • Ingestion
            • Data Ingestion
              • Dense Datasets
              • Sparse Datasets
            • ML Model Ingestion
          • Management
            • Array Schema
            • Machine Learning: Groups
            • Time Traveling
    • Life Sciences
      • Single-cell
        • Introduction
        • Quickstart
        • Foundation
          • Data Model
          • Key Concepts
            • Data Structures
            • Use of Apache Arrow
            • Join IDs
            • State Management
            • TileDB Cloud URIs
          • SOMA API Specification
        • Tutorials
          • Data Ingestion
          • Bulk Ingestion Tutorial
          • Data Access
          • Distributed Compute
          • Basic S3 Example
          • Multi-Experiment Queries
          • Appending Data to a SOMA Experiment
          • Add New Measurements
          • SQL Queries
          • Running Locally
          • Shapes in TileDB-SOMA
          • Drug Discovery App
        • Spatial
          • Introduction
          • Foundation
            • Spatial Data Model
            • Data Structures
          • Tutorials
            • Spatial Data Ingestion
            • Access Spatial Data
            • Manage Coordinate Spaces
        • API Reference
      • Population Genomics
        • Introduction
        • Quickstart
        • Foundation
          • Data Model
          • Key Concepts
            • The N+1 Problem
            • Architecture
            • Arrays
            • Ingestion
            • Reads
            • Variant Statistics
            • Annotations
            • User-Defined Functions
            • Tables and SQL
            • Distributed Compute
          • Storage Format Spec
        • Tutorials
          • Basics
            • Basic Ingestion
            • Basic Queries
            • Export to VCF
            • Add New Samples
            • Deleting Samples
            • Basic S3 Example
            • Basic TileDB Cloud
          • Advanced
            • Scalable Ingestion
            • Scalable Queries
            • Query Transforms
            • Handling Large Queries
            • Annotations
              • Finding Annotations
              • Embedded Annotations
              • External Annotations
              • Annotation VCFs
              • Ingesting Annotations
            • Variant Statistics
            • Tables and SQL
            • User-Defined Functions
            • Sample Metadata
            • Split VCF
          • Performance
        • API Reference
          • Command Line Interface
          • Python API
          • Cloud API
      • Biomedical Imaging
        • Introduction
        • Foundation
          • Data Model
          • Key Concepts
            • Arrays
            • Ingestion
            • Reads
            • User Defined Functions
          • Storage Format Spec
        • Quickstart
        • Tutorials
          • Basics
            • Ingestion
            • Read
              • OpenSlide
              • TileDB-Py
          • Advanced
            • Batched Ingestion
            • Chunked Ingestion
            • Machine Learning
              • PyTorch
            • Napari
    • Files
  • API Reference
  • Self-Hosting
    • Installation
    • Upgrades
    • Administrative Tasks
    • Image Customization
      • Customize User-Defined Function Images
      • AWS ECR Container Registry
      • Customize Jupyter Notebook Images
    • Single Sign-On
      • Configure Single Sign-On
      • OpenID Connect
      • Okta SCIM
      • Microsoft Entra
  • Glossary

On this page

  • Prerequisites
  • Installation
    • Installation options
      • Pip
      • Complete installation
  • Basic ML example
  • Dataloaders
  • Visualization
  • Model training
  • Store models
  1. Structure
  2. AI & ML
  3. ML Models
  4. Quickstart

ML Models Quickstart

ai/ml
quickstart
tutorials
machine learning (ml)
Learn how to install TileDB-ML and work with various libraries, including TensorFlow and PyTorch.

This quickstart gives you a rapid introduction to TileDB-ML and its capabilities. It’s split into two main categories:

  1. Models in TileDB
  2. Data in TileDB

This tutorial will teach you how to:

  • Store and load an ML model into TileDB.
  • Register a TileDB-ML model in TileDB Cloud.
  • Store and load a basic dataset into TileDB.
  • Run sample queries on this dataset.
  • Register a dataset in TileDB Cloud.
  • Run model training on a dataset in TileDB.

Prerequisites

Familiarize yourself with Jupyter notebooks to run data exploration and analysis efficiently. You can review Jupyter’s documentation on installing and running notebooks.

Installation

You have two options to install TileDB-ML. You can either install TileDB-ML via pip—the preferred mechanism for installing it—or from source.

  • Pip
  • Source
pip install tiledb-ml
gh repo clone TileDB-Inc/TileDB-ML
cd TileDB-ML
pip install .

Installation options

Given that TileDB-ML integrates with many ML frameworks, you have the following options:

  1. Install all frameworks.
  2. Limit TileDB-ML’s installation to your preferred framework, without the need to install extra dependencies.

Pip

  • PyTorch
  • TensorFlow
pip install tiledb-ml[pytorch]
# or zsh
pip install tiledb-ml\[pytorch\]
pip install tiledb-ml[tensorflow]
# or zsh
pip install tiledb-ml\[tensorflow\]

Complete installation

To install TileDB-ML with all its dependencies and all options enabled, run the following:

  • Bash
  • Zsh
pip install tiledb-ml[full]
pip install tiledb-ml\[full\]

Basic ML example

This quickstart guide will show you how to ingest the MNIST dataset, train a ML model and use it for inference on this dataset.

Note

The MNIST dataset is a widely used benchmark dataset in the field of machine learning. It consists of a collection of 28×28 pixel grayscale images of handwritten digits (0 to 9), along with their corresponding labels showing the digit represented in each image. The dataset is commonly used for training and evaluating machine learning models, particularly for image classification tasks.

  • PyTorch
  • TensorFlow

Import libraries

Start by importing the libraries used in this tutorial.

import os
import tempfile

import idx2numpy
import matplotlib.pyplot as plt
import numpy as np
import tiledb
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tiledb.ml.models.pytorch import PyTorchTileDBModel
from tiledb.ml.readers.pytorch import PyTorchTileDBDataLoader
from tiledb.ml.readers.types import ArrayParams

Download the dataset

def load_mnist_data():
    data_home = os.path.join(os.path.pardir, "data")
    _ = torchvision.datasets.MNIST(root=data_home, train=False, download=True)
    img_path = os.path.join(data_home, "MNIST/raw/train-images-idx3-ubyte")
    labels_path = os.path.join(data_home, "MNIST/raw/train-labels-idx1-ubyte")
    images = idx2numpy.convert_from_file(img_path)
    labels = idx2numpy.convert_from_file(labels_path)
    return images, labels


(images, labels) = load_mnist_data()

Ingest in TileDB

def ingest_in_tiledb(data: np.array, batch_size: int, uri: str):
    # Equal number of dimensions with the numpy array.
    dims = [
        tiledb.Dim(
            name="dim_" + str(dim),
            domain=(0, data.shape[dim] - 1),
            tile=data.shape[dim] if dim > 0 else batch_size,
            dtype=np.int32,
        )
        for dim in range(data.ndim)
    ]
    # TileDB schema
    schema = tiledb.ArraySchema(
        domain=tiledb.Domain(*dims),
        sparse=False,
        attrs=[tiledb.Attr(name="features", dtype=data.dtype)],
    )
    # Create array
    tiledb.Array.create(uri, schema)
    # Ingest
    with tiledb.open(uri, "w") as tiledb_array:
        tiledb_array[:] = {"features": data}
dataset = tempfile.mkdtemp("mnist_array")
# Ingest images
training_images = os.path.join(dataset, "training_images")
ingest_in_tiledb(data=images, batch_size=64, uri=training_images)

# Ingest labels
training_labels = os.path.join(dataset, "training_labels")
ingest_in_tiledb(data=labels, batch_size=64, uri=training_labels)

TileDB dataset

images_array = tiledb.open(training_images)
labels_array = tiledb.open(training_labels)

Arrays schemas

images_array.schema
Domain
Name Domain Tile Data Type Is Var-length Filters
dim_0 (0, 59999) 64 int32 False -
dim_1 (0, 27) 28 int32 False -
dim_2 (0, 27) 28 int32 False -
Attributes
Name Data Type Is Var-Len Is Nullable Filters
features uint8 False False -
Cell Order
row-major
Tile Order
row-major
Sparse
False
labels_array.schema
Domain
Name Domain Tile Data Type Is Var-length Filters
dim_0 (0, 59999) 64 int32 False -
Attributes
Name Data Type Is Var-Len Is Nullable Filters
features uint8 False False -
Cell Order
row-major
Tile Order
row-major
Sparse
False

Import libraries

Start by importing the libraries used in this tutorial.

import os
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tiledb
from tiledb.ml.models.tensorflow_keras import TensorflowKerasTileDBModel
from tiledb.ml.readers.tensorflow import ArrayParams, TensorflowTileDBDataset

Download the dataset

(images, labels), _ = tf.keras.datasets.mnist.load_data()
images = images / 255.0

Ingest in TileDB

def ingest_in_tiledb(data: np.array, batch_size: int, uri: str):
    # Equal number of dimensions with the numpy array.
    dims = [
        tiledb.Dim(
            name="dim_" + str(dim),
            domain=(0, data.shape[dim] - 1),
            tile=data.shape[dim] if dim > 0 else batch_size,
            dtype=np.int32,
        )
        for dim in range(data.ndim)
    ]
    # TileDB schema
    schema = tiledb.ArraySchema(
        domain=tiledb.Domain(*dims),
        sparse=False,
        attrs=[tiledb.Attr(name="features", dtype=data.dtype)],
    )
    # Create array
    tiledb.Array.create(uri, schema)
    # Ingest
    with tiledb.open(uri, "w") as tiledb_array:
        tiledb_array[:] = {"features": data}
dataset = tempfile.mkdtemp("mnist_array")
# Ingest images
training_images = os.path.join(dataset, "training_images")
ingest_in_tiledb(data=images, batch_size=64, uri=training_images)

# Ingest labels
training_labels = os.path.join(dataset, "training_labels")
ingest_in_tiledb(data=labels, batch_size=64, uri=training_labels)

TileDB dataset

images_array = tiledb.open(training_images)
labels_array = tiledb.open(training_labels)

Arrays schemas

images_array.schema
Domain
Name Domain Tile Data Type Is Var-length Filters
dim_0 (0, 59999) 64 int32 False -
dim_1 (0, 27) 28 int32 False -
dim_2 (0, 27) 28 int32 False -
Attributes
Name Data Type Is Var-Len Is Nullable Filters
features float64 False False -
Cell Order
row-major
Tile Order
row-major
Sparse
False
labels_array.schema
Domain
Name Domain Tile Data Type Is Var-length Filters
dim_0 (0, 59999) 64 int32 False -
Attributes
Name Data Type Is Var-Len Is Nullable Filters
features uint8 False False -
Cell Order
row-major
Tile Order
row-major
Sparse
False

Dataloaders

TileDB offers an API with native dataloaders for all the ML frameworks with which TileDB integrates. After you store your data, you can use the API to create dataloaders in each framework that will be later used as input to the model’s training stage. The API takes two TileDB arrays as inputs: x, which refers to the sample data; and y, which holds the label data corresponding to each sample in x. The dataloader collates these two arrays into a single data object that can later be used as input for training a model.

  • PyTorch
  • TensorFlow
Note

Jupyter notebooks have limited support of Python multiprocessing. Avoid using multiple workers on Jupyter when you need multiprocessing. Instead, run scripts with a normal Python interpreter.

with tiledb.open(training_images) as x, tiledb.open(training_labels) as y:
    train_loader = PyTorchTileDBDataLoader(
        ArrayParams(x),
        ArrayParams(y),
        batch_size=128,
        num_workers=0,
        shuffle_buffer_size=256,
    )
    batch_imgs, batch_labels = next(iter(train_loader))
    print(f"Input Shape: {batch_imgs.shape}")
    print(f"Label Shape: {batch_labels.shape}")
Input Shape: torch.Size([128, 28, 28])
Label Shape: torch.Size([128])
with tiledb.open(training_images) as x, tiledb.open(training_labels) as y:
    tiledb_dataset = TensorflowTileDBDataset(
        ArrayParams(array=x, fields=["features"]),
        ArrayParams(array=y, fields=["features"]),
    )
    batched_dataset = tiledb_dataset.batch(128)
    batch_imgs, batch_labels = next(batched_dataset.as_numpy_iterator())
    print(f"Input Shape: {batch_imgs.shape}")
    print(f"Label Shape: {batch_labels.shape}")
Input Shape: (128, 28, 28)
Label Shape: (128,)

Visualization

Render the first image from the batched data fetched by TileDB-ML loaders:

  • PyTorch
  • TensorFlow
image = batch_imgs[0]
plt.subplot(1, 2, 1)
plt.imshow(image, cmap="gray")
Figure 1
image = batch_imgs[0]
plt.subplot(1, 2, 1)
plt.imshow(image, cmap="gray")
Figure 2

Model training

  • PyTorch
  • TensorFlow

Configure the model

epochs = 1
batch_size_train = 128
learning_rate = 0.01
momentum = 0.5
log_interval = 10

Define the model

class Net(nn.Module):
    def __init__(self, shape):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(np.product(shape), 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

Model training

train_losses = []
train_counter = []

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.5)


def do_random_noise(img, mag=0.1):
    noise = np.random.uniform(-1, 1, img.shape) * mag
    img = img + noise
    img = np.clip(img, 0, 1)
    return img


with tiledb.open(training_images) as x, tiledb.open(training_labels) as y:
    train_loader = PyTorchTileDBDataLoader(
        ArrayParams(x, fn=do_random_noise),
        ArrayParams(y),
        batch_size=128,
        num_workers=0,
        shuffle_buffer_size=256,
    )
    for epoch in range(1, 3):
        network.train()
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = network(inputs.to(torch.float))
            loss = criterion(outputs, labels.to(torch.float).type(torch.LongTensor))
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(
                    "Train Epoch: {} Batch: {} Loss: {:.6f}".format(
                        epoch, batch_idx, loss.item()
                    )
                )
Train Epoch: 1 Batch: 0 Loss: 2.306959
Train Epoch: 1 Batch: 100 Loss: 2.268600
Train Epoch: 1 Batch: 200 Loss: 2.188591
Train Epoch: 1 Batch: 300 Loss: 1.987314
Train Epoch: 1 Batch: 400 Loss: 1.855456
Train Epoch: 2 Batch: 0 Loss: 1.624074
Train Epoch: 2 Batch: 100 Loss: 1.569407
Train Epoch: 2 Batch: 200 Loss: 1.558700
Train Epoch: 2 Batch: 300 Loss: 1.329881
Train Epoch: 2 Batch: 400 Loss: 1.217714

Define the model

model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10),
    ]
)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])

Training

with tiledb.open(training_images) as x, tiledb.open(training_labels) as y:
    tiledb_dataset = TensorflowTileDBDataset(
        ArrayParams(array=x),
        ArrayParams(array=y),
        num_workers=2 if os.cpu_count() > 2 else 0,
    )
    tiledb_dataset = tiledb_dataset.batch(64).shuffle(128)
    model.fit(tiledb_dataset, epochs=5)
Epoch 1/5
938/938 [==============================] - 3s 2ms/step - loss: 0.0755 - accuracy: 0.9772
Epoch 2/5
938/938 [==============================] - 3s 2ms/step - loss: 0.0675 - accuracy: 0.9786
Epoch 3/5
938/938 [==============================] - 3s 2ms/step - loss: 0.0590 - accuracy: 0.9806
Epoch 4/5
938/938 [==============================] - 3s 2ms/step - loss: 0.0551 - accuracy: 0.9828
Epoch 5/5
938/938 [==============================] - 3s 2ms/step - loss: 0.0500 - accuracy: 0.9837

Store models

  • PyTorch
  • TensorFlow

Configure store paths

model_dir = tempfile.mkdtemp("mnist_model")
model_uri = os.path.join(model_dir, "mnist-1")

Save the model

network = Net(shape=(28, 28))
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)
tiledb_model = PyTorchTileDBModel(uri=model_uri, model=network, optimizer=optimizer)
tiledb_model.save(meta={"epochs": epochs})

Configure store paths

model_dir = tempfile.mkdtemp("mnist_model")
uri = os.path.join(model_dir, "mnist-1")

Save the model

tiledb_model = TensorflowKerasTileDBModel(uri=uri, model=model)
tiledb_model.save(include_optimizer=True)
Introduction
Foundation