File size: 5,330 Bytes
e00716b
 
 
 
 
6721742
e00716b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
---
library_name: transformers
tags:
- medical
license: mit
pipeline_tag: image-segmentation
---

```py
import os
import cv2
import numpy as np
import torch
import albumentations as album
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import warnings
warnings.filterwarnings("ignore")

class Pipeline:
    def __init__(self, model_path, device=None):
        self.img_size = (384, 288)
        self.classes = ['background', 'polyp']
        self.class_rgb_values = [
            [0, 0, 0],      
            [255, 0, 0]     
        ]

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
            
        print(f"Using device: {self.device}")
        
        try:
            self.model = torch.load(model_path, map_location=self.device)
            self.model.eval()
        except Exception as e:
            print(f"Failed to load model: {e}")
            raise
        
        encoder_name = 'efficientnet-b3'  
        self.preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder_name, 'imagenet')
        
    def preprocess_image(self, image):
        image = cv2.resize(image, self.img_size, interpolation=cv2.INTER_AREA)
        
        preprocessing = self._get_preprocessing(self.preprocessing_fn)
        sample = preprocessing(image=image)
        image = sample['image']
        
        return torch.from_numpy(image).unsqueeze(0).to(self.device)
    
    def _get_preprocessing(self, preprocessing_fn=None):
        _transform = []
        if preprocessing_fn:
            _transform.append(album.Lambda(image=preprocessing_fn))
        _transform.append(album.Lambda(image=self._to_tensor))
        return album.Compose(_transform)
    
    def _to_tensor(self, x, **kwargs):
        return x.transpose(2, 0, 1).astype('float32')
    
    def _reverse_one_hot(self, image):
        return np.argmax(image, axis=-1)
    
    def _colour_code_segmentation(self, image):
        colour_codes = np.array(self.class_rgb_values)
        return colour_codes[image.astype(int)]
    
    def predict(self, image):
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = image[:, :, :3]  
            
        original_h, original_w = image.shape[:2]
        
        x_tensor = self.preprocess_image(image)
        
        with torch.no_grad():
            pred_mask = self.model(x_tensor)
            pred_mask = pred_mask.detach().squeeze().cpu().numpy()
            pred_mask = np.transpose(pred_mask, (1, 2, 0))
        
        polyp_heatmap = pred_mask[:, :, self.classes.index('polyp')]
        
        binary_mask = (polyp_heatmap > 0.5).astype(np.uint8)

        colored_mask = self._colour_code_segmentation(self._reverse_one_hot(pred_mask))
        
        if (original_h, original_w) != self.img_size[::-1]:
            binary_mask = cv2.resize(binary_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
            colored_mask = cv2.resize(colored_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
            polyp_heatmap = cv2.resize(polyp_heatmap, (original_w, original_h), interpolation=cv2.INTER_LINEAR)
        
        return {
            'binary_mask': binary_mask,
            'colored_mask': colored_mask,
            'heatmap': polyp_heatmap
        }
    
    def visualize_prediction(self, image, prediction, save_path=None):
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))

        axs[0].imshow(image)
        axs[0].set_title('Original Image')
        axs[0].axis('off')
        
        axs[1].imshow(prediction['heatmap'], cmap='jet')
        axs[1].set_title('Polyp Probability Heatmap')
        axs[1].axis('off')
        
        axs[2].imshow(prediction['binary_mask'], cmap='gray')
        axs[2].set_title('Binary Polyp Mask')
        axs[2].axis('off')
        
        overlay = image.copy()
        colored_mask = prediction['colored_mask']
        mask_condition = prediction['binary_mask'] > 0
        overlay[mask_condition] = overlay[mask_condition] * 0.5 + colored_mask[mask_condition] * 0.5
        
        axs[3].imshow(overlay)
        axs[3].set_title('Polyp Overlay')
        axs[3].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Visualization saved to {save_path}")
        
        plt.show()
        
        return overlay

def main():
    model_path = './best_model.pth'  # Path to model
    pipeline = Pipeline(model_path)
    
    image_path = 'test.png' 
    
    if os.path.exists(image_path):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        prediction = pipeline.predict(image)
        
        pipeline.visualize_prediction(image, prediction, save_path='prediction_result.png')
        
        polyp_percentage = np.mean(prediction['binary_mask']) * 100
        print(f"Polyp covers approximately {polyp_percentage:.2f}% of the image")
    else:
        print(f"Image not found at {image_path}")

if __name__ == "__main__":
    main()
```