TabDPT Checkpoints

Pre-trained TabDPT model weights trained with three different random seeds. Each checkpoint is from epoch 2040 of production training.

Files

File Training Seed
production_seed42.safetensors 42
production_seed123.safetensors 123
production_seed456.safetensors 456

Model Architecture

  • Embedding size: 512
  • Attention heads: 8
  • Layers: 12
  • Hidden factor: 2
  • Max features: 100
  • Max classes: 10

Benchmark Results

Classification: Breast Cancer (binary, 30 features)

Checkpoint Accuracy Ensemble Accuracy
seed42 99.4% 99.4%
seed123 98.8% 98.8%
seed456 98.2% 98.2%
HF default (Layer6/TabDPT) — 99.4%

Classification: Wine (3-class, 13 features)

Checkpoint Accuracy Ensemble Accuracy
seed42 100% 100%
seed123 100% 100%
seed456 100% 100%
HF default (Layer6/TabDPT) — 100%

Regression: Diabetes (10 features)

Checkpoint MSE Correlation
seed42 2618.6 0.718
seed123 2655.1 0.713
seed456 2795.5 0.701
HF default (Layer6/TabDPT) 2673.1 0.711

Training Stats (from checkpoint metadata)

Metric seed42 seed123 seed456
CC18 Accuracy 0.877 0.878 0.879
CC18 F1 0.870 0.872 0.873
CC18 AUC 0.927 0.927 0.928
CTR Correlation 0.830 0.830 0.827
CTR R² 0.726 0.730 0.725

Format

These checkpoints were converted from PyTorch Lightning .ckpt files (which include optimizer state, ~295MB each) to SafeTensors format (model weights only, ~103MB each). This is the same format used by the official Layer6/TabDPT release. The tabdpt package natively loads SafeTensors via the model_weight_path argument — no extra conversion needed.

Usage

from tabdpt import TabDPTClassifier
from huggingface_hub import hf_hub_download

# Download once (cached afterwards)
path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")

# Use exactly like the default model
clf = TabDPTClassifier(model_weight_path=path)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)

Works identically with TabDPTRegressor.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support