| | import os |
| | import sys |
| | import cv2 |
| | import math |
| | import json |
| | import torch |
| | import gradio as gr |
| | import numpy as np |
| | from PIL import Image |
| | from PIL import ImageOps |
| | from pathlib import Path |
| | import multiprocessing as mp |
| | from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer, recon_traj |
| | from vitra.utils.config_utils import load_config |
| | from scipy.spatial.transform import Rotation as R |
| | import spaces |
| |
|
| | repo_root = Path(__file__).parent |
| | sys.path.insert(0, str(repo_root)) |
| |
|
| | from visualization.visualize_core import HandVisualizer, normalize_camera_intrinsics, save_to_video, Renderer, process_single_hand_labels |
| | from visualization.visualize_core import Config as HandConfig |
| |
|
| | |
| | from inference_human_prediction import ( |
| | get_state, |
| | euler_traj_to_rotmat_traj, |
| | ) |
| |
|
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | vla_model = None |
| | vla_normalizer = None |
| | hand_reconstructor = None |
| | visualizer = None |
| | hand_config = None |
| | app_config = None |
| |
|
| | def vla_predict(model, normalizer, image, instruction, state, state_mask, |
| | action_mask, fov, num_ddim_steps, cfg_scale, sample_times): |
| | """ |
| | VLA prediction function that runs on GPU. |
| | Model is already loaded and moved to CUDA in main process. |
| | """ |
| | from vitra.datasets.human_dataset import pad_state_human, pad_action |
| | from vitra.datasets.dataset_utils import ActionFeature, StateFeature |
| | |
| | |
| | norm_state = normalizer.normalize_state(state.copy()) |
| | |
| | |
| | unified_action_dim = ActionFeature.ALL_FEATURES[1] |
| | unified_state_dim = StateFeature.ALL_FEATURES[1] |
| | |
| | unified_state, unified_state_mask = pad_state_human( |
| | state=norm_state, |
| | state_mask=state_mask, |
| | action_dim=normalizer.action_mean.shape[0], |
| | state_dim=normalizer.state_mean.shape[0], |
| | unified_state_dim=unified_state_dim, |
| | ) |
| | _, unified_action_mask = pad_action( |
| | actions=None, |
| | action_mask=action_mask.copy(), |
| | action_dim=normalizer.action_mean.shape[0], |
| | unified_action_dim=unified_action_dim |
| | ) |
| | |
| | |
| | device = torch.device('cuda') |
| | fov = torch.from_numpy(fov).unsqueeze(0).to(device) |
| | unified_state = unified_state.unsqueeze(0).to(device) |
| | unified_state_mask = unified_state_mask.unsqueeze(0).to(device) |
| | unified_action_mask = unified_action_mask.unsqueeze(0).to(device) |
| | |
| | |
| | model = model.to(device) |
| | |
| | |
| | norm_action = model.predict_action( |
| | image=image, |
| | instruction=instruction, |
| | current_state=unified_state, |
| | current_state_mask=unified_state_mask, |
| | action_mask_torch=unified_action_mask, |
| | num_ddim_steps=num_ddim_steps, |
| | cfg_scale=cfg_scale, |
| | fov=fov, |
| | sample_times=sample_times, |
| | ) |
| | |
| | |
| | norm_action = norm_action[:, :, :102] |
| | unnorm_action = normalizer.unnormalize_action(norm_action) |
| | |
| | |
| | if isinstance(unnorm_action, torch.Tensor): |
| | unnorm_action_np = unnorm_action.cpu().numpy() |
| | else: |
| | unnorm_action_np = np.array(unnorm_action) |
| |
|
| | return unnorm_action_np |
| |
|
| | class GradioConfig: |
| | """Configuration for Gradio app""" |
| | def __init__(self): |
| | |
| | self.config_path = 'microsoft/VITRA-VLA-3B' |
| | self.model_path = None |
| | self.statistics_path = None |
| | |
| | |
| | self.hawor_model_path = 'arnoldland/HAWOR' |
| | self.detector_path = './weights/hawor/external/detector.pt' |
| | self.moge_model_name = 'Ruicheng/moge-2-vitl' |
| | self.mano_path = './weights/mano' |
| | |
| | |
| | self.fps = 8 |
| |
|
| |
|
| | def initialize_services(): |
| | """Initialize all models once at startup""" |
| | global vla_model, vla_normalizer, hand_reconstructor, visualizer, hand_config, app_config |
| | |
| | if vla_model is not None: |
| | return "Services already initialized" |
| | |
| | try: |
| | app_config = GradioConfig() |
| | |
| | |
| | hf_token = os.environ.get('HF_TOKEN', None) |
| | if hf_token: |
| | from huggingface_hub import login |
| | login(token=hf_token) |
| | print("Logged in to HuggingFace Hub") |
| | |
| | |
| | print("Loading VLA model...") |
| | from vitra.models import load_model |
| | from vitra.utils.data_utils import load_normalizer |
| | |
| | configs = load_config(app_config.config_path) |
| | if app_config.model_path is not None: |
| | configs['model_load_path'] = app_config.model_path |
| | if app_config.statistics_path is not None: |
| | configs['statistics_path'] = app_config.statistics_path |
| | |
| | |
| | globals()['vla_model'] = load_model(configs).cuda() |
| | globals()['vla_model'].eval() |
| | globals()['vla_normalizer'] = load_normalizer(configs) |
| | print("VLA model loaded") |
| | |
| | |
| | print("Loading Hand Reconstructor...") |
| | from data.tools.hand_recon_core import Config, HandReconstructor |
| | |
| | class ArgsObj: |
| | pass |
| | args_obj = ArgsObj() |
| | args_obj.hawor_model_path = app_config.hawor_model_path |
| | args_obj.detector_path = app_config.detector_path |
| | args_obj.moge_model_name = app_config.moge_model_name |
| | args_obj.mano_path = app_config.mano_path |
| | |
| | recon_config = Config(args_obj) |
| | globals()['hand_reconstructor'] = HandReconstructor(config=recon_config, device='cuda') |
| | print("Hand Reconstructor loaded") |
| | |
| | |
| | print("Loading Visualizer...") |
| | globals()['hand_config'] = HandConfig(app_config) |
| | globals()['hand_config'].FPS = app_config.fps |
| | globals()['visualizer'] = HandVisualizer(globals()['hand_config'], render_gradual_traj=False) |
| | globals()['visualizer'].mano = globals()['visualizer'].mano.cuda() |
| | print("Visualizer loaded") |
| | |
| | return "✅ All services initialized successfully!" |
| | |
| | except Exception as e: |
| | import traceback |
| | return f"❌ Failed to initialize services: {str(e)}\n{traceback.format_exc()}" |
| |
|
| |
|
| | def validate_image_dimensions(image): |
| | """Validate image dimensions before GPU allocation. |
| | Returns (is_valid, message) |
| | """ |
| | if image is None: |
| | return True, "" |
| | |
| | |
| | if isinstance(image, np.ndarray): |
| | img_pil = Image.fromarray(image) |
| | else: |
| | img_pil = image |
| | |
| | |
| | width, height = img_pil.size |
| | if width < height: |
| | error_msg = f"❌ Please upload a landscape image (width ≥ height).\nCurrent image: {width}x{height} (portrait orientation)" |
| | return False, error_msg |
| | |
| | return True, "" |
| |
|
| |
|
| | def validate_and_process_wrapper(image, session_state, progress=gr.Progress()): |
| | """Wrapper function to validate image before GPU allocation""" |
| | |
| | if image is None: |
| | return ("Waiting for image upload...", |
| | gr.update(interactive=False), |
| | None, |
| | False, |
| | False, |
| | session_state) |
| | |
| | |
| | is_valid, error_msg = validate_image_dimensions(image) |
| | if not is_valid: |
| | return (error_msg, |
| | gr.update(interactive=False), |
| | None, |
| | False, |
| | False, |
| | session_state) |
| | |
| | |
| | return process_image_upload(image, session_state, progress) |
| |
|
| |
|
| | @spaces.GPU(duration=120) |
| | def process_image_upload(image, session_state, progress=gr.Progress()): |
| | """Process uploaded image and run hand reconstruction""" |
| | global hand_reconstructor |
| | if torch.cuda.is_available(): |
| | print("CUDA is available for image processing") |
| | else: |
| | print("CUDA is NOT available for image processing") |
| | |
| | import time |
| | start_time = time.time() |
| | while time.time() - start_time < 60: |
| | try: |
| | if torch.cuda.is_available(): |
| | torch.zeros(1).cuda() |
| | break |
| | except: |
| | time.sleep(2) |
| | |
| | if hand_reconstructor is None: |
| | return ("Services not initialized. Please wait for initialization to complete.", |
| | gr.update(interactive=False), |
| | None, |
| | False, |
| | False, |
| | session_state) |
| | |
| | try: |
| | progress(0, desc="Preparing image...") |
| | |
| | |
| | if isinstance(image, np.ndarray): |
| | img_pil = Image.fromarray(image) |
| | else: |
| | img_pil = image |
| | |
| | |
| | session_state['current_image'] = img_pil |
| | |
| | progress(0.2, desc="Running hand reconstruction...") |
| | |
| | |
| | image_np = np.array(img_pil) |
| | image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
| | |
| | |
| | image_list = [image_bgr] |
| | hand_data = hand_reconstructor.recon(image_list) |
| |
|
| | session_state['current_hand_data'] = hand_data |
| | |
| | progress(1.0, desc="Hand reconstruction complete!") |
| | |
| | |
| | has_left = 'left' in hand_data and len(hand_data['left']) > 0 |
| | has_right = 'right' in hand_data and len(hand_data['right']) > 0 |
| | |
| | info_msg = "✅ Hand reconstruction complete!\n" |
| | info_msg += f"Detected hands: " |
| | if has_left and has_right: |
| | info_msg += "Left ✓, Right ✓" |
| | elif has_left: |
| | info_msg += "Left ✓, Right ✗" |
| | elif has_right: |
| | info_msg += "Left ✗, Right ✓" |
| | else: |
| | info_msg += "None detected" |
| | |
| | |
| | session_state['detected_left'] = has_left |
| | session_state['detected_right'] = has_right |
| | |
| | |
| | |
| | return (info_msg, |
| | gr.update(interactive=True), |
| | hand_data, |
| | has_left, |
| | has_right, |
| | session_state) |
| | |
| | except Exception as e: |
| | import traceback |
| | error_msg = f"❌ Hand reconstruction failed: {str(e)}\n{traceback.format_exc()}" |
| | |
| | session_state['detected_left'] = False |
| | session_state['detected_right'] = False |
| | |
| | return (error_msg, |
| | gr.update(interactive=True), |
| | None, |
| | False, |
| | False, |
| | session_state) |
| |
|
| | def update_checkboxes(has_left, has_right): |
| | """Update checkbox states based on detected hands (no progress bar)""" |
| |
|
| | |
| | left_checkbox_update = gr.update( |
| | value=has_left, |
| | interactive=True if has_left else False, |
| | elem_classes="disabled-checkbox" if not has_left else "" |
| | ) |
| | right_checkbox_update = gr.update( |
| | value=has_right, |
| | interactive=True if has_right else False, |
| | elem_classes="disabled-checkbox" if not has_right else "" |
| | ) |
| | |
| | |
| | left_instruction_update = gr.update( |
| | interactive=has_left, |
| | elem_classes="disabled-textbox" if not has_left else "" |
| | ) |
| | right_instruction_update = gr.update( |
| | interactive=has_right, |
| | elem_classes="disabled-textbox" if not has_right else "" |
| | ) |
| | |
| | return left_checkbox_update, right_checkbox_update, left_instruction_update, right_instruction_update |
| |
|
| |
|
| | def update_instruction_interactivity(use_left, use_right): |
| | """Update instruction textbox interactivity based on checkbox states""" |
| | left_update = gr.update( |
| | interactive=use_left, |
| | elem_classes="disabled-textbox" if not use_left else "" |
| | ) |
| | right_update = gr.update( |
| | interactive=use_right, |
| | elem_classes="disabled-textbox" if not use_right else "" |
| | ) |
| | return left_update, right_update |
| |
|
| | def update_final_instruction(left_instruction, right_instruction, use_left, use_right): |
| | """Update final instruction based on left/right inputs and checkbox states""" |
| | |
| | left_text = left_instruction if use_left else "None." |
| | right_text = right_instruction if use_right else "None." |
| | |
| | final = f"Left hand: {left_text} Right hand: {right_text}" |
| | |
| | |
| | styled_output = f"""<div style='padding: 12px; background-color: #f0f7ff; border-left: 4px solid #4A90E2; border-radius: 4px; margin-top: 10px;'> |
| | <strong style='color: #2c5282;'>📝 Final Instruction:</strong><br> |
| | <span style='color: #1a365d; font-size: 14px;'>{final}</span> |
| | </div>""" |
| | |
| | |
| | return gr.update(value=styled_output), final |
| |
|
| | def parse_instruction(instruction_text): |
| | """Parse combined instruction into left and right parts""" |
| | import re |
| | |
| | |
| | left_match = re.search(r'Left(?:\s+hand)?:\s*([^.]*(?:\.[^LR]*)*)(?=Right|$)', instruction_text, re.IGNORECASE) |
| | right_match = re.search(r'Right(?:\s+hand)?:\s*(.+?)$', instruction_text, re.IGNORECASE) |
| | |
| | left_text = left_match.group(1).strip() if left_match else "None." |
| | right_text = right_match.group(1).strip() if right_match else "None." |
| | |
| | return left_text, right_text |
| |
|
| | @spaces.GPU(duration=120) |
| | def generate_prediction(instruction, use_left, use_right, sample_times, num_ddim_steps, cfg_scale, hand_data, image, progress=gr.Progress()): |
| | """Generate hand motion prediction and visualization""" |
| | global vla_model, vla_normalizer, visualizer, hand_config, app_config |
| | |
| | |
| | import time |
| | start_time = time.time() |
| | while time.time() - start_time < 60: |
| | try: |
| | if torch.cuda.is_available(): |
| | torch.zeros(1).cuda() |
| | break |
| | except: |
| | time.sleep(2) |
| | |
| | if hand_data is None: |
| | return None, "Please upload an image and wait for hand reconstruction first" |
| | |
| | if not use_left and not use_right: |
| | return None, "Please select at least one hand (left or right)" |
| | |
| | try: |
| | progress(0, desc="Preparing data...") |
| | |
| | |
| | if image is None: |
| | return None, "Image not found. Please upload an image first." |
| | |
| | ori_w, ori_h = image.size |
| | |
| | try: |
| | image = ImageOps.exif_transpose(image) |
| | except Exception: |
| | pass |
| | |
| | image_resized = resize_short_side_to_target(image, target=224) |
| | w, h = image_resized.size |
| | |
| | |
| | current_state_left = None |
| | current_state_right = None |
| | beta_left = None |
| | beta_right = None |
| | |
| | progress(0.1, desc="Extracting hand states...") |
| | |
| | if use_right: |
| | current_state_right, beta_right, fov_x, _ = get_state(hand_data, hand_side='right') |
| | if use_left: |
| | current_state_left, beta_left, fov_x, _ = get_state(hand_data, hand_side='left') |
| | |
| | fov_x = fov_x * np.pi / 180 |
| | f_ori = ori_w / np.tan(fov_x / 2) / 2 |
| | fov_y = 2 * np.arctan(ori_h / (2 * f_ori)) |
| | |
| | f = w / np.tan(fov_x / 2) / 2 |
| | intrinsics = np.array([ |
| | [f, 0, w/2], |
| | [0, f, h/2], |
| | [0, 0, 1] |
| | ]) |
| | |
| | |
| | if current_state_left is None and current_state_right is None: |
| | return None, "No valid hand states found" |
| | |
| | state_left = current_state_left if use_left else np.zeros_like(current_state_right) |
| | beta_left = beta_left if use_left else np.zeros_like(beta_right) |
| | state_right = current_state_right if use_right else np.zeros_like(current_state_left) |
| | beta_right = beta_right if use_right else np.zeros_like(beta_left) |
| | |
| | state = np.concatenate([state_left, beta_left, state_right, beta_right], axis=0) |
| | state_mask = np.array([use_left, use_right], dtype=bool) |
| | |
| | |
| | configs = load_config(app_config.config_path) |
| | chunk_size = configs.get('fwd_pred_next_n', 16) |
| | action_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (chunk_size, 1)) |
| | |
| | fov = np.array([fov_x, fov_y], dtype=np.float32) |
| | image_resized_np = np.array(image_resized) |
| | |
| | progress(0.3, desc="Running VLA inference...") |
| | |
| | |
| | unnorm_action = vla_predict( |
| | model=vla_model, |
| | normalizer=vla_normalizer, |
| | image=image_resized_np, |
| | instruction=instruction, |
| | state=state, |
| | state_mask=state_mask, |
| | action_mask=action_mask, |
| | fov=fov, |
| | num_ddim_steps=num_ddim_steps, |
| | cfg_scale=cfg_scale, |
| | sample_times=sample_times, |
| | ) |
| | |
| | progress(0.6, desc="Visualizing predictions...") |
| | |
| | |
| | fx_exo = intrinsics[0, 0] |
| | fy_exo = intrinsics[1, 1] |
| | renderer = Renderer(w, h, (fx_exo, fy_exo), 'cuda') |
| | |
| | T = chunk_size + 1 |
| | traj_right_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| | traj_left_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| | |
| | traj_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (T, 1)) |
| | left_hand_mask = traj_mask[:, 0] |
| | right_hand_mask = traj_mask[:, 1] |
| | hand_mask = (left_hand_mask, right_hand_mask) |
| | |
| | all_rendered_frames = [] |
| | |
| | |
| | for i in range(sample_times): |
| | progress(0.6 + 0.3 * (i / sample_times), desc=f"Rendering sample {i+1}/{sample_times}...") |
| | |
| | traj_right = traj_right_list[i] |
| | traj_left = traj_left_list[i] |
| | |
| | if use_left: |
| | traj_left = recon_traj( |
| | state=state_left, |
| | rel_action=unnorm_action[i, :, 0:51], |
| | ) |
| | if use_right: |
| | traj_right = recon_traj( |
| | state=state_right, |
| | rel_action=unnorm_action[i, :, 51:102], |
| | ) |
| | |
| | left_hand_labels = { |
| | 'transl_worldspace': traj_left[:, 0:3], |
| | 'global_orient_worldspace': R.from_euler('xyz', traj_left[:, 3:6]).as_matrix(), |
| | 'hand_pose': euler_traj_to_rotmat_traj(traj_left[:, 6:51], T), |
| | 'beta': beta_left, |
| | } |
| | right_hand_labels = { |
| | 'transl_worldspace': traj_right[:, 0:3], |
| | 'global_orient_worldspace': R.from_euler('xyz', traj_right[:, 3:6]).as_matrix(), |
| | 'hand_pose': euler_traj_to_rotmat_traj(traj_right[:, 6:51], T), |
| | 'beta': beta_right, |
| | } |
| | |
| | verts_left_worldspace, _ = process_single_hand_labels(left_hand_labels, left_hand_mask, visualizer.mano, is_left=True) |
| | verts_right_worldspace, _ = process_single_hand_labels(right_hand_labels, right_hand_mask, visualizer.mano, is_left=False) |
| | |
| | hand_traj_wordspace = (verts_left_worldspace, verts_right_worldspace) |
| | |
| | R_w2c = np.broadcast_to(np.eye(3), (T, 3, 3)).copy() |
| | t_w2c = np.zeros((T, 3, 1), dtype=np.float32) |
| | extrinsics = (R_w2c, t_w2c) |
| | |
| | image_bgr = image_resized_np[..., ::-1] |
| | resize_video_frames = [image_bgr] * T |
| | save_frames = visualizer._render_hand_trajectory( |
| | resize_video_frames, |
| | hand_traj_wordspace, |
| | hand_mask, |
| | extrinsics, |
| | renderer, |
| | mode='first' |
| | ) |
| | |
| | all_rendered_frames.append(save_frames) |
| | |
| | progress(0.95, desc="Creating output video...") |
| | |
| | |
| | num_frames = len(all_rendered_frames[0]) |
| | grid_cols = math.ceil(math.sqrt(sample_times)) |
| | grid_rows = math.ceil(sample_times / grid_cols) |
| | |
| | combined_frames = [] |
| | for frame_idx in range(num_frames): |
| | sample_frames = [all_rendered_frames[i][frame_idx] for i in range(sample_times)] |
| | |
| | while len(sample_frames) < grid_rows * grid_cols: |
| | black_frame = np.zeros_like(sample_frames[0]) |
| | sample_frames.append(black_frame) |
| | |
| | rows = [] |
| | for row_idx in range(grid_rows): |
| | row_frames = sample_frames[row_idx * grid_cols:(row_idx + 1) * grid_cols] |
| | row_concat = np.concatenate(row_frames, axis=1) |
| | rows.append(row_concat) |
| | |
| | combined_frame = np.concatenate(rows, axis=0) |
| | combined_frames.append(combined_frame) |
| | |
| | |
| | output_dir = Path("./temp_gradio/outputs") |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | output_path = output_dir / "prediction.mp4" |
| | save_to_video(combined_frames, str(output_path), fps=hand_config.FPS) |
| | |
| | progress(1.0, desc="Complete!") |
| | |
| | return str(output_path), f"✅ Generated {sample_times} prediction samples successfully!" |
| | |
| | except Exception as e: |
| | import traceback |
| | error_msg = f"❌ Prediction failed: {str(e)}\n{traceback.format_exc()}" |
| | return None, error_msg |
| |
|
| |
|
| | def load_examples(): |
| | """Automatically load all image examples from the examples folder""" |
| | examples_dir = Path(__file__).parent / "examples" |
| | |
| | |
| | default_instructions = { |
| | "0001.jpg": "Left hand: Put the trash into the garbage. Right hand: None.", |
| | "0002.jpg": "Left hand: None. Right hand: Pick up the picture of Michael Jackson.", |
| | "0003.png": "Left hand: None. Right hand: Pick up the metal water cup.", |
| | "0004.jpg": "Left hand: Squeeze the dish sponge. Right hand: None.", |
| | "0005.jpg": "Left hand: None. Right hand: Cut the meat with the knife.", |
| | "0006.jpg": "Left hand: Open the closet door. Right hand: None.", |
| | "0007.jpg": "Left hand: None. Right hand: Cut the paper with the scissors.", |
| | "0008.jpg": "Left hand: Wipe the countertop with the cloth. Right hand: None.", |
| | "0009.jpg": "Left hand: None. Right hand: Open the cabinet door.", |
| | "0010.png": "Left hand: None. Right hand: Turn on the faucet.", |
| | "0011.jpg": "Left hand: Put the drink bottle into the trash can. Right hand: None.", |
| | "0012.jpg": "Left hand: None. Right hand: Pick up the gray cup from the cabinet.", |
| | "0013.jpg": "Left hand: None. Right hand: Take the milk bottle out of the fridge.", |
| | "0014.jpg": "Left hand: None. Right hand: 拿起气球。", |
| | "0015.jpg": "Left hand: None. Right hand: Pick up the picture with the smaller red heart.", |
| | "0016.jpg": "Left hand: None. Right hand: Pick up the picture with \"Cat\".", |
| | "0017.jpg": "Left hand: None. Right hand: Pick up the picture of the Statue of Liberty.", |
| | "0018.jpg": "Left hand: None. Right hand: Pick up the picture of the two people.", |
| | } |
| | |
| | examples_images = [] |
| | instructions_map = {} |
| | |
| | if examples_dir.exists(): |
| | |
| | image_files = sorted([f for f in examples_dir.iterdir() |
| | if f.suffix.lower() in ['.jpg', '.jpeg', '.png']]) |
| | |
| | for img_path in image_files: |
| | img_path_str = str(img_path) |
| | instruction = default_instructions.get( |
| | img_path.name, |
| | "Left hand: Perform the action. Right hand: None." |
| | ) |
| | |
| | examples_images.append([img_path_str]) |
| | |
| | instructions_map[img_path_str] = instruction |
| | |
| | return examples_images, instructions_map |
| |
|
| |
|
| | def get_instruction_for_image(image_path, instructions_map): |
| | """Get the instruction for a given image path""" |
| | if image_path is None: |
| | return gr.update() |
| | |
| | |
| | instruction = instructions_map.get(str(image_path), "") |
| | return instruction |
| |
|
| |
|
| |
|
| | def create_gradio_interface(): |
| | """Create Gradio interface""" |
| | |
| | with gr.Blocks(delete_cache=(600, 600), title="3D Hand Motion Prediction with VITRA") as demo: |
| |
|
| | |
| | gr.HTML(""" |
| | <style> |
| | .disabled-checkbox { |
| | opacity: 0.5 !important; |
| | pointer-events: none !important; |
| | } |
| | .disabled-textbox textarea { |
| | background-color: #f5f5f5 !important; |
| | color: #9e9e9e !important; |
| | cursor: not-allowed !important; |
| | } |
| | </style> |
| | """) |
| | |
| | gr.HTML(""" |
| | <div align="center"> |
| | <h1> 🤖 Hand Action Prediction with <a href="https://microsoft.github.io/VITRA/" target="_blank" style="text-decoration: underline; font-weight: bold; color: #4A90E2;">VITRA</a> <a title="Github" href="https://github.com/microsoft/VITRA" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/microsoft/VITRA?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> </a> </h1> |
| | </div> |
| | |
| | <div style="line-height: 1.8;"> |
| | <br> |
| | <p style="font-size: 16px;">Upload a <strong style="color: #7C4DFF;">landscape</strong>, <strong style="color: #7C4DFF;">egocentric (first-person)</strong> image containing hand(s) and provide instructions to predict future 3D hand trajectories.</p> |
| | |
| | <h3>🌟 Steps:</h3> |
| | <ol> |
| | <li>Upload an landscape view image containing hand(s).</li> |
| | <li>Enter text instructions describing the desired task.</li> |
| | <li>Configure advanced settings (Optional) and click "Generate 3D Hand Trajectory".</li> |
| | </ol> |
| | |
| | <h3>💡 Tips:</h3> |
| | <ul> |
| | <li><strong>Use Left/Right Hand</strong>: Select which hand to predict based on what's detected and what you want to predict.</li> |
| | <li><strong>Instruction</strong>: Provide clear and specific imperative instructions separately for the left and right hands, and enter them in the corresponding fields. If the results are unsatisfactory, <strong style="color: #7C4DFF;">try providing more detailed instructions</strong> (e.g., color, orientation, etc.).</li> |
| | <li>For best inference quality, it is recommended to <strong style="color: #7C4DFF;">capture landscape view images from a camera height close to that of a human head</strong>. Highly unusual or distorted hand poses/positions may cause inference failures.</li> |
| | <li>It is worth noting that each generation produces only a single action chunking starting from the current state, which <strong style="color: #7C4DFF;">does not necessarily complete the entire task</strong>. Executing an entire chunking in one step may lead to reduced precision.</li> |
| | </ul> |
| | |
| | </div> |
| | |
| | <hr style='border: none; border-top: 1px solid #e0e0e0; margin: 20px 0;'> |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.HTML(""" |
| | <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
| | <h3 style='color: white; margin: 0; text-align: center;'>📄 Input</h3> |
| | </div> |
| | """) |
| | |
| | |
| | input_image = gr.Image( |
| | label="🖼️ Upload Image with Hands", |
| | type="pil", |
| | height=300, |
| | ) |
| | |
| | |
| | recon_status = gr.Textbox( |
| | label="🔍 Hand Reconstruction Status", |
| | value="⏳ Waiting for image upload...", |
| | interactive=False, |
| | lines=2, |
| | container=True |
| | ) |
| | |
| | gr.Markdown("---") |
| | gr.HTML(""" |
| | <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
| | <h3 style='color: white; margin: 0; text-align: center;'>⚙️ Prediction Settings</h3> |
| | </div> |
| | """) |
| | gr.HTML(""" |
| | <div style='padding: 8px; background-color: #e8eaf6; border-left: 4px solid #5c6bc0; border-radius: 4px; margin-bottom: 10px;'> |
| | <strong style='color: #3949ab;'>👋 Select Hands:</strong> |
| | </div> |
| | """) |
| | with gr.Row(): |
| | use_left = gr.Checkbox(label="Use Left Hand", value=True) |
| | use_right = gr.Checkbox(label="Use Right Hand", value=True) |
| | |
| | |
| | gr.HTML(""" |
| | <div style='padding: 8px; background-color: #e8eaf6; border-left: 4px solid #5c6bc0; border-radius: 4px; margin: 15px 0 10px 0;'> |
| | <strong style='color: #3949ab;'>✍️ Instructions:</strong> |
| | </div> |
| | """) |
| | with gr.Row(): |
| | with gr.Column(): |
| | with gr.Row(): |
| | gr.HTML("<div style='display: flex; align-items: center; min-height: 40px; padding-right: 2px;'><span style='font-weight: 600; color: #5c6bc0; white-space: nowrap;'>Left hand:</span></div>") |
| | left_instruction = gr.Textbox( |
| | label="", |
| | value="Put the trash into the garbage.", |
| | lines=1, |
| | max_lines=5, |
| | placeholder="Describe left hand action...", |
| | show_label=False, |
| | interactive=True, |
| | scale=3 |
| | ) |
| | with gr.Column(): |
| | with gr.Row(): |
| | gr.HTML("<div style='display: flex; align-items: center; min-height: 40px; padding-right: 2px;'><span style='font-weight: 600; color: #5c6bc0; white-space: nowrap;'>Right hand:</span></div>") |
| | right_instruction = gr.Textbox( |
| | label="", |
| | value="None.", |
| | lines=1, |
| | max_lines=5, |
| | placeholder="Describe right hand action...", |
| | show_label=False, |
| | interactive=True, |
| | scale=3 |
| | ) |
| |
|
| | |
| | final_instruction = gr.HTML( |
| | value="""<div style='padding: 12px; background-color: #f0f7ff; border-left: 4px solid #4A90E2; border-radius: 4px; margin-top: 10px;'> |
| | <strong style='color: #2c5282;'>📝 Final Instruction:</strong><br> |
| | <span style='color: #1a365d; font-size: 14px;'>Left hand: Put the trash into the garbage. Right hand: None.</span> |
| | </div>""", |
| | show_label=False |
| | ) |
| | final_instruction_text = gr.State(value="Left hand: Put the trash into the garbage. Right hand: None.") |
| | |
| | |
| | with gr.Accordion("🔧 Advanced Settings", open=False): |
| | sample_times = gr.Slider( |
| | minimum=1, |
| | maximum=9, |
| | value=4, |
| | step=1, |
| | label="Number of Samples", |
| | info="Multiple samples show different possible trajectories." |
| | ) |
| | num_ddim_steps = gr.Slider( |
| | minimum=1, |
| | maximum=50, |
| | value=10, |
| | step=5, |
| | label="DDIM Steps", |
| | info="DDIM steps of the diffusion model. 10 is usually sufficient." |
| | ) |
| | cfg_scale = gr.Slider( |
| | minimum=1.0, |
| | maximum=15.0, |
| | value=5.0, |
| | step=0.5, |
| | label="CFG Scale", |
| | info="Classifier-free guidance scale of the diffusion model." |
| | ) |
| | |
| | |
| | generate_btn = gr.Button("🎬 Generate 3D Hand Trajectory", variant="primary", size="lg") |
| | |
| | |
| | hand_data = gr.State(value=None) |
| | detected_left = gr.State(value=False) |
| | detected_right = gr.State(value=False) |
| | |
| | |
| | session_state = gr.State(value={}) |
| | |
| | |
| | with gr.Column(scale=1): |
| | gr.HTML(""" |
| | <div style='background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
| | <h3 style='color: white; margin: 0; text-align: center;'>🎬 Output</h3> |
| | </div> |
| | """) |
| |
|
| | |
| | output_video = gr.Video( |
| | label="🎬 Predicted Hand Motion", |
| | height=500, |
| | autoplay=True |
| | ) |
| | |
| | |
| | gen_status = gr.Textbox( |
| | label="📊 Generation Status", |
| | value="", |
| | interactive=False, |
| | lines=2 |
| | ) |
| | |
| | |
| | gr.Markdown("---") |
| | gr.HTML(""" |
| | <div style='background: linear-gradient(135deg, #89f7fe 0%, #66a6ff 100%); padding: 15px; border-radius: 8px; margin: 20px 0 10px 0;'> |
| | <h3 style='color: white; margin: 0; text-align: center;'>📋 Examples</h3> |
| | </div> |
| | """) |
| | gr.HTML(""" |
| | <div style='padding: 10px; background-color: #e7f3ff; border-left: 4px solid #2196F3; border-radius: 4px; margin-bottom: 15px;'> |
| | <span style='color: #1565c0;'>👆 Click any example below to load the image and instruction</span> |
| | </div> |
| | """) |
| |
|
| | examples_images, instructions_map = load_examples() |
| | |
| | |
| | example_gallery = gr.Gallery( |
| | value=[img[0] for img in examples_images], |
| | label="", |
| | columns=6, |
| | height="450", |
| | object_fit="contain", |
| | show_label=False |
| | ) |
| | |
| | |
| | def load_example_from_gallery(evt: gr.SelectData): |
| | selected_index = evt.index |
| | if selected_index < len(examples_images): |
| | img_path = examples_images[selected_index][0] |
| | instruction_text = instructions_map.get(img_path, "") |
| | |
| | left_text, right_text = parse_instruction(instruction_text) |
| | |
| | return gr.update(value=img_path), gr.update(value=left_text), gr.update(value=right_text), gr.update(interactive=False) |
| | return gr.update(), gr.update(), gr.update(), gr.update() |
| |
|
| | example_gallery.select( |
| | fn=load_example_from_gallery, |
| | inputs=[], |
| | outputs=[input_image, left_instruction, right_instruction, generate_btn], |
| | show_progress=False |
| | ).then( |
| | fn=update_final_instruction, |
| | inputs=[left_instruction, right_instruction, use_left, use_right], |
| | outputs=[final_instruction, final_instruction_text], |
| | show_progress=False |
| | ) |
| |
|
| | |
| | |
| | |
| | input_image.change( |
| | fn=validate_and_process_wrapper, |
| | inputs=[input_image, session_state], |
| | outputs=[recon_status, generate_btn, hand_data, detected_left, detected_right, session_state], |
| | show_progress='full' |
| | ).then( |
| | fn=update_checkboxes, |
| | inputs=[detected_left, detected_right], |
| | outputs=[use_left, use_right, left_instruction, right_instruction], |
| | show_progress=False |
| | ) |
| |
|
| | |
| | use_left.change( |
| | fn=update_instruction_interactivity, |
| | inputs=[use_left, use_right], |
| | outputs=[left_instruction, right_instruction], |
| | show_progress=False |
| | ).then( |
| | fn=update_final_instruction, |
| | inputs=[left_instruction, right_instruction, use_left, use_right], |
| | outputs=[final_instruction, final_instruction_text], |
| | show_progress=False |
| | ) |
| |
|
| | use_right.change( |
| | fn=update_instruction_interactivity, |
| | inputs=[use_left, use_right], |
| | outputs=[left_instruction, right_instruction], |
| | show_progress=False |
| | ).then( |
| | fn=update_final_instruction, |
| | inputs=[left_instruction, right_instruction, use_left, use_right], |
| | outputs=[final_instruction, final_instruction_text], |
| | show_progress=False |
| | ) |
| |
|
| | |
| | left_instruction.change( |
| | fn=update_final_instruction, |
| | inputs=[left_instruction, right_instruction, use_left, use_right], |
| | outputs=[final_instruction, final_instruction_text], |
| | show_progress=False |
| | ) |
| | |
| | right_instruction.change( |
| | fn=update_final_instruction, |
| | inputs=[left_instruction, right_instruction, use_left, use_right], |
| | outputs=[final_instruction, final_instruction_text], |
| | show_progress=False |
| | ) |
| |
|
| |
|
| | generate_btn.click( |
| | fn=generate_prediction, |
| | inputs=[final_instruction_text, use_left, use_right, sample_times, num_ddim_steps, cfg_scale, hand_data, input_image], |
| | outputs=[output_video, gen_status], |
| | show_progress='full' |
| | ) |
| | |
| | return demo |
| |
|
| | if __name__ == "__main__": |
| | """launch Gradio app""" |
| | |
| | print("Initializing services...") |
| | init_msg = initialize_services() |
| | print(init_msg) |
| | |
| | if "Failed" in init_msg: |
| | print("⚠️ Services failed to initialize. Please check the configuration and try again.") |
| | |
| | |
| | demo = create_gradio_interface() |
| | |
| | |
| | demo.launch() |