Skip to content

Inference Slicer

InferenceSlicer

InferenceSlicer performs slicing-based inference for small target detection. This method, often referred to as Slicing Adaptive Inference (SAHI), involves dividing a larger image into smaller slices, performing inference on each slice, and then merging the detections.

Parameters:

Name Type Description Default
slice_wh Tuple[int, int]

Dimensions of each slice in the format (width, height).

(320, 320)
overlap_ratio_wh Tuple[float, float]

Overlap ratio between consecutive slices in the format (width_ratio, height_ratio).

(0.2, 0.2)
iou_threshold Optional[float]

Intersection over Union (IoU) threshold used for non-max suppression.

0.5
callback Callable

A function that performs inference on a given image slice and returns detections.

required
thread_workers int

Number of threads for parallel execution.

1
Note

The class ensures that slices do not exceed the boundaries of the original image. As a result, the final slices in the row and column dimensions might be smaller than the specified slice dimensions if the image's width or height is not a multiple of the slice's width or height minus the overlap.

Source code in supervision/detection/tools/inference_slicer.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class InferenceSlicer:
    """
    InferenceSlicer performs slicing-based inference for small target detection. This
    method, often referred to as
    [Slicing Adaptive Inference (SAHI)](https://ieeexplore.ieee.org/document/9897990),
    involves dividing a larger image into smaller slices, performing inference on each
    slice, and then merging the detections.

    Args:
        slice_wh (Tuple[int, int]): Dimensions of each slice in the format
            `(width, height)`.
        overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive
            slices in the format `(width_ratio, height_ratio)`.
        iou_threshold (Optional[float]): Intersection over Union (IoU) threshold
            used for non-max suppression.
        callback (Callable): A function that performs inference on a given image
            slice and returns detections.
        thread_workers (int): Number of threads for parallel execution.

    Note:
        The class ensures that slices do not exceed the boundaries of the original
        image. As a result, the final slices in the row and column dimensions might be
        smaller than the specified slice dimensions if the image's width or height is
        not a multiple of the slice's width or height minus the overlap.
    """

    def __init__(
        self,
        callback: Callable[[np.ndarray], Detections],
        slice_wh: Tuple[int, int] = (320, 320),
        overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
        iou_threshold: Optional[float] = 0.5,
        thread_workers: int = 1,
    ):
        self.slice_wh = slice_wh
        self.overlap_ratio_wh = overlap_ratio_wh
        self.iou_threshold = iou_threshold
        self.callback = callback
        self.thread_workers = thread_workers

    def __call__(self, image: np.ndarray) -> Detections:
        """
        Performs slicing-based inference on the provided image using the specified
            callback.

        Args:
            image (np.ndarray): The input image on which inference needs to be
                performed. The image should be in the format
                `(height, width, channels)`.

        Returns:
            Detections: A collection of detections for the entire image after merging
                results from all slices and applying NMS.

        Example:
            ```python
            import cv2
            import supervision as sv
            from ultralytics import YOLO

            image = cv2.imread(SOURCE_IMAGE_PATH)
            model = YOLO(...)

            def callback(image_slice: np.ndarray) -> sv.Detections:
                result = model(image_slice)[0]
                return sv.Detections.from_ultralytics(result)

            slicer = sv.InferenceSlicer(callback = callback)

            detections = slicer(image)
            ```
        """
        detections_list = []
        resolution_wh = (image.shape[1], image.shape[0])
        offsets = self._generate_offset(
            resolution_wh=resolution_wh,
            slice_wh=self.slice_wh,
            overlap_ratio_wh=self.overlap_ratio_wh,
        )

        with ThreadPoolExecutor(max_workers=self.thread_workers) as executor:
            futures = [
                executor.submit(self._run_callback, image, offset) for offset in offsets
            ]
            for future in as_completed(futures):
                detections_list.append(future.result())

        return Detections.merge(detections_list=detections_list).with_nms(
            threshold=self.iou_threshold
        )

    def _run_callback(self, image, offset) -> Detections:
        """
        Run the provided callback on a slice of an image.

        Args:
            image (np.ndarray): The input image on which inference needs to run
            offset (np.ndarray): An array of shape `(4,)` containing coordinates
                for the slice.

        Returns:
            Detections: A collection of detections for the slice.
        """
        image_slice = crop_image(image=image, xyxy=offset)
        detections = self.callback(image_slice)
        detections = move_detections(detections=detections, offset=offset[:2])

        return detections

    @staticmethod
    def _generate_offset(
        resolution_wh: Tuple[int, int],
        slice_wh: Tuple[int, int],
        overlap_ratio_wh: Tuple[float, float],
    ) -> np.ndarray:
        """
        Generate offset coordinates for slicing an image based on the given resolution,
        slice dimensions, and overlap ratios.

        Args:
            resolution_wh (Tuple[int, int]): A tuple representing the width and height
                of the image to be sliced.
            slice_wh (Tuple[int, int]): A tuple representing the desired width and
                height of each slice.
            overlap_ratio_wh (Tuple[float, float]): A tuple representing the desired
                overlap ratio for width and height between consecutive slices. Each
                value should be in the range [0, 1), where 0 means no overlap and a
                value close to 1 means high overlap.

        Returns:
            np.ndarray: An array of shape `(n, 4)` containing coordinates for each
                slice in the format `[xmin, ymin, xmax, ymax]`.

        Note:
            The function ensures that slices do not exceed the boundaries of the
                original image. As a result, the final slices in the row and column
                dimensions might be smaller than the specified slice dimensions if the
                image's width or height is not a multiple of the slice's width or
                height minus the overlap.
        """
        slice_width, slice_height = slice_wh
        image_width, image_height = resolution_wh
        overlap_ratio_width, overlap_ratio_height = overlap_ratio_wh

        width_stride = slice_width - int(overlap_ratio_width * slice_width)
        height_stride = slice_height - int(overlap_ratio_height * slice_height)

        ws = np.arange(0, image_width, width_stride)
        hs = np.arange(0, image_height, height_stride)

        xmin, ymin = np.meshgrid(ws, hs)
        xmax = np.clip(xmin + slice_width, 0, image_width)
        ymax = np.clip(ymin + slice_height, 0, image_height)

        offsets = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4)

        return offsets

__call__(image)

Performs slicing-based inference on the provided image using the specified callback.

Parameters:

Name Type Description Default
image ndarray

The input image on which inference needs to be performed. The image should be in the format (height, width, channels).

required

Returns:

Name Type Description
Detections Detections

A collection of detections for the entire image after merging results from all slices and applying NMS.

Example
import cv2
import supervision as sv
from ultralytics import YOLO

image = cv2.imread(SOURCE_IMAGE_PATH)
model = YOLO(...)

def callback(image_slice: np.ndarray) -> sv.Detections:
    result = model(image_slice)[0]
    return sv.Detections.from_ultralytics(result)

slicer = sv.InferenceSlicer(callback = callback)

detections = slicer(image)
Source code in supervision/detection/tools/inference_slicer.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __call__(self, image: np.ndarray) -> Detections:
    """
    Performs slicing-based inference on the provided image using the specified
        callback.

    Args:
        image (np.ndarray): The input image on which inference needs to be
            performed. The image should be in the format
            `(height, width, channels)`.

    Returns:
        Detections: A collection of detections for the entire image after merging
            results from all slices and applying NMS.

    Example:
        ```python
        import cv2
        import supervision as sv
        from ultralytics import YOLO

        image = cv2.imread(SOURCE_IMAGE_PATH)
        model = YOLO(...)

        def callback(image_slice: np.ndarray) -> sv.Detections:
            result = model(image_slice)[0]
            return sv.Detections.from_ultralytics(result)

        slicer = sv.InferenceSlicer(callback = callback)

        detections = slicer(image)
        ```
    """
    detections_list = []
    resolution_wh = (image.shape[1], image.shape[0])
    offsets = self._generate_offset(
        resolution_wh=resolution_wh,
        slice_wh=self.slice_wh,
        overlap_ratio_wh=self.overlap_ratio_wh,
    )

    with ThreadPoolExecutor(max_workers=self.thread_workers) as executor:
        futures = [
            executor.submit(self._run_callback, image, offset) for offset in offsets
        ]
        for future in as_completed(futures):
            detections_list.append(future.result())

    return Detections.merge(detections_list=detections_list).with_nms(
        threshold=self.iou_threshold
    )

Comments