
scvi-tools 深度学习单细胞分析:功能详解
Written By

技能练习生
深入掌握 scvi-tools 的核心功能
本章系统讲解 scvi-tools 各模块的详细用法,帮助你建立完整的知识体系,应对复杂的研究需求。
3.1 模型选择指南
3.1.1 scRNA-seq 数据整合
scVI:无监督整合的瑞士军刀
适用场景:
- 整合 3+ 个批次的数据
- 批次效应较强(如跨平台、跨实验室)
- 无需细胞类型标签
- 需要进行差异表达分析
核心特性:
import scvi
# 基础用法
scvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key='batch')
model = scvi.model.SCVI(adata)
# 关键参数
model = scvi.model.SCVI(
adata,
latent_dim=30, # 潜在维度:20-50(默认 10)
n_layers=2, # 神经网络层数:1-3
n_hidden=128, # 隐藏层大小:64-256
gene_likelihood='nb', # 分布假设:'nb'(负二项)或 'zinb'
dispersion='gene-batch', # 离散度参数:'gene', 'gene-batch'
)超参数调优策略:
| 参数 | 增大效果 | 减小效果 | 推荐值 |
|---|---|---|---|
latent_dim | 保留更多信息,过拟合风险↑ | 丢失细微差异 | 20-40 |
n_layers | 模型更复杂,训练时间↑ | 欠拟合风险 | 2-3 |
n_hidden | 学习复杂模式 | 无法捕获非线性关系 | 128 |
训练技巧:
# 渐进式训练(快速收敛)
model.train(max_epochs=50) # 初始训练
model.train(max_epochs=100, plan_kwargs={'lr': 1e-3}) # 降低学习率继续训练
# 早停策略
model.train(
max_epochs=400,
early_stopping=True,
early_stopping_patience=15, # 15 轮无改善则停止
early_stopping_min_delta=0.001
)scANVI:半监督标签转移
适用场景:
- 有已注释的参考数据集
- 需要将标签转移到新数据
- 只有部分细胞有标签(半监督场景)
核心特性:
# 准备数据:参考数据有标签,查询数据标签设为 'Unknown'
adata.obs['cell_type'] = adata.obs['cell_type'].astype(str)
adata.obs.loc[adata.obs['batch'] == 'query', 'cell_type'] = 'Unknown'
# 设置模型
scvi.model.SCANVI.setup_anndata(
adata,
layer='counts',
batch_key='batch',
labels_key='cell_type', # 标签列
unlabeled_category='Unknown' # 未标记类别
)
# 从 scVI 模型初始化(加速收敛)
scvi_model = scvi.model.SCVI(adata)
scvi_model.train(max_epochs=100)
scanvi_model = scvi.model.SCANVI.from_scvi_model(
scvi_model,
unlabeled_category='Unknown'
)
scanvi_model.train(max_epochs=50)
# 获取预测标签和概率
adata.obs['predicted_cell_type'] = scanvi_model.predict()
adata.obs['prediction_confidence'] = scanvi_model.predict(soft=True).max(axis=1)评估标签转移质量:
# 如果有部分真实标签,可计算准确率
from sklearn.metrics import accuracy_score, classification_report
mask = adata.obs['cell_type'] != 'Unknown'
true_labels = adata.obs.loc[mask, 'cell_type']
pred_labels = adata.obs.loc[mask, 'predicted_cell_type']
print(f"准确率:{accuracy_score(true_labels, pred_labels):.2%}")
print(classification_report(true_labels, pred_labels))3.1.2 多模态数据分析
totalVI:CITE-seq(RNA + 蛋白质)
适用场景:
- CITE-seq 数据(TotalSeq 抗体标记)
- REAP-seq 数据
- 需要联合分析 RNA 和蛋白质
数据准备:
# AnnData 结构要求:
# adata.X -> RNA 计数
# adata.obsm['protein_counts'] -> 蛋白质计数 (n_cells × n_proteins)
import scvi
# 设置模型
scvi.model.TOTALVI.setup_anndata(
adata,
batch_key='batch',
protein_expression_obsm_key='protein_counts' # 蛋白数据位置
)
# 训练
model = scvi.model.TOTALVI(adata)
model.train(max_epochs=400)
# 获取去噪的蛋白质表达(纠正技术噪声)
protein_expression = model.get_normalized_expression(
adata,
transform='log', # 对数变换
get_protein_background=True,
include_n_background_proteins=10
)
adata.obsm['protein_denoised'] = protein_expression[1]应用案例:
# 比较原始与去噪后的蛋白质表达
import matplotlib.pyplot as plt
protein_idx = 0 # 选择某个蛋白
plt.scatter(
adata.obsm['protein_counts'][:, protein_idx],
adata.obsm['protein_denoised'][:, protein_idx],
alpha=0.5
)
plt.xlabel('Raw counts')
plt.ylabel('Denoised expression')
plt.title('TotalVI denoising')
plt.show()MultiVI:Multiome(RNA + ATAC)
适用场景:
- 10X Multiome 数据
- 联合分析基因表达和染色质可及性
- 推断基因调控关系
数据准备:
# AnnData 结构要求:
# adata.X -> RNA 计数
# adata.obsm['mode2'] -> ATAC 峰计数 (n_cells × n_peaks)
scvi.model.MULTIVI.setup_anndata(
adata,
batch_key='batch',
modalities_key='modality' # 标识每个细胞是哪种模态
)
# 训练
model = scvi.model.MULTIVI(
adata,
n_genes=adata.shape[1],
n_regions=adata.obsm['mode2'].shape[1]
)
model.train(max_epochs=400)
# 获取联合潜在表示
adata.obsm["X_multivi"] = model.get_latent_representation()
# 分析模态特异性
modality_specificity = model.get_modality_specificity()
# 值接近 0:RNA 和 ATAC 一致
# 值接近 1:模态特异性3.1.3 空间转录组与参考映射
DestVI:空间反卷积
适用场景:
- 空间转录组数据(Visium、Slide-seq)
- 需要解析每个 spot 的细胞类型组成
- 有匹配的 scRNA-seq 参考数据
工作流程:
# 1. 准备参考数据(scRNA-seq)
scvi.model.SCVI.setup_anndata(adata_ref, layer='counts', batch_key='batch')
ref_model = scvi.model.SCVI(adata_ref)
ref_model.train()
# 2. 准备空间数据
scvi.model.DESTVI.setup_anndata(
adata_spatial,
layer='counts',
batch_key='batch'
)
# 3. 创建 DestVI 模型
spatial_model = scvi.model.DESTVI.from_scvi_model(ref_model, adata_spatial)
spatial_model.train(max_epochs=200)
# 4. 获取细胞类型比例
cell_type_abundance = spatial_model.get_proportions()
# shape: (n_spots, n_cell_types)
adata_spatial.obs['cell_type_composition'] = [
{ct: prop for ct, prop in zip(cell_types, row)}
for row in cell_type_abundance
]scArches:查询数据映射
适用场景:
- 有大规模预训练参考模型
- 需要将新数据快速映射到参考空间
- 无需重新训练整个模型
优势:
- 训练速度提升 10-100 倍
- 内存占用减少 90%+
- 适合频繁更新参考图谱
# 1. 保存预训练的参考模型
ref_model = scvi.model.SCVI(adata_ref)
ref_model.train()
ref_model.save('reference_model/', overwrite=True)
# 2. 加载新数据并映射
adata_query = sc.read_h5ad('new_data.h5ad')
# 使用 scArches 方式训练(冻结部分网络)
query_model = scvi.model.SCVI.load_query_data(
'reference_model/',
adata_query,
freeze_dropout=True, # 冻结 dropout
freeze_expression=True, # 冻结编码器部分层
)
query_model.train(max_epochs=50) # 快速训练
# 3. 获取映射后的表示
adata_query.obsm['X_mapped'] = query_model.get_latent_representation()3.2 差异表达分析
3.2.1 scVI 的 DE 分析
为什么优于传统方法:
- 基于零膨胀负二项分布,更符合单细胞数据特性
- 自动校正批次效应
- 提供贝叶斯因子,比 p 值更稳健
# 簇间差异表达
de_results = model.differential_expression(
groupby='leiden',
group1='0', # 目标簇
group2='1', # 对照簇
mode='vanilla', # 'vanilla': 简单比较,'change': 最小检测限
batch_correction=True # 批次校正
)
# 结果列:
# - lfc_mean: 对数折叠变化(后验均值)
# - lfc_median: 对数折叠变化(后验中位数)
# - prob_de: 差异表达概率
# - bayes_factor: 贝叶斯因子(>3 为强证据)
# 筛选显著基因
significant_genes = de_results[
(de_results['bayes_factor'] > 3) &
(abs(de_results['lfc_mean']) > 0.5)
]
print(significant_genes.head(10))3.2.2 多组比较
# 找到每个簇的标志基因
from scvi.tools import differential_expression
cluster_markers = {}
for cluster in adata.obs['leiden'].unique():
de_results = model.differential_expression(
groupby='leiden',
group1=cluster,
group2='rest' # 与其他所有簇比较
)
cluster_markers[cluster] = de_results[
de_results['bayes_factor'] > 3
].head(20)3.3 高级训练技巧
3.3.1 GPU 加速配置
# 检查 GPU 可用性
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
# 指定 GPU 设备
model.train(use_gpu=True, gpus=[0]) # 使用第一块 GPU
# 多 GPU 训练(大数据集)
model.train(use_gpu=True, gpus=[0, 1], accelerator='ddp')3.3.2 训练监控
# 自定义回调
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(
dirpath='checkpoints/',
monitor='elbo_validation',
save_top_k=1,
mode='min'
)
model.train(
max_epochs=400,
callbacks=[checkpoint],
check_val_every_n_epoch=5 # 每 5 轮验证一次
)
# 重新加载最佳模型
model = scvi.model.SCVI.load('checkpoints/last.ckpt', adata)3.3.3 处理超大数据集
# 数据流式加载(内存不足时)
model.train(
max_epochs=400,
data_loader_kwargs={
'batch_size': 1024, # 增大批大小
'num_workers': 4 # 多进程加载
}
)
# 分层训练(先训练子集)
adata_subset = adata[:10000, :] # 训练 1 万细胞
model.train(max_epochs=200)
# 再用全数据微调
model.train(max_epochs=100)3.4 模型解释与可视化
3.4.1 潜在空间探索
# 获取不确定性估计
latent, uncertainty = model.get_latent_representation(
adata,
give_mean=False, # 返回采样而非均值
return_dist=True # 返回均值和标准差
)
adata.obsm['X_scvi_uncertainty'] = uncertainty
sc.pl.umap(adata, color='X_scvi_uncertainty', cmap='Reds')3.4.2 特征重要性分析
# 获取每个细胞类型的标志基因重要性
importance = model.get_feature_correlation(
adata,
groupby='leiden',
group1='0'
)
# 可视化 top 基因
import seaborn as sns
top_genes = importance.nlargest(20, importance.columns[0])
sns.barplot(x=top_genes.iloc[:, 0], y=top_genes.index)
plt.title('Top contributing genes for cluster 0')
plt.show()3.5 性能优化清单
| 优化方向 | 具体措施 | 预期提升 |
|---|---|---|
| 训练速度 | 使用 GPU | 10-50x |
| 减少高变基因数量 | 2-5x | |
| 增大批大小 | 1.5-3x | |
| 内存占用 | 启用数据流式加载 | 减少 50-70% |
| 降低 latent_dim | 减少 20-30% | |
| 模型质量 | 增加训练轮数 | 提升收敛性 |
| 使用早停策略 | 避免过拟合 | |
| 调整学习率 | 更稳定的收敛 |
3.6 常用命令速查
# ============ 模型设置 ============
scvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key='batch')
scvi.model.SCANVI.setup_anndata(adata, layer='counts', batch_key='batch', labels_key='cell_type', unlabeled_category='Unknown')
scvi.model.TOTALVI.setup_anndata(adata, batch_key='batch', protein_expression_obsm_key='protein_counts')
# ============ 模型训练 ============
model.train(max_epochs=400, batch_size=128, use_gpu=True, early_stopping=True)
# ============ 获取表示 =-----------
adata.obsm["X_scvi"] = model.get_latent_representation()
denoised = model.get_normalized_expression(adata, transform='log')
# ============ 差异表达 ============
de_results = model.differential_expression(groupby='leiden', group1='0', group2='1')
# ============ 模型保存 ============
model.save('model_dir/', overwrite=True)
loaded_model = scvi.model.SCVI.load('model_dir/', adata)下一章
现在你已经掌握了 scvi-tools 的核心功能,让我们通过真实的科研案例看看这些工具如何解决实际问题。
→ 继续阅读:第四章 - 应用案例