ai_project_v1/b3dm/data_3dtiles_to_dem.py
2026-01-14 11:37:35 +08:00

757 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# data_3dtiles_to_dem52
import os
import json
import numpy as np
import pandas as pd
import pyproj
import struct
from osgeo import gdal, osr
import uuid
from glb_with_draco import DracoGLBParser
# 解决GDAL中文路径/警告问题(生产必加)
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("CPL_ZIP_ENCODING", "UTF-8")
gdal.PushErrorHandler('CPLQuietErrorHandler')
class RegionFilter:
"""完整的区域过滤器,支持所有格式"""
def __init__(self, region_coords=None, enable_tile_filter=True, debug=False):
"""
初始化区域过滤器
:param debug: 是否输出调试信息
"""
self.region_coords = region_coords
self.enable_tile_filter = enable_tile_filter
self.debug = debug
if region_coords:
# 提取区域边界
lons = [coord[0] for coord in region_coords]
lats = [coord[1] for coord in region_coords]
self.min_lon = min(lons)
self.max_lon = max(lons)
self.min_lat = min(lats)
self.max_lat = max(lats)
# 扩展边界(避免边缘误差)
self.expand_factor = 0.1
lon_expand = (self.max_lon - self.min_lon) * self.expand_factor
lat_expand = (self.max_lat - self.min_lat) * self.expand_factor
self.filter_min_lon = self.min_lon - lon_expand
self.filter_max_lon = self.max_lon + lon_expand
self.filter_min_lat = self.min_lat - lat_expand
self.filter_max_lat = self.max_lat + lat_expand
if self.debug:
print(f"[DEBUG] 区域过滤器初始化:")
print(f" 目标区域: Lon[{self.min_lon:.6f}, {self.max_lon:.6f}], "
f"Lat[{self.min_lat:.6f}, {self.max_lat:.6f}]")
print(f" 过滤区域: Lon[{self.filter_min_lon:.6f}, {self.filter_max_lon:.6f}], "
f"Lat[{self.filter_min_lat:.6f}, {self.filter_max_lat:.6f}]")
else:
self.min_lon = self.min_lat = -180
self.max_lon = self.max_lat = 180
self.filter_min_lon = self.filter_min_lat = -180
self.filter_max_lon = self.filter_max_lat = 180
if self.debug:
print("[DEBUG] 区域过滤器: 未指定区域,将处理所有数据")
def check_tile_bounding_volume(self, bounding_volume):
"""检查瓦片的包围体是否与指定区域相交"""
if not self.region_coords or not self.enable_tile_filter:
return True
try:
if 'region' in bounding_volume:
return self._check_region(bounding_volume['region'])
elif 'box' in bounding_volume:
box = bounding_volume['box']
if self.debug:
print(f"[DEBUG] 检查box包围体长度={len(box)}")
if len(box) == 12:
result = self._check_box_12(box)
elif len(box) == 15:
result = self._check_box_15(box)
else:
if self.debug:
print(f"[DEBUG] 异常box长度 {len(box)},默认通过")
return True
if self.debug and not result:
print(f"[DEBUG] Box被过滤")
return result
elif 'sphere' in bounding_volume:
return self._check_sphere(bounding_volume['sphere'])
return True
except Exception as e:
if self.debug:
print(f"[DEBUG] 包围体检查出错: {e}")
return True
def _check_region(self, region):
"""检查region格式 [west, south, east, north, minHeight, maxHeight]"""
if len(region) != 6:
return True
west, south, east, north, min_h, max_h = region
# 检查是否完全在过滤区域外
if (east < self.filter_min_lon or west > self.filter_max_lon or
north < self.filter_min_lat or south > self.filter_max_lat):
if self.debug:
print(f"[DEBUG] Region过滤: [{west:.3f},{south:.3f},{east:.3f},{north:.3f}]")
return False
return True
def _check_box_12(self, box):
"""检查12值box格式"""
# 提取参数
cx, cy, cz = box[0], box[1], box[2]
halfX = box[3]
halfY = box[7]
halfZ = box[11]
if self.debug:
print(f"[DEBUG] 12值box: center=({cx:.1f},{cy:.1f},{cz:.1f}), "
f"halfs=({halfX},{halfY},{halfZ})")
try:
# 获取转换器
transformer = self._get_transformer()
# 转换中心点
center_lon, center_lat, _ = transformer.transform(
cx, cy, cz, radians=False
)
# 计算最大偏移(简化方法,避免计算所有角点)
max_half = max(halfX, halfY, halfZ)
# 转换为经纬度偏移
earth_radius = 6378137.0
lon_offset = np.degrees(max_half / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(max_half / earth_radius)
# 计算box范围
box_min_lon = center_lon - lon_offset
box_max_lon = center_lon + lon_offset
box_min_lat = center_lat - lat_offset
box_max_lat = center_lat + lat_offset
if self.debug:
print(f"[DEBUG] 中心: ({center_lon:.6f}, {center_lat:.6f})")
print(f"[DEBUG] 范围: Lon[{box_min_lon:.6f}, {box_max_lon:.6f}], "
f"Lat[{box_min_lat:.6f}, {box_max_lat:.6f}]")
# 检查相交
if (box_max_lon < self.filter_min_lon or box_min_lon > self.filter_max_lon or
box_max_lat < self.filter_min_lat or box_min_lat > self.filter_max_lat):
return False
return True
except Exception as e:
if self.debug:
print(f"[DEBUG] Box12检查失败: {e}")
return True # 失败时默认通过
def _check_box_15(self, box):
"""检查标准15值box格式"""
if len(box) < 15:
return True
cx, cy, cz = box[0], box[1], box[2]
halfX, halfY, halfZ = box[12], box[13], box[14]
# 简化处理:只检查中心点
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
# 计算最大偏移
max_half = max(halfX, halfY, halfZ)
earth_radius = 6378137.0
lon_offset = np.degrees(max_half / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(max_half / earth_radius)
box_min_lon = center_lon - lon_offset
box_max_lon = center_lon + lon_offset
box_min_lat = center_lat - lat_offset
box_max_lat = center_lat + lat_offset
if (box_max_lon < self.filter_min_lon or box_min_lon > self.filter_max_lon or
box_max_lat < self.filter_min_lat or box_min_lat > self.filter_max_lat):
return False
return True
def _check_sphere(self, sphere):
"""检查sphere格式 [centerX, centerY, centerZ, radius]"""
if len(sphere) < 4:
return True
cx, cy, cz, radius = sphere[0], sphere[1], sphere[2], sphere[3]
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
# 计算半径对应的经纬度偏移
earth_radius = 6378137.0
lon_offset = np.degrees(radius / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(radius / earth_radius)
sphere_min_lon = center_lon - lon_offset
sphere_max_lon = center_lon + lon_offset
sphere_min_lat = center_lat - lat_offset
sphere_max_lat = center_lat + lat_offset
if (sphere_max_lon < self.filter_min_lon or sphere_min_lon > self.filter_max_lon or
sphere_max_lat < self.filter_min_lat or sphere_min_lat > self.filter_max_lat):
return False
return True
def _get_transformer(self):
"""获取或创建坐标转换器"""
if not hasattr(self, '_transformer'):
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
self._transformer = pyproj.Transformer.from_proj(ecef, lla)
return self._transformer
def filter_points(self, points):
"""
过滤点集,只保留区域内的点
:param points: 点列表或numpy数组每行 [lon, lat, height]
:return: 过滤后的点列表
"""
if not self.region_coords or len(points) == 0:
return points
# 转换为numpy数组处理
if isinstance(points, list):
points_array = np.array(points)
else:
points_array = points
if len(points_array) == 0 or points_array.shape[1] < 2:
return points_array.tolist() if isinstance(points, list) else points_array
# 检查每个点是否在区域内
in_region_mask = (
(points_array[:, 0] >= self.min_lon) &
(points_array[:, 0] <= self.max_lon) &
(points_array[:, 1] >= self.min_lat) &
(points_array[:, 1] <= self.max_lat)
)
filtered_points = points_array[in_region_mask]
print(f"区域过滤: {len(points_array)} 个点 -> {len(filtered_points)} 个点 "
f"(过滤掉 {len(points_array) - len(filtered_points)} 个点)")
return filtered_points.tolist() if isinstance(points, list) else filtered_points
# ========== 核心工具函数:矩阵变换 ==========
def apply_transform_matrix(vertices, transform_matrix):
"""
将模型的局部相对顶点坐标通过transform矩阵转换为绝对ECEF坐标
:param vertices: 原始局部顶点 (n,3) numpy数组
:param transform_matrix: 瓦片的transform矩阵 一维列表/数组长度164x4矩阵
:return: 绝对ECEF坐标 (n,3) numpy数组
"""
if transform_matrix is None or len(transform_matrix) != 16:
return vertices
# reshape为4x4然后转置因为glTF是列主序
mat = np.array(transform_matrix).reshape(4, 4).astype(np.float64).T
# 顶点齐次坐标化 (n,3) -> (n,4) 最后一列补1
ones = np.ones((vertices.shape[0], 1), dtype=np.float64)
vertices_hom = np.hstack([vertices, ones])
# 矩阵乘法:顶点坐标 * 变换矩阵 = 绝对坐标
vertices_ecef_hom = np.dot(vertices_hom, mat.T)
# 还原为三维坐标 (n,4) -> (n,3)
vertices_ecef = vertices_ecef_hom[:, :3]
return vertices_ecef
def parse_b3dm_to_points(b3dm_path, region_filter=None, transform_matrix=None):
"""
解析B3DM文件,提取顶点的经纬度+高程
【关键修改】区域过滤移到读取顶点后进行
"""
# 获取脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = os.path.join(script_dir, "temp_glb")
# 创建临时目录(如果不存在)
os.makedirs(temp_dir, exist_ok=True)
# 1. 读取B3DM二进制文件
with open(b3dm_path, "rb") as f:
b3dm_data = f.read()
# 跳过头部
header = struct.unpack('<4sIIIIII', b3dm_data[:28])
ft_json_len, ft_bin_len, bt_json_len, bt_bin_len = header[3:7]
offset = 28
offset += ft_json_len # 跳过Feature Table JSON
offset += ft_bin_len # 跳过Feature Table Binary
offset += bt_json_len # 跳过Batch Table JSON
offset += bt_bin_len # 跳过Batch Table Binary
# 提取glb数据
glb_data = b3dm_data[offset:]
if len(glb_data) < 12:
return []
# 2. 使用脚本目录下的临时文件
temp_file_path = None
try:
# 生成唯一临时文件名
temp_filename = f"temp_{uuid.uuid4().hex[:8]}.glb"
temp_file_path = os.path.join(temp_dir, temp_filename)
# 将GLB数据写入临时文件
with open(temp_file_path, "wb") as tmp_glb:
tmp_glb.write(glb_data)
# 使用DracoGLBParser解析
parser = DracoGLBParser(temp_file_path)
# 解析 GLB 结构
parser.parse_glb_structure()
# 分析结构
parser.analyze_structure()
# 解码 Draco 网格
mesh = parser.decode_draco_meshes()
except Exception as e:
print(f"读取GLB数据失败 {b3dm_path}: {e}")
return []
finally:
# 清理临时文件
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception as e:
pass
if mesh is None:
print(f"无法加载模型: {b3dm_path}")
return []
# 获取顶点数据
vertices = parser.get_all_vertices()
if vertices.size == 0 or len(vertices.shape) < 2 or vertices.shape[1] != 3:
print(f"顶点数据格式无效: {b3dm_path}")
return []
# 3. 应用transform矩阵局部坐标 → 绝对ECEF坐标
vertices = apply_transform_matrix(vertices, transform_matrix)
# 4. ECEF坐标转WGS84经纬度+高程
try:
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
transformer = pyproj.Transformer.from_proj(ecef, lla, always_xy=True)
lons, lats, heights = transformer.transform(
vertices[:, 0], vertices[:, 1], vertices[:, 2], radians=False
)
# 组合成点集
points = np.column_stack([lons, lats, heights])
# 5. 基本数据清洗(过滤异常值)
# 过滤nan/inf
valid_mask = np.isfinite(points).all(axis=1)
points = points[valid_mask]
if len(points) == 0:
print(f"B3DM文件 {os.path.basename(b3dm_path)} 转换后无有效点")
return []
# 过滤经纬度超限值
geo_mask = (
(points[:, 0] >= -180) & (points[:, 0] <= 180) &
(points[:, 1] >= -90) & (points[:, 1] <= 90)
)
points = points[geo_mask]
if len(points) == 0:
print(f"B3DM文件 {os.path.basename(b3dm_path)} 经纬度超限")
return []
print(f"{os.path.basename(b3dm_path)} 提取到 {len(points)} 个原始顶点")
# 6. 【关键修改】应用区域过滤器(如果提供)
# 在读取顶点后进行区域过滤,而不是在瓦片级别过滤
if region_filter:
points = region_filter.filter_points(points)
if len(points) == 0:
print(f"B3DM文件 {os.path.basename(b3dm_path)} 的所有点都在区域外,跳过")
return []
print(f"最终保留 {len(points)} 个在区域内的顶点")
return points.tolist()
except Exception as e:
print(f"坐标转换失败 {b3dm_path}: {e}")
return []
def traverse_nested_tiles(tile_obj, base_dir, b3dm_paths, tile_transforms, region_filter=None, parent_transform=None):
"""
深度递归遍历瓦片,自动识别「子JSON」和「B3DM」
【修改】移除瓦片级别的区域过滤,完全依赖顶点级别的过滤
"""
# 1. 计算当前瓦片的最终变换矩阵
current_transform = parent_transform
if "transform" in tile_obj:
tile_mat = tile_obj["transform"]
if current_transform is None:
current_transform = tile_mat
else:
# 合并变换矩阵
mat1 = np.array(current_transform).reshape(4, 4)
mat2 = np.array(tile_mat).reshape(4, 4)
combined_mat = np.dot(mat1, mat2).flatten().tolist()
current_transform = combined_mat
# 2. 【修改】不再检查瓦片包围体,直接处理内容
# 这样可以避免因粗略的包围体判断而漏掉部分在区域内的顶点
# 3. 处理当前瓦片的内容
if "content" in tile_obj and "uri" in tile_obj["content"]:
tile_uri = tile_obj["content"]["uri"]
tile_abs_path = os.path.join(base_dir, tile_uri)
if tile_uri.lower().endswith(".json"):
# 情况1:uri是子JSON文件 → 递归解析这个子JSON
if os.path.exists(tile_abs_path):
print(f"解析嵌套子JSON文件: {tile_abs_path}")
with open(tile_abs_path, "r", encoding="utf-8") as f:
sub_tileset = json.load(f)
sub_base_dir = os.path.dirname(tile_abs_path)
traverse_nested_tiles(sub_tileset["root"], sub_base_dir, b3dm_paths, tile_transforms, region_filter, current_transform)
else:
print(f"嵌套子JSON文件不存在,跳过: {tile_abs_path}")
elif tile_uri.lower().endswith(".b3dm"):
# 情况2:uri是B3DM文件 → 收集路径+对应transform矩阵
if os.path.exists(tile_abs_path):
b3dm_paths.append(tile_abs_path)
tile_transforms.append(current_transform)
print(f"收集到B3DM文件: {tile_abs_path}")
else:
print(f"B3DM文件不存在,跳过: {tile_abs_path}")
# 4. 递归遍历当前tile的子节点children
if "children" in tile_obj and len(tile_obj["children"]) > 0:
for child_tile in tile_obj["children"]:
traverse_nested_tiles(child_tile, base_dir, b3dm_paths, tile_transforms, region_filter, current_transform)
def parse_tileset(tileset_path, region_coords=None):
"""重构主解析函数,支持区域过滤+矩阵变换"""
if not os.path.exists(tileset_path):
raise FileNotFoundError(f"根tileset.json文件不存在: {tileset_path}")
# 初始化区域过滤器(将在顶点级别使用)
region_filter = RegionFilter(region_coords)
# 读取根tileset.json
with open(tileset_path, "r", encoding="utf-8") as f:
tileset_json = json.load(f)
root_dir = os.path.dirname(tileset_path)
b3dm_paths = []
tile_transforms = []
print(f"开始遍历tileset结构...")
# 调用深度递归函数
traverse_nested_tiles(tileset_json["root"], root_dir, b3dm_paths, tile_transforms, region_filter, None)
print(f"\n遍历完成,共发现 {len(b3dm_paths)} 个B3DM文件:")
for i, b3dm_path in enumerate(b3dm_paths):
print(f" {i+1}. {os.path.basename(b3dm_path)}")
# 批量解析所有B3DM文件,合并点云
all_points = []
if len(b3dm_paths) == 0:
print("未提取到任何有效的B3DM文件")
return all_points
print(f"\n===== 开始解析B3DM文件 =====")
for i, (b3dm_path, transform_mat) in enumerate(zip(b3dm_paths, tile_transforms), 1):
print(f"解析文件 {i}/{len(b3dm_paths)}: {os.path.basename(b3dm_path)}")
points = parse_b3dm_to_points(b3dm_path, region_filter, transform_mat)
if points:
all_points.extend(points)
print(f" 提取到 {len(points)} 个点")
else:
print(f" 未提取到有效点")
# 点云去重+优化
if all_points:
all_points = np.array(all_points)
original_count = len(all_points)
all_points = np.unique(all_points.round(decimals=6), axis=0)
print(f"\n最终提取点云数量: {len(all_points)} 个 (已去重, 去除了 {original_count - len(all_points)} 个重复点)")
# ========== 新增:输出整个地图文件的经纬度高程范围 ==========
print("\n" + "=" * 60)
print("地图文件总体范围统计:")
print("-" * 60)
# 计算经纬度范围
min_lon = np.min(all_points[:, 0])
max_lon = np.max(all_points[:, 0])
min_lat = np.min(all_points[:, 1])
max_lat = np.max(all_points[:, 1])
min_height = np.min(all_points[:, 2])
max_height = np.max(all_points[:, 2])
avg_height = np.mean(all_points[:, 2])
std_height = np.std(all_points[:, 2])
# 计算中心点
center_lon = (min_lon + max_lon) / 2
center_lat = (min_lat + max_lat) / 2
center_height = (min_height + max_height) / 2
print(f"经度范围: {min_lon:.6f}° ~ {max_lon:.6f}° (跨度: {max_lon - min_lon:.6f}°)")
print(f"纬度范围: {min_lat:.6f}° ~ {max_lat:.6f}° (跨度: {max_lat - min_lat:.6f}°)")
print(f"高程范围: {min_height:.2f}m ~ {max_height:.2f}m (总高差: {max_height - min_height:.2f}m)")
print(f"平均高程: {avg_height:.2f}m (±{std_height:.2f}m)")
print(f"\n中心点坐标:")
print(f" 位置: ({center_lon:.6f}°, {center_lat:.6f}°)")
print(f" 高程: {center_height:.2f}m")
# 输出边界坐标(用于复制使用)
print(f"\n边界坐标:")
print(f" 西北角: ({min_lon:.6f}, {max_lat:.6f})")
print(f" 东北角: ({max_lon:.6f}, {max_lat:.6f})")
print(f" 西南角: ({min_lon:.6f}, {min_lat:.6f})")
print(f" 东南角: ({max_lon:.6f}, {min_lat:.6f})")
# 如果有区域过滤,显示过滤效果
if region_coords:
region_min_lon = min(coord[0] for coord in region_coords)
region_max_lon = max(coord[0] for coord in region_coords)
region_min_lat = min(coord[1] for coord in region_coords)
region_max_lat = max(coord[1] for coord in region_coords)
print(f"\n区域过滤效果:")
print(f" 原始地图范围: Lon[{region_min_lon:.6f}, {region_max_lon:.6f}], Lat[{region_min_lat:.6f}, {region_max_lat:.6f}]")
print(f" 提取数据范围: Lon[{min_lon:.6f}, {max_lon:.6f}], Lat[{min_lat:.6f}, {max_lat:.6f}]")
# 计算覆盖率
region_width = region_max_lon - region_min_lon
region_height = region_max_lat - region_min_lat
extracted_width = max_lon - min_lon
extracted_height = max_lat - min_lat
width_coverage = extracted_width / region_width * 100 if region_width > 0 else 0
height_coverage = extracted_height / region_height * 100 if region_height > 0 else 0
print(f" 经度方向覆盖率: {width_coverage:.1f}%")
print(f" 纬度方向覆盖率: {height_coverage:.1f}%")
print("=" * 60)
return all_points
def points_to_dem(points, output_dem_path, pixel_size=0.0001):
"""将离散点云插值为DEMGeoTIFF格式- 使用Scipy插值优化版本"""
if len(points) == 0:
raise ValueError("无有效点云数据,无法生成DEM")
# 转换为numpy数组
points_array = np.array(points)
lons = points_array[:, 0]
lats = points_array[:, 1]
heights = points_array[:, 2]
min_lon, max_lon = lons.min(), lons.max()
min_lat, max_lat = lats.min(), lats.max()
print(f"DEM范围: Lon[{min_lon:.6f}, {max_lon:.6f}], Lat[{min_lat:.6f}, {max_lat:.6f}]")
print(f"点云数量: {len(points)}")
print(f"高程范围: {heights.min():.2f} ~ {heights.max():.2f}")
# 计算网格尺寸
width = int((max_lon - min_lon) / pixel_size) + 1
height = int((max_lat - min_lat) / pixel_size) + 1
# 限制网格大小,避免过大
max_grid_size = 5000 # 最大网格尺寸
if width > max_grid_size or height > max_grid_size:
print(f"警告: 网格尺寸过大 ({width}x{height}),自动调整像素大小...")
# 重新计算像素大小
larger_dim = max(width, height)
pixel_size = pixel_size * (larger_dim / max_grid_size)
width = int((max_lon - min_lon) / pixel_size) + 1
height = int((max_lat - min_lat) / pixel_size) + 1
print(f"调整后像素大小: {pixel_size:.6f}°")
print(f"DEM网格: {width}x{height} (像素大小: {pixel_size:.6f}°)")
# 创建网格坐标
x_grid = np.linspace(min_lon, max_lon, width)
y_grid = np.linspace(max_lat, min_lat, height) # 纬度从上到下递减
xi, yi = np.meshgrid(x_grid, y_grid)
# 使用scipy进行插值
from scipy.interpolate import griddata
print("开始插值计算...")
# 方法1: 先尝试线性插值
try:
zi = griddata((lons, lats), heights, (xi, yi), method='linear')
nan_count = np.isnan(zi).sum()
nan_percent = nan_count / (width * height) * 100
print(f"线性插值完成,空白区域: {nan_count} 像素 ({nan_percent:.1f}%)")
# 如果有空白区域,使用最近邻方法填充
if nan_count > 0:
print("使用最近邻方法填充空白区域...")
zi_nn = griddata((lons, lats), heights, (xi, yi), method='nearest')
mask = np.isnan(zi)
zi[mask] = zi_nn[mask]
# 再次检查
nan_count = np.isnan(zi).sum()
if nan_count > 0:
print(f"仍有 {nan_count} 个空白像素,填充为最低高程值")
min_height = heights.min()
zi[np.isnan(zi)] = min_height
except Exception as e:
print(f"线性插值失败: {e}")
print("尝试使用最近邻插值...")
zi = griddata((lons, lats), heights, (xi, yi), method='nearest')
# 创建GeoTIFF
print("创建GeoTIFF文件...")
driver = gdal.GetDriverByName("GTiff")
dem_ds = driver.Create(
output_dem_path, width, height, 1, gdal.GDT_Float32,
options=["COMPRESS=LZW", "TILED=YES", "PREDICTOR=2", "ZLEVEL=9"]
)
if dem_ds is None:
raise RuntimeError(f"无法创建DEM文件: {output_dem_path}")
# 设置投影和地理变换
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # WGS84
dem_ds.SetProjection(srs.ExportToWkt())
geotransform = [
min_lon, pixel_size, 0,
max_lat, 0, -pixel_size
]
dem_ds.SetGeoTransform(geotransform)
# 写入数据
band = dem_ds.GetRasterBand(1)
band.WriteArray(zi)
band.SetNoDataValue(-9999.0)
band.SetDescription("Elevation")
band.SetUnitType("meters")
# 计算统计信息
print("计算统计信息...")
band.FlushCache()
band.ComputeStatistics(False)
# 设置颜色表(可选)
try:
import matplotlib.pyplot as plt
cmap = plt.cm.terrain
colors = cmap(np.linspace(0, 1, 256))
colors = (colors[:, :3] * 255).astype(np.uint8)
color_table = gdal.ColorTable()
for i in range(256):
color_table.SetColorEntry(i, (colors[i, 0], colors[i, 1], colors[i, 2], 255))
band.SetColorTable(color_table)
band.SetColorInterpretation(gdal.GCI_PaletteIndex)
except:
pass # 如果设置颜色表失败,继续
dem_ds = None # 关闭文件
print(f"DEM生成成功: {output_dem_path}")
print(f" 文件大小: {os.path.getsize(output_dem_path) / (1024*1024):.2f} MB")
# 验证文件
if os.path.exists(output_dem_path):
ds = gdal.Open(output_dem_path, gdal.GA_ReadOnly)
if ds:
band = ds.GetRasterBand(1)
stats = band.GetStatistics(True, True)
print(f" DEM统计: 最小值={stats[0]:.2f}m, 最大值={stats[1]:.2f}m")
print(f" 平均值={stats[2]:.2f}m, 标准差={stats[3]:.2f}m")
ds = None
else:
print("警告: DEM文件可能未成功生成")
def generate_dem(REGION_COORDS=None, tileset_path=None, dem_path=None):
# 配置参数
script_dir = os.path.dirname(os.path.abspath(__file__))
if tileset_path:
TILESET_PATH = tileset_path
else:
TILESET_PATH = os.path.dirname(script_dir) + "/data/3dtiles/tileset.json"
if dem_path :
OUTPUT_DEM_PATH = dem_path
else :
OUTPUT_DEM_PATH = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
PIXEL_SIZE = 0.0001
# 执行流程
print("=" * 60)
print("开始解析3D Tiles...")
if REGION_COORDS:
print(f"启用区域过滤: {REGION_COORDS}")
else:
print("未启用区域过滤,将处理所有数据")
points = parse_tileset(TILESET_PATH, REGION_COORDS)
print(f"解析完成,共提取点云: {len(points)}")
if len(points) > 0:
points_to_dem(points, OUTPUT_DEM_PATH, pixel_size=PIXEL_SIZE)
return OUTPUT_DEM_PATH
else:
print("无点云数据,无法生成DEM")
return None
if __name__ == "__main__":
# 测试示例
# 石棉县核心区域
# SHIMIAN_CORE = [(100.22476304, 29.38340151), (110.32476304, 31.28340151)]
SHIMIAN_CORE = [(100.22476304, 29.18340151), (110.32476304, 31.28340151)]
# 可以根据需要启用或禁用区域过滤
REGION_COORDS = SHIMIAN_CORE # 启用区域过滤
# REGION_COORDS = None # 禁用区域过滤
dem_path = generate_dem(REGION_COORDS)
if dem_path:
print(f"\nDEM文件已生成: {dem_path}")