hallucination-guard-env / inference.py
SamSankar's picture
Upload inference.py
6d1b06e verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
inference.py β€” HallucinationGuard-Env Inference Script
=======================================================
Mandatory submission script for the Meta PyTorch OpenEnv Hackathon 2026.
Environment variables (set before running):
API_BASE_URL The API endpoint for the LLM (e.g. https://huggingface.co/proxy/router.huggingface.co/v1)
MODEL_NAME The model identifier (e.g. Qwen/Qwen2.5-72B-Instruct)
HF_TOKEN Your HuggingFace API key
Usage:
export API_BASE_URL="https://huggingface.co/proxy/router.huggingface.co/v1"
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
export HF_TOKEN="hf_..."
python inference.py
# Dry-run without API key (heuristic agent):
python inference.py --heuristic
# Run against local dev server:
python inference.py --env-url http://localhost:7860
Expected baseline scores (heuristic agent, seed=42, 3 episodes x 5 steps):
task_1_factual_grounding : ~0.29
task_2_multi_hop_synthesis : ~0.25
task_3_adversarial_resistance : ~0.22
overall : ~0.25
"""
from __future__ import annotations
import os
# Fix Unicode encoding for Windows console
os.environ['PYTHONIOENCODING'] = 'utf-8'
import sys
import json
import time
import argparse
import logging
from typing import Dict, Any, List, Optional, Callable
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# ── Structured stdout logging for hackathon evaluation ──────────────────────────
# Required format:
# [START] task=<task_name> env=<benchmark> model=<model_name>
# [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
# [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
BENCHMARK = "hallucination-guard-env"
def log_start(task: str, env: str, model: str) -> None:
"""Emit [START] log in required format."""
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None) -> None:
"""Emit [STEP] log in required format."""
error_val = error if error else "null"
done_val = str(done).lower()
# Truncate action if too long and handle Unicode
action_trunc = action[:200].replace("\n", " ") if len(action) > 200 else action.replace("\n", " ")
# Replace non-ASCII characters to avoid encoding issues
action_trunc = action_trunc.encode('ascii', 'replace').decode('ascii')
print(f"[STEP] step={step} action={action_trunc} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
"""Emit [END] log in required format."""
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ── Mandatory environment variables ──────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://huggingface.co/proxy/router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN", "")
# ── Defaults ──────────────────────────────────────────────────────────────────
DEFAULT_ENV_URL = os.environ.get(
"HALLUGUARD_ENV_URL",
"https://huggingface.co/proxy/samsankar-hallucination-guard-env.hf.space",
)
DEFAULT_EPISODES = 3
DEFAULT_STEPS = 5
SEED = 42
TASK_ORDER = [
("task_1_factual_grounding", "beginner"),
("task_2_multi_hop_synthesis", "intermediate"),
("task_3_adversarial_resistance", "advanced"),
]
SYSTEM_PROMPT = """You are a precise, grounded question-answering assistant.
RULES (follow strictly):
1. Answer ONLY using information present in the CONTEXT provided.
2. If the answer is not in the context, say exactly: "I cannot answer from the provided context."
3. Keep answers concise β€” 1-3 sentences maximum.
4. Never fabricate facts, names, dates, or numbers not in the context.
5. If uncertain, express that uncertainty explicitly in your answer.
"""
ANSWER_PROMPT_TEMPLATE = """CONTEXT:
{context}
QUESTION:
{question}
Instructions:
- Answer using ONLY the context above.
- Provide a source_quote: a short verbatim phrase from the context that supports your answer.
- Rate your confidence from 0.0 (unsure) to 1.0 (certain).
Respond in JSON with these exact keys:
{{
"answer": "<your answer>",
"source_quote": "<verbatim phrase from context>",
"confidence": <float 0.0-1.0>
}}"""
# ── Environment client ────────────────────────────────────────────────────────
class EnvClient:
"""Thin HTTP wrapper around the HallucinationGuard REST API."""
def __init__(self, base_url: str, timeout: int = 300):
self.base = base_url.rstrip("/")
self.timeout = timeout
self.session = requests.Session()
self._session_id: Optional[str] = None
def _request_with_retry(self, method: str, path: str, body: Dict[str, Any] = None, retries: int = 3, backoff: float = 2.0) -> Dict[str, Any]:
url = f"{self.base}{path}"
for attempt in range(retries):
try:
if method == "GET":
r = self.session.get(url, timeout=self.timeout)
else:
r = self.session.post(url, json=body, timeout=self.timeout)
r.raise_for_status()
return r.json()
except (requests.exceptions.ChunkedEncodingError,
requests.exceptions.ConnectionError,
requests.exceptions.ReadTimeout) as e:
if attempt < retries - 1:
wait = backoff * (attempt + 1)
logger.warning(f"Request to {path} failed ({type(e).__name__}), retrying in {wait:.0f}s... ({attempt+1}/{retries})")
time.sleep(wait)
else:
raise
def _get(self, path: str) -> Dict[str, Any]:
return self._request_with_retry("GET", path)
def _post(self, path: str, body: Dict[str, Any] = {}) -> Dict[str, Any]:
return self._request_with_retry("POST", path, body)
def health(self) -> Dict[str, Any]:
return self._get("/health")
def list_tasks(self) -> Dict[str, Any]:
return self._get("/tasks")
def reset(self, difficulty: str, seed: int) -> Dict[str, Any]:
result = self._post("/reset", {"difficulty": difficulty, "seed": seed})
self._session_id = result.get("session_id")
return result
def step(self, answer: str, confidence: float, source_quote: str) -> Dict[str, Any]:
body: Dict[str, Any] = {
"answer": answer,
"confidence": confidence,
"source_quote": source_quote,
}
if self._session_id:
body["session_id"] = self._session_id
return self._post("/step", body)
def grade(self, task_id: str,
step_rewards: List[float],
step_infos: List[Dict[str, Any]]) -> Dict[str, Any]:
return self._post("/grader", {
"task_id": task_id,
"step_rewards": step_rewards,
"step_infos": step_infos,
})
# ── Agents ────────────────────────────────────────────────────────────────────
def heuristic_agent(question: str, context: str) -> Dict[str, Any]:
"""
Deterministic heuristic baseline β€” no LLM required.
Extracts the first meaningful sentence of the context as the answer.
Used when --heuristic flag is set or no API credentials are available.
"""
sentences = [s.strip() for s in context.replace("\n", " ").split(".") if len(s.strip()) > 10]
answer = sentences[0] if sentences else context[:120]
source_quote = context[:80] if context else ""
return {"answer": answer, "confidence": 0.6, "source_quote": source_quote}
def openai_agent(model: str, base_url: str, api_key: str) -> Callable:
"""
Returns a callable agent backed by any OpenAI-compatible chat endpoint.
Uses API_BASE_URL, MODEL_NAME, HF_TOKEN from environment variables.
"""
try:
from openai import OpenAI
except ImportError:
logger.error("openai package not installed. Run: pip install openai")
sys.exit(1)
if not api_key:
logger.error(
"HF_TOKEN not set. Export it or use --heuristic for the "
"no-API baseline.\n"
" export HF_TOKEN=hf_..."
)
sys.exit(1)
client = OpenAI(base_url=base_url, api_key=api_key)
def _call(question: str, context: str) -> Dict[str, Any]:
prompt = ANSWER_PROMPT_TEMPLATE.format(
context=context[:3000],
question=question,
)
# Try with JSON response format first, fall back to no format
for use_json_format in [True, False]:
try:
kwargs = dict(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.0,
max_tokens=512,
)
if use_json_format:
kwargs["response_format"] = {"type": "json_object"}
resp = client.chat.completions.create(**kwargs)
raw = (resp.choices[0].message.content or "").strip()
# Strip markdown code fences if present (```json ... ```)
import re
fence_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw, re.DOTALL)
if fence_match:
raw = fence_match.group(1)
# Try direct JSON parse
try:
parsed = json.loads(raw)
answer = str(parsed.get("answer", ""))
if answer:
return {
"answer": answer,
"confidence": float(parsed.get("confidence", 0.5)),
"source_quote": str(parsed.get("source_quote", "")),
}
except json.JSONDecodeError:
pass
# Try to extract JSON object from mixed text
json_match = re.search(r'\{[^{}]*"answer"[^{}]*\}', raw, re.DOTALL)
if json_match:
try:
parsed = json.loads(json_match.group(0))
return {
"answer": str(parsed.get("answer", "")),
"confidence": float(parsed.get("confidence", 0.5)),
"source_quote": str(parsed.get("source_quote", "")),
}
except json.JSONDecodeError:
pass
# If JSON format was tried and failed, fall through to no-format attempt
if use_json_format:
logger.warning("JSON parse failed, trying without response_format")
continue
# Last resort: use raw text as answer
return {"answer": raw[:200], "confidence": 0.4, "source_quote": ""}
except Exception as e:
if use_json_format:
error_msg = str(e)
if "response_format" in error_msg.lower() or "json_validate_failed" in error_msg:
logger.warning(f"JSON format not supported, trying without: {e}")
continue
else:
logger.warning(f"LLM call failed: {e}")
return {"answer": "", "confidence": 0.0, "source_quote": ""}
else:
logger.warning(f"LLM call failed: {e}")
return {"answer": "", "confidence": 0.0, "source_quote": ""}
return {"answer": "", "confidence": 0.0, "source_quote": ""}
return _call
# ── Episode runner ────────────────────────────────────────────────────────────
def run_episode(
env: EnvClient,
agent_fn: Callable,
task_id: str,
difficulty: str,
steps: int,
seed: int,
episode_num: int,
model_label: str,
) -> Dict[str, Any]:
"""Run one episode and return rewards + infos for the grader."""
# Emit START log at beginning of each task
if episode_num == 0:
log_start(task=task_id, env=BENCHMARK, model=model_label)
obs = env.reset(difficulty=difficulty, seed=seed + episode_num)
step_rewards: List[float] = []
step_infos: List[Dict[str, Any]] = []
for step_n in range(steps):
if obs.get("done", False):
break
question = obs.get("question", "")
context = obs.get("context", "")
action = agent_fn(question, context)
obs = env.step(
answer=action["answer"],
confidence=action["confidence"],
source_quote=action["source_quote"],
)
reward = float(obs.get("reward") or 0.0)
done = bool(obs.get("done", False))
step_rewards.append(reward)
# Extract metrics from observation metadata (returned by the environment)
obs_metadata = obs.get("metadata", {})
if isinstance(obs_metadata, dict):
obs_correctness = obs_metadata.get("correctness", 0.0)
obs_calibration = obs_metadata.get("calibration", 0.0)
obs_hall_score = obs_metadata.get("hallucination_score", 0.0)
else:
obs_correctness = 0.0
obs_calibration = 0.0
obs_hall_score = 0.0
# Extract ML component scores from reward_breakdown if available
rb = obs_metadata.get("reward_breakdown", {}) if isinstance(obs_metadata, dict) else {}
step_infos.append({
"correctness": obs_correctness,
"grounding": obs.get("grounding_score", 0.0),
"calibration": obs_calibration if obs_calibration else action["confidence"],
"hallucination_score": obs_hall_score if obs_hall_score else (1.0 if obs.get("is_hallucination") else 0.0),
"is_hallucination": bool(obs.get("is_hallucination", False)),
"semantic_consistency": rb.get("semantic_consistency", 0.0),
"rouge_l": rb.get("rouge_l", 0.0),
"bert_score": rb.get("bert_score", 0.0),
"align_score": rb.get("align_score", 0.0),
})
# Format action for logging (truncated answer)
action_str = f'answer="{action["answer"][:100]}" confidence={action["confidence"]:.2f}'
# Emit STEP log
log_step(
step=step_n + 1,
action=action_str,
reward=reward,
done=done,
error=None,
)
status = "HALLUCINATION" if obs.get("is_hallucination") else "OK"
logger.info(
f" [{task_id[:25]}] ep={episode_num+1} step={step_n+1} "
f"reward={reward:.3f} [{status}]"
)
grade = env.grade(task_id, step_rewards, step_infos)
episode_score = grade.get("score", 0.0)
return {
"episode": episode_num + 1,
"score": episode_score,
"rewards": step_rewards,
"grade": grade,
}
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="HallucinationGuard-Env inference script",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--env-url", default=DEFAULT_ENV_URL, help="Environment URL")
parser.add_argument("--episodes", type=int, default=DEFAULT_EPISODES)
parser.add_argument("--steps", type=int, default=DEFAULT_STEPS)
parser.add_argument("--seed", type=int, default=SEED)
parser.add_argument("--heuristic", action="store_true",
help="Use heuristic agent (no API key needed)")
parser.add_argument("--output", default=None,
help="Write JSON results to this file")
args = parser.parse_args()
# ── Connect to environment ────────────────────────────────────────────────
env = EnvClient(args.env_url)
logger.info(f"Connecting to environment: {args.env_url}")
try:
h = env.health()
logger.info(f" Environment: {h.get('service')} v{h.get('version')} β€” healthy")
except Exception as e:
logger.error(f"Cannot reach environment: {e}")
sys.exit(1)
# Verify /tasks endpoint
try:
tasks_info = env.list_tasks()
task_ids = [t["task_id"] for t in tasks_info.get("tasks", [])]
logger.info(f" Tasks: {task_ids}")
except Exception as e:
logger.error(f"/tasks endpoint failed: {e}")
sys.exit(1)
# ── Select agent ─────────────────────────────────────────────────────────
if args.heuristic or not HF_TOKEN:
logger.info("Using heuristic baseline agent (no LLM).")
agent_fn = heuristic_agent
model_label = "heuristic_baseline"
else:
logger.info(f"Using LLM agent: {MODEL_NAME} via {API_BASE_URL}")
agent_fn = openai_agent(MODEL_NAME, API_BASE_URL, HF_TOKEN)
model_label = MODEL_NAME
# ── Run all 3 tasks ───────────────────────────────────────────────────────
task_results: List[Dict[str, Any]] = []
all_scores: List[float] = []
all_rewards: List[float] = []
total_steps = 0
start_time = time.time()
for task_id, difficulty in TASK_ORDER:
logger.info(f"\n{'='*55}")
logger.info(f"TASK: {task_id} (difficulty={difficulty})")
logger.info(f"{'='*55}")
episode_scores: List[float] = []
task_rewards: List[float] = []
for ep in range(args.episodes):
ep_result = run_episode(
env=env,
agent_fn=agent_fn,
task_id=task_id,
difficulty=difficulty,
steps=args.steps,
seed=args.seed,
episode_num=ep,
model_label=model_label,
)
episode_scores.append(ep_result["score"])
all_scores.append(ep_result["score"])
all_rewards.extend(ep_result["rewards"])
task_rewards.extend(ep_result["rewards"])
total_steps += len(ep_result["rewards"])
task_avg = sum(episode_scores) / max(len(episode_scores), 1)
task_std = (
(sum((s - task_avg) ** 2 for s in episode_scores) / max(len(episode_scores), 1)) ** 0.5
if len(episode_scores) > 1 else 0.0
)
# Emit [END] log for this task
success = task_avg >= 0.5 # Consider success if score >= 0.5
log_end(
success=success,
steps=len(task_rewards),
score=task_avg,
rewards=task_rewards,
)
task_results.append({
"task_id": task_id,
"difficulty": difficulty,
"episodes": args.episodes,
"episode_scores": [round(s, 4) for s in episode_scores],
"avg_score": round(task_avg, 4),
"std_score": round(task_std, 4),
})
logger.info(f"\n Task score: {task_avg:.4f} Β± {task_std:.4f}")
elapsed = time.time() - start_time
overall_score = sum(all_scores) / max(len(all_scores), 1)
avg_reward = sum(all_rewards) / max(len(all_rewards), 1)
summary = {
"model": model_label,
"api_base_url": API_BASE_URL,
"env_url": args.env_url,
"seed": args.seed,
"episodes_per_task": args.episodes,
"steps_per_episode": args.steps,
"total_steps": total_steps,
"elapsed_seconds": round(elapsed, 1),
"tasks": task_results,
"overall": {
"score": round(overall_score, 4),
"avg_reward": round(avg_reward, 4),
},
}
# ── Print results ─────────────────────────────────────────────────────────
print("\n" + "=" * 55)
print("INFERENCE RESULTS")
print("=" * 55)
print(f"Model : {model_label}")
print(f"Seed : {args.seed} | {args.episodes} episodes x {args.steps} steps")
print(f"Elapsed : {elapsed:.1f}s")
print()
for t in task_results:
# Use ASCII characters for progress bar
bar = "#" * round(t["avg_score"] * 20)
print(
f" {t['task_id']:<42} "
f"{t['avg_score']:.4f} +- {t['std_score']:.4f} |{bar:<20}|"
)
print()
print(f" {'OVERALL':<42} {overall_score:.4f}")
print("=" * 55)
if args.output:
with open(args.output, "w") as f:
json.dump(summary, f, indent=2)
logger.info(f"Results written to {args.output}")
return summary
if __name__ == "__main__":
main()