Attention-Guided TransUNet for Multi-Class Retinal Fluid Segmentation
Author: Animesh Kumar | Newcastle University MSc Advanced Computer Science 2025β26
Target Venue: OMIA 2026 Workshop at MICCAI + medRxiv preprint
Framework: PyTorch | Compute: Google Colab H100
Live Demo: HuggingFace Space | Code: GitHub | DOI: Zenodo
Clinical Motivation
Retinal fluid accumulation is the primary biomarker for three major vision-threatening diseases:
| Fluid Type | Full Name | Clinical Significance |
|---|---|---|
| IRF | Intraretinal Fluid | Active inflammation in diabetic macular oedema (DME) and wet AMD. Requires anti-VEGF injection within days of detection. |
| SRF | Subretinal Fluid | Associated with neovascular AMD and central serous chorioretinopathy. Volume determines treatment frequency. |
| PED | Pigment Epithelial Detachment | Elevation of the retinal pigment epithelium β a key AMD progression marker. |
Manual segmentation by expert graders takes 20β40 minutes per volume with up to 15% inter-grader variability on small fluid pockets. This model provides calibrated, uncertainty-aware segmentation masks to reduce clinician workload and improve treatment monitoring consistency.
Architecture
Two Models β Dual Ensemble
| Component | V2S (Small) | V2L (Large) |
|---|---|---|
| Encoder | EfficientNetV2S | EfficientNetV2L |
| Encoder channels | s1=24, s2=48, s3=64, s4=160, bot=256 | s1=32, s2=64, s3=96, s4=192, bot=640 |
| Transformer d_model | 256 | 512 |
| Attention heads | 16 | 16 |
| Transformer layers | 2 | 2 |
| Total parameters | ~22M | ~127M |
| Phase B val Dice | 0.7443 (seed=42) | 0.7913 (seed=123) |
Novel Contributions
UCUS β Uncertainty-Weighted Clinical Urgency Score
Combines volume score, foveal multiplier, boundary uncertainty and uncertainty discount into a single triage band: Monitor / Review / Urgent.Dual Uncertainty Estimation
MC Dropout variance (20 forward passes) combined with inter-model disagreement between V2S and V2L predictions.Source-Adaptive BatchNorm (SA-BN)
Separate batch norm statistics per scanner source (DUKE, AROI, UMN-AMD, UMN-DME, OPTIMA). Enables cross-scanner domain adaptation without retraining.Multi-Source Four-Dataset Evaluation
Simultaneous evaluation across 4 independent acquisition sources with per-source Dice breakdown.
Architecture Data Flow
Input OCT B-scan (1Γ512Γ512)
β
βΌ
EfficientNetV2L Encoder (5 stages, ImageNet pretrained)
β skip connections s1-s4
βΌ
Transformer Bottleneck (2Γ MHA, d_model=512, 16 heads, learnable pos encoding)
β
βΌ
Attention-Gate Decoder (4 levels, SA-BatchNorm per decoder block)
β
βΌ
MC Dropout (p=0.3) + Output Head (4 classes)
β
βββ Segmentation Mask (BG/IRF/SRF/PED)
βββ Uncertainty Heatmap (pixel-wise std)
βββ UCUS Score (clinical triage)
Datasets
| Dataset | Volumes | Annotated Classes | Disease | Scanner |
|---|---|---|---|---|
| DUKE DME | 10 subjects, 110 B-scans | IRF only | DME | Spectralis |
| AROI | 24 patients | IRF + SRF + PED | AMD | Zeiss Cirrus |
| UMN AMD | 24 subjects | SRF (binary) | AMD | Spectralis |
| UMN DME | 29 subjects | IRF (binary) | DME | Spectralis |
Unified label space: 0=Background, 1=IRF, 2=SRF, 3=PED
Split after preprocessing:
- Train: 4983 fluid-only slices
- Validation: 552 slices
- Test: 503 slices
Training Protocol
Phase A β Decoder Only (5 epochs)
- Encoder frozen at ImageNet weights
- LR = 1e-3, Adam, batch size 8
- Target: val_dice > 0.50
Phase B β Full Fine-tuning (25 epochs)
- Encoder blocks 3-5 unfrozen
- LR = 1e-4, WarmupCosineDecay (5 epoch warmup)
- Batch size 4, early stopping patience=7
- Loss: Dice + 0.5 Γ CrossEntropy
Seeds Trained
V2S: 42, 123, 2024 | V2L: 42, 123, 2024
Results
Multi-Seed Validation Dice (mean Β± std across seeds 42/123/2024)
| Model | IRF | SRF | PED | Mean Fluid |
|---|---|---|---|---|
| V2S | 0.8658 Β± 0.0067 | 0.8272 Β± 0.0046 | 0.5184 Β± 0.0093 | 0.7371 Β± 0.0052 |
| V2L | 0.9158 Β± 0.0034 | 0.8560 Β± 0.0034 | 0.5811 Β± 0.0175 | 0.7843 Β± 0.0058 |
Test Set Results (503 slices, 4 sources)
| Metric | Mean | Std |
|---|---|---|
| dice_IRF | 0.2043 | Β±0.3482 |
| dice_SRF | 0.1712 | Β±0.2359 |
| dice_PED | 0.4463 | Β±0.4611 |
| dice_mean_fluid | 0.2739 | Β±0.2161 |
Note: Low test Dice is driven by domain shift across 4 independent sources. V2L alone achieves 0.4511 on the ablation β the multi-source test set is a hard benchmark.
Per-Source Breakdown
| Source | IRF | SRF | PED | Mean Fluid |
|---|---|---|---|---|
| AROI | 0.054 | 0.299 | 0.144 | 0.166 |
| DUKE | 0.071 | 0.000 | 0.902 | 0.324 |
| UMN | 0.381 | 0.176 | 0.409 | 0.322 |
DUKE SRF = 0.000 is expected β DUKE dataset contains only IRF annotations.
DUKE PED = 0.902 shows the model correctly detects PED on DUKE scans.
Clinical Safety Metrics
| Metric | Value | Significance |
|---|---|---|
| Inter-grader human ceiling | 0.9030 | Upper bound for automated systems |
| Model match threshold (95%) | 0.8579 | Target for clinical deployment |
| Uncertainty ratio | 1.34Γ | p=3.77e-05 β |
| SRF volume correlation | r=0.778 | p=6.33e-04 β |
| PED volume correlation | r=0.841 | p=8.64e-05 β |
| Total fluid correlation | r=0.562 | p=2.93e-02 β |
The uncertainty ratio finding means the model is 1.34Γ more uncertain at pixels where human experts disagree β a statistically significant result confirming that model uncertainty correlates with genuine clinical ambiguity.
Ablation Study
| Variant | Mean Dice | Std |
|---|---|---|
| V2S only (no MC, no TTA) | 0.338 | Β±0.332 |
| V2S + MC Dropout | 0.141 | Β±0.144 |
| V2S + TTA | 0.121 | Β±0.115 |
| V2L only (no MC, no TTA) | 0.449 | Β±0.304 |
| V2S + V2L ensemble | 0.415 | Β±0.318 |
| V2S + V2L + MC Dropout | 0.279 | Β±0.219 |
| Full (V2S+V2L+MC+TTA) | 0.293 | Β±0.218 |
INT8 Quantisation (Phase 5B)
| Model | FP32 | INT8 | Compression | Method |
|---|---|---|---|---|
| V2L | 510 MB | 132 MB | 3.9Γ | Per-tensor symmetric int8 |
| V2S | 91 MB | 24 MB | 3.8Γ | Per-tensor symmetric int8 |
Files in This Repository
| File | Size | Description |
|---|---|---|
ckpt_phaseB_V2L_s123.pth |
1526 MB | Best V2L checkpoint β val_dice=0.7913, epoch=25 |
ckpt_phaseB_V2L_s2024.pth |
1526 MB | Second V2L checkpoint β val_dice=0.7841, epoch=24 |
ckpt_phaseB_V2S_s42.pth |
271 MB | Best V2S checkpoint β val_dice=0.7443, epoch=34 |
ckpt_phaseB_V2L_s123_int8.pth |
132 MB | INT8 quantised V2L (3.9Γ compression) |
ckpt_phaseB_V2S_s42_int8.pth |
24 MB | INT8 quantised V2S (3.8Γ compression) |
deployment/slot1_v2l_seed2024.onnx |
β | ONNX export β ready for TensorRT/OpenVINO |
deployment/slot2_v2l_seed123.onnx |
β | ONNX export β ready for TensorRT/OpenVINO |
demo_results.json |
87 MB | 20 precomputed demo samples (5 per source) |
Usage
Load Best Model (V2L seed=123)
from huggingface_hub import hf_hub_download
import torch
path = hf_hub_download(
repo_id="animeshakr/oct-fluid-segmentation",
filename="ckpt_phaseB_V2L_s123.pth"
)
ck = torch.load(path, map_location="cpu")
print(f"val_dice: {ck['val_dice']:.4f}") # 0.7913
print(f"epoch: {ck['epoch']}") # 25
Load INT8 Quantised Model (Edge Deployment)
from huggingface_hub import hf_hub_download
import torch
path = hf_hub_download(
repo_id="animeshakr/oct-fluid-segmentation",
filename="ckpt_phaseB_V2L_s123_int8.pth"
)
qsd = torch.load(path, map_location="cpu")
# Dequantise at inference: weight = weight_int8 * scale
# 132MB vs 510MB original β 3.9x smaller
ONNX Inference
from huggingface_hub import hf_hub_download
import onnxruntime as ort
import numpy as np
# Download both files β .onnx.data must be downloaded first
hf_hub_download(repo_id="animeshakr/oct-fluid-segmentation",
filename="deployment/slot2_v2l_seed123.onnx.data")
path = hf_hub_download(repo_id="animeshakr/oct-fluid-segmentation",
filename="deployment/slot2_v2l_seed123.onnx")
sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
x = np.random.randn(1, 1, 512, 512).astype(np.float32)
out = sess.run(None, {sess.get_inputs()[0].name: x})[0]
pred_mask = out.argmax(axis=1)[0] # (512, 512) with values 0-3
Citation
@misc{kumar2026octseg,
title={Attention-Guided TransUNet for Multi-Class Retinal Fluid Segmentation
in OCT with MC Dropout Uncertainty Quantification},
author={Kumar, Animesh A.},
institution={Newcastle University, UK},
year={2026},
note={MSc Advanced Computer Science dissertation.
Targeting OMIA 2026 Workshop at MICCAI.}
}
Related Papers
- Bogunovic et al. (2019) β RETOUCH Challenge, IEEE TMI β benchmark definition
- Ronneberger et al. (2015) β U-Net β segmentation baseline
- Schlemper et al. (2019) β Attention U-Net β attention gate mechanism
- Chen et al. (2021) β TransUNet β transformer bottleneck design
- Rasti et al. (2022) β RetiFluidNet β current SOTA on RETOUCH
License
MIT License β see GitHub repository