scrna5/6 Jupyter Notebook lamindata

Train an ML model on a dataset#

In the previous tutorial, we loaded an entire dataset into memory to perform a simple analysis.

Here, we’ll iterate over the files within the dataset to train an ML model.

import lamindb as ln
import anndata as ad
import numpy as np
💡 lamindb instance: testuser1/test-scrna
ln.track()
💡 notebook imports: anndata==0.9.2 lamindb==0.61.0 numpy==1.26.2 torch==2.1.1
💡 saved: Transform(uid='Qr1kIHvK506rz8', name='Train an ML model on a dataset', short_name='scrna5', version='0', type=notebook, updated_at=2023-11-20 22:27:32 UTC, created_by_id=1)
💡 saved: Run(uid='gVos6BQcOwDHyQBIj7bQ', run_at=2023-11-20 22:27:32 UTC, transform_id=5, created_by_id=1)

Preprocessing#

Let us get our dataset:

dataset_v2 = ln.Dataset.filter(name="My versioned scRNA-seq dataset", version="2").one()
dataset_v2
Dataset(uid='xDJG50qUj6fnqYEgYaA5', name='My versioned scRNA-seq dataset', version='2', hash='-J1PZEjWCBP0OptD6HtZ', visibility=0, updated_at=2023-11-20 22:27:09 UTC, transform_id=2, run_id=2, initial_version_id=1, created_by_id=1)

We’ll need to make a decision on the features that we want to use for training the model.

Because each file is validated, they’re all indexed by ensembl_gene_id in the var slot of AnnData.

To make our live easy, we’ll intersect features across files:

files = dataset_v2.files.all()
# the gene sets are stored in the "var" slot of features
shared_genes = files[0].features["var"]
for file in files[1:]:
    # QuerySet objects allow set operations
    shared_genes = shared_genes & file.features["var"]
shared_genes_ensembl = shared_genes.list("ensembl_gene_id")

We’ll now store the raw representations and create a training dataset:

raw_files = []
for file in files:
    adata_raw = file.load().raw[:, shared_genes_ensembl].to_adata()
    raw_file = ln.File(adata_raw, description=f"Raw data of file {file.uid}")
    raw_files.append(raw_file)
ln.save(raw_files)

ds_train = ln.Dataset(raw_files, name="My training dataset", version="2")
ds_train.save()
ds_train.view_flow()
Hide code cell output
_images/f8df0ff39fed97d3a3fb6c90c8ef9c4a03eb68e089742ceaf16ffd4202c871f3.svg

Create a pytorch DataLoader#

If you need to train your model on a list of files, you can use ln.Dataset.indexed() with the pytorch DataLoader.

It doesn’t load anything into memory and thus allows to work with very large datasets.

from torch.utils.data import DataLoader, WeightedRandomSampler

Files in the dataset should have the same variables, we have already taken care of this.

ds_mapped = ds_train.mapped(label_keys=["cell_type"])

This is compatible with pytorch DataLoader because it implements __getitem__ over a list of AnnData files.

ds_mapped[5]
[array([ 0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0., 17.,  0.,  0.,  0.,  2.,  0.,  0.,  2.,  1.,
         0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
         0.,  0.,  0.,  0.,  0.,  4.,  0.,  2.,  3.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
         3.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
         1.,  0.,  3.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,
         2.,  0.,  0.,  5.,  6.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,
         0.,  0.,  0.,  1.,  1.,  1.,  1.,  3.,  0.,  0.,  4.,  1.,  3.,
         0.,  0.,  0.,  0.,  0.,  2.,  0.,  2.,  1.,  0.,  0.,  1.,  0.,
         0.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  2.,  0.,  1.,
         0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         5.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
         0.,  2.,  1.,  0.,  0.,  1.,  3.,  4.,  1.,  0.,  2.,  1.,  1.,
         1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  0.,  0.,  3.,
         0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 96.,  0.,  6.,
         1.,  2.,  0.,  0.,  1.,  0.,  6.,  1.,  0.,  0.,  1.,  0.,  2.,
         0.,  3.,  0.,  0.,  2., 10.,  0.,  0.,  0.,  5.,  1., 26.,  2.,
        14.,  6.,  5.,  0.,  3.,  0.,  8., 10.,  0.,  1.,  0.,  1.,  0.,
         1.,  1.,  0.,  5.,  1.,  0.,  3.,  1.,  1.,  1.,  0.,  0.,  0.,
         2.,  1.,  3.,  0.,  0.,  1.,  3.,  3.,  0.,  0.,  2.,  0.,  1.,
         0.,  4.,  1.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  7.,  1.,  0.,
         0.,  0.,  0.,  0.,  0.,  2.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,
         0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  2.,  0.,  0.,  0.,
         0.,  0.,  0.,  2.,  0.,  1.,  2.,  0.,  1.,  0.,  2.,  0.,  1.,
         0.,  1.,  0.,  0.,  1.,  0.,  0.,  0., 17.,  0.,  0.,  0.,  0.,
         4.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         2.,  1., 13.,  0.,  1.,  1.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,
         0.,  0.,  0.,  2.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  2.,  2.,  1.,  0.,  0.,  0.,  6.,  0.,  2.,  0.,  0.,
         0.,  0.,  1.,  1.,  0.,  1.,  0.,  2.,  0.,  0.,  1.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  1.,  0.,  1.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  2.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  4.,  0.,  6.,  1.,  0.,  0.,  0.,  2.,  0.,  2.,
         1.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  5.,  0.,  7.,  1.,  0.,
         0.,  0.,  0.,  1., 10.,  1.,  6.,  0.,  0.,  1.,  4.,  0.,  0.,
         0.,  0.,  0.,  2.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,
         0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  0.,  0.,  0.,  0.,  1.,
         0.,  0.,  0.,  0.,  1.,  2.,  1.,  0.,  0.,  0.,  4.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  3.,  0.,  0.,  0.,  0.,
         0.,  0.,  2.,  0.,  0.,  0.,  3.,  0.,  0.,  3.,  0.,  0.,  0.,
         0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  2.,
         0.,  6.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  2.,  0.,  0.,  0.,  1.,
         1.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  2.,
         0.,  0.,  0.,  0.,  0.,  1.,  3.,  4.,  0.,  0.,  0.,  1.,  0.,
        10.,  0.,  1.,  0.,  1.,  1.,  0.,  3.,  0.,  0.,  0.,  0.,  0.,
         1.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  0., 15.,  0.,  6.,  0.,
         1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,
         0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  2.,  3.,  0.,  6.,
         1.,  1.,  0.,  1.,  1.,  0.,  2.,  0.,  1.,  0.,  0.,  0.,  1.,
         1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
         0.,  4.,  0.,  1.,  1.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  2.,
         1.,  1.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  2.,  0.,  1.,  0.,
         0.,  0., 53.,  1.,  0.,  1.,  0., 35.]),
 3]

The labels are encoded into integers.

ds_mapped.encoders
[{'CD4-positive helper T cell': 0,
  'classical monocyte': 1,
  'megakaryocyte': 2,
  'memory B cell': 3,
  'gamma-delta T cell': 4,
  'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 5,
  'CD16-negative, CD56-bright natural killer cell, human': 6,
  'group 3 innate lymphoid cell': 7,
  'cytotoxic T cell': 8,
  'lymphocyte': 9,
  'dendritic cell': 10,
  'naive thymus-derived CD8-positive, alpha-beta T cell': 11,
  'macrophage': 12,
  'progenitor cell': 13,
  'mucosal invariant T cell': 14,
  'plasmablast': 15,
  'B cell, CD19-positive': 16,
  'naive thymus-derived CD4-positive, alpha-beta T cell': 17,
  'alveolar macrophage': 18,
  'CD38-positive naive B cell': 19,
  'non-classical monocyte': 20,
  'dendritic cell, human': 21,
  'plasma cell': 22,
  'effector memory CD4-positive, alpha-beta T cell': 23,
  'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 24,
  'CD16-positive, CD56-dim natural killer cell, human': 25,
  'conventional dendritic cell': 26,
  'CD8-positive, alpha-beta memory T cell': 27,
  'animal cell': 28,
  'CD14-positive, CD16-negative classical monocyte': 29,
  'germinal center B cell': 30,
  'CD4-positive, alpha-beta T cell': 31,
  'mast cell': 32,
  'alpha-beta T cell': 33,
  'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 34,
  'T follicular helper cell': 35,
  'naive B cell': 36,
  'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 37,
  'plasmacytoid dendritic cell': 38,
  'regulatory T cell': 39}]

Let us use a weighted sampler:

# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
    weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
)
dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)
for batch in dl:
    pass
# clean up test instance
!lamin delete --force test-scrna
!rm -r ./test-scrna
💡 deleting instance testuser1/test-scrna
✅     deleted instance settings file: /home/runner/.lamin/instance--testuser1--test-scrna.env
✅     instance cache deleted
✅     deleted '.lndb' sqlite file
❗     consider manually deleting your stored data: /home/runner/work/lamin-usecases/lamin-usecases/docs/test-scrna