Deepfake Detection with Vision Transformer (ViT) and LoRA

Overview

This project focuses on fine-tuning a pre-trained Vision Transformer (ViT) model for the task of deepfake detection. Leveraging the power of Transfer Learning and Low-Rank Adaptation (LoRA), we adapt a google/vit-base-patch16-224-in21k model (specifically initialized from prithivMLmods/Deep-Fake-Detector-v2-Model) to classify images as either "Real" or "Fake".

The model is trained on a balanced subset of the OpenRL/DeepFakeFace dataset, containing 12,000 images. Due to hardware constraints, this subset was carefully selected to ensure diverse representation of various generative techniques (Stable Diffusion, Inpainting, InsightFace).

Key Features

  • Efficient Fine-Tuning: Uses LoRA (Low-Rank Adaptation) to fine-tune the heavy ViT model with significantly fewer trainable parameters (r=16), making training feasible on consumer GPUs.
  • Robust Data Augmentation: Implements ColorJitter, RandomResizedCrop, and RandomHorizontalFlip to improve model generalization.
  • Video Inference Support: Includes a custom pipeline for analyzing videos frame-by-frame to detect deepfakes in motion.
  • High Confidence Thresholding: Inference logic implements a strict threshold (e.g., 90%) for classifying an image as "Real" to minimize false negatives in security-critical contexts.

Model Architecture

Dataset

The project uses the DeepFakeFace(DFF) dataset OpenRL/DeepFakeFace. A balanced subset of 12,000 images was curated using a custom selection script (select_dataset.py).

Data Distribution (12,000 Images Total)

Class Count Source / Generator Description
Real 6,000 wiki dataset Real human faces from Wikipedia
Fake 2,000 text2img Generated via Stable Diffusion v1.5
Fake 2,000 inpainting Generated via SD Inpainting
Fake 2,000 insight Generated via InsightFace

Data Splits

  • Train: 70% (8,400 images)
  • Validation: 20% (2,401 images)
  • Test: 10% (1,199 images)

Training Details

The model was trained using the Hugging Face Trainer API.

Hyperparameters

  • Optimizer: AdamW
  • Learning Rate: 1e-4
  • Scheduler: Cosine with Warmup (ratio 0.1)
  • Batch Size: 16 (Train) / 32 (Eval)
  • Epochs: 10
  • Weight Decay: 0.01
  • Precision: FP16 (Mixed Precision)
  • Loss Function: CrossEntropyLoss

Data Augmentation

To prevent overfitting, the following transformations are applied during training:

  1. RandomResizedCrop: Scales 0.8-1.0 of the original image.
  2. RandomHorizontalFlip: Probability 0.5.
  3. ColorJitter: Brightness (±20%), Contrast (±20%), Saturation (±20%), Hue (±10%).
  4. Normalization: Standard ImageNet mean and std.

Performance

  • Train Loss & Validation Loss: vit-training_curves.png
  • Confusion Matrix: vit-confusion_matrix.png
  • Test Accuracy: ~85.32%
  • Test F1-Score: ~85.31%

Inference

1. Using Hugging Face Pipeline

# Use a pipeline as a high-level helper
from transformers import pipeline

# Load the model
pipe = pipeline("image-classification", model="shunda012/vit-deepfake-detector")

# Predict on an image
result = pipe("path_to_image.jpg")
print(result)

2. Using PyTorch

from transformers import ViTForImageClassification, ViTImageProcessor
import torch
from PIL import Image

# Load Base Model & processor
model = ViTForImageClassification.from_pretrained("shunda012/vit-deepfake-detector")
processor = ViTImageProcessor.from_pretrained("shunda012/vit-deepfake-detector")

# Load and preprocess the image
image = Image.open("path_to_image.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=-1)
    predicted_class = torch.argmax(probs, dim=-1).item()

# Print probabilities for each class
print(f"Fake Prob: {probs[0][0]:.2f}, Real Prob: {probs[0][1]:.2f}")

# Map class index to label
label = model.config.id2label[predicted_class]
print(f"Predicted Label: {label}")

Limitations

  • Hardware and Data Limits: Because of hardware limits, the model was trained on only 12,000 images, which is a small fraction of the original 120k dataset. Scaling this up would likely improve accuracy.
  • Generalization: While it covers three generation methods (SD, Inpainting, InsightFace), it may struggle with newer, unseen generative models (e.g., Flux, Midjourney v6).
  • Resolution: The model operates at 224x224 resolution. High-quality deepfakes might lose artifacts when downscaled.

Future Work

  • Upgrade to Video Detection: Instead of just looking at still images, the next step is to upgrade the model to analyze videos frame-by-frame. This will help catch unnatural movements, like weird blinking or strange background flickering.

  • Combine Audio and Visual Checks: To catch the most advanced deepfakes, the system needs to listen as well as look. By combining image detection with audio analysis (Multimodal Learning), the system can flag a video if a person's voice doesn't perfectly match their lip movements.

  • Train on the Newest AI: The training data must be expanded to include the latest and most realistic AI generators, such as OpenAI's Sora (for video) and Stable Diffusion 3 (for images).

  • Defend Against Hackers: Future versions should be trained using "adversarial examples." These are fake images intentionally injected with invisible digital noise designed to trick AI detectors. Training against these will make the system much harder to fool.

Downloads last month
46
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for shunda012/vit-deepfake-detector

Finetuned
(2498)
this model

Dataset used to train shunda012/vit-deepfake-detector

Space using shunda012/vit-deepfake-detector 1