Attention-Guided TransUNet for Multi-Class Retinal Fluid Segmentation

Author: Animesh Kumar | Newcastle University MSc Advanced Computer Science 2025–26
Target Venue: OMIA 2026 Workshop at MICCAI + medRxiv preprint
Framework: PyTorch | Compute: Google Colab H100
Live Demo: HuggingFace Space | Code: GitHub | DOI: Zenodo


Clinical Motivation

Retinal fluid accumulation is the primary biomarker for three major vision-threatening diseases:

Fluid Type Full Name Clinical Significance
IRF Intraretinal Fluid Active inflammation in diabetic macular oedema (DME) and wet AMD. Requires anti-VEGF injection within days of detection.
SRF Subretinal Fluid Associated with neovascular AMD and central serous chorioretinopathy. Volume determines treatment frequency.
PED Pigment Epithelial Detachment Elevation of the retinal pigment epithelium β€” a key AMD progression marker.

Manual segmentation by expert graders takes 20–40 minutes per volume with up to 15% inter-grader variability on small fluid pockets. This model provides calibrated, uncertainty-aware segmentation masks to reduce clinician workload and improve treatment monitoring consistency.


Architecture

Two Models β€” Dual Ensemble

Component V2S (Small) V2L (Large)
Encoder EfficientNetV2S EfficientNetV2L
Encoder channels s1=24, s2=48, s3=64, s4=160, bot=256 s1=32, s2=64, s3=96, s4=192, bot=640
Transformer d_model 256 512
Attention heads 16 16
Transformer layers 2 2
Total parameters ~22M ~127M
Phase B val Dice 0.7443 (seed=42) 0.7913 (seed=123)

Novel Contributions

  1. UCUS β€” Uncertainty-Weighted Clinical Urgency Score
    Combines volume score, foveal multiplier, boundary uncertainty and uncertainty discount into a single triage band: Monitor / Review / Urgent.

  2. Dual Uncertainty Estimation
    MC Dropout variance (20 forward passes) combined with inter-model disagreement between V2S and V2L predictions.

  3. Source-Adaptive BatchNorm (SA-BN)
    Separate batch norm statistics per scanner source (DUKE, AROI, UMN-AMD, UMN-DME, OPTIMA). Enables cross-scanner domain adaptation without retraining.

  4. Multi-Source Four-Dataset Evaluation
    Simultaneous evaluation across 4 independent acquisition sources with per-source Dice breakdown.

Architecture Data Flow

Input OCT B-scan (1Γ—512Γ—512)
        β”‚
        β–Ό
EfficientNetV2L Encoder (5 stages, ImageNet pretrained)
        β”‚  skip connections s1-s4
        β–Ό
Transformer Bottleneck (2Γ— MHA, d_model=512, 16 heads, learnable pos encoding)
        β”‚
        β–Ό
Attention-Gate Decoder (4 levels, SA-BatchNorm per decoder block)
        β”‚
        β–Ό
MC Dropout (p=0.3) + Output Head (4 classes)
        β”‚
        β”œβ”€β”€ Segmentation Mask (BG/IRF/SRF/PED)
        β”œβ”€β”€ Uncertainty Heatmap (pixel-wise std)
        └── UCUS Score (clinical triage)

Datasets

Dataset Volumes Annotated Classes Disease Scanner
DUKE DME 10 subjects, 110 B-scans IRF only DME Spectralis
AROI 24 patients IRF + SRF + PED AMD Zeiss Cirrus
UMN AMD 24 subjects SRF (binary) AMD Spectralis
UMN DME 29 subjects IRF (binary) DME Spectralis

Unified label space: 0=Background, 1=IRF, 2=SRF, 3=PED

Split after preprocessing:

  • Train: 4983 fluid-only slices
  • Validation: 552 slices
  • Test: 503 slices

Training Protocol

Phase A β€” Decoder Only (5 epochs)

  • Encoder frozen at ImageNet weights
  • LR = 1e-3, Adam, batch size 8
  • Target: val_dice > 0.50

Phase B β€” Full Fine-tuning (25 epochs)

  • Encoder blocks 3-5 unfrozen
  • LR = 1e-4, WarmupCosineDecay (5 epoch warmup)
  • Batch size 4, early stopping patience=7
  • Loss: Dice + 0.5 Γ— CrossEntropy

Seeds Trained

V2S: 42, 123, 2024 | V2L: 42, 123, 2024


Results

Multi-Seed Validation Dice (mean Β± std across seeds 42/123/2024)

Model IRF SRF PED Mean Fluid
V2S 0.8658 Β± 0.0067 0.8272 Β± 0.0046 0.5184 Β± 0.0093 0.7371 Β± 0.0052
V2L 0.9158 Β± 0.0034 0.8560 Β± 0.0034 0.5811 Β± 0.0175 0.7843 Β± 0.0058

Test Set Results (503 slices, 4 sources)

Metric Mean Std
dice_IRF 0.2043 Β±0.3482
dice_SRF 0.1712 Β±0.2359
dice_PED 0.4463 Β±0.4611
dice_mean_fluid 0.2739 Β±0.2161

Note: Low test Dice is driven by domain shift across 4 independent sources. V2L alone achieves 0.4511 on the ablation β€” the multi-source test set is a hard benchmark.

Per-Source Breakdown

Source IRF SRF PED Mean Fluid
AROI 0.054 0.299 0.144 0.166
DUKE 0.071 0.000 0.902 0.324
UMN 0.381 0.176 0.409 0.322

DUKE SRF = 0.000 is expected β€” DUKE dataset contains only IRF annotations.
DUKE PED = 0.902 shows the model correctly detects PED on DUKE scans.

Clinical Safety Metrics

Metric Value Significance
Inter-grader human ceiling 0.9030 Upper bound for automated systems
Model match threshold (95%) 0.8579 Target for clinical deployment
Uncertainty ratio 1.34Γ— p=3.77e-05 βœ…
SRF volume correlation r=0.778 p=6.33e-04 βœ…
PED volume correlation r=0.841 p=8.64e-05 βœ…
Total fluid correlation r=0.562 p=2.93e-02 βœ…

The uncertainty ratio finding means the model is 1.34Γ— more uncertain at pixels where human experts disagree β€” a statistically significant result confirming that model uncertainty correlates with genuine clinical ambiguity.

Ablation Study

Variant Mean Dice Std
V2S only (no MC, no TTA) 0.338 Β±0.332
V2S + MC Dropout 0.141 Β±0.144
V2S + TTA 0.121 Β±0.115
V2L only (no MC, no TTA) 0.449 Β±0.304
V2S + V2L ensemble 0.415 Β±0.318
V2S + V2L + MC Dropout 0.279 Β±0.219
Full (V2S+V2L+MC+TTA) 0.293 Β±0.218

INT8 Quantisation (Phase 5B)

Model FP32 INT8 Compression Method
V2L 510 MB 132 MB 3.9Γ— Per-tensor symmetric int8
V2S 91 MB 24 MB 3.8Γ— Per-tensor symmetric int8

Files in This Repository

File Size Description
ckpt_phaseB_V2L_s123.pth 1526 MB Best V2L checkpoint β€” val_dice=0.7913, epoch=25
ckpt_phaseB_V2L_s2024.pth 1526 MB Second V2L checkpoint β€” val_dice=0.7841, epoch=24
ckpt_phaseB_V2S_s42.pth 271 MB Best V2S checkpoint β€” val_dice=0.7443, epoch=34
ckpt_phaseB_V2L_s123_int8.pth 132 MB INT8 quantised V2L (3.9Γ— compression)
ckpt_phaseB_V2S_s42_int8.pth 24 MB INT8 quantised V2S (3.8Γ— compression)
deployment/slot1_v2l_seed2024.onnx β€” ONNX export β€” ready for TensorRT/OpenVINO
deployment/slot2_v2l_seed123.onnx β€” ONNX export β€” ready for TensorRT/OpenVINO
demo_results.json 87 MB 20 precomputed demo samples (5 per source)

Usage

Load Best Model (V2L seed=123)

from huggingface_hub import hf_hub_download
import torch

path = hf_hub_download(
    repo_id="animeshakr/oct-fluid-segmentation",
    filename="ckpt_phaseB_V2L_s123.pth"
)
ck = torch.load(path, map_location="cpu")
print(f"val_dice: {ck['val_dice']:.4f}")  # 0.7913
print(f"epoch:    {ck['epoch']}")          # 25

Load INT8 Quantised Model (Edge Deployment)

from huggingface_hub import hf_hub_download
import torch

path = hf_hub_download(
    repo_id="animeshakr/oct-fluid-segmentation",
    filename="ckpt_phaseB_V2L_s123_int8.pth"
)
qsd = torch.load(path, map_location="cpu")
# Dequantise at inference: weight = weight_int8 * scale
# 132MB vs 510MB original β€” 3.9x smaller

ONNX Inference

from huggingface_hub import hf_hub_download
import onnxruntime as ort
import numpy as np

# Download both files β€” .onnx.data must be downloaded first
hf_hub_download(repo_id="animeshakr/oct-fluid-segmentation",
                filename="deployment/slot2_v2l_seed123.onnx.data")
path = hf_hub_download(repo_id="animeshakr/oct-fluid-segmentation",
                       filename="deployment/slot2_v2l_seed123.onnx")

sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
x = np.random.randn(1, 1, 512, 512).astype(np.float32)
out = sess.run(None, {sess.get_inputs()[0].name: x})[0]
pred_mask = out.argmax(axis=1)[0]  # (512, 512) with values 0-3

Citation

@misc{kumar2026octseg,
  title={Attention-Guided TransUNet for Multi-Class Retinal Fluid Segmentation
         in OCT with MC Dropout Uncertainty Quantification},
  author={Kumar, Animesh A.},
  institution={Newcastle University, UK},
  year={2026},
  note={MSc Advanced Computer Science dissertation.
        Targeting OMIA 2026 Workshop at MICCAI.}
}

Related Papers

  • Bogunovic et al. (2019) β€” RETOUCH Challenge, IEEE TMI β€” benchmark definition
  • Ronneberger et al. (2015) β€” U-Net β€” segmentation baseline
  • Schlemper et al. (2019) β€” Attention U-Net β€” attention gate mechanism
  • Chen et al. (2021) β€” TransUNet β€” transformer bottleneck design
  • Rasti et al. (2022) β€” RetiFluidNet β€” current SOTA on RETOUCH

License

MIT License β€” see GitHub repository

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Spaces using animeshakr/oct-fluid-segmentation 3