pray commited on
Commit
738d527
·
1 Parent(s): 4fcf3b3

support gpu and parallel process

Browse files
scripts/prepare_database.py CHANGED
@@ -12,7 +12,7 @@ import argparse
12
  import logging
13
  from typing import List, Tuple
14
  from tqdm import tqdm
15
- import time
16
 
17
  import sys
18
  from pathlib import Path
@@ -59,12 +59,46 @@ def download_image(url: str, save_path: str, timeout: int = 10) -> bool:
59
  return False
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def process_dataset(
63
  csv_path: str,
64
  output_dir: str,
65
  detector_type: str = "retinaface",
66
  max_images: int = None,
67
- skip_existing: bool = True
 
 
68
  ) -> Tuple[List[str], List[str], List[str], List[np.ndarray]]:
69
  """
70
  Process dataset: download images and detect faces
@@ -75,6 +109,8 @@ def process_dataset(
75
  detector_type: "retinaface" or "haarcascade"
76
  max_images: Maximum number of images to process
77
  skip_existing: Skip if aligned face already exists
 
 
78
 
79
  Returns:
80
  Tuple of (names, image_paths, original_paths, aligned_faces)
@@ -101,12 +137,11 @@ def process_dataset(
101
 
102
  logger.info(f"Processing {len(rows)} images with {detector_type} detector...")
103
 
104
- names = []
105
- image_paths = []
106
- original_paths = []
107
- aligned_faces = []
108
 
109
- for row in tqdm(rows, desc="Processing images"):
110
  name = row['name']
111
  image_id = row['image_id']
112
  url = row['url']
@@ -121,23 +156,45 @@ def process_dataset(
121
  cropped_filename = f"{name}_{image_id}_crop.jpg".replace(" ", "_")
122
  cropped_path = os.path.join(cropped_dir, cropped_filename)
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Skip if both aligned and cropped faces already exist
125
- if skip_existing and os.path.exists(aligned_path) and os.path.exists(cropped_path) and os.path.exists(download_path):
126
  aligned_face = cv2.imread(aligned_path)
127
  if aligned_face is not None:
128
  names.append(name)
129
- image_paths.append(cropped_path) # Store path to cropped face
130
- original_paths.append(download_path) # Store path to original download
131
  aligned_faces.append(aligned_face)
132
  continue
133
 
134
- # Download image if not exists
135
  if not os.path.exists(download_path):
136
- if not download_image(url, download_path):
137
- continue
138
- time.sleep(0.1) # Rate limiting
139
 
140
- # Load image
141
  image = cv2.imread(download_path)
142
  if image is None:
143
  continue
@@ -148,13 +205,13 @@ def process_dataset(
148
  continue
149
 
150
  # Save both versions
151
- cv2.imwrite(aligned_path, aligned_face) # For embedding extraction
152
- cv2.imwrite(cropped_path, original_crop) # For display in gallery
153
 
154
  # Store results
155
  names.append(name)
156
- image_paths.append(cropped_path) # Store path to original crop for gallery
157
- original_paths.append(download_path) # Store path to original download
158
  aligned_faces.append(aligned_face)
159
 
160
  logger.info(f"Successfully processed {len(names)} faces")
@@ -164,7 +221,10 @@ def process_dataset(
164
 
165
  def extract_embeddings(
166
  aligned_faces: List[np.ndarray],
167
- model_path: str
 
 
 
168
  ) -> np.ndarray:
169
  """
170
  Extract face embeddings using MobileFaceNet
@@ -172,21 +232,36 @@ def extract_embeddings(
172
  Args:
173
  aligned_faces: List of aligned face images
174
  model_path: Path to ONNX model
 
 
 
175
 
176
  Returns:
177
  Numpy array of embeddings (N x embedding_dim)
178
  """
179
- logger.info("Extracting embeddings...")
180
-
181
  # Initialize embedding extractor
182
- extractor = FaceEmbeddingExtractor(model_path)
183
-
184
- embeddings = []
185
- for aligned_face in tqdm(aligned_faces, desc="Extracting embeddings"):
186
- embedding = extractor.extract_embedding(aligned_face)
187
- embeddings.append(embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- embeddings_array = np.array(embeddings, dtype=np.float32)
190
  logger.info(f"Extracted {len(embeddings_array)} embeddings with shape {embeddings_array.shape}")
191
 
192
  return embeddings_array
@@ -302,6 +377,37 @@ def main():
302
  default=False,
303
  help='Reset database by dropping existing collection and creating new one'
304
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  args = parser.parse_args()
307
 
@@ -312,13 +418,23 @@ def main():
312
  # Create necessary directories
313
  Config.create_directories()
314
 
 
 
 
 
 
 
 
 
315
  # Process dataset
316
  names, image_paths, original_paths, aligned_faces = process_dataset(
317
  csv_path=args.csv,
318
  output_dir=args.output_dir,
319
  detector_type=args.detector,
320
  max_images=args.max_images,
321
- skip_existing=args.skip_existing
 
 
322
  )
323
 
324
  if len(names) == 0:
@@ -326,7 +442,13 @@ def main():
326
  return
327
 
328
  # Extract embeddings
329
- embeddings = extract_embeddings(aligned_faces, model_path=args.model)
 
 
 
 
 
 
330
 
331
  # Populate database
332
  populate_database(
 
12
  import logging
13
  from typing import List, Tuple
14
  from tqdm import tqdm
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
 
17
  import sys
18
  from pathlib import Path
 
59
  return False
60
 
61
 
62
+ def download_images_parallel(download_tasks: List[Tuple[str, str]], max_workers: int = 10) -> List[bool]:
63
+ """
64
+ Download multiple images in parallel
65
+
66
+ Args:
67
+ download_tasks: List of (url, save_path) tuples
68
+ max_workers: Maximum number of parallel downloads
69
+
70
+ Returns:
71
+ List of success status for each download
72
+ """
73
+ results = [False] * len(download_tasks)
74
+
75
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
76
+ # Submit all download tasks
77
+ future_to_idx = {
78
+ executor.submit(download_image, url, save_path): idx
79
+ for idx, (url, save_path) in enumerate(download_tasks)
80
+ }
81
+
82
+ # Collect results as they complete
83
+ for future in as_completed(future_to_idx):
84
+ idx = future_to_idx[future]
85
+ try:
86
+ results[idx] = future.result()
87
+ except Exception as e:
88
+ logger.debug(f"Download task {idx} failed: {e}")
89
+ results[idx] = False
90
+
91
+ return results
92
+
93
+
94
  def process_dataset(
95
  csv_path: str,
96
  output_dir: str,
97
  detector_type: str = "retinaface",
98
  max_images: int = None,
99
+ skip_existing: bool = True,
100
+ parallel: bool = True,
101
+ max_workers: int = 20
102
  ) -> Tuple[List[str], List[str], List[str], List[np.ndarray]]:
103
  """
104
  Process dataset: download images and detect faces
 
109
  detector_type: "retinaface" or "haarcascade"
110
  max_images: Maximum number of images to process
111
  skip_existing: Skip if aligned face already exists
112
+ parallel: Use parallel processing for downloads (default: True)
113
+ max_workers: Number of parallel workers for downloads (default: 20)
114
 
115
  Returns:
116
  Tuple of (names, image_paths, original_paths, aligned_faces)
 
137
 
138
  logger.info(f"Processing {len(rows)} images with {detector_type} detector...")
139
 
140
+ # Step 1: Prepare file paths and identify images to download
141
+ download_tasks = []
142
+ row_info = [] # Store (row, download_path, aligned_path, cropped_path)
 
143
 
144
+ for row in rows:
145
  name = row['name']
146
  image_id = row['image_id']
147
  url = row['url']
 
156
  cropped_filename = f"{name}_{image_id}_crop.jpg".replace(" ", "_")
157
  cropped_path = os.path.join(cropped_dir, cropped_filename)
158
 
159
+ row_info.append((row, download_path, aligned_path, cropped_path))
160
+
161
+ # Add to download queue if image doesn't exist
162
+ if not os.path.exists(download_path):
163
+ download_tasks.append((url, download_path))
164
+
165
+ # Step 2: Download missing images (parallel or sequential)
166
+ if download_tasks:
167
+ if parallel:
168
+ logger.info(f"Downloading {len(download_tasks)} images in parallel (workers={max_workers})...")
169
+ download_images_parallel(download_tasks, max_workers=max_workers)
170
+ else:
171
+ logger.info(f"Downloading {len(download_tasks)} images sequentially...")
172
+ for url, save_path in tqdm(download_tasks, desc="Downloading images"):
173
+ download_image(url, save_path)
174
+
175
+ # Step 3: Process faces (detection and alignment)
176
+ names = []
177
+ image_paths = []
178
+ original_paths = []
179
+ aligned_faces = []
180
+
181
+ for row, download_path, aligned_path, cropped_path in tqdm(row_info, desc="Detecting and aligning faces"):
182
+ name = row['name']
183
+
184
  # Skip if both aligned and cropped faces already exist
185
+ if skip_existing and os.path.exists(aligned_path) and os.path.exists(cropped_path):
186
  aligned_face = cv2.imread(aligned_path)
187
  if aligned_face is not None:
188
  names.append(name)
189
+ image_paths.append(cropped_path)
190
+ original_paths.append(download_path)
191
  aligned_faces.append(aligned_face)
192
  continue
193
 
194
+ # Load image
195
  if not os.path.exists(download_path):
196
+ continue
 
 
197
 
 
198
  image = cv2.imread(download_path)
199
  if image is None:
200
  continue
 
205
  continue
206
 
207
  # Save both versions
208
+ cv2.imwrite(aligned_path, aligned_face)
209
+ cv2.imwrite(cropped_path, original_crop)
210
 
211
  # Store results
212
  names.append(name)
213
+ image_paths.append(cropped_path)
214
+ original_paths.append(download_path)
215
  aligned_faces.append(aligned_face)
216
 
217
  logger.info(f"Successfully processed {len(names)} faces")
 
221
 
222
  def extract_embeddings(
223
  aligned_faces: List[np.ndarray],
224
+ model_path: str,
225
+ device: str = "cuda",
226
+ batch_size: int = 32,
227
+ use_batch: bool = True
228
  ) -> np.ndarray:
229
  """
230
  Extract face embeddings using MobileFaceNet
 
232
  Args:
233
  aligned_faces: List of aligned face images
234
  model_path: Path to ONNX model
235
+ device: Device to use ("cuda" or "cpu")
236
+ batch_size: Batch size for batch processing (default: 32)
237
+ use_batch: Use batch processing for faster inference (default: True)
238
 
239
  Returns:
240
  Numpy array of embeddings (N x embedding_dim)
241
  """
 
 
242
  # Initialize embedding extractor
243
+ extractor = FaceEmbeddingExtractor(model_path, device=device)
244
+
245
+ if use_batch and batch_size > 1:
246
+ logger.info(f"Extracting embeddings with batch processing (batch_size={batch_size}, device={device})...")
247
+ embeddings = []
248
+ num_batches = (len(aligned_faces) + batch_size - 1) // batch_size
249
+
250
+ for i in tqdm(range(0, len(aligned_faces), batch_size),
251
+ desc="Extracting embeddings", total=num_batches):
252
+ batch = aligned_faces[i:i + batch_size]
253
+ batch_embeddings = extractor.extract_embeddings_batch(batch)
254
+ embeddings.append(batch_embeddings)
255
+
256
+ embeddings_array = np.vstack(embeddings)
257
+ else:
258
+ logger.info(f"Extracting embeddings sequentially (device={device})...")
259
+ embeddings = []
260
+ for aligned_face in tqdm(aligned_faces, desc="Extracting embeddings"):
261
+ embedding = extractor.extract_embedding(aligned_face)
262
+ embeddings.append(embedding)
263
+ embeddings_array = np.array(embeddings, dtype=np.float32)
264
 
 
265
  logger.info(f"Extracted {len(embeddings_array)} embeddings with shape {embeddings_array.shape}")
266
 
267
  return embeddings_array
 
377
  default=False,
378
  help='Reset database by dropping existing collection and creating new one'
379
  )
380
+ parser.add_argument(
381
+ '--device',
382
+ type=str,
383
+ choices=['cuda', 'cpu'],
384
+ default='cuda',
385
+ help='Device for ONNX inference: cuda (GPU) or cpu (default: cuda)'
386
+ )
387
+ parser.add_argument(
388
+ '--parallel',
389
+ action='store_true',
390
+ default=True,
391
+ help='Use parallel processing for downloads and batch inference (default: True)'
392
+ )
393
+ parser.add_argument(
394
+ '--no-parallel',
395
+ dest='parallel',
396
+ action='store_false',
397
+ help='Disable parallel processing, use sequential mode'
398
+ )
399
+ parser.add_argument(
400
+ '--batch_size',
401
+ type=int,
402
+ default=32,
403
+ help='Batch size for embedding extraction (default: 32, only used with --parallel)'
404
+ )
405
+ parser.add_argument(
406
+ '--max_workers',
407
+ type=int,
408
+ default=20,
409
+ help='Number of parallel workers for image downloads (default: 20, only used with --parallel)'
410
+ )
411
 
412
  args = parser.parse_args()
413
 
 
418
  # Create necessary directories
419
  Config.create_directories()
420
 
421
+ # Log configuration
422
+ logger.info(f"Configuration:")
423
+ logger.info(f" Device: {args.device}")
424
+ logger.info(f" Parallel mode: {args.parallel}")
425
+ if args.parallel:
426
+ logger.info(f" Batch size: {args.batch_size}")
427
+ logger.info(f" Max workers: {args.max_workers}")
428
+
429
  # Process dataset
430
  names, image_paths, original_paths, aligned_faces = process_dataset(
431
  csv_path=args.csv,
432
  output_dir=args.output_dir,
433
  detector_type=args.detector,
434
  max_images=args.max_images,
435
+ skip_existing=args.skip_existing,
436
+ parallel=args.parallel,
437
+ max_workers=args.max_workers
438
  )
439
 
440
  if len(names) == 0:
 
442
  return
443
 
444
  # Extract embeddings
445
+ embeddings = extract_embeddings(
446
+ aligned_faces,
447
+ model_path=args.model,
448
+ device=args.device,
449
+ batch_size=args.batch_size if args.parallel else 1,
450
+ use_batch=args.parallel
451
+ )
452
 
453
  # Populate database
454
  populate_database(
src/face_matcher/config.py CHANGED
@@ -17,8 +17,9 @@ PROJECT_ROOT = Path(__file__).parent.parent.parent
17
  class ModelConfig:
18
  """Model configuration"""
19
  model_path: str = str(PROJECT_ROOT / "models/MobileFaceNet.onnx")
20
- device: str = "cpu" # "cpu" or "cuda"
21
  embedding_dim: int = 128
 
22
 
23
 
24
  @dataclass
 
17
  class ModelConfig:
18
  """Model configuration"""
19
  model_path: str = str(PROJECT_ROOT / "models/MobileFaceNet.onnx")
20
+ device: str = "cuda" # "cpu" or "cuda"
21
  embedding_dim: int = 128
22
+ batch_size: int = 32 # Batch size for embedding extraction
23
 
24
 
25
  @dataclass
src/face_matcher/core/recognition.py CHANGED
@@ -132,6 +132,45 @@ class FaceEmbeddingExtractor:
132
 
133
  return embedding
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def extract_embedding_from_path(self, image_path: str) -> Optional[np.ndarray]:
136
  """
137
  Extract face embedding from image file path
 
132
 
133
  return embedding
134
 
135
+ def extract_embeddings_batch(self, images: list) -> np.ndarray:
136
+ """
137
+ Extract face embeddings from multiple aligned face images (batch processing)
138
+
139
+ Note: This runs inference sequentially but is optimized for GPU execution.
140
+ True batching requires a dynamic batch size ONNX model.
141
+
142
+ Args:
143
+ images: List of aligned face images (BGR or RGB format)
144
+
145
+ Returns:
146
+ Numpy array of face embeddings (L2 normalized), shape (N, embedding_dim)
147
+ """
148
+ if len(images) == 0:
149
+ return np.array([])
150
+
151
+ embeddings_list = []
152
+
153
+ # Process each image (model has fixed batch size of 1)
154
+ for img in images:
155
+ # Preprocess image
156
+ input_data = self.preprocess_image(img)
157
+
158
+ # Run inference
159
+ outputs = self.session.run([self.output_name], {self.input_name: input_data})
160
+
161
+ # Get embedding (remove batch dimension)
162
+ embedding = outputs[0][0]
163
+
164
+ # L2 normalize
165
+ embedding = embedding / np.linalg.norm(embedding)
166
+
167
+ embeddings_list.append(embedding)
168
+
169
+ # Stack all embeddings
170
+ embeddings = np.array(embeddings_list, dtype=np.float32)
171
+
172
+ return embeddings
173
+
174
  def extract_embedding_from_path(self, image_path: str) -> Optional[np.ndarray]:
175
  """
176
  Extract face embedding from image file path