scvi-tools 深度学习单细胞分析:功能详解

scvi-tools 深度学习单细胞分析:功能详解

Written By
技能练习生
技能练习生

深入掌握 scvi-tools 的核心功能

本章系统讲解 scvi-tools 各模块的详细用法,帮助你建立完整的知识体系,应对复杂的研究需求。

3.1 模型选择指南

3.1.1 scRNA-seq 数据整合

scVI:无监督整合的瑞士军刀

适用场景

  • 整合 3+ 个批次的数据
  • 批次效应较强(如跨平台、跨实验室)
  • 无需细胞类型标签
  • 需要进行差异表达分析

核心特性

import scvi

# 基础用法
scvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key='batch')
model = scvi.model.SCVI(adata)

# 关键参数
model = scvi.model.SCVI(
    adata,
    latent_dim=30,        # 潜在维度:20-50(默认 10)
    n_layers=2,           # 神经网络层数:1-3
    n_hidden=128,         # 隐藏层大小:64-256
    gene_likelihood='nb',  # 分布假设:'nb'(负二项)或 'zinb'
    dispersion='gene-batch',  # 离散度参数:'gene', 'gene-batch'
)

超参数调优策略

参数增大效果减小效果推荐值
latent_dim保留更多信息,过拟合风险↑丢失细微差异20-40
n_layers模型更复杂,训练时间↑欠拟合风险2-3
n_hidden学习复杂模式无法捕获非线性关系128

训练技巧

# 渐进式训练(快速收敛)
model.train(max_epochs=50)  # 初始训练
model.train(max_epochs=100, plan_kwargs={'lr': 1e-3})  # 降低学习率继续训练

# 早停策略
model.train(
    max_epochs=400,
    early_stopping=True,
    early_stopping_patience=15,  # 15 轮无改善则停止
    early_stopping_min_delta=0.001
)

scANVI:半监督标签转移

适用场景

  • 有已注释的参考数据集
  • 需要将标签转移到新数据
  • 只有部分细胞有标签(半监督场景)

核心特性

# 准备数据:参考数据有标签,查询数据标签设为 'Unknown'
adata.obs['cell_type'] = adata.obs['cell_type'].astype(str)
adata.obs.loc[adata.obs['batch'] == 'query', 'cell_type'] = 'Unknown'

# 设置模型
scvi.model.SCANVI.setup_anndata(
    adata,
    layer='counts',
    batch_key='batch',
    labels_key='cell_type',  # 标签列
    unlabeled_category='Unknown'  # 未标记类别
)

# 从 scVI 模型初始化(加速收敛)
scvi_model = scvi.model.SCVI(adata)
scvi_model.train(max_epochs=100)
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    scvi_model,
    unlabeled_category='Unknown'
)
scanvi_model.train(max_epochs=50)

# 获取预测标签和概率
adata.obs['predicted_cell_type'] = scanvi_model.predict()
adata.obs['prediction_confidence'] = scanvi_model.predict(soft=True).max(axis=1)

评估标签转移质量

# 如果有部分真实标签,可计算准确率
from sklearn.metrics import accuracy_score, classification_report

mask = adata.obs['cell_type'] != 'Unknown'
true_labels = adata.obs.loc[mask, 'cell_type']
pred_labels = adata.obs.loc[mask, 'predicted_cell_type']

print(f"准确率:{accuracy_score(true_labels, pred_labels):.2%}")
print(classification_report(true_labels, pred_labels))

3.1.2 多模态数据分析

totalVI:CITE-seq(RNA + 蛋白质)

适用场景

  • CITE-seq 数据(TotalSeq 抗体标记)
  • REAP-seq 数据
  • 需要联合分析 RNA 和蛋白质

数据准备

# AnnData 结构要求:
# adata.X -> RNA 计数
# adata.obsm['protein_counts'] -> 蛋白质计数 (n_cells × n_proteins)

import scvi

# 设置模型
scvi.model.TOTALVI.setup_anndata(
    adata,
    batch_key='batch',
    protein_expression_obsm_key='protein_counts'  # 蛋白数据位置
)

# 训练
model = scvi.model.TOTALVI(adata)
model.train(max_epochs=400)

# 获取去噪的蛋白质表达(纠正技术噪声)
protein_expression = model.get_normalized_expression(
    adata,
    transform='log',  # 对数变换
    get_protein_background=True,
    include_n_background_proteins=10
)

adata.obsm['protein_denoised'] = protein_expression[1]

应用案例

# 比较原始与去噪后的蛋白质表达
import matplotlib.pyplot as plt

protein_idx = 0  # 选择某个蛋白
plt.scatter(
    adata.obsm['protein_counts'][:, protein_idx],
    adata.obsm['protein_denoised'][:, protein_idx],
    alpha=0.5
)
plt.xlabel('Raw counts')
plt.ylabel('Denoised expression')
plt.title('TotalVI denoising')
plt.show()

MultiVI:Multiome(RNA + ATAC)

适用场景

  • 10X Multiome 数据
  • 联合分析基因表达和染色质可及性
  • 推断基因调控关系

数据准备

# AnnData 结构要求:
# adata.X -> RNA 计数
# adata.obsm['mode2'] -> ATAC 峰计数 (n_cells × n_peaks)

scvi.model.MULTIVI.setup_anndata(
    adata,
    batch_key='batch',
    modalities_key='modality'  # 标识每个细胞是哪种模态
)

# 训练
model = scvi.model.MULTIVI(
    adata,
    n_genes=adata.shape[1],
    n_regions=adata.obsm['mode2'].shape[1]
)
model.train(max_epochs=400)

# 获取联合潜在表示
adata.obsm["X_multivi"] = model.get_latent_representation()

# 分析模态特异性
modality_specificity = model.get_modality_specificity()
# 值接近 0:RNA 和 ATAC 一致
# 值接近 1:模态特异性

3.1.3 空间转录组与参考映射

DestVI:空间反卷积

适用场景

  • 空间转录组数据(Visium、Slide-seq)
  • 需要解析每个 spot 的细胞类型组成
  • 有匹配的 scRNA-seq 参考数据

工作流程

# 1. 准备参考数据(scRNA-seq)
scvi.model.SCVI.setup_anndata(adata_ref, layer='counts', batch_key='batch')
ref_model = scvi.model.SCVI(adata_ref)
ref_model.train()

# 2. 准备空间数据
scvi.model.DESTVI.setup_anndata(
    adata_spatial,
    layer='counts',
    batch_key='batch'
)

# 3. 创建 DestVI 模型
spatial_model = scvi.model.DESTVI.from_scvi_model(ref_model, adata_spatial)
spatial_model.train(max_epochs=200)

# 4. 获取细胞类型比例
cell_type_abundance = spatial_model.get_proportions()
# shape: (n_spots, n_cell_types)

adata_spatial.obs['cell_type_composition'] = [
    {ct: prop for ct, prop in zip(cell_types, row)}
    for row in cell_type_abundance
]

scArches:查询数据映射

适用场景

  • 有大规模预训练参考模型
  • 需要将新数据快速映射到参考空间
  • 无需重新训练整个模型

优势

  • 训练速度提升 10-100 倍
  • 内存占用减少 90%+
  • 适合频繁更新参考图谱
# 1. 保存预训练的参考模型
ref_model = scvi.model.SCVI(adata_ref)
ref_model.train()
ref_model.save('reference_model/', overwrite=True)

# 2. 加载新数据并映射
adata_query = sc.read_h5ad('new_data.h5ad')

# 使用 scArches 方式训练(冻结部分网络)
query_model = scvi.model.SCVI.load_query_data(
    'reference_model/',
    adata_query,
    freeze_dropout=True,  # 冻结 dropout
    freeze_expression=True,  # 冻结编码器部分层
)

query_model.train(max_epochs=50)  # 快速训练

# 3. 获取映射后的表示
adata_query.obsm['X_mapped'] = query_model.get_latent_representation()

3.2 差异表达分析

3.2.1 scVI 的 DE 分析

为什么优于传统方法

  • 基于零膨胀负二项分布,更符合单细胞数据特性
  • 自动校正批次效应
  • 提供贝叶斯因子,比 p 值更稳健
# 簇间差异表达
de_results = model.differential_expression(
    groupby='leiden',
    group1='0',  # 目标簇
    group2='1',  # 对照簇
    mode='vanilla',  # 'vanilla': 简单比较,'change': 最小检测限
    batch_correction=True  # 批次校正
)

# 结果列:
# - lfc_mean: 对数折叠变化(后验均值)
# - lfc_median: 对数折叠变化(后验中位数)
# - prob_de: 差异表达概率
# - bayes_factor: 贝叶斯因子(>3 为强证据)

# 筛选显著基因
significant_genes = de_results[
    (de_results['bayes_factor'] > 3) &
    (abs(de_results['lfc_mean']) > 0.5)
]
print(significant_genes.head(10))

3.2.2 多组比较

# 找到每个簇的标志基因
from scvi.tools import differential_expression

cluster_markers = {}
for cluster in adata.obs['leiden'].unique():
    de_results = model.differential_expression(
        groupby='leiden',
        group1=cluster,
        group2='rest'  # 与其他所有簇比较
    )
    cluster_markers[cluster] = de_results[
        de_results['bayes_factor'] > 3
    ].head(20)

3.3 高级训练技巧

3.3.1 GPU 加速配置

# 检查 GPU 可用性
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")

# 指定 GPU 设备
model.train(use_gpu=True, gpus=[0])  # 使用第一块 GPU

# 多 GPU 训练(大数据集)
model.train(use_gpu=True, gpus=[0, 1], accelerator='ddp')

3.3.2 训练监控

# 自定义回调
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(
    dirpath='checkpoints/',
    monitor='elbo_validation',
    save_top_k=1,
    mode='min'
)

model.train(
    max_epochs=400,
    callbacks=[checkpoint],
    check_val_every_n_epoch=5  # 每 5 轮验证一次
)

# 重新加载最佳模型
model = scvi.model.SCVI.load('checkpoints/last.ckpt', adata)

3.3.3 处理超大数据集

# 数据流式加载(内存不足时)
model.train(
    max_epochs=400,
    data_loader_kwargs={
        'batch_size': 1024,  # 增大批大小
        'num_workers': 4  # 多进程加载
    }
)

# 分层训练(先训练子集)
adata_subset = adata[:10000, :]  # 训练 1 万细胞
model.train(max_epochs=200)
# 再用全数据微调
model.train(max_epochs=100)

3.4 模型解释与可视化

3.4.1 潜在空间探索

# 获取不确定性估计
latent, uncertainty = model.get_latent_representation(
    adata,
    give_mean=False,  # 返回采样而非均值
    return_dist=True  # 返回均值和标准差
)

adata.obsm['X_scvi_uncertainty'] = uncertainty
sc.pl.umap(adata, color='X_scvi_uncertainty', cmap='Reds')

3.4.2 特征重要性分析

# 获取每个细胞类型的标志基因重要性
importance = model.get_feature_correlation(
    adata,
    groupby='leiden',
    group1='0'
)

# 可视化 top 基因
import seaborn as sns
top_genes = importance.nlargest(20, importance.columns[0])
sns.barplot(x=top_genes.iloc[:, 0], y=top_genes.index)
plt.title('Top contributing genes for cluster 0')
plt.show()

3.5 性能优化清单

优化方向具体措施预期提升
训练速度使用 GPU10-50x
减少高变基因数量2-5x
增大批大小1.5-3x
内存占用启用数据流式加载减少 50-70%
降低 latent_dim减少 20-30%
模型质量增加训练轮数提升收敛性
使用早停策略避免过拟合
调整学习率更稳定的收敛

3.6 常用命令速查

# ============ 模型设置 ============
scvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key='batch')
scvi.model.SCANVI.setup_anndata(adata, layer='counts', batch_key='batch', labels_key='cell_type', unlabeled_category='Unknown')
scvi.model.TOTALVI.setup_anndata(adata, batch_key='batch', protein_expression_obsm_key='protein_counts')

# ============ 模型训练 ============
model.train(max_epochs=400, batch_size=128, use_gpu=True, early_stopping=True)

# ============ 获取表示 =-----------
adata.obsm["X_scvi"] = model.get_latent_representation()
denoised = model.get_normalized_expression(adata, transform='log')

# ============ 差异表达 ============
de_results = model.differential_expression(groupby='leiden', group1='0', group2='1')

# ============ 模型保存 ============
model.save('model_dir/', overwrite=True)
loaded_model = scvi.model.SCVI.load('model_dir/', adata)

下一章

现在你已经掌握了 scvi-tools 的核心功能,让我们通过真实的科研案例看看这些工具如何解决实际问题。

→ 继续阅读:第四章 - 应用案例