ai_project_v1/uav_module/segementation.py

70 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
import json
import os
from io import BytesIO
from middleware.minio_util import downFile, upload_file_from_buffer, upload_file
from mqtt_pub import MQTTClient
from uav_module.uav_seg import UAVSegPredictor
mqtt_client=None
# MQTT 代理地址和端口
# broker = "112.44.103.230" # 公共 MQTT 代理(免费)
broker = "8.137.54.85" # 公共 MQTT 代理(免费)
port = 1883 # MQTT 默认端口
# 主题
topic = "thing/product/ai/events"
def segementation_func(task_id, s3_id, s3_url, func_id):
predictor = UAVSegPredictor(
model_path='uav_module/checkpoints/deeplabv3plus_best.pth',
model_type='deeplabv3plus',
num_classes=7
)
color_map = [
[255, 0, 0], # 类别0: 红色
[0, 255, 0], # 类别1: 绿色
[0, 0, 255], # 类别2: 蓝色
[255, 255, 0], # 类别3: 黄色
[255, 0, 255], # 类别4: 品红
[0, 255, 255], # 类别5: 青色
[128, 0, 128] # 类别6: 紫色
]
# 类别 0黑色 ;背景
# 类别 1棕色 ;荒地
# 类别 2绿色 ;林地
# 类别 3黄色 ;农田
# 类别 4蓝色 ;水域
# 类别 5灰色 ;道路
# 类别 6青色 ;建筑用地
mqtt_client = MQTTClient(broker, port, topic)
# dir_name = "ai_result"
for img_url in s3_url:
pic = downFile(img_url)
color_buffer=predictor.predict(pic, 'uav_module/output_test/patch_0011.png', color_map)
if color_buffer is not None:
pic_name = os.path.basename(pic)
date_str = datetime.datetime.now().strftime("%Y%m%d")
time_s = datetime.datetime.now().timestamp()
# mqtt_client = MQTTClient(broker, port, topic)
pic_name_before = f"{date_str}/{time_s}-before-{pic_name}"
pic_name_after = f"{date_str}/{time_s}-after-{pic_name}"
minio_path_before, file_type_before = upload_file(pic,None)
minio_path_after, file_type_after = upload_file_from_buffer(color_buffer, pic_name_after)
message = {
"flight_task_id": task_id,
"minio": {
"minio_path_before": minio_path_before,
"minio_path_after": minio_path_after,
"file_type": file_type_after
}
}
json_message = json.dumps(message, indent=4, ensure_ascii=False)
mqtt_client.publish_message(json_message)