#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os os.environ.setdefault("TRANSFORMERS_NO_TF", "1") os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1") os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") os.environ["USE_TF"] = "0" os.environ["USE_FLAX"] = "0" os.environ["USE_TORCH"] = "1" import torch from transformers import AutoConfig from common import decode_span_matrix, safe_auto_tokenizer from model import IrishCoreGlobalPointerModel def replacement(label: str) -> str: return f"[PII:{label}]" def mask_text(text: str, spans: list[dict]) -> str: out = text for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True): out = out[: span["start"]] + replacement(span["label"]) + out[span["end"] :] return out def predict(text: str, model, tokenizer, min_score: float): encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="pt", truncation=True) offsets = [tuple(item) for item in encoded.pop("offset_mapping")[0].tolist()] device = next(model.parameters()).device encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): output = model(**encoded) span_scores = torch.sigmoid(output.span_logits[0]).cpu().numpy() spans = decode_span_matrix(text, offsets, span_scores, model.config, min_score) for span in spans: span["replacement"] = replacement(span["label"]) return spans def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--text", required=True) parser.add_argument("--min-score", type=float, default=0.5) parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") parser.add_argument("--json", action="store_true") args = parser.parse_args() tokenizer = safe_auto_tokenizer(args.model) config = AutoConfig.from_pretrained(args.model) model = IrishCoreGlobalPointerModel.from_pretrained(args.model, config=config) if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device model.to(device) model.eval() spans = predict(args.text, model, tokenizer, args.min_score) result = { "model": args.model, "backend": "transformers_global_pointer", "min_score": args.min_score, "spans": spans, "masked_text": mask_text(args.text, spans), } if args.json: print(json.dumps(result, indent=2, ensure_ascii=False)) else: print(result["masked_text"]) if __name__ == "__main__": main()