{ "cells": [ { "cell_type": "markdown", "id": "c63b582b", "metadata": {}, "source": [ "# CRCLM Dateset" ] }, { "cell_type": "code", "execution_count": 1, "id": "57dde758", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "import os\n", "os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n", "os.environ[\"MKL_NUM_THREADS\"] = \"1\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"0\")\n", "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", "import torch\n", "torch.set_num_threads(1)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "3a7a0f5f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2026-03-05 11:33:53.186147: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX_VNNI FMA\n", "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2026-03-05 11:33:53.251931: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2026-03-05 11:33:53.530789: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2026-03-05 11:33:53.530822: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2026-03-05 11:33:53.530824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "================================================================================\n", "Spatial Transcriptomics Concept Bottleneck Model + Attribution Tutorial\n", "Reference PAUSE framework\n", "================================================================================\n", "\n", "PyTorch version: 2.7.0+cu118\n", "CUDA available: True\n", "\n" ] } ], "source": [ "import sys\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import scanpy as sc\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import json\n", "from tqdm import tqdm\n", "from sklearn.model_selection import LeaveOneGroupOut\n", "from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,\n", " confusion_matrix, classification_report, roc_auc_score)\n", "ROOT_DIR = os.path.abspath(os.getcwd())\n", "if os.path.isdir(os.path.join(ROOT_DIR, 'DREAM_stage2')):\n", " sys.path.insert(0, ROOT_DIR)\n", "else:\n", " parent = os.path.dirname(ROOT_DIR)\n", " if os.path.isdir(os.path.join(parent, 'DREAM_stage2')):\n", " sys.path.insert(0, parent)\n", " else:\n", " sys.path.insert(0, ROOT_DIR)\n", "\n", "from DREAM_stage2.dataset import prepare_data_loaders\n", "from DREAM_stage2.models import create_end2end_model\n", "from DREAM_stage2.train import accuracy, AverageMeter, train_and_evaluate_fold, run_loocv\n", "\n", "sns.set_palette(\"husl\")\n", "plt.rcParams['figure.dpi'] = 100\n", "\n", "SEED = 2025\n", "torch.manual_seed(SEED)\n", "np.random.seed(SEED)\n", "\n", "print(\"=\"*80)\n", "print(\"Spatial Transcriptomics Concept Bottleneck Model + Attribution Tutorial\")\n", "print(\"Reference PAUSE framework\")\n", "print(\"=\"*80)\n", "print(f\"\\nPyTorch version: {torch.__version__}\")\n", "print(f\"CUDA available: {torch.cuda.is_available()}\\n\")" ] }, { "cell_type": "markdown", "id": "5cd699af", "metadata": {}, "source": [ "## 1. Load Data" ] }, { "cell_type": "code", "execution_count": 3, "id": "8fb400ee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Number of cells: 68261\n", " Number of features: 64\n", " Number of spatial domains: 21\n", " Number of clinical phenotypes: 2\n", " Number of slices: 24\n" ] } ], "source": [ "\n", "DATA_PATH = \"/home/Data/CRCLM/adata_DREAM_Stage.h5ad\"\n", "RAW_DATA_PATH = \"/home/Data/CRCLM/adata_raw.h5ad\"\n", "\n", "adata_scvi = sc.read_h5ad(DATA_PATH)\n", "adata_raw = sc.read_h5ad(RAW_DATA_PATH)\n", "\n", "if 'X_raw' not in adata_scvi.uns:\n", " adata_scvi.uns['X_raw'] = adata_raw.X\n", "\n", "\n", "print(f\" Number of cells: {adata_scvi.obsm['latent'].shape[0]}\")\n", "print(f\" Number of features: {adata_scvi.obsm['latent'].shape[1]}\")\n", "print(f\" Number of spatial domains: {adata_scvi.obs['leiden'].nunique()}\")\n", "print(f\" Number of clinical phenotypes: {adata_scvi.obs['type'].nunique()}\")\n", "print(f\" Number of slices: {adata_scvi.obs['slice_name'].nunique()}\")" ] }, { "cell_type": "markdown", "id": "4f61a010", "metadata": {}, "source": [ "## 2. Prepare Data (no test split)" ] }, { "cell_type": "code", "execution_count": 4, "id": "876df202", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Number of genes: 64\n", " Number of concepts: 21 (0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)\n", " Number of labels: 2 (Metastasis, Primary)\n", " Training slices: 24\n", " Validation slices: 24\n" ] } ], "source": [ "(train_loader, val_loader, all_loader,\n", "# slice_train_loader, slice_val_loader, slice_all_loader,\n", " metadata) = prepare_data_loaders(\n", " DATA_PATH,\n", " val_split=0.2,\n", " seed=SEED,\n", " expr_key='latent',\n", " use_all_data=True # Use all data for training (no val split)\n", ")\n", "\n", "print(f\" Number of genes: {metadata['n_genes']}\")\n", "print(f\" Number of concepts: {metadata['n_concepts']} ({', '.join(metadata['concept_classes'])})\")\n", "print(f\" Number of labels: {metadata['n_labels']} ({', '.join(metadata['label_classes'])})\")\n", "print(f\" Training slices: {len(metadata['train_slices'])}\")\n", "print(f\" Validation slices: {len(metadata['val_slices'])}\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "4972109f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Data preparation complete:\n", " Total cells: 68261\n", " Total slices: 24\n", " LOOCV Folds: 24 (each slice as validation once)\n", " Using sklearn.model_selection.LeaveOneGroupOut\n" ] } ], "source": [ "\n", "all_X = all_loader.dataset.X.numpy()\n", "all_concepts = all_loader.dataset.concepts.numpy()\n", "all_labels = all_loader.dataset.labels.numpy()\n", "\n", "groups = adata_scvi.obs['slice_name'].values \n", "\n", "unique_slices = np.unique(groups)\n", "slice_to_label = {}\n", "for slice_name in unique_slices:\n", " mask = groups == slice_name\n", " slice_to_label[slice_name] = all_labels[mask][0]\n", "\n", "\n", "logo = LeaveOneGroupOut()\n", "n_splits = logo.get_n_splits(all_X, all_labels, groups)\n", "\n", "print(f\"\\nData preparation complete:\")\n", "print(f\" Total cells: {len(all_X)}\")\n", "print(f\" Total slices: {len(unique_slices)}\")\n", "print(f\" LOOCV Folds: {n_splits} (each slice as validation once)\")\n", "print(f\" Using sklearn.model_selection.LeaveOneGroupOut\")\n" ] }, { "cell_type": "markdown", "id": "1de7b498", "metadata": {}, "source": [ "## Step 3: Run LeaveOneGroupOut cross-validation" ] }, { "cell_type": "code", "execution_count": 6, "id": "52b00262", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 3: Run LeaveOneGroupOut cross-validation\n", "\n", "Starting LOOCV training (24 folds)...\n", " Training 5000 epochs per fold\n", " Learning rate: 0.001\n", " Using sklearn LeaveOneGroupOut\n", "\n", "[Fold 1/24] Validation slice: Garbarino_LM1\n", " Train cells: 64763, Val cells: 3498\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 83.59%\n", "\n", "[Fold 2/24] Validation slice: Garbarino_LM2\n", " Train cells: 64485, Val cells: 3776\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 83.95%\n", "\n", "[Fold 3/24] Validation slice: Garbarino_LM3\n", " Train cells: 67013, Val cells: 1248\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 87.42%\n", "\n", "[Fold 4/24] Validation slice: Garbarino_LM4\n", " Train cells: 66724, Val cells: 1537\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 82.82%\n", "\n", "[Fold 5/24] Validation slice: Valdeolivas_CRC_s2_rep1\n", " Train cells: 67069, Val cells: 1192\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 83.22%\n", "\n", "[Fold 6/24] Validation slice: Valdeolivas_CRC_s3_rep1\n", " Train cells: 65944, Val cells: 2317\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 80.45%\n", "\n", "[Fold 7/24] Validation slice: Valdeolivas_CRC_s3_rep2\n", " Train cells: 66458, Val cells: 1803\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 88.19%\n", "\n", "[Fold 8/24] Validation slice: Valdeolivas_CRC_s4_rep2\n", " Train cells: 67213, Val cells: 1048\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 84.06%\n", "\n", "[Fold 9/24] Validation slice: Valdeolivas_CRC_s5_rep1\n", " Train cells: 66058, Val cells: 2203\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 88.11%\n", "\n", "[Fold 10/24] Validation slice: Valdeolivas_CRC_s5_rep2\n", " Train cells: 65876, Val cells: 2385\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 90.69%\n", "\n", "[Fold 11/24] Validation slice: Valdeolivas_CRC_s6_rep1\n", " Train cells: 66133, Val cells: 2128\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 88.77%\n", "\n", "[Fold 12/24] Validation slice: Valdeolivas_CRC_s6_rep2\n", " Train cells: 66570, Val cells: 1691\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 87.29%\n", "\n", "[Fold 13/24] Validation slice: Villemin_LM1\n", " Train cells: 65374, Val cells: 2887\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 81.64%\n", "\n", "[Fold 14/24] Validation slice: Villemin_LM4\n", " Train cells: 65944, Val cells: 2317\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 82.22%\n", "\n", "[Fold 15/24] Validation slice: Wang_CRC1\n", " Train cells: 66207, Val cells: 2054\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 71.28%\n", "\n", "[Fold 16/24] Validation slice: Wang_CRC2\n", " Train cells: 64956, Val cells: 3305\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 70.86%\n", "\n", "[Fold 17/24] Validation slice: Wang_CRC3\n", " Train cells: 64844, Val cells: 3417\n", " True: Primary, Pred: Metastasis ✗\n", " Concept accuracy: 62.31%\n", "\n", "[Fold 18/24] Validation slice: Wang_CRC4\n", " Train cells: 64245, Val cells: 4016\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 85.53%\n", "\n", "[Fold 19/24] Validation slice: Wang_LM1\n", " Train cells: 63589, Val cells: 4672\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 89.43%\n", "\n", "[Fold 20/24] Validation slice: Wang_LM2\n", " Train cells: 63465, Val cells: 4796\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 87.57%\n", "\n", "[Fold 21/24] Validation slice: Wu_CRC1\n", " Train cells: 64948, Val cells: 3313\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 79.84%\n", "\n", "[Fold 22/24] Validation slice: Wu_CRC2\n", " Train cells: 64087, Val cells: 4174\n", " True: Primary, Pred: Primary ✓\n", " Concept accuracy: 82.97%\n", "\n", "[Fold 23/24] Validation slice: Wu_LM1\n", " Train cells: 64435, Val cells: 3826\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 80.87%\n", "\n", "[Fold 24/24] Validation slice: Wu_LM2\n", " Train cells: 63603, Val cells: 4658\n", " True: Metastasis, Pred: Metastasis ✓\n", " Concept accuracy: 80.18%\n", "\n" ] } ], "source": [ "print(\"Step 3: Run LeaveOneGroupOut cross-validation\")\n", "\n", "NUM_EPOCHS = 5000\n", "LEARNING_RATE = 0.001\n", "\n", "\n", "results = run_loocv(\n", " all_X=all_X,\n", " all_concepts=all_concepts,\n", " all_labels=all_labels,\n", " groups=groups,\n", " metadata=metadata,\n", " device=device,\n", " num_epochs=NUM_EPOCHS,\n", " learning_rate=LEARNING_RATE,\n", " verbose=True,\n", ")\n", "\n", "\n", "loocv_results = results['loocv_results']\n", "all_true_labels = results['all_true_labels']\n", "all_pred_labels = results['all_pred_labels']\n", "all_pred_probs = results['all_pred_probs']\n", "all_concept_accs = results['all_concept_accs']\n", "all_loss_histories = results['all_loss_histories']" ] }, { "cell_type": "code", "execution_count": 7, "id": "df8efd00", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total slices: 24\n", "Correct predictions: 23\n", "Incorrect predictions: 1\n" ] } ], "source": [ "print(f\"Total slices: {len(unique_slices)}\")\n", "print(f\"Correct predictions: {(all_true_labels == all_pred_labels).sum()}\")\n", "print(f\"Incorrect predictions: {(all_true_labels != all_pred_labels).sum()}\")" ] } ], "metadata": { "kernelspec": { "display_name": "STT", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }