{ "cells": [ { "cell_type": "markdown", "id": "aa1e816b", "metadata": {}, "source": [ "# Spleen Dataset" ] }, { "cell_type": "code", "execution_count": 1, "id": "336c1897", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tutorial environment initialized\n" ] } ], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "import os\n", "import sys\n", "sys.path.insert(0, os.path.join(os.getcwd(), 'DREAM_stage1'))\n", "os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n", "os.environ[\"MKL_NUM_THREADS\"] = \"1\"\n", "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", "\n", "import torch\n", "torch.set_num_threads(1)\n", "torch.use_deterministic_algorithms(True)\n", "\n", "from notebook_pipeline import (\n", " SpleenTutorialConfig,\n", " prepare_adata,\n", " attach_pseudo_labels,\n", " build_model_and_optimizer,\n", " train_embedding,\n", " cluster_and_report,\n", ")\n", "\n", "print(\"Tutorial environment initialized\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "f3610728", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SpleenTutorialConfig(data_dir='/home/zhangdaijun/Data/scNiche_data/', batches=['BALBc-1', 'BALBc-2', 'BALBc-3'], dnn_model='/home/zhangdaijun/Code/spatialID-main/result/Spleen/best_DNN_model_all.pth', seed=2025, n_layers=4, agg_method='Mean', prune_long_links=False, model_name='Muti', gae_dim=[128, 64], dae_dim=[128, 64], feat_dim=64, include_cat_covariates_contrastive_loss=False, epochs=1000, optimizer='Adam', use_dnn=True, lr=0.001, attr_loss_weight=1.0, bottleneck=False, n_attributes=1, edge_weight=True, kd_T=1, w_dae=1.0, w_gae=1.0, n_cluster=4, batch_size=4096, weight_decay=0.0001, scheduler_step=20, device=device(type='cuda'))\n" ] } ], "source": [ "cfg = SpleenTutorialConfig(\n", " data_dir='/home//Data/',\n", " dnn_model='/home//Result/Spleen/DNN_model.pth',\n", " batches=['BALBc-1', 'BALBc-2', 'BALBc-3'],\n", " epochs=1000,\n", " n_cluster=4,\n", ")\n", "\n", "print(cfg)" ] }, { "cell_type": "markdown", "id": "9a15194e", "metadata": {}, "source": [ "## Step 1: Load and preprocess multiple batch data" ] }, { "cell_type": "code", "execution_count": 3, "id": "15c5b5fc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [00:00<00:00, 12.91it/s]\n", "100%|██████████| 5/5 [00:00<00:00, 13.90it/s]\n", "100%|██████████| 5/5 [00:00<00:00, 13.95it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "adata shape: (244233, 30)\n", "batches: ['BALBc-1', 'BALBc-2', 'BALBc-3']\n" ] } ], "source": [ "adata = prepare_adata(cfg)\n", "print(f\"adata shape: {adata.shape}\")\n", "print(f\"batches: {adata.obs['batch'].cat.categories.tolist()}\")" ] }, { "cell_type": "markdown", "id": "6f647570", "metadata": {}, "source": [ "## Step 2: Generate pseudo labels using DNN" ] }, { "cell_type": "code", "execution_count": 4, "id": "3317656c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
batchpseudo_class
BALBc_1_Cell1BALBc-1CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma
BALBc_1_Cell2BALBc-1ERTR7(+) stroma
BALBc_1_Cell3BALBc-1ERTR7(+) stroma
BALBc_1_Cell4BALBc-1CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma
BALBc_1_Cell5BALBc-1ERTR7(+) stroma
\n", "
" ], "text/plain": [ " batch pseudo_class\n", "BALBc_1_Cell1 BALBc-1 CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma\n", "BALBc_1_Cell2 BALBc-1 ERTR7(+) stroma\n", "BALBc_1_Cell3 BALBc-1 ERTR7(+) stroma\n", "BALBc_1_Cell4 BALBc-1 CD106(+)CD16/32(+)CD31(-)Ly6C(-) stroma\n", "BALBc_1_Cell5 BALBc-1 ERTR7(+) stroma" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attach_pseudo_labels(adata, cfg)\n", "adata.obs[['batch', 'pseudo_class']].head()" ] }, { "cell_type": "markdown", "id": "7a58a540", "metadata": {}, "source": [ "## Step 3: Build model and optimizer" ] }, { "cell_type": "code", "execution_count": 5, "id": "56fde8d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cat_covariates_keys: ['batch']\n", "CATEGORICAL COVARIATES EMBEDDINGS INJECTION -> ['decoder']\n", "self.include_cat_covariates_contrastive_loss False\n", "['decoder']\n", "Decoder embedding effective!\n", "SpatialModel_cov\n" ] } ], "source": [ "model, optimizer, scheduler = build_model_and_optimizer(adata, cfg)\n", "print(model.__class__.__name__)\n" ] }, { "cell_type": "markdown", "id": "d138504e", "metadata": {}, "source": [ "## Step 4: Train and write back latent representations" ] }, { "cell_type": "code", "execution_count": 6, "id": "ba3c520a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Training Epoch: 100%|██████████| 1000/1000 [15:01<00:00, 1.11it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training epochs: 1000\n", "Final total loss: 148503168.0000\n" ] } ], "source": [ "history = train_embedding(model, optimizer, adata, cfg)\n", "print(f\"Training epochs: {len(history['losses'])}\")\n", "print(f\"Final total loss: {history['losses'][-1]:.4f}\")" ] }, { "cell_type": "markdown", "id": "35d405d5", "metadata": {}, "source": [ "## Step 5: GMM clustering and ARI/NMI summary per batch" ] }, { "cell_type": "code", "execution_count": 7, "id": "0efe211d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
batcharinmi
0BALBc-10.6190940.592073
1BALBc-20.6572780.618800
2BALBc-30.5649920.564693
\n", "
" ], "text/plain": [ " batch ari nmi\n", "0 BALBc-1 0.619094 0.592073\n", "1 BALBc-2 0.657278 0.618800\n", "2 BALBc-3 0.564992 0.564693" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics_df = cluster_and_report(\n", " adata,\n", " batches=cfg.batches,\n", " n_cluster=cfg.n_cluster,\n", " seed=cfg.seed,\n", " pred_key=\"GM\",\n", " truth_key=\"Compartment\",\n", ")\n", "metrics_df\n" ] } ], "metadata": { "kernelspec": { "display_name": "DREAM_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" } }, "nbformat": 4, "nbformat_minor": 5 }