CRCLM Dateset

[1]:
import warnings
warnings.filterwarnings("ignore")
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
torch.set_num_threads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[2]:
import sys
import os
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
from tqdm import tqdm
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                            confusion_matrix, classification_report, roc_auc_score)
ROOT_DIR = os.path.abspath(os.getcwd())
if os.path.isdir(os.path.join(ROOT_DIR, 'DREAM_stage2')):
    sys.path.insert(0, ROOT_DIR)
else:
    parent = os.path.dirname(ROOT_DIR)
    if os.path.isdir(os.path.join(parent, 'DREAM_stage2')):
        sys.path.insert(0, parent)
    else:
        sys.path.insert(0, ROOT_DIR)

from DREAM_stage2.dataset import prepare_data_loaders
from DREAM_stage2.models import create_end2end_model
from DREAM_stage2.train import accuracy, AverageMeter, train_and_evaluate_fold, run_loocv

sns.set_palette("husl")
plt.rcParams['figure.dpi'] = 100

SEED = 2025
torch.manual_seed(SEED)
np.random.seed(SEED)

print("="*80)
print("Spatial Transcriptomics Concept Bottleneck Model + Attribution Tutorial")
print("Reference PAUSE framework")
print("="*80)
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}\n")
2026-03-05 11:33:53.186147: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-05 11:33:53.251931: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-03-05 11:33:53.530789: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2026-03-05 11:33:53.530822: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2026-03-05 11:33:53.530824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
================================================================================
Spatial Transcriptomics Concept Bottleneck Model + Attribution Tutorial
Reference PAUSE framework
================================================================================

PyTorch version: 2.7.0+cu118
CUDA available: True

1. Load Data

[3]:

DATA_PATH = "/home/Data/CRCLM/adata_DREAM_Stage.h5ad" RAW_DATA_PATH = "/home/Data/CRCLM/adata_raw.h5ad" adata_scvi = sc.read_h5ad(DATA_PATH) adata_raw = sc.read_h5ad(RAW_DATA_PATH) if 'X_raw' not in adata_scvi.uns: adata_scvi.uns['X_raw'] = adata_raw.X print(f" Number of cells: {adata_scvi.obsm['latent'].shape[0]}") print(f" Number of features: {adata_scvi.obsm['latent'].shape[1]}") print(f" Number of spatial domains: {adata_scvi.obs['leiden'].nunique()}") print(f" Number of clinical phenotypes: {adata_scvi.obs['type'].nunique()}") print(f" Number of slices: {adata_scvi.obs['slice_name'].nunique()}")
  Number of cells: 68261
  Number of features: 64
  Number of spatial domains: 21
  Number of clinical phenotypes: 2
  Number of slices: 24

2. Prepare Data (no test split)

[4]:
(train_loader, val_loader, all_loader,
#  slice_train_loader, slice_val_loader, slice_all_loader,
 metadata) = prepare_data_loaders(
    DATA_PATH,
    val_split=0.2,
    seed=SEED,
    expr_key='latent',
    use_all_data=True  # Use all data for training (no val split)
)

print(f"  Number of genes: {metadata['n_genes']}")
print(f"  Number of concepts: {metadata['n_concepts']} ({', '.join(metadata['concept_classes'])})")
print(f"  Number of labels: {metadata['n_labels']} ({', '.join(metadata['label_classes'])})")
print(f"  Training slices: {len(metadata['train_slices'])}")
print(f"  Validation slices: {len(metadata['val_slices'])}")
  Number of genes: 64
  Number of concepts: 21 (0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)
  Number of labels: 2 (Metastasis, Primary)
  Training slices: 24
  Validation slices: 24
[5]:

all_X = all_loader.dataset.X.numpy() all_concepts = all_loader.dataset.concepts.numpy() all_labels = all_loader.dataset.labels.numpy() groups = adata_scvi.obs['slice_name'].values unique_slices = np.unique(groups) slice_to_label = {} for slice_name in unique_slices: mask = groups == slice_name slice_to_label[slice_name] = all_labels[mask][0] logo = LeaveOneGroupOut() n_splits = logo.get_n_splits(all_X, all_labels, groups) print(f"\nData preparation complete:") print(f" Total cells: {len(all_X)}") print(f" Total slices: {len(unique_slices)}") print(f" LOOCV Folds: {n_splits} (each slice as validation once)") print(f" Using sklearn.model_selection.LeaveOneGroupOut")

Data preparation complete:
  Total cells: 68261
  Total slices: 24
  LOOCV Folds: 24 (each slice as validation once)
  Using sklearn.model_selection.LeaveOneGroupOut

Step 3: Run LeaveOneGroupOut cross-validation

[6]:
print("Step 3: Run LeaveOneGroupOut cross-validation")

NUM_EPOCHS = 5000
LEARNING_RATE = 0.001


results = run_loocv(
    all_X=all_X,
    all_concepts=all_concepts,
    all_labels=all_labels,
    groups=groups,
    metadata=metadata,
    device=device,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    verbose=True,
)


loocv_results = results['loocv_results']
all_true_labels = results['all_true_labels']
all_pred_labels = results['all_pred_labels']
all_pred_probs = results['all_pred_probs']
all_concept_accs = results['all_concept_accs']
all_loss_histories = results['all_loss_histories']
Step 3: Run LeaveOneGroupOut cross-validation

Starting LOOCV training (24 folds)...
  Training 5000 epochs per fold
  Learning rate: 0.001
  Using sklearn LeaveOneGroupOut

[Fold 1/24] Validation slice: Garbarino_LM1
  Train cells: 64763, Val cells: 3498
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 83.59%

[Fold 2/24] Validation slice: Garbarino_LM2
  Train cells: 64485, Val cells: 3776
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 83.95%

[Fold 3/24] Validation slice: Garbarino_LM3
  Train cells: 67013, Val cells: 1248
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 87.42%

[Fold 4/24] Validation slice: Garbarino_LM4
  Train cells: 66724, Val cells: 1537
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 82.82%

[Fold 5/24] Validation slice: Valdeolivas_CRC_s2_rep1
  Train cells: 67069, Val cells: 1192
  True: Primary, Pred: Primary ✓
  Concept accuracy: 83.22%

[Fold 6/24] Validation slice: Valdeolivas_CRC_s3_rep1
  Train cells: 65944, Val cells: 2317
  True: Primary, Pred: Primary ✓
  Concept accuracy: 80.45%

[Fold 7/24] Validation slice: Valdeolivas_CRC_s3_rep2
  Train cells: 66458, Val cells: 1803
  True: Primary, Pred: Primary ✓
  Concept accuracy: 88.19%

[Fold 8/24] Validation slice: Valdeolivas_CRC_s4_rep2
  Train cells: 67213, Val cells: 1048
  True: Primary, Pred: Primary ✓
  Concept accuracy: 84.06%

[Fold 9/24] Validation slice: Valdeolivas_CRC_s5_rep1
  Train cells: 66058, Val cells: 2203
  True: Primary, Pred: Primary ✓
  Concept accuracy: 88.11%

[Fold 10/24] Validation slice: Valdeolivas_CRC_s5_rep2
  Train cells: 65876, Val cells: 2385
  True: Primary, Pred: Primary ✓
  Concept accuracy: 90.69%

[Fold 11/24] Validation slice: Valdeolivas_CRC_s6_rep1
  Train cells: 66133, Val cells: 2128
  True: Primary, Pred: Primary ✓
  Concept accuracy: 88.77%

[Fold 12/24] Validation slice: Valdeolivas_CRC_s6_rep2
  Train cells: 66570, Val cells: 1691
  True: Primary, Pred: Primary ✓
  Concept accuracy: 87.29%

[Fold 13/24] Validation slice: Villemin_LM1
  Train cells: 65374, Val cells: 2887
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 81.64%

[Fold 14/24] Validation slice: Villemin_LM4
  Train cells: 65944, Val cells: 2317
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 82.22%

[Fold 15/24] Validation slice: Wang_CRC1
  Train cells: 66207, Val cells: 2054
  True: Primary, Pred: Primary ✓
  Concept accuracy: 71.28%

[Fold 16/24] Validation slice: Wang_CRC2
  Train cells: 64956, Val cells: 3305
  True: Primary, Pred: Primary ✓
  Concept accuracy: 70.86%

[Fold 17/24] Validation slice: Wang_CRC3
  Train cells: 64844, Val cells: 3417
  True: Primary, Pred: Metastasis ✗
  Concept accuracy: 62.31%

[Fold 18/24] Validation slice: Wang_CRC4
  Train cells: 64245, Val cells: 4016
  True: Primary, Pred: Primary ✓
  Concept accuracy: 85.53%

[Fold 19/24] Validation slice: Wang_LM1
  Train cells: 63589, Val cells: 4672
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 89.43%

[Fold 20/24] Validation slice: Wang_LM2
  Train cells: 63465, Val cells: 4796
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 87.57%

[Fold 21/24] Validation slice: Wu_CRC1
  Train cells: 64948, Val cells: 3313
  True: Primary, Pred: Primary ✓
  Concept accuracy: 79.84%

[Fold 22/24] Validation slice: Wu_CRC2
  Train cells: 64087, Val cells: 4174
  True: Primary, Pred: Primary ✓
  Concept accuracy: 82.97%

[Fold 23/24] Validation slice: Wu_LM1
  Train cells: 64435, Val cells: 3826
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 80.87%

[Fold 24/24] Validation slice: Wu_LM2
  Train cells: 63603, Val cells: 4658
  True: Metastasis, Pred: Metastasis ✓
  Concept accuracy: 80.18%

[7]:
print(f"Total slices: {len(unique_slices)}")
print(f"Correct predictions: {(all_true_labels == all_pred_labels).sum()}")
print(f"Incorrect predictions: {(all_true_labels != all_pred_labels).sum()}")
Total slices: 24
Correct predictions: 23
Incorrect predictions: 1