从零开始使用Surya-OCR——项目源码拆解
本文记录详细拆解surya项目源代码,以及在实际部署中对模型的进一步修改服务于实际应用
目录
一、Surya模型检测使用Python接口中的源码详解
使用surya源码进行模型检测的过程中, 模型的各种参数设置、环境变量配置都写在下载的源码下 /surya/settings.py 文件中,修改其中的参数即可实现全局的配置。
1.选择模型检测GPU
修改 settings.py 文件中的 TORCH_DEVICE 参数,默认为 None 时,运行监测代码会自动检查当前设备,并选择索引顺序最小的GPU运行——‘cuda:0’。在实际部署中,如果服务器中第一块GPU——‘cuda:0’有其他模型在跑,可能需要调整模型预测位置,将模型放到另一块GPU上运行。只需修改以下参数代码。
# 指定模型所在GPU
TORCH_DEVICE: Optional[str] = 'cuda:1'
2.配置加载模型参数
在 settings.py 文件中模型参数默认为在线加载,当服务器无法连接外部网络时,离线部署加载模型参数需调整设置中的地址。需根据自己模型存放位置,修改为下面的参数字符串。
## 地址需改为你存放模型的绝对地址
# 文本行检测模型
DETECTOR_MODEL_CHECKPOINT: str = "//Surya-OCR/hugging_model/surya_det2"
# 文本区域检测模型
LAYOUT_MODEL_CHECKPOINT: str = "//Surya-OCR/hugging_model/surya_layout"
默认在线加载模型参数位置:
参数修改内容:(为你存放离线下载模型地址)
测试模型是否加载成功的代码如下:
from surya.model.detection.segformer import load_model, load_processor
from surya.settings import settings
# 行检测模型:surya_det_2
det_model = load_model()
det_processor = load_processor()
print('det2_model load success')
# 区域检测:surya_layout
model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
print('layout_model load success')
修改后,检验模型加载情况:
3.批量检测图片
实际部署中,官方文档提供 Python 接口检测的代码对单个图片检测顺利,但在批量检测图片集——文件夹时报错。 (单图代码写在上一篇博文《从零开始使用Surya-OCR——文本目标检测模型的安装与部署》中,具体可参考https://blog.csdn.net/qq_58718853/article/details/137150986)
第一个报错是,官方提供接口函数,无法读取文件夹内图片,报读取文件权限被拒。暂未实现直接解决该问题的办法。参看后续 batch_text_detection 源代码传参信息,得知图片读取后的传入函数结果是一个列表,可以选择替代方案实现同等效果,替代方案代码如下。
import os
from PIL import Image
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
IMAGE_PATH = 'image_path'
model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
print('model load success')
# 批量将文件夹图片读入images列表中
images = []
for file in os.listdir(IMAGE_PATH):
image_path = os.path.join(IMAGE_PATH, file)
image = Image.open(image_path)
images.append(image)
predictions = batch_text_detection(images, model, processor)
print(predictions)
使用新代码后,原本的报错问题解决了,但出现了新的报错。
第二个报错是,检查surya模型做批量预测任务时,在得到模型输出后,还会对多张图片的结果进行多进程处理。问题就在多进程处理时,源代码内置函数重复调用主函数——导致模型本来只需加载一次,此时不断加载模型致使GPU显存爆了。我们找到对应报错的源码处检查。
按住 Ctrl 并点击报错的 batch_text_detection 函数,即可进入源码处,检查研究发现可能是Windows系统和 Linux 系统对于多进程和多线程的解释存在差异,本机是windows系统, 可能在导入 ProcessPoolExecutor 函数时,将我的主函数视为多进程对象,不断创建新的进程。但是此处希望实现的是中间过程结果的多线程处理,而不影响主函数。因此需要将 ProcessPoolExecutor 改为使用多线程的 ThreadPoolExecutor,问题即可解决。
经过上述所有源代码修改处理,成功运行主函数,得到surya模型批量检测图片后得到的框信息结果,并将其打印出来。
4.检测输出结果源码解读
Surya模型输出是自定义类的数据格式,下面根据其官方文档和项目源码解读其输出的格式,以方便后续对输出的处理,提取出所需的数据信息。
官方文档:https://github.com/VikParuchuri/surya
Surya 模型有三种预测模式——OCR & Text Line & Layout,对应三种模型输出的格式,每种模式的输出都是以类的形式定义的。下面重点放在 Text Line 文本行检测和 Layout 区域检测的源码信息解读上。
①文本行检测的模型输出——Text Line
将与输出相关的源代码从项目中单独提取出来看,下面是输出的基础类,即每个图片模型预测后的信息都封装在了 TextDetectionResult 里面。
"文本行检测"
# 输出的基础类
class TextDetectionResult(BaseModel):
bboxes: List[PolygonBox]
vertical_lines: List[ColumnLine]
horizontal_lines: List[ColumnLine]
heatmap: Any
affinity_map: Any
image_bbox: List[float]
下面分别解释输出基础类中的具体信息都是怎么定义的,从源码中提出相关代码。
输出的第一个类信息:PolygonBox 注解
(下面非完整代码,为清晰类输出含义,只选取主要功能代码)
# 输出框信息类
class PolygonBox(BaseModel):
polygon: List[List[float]] ## 存储框四个角——全坐标
confidence: Optional[float] = None ## 框预测置信度
def bbox(self) -> List[float]:
box = [self.polygon[0][0], self.polygon[0][1], self.polygon[1][0], self.polygon[2][1]]
if box[0] > box[2]:
box[0], box[2] = box[2], box[0]
if box[1] > box[3]:
box[1], box[3] = box[3], box[1]
return box ## 存储框左上右下——对角坐标
通过源码可知,此类保存的是检测框的坐标信息和置信度,这是预测中的主要信息。使用具体的模型输出结果,可以更清楚该类的输出形状。输出TextDetectionResult 包含多个类信息,其中定义的 bboxes ——即 PolygonBox 存储的是一张图片内检测出来的所有框,而每个框的信息结构是包含三个子类:全坐标、置信度和对角坐标。
输出的第二个类信息:ColumnLine 注解
(下面非完整代码,为清晰类输出含义,只选取主要功能代码)
# 图片中线的检测
class ColumnLine(Bbox):
vertical: bool # 垂直线:有-True;无-False
horizontal: bool # 水平线:有-True;无-False
# 检测线的框坐标
class Bbox(BaseModel):
bbox: List[float]
同样通过源码可知,输出的第二个类信息 ColumnLine 是用来保存在模型预测中对图片检测到的水平线、垂直线的信息,如果检测到了那么就同时保存其框位置。通过具体模型输出对应位置检索,可以清晰理解。
输出的剩余类信息:heatmap、affinity_map、image_bbox
剩余的类信息是对图片结果输出的补充,我们可以直接将其打印出来,看看其内容。最容易看出来的是 image_bbox ,其实际就是用一个框将整个图片框起来,然后返回对角坐标。而heatmap、affinity_map 则是 PIL.Image.Image 的类信息。
②文本区域检测的模型输出 ——Layout
同文本行检测一样,废话少说,直接上代码和图。
# 区域检测输出
class LayoutResult(BaseModel):
bboxes: List[LayoutBox]
segmentation_map: Any
image_bbox: List[float]、
class LayoutBox(PolygonBox):
label: str ## 多了一个区域类别的预测
5.批量信息的保存和可视化
将surya模型自定义的类输出转化为 json 格式保存到指定文件夹,代码如下。
import os
import json
from PIL import Image
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
import cv2
import numpy as np
IMAGE_PATH = 'iamge_path' ## 检测图片保存地址
json_file = 'json_path' ## 框json保存地址
checkpoint = 'model_path' ## 模型参数加载地址
heat_file = 'heat_path' ## 热图保存
########### 上述为修改部分,根据实际地址填入,下面无需修改 ############
model, processor = load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)
print('model load success')
# 模型预测
images = []
image_name = []
for file in os.listdir(IMAGE_PATH):
image_path = os.path.join(IMAGE_PATH, file)
image = Image.open(image_path)
images.append(image)
image_name.append(file)
predictions = batch_text_detection(images, model, processor)
print('predict success')
# 保存模型结果
## 类型转为json
def class_to_json(bboxes, file, box_type=True):
json_list = []
for i, bbox in enumerate(bboxes):
if box_type:
json_dict = dict()
box = bbox.bbox
box.append(bbox.confidence)
json_dict["id"] = i
json_dict["name"] = file
json_dict["box"] = box
json_list.append(json_dict)
else:
json_dict = dict()
box = bbox.bbox
json_dict["id"] = i
json_dict["name"] = file
json_dict["box"] = box
json_list.append(json_dict)
return json_list
## 保存到指定文件夹
def save_json(json_list, json_path):
with open(json_path, 'w') as f:
json.dump(json_list, f)
## 主函数
def save_predict(predictions, image_name, heat_file):
for i, pred in enumerate(predictions):
# 框信息保存
bboxes = pred.bboxes
vertical = pred.vertical_lines
horizontal = pred.horizontal_lines
file = image_name[i]
bboxes_json = class_to_json(bboxes, file)
vertical_json = class_to_json(vertical, file, box_type=False)
horizontal_json = class_to_json(horizontal, file, box_type=False)
basename = file.split('.')[0]
save_json(bboxes_json, os.path.join(json_file+'box/', basename + '.json'))
save_json(vertical_json, os.path.join(json_file + 'vertical/', basename + '.json'))
save_json(horizontal_json, os.path.join(json_file + 'horizontal/', basename + '.json'))
# 热图调参信息保存
heatmap = pred.heatmap
img = cv2.cvtColor(np.asarray(heatmap), cv2.COLOR_RGB2BGR)
cv2.imwrite(heat_file+basename+'.jpg', img)
print(basename + ' success')
if __name__ == '__main__':
save_predict(predictions, image_name, heat_file)
可视化框的代码如下。
import os
import json
import cv2
# jpg、json、vis文件位置
jpg_path = 'JPG'
json_path = 'JSON'
vis_path = 'VIS'
########### 上述为修改部分,根据实际地址填入,下面无需修改 ############
# 可视化锚框
## 锚框展示细节
def hsv2bgr(h, s, v):
h_i = int(h * 6)
f = h * 6 - h_i
p = v * (1 - s)
q = v * (1 - f * s)
t = v * (1 - (1 - f) * s)
r, g, b = 0, 0, 0
if h_i == 0:
r, g, b = v, t, p
elif h_i == 1:
r, g, b = q, v, p
elif h_i == 2:
r, g, b = p, v, t
elif h_i == 3:
r, g, b = p, q, v
elif h_i == 4:
r, g, b = t, p, v
elif h_i == 5:
r, g, b = v, p, q
return int(b * 255), int(g * 255), int(r * 255)
def random_color(id):
h_plane = (((id << 2) ^ 0x937151) % 100) / 100.0
s_plane = (((id << 3) ^ 0x315793) % 100) / 100.0
return hsv2bgr(h_plane, s_plane, 1)
# 可视化主函数
def visualize(json_path, jpg_path, vis_path, box_type=True):
if box_type:
for file in os.listdir(json_path):
with open(json_path+file,'r') as f:
drawResult = json.load(f)
basefile = file.split('.')[0]
jpg_file = os.path.join(jpg_path,basefile+".jpg")
img = cv2.imread(jpg_file)
for idx, result in enumerate(drawResult):
left, top, right, bottom = int(result['box'][0]), int(result['box'][1]), int(result['box'][2]), int(result['box'][3])
label = int(result['box'][4])
color = random_color(1)
cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)
caption = f"{'ZW'}"
w, h = cv2.getTextSize(caption, 0, 1, 2)[0]
cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)
cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)
save_file = os.path.join(vis_path, basefile+".jpg")
print(save_file)
cv2.imwrite(save_file, img)
else:
for file in os.listdir(json_path):
with open(json_path+file,'r') as f:
drawResult = json.load(f)
basefile = file.split('.')[0]
jpg_file = os.path.join(jpg_path,basefile+".jpg")
img = cv2.imread(jpg_file)
for idx, result in enumerate(drawResult):
left, top, right, bottom = int(result['box'][0]), int(result['box'][1]), int(result['box'][2]), int(result['box'][3])
color = random_color(1)
cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)
caption = f"{'ZW'}"
w, h = cv2.getTextSize(caption, 0, 1, 2)[0]
cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)
cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)
save_file = os.path.join(vis_path, basefile+".jpg")
print(save_file)
cv2.imwrite(save_file, img)
if __name__ == '__main__':
visualize(json_path+'box/', jpg_path, vis_path+'box/')
visualize(json_path + 'vertical/', jpg_path, vis_path + 'vertical/', box_type=False)
visualize(json_path + 'horizontal/', jpg_path, vis_path + 'horizontal/', box_type=False)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)