Deepfake Detection with Vision Transformer (ViT) and LoRA
Overview
This project focuses on fine-tuning a pre-trained Vision Transformer (ViT) model for the task of deepfake detection. Leveraging the power of Transfer Learning and Low-Rank Adaptation (LoRA), we adapt a google/vit-base-patch16-224-in21k model (specifically initialized from prithivMLmods/Deep-Fake-Detector-v2-Model) to classify images as either "Real" or "Fake".
The model is trained on a balanced subset of the OpenRL/DeepFakeFace dataset, containing 12,000 images. Due to hardware constraints, this subset was carefully selected to ensure diverse representation of various generative techniques (Stable Diffusion, Inpainting, InsightFace).
Key Features
- Efficient Fine-Tuning: Uses LoRA (Low-Rank Adaptation) to fine-tune the heavy ViT model with significantly fewer trainable parameters (
r=16), making training feasible on consumer GPUs. - Robust Data Augmentation: Implements
ColorJitter,RandomResizedCrop, andRandomHorizontalFlipto improve model generalization. - Video Inference Support: Includes a custom pipeline for analyzing videos frame-by-frame to detect deepfakes in motion.
- High Confidence Thresholding: Inference logic implements a strict threshold (e.g., 90%) for classifying an image as "Real" to minimize false negatives in security-critical contexts.
Model Architecture
- Base Model:
google/vit-base-patch16-224-in21k - Pre-trained Weights: Sourced from
prithivMLmods/Deep-Fake-Detector-v2-Model - Architecture Type: Vision Transformer (ViT)
- Fine-tuning Method: PEFT (Parameter-Efficient Fine-Tuning) with LoRA
- Target Modules:
query,value - Rank (r): 16
- Alpha: 32
- Dropout: 0.1
- Target Modules:
Dataset
The project uses the DeepFakeFace(DFF) dataset OpenRL/DeepFakeFace. A balanced subset of 12,000 images was curated using a custom selection script (select_dataset.py).
Data Distribution (12,000 Images Total)
| Class | Count | Source / Generator | Description |
|---|---|---|---|
| Real | 6,000 | wiki dataset |
Real human faces from Wikipedia |
| Fake | 2,000 | text2img |
Generated via Stable Diffusion v1.5 |
| Fake | 2,000 | inpainting |
Generated via SD Inpainting |
| Fake | 2,000 | insight |
Generated via InsightFace |
Data Splits
- Train: 70% (8,400 images)
- Validation: 20% (2,401 images)
- Test: 10% (1,199 images)
Training Details
The model was trained using the Hugging Face Trainer API.
Hyperparameters
- Optimizer: AdamW
- Learning Rate: 1e-4
- Scheduler: Cosine with Warmup (ratio 0.1)
- Batch Size: 16 (Train) / 32 (Eval)
- Epochs: 10
- Weight Decay: 0.01
- Precision: FP16 (Mixed Precision)
- Loss Function: CrossEntropyLoss
Data Augmentation
To prevent overfitting, the following transformations are applied during training:
- RandomResizedCrop: Scales 0.8-1.0 of the original image.
- RandomHorizontalFlip: Probability 0.5.
- ColorJitter: Brightness (±20%), Contrast (±20%), Saturation (±20%), Hue (±10%).
- Normalization: Standard ImageNet mean and std.
Performance
Inference
1. Using Hugging Face Pipeline
# Use a pipeline as a high-level helper
from transformers import pipeline
# Load the model
pipe = pipeline("image-classification", model="shunda012/vit-deepfake-detector")
# Predict on an image
result = pipe("path_to_image.jpg")
print(result)
2. Using PyTorch
from transformers import ViTForImageClassification, ViTImageProcessor
import torch
from PIL import Image
# Load Base Model & processor
model = ViTForImageClassification.from_pretrained("shunda012/vit-deepfake-detector")
processor = ViTImageProcessor.from_pretrained("shunda012/vit-deepfake-detector")
# Load and preprocess the image
image = Image.open("path_to_image.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probs, dim=-1).item()
# Print probabilities for each class
print(f"Fake Prob: {probs[0][0]:.2f}, Real Prob: {probs[0][1]:.2f}")
# Map class index to label
label = model.config.id2label[predicted_class]
print(f"Predicted Label: {label}")
Limitations
- Hardware and Data Limits: Because of hardware limits, the model was trained on only 12,000 images, which is a small fraction of the original 120k dataset. Scaling this up would likely improve accuracy.
- Generalization: While it covers three generation methods (SD, Inpainting, InsightFace), it may struggle with newer, unseen generative models (e.g., Flux, Midjourney v6).
- Resolution: The model operates at 224x224 resolution. High-quality deepfakes might lose artifacts when downscaled.
Future Work
Upgrade to Video Detection: Instead of just looking at still images, the next step is to upgrade the model to analyze videos frame-by-frame. This will help catch unnatural movements, like weird blinking or strange background flickering.
Combine Audio and Visual Checks: To catch the most advanced deepfakes, the system needs to listen as well as look. By combining image detection with audio analysis (Multimodal Learning), the system can flag a video if a person's voice doesn't perfectly match their lip movements.
Train on the Newest AI: The training data must be expanded to include the latest and most realistic AI generators, such as OpenAI's Sora (for video) and Stable Diffusion 3 (for images).
Defend Against Hackers: Future versions should be trained using "adversarial examples." These are fake images intentionally injected with invisible digital noise designed to trick AI detectors. Training against these will make the system much harder to fool.
- Downloads last month
- 46
Model tree for shunda012/vit-deepfake-detector
Base model
google/vit-base-patch16-224-in21k
