| import os |
| import sys |
| import cv2 |
| import math |
| import json |
| import torch |
| import argparse |
| import numpy as np |
| from PIL import Image |
| from PIL import ImageOps |
| from pathlib import Path |
| import multiprocessing as mp |
| from vitra.models import VITRA_Paligemma, load_model |
| from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer, recon_traj |
| from vitra.utils.config_utils import load_config |
| from vitra.datasets.human_dataset import pad_state_human, pad_action |
| from scipy.spatial.transform import Rotation as R |
| from vitra.datasets.dataset_utils import ( |
| compute_new_intrinsics_resize, |
| calculate_fov, |
| ActionFeature, |
| StateFeature, |
| ) |
|
|
| repo_root = Path(__file__).parent.parent |
| sys.path.insert(0, str(repo_root)) |
|
|
| from visualization.visualize_core import HandVisualizer, normalize_camera_intrinsics, save_to_video, Renderer, process_single_hand_labels |
| from visualization.visualize_core import Config as HandConfig |
|
|
| def main(): |
| """ |
| Main execution function for hand action prediction and visualization. |
| |
| This function uses a multi-process architecture to separate hand reconstruction |
| and VLA inference into independent processes, preventing CUDA conflicts. |
| |
| Workflow: |
| 1. Parse command-line arguments and load model configurations |
| 2. Initialize persistent services: |
| - HandReconstructionService: Runs HAWOR + MOGE models in separate process |
| - VLAInferenceService: Runs VLA model in separate process |
| 3. Load or reconstruct hand state: |
| - Uses precomputed .npy file if available (same stem as image) |
| - Otherwise runs hand reconstruction service |
| 4. Prepare input data: |
| - Load and resize image |
| - Extract hand state (translation, rotation, pose) for left/right hands |
| - Create state and action masks based on which hands to predict |
| 5. Run VLA inference to predict future hand actions (multiple samples for diversity) |
| 6. Reconstruct absolute hand trajectories from relative actions |
| 7. Visualize predicted hand motions using MANO hand model |
| 8. Generate grid layout video showing all samples and save to file |
| 9. Cleanup: Shutdown persistent services and free GPU memory |
| |
| """ |
| parser = argparse.ArgumentParser(description="Hand VLA inference and visualization.") |
| |
| |
| parser.add_argument('--config_path', type=str, required=True, help='Path to model configuration JSON file') |
| parser.add_argument('--model_path', type=str, default=None, help='Path to model checkpoint (overrides config)') |
| parser.add_argument('--statistics_path', type=str, default=None, help='Path to normalization statistics JSON (overrides config)') |
| |
| |
| parser.add_argument('--image_path', type=str, required=True, help='Path to input image file') |
| parser.add_argument('--hand_path', type=str, default=None, help='Path to hand state .npy file (optional, will run reconstruction if not provided)') |
| parser.add_argument('--video_path', type=str, default='./example_human_inf.mp4', help='Path to save output visualization video') |
| |
| |
| parser.add_argument('--hawor_model_path', type=str, default='./weights/hawor/checkpoints/hawor.ckpt', help='Path to HAWOR model weights') |
| parser.add_argument('--detector_path', type=str, default='./weights/hawor/external/detector.pt', help='Path to hand detector model') |
| parser.add_argument('--moge_model_name', type=str, default='Ruicheng/moge-2-vitl', help='MOGE model name from Hugging Face') |
| parser.add_argument('--mano_path', type=str, default='/home/t-qixiuli/repo/VITRA/weights/mano', help='Path to MANO model files') |
| |
| |
| |
| parser.add_argument('--use_left', action='store_true', help='Enable left hand prediction') |
| parser.add_argument('--use_right', action='store_true', help='Enable right hand prediction') |
| parser.add_argument('--instruction', type=str, default="Left hand: Put the trash into the garbage. Right hand: None.", help='Text instruction for hand motion') |
| parser.add_argument('--sample_times', type=int, default=4, help='Number of action samples to generate for diversity') |
| parser.add_argument('--fps', type=int, default=8, help='Frames per second for output video') |
| |
| |
| parser.add_argument('--save_state_local', action='store_true', help='Save hand state locally as .npy file') |
|
|
| |
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| args = parser.parse_args() |
| |
| |
| if not args.use_left and not args.use_right: |
| raise ValueError("At least one of --use_left or --use_right must be specified.") |
| |
| |
| configs = load_config(args.config_path) |
| |
| |
| if args.model_path is not None: |
| configs['model_load_path'] = args.model_path |
| if args.statistics_path is not None: |
| configs['statistics_path'] = args.statistics_path |
|
|
| |
| image_path_obj = Path(args.image_path) |
| npy_path = image_path_obj.with_suffix('.npy') |
|
|
| |
| print("Initializing services...") |
| if npy_path.exists(): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print(f"Found precomputed hand state results: {npy_path}. Using the state instead of running hand recon.") |
| hand_data = np.load(npy_path, allow_pickle=True).item() |
|
|
| hand_recon_service = None |
| else: |
| print(f"No precomputed hand state .npy found at {npy_path}. Starting hand reconstruction service.") |
| |
| |
| hand_recon_service = HandReconstructionService(args) |
| hand_data = None |
|
|
|
|
| |
| vla_service = VLAInferenceService(configs) |
| |
| |
| hand_config = HandConfig(args) |
| hand_config.FPS = args.fps |
| visualizer = HandVisualizer(hand_config, render_gradual_traj=False) |
|
|
| try: |
| if hand_data is None: |
| |
| print("Running hand reconstruction...") |
| hand_data = hand_recon_service.reconstruct(args.image_path) |
| if args.save_state_local: |
| |
| np.save(npy_path, hand_data, allow_pickle=True) |
| print(f"Saved reconstructed hand state to {npy_path}") |
|
|
| |
| image = Image.open(args.image_path) |
| ori_w, ori_h = image.size |
|
|
| |
| try: |
| image = ImageOps.exif_transpose(image) |
| except Exception: |
| pass |
|
|
| image_resized = resize_short_side_to_target(image, target=224) |
| w, h = image_resized.size |
|
|
| use_right = args.use_right |
| use_left = args.use_left |
|
|
| |
| current_state_left = None |
| current_state_right = None |
| |
| if use_right: |
| current_state_right, beta_right, fov_x, _ = get_state(hand_data, hand_side='right') |
| if use_left: |
| current_state_left, beta_left, fov_x, _ = get_state(hand_data, hand_side='left') |
| |
| fov_x = fov_x * np.pi /180 |
| f_ori = ori_w / np.tan(fov_x / 2) /2 |
| fov_y = 2 * np.arctan(ori_h / (2 * f_ori)) |
|
|
| f = w / np.tan(fov_x / 2) /2 |
| intrinsics = np.array([ |
| [f, 0, w/2], |
| [0, f, h/2], |
| [0, 0, 1] |
| ]) |
|
|
| |
| if current_state_left is None and current_state_right is None: |
| raise ValueError("Both current_state_left and current_state_right are None") |
| |
| state_left = current_state_left if use_left else np.zeros_like(current_state_right) |
| beta_left = beta_left if use_left else np.zeros_like(beta_right) |
| state_right = current_state_right if use_right else np.zeros_like(current_state_left) |
| beta_right = beta_right if use_right else np.zeros_like(beta_left) |
| |
| state = np.concatenate([state_left, beta_left, state_right, beta_right], axis=0) |
| state_mask = np.array([use_left, use_right], dtype=bool) |
| |
| |
| chunk_size = configs.get('fwd_pred_next_n', 16) |
| action_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (chunk_size, 1)) |
|
|
| fov = np.array([fov_x, fov_y], dtype=np.float32) |
| |
| image_resized_np = np.array(image_resized) |
|
|
| |
| instruction = args.instruction |
|
|
| |
| print(f"Running VLA inference...") |
| sample_times = args.sample_times |
| unnorm_action = vla_service.predict( |
| image=image_resized_np, |
| instruction=instruction, |
| state=state, |
| state_mask=state_mask, |
| action_mask=action_mask, |
| fov=fov, |
| num_ddim_steps=10, |
| cfg_scale=5.0, |
| sample_times=sample_times, |
| ) |
| |
| fx_exo = intrinsics[0, 0] |
| fy_exo = intrinsics[1, 1] |
| renderer = Renderer(w, h, (fx_exo, fy_exo), 'cuda') |
|
|
| T = len(action_mask) + 1 |
| traj_right_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| traj_left_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
|
|
| traj_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (T, 1)) |
| left_hand_mask = traj_mask[:, 0] |
| right_hand_mask = traj_mask[:, 1] |
|
|
| |
| hand_mask = (left_hand_mask, right_hand_mask) |
|
|
| all_rendered_frames = [] |
| |
| |
| for i in range(sample_times): |
| traj_right = traj_right_list[i] |
| traj_left = traj_left_list[i] |
| |
| if use_left: |
| traj_left = recon_traj( |
| state=state_left, |
| rel_action=unnorm_action[i, :, 0:51], |
| ) |
| if use_right: |
| traj_right = recon_traj( |
| state=state_right, |
| rel_action=unnorm_action[i, :, 51:102], |
| ) |
| |
| left_hand_labels = { |
| 'transl_worldspace': traj_left[:, 0:3], |
| 'global_orient_worldspace': R.from_euler('xyz', traj_left[:, 3:6]).as_matrix(), |
| 'hand_pose': euler_traj_to_rotmat_traj(traj_left[:, 6:51], T), |
| 'beta': beta_left, |
| } |
| right_hand_labels = { |
| 'transl_worldspace': traj_right[:, 0:3], |
| 'global_orient_worldspace': R.from_euler('xyz', traj_right[:, 3:6]).as_matrix(), |
| 'hand_pose': euler_traj_to_rotmat_traj(traj_right[:, 6:51], T), |
| 'beta': beta_right, |
| } |
|
|
| verts_left_worldspace, _ = process_single_hand_labels(left_hand_labels, left_hand_mask, visualizer.mano, is_left=True) |
| verts_right_worldspace, _ = process_single_hand_labels(right_hand_labels, right_hand_mask, visualizer.mano, is_left=False) |
|
|
| hand_traj_wordspace = (verts_left_worldspace, verts_right_worldspace) |
| |
| R_w2c = np.broadcast_to(np.eye(3), (T, 3, 3)).copy() |
| t_w2c = np.zeros((T, 3, 1), dtype=np.float32) |
|
|
| extrinsics = (R_w2c, t_w2c) |
|
|
| |
| image_bgr = image_resized_np[..., ::-1] |
| resize_video_frames = [image_bgr] * T |
| save_frames = visualizer._render_hand_trajectory( |
| resize_video_frames, |
| hand_traj_wordspace, |
| hand_mask, |
| extrinsics, |
| renderer, |
| mode='first' |
| ) |
| |
| all_rendered_frames.append(save_frames) |
| |
| |
| |
| |
| num_frames = len(all_rendered_frames[0]) |
| |
| |
| grid_cols = math.ceil(math.sqrt(sample_times)) |
| grid_rows = math.ceil(sample_times / grid_cols) |
| |
| |
| combined_frames = [] |
| for frame_idx in range(num_frames): |
| |
| sample_frames = [all_rendered_frames[i][frame_idx] for i in range(sample_times)] |
| |
| |
| while len(sample_frames) < grid_rows * grid_cols: |
| black_frame = np.zeros_like(sample_frames[0]) |
| sample_frames.append(black_frame) |
| |
| |
| rows = [] |
| for row_idx in range(grid_rows): |
| row_frames = sample_frames[row_idx * grid_cols:(row_idx + 1) * grid_cols] |
| row_concat = np.concatenate(row_frames, axis=1) |
| rows.append(row_concat) |
| |
| |
| combined_frame = np.concatenate(rows, axis=0) |
| combined_frames.append(combined_frame) |
|
|
| |
| save_to_video(combined_frames, f'{args.video_path}', fps=hand_config.FPS) |
| print(f"Combined video with {sample_times} samples saved to {args.video_path}") |
| |
| finally: |
| |
| print("Shutting down services...") |
| if hand_recon_service is not None: |
| hand_recon_service.shutdown() |
| vla_service.shutdown() |
| print("All services shut down successfully") |
| |
|
|
| def get_state(hand_data, hand_side='right'): |
| """ |
| Load and extract hand state from hand data. |
| |
| Args: |
| hand_data (dict): Dictionary containing hand data |
| hand_side (str): Which hand to extract, either 'left' or 'right'. Default is 'right'. |
| |
| Returns: |
| tuple: (state_t0, beta, fov_x, None) where: |
| - state_t0 (np.ndarray): Hand state [51] containing translation (3), |
| global rotation (3 euler angles), and hand pose (45 euler angles) |
| - beta (np.ndarray): MANO shape parameters [10] |
| - fov_x (float): Horizontal field of view in degrees |
| - None: Placeholder for optional text annotations |
| """ |
| if hand_side not in ['left', 'right']: |
| raise ValueError(f"hand_side must be 'left' or 'right', got '{hand_side}'") |
| |
| hand_pose_t0 = hand_data[hand_side][0]['hand_pose'] |
| hand_pose_t0_euler = R.from_matrix(hand_pose_t0).as_euler('xyz', degrees=False) |
| hand_pose_t0_euler = hand_pose_t0_euler.reshape(-1) |
| global_orient_mat_t0 = hand_data[hand_side][0]['global_orient'] |
| R_t0_euler = R.from_matrix(global_orient_mat_t0).as_euler('xyz', degrees=False) |
| transl_t0 = hand_data[hand_side][0]['transl'] |
| state_t0 = np.concatenate([transl_t0, R_t0_euler, hand_pose_t0_euler]) |
| fov_x = hand_data['fov_x'] |
|
|
| return state_t0, hand_data[hand_side][0]['beta'], fov_x, None |
|
|
| def euler_traj_to_rotmat_traj(euler_traj, T): |
| """ |
| Convert Euler angle trajectory to rotation matrix trajectory. |
| |
| Converts a sequence of hand poses represented as Euler angles into |
| rotation matrices suitable for MANO model input. |
| |
| Args: |
| euler_traj (np.ndarray): Hand pose trajectory as Euler angles. |
| Shape: [T, 45] where T is number of timesteps |
| and 45 = 15 joints * 3 Euler angles per joint |
| T (int): Number of timesteps in the trajectory |
| |
| Returns: |
| np.ndarray: Rotation matrix trajectory. Shape: [T, 15, 3, 3] |
| where each [3, 3] block is a rotation matrix for one joint |
| """ |
| hand_pose = euler_traj.reshape(-1, 3) |
| pose_matrices = R.from_euler('xyz', hand_pose).as_matrix() |
| pose_matrices = pose_matrices.reshape(T, 15, 3, 3) |
|
|
| return pose_matrices |
|
|
|
|
| def _hand_reconstruction_worker(args_dict, task_queue, result_queue): |
| """ |
| Persistent worker for hand reconstruction that runs in a separate process. |
| Keeps model loaded and processes multiple requests until shutdown signal. |
| """ |
| from data.tools.hand_recon_core import Config, HandReconstructor |
| |
| hand_reconstructor = None |
| |
| try: |
| |
| class ArgsObj: |
| pass |
| args_obj = ArgsObj() |
| for key, value in args_dict.items(): |
| setattr(args_obj, key, value) |
| |
| |
| print("[HandRecon Process] Initializing hand reconstructor...") |
| config = Config(args_obj) |
| hand_reconstructor = HandReconstructor(config=config, device='cuda') |
| print("[HandRecon Process] Hand reconstructor ready") |
| |
| |
| result_queue.put({'type': 'ready'}) |
| |
| |
| while True: |
| task = task_queue.get() |
| |
| if task['type'] == 'shutdown': |
| print("[HandRecon Process] Received shutdown signal") |
| break |
| |
| elif task['type'] == 'reconstruct': |
| try: |
| image_path = task['image_path'] |
| image = cv2.imread(image_path) |
| if image is None: |
| raise ValueError(f"Failed to load image from {image_path}") |
| |
| image_list = [image] |
| recon_results = hand_reconstructor.recon(image_list) |
| |
| result_queue.put({ |
| 'type': 'result', |
| 'success': True, |
| 'data': recon_results |
| }) |
| |
| except Exception as e: |
| import traceback |
| result_queue.put({ |
| 'type': 'result', |
| 'success': False, |
| 'error': str(e), |
| 'traceback': traceback.format_exc() |
| }) |
| |
| except Exception as e: |
| import traceback |
| result_queue.put({ |
| 'type': 'error', |
| 'error': str(e), |
| 'traceback': traceback.format_exc() |
| }) |
| |
| finally: |
| |
| if hand_reconstructor is not None: |
| del hand_reconstructor |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| print("[HandRecon Process] Cleaned up and exiting") |
|
|
|
|
| def _vla_inference_worker(configs_dict, task_queue, result_queue): |
| """ |
| Persistent worker for VLA model inference that runs in a separate process. |
| Keeps model loaded and processes multiple requests until shutdown signal. |
| """ |
| from vitra.models import load_model |
| from vitra.utils.data_utils import load_normalizer |
| from vitra.datasets.human_dataset import pad_state_human, pad_action |
| from vitra.datasets.dataset_utils import ActionFeature, StateFeature |
| |
| model = None |
| normalizer = None |
| |
| try: |
| |
| print("[VLA Process] Loading VLA model...") |
| model = load_model(configs_dict).cuda() |
| model.eval() |
| normalizer = load_normalizer(configs_dict) |
| print(f"[VLA Process] VLA model ready.") |
| |
| |
| result_queue.put({'type': 'ready'}) |
| |
| |
| while True: |
| task = task_queue.get() |
| |
| if task['type'] == 'shutdown': |
| print("[VLA Process] Received shutdown signal") |
| break |
| |
| elif task['type'] == 'predict': |
| try: |
| image = task['image'] |
| instruction = task['instruction'] |
| state = task['state'] |
| state_mask = task['state_mask'] |
| action_mask = task['action_mask'] |
| fov = task['fov'] |
| num_ddim_steps = task.get('num_ddim_steps', 10) |
| cfg_scale = task.get('cfg_scale', 5.0) |
| sample_times = task.get('sample_times', 1) |
| |
| |
| norm_state = normalizer.normalize_state(state.copy()) |
| |
| |
| unified_action_dim = ActionFeature.ALL_FEATURES[1] |
| unified_state_dim = StateFeature.ALL_FEATURES[1] |
| |
| unified_state, unified_state_mask = pad_state_human( |
| state=norm_state, |
| state_mask=state_mask, |
| action_dim=normalizer.action_mean.shape[0], |
| state_dim=normalizer.state_mean.shape[0], |
| unified_state_dim=unified_state_dim, |
| ) |
| _, unified_action_mask = pad_action( |
| actions=None, |
| action_mask=action_mask.copy(), |
| action_dim=normalizer.action_mean.shape[0], |
| unified_action_dim=unified_action_dim |
| ) |
| |
| |
| fov = torch.from_numpy(fov).unsqueeze(0) |
| unified_state = unified_state.unsqueeze(0) |
| unified_state_mask = unified_state_mask.unsqueeze(0) |
| unified_action_mask = unified_action_mask.unsqueeze(0) |
| |
| |
| norm_action = model.predict_action( |
| image=image, |
| instruction=instruction, |
| current_state=unified_state, |
| current_state_mask=unified_state_mask, |
| action_mask_torch=unified_action_mask, |
| num_ddim_steps=num_ddim_steps, |
| cfg_scale=cfg_scale, |
| fov=fov, |
| sample_times=sample_times, |
| ) |
| |
| |
| norm_action = norm_action[:, :, :102] |
| unnorm_action = normalizer.unnormalize_action(norm_action) |
| |
| |
| if isinstance(unnorm_action, torch.Tensor): |
| unnorm_action_np = unnorm_action.cpu().numpy() |
| else: |
| unnorm_action_np = np.array(unnorm_action) |
| |
| result_queue.put({ |
| 'type': 'result', |
| 'success': True, |
| 'data': unnorm_action_np |
| }) |
| |
| except Exception as e: |
| import traceback |
| result_queue.put({ |
| 'type': 'result', |
| 'success': False, |
| 'error': str(e), |
| 'traceback': traceback.format_exc() |
| }) |
| |
| except Exception as e: |
| import traceback |
| result_queue.put({ |
| 'type': 'error', |
| 'error': str(e), |
| 'traceback': traceback.format_exc() |
| }) |
| |
| finally: |
| |
| if model is not None: |
| del model |
| if normalizer is not None: |
| del normalizer |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| print("[VLA Process] Cleaned up and exiting") |
|
|
|
|
| class HandReconstructionService: |
| """Service wrapper for persistent hand reconstruction process""" |
| |
| def __init__(self, args): |
| self.ctx = mp.get_context('spawn') |
| self.task_queue = self.ctx.Queue() |
| self.result_queue = self.ctx.Queue() |
| |
| |
| args_dict = { |
| 'hawor_model_path': args.hawor_model_path, |
| 'detector_path': args.detector_path, |
| 'moge_model_name': args.moge_model_name, |
| 'mano_path': args.mano_path, |
| } |
| |
| |
| self.process = self.ctx.Process( |
| target=_hand_reconstruction_worker, |
| args=(args_dict, self.task_queue, self.result_queue) |
| ) |
| self.process.start() |
| |
| |
| ready_msg = self.result_queue.get() |
| if ready_msg['type'] == 'ready': |
| print("Hand reconstruction service initialized") |
| elif ready_msg['type'] == 'error': |
| raise RuntimeError(f"Failed to initialize hand reconstruction: {ready_msg['error']}") |
| |
| def reconstruct(self, image_path): |
| """Request hand reconstruction for an image""" |
| self.task_queue.put({ |
| 'type': 'reconstruct', |
| 'image_path': image_path |
| }) |
| |
| result = self.result_queue.get() |
| if result['type'] == 'result' and result['success']: |
| return result['data'] |
| else: |
| raise RuntimeError(f"Hand reconstruction failed: {result.get('error', 'Unknown error')}") |
| |
| def shutdown(self): |
| """Shutdown the persistent process""" |
| self.task_queue.put({'type': 'shutdown'}) |
| self.process.join(timeout=10) |
| if self.process.is_alive(): |
| self.process.terminate() |
| self.process.join() |
|
|
|
|
| class VLAInferenceService: |
| """Service wrapper for persistent VLA inference process""" |
| |
| def __init__(self, configs): |
| self.ctx = mp.get_context('spawn') |
| self.task_queue = self.ctx.Queue() |
| self.result_queue = self.ctx.Queue() |
| |
| |
| self.process = self.ctx.Process( |
| target=_vla_inference_worker, |
| args=(configs, self.task_queue, self.result_queue) |
| ) |
| self.process.start() |
| |
| |
| ready_msg = self.result_queue.get() |
| if ready_msg['type'] == 'ready': |
| print("VLA inference service initialized") |
| elif ready_msg['type'] == 'error': |
| raise RuntimeError(f"Failed to initialize VLA model: {ready_msg['error']}") |
| |
| def predict(self, image, instruction, state, state_mask, action_mask, |
| fov, num_ddim_steps=10, cfg_scale=5.0, sample_times=1): |
| """Request action prediction with state normalization and padding""" |
|
|
| self.task_queue.put({ |
| 'type': 'predict', |
| 'image': image, |
| 'instruction': instruction, |
| 'state': state, |
| 'state_mask': state_mask, |
| 'action_mask': action_mask, |
| 'fov': fov, |
| 'num_ddim_steps': num_ddim_steps, |
| 'cfg_scale': cfg_scale, |
| 'sample_times': sample_times, |
| }) |
| |
| result = self.result_queue.get() |
| if result['type'] == 'result' and result['success']: |
| |
| return result['data'] |
| else: |
| raise RuntimeError(f"VLA inference failed: {result.get('error', 'Unknown error')}") |
| |
| def shutdown(self): |
| """Shutdown the persistent process""" |
| self.task_queue.put({'type': 'shutdown'}) |
| self.process.join(timeout=10) |
| if self.process.is_alive(): |
| self.process.terminate() |
| self.process.join() |
|
|
|
|
| if __name__ == "__main__": |
| |
| mp.set_start_method('spawn', force=True) |
| main() |