【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二)

2024-06-04 7561阅读

原创作者:RS迷途小书童

博客地址:https://blog.csdn.net/m0_56729804?type=blog

我上篇博文分享了Segment Anything(SAM)模型的基本操作,这篇给大家分享下官方的整张图片的语义分割代码(全局),同时我还修改了一部分支持掩膜和叠加影像的保存。

1 Segment Anything介绍

1.1 概况

        Meta AI 公司的 Segment Anything 模型是一项革命性的技术,该模型能够根据文本指令或图像识别,实现对任意物体的识别和分割。这一模型的推出,将极大地推动计算机视觉领域的发展,并使得图像分割技术进一步普及化。

        论文地址:https://arxiv.org/abs/2304.02643

        项目地址:Segment Anything

1.2 使用方法

        具体使用方法上,Segment Anything 提供了简单易用的接口,用户只需要通过提示,即可进行物体识别和分割操作。例如在图片处理中,用户可以通过 Hover & Click 或 Box 等方式来选取物体。值得一提的是,SAM 还支持通过上传自己的图片进行物体分割操作,提取物体用时仅需数秒。

        总的来说,Meta AI 的 Segment Anything 模型为我们提供了一种全新的物体识别和分割方式,其强大的泛化能力和广泛的应用前景将极大地推动计算机视觉领域的发展。未来,我们期待看到更多基于 Segment Anything 的创新应用,以及在科学图像分析、照片编辑等领域的广泛应用。

【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二) 第1张

​​2 模型代码+注释

2.1 模型预加载

        我这里将掩膜生成的函数单独拿出来了,因为里面集成了掩膜保存的代码。所以先给大家看预处理部分。

    try:
        image = cv2.imread(image_path)  # 读取的图像以NumPy数组的形式存储在变量image中
        print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
        print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    except:
        print("图片打开失败!请检查路径!")
        pass
        sys.exit()
    sys.path.append("..")  # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
    sam_checkpoint = model_path  # 定义模型路径
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)  # 定义模型参数
    mask_generator = SamAutomaticMaskGenerator(model=sam,  # 用于掩膜预测的SAM模型
points_per_side=32,  # 图像一侧的采样点数,总采样点数是一侧采样点数的平方,点数给的越多,分割越细
# points_per_batch=64,  # 设置模型同时运行的点的数量。更高的数字可能会更快,但会使用更多的GPU内存
pred_iou_thresh=0.86,  # 滤波阈值,在[0,1]中,使用模型的预测掩膜质量0.86
stability_score_thresh=0.92,
# 滤波阈值,在[0,1]中,使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性0.92
# stability_score_offset=1.0,  # 计算稳定性分数时,对截止点的偏移量
# box_nms_thresh=0.7,  # 非最大抑制用于过滤重复掩码的箱体IoU截止点
crop_n_layers=1,  # 如果>0,蒙版预测将在图像的裁剪上再次运行。设置运行的层数,其中每层有2**i_layer的图像裁剪数1
# crop_nms_thresh=0.7,  # 非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值
# crop_overlap_ratio=512 / 1500,  # 设置作物重叠的程度
crop_n_points_downscale_factor=2,
# 在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减2
# point_grids=None,  # 用于取样的明确网格的列表,归一化为[0,1]
min_mask_region_area=100,
# 如果>0,后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。50
# output_mode="binary_mask"  # 掩模的返回形式。
# 可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。
# coco_rle’需要pycocotools。对于大的分辨率,'binary_mask’可能会消耗大量的内存
)  # 激活函数

2.2 模型预测代码

masks = mask_generator.generate(image)  # 类别掩膜提取(包含所有的,可按照索引查看)
# ---------------------------masks输出内容---------------------------
# segmentation : np的二维数组,为二值的mask图片
# area : mask的像素面积
# bbox : mask的外接矩形框,为X Y WH格式
# predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
# point_coords : 用于生成该mask的point输入
# stability_score : mask质量的附加指标
# crop_box : 用于以X Y WH格式生成此遮罩的图像裁剪
# ------------------------------------------------------------------
print("[%s]正在绘制图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.figure(figsize=(20, 20))  # 创建一个新的图形窗口,设置其大小为10x10英寸
plt.imshow(image)  # 使用imshow函数在创建的图形窗口中显示图像
print("[%s]正在制作掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("【结果保存阶段】")
show_mask_auto(masks, out_path, out_path1)
plt.axis('on')  # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
print("[%s]正在保存叠加结果......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.savefig(out_image_path, dpi=300)
plt.show()  # 显示已经创建的图形窗口和其中的内容

2.3 掩膜生成+保存代码

        我这里在官方的掩膜生成的函数的基础上,加入了两段保存数据的代码。一个是彩色的mask(叠加显示的mask),一个是单波段的mask(DN值代表序号)。

        大家在使用这个函数时,将这段放在2.1,2.2展示的代码前面即可。

def show_mask_auto(masks_data, out_mask_path, out_path_01):
    """
    :param masks_data: 掩膜数据
    :param out_mask_path: 输出彩色掩膜
    :param out_path_01: 输出单波段掩膜
    :return: None
    """
    if len(masks_data) == 0:
        return
    sorted_masks_data = sorted(masks_data, key=(lambda x: x['area']), reverse=True)  # 按照面积大小降序排列
    ax = plt.gca()  # 获取当前的轴(axes)
    ax.set_autoscale_on(False)  # 关闭轴的自动缩放功能
    img = np.ones((sorted_masks_data[0]['segmentation'].shape[0], sorted_masks_data[0]['segmentation'].shape[1], 4))
    # 创建了一个新的三维数组img。数组的形状是基于segmentation']的形状,其中四个通道通常代表红色、绿色、蓝色和透明度(RGBA)
    img[:, :, 3] = 0  # 将新创建的图像的第四个通道(也就是透明度通道)设置为0
    img_raster = np.zeros((sorted_masks_data[0]['segmentation'].shape[0],
                          sorted_masks_data[0]['segmentation'].shape[1]))
    # 创建一个二维数组,用于保存掩膜做栅格转面
    j = 1
    for sorted_mask_data in sorted_masks_data:
        # 循环所有类别的掩膜
        m = sorted_mask_data['segmentation']
        # 获取当前类别的二值mask图片
        color_mask = np.concatenate([np.random.random(3), [0.65]])
        # 随机生成的RGB颜色,它的形状为(3,),0.65表示颜色的透明度。
        img[m] = color_mask
        # 将颜色赋予图片的数组
        img_raster[m] = j
        # 给掩膜赋值
        j += 1
    """for i in range(0, len(masks_data)):
        # 循环所有类别的掩膜
        rect = patches.Rectangle((masks_data[i]['bbox'][0], masks_data[i]['bbox'][1]), masks_data[i]['bbox'][2],
                                 masks_data[i]['bbox'][3], edgecolor=tuple(random.uniform(0, 1) for _ in range(3)),
                                 facecolor='none', linewidth=2)  # 绘制类别的外接矩形框
        ax.add_patch(rect)  # 将矩形添加到ax对象中"""
    plt.imshow(img, alpha=0.8)
    print("[%s]正在保存类别掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    driver = gdal.GetDriverByName('GTiff')  # 载入数据驱动,用于存储内存中的数组
    ds_result = driver.Create(out_mask_path, sorted_masks_data[0]['segmentation'].shape[1],
                              sorted_masks_data[0]['segmentation'].shape[0], bands=4, eType=gdal.GDT_Float64)
    # 创建一个数组,宽高为原始尺寸
    for i in range(3):
        ds_result.GetRasterBand(i+1).SetNoDataValue(0)  # 将无效值设为0
        ds_result.GetRasterBand(i+1).WriteArray(img[:, :, i])  # 将结果写入数组
    ds_result_raster = driver.Create(out_path_01, sorted_masks_data[0]['segmentation'].shape[1],
                                     sorted_masks_data[0]['segmentation'].shape[0], bands=1, eType=gdal.GDT_Float64)
    # ds_result.SetGeoTransform(ds_geo)  # 导入仿射地理变换参数
    # ds_result.SetProjection(ds_prj)  # 导入投影信息
    ds_result_raster.GetRasterBand(1).SetNoDataValue(0)  # 将无效值设为0
    ds_result_raster.GetRasterBand(1).WriteArray(img_raster)  # 将结果写入数组
    del ds_result
    del ds_result_raster

【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二) 第2张

【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二) 第3张

【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二) 第4张

3 完整代码

# -*- coding: utf-8 -*-
"""
@Time : 2023/10/8 10:15
@Auth : RS迷途小书童
@File :Segment Anything Auto.py
@IDE :PyCharm
@Purpose:Segment Anything Model自动全局语义分割
"""
import sys
import cv2
import random
import numpy as np
from osgeo import gdal
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
def SAM_auto(image_path, model_path, model_type, device, out_path, out_path1, out_image_path):
    """
    :param image_path: 输入需要分割的影像
    :param model_path: 输入模型路径
    :param model_type: 输入模型类型
    :param device: 输入cpu or cuda
    :param out_path: 输出彩色掩膜文件
    :param out_path1: 输出单波段掩膜文件
    :param out_image_path: 输出叠加图片
    :return: None
    """
    def show_mask_auto(masks_data, out_mask_path, out_path_01):
        """
        :param masks_data: 掩膜数据
        :param out_mask_path: 输出彩色掩膜
        :param out_path_01: 输出单波段掩膜
        :return: None
        """
        if len(masks_data) == 0:
            return
        sorted_masks_data = sorted(masks_data, key=(lambda x: x['area']), reverse=True)  # 按照面积大小降序排列
        ax = plt.gca()  # 获取当前的轴(axes)
        ax.set_autoscale_on(False)  # 关闭轴的自动缩放功能
        img = np.ones((sorted_masks_data[0]['segmentation'].shape[0], sorted_masks_data[0]['segmentation'].shape[1], 4))
        # 创建了一个新的三维数组img。数组的形状是基于segmentation']的形状,其中四个通道通常代表红色、绿色、蓝色和透明度(RGBA)
        img[:, :, 3] = 0  # 将新创建的图像的第四个通道(也就是透明度通道)设置为0
        img_raster = np.zeros((sorted_masks_data[0]['segmentation'].shape[0],
                              sorted_masks_data[0]['segmentation'].shape[1]))
        # 创建一个二维数组,用于保存掩膜做栅格转面
        j = 1
        for sorted_mask_data in sorted_masks_data:
            # 循环所有类别的掩膜
            m = sorted_mask_data['segmentation']
            # 获取当前类别的二值mask图片
            color_mask = np.concatenate([np.random.random(3), [0.65]])
            # 随机生成的RGB颜色,它的形状为(3,),0.65表示颜色的透明度。
            img[m] = color_mask
            # 将颜色赋予图片的数组
            img_raster[m] = j
            # 给掩膜赋值
            j += 1
        """for i in range(0, len(masks_data)):
            # 循环所有类别的掩膜
            rect = patches.Rectangle((masks_data[i]['bbox'][0], masks_data[i]['bbox'][1]), masks_data[i]['bbox'][2],
                                     masks_data[i]['bbox'][3], edgecolor=tuple(random.uniform(0, 1) for _ in range(3)),
                                     facecolor='none', linewidth=2)  # 绘制类别的外接矩形框
            ax.add_patch(rect)  # 将矩形添加到ax对象中"""
        plt.imshow(img, alpha=0.8)
        print("[%s]正在保存类别掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        driver = gdal.GetDriverByName('GTiff')  # 载入数据驱动,用于存储内存中的数组
        ds_result = driver.Create(out_mask_path, sorted_masks_data[0]['segmentation'].shape[1],
                                  sorted_masks_data[0]['segmentation'].shape[0], bands=4, eType=gdal.GDT_Float64)
        # 创建一个数组,宽高为原始尺寸
        for i in range(3):
            ds_result.GetRasterBand(i+1).SetNoDataValue(0)  # 将无效值设为0
            ds_result.GetRasterBand(i+1).WriteArray(img[:, :, i])  # 将结果写入数组
        ds_result_raster = driver.Create(out_path_01, sorted_masks_data[0]['segmentation'].shape[1],
                                         sorted_masks_data[0]['segmentation'].shape[0], bands=1, eType=gdal.GDT_Float64)
        # ds_result.SetGeoTransform(ds_geo)  # 导入仿射地理变换参数
        # ds_result.SetProjection(ds_prj)  # 导入投影信息
        ds_result_raster.GetRasterBand(1).SetNoDataValue(0)  # 将无效值设为0
        ds_result_raster.GetRasterBand(1).WriteArray(img_raster)  # 将结果写入数组
        del ds_result
        del ds_result_raster
    print("【程序准备阶段】")
    print("[%s]正在读取图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    try:
        image = cv2.imread(image_path)  # 读取的图像以NumPy数组的形式存储在变量image中
        print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
        print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    except:
        print("图片打开失败!请检查路径!")
        pass
        sys.exit()
    sys.path.append("..")  # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
    sam_checkpoint = model_path  # 定义模型路径
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)  # 定义模型参数
    mask_generator = SamAutomaticMaskGenerator(model=sam,  # 用于掩膜预测的SAM模型
points_per_side=32,  # 图像一侧的采样点数,总采样点数是一侧采样点数的平方,点数给的越多,分割越细
# points_per_batch=64,  # 设置模型同时运行的点的数量。更高的数字可能会更快,但会使用更多的GPU内存
pred_iou_thresh=0.86,  # 滤波阈值,在[0,1]中,使用模型的预测掩膜质量0.86
stability_score_thresh=0.92,
# 滤波阈值,在[0,1]中,使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性0.92
# stability_score_offset=1.0,  # 计算稳定性分数时,对截止点的偏移量
# box_nms_thresh=0.7,  # 非最大抑制用于过滤重复掩码的箱体IoU截止点
crop_n_layers=1,  # 如果>0,蒙版预测将在图像的裁剪上再次运行。设置运行的层数,其中每层有2**i_layer的图像裁剪数1
# crop_nms_thresh=0.7,  # 非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值
# crop_overlap_ratio=512 / 1500,  # 设置作物重叠的程度
crop_n_points_downscale_factor=2,
# 在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减2
# point_grids=None,  # 用于取样的明确网格的列表,归一化为[0,1]
min_mask_region_area=100,
# 如果>0,后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。50
# output_mode="binary_mask"  # 掩模的返回形式。
# 可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。
# coco_rle’需要pycocotools。对于大的分辨率,'binary_mask’可能会消耗大量的内存
)  # 激活函数
    print("【模型预测阶段】")
    print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    masks = mask_generator.generate(image)  # 类别掩膜提取(包含所有的,可按照索引查看)
    # ---------------------------masks输出内容---------------------------
    # segmentation : np的二维数组,为二值的mask图片
    # area : mask的像素面积
    # bbox : mask的外接矩形框,为X Y WH格式
    # predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
    # point_coords : 用于生成该mask的point输入
    # stability_score : mask质量的附加指标
    # crop_box : 用于以X Y WH格式生成此遮罩的图像裁剪
    # ------------------------------------------------------------------
    print("[%s]正在绘制图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    plt.figure(figsize=(20, 20))  # 创建一个新的图形窗口,设置其大小为10x10英寸
    plt.imshow(image)  # 使用imshow函数在创建的图形窗口中显示图像
    print("[%s]正在制作掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print("【结果保存阶段】")
    show_mask_auto(masks, out_path, out_path1)
    plt.axis('on')  # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
    print("[%s]正在保存叠加结果......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    plt.savefig(out_image_path, dpi=300)
    plt.show()  # 显示已经创建的图形窗口和其中的内容
    print("-----------------------------------------语义分割已完成----------------------------------------")
if __name__ == "__main__":
    print("\n")
    print("--------------------------------------Segment Anything--------------------------------------")
    Image_path = r'B:/Personal/satellite.tif'  # 分割的影像
    Model_path = "G:/Neat Download Manager/Misc/sam_vit_h_4b8939.pth"  # 模型路径
    Out_mask_path = 'B:/Personal/my_figure1.tif'  # 彩色掩膜
    Out_mask_path1 = 'B:/Personal/my_figure2.tif'  # 二维掩膜用于转矢量
    Out_image_path = 'B:/Personal/my_figure3.png'  # 叠加结果
    Model_type = "vit_h"  # 定义模型类型
    Device = "cuda"  # "cpu"  or  "cuda"
    SAM_auto(Image_path, Model_path, Model_type, Device, Out_mask_path, Out_mask_path1, Out_image_path)
    # 图片,模型,类型,算力,彩色掩膜,黑白掩膜,叠加图片

    免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

    目录[+]