CRCLM Dateset
[1]:
import warnings
warnings.filterwarnings("ignore")
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
torch.set_num_threads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[2]:
import sys
import os
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
from tqdm import tqdm
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, classification_report, roc_auc_score)
ROOT_DIR = os.path.abspath(os.getcwd())
if os.path.isdir(os.path.join(ROOT_DIR, 'DREAM_stage2')):
sys.path.insert(0, ROOT_DIR)
else:
parent = os.path.dirname(ROOT_DIR)
if os.path.isdir(os.path.join(parent, 'DREAM_stage2')):
sys.path.insert(0, parent)
else:
sys.path.insert(0, ROOT_DIR)
from DREAM_stage2.dataset import prepare_data_loaders
from DREAM_stage2.models import create_end2end_model
from DREAM_stage2.train import accuracy, AverageMeter, train_and_evaluate_fold, run_loocv
sns.set_palette("husl")
plt.rcParams['figure.dpi'] = 100
SEED = 2025
torch.manual_seed(SEED)
np.random.seed(SEED)
print("="*80)
print("Spatial Transcriptomics Concept Bottleneck Model + Attribution Tutorial")
print("Reference PAUSE framework")
print("="*80)
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}\n")
2026-03-05 11:33:53.186147: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-05 11:33:53.251931: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-03-05 11:33:53.530789: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2026-03-05 11:33:53.530822: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2026-03-05 11:33:53.530824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
================================================================================
Spatial Transcriptomics Concept Bottleneck Model + Attribution Tutorial
Reference PAUSE framework
================================================================================
PyTorch version: 2.7.0+cu118
CUDA available: True
1. Load Data
[3]:
DATA_PATH = "/home/Data/CRCLM/adata_DREAM_Stage.h5ad"
RAW_DATA_PATH = "/home/Data/CRCLM/adata_raw.h5ad"
adata_scvi = sc.read_h5ad(DATA_PATH)
adata_raw = sc.read_h5ad(RAW_DATA_PATH)
if 'X_raw' not in adata_scvi.uns:
adata_scvi.uns['X_raw'] = adata_raw.X
print(f" Number of cells: {adata_scvi.obsm['latent'].shape[0]}")
print(f" Number of features: {adata_scvi.obsm['latent'].shape[1]}")
print(f" Number of spatial domains: {adata_scvi.obs['leiden'].nunique()}")
print(f" Number of clinical phenotypes: {adata_scvi.obs['type'].nunique()}")
print(f" Number of slices: {adata_scvi.obs['slice_name'].nunique()}")
Number of cells: 68261
Number of features: 64
Number of spatial domains: 21
Number of clinical phenotypes: 2
Number of slices: 24
2. Prepare Data (no test split)
[4]:
(train_loader, val_loader, all_loader,
# slice_train_loader, slice_val_loader, slice_all_loader,
metadata) = prepare_data_loaders(
DATA_PATH,
val_split=0.2,
seed=SEED,
expr_key='latent',
use_all_data=True # Use all data for training (no val split)
)
print(f" Number of genes: {metadata['n_genes']}")
print(f" Number of concepts: {metadata['n_concepts']} ({', '.join(metadata['concept_classes'])})")
print(f" Number of labels: {metadata['n_labels']} ({', '.join(metadata['label_classes'])})")
print(f" Training slices: {len(metadata['train_slices'])}")
print(f" Validation slices: {len(metadata['val_slices'])}")
Number of genes: 64
Number of concepts: 21 (0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)
Number of labels: 2 (Metastasis, Primary)
Training slices: 24
Validation slices: 24
[5]:
all_X = all_loader.dataset.X.numpy()
all_concepts = all_loader.dataset.concepts.numpy()
all_labels = all_loader.dataset.labels.numpy()
groups = adata_scvi.obs['slice_name'].values
unique_slices = np.unique(groups)
slice_to_label = {}
for slice_name in unique_slices:
mask = groups == slice_name
slice_to_label[slice_name] = all_labels[mask][0]
logo = LeaveOneGroupOut()
n_splits = logo.get_n_splits(all_X, all_labels, groups)
print(f"\nData preparation complete:")
print(f" Total cells: {len(all_X)}")
print(f" Total slices: {len(unique_slices)}")
print(f" LOOCV Folds: {n_splits} (each slice as validation once)")
print(f" Using sklearn.model_selection.LeaveOneGroupOut")
Data preparation complete:
Total cells: 68261
Total slices: 24
LOOCV Folds: 24 (each slice as validation once)
Using sklearn.model_selection.LeaveOneGroupOut
Step 3: Run LeaveOneGroupOut cross-validation
[6]:
print("Step 3: Run LeaveOneGroupOut cross-validation")
NUM_EPOCHS = 5000
LEARNING_RATE = 0.001
results = run_loocv(
all_X=all_X,
all_concepts=all_concepts,
all_labels=all_labels,
groups=groups,
metadata=metadata,
device=device,
num_epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
verbose=True,
)
loocv_results = results['loocv_results']
all_true_labels = results['all_true_labels']
all_pred_labels = results['all_pred_labels']
all_pred_probs = results['all_pred_probs']
all_concept_accs = results['all_concept_accs']
all_loss_histories = results['all_loss_histories']
Step 3: Run LeaveOneGroupOut cross-validation
Starting LOOCV training (24 folds)...
Training 5000 epochs per fold
Learning rate: 0.001
Using sklearn LeaveOneGroupOut
[Fold 1/24] Validation slice: Garbarino_LM1
Train cells: 64763, Val cells: 3498
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 83.59%
[Fold 2/24] Validation slice: Garbarino_LM2
Train cells: 64485, Val cells: 3776
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 83.95%
[Fold 3/24] Validation slice: Garbarino_LM3
Train cells: 67013, Val cells: 1248
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 87.42%
[Fold 4/24] Validation slice: Garbarino_LM4
Train cells: 66724, Val cells: 1537
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 82.82%
[Fold 5/24] Validation slice: Valdeolivas_CRC_s2_rep1
Train cells: 67069, Val cells: 1192
True: Primary, Pred: Primary ✓
Concept accuracy: 83.22%
[Fold 6/24] Validation slice: Valdeolivas_CRC_s3_rep1
Train cells: 65944, Val cells: 2317
True: Primary, Pred: Primary ✓
Concept accuracy: 80.45%
[Fold 7/24] Validation slice: Valdeolivas_CRC_s3_rep2
Train cells: 66458, Val cells: 1803
True: Primary, Pred: Primary ✓
Concept accuracy: 88.19%
[Fold 8/24] Validation slice: Valdeolivas_CRC_s4_rep2
Train cells: 67213, Val cells: 1048
True: Primary, Pred: Primary ✓
Concept accuracy: 84.06%
[Fold 9/24] Validation slice: Valdeolivas_CRC_s5_rep1
Train cells: 66058, Val cells: 2203
True: Primary, Pred: Primary ✓
Concept accuracy: 88.11%
[Fold 10/24] Validation slice: Valdeolivas_CRC_s5_rep2
Train cells: 65876, Val cells: 2385
True: Primary, Pred: Primary ✓
Concept accuracy: 90.69%
[Fold 11/24] Validation slice: Valdeolivas_CRC_s6_rep1
Train cells: 66133, Val cells: 2128
True: Primary, Pred: Primary ✓
Concept accuracy: 88.77%
[Fold 12/24] Validation slice: Valdeolivas_CRC_s6_rep2
Train cells: 66570, Val cells: 1691
True: Primary, Pred: Primary ✓
Concept accuracy: 87.29%
[Fold 13/24] Validation slice: Villemin_LM1
Train cells: 65374, Val cells: 2887
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 81.64%
[Fold 14/24] Validation slice: Villemin_LM4
Train cells: 65944, Val cells: 2317
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 82.22%
[Fold 15/24] Validation slice: Wang_CRC1
Train cells: 66207, Val cells: 2054
True: Primary, Pred: Primary ✓
Concept accuracy: 71.28%
[Fold 16/24] Validation slice: Wang_CRC2
Train cells: 64956, Val cells: 3305
True: Primary, Pred: Primary ✓
Concept accuracy: 70.86%
[Fold 17/24] Validation slice: Wang_CRC3
Train cells: 64844, Val cells: 3417
True: Primary, Pred: Metastasis ✗
Concept accuracy: 62.31%
[Fold 18/24] Validation slice: Wang_CRC4
Train cells: 64245, Val cells: 4016
True: Primary, Pred: Primary ✓
Concept accuracy: 85.53%
[Fold 19/24] Validation slice: Wang_LM1
Train cells: 63589, Val cells: 4672
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 89.43%
[Fold 20/24] Validation slice: Wang_LM2
Train cells: 63465, Val cells: 4796
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 87.57%
[Fold 21/24] Validation slice: Wu_CRC1
Train cells: 64948, Val cells: 3313
True: Primary, Pred: Primary ✓
Concept accuracy: 79.84%
[Fold 22/24] Validation slice: Wu_CRC2
Train cells: 64087, Val cells: 4174
True: Primary, Pred: Primary ✓
Concept accuracy: 82.97%
[Fold 23/24] Validation slice: Wu_LM1
Train cells: 64435, Val cells: 3826
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 80.87%
[Fold 24/24] Validation slice: Wu_LM2
Train cells: 63603, Val cells: 4658
True: Metastasis, Pred: Metastasis ✓
Concept accuracy: 80.18%
[7]:
print(f"Total slices: {len(unique_slices)}")
print(f"Correct predictions: {(all_true_labels == all_pred_labels).sum()}")
print(f"Incorrect predictions: {(all_true_labels != all_pred_labels).sum()}")
Total slices: 24
Correct predictions: 23
Incorrect predictions: 1