--- library_name: transformers tags: - medical license: mit pipeline_tag: image-segmentation --- ```py import os import cv2 import numpy as np import torch import albumentations as album from albumentations.pytorch import ToTensorV2 import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import warnings warnings.filterwarnings("ignore") class Pipeline: def __init__(self, model_path, device=None): self.img_size = (384, 288) self.classes = ['background', 'polyp'] self.class_rgb_values = [ [0, 0, 0], [255, 0, 0] ] if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) print(f"Using device: {self.device}") try: self.model = torch.load(model_path, map_location=self.device) self.model.eval() except Exception as e: print(f"Failed to load model: {e}") raise encoder_name = 'efficientnet-b3' self.preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder_name, 'imagenet') def preprocess_image(self, image): image = cv2.resize(image, self.img_size, interpolation=cv2.INTER_AREA) preprocessing = self._get_preprocessing(self.preprocessing_fn) sample = preprocessing(image=image) image = sample['image'] return torch.from_numpy(image).unsqueeze(0).to(self.device) def _get_preprocessing(self, preprocessing_fn=None): _transform = [] if preprocessing_fn: _transform.append(album.Lambda(image=preprocessing_fn)) _transform.append(album.Lambda(image=self._to_tensor)) return album.Compose(_transform) def _to_tensor(self, x, **kwargs): return x.transpose(2, 0, 1).astype('float32') def _reverse_one_hot(self, image): return np.argmax(image, axis=-1) def _colour_code_segmentation(self, image): colour_codes = np.array(self.class_rgb_values) return colour_codes[image.astype(int)] def predict(self, image): if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = image[:, :, :3] original_h, original_w = image.shape[:2] x_tensor = self.preprocess_image(image) with torch.no_grad(): pred_mask = self.model(x_tensor) pred_mask = pred_mask.detach().squeeze().cpu().numpy() pred_mask = np.transpose(pred_mask, (1, 2, 0)) polyp_heatmap = pred_mask[:, :, self.classes.index('polyp')] binary_mask = (polyp_heatmap > 0.5).astype(np.uint8) colored_mask = self._colour_code_segmentation(self._reverse_one_hot(pred_mask)) if (original_h, original_w) != self.img_size[::-1]: binary_mask = cv2.resize(binary_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) colored_mask = cv2.resize(colored_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) polyp_heatmap = cv2.resize(polyp_heatmap, (original_w, original_h), interpolation=cv2.INTER_LINEAR) return { 'binary_mask': binary_mask, 'colored_mask': colored_mask, 'heatmap': polyp_heatmap } def visualize_prediction(self, image, prediction, save_path=None): fig, axs = plt.subplots(1, 4, figsize=(20, 5)) axs[0].imshow(image) axs[0].set_title('Original Image') axs[0].axis('off') axs[1].imshow(prediction['heatmap'], cmap='jet') axs[1].set_title('Polyp Probability Heatmap') axs[1].axis('off') axs[2].imshow(prediction['binary_mask'], cmap='gray') axs[2].set_title('Binary Polyp Mask') axs[2].axis('off') overlay = image.copy() colored_mask = prediction['colored_mask'] mask_condition = prediction['binary_mask'] > 0 overlay[mask_condition] = overlay[mask_condition] * 0.5 + colored_mask[mask_condition] * 0.5 axs[3].imshow(overlay) axs[3].set_title('Polyp Overlay') axs[3].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Visualization saved to {save_path}") plt.show() return overlay def main(): model_path = './best_model.pth' # Path to model pipeline = Pipeline(model_path) image_path = 'test.png' if os.path.exists(image_path): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) prediction = pipeline.predict(image) pipeline.visualize_prediction(image, prediction, save_path='prediction_result.png') polyp_percentage = np.mean(prediction['binary_mask']) * 100 print(f"Polyp covers approximately {polyp_percentage:.2f}% of the image") else: print(f"Image not found at {image_path}") if __name__ == "__main__": main() ```