{
"cells": [
{
"cell_type": "markdown",
"id": "aa1e816b",
"metadata": {},
"source": [
"# Spleen Dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "336c1897",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tutorial environment initialized\n"
]
}
],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"import os\n",
"import sys\n",
"sys.path.insert(0, os.path.join(os.getcwd(), 'DREAM_stage1'))\n",
"os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n",
"os.environ[\"MKL_NUM_THREADS\"] = \"1\"\n",
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
"\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"torch.use_deterministic_algorithms(True)\n",
"\n",
"from notebook_pipeline import (\n",
" SpleenTutorialConfig,\n",
" prepare_adata,\n",
" attach_pseudo_labels,\n",
" build_model_and_optimizer,\n",
" train_embedding,\n",
" cluster_and_report,\n",
")\n",
"\n",
"print(\"Tutorial environment initialized\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f3610728",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SpleenTutorialConfig(data_dir='/home/zhangdaijun/Data/scNiche_data/', batches=['BALBc-1', 'BALBc-2', 'BALBc-3'], dnn_model='/home/zhangdaijun/Code/spatialID-main/result/Spleen/best_DNN_model_all.pth', seed=2025, n_layers=4, agg_method='Mean', prune_long_links=False, model_name='Muti', gae_dim=[128, 64], dae_dim=[128, 64], feat_dim=64, include_cat_covariates_contrastive_loss=False, epochs=1000, optimizer='Adam', use_dnn=True, lr=0.001, attr_loss_weight=1.0, bottleneck=False, n_attributes=1, edge_weight=True, kd_T=1, w_dae=1.0, w_gae=1.0, n_cluster=4, batch_size=4096, weight_decay=0.0001, scheduler_step=20, device=device(type='cuda'))\n"
]
}
],
"source": [
"cfg = SpleenTutorialConfig(\n",
" data_dir='/home//Data/',\n",
" dnn_model='/home//Result/Spleen/DNN_model.pth',\n",
" batches=['BALBc-1', 'BALBc-2', 'BALBc-3'],\n",
" epochs=1000,\n",
" n_cluster=4,\n",
")\n",
"\n",
"print(cfg)"
]
},
{
"cell_type": "markdown",
"id": "9a15194e",
"metadata": {},
"source": [
"## Step 1: Load and preprocess multiple batch data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "15c5b5fc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [00:00<00:00, 12.91it/s]\n",
"100%|██████████| 5/5 [00:00<00:00, 13.90it/s]\n",
"100%|██████████| 5/5 [00:00<00:00, 13.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"adata shape: (244233, 30)\n",
"batches: ['BALBc-1', 'BALBc-2', 'BALBc-3']\n"
]
}
],
"source": [
"adata = prepare_adata(cfg)\n",
"print(f\"adata shape: {adata.shape}\")\n",
"print(f\"batches: {adata.obs['batch'].cat.categories.tolist()}\")"
]
},
{
"cell_type": "markdown",
"id": "6f647570",
"metadata": {},
"source": [
"## Step 2: Generate pseudo labels using DNN"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3317656c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" batch | \n",
" pseudo_class | \n",
"
\n",
" \n",
" \n",
" \n",
" | BALBc_1_Cell1 | \n",
" BALBc-1 | \n",
" CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma | \n",
"
\n",
" \n",
" | BALBc_1_Cell2 | \n",
" BALBc-1 | \n",
" ERTR7(+) stroma | \n",
"
\n",
" \n",
" | BALBc_1_Cell3 | \n",
" BALBc-1 | \n",
" ERTR7(+) stroma | \n",
"
\n",
" \n",
" | BALBc_1_Cell4 | \n",
" BALBc-1 | \n",
" CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma | \n",
"
\n",
" \n",
" | BALBc_1_Cell5 | \n",
" BALBc-1 | \n",
" ERTR7(+) stroma | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" batch pseudo_class\n",
"BALBc_1_Cell1 BALBc-1 CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma\n",
"BALBc_1_Cell2 BALBc-1 ERTR7(+) stroma\n",
"BALBc_1_Cell3 BALBc-1 ERTR7(+) stroma\n",
"BALBc_1_Cell4 BALBc-1 CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma\n",
"BALBc_1_Cell5 BALBc-1 ERTR7(+) stroma"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"attach_pseudo_labels(adata, cfg)\n",
"adata.obs[['batch', 'pseudo_class']].head()"
]
},
{
"cell_type": "markdown",
"id": "7a58a540",
"metadata": {},
"source": [
"## Step 3: Build model and optimizer"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "56fde8d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cat_covariates_keys: ['batch']\n",
"CATEGORICAL COVARIATES EMBEDDINGS INJECTION -> ['decoder']\n",
"self.include_cat_covariates_contrastive_loss False\n",
"['decoder']\n",
"Decoder embedding effective!\n",
"SpatialModel_cov\n"
]
}
],
"source": [
"model, optimizer, scheduler = build_model_and_optimizer(adata, cfg)\n",
"print(model.__class__.__name__)\n"
]
},
{
"cell_type": "markdown",
"id": "d138504e",
"metadata": {},
"source": [
"## Step 4: Train and write back latent representations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ba3c520a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training Epoch: 100%|██████████| 1000/1000 [15:01<00:00, 1.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training epochs: 1000\n",
"Final total loss: 148503168.0000\n"
]
}
],
"source": [
"history = train_embedding(model, optimizer, adata, cfg)\n",
"print(f\"Training epochs: {len(history['losses'])}\")\n",
"print(f\"Final total loss: {history['losses'][-1]:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "35d405d5",
"metadata": {},
"source": [
"## Step 5: GMM clustering and ARI/NMI summary per batch"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0efe211d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" batch | \n",
" ari | \n",
" nmi | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" BALBc-1 | \n",
" 0.619094 | \n",
" 0.592073 | \n",
"
\n",
" \n",
" | 1 | \n",
" BALBc-2 | \n",
" 0.657278 | \n",
" 0.618800 | \n",
"
\n",
" \n",
" | 2 | \n",
" BALBc-3 | \n",
" 0.564992 | \n",
" 0.564693 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" batch ari nmi\n",
"0 BALBc-1 0.619094 0.592073\n",
"1 BALBc-2 0.657278 0.618800\n",
"2 BALBc-3 0.564992 0.564693"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metrics_df = cluster_and_report(\n",
" adata,\n",
" batches=cfg.batches,\n",
" n_cluster=cfg.n_cluster,\n",
" seed=cfg.seed,\n",
" pred_key=\"GM\",\n",
" truth_key=\"Compartment\",\n",
")\n",
"metrics_df\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "DREAM_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}