Tutorial 3 Querying niches across different technologies
In this section, we conducted niche query experiments on a Mouse Olfactory Bulb Tissue (MOBT) dataset.
[ ]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import logging
import warnings
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
[ ]:
import sys, os
import quest.utils as utils
from quest.trainer import QueSTTrainer
[ ]:
from quest.trainer import QueSTTrainer
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)
Read and explore the data
[ ]:
dataset = "MouseOlfactoryBulbTissue"
data_path = "../data/MOBT"
model_path = "../results/MOBT/model/quest_model.pth"
sample_ids = ["Stereo-seq", "10X", "Slide-seq V2"]
adata_list = [sc.read_h5ad(f"{data_path}/{data_id}.h5ad") for data_id in sample_ids]
This dataset included three samples generated by 10X Visium, Stereo-seq, and Slide-seq V2 technologies, with spot radius of 50 μm, 35 μm, and 10 μm, respectively.
[ ]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
axes = axes.flatten()
for i, adata in enumerate(adata_list):
spot_size = utils.get_spot_size(dataset=dataset, ref_id=adata.uns['library_id'])
sc.pl.spatial(adata, color='layer', spot_size=spot_size, ax=axes[i], title=adata.uns['library_id'], show=False)
axes[i].invert_yaxis()
fig.tight_layout()
fig.show()
We set the interface of the GCL and MCL layers as the query niche.
[ ]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
axes = axes.flatten()
for i, adata in enumerate(adata_list):
spot_size = utils.get_spot_size(dataset=dataset, ref_id=adata.uns['library_id'])
if adata.uns['library_id'] == "Stereo-seq":
sc.pl.spatial(adata, color="GCL_MCL_30_layer", spot_size=spot_size, ax=axes[i],
title="Query niche GCL_MCL_30 on Stereo-seq sample", show=False)
else:
sc.pl.spatial(adata, color='GCL_MCL_30_region_matching_score', spot_size=spot_size, ax=axes[i],
title=f"Region Matching Score on {adata.uns['library_id']} sample", show=False)
axes[i].invert_yaxis()
fig.tight_layout()
fig.show()
Configure the niche query task
[ ]:
query_sample_id = "Stereo-seq"
query_niches = ['GCL_MCL_30_niche']
Set up QueST trainer
[ ]:
trainer = QueSTTrainer(dataset=dataset, data_path=data_path, sample_ids=sample_ids, adata_list=adata_list,
query_niches=query_niches, query_sample_id=query_sample_id,
model_path=model_path,
epochs=20, save_model=True, hvg=None, min_count=0, normalize=False)
Train QueST model
We also provided pretrained QueST model checkpoint weights at https://cloud.tsinghua.edu.cn/d/ab0cc439245a4e1f99dd/. To skip training and use pretrained checkpoint, simiply put it at corresponding model path and comment the “trainer.train()” line.
[ ]:
trainer.train()
trainer.inference(ckpt_path=model_path, save_embedding=False, query=True)
computing 3-hop subgraph (Stereo-seq): 100%|██████████| 8762/8762 [00:05<00:00, 1686.03it/s]
computing 3-hop subgraph (10X): 100%|██████████| 1185/1185 [00:00<00:00, 1758.54it/s]
computing 3-hop subgraph (Slide-seq V2): 100%|██████████| 18173/18173 [00:10<00:00, 1740.68it/s]
training: 100%|██████████| 20/20 [04:04<00:00, 12.22s/epoch]
[ ]:
fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(10, 5))
axes = axes.flatten()
query_niche = query_niches[0]
adata_ref = trainer.adata_ref_list[0]
for i, adata_ref in enumerate(trainer.adata_ref_list):
spot_size = utils.get_spot_size(dataset, adata_ref.uns['library_id'])
sc.pl.spatial(adata_ref, color=f"{query_niche} predicted matching score", ax=axes[i], show=False,
spot_size=spot_size, title=f"{adata_ref.uns['library_id']}")
axes[i].invert_yaxis()
fig.tight_layout()
fig.show()