Medal S: Spatio-Textual Prompt Model for Medical Segmentation

Paper OpenReview HuggingFace GitHub

This repository provides guidance for training and inference of Medal S within the CVPR 2025: Foundation Models for Text-Guided 3D biomedical image segmentation

Docker link for the 2025/05/30 testing submission: Medal S

Requirements

Python Version: 3.10.16

Installation

  1. Create conda environment and install dependencies:
#!/bin/bash

# Create environment
conda create -n medals_local_test python=3.10 -y

conda activate medals_local_test

# Install packages
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
pip install transformers==4.51.3 monai==1.4.0 nibabel==5.3.2 tensorboard einops positional_encodings scipy pandas scikit-learn scikit-image batchgenerators acvl_utils

# Install nnU-Net
wget https://github.com/MIC-DKFZ/nnUNet/archive/refs/tags/v2.4.1.tar.gz
tar -xvf v2.4.1.tar.gz
pip install -e nnUNet-2.4.1

# Install dynamic network architectures
cd model && pip install -e dynamic-network-architectures-main && cd ..

Training Guidance

First, download the dataset from Hugging Face: junma/CVPR-BiomedSegFM.

  • Data Preparation: Preprocess and organize all training data into a train_all.jsonl file using the provided script: data/challenge_data/get_train_jsonl.py.

  • Knowledge Enhancement: You can either use the pre-trained text encoder from SAT (https://github.com/zhaoziheng/SAT/tree/cvpr2025challenge) available on Hugging Face, or pre-train it yourself following the guidance in this repository. As recommended by SAT, we freeze the text encoder when training the segmentation model.

  • Segmentation: The training script is located at sh/cvpr2025_Blosc2_pretrain_1.0_1.0_1.0_UNET_ps192.sh. Before training, NPZ files will be converted to the Blosc2 compressed format (from the nnU-Net framework).

Training takes approximately 7 days with 2x H100-80GB GPUs for a 224x224x128 (1.5, 1.5, 3.0) spacing model, using a batch size of 2 per GPU. For a 192x192x192 (1.0, 1.0, 1.0) spacing model, it requires 4x H100-80GB GPUs with a batch size of 2 per GPU. You may modify the patch size and batch size to train on GPUs with less memory.

Inference Guidance

Quick Start

The inference pipeline takes raw NIfTI images and free-form text prompts as input, and outputs segmentation predictions. All preprocessing (spacing conversion, normalization, resampling) is handled automatically.

Basic Usage:

# Using shell script (recommended)
bash run_inference_medals_nifti.sh

# Or using Python directly
python inference_medals_nifti.py \
    --input input.nii.gz \
    --output output.nii.gz \
    --config config_CT.json \
    --mode stage1+stage2 \
    --device cuda:0 \
    --checkpoints ./checkpoints

Configuration Files

The inference uses JSON configuration files to specify text prompts and modality-specific settings:

For CT images (config_CT.json):

  • Supports multiple window types: soft_tissue, bone, lung
  • Each window type has its own text prompts and window settings
  • Multiple window types are processed separately and merged automatically
{
    "texts_soft_tissue": ["Aorta in whole body CT", "Liver in whole body CT"],
    "texts_bone": ["Vertebrae C1 in whole body CT"],
    "texts_lung": ["Left lung in whole body CT"],
    "window_settings": {
        "soft_tissue": {"window_level": 40, "window_width": 400},
        "bone": {"window_level": 500, "window_width": 1500},
        "lung": {"window_level": -600, "window_width": 1500}
    },
    "modality": "CT"
}

For non-CT images (config_nonCT.json):

  • Uses percentile-based normalization
  • Supports MRI, US, PET, microscopy
{
    "texts": ["Spleen in MRI"],
    "normalization_settings": {
        "percentile_lower": 0.5,
        "percentile_upper": 99.5,
        "preserve_zero": true
    },
    "modality": "MRI"
}

Input/Output

Input:

  • Raw NIfTI image file (.nii or .nii.gz) with spacing metadata
  • Free-form text prompt(s) describing the anatomical structures or lesions to segment
  • Optional: Modality information (CT, MRI, US, PET, microscopy)

Output:

  • Segmentation prediction as NIfTI file with same spacing/origin/direction as input
  • Label values correspond to the order of text prompts in the config

Example:

# Input: raw CT image + text prompts
Input:  "chest_ct.nii.gz"
Texts:  ["Aorta in whole body CT", "Liver in whole body CT"]
Labels: [1, 2]  # Auto-generated if not specified

# Output: segmentation mask
Output: "chest_ct_stage1+stage2.nii.gz"
        # Voxels labeled as 1 = Aorta, 2 = Liver, 0 = Background

Core Inference Logic

The inference pipeline automatically handles:

  1. Input Loading: Reads raw NIfTI image with spacing metadata
  2. Preprocessing:
    • CT: Window/level normalization (automatically selects appropriate window based on text prompts)
    • Non-CT: Percentile-based normalization
    • Spacing conversion and resampling to target resolution
  3. Text Encoding: Encodes free-form text prompts using the knowledge encoder
  4. Segmentation: Performs patch-based inference with optional two-stage refinement
  5. Post-processing: Resamples prediction back to original image space

Inference Modes

  • stage2_only: Single-stage high-resolution inference (faster)
  • stage1+stage2: Two-stage inference with ROI refinement (more accurate)

Python API Example

For programmatic usage, you can call the inference function directly:

from inference_medals_nifti import run_inference, load_config_from_json

# Option 1: Using config file (recommended)
config = load_config_from_json("config_CT.json")
pred_array, inference_time = run_inference(
    image_path="input.nii.gz",
    output_path="output.nii.gz",
    modality=config["modality"],
    texts=config["texts"],
    label_values=[str(l) for l in config["labels"]],
    inference_mode="stage1+stage2",
    device="cuda:0",
    checkpoints_path="./checkpoints",
    window_settings=config.get("window_settings"),
    normalization_settings=config.get("normalization_settings"),
    window_type_mapping=config.get("window_type_mapping")
)

# Option 2: Direct parameters
pred_array, inference_time = run_inference(
    image_path="input.nii.gz",
    output_path="output.nii.gz",
    modality="CT",
    texts=["Aorta in whole body CT", "Liver in whole body CT"],
    label_values=["1", "2"],
    inference_mode="stage1+stage2",
    device="cuda:0",
    checkpoints_path="./checkpoints"
)

Important Notes

  • Preprocessing: All preprocessing steps (spacing, normalization, resampling) are handled internally. Do not preprocess images yourself to avoid discrepancies.
  • Text Prompts: Use descriptive, modality-specific prompts (e.g., "Aorta in whole body CT" rather than just "Aorta")
  • Labels: If not specified in config, labels are auto-generated as consecutive integers starting from 1
  • GPU Memory: For large 3D images, adjust category_batch_size in the config to manage GPU memory usage

Citation

@misc{shi2025medalsspatiotextualprompt,
      title={Medal S: Spatio-Textual Prompt Model for Medical Segmentation}, 
      author={Pengcheng Shi and Jiawei Chen and Jiaqi Liu and Xinglin Zhang and Tao Chen and Lei Li},
      year={2025},
      eprint={2511.13001},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2511.13001}, 
}

Acknowledgements

This project is significantly improved based on nnU-Net and SAT. We extend our gratitude to both projects. Medal-S is developed and maintained by Medical Image Insights.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support