ai_project_v1/b3dm/data_3dtiles_to_dem.py

757 lines
29 KiB
Python
Raw Normal View History

2026-01-14 11:37:35 +08:00
# data_3dtiles_to_dem52
import os
import json
import numpy as np
import pandas as pd
import pyproj
import struct
from osgeo import gdal, osr
import uuid
from glb_with_draco import DracoGLBParser
# 解决GDAL中文路径/警告问题(生产必加)
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("CPL_ZIP_ENCODING", "UTF-8")
gdal.PushErrorHandler('CPLQuietErrorHandler')
class RegionFilter:
"""完整的区域过滤器,支持所有格式"""
def __init__(self, region_coords=None, enable_tile_filter=True, debug=False):
"""
初始化区域过滤器
:param debug: 是否输出调试信息
"""
self.region_coords = region_coords
self.enable_tile_filter = enable_tile_filter
self.debug = debug
if region_coords:
# 提取区域边界
lons = [coord[0] for coord in region_coords]
lats = [coord[1] for coord in region_coords]
self.min_lon = min(lons)
self.max_lon = max(lons)
self.min_lat = min(lats)
self.max_lat = max(lats)
# 扩展边界(避免边缘误差)
self.expand_factor = 0.1
lon_expand = (self.max_lon - self.min_lon) * self.expand_factor
lat_expand = (self.max_lat - self.min_lat) * self.expand_factor
self.filter_min_lon = self.min_lon - lon_expand
self.filter_max_lon = self.max_lon + lon_expand
self.filter_min_lat = self.min_lat - lat_expand
self.filter_max_lat = self.max_lat + lat_expand
if self.debug:
print(f"[DEBUG] 区域过滤器初始化:")
print(f" 目标区域: Lon[{self.min_lon:.6f}, {self.max_lon:.6f}], "
f"Lat[{self.min_lat:.6f}, {self.max_lat:.6f}]")
print(f" 过滤区域: Lon[{self.filter_min_lon:.6f}, {self.filter_max_lon:.6f}], "
f"Lat[{self.filter_min_lat:.6f}, {self.filter_max_lat:.6f}]")
else:
self.min_lon = self.min_lat = -180
self.max_lon = self.max_lat = 180
self.filter_min_lon = self.filter_min_lat = -180
self.filter_max_lon = self.filter_max_lat = 180
if self.debug:
print("[DEBUG] 区域过滤器: 未指定区域,将处理所有数据")
def check_tile_bounding_volume(self, bounding_volume):
"""检查瓦片的包围体是否与指定区域相交"""
if not self.region_coords or not self.enable_tile_filter:
return True
try:
if 'region' in bounding_volume:
return self._check_region(bounding_volume['region'])
elif 'box' in bounding_volume:
box = bounding_volume['box']
if self.debug:
print(f"[DEBUG] 检查box包围体长度={len(box)}")
if len(box) == 12:
result = self._check_box_12(box)
elif len(box) == 15:
result = self._check_box_15(box)
else:
if self.debug:
print(f"[DEBUG] 异常box长度 {len(box)},默认通过")
return True
if self.debug and not result:
print(f"[DEBUG] Box被过滤")
return result
elif 'sphere' in bounding_volume:
return self._check_sphere(bounding_volume['sphere'])
return True
except Exception as e:
if self.debug:
print(f"[DEBUG] 包围体检查出错: {e}")
return True
def _check_region(self, region):
"""检查region格式 [west, south, east, north, minHeight, maxHeight]"""
if len(region) != 6:
return True
west, south, east, north, min_h, max_h = region
# 检查是否完全在过滤区域外
if (east < self.filter_min_lon or west > self.filter_max_lon or
north < self.filter_min_lat or south > self.filter_max_lat):
if self.debug:
print(f"[DEBUG] Region过滤: [{west:.3f},{south:.3f},{east:.3f},{north:.3f}]")
return False
return True
def _check_box_12(self, box):
"""检查12值box格式"""
# 提取参数
cx, cy, cz = box[0], box[1], box[2]
halfX = box[3]
halfY = box[7]
halfZ = box[11]
if self.debug:
print(f"[DEBUG] 12值box: center=({cx:.1f},{cy:.1f},{cz:.1f}), "
f"halfs=({halfX},{halfY},{halfZ})")
try:
# 获取转换器
transformer = self._get_transformer()
# 转换中心点
center_lon, center_lat, _ = transformer.transform(
cx, cy, cz, radians=False
)
# 计算最大偏移(简化方法,避免计算所有角点)
max_half = max(halfX, halfY, halfZ)
# 转换为经纬度偏移
earth_radius = 6378137.0
lon_offset = np.degrees(max_half / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(max_half / earth_radius)
# 计算box范围
box_min_lon = center_lon - lon_offset
box_max_lon = center_lon + lon_offset
box_min_lat = center_lat - lat_offset
box_max_lat = center_lat + lat_offset
if self.debug:
print(f"[DEBUG] 中心: ({center_lon:.6f}, {center_lat:.6f})")
print(f"[DEBUG] 范围: Lon[{box_min_lon:.6f}, {box_max_lon:.6f}], "
f"Lat[{box_min_lat:.6f}, {box_max_lat:.6f}]")
# 检查相交
if (box_max_lon < self.filter_min_lon or box_min_lon > self.filter_max_lon or
box_max_lat < self.filter_min_lat or box_min_lat > self.filter_max_lat):
return False
return True
except Exception as e:
if self.debug:
print(f"[DEBUG] Box12检查失败: {e}")
return True # 失败时默认通过
def _check_box_15(self, box):
"""检查标准15值box格式"""
if len(box) < 15:
return True
cx, cy, cz = box[0], box[1], box[2]
halfX, halfY, halfZ = box[12], box[13], box[14]
# 简化处理:只检查中心点
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
# 计算最大偏移
max_half = max(halfX, halfY, halfZ)
earth_radius = 6378137.0
lon_offset = np.degrees(max_half / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(max_half / earth_radius)
box_min_lon = center_lon - lon_offset
box_max_lon = center_lon + lon_offset
box_min_lat = center_lat - lat_offset
box_max_lat = center_lat + lat_offset
if (box_max_lon < self.filter_min_lon or box_min_lon > self.filter_max_lon or
box_max_lat < self.filter_min_lat or box_min_lat > self.filter_max_lat):
return False
return True
def _check_sphere(self, sphere):
"""检查sphere格式 [centerX, centerY, centerZ, radius]"""
if len(sphere) < 4:
return True
cx, cy, cz, radius = sphere[0], sphere[1], sphere[2], sphere[3]
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
# 计算半径对应的经纬度偏移
earth_radius = 6378137.0
lon_offset = np.degrees(radius / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(radius / earth_radius)
sphere_min_lon = center_lon - lon_offset
sphere_max_lon = center_lon + lon_offset
sphere_min_lat = center_lat - lat_offset
sphere_max_lat = center_lat + lat_offset
if (sphere_max_lon < self.filter_min_lon or sphere_min_lon > self.filter_max_lon or
sphere_max_lat < self.filter_min_lat or sphere_min_lat > self.filter_max_lat):
return False
return True
def _get_transformer(self):
"""获取或创建坐标转换器"""
if not hasattr(self, '_transformer'):
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
self._transformer = pyproj.Transformer.from_proj(ecef, lla)
return self._transformer
def filter_points(self, points):
"""
过滤点集只保留区域内的点
:param points: 点列表或numpy数组每行 [lon, lat, height]
:return: 过滤后的点列表
"""
if not self.region_coords or len(points) == 0:
return points
# 转换为numpy数组处理
if isinstance(points, list):
points_array = np.array(points)
else:
points_array = points
if len(points_array) == 0 or points_array.shape[1] < 2:
return points_array.tolist() if isinstance(points, list) else points_array
# 检查每个点是否在区域内
in_region_mask = (
(points_array[:, 0] >= self.min_lon) &
(points_array[:, 0] <= self.max_lon) &
(points_array[:, 1] >= self.min_lat) &
(points_array[:, 1] <= self.max_lat)
)
filtered_points = points_array[in_region_mask]
print(f"区域过滤: {len(points_array)} 个点 -> {len(filtered_points)} 个点 "
f"(过滤掉 {len(points_array) - len(filtered_points)} 个点)")
return filtered_points.tolist() if isinstance(points, list) else filtered_points
# ========== 核心工具函数:矩阵变换 ==========
def apply_transform_matrix(vertices, transform_matrix):
"""
将模型的局部相对顶点坐标通过transform矩阵转换为绝对ECEF坐标
:param vertices: 原始局部顶点 (n,3) numpy数组
:param transform_matrix: 瓦片的transform矩阵 一维列表/数组长度164x4矩阵
:return: 绝对ECEF坐标 (n,3) numpy数组
"""
if transform_matrix is None or len(transform_matrix) != 16:
return vertices
# reshape为4x4然后转置因为glTF是列主序
mat = np.array(transform_matrix).reshape(4, 4).astype(np.float64).T
# 顶点齐次坐标化 (n,3) -> (n,4) 最后一列补1
ones = np.ones((vertices.shape[0], 1), dtype=np.float64)
vertices_hom = np.hstack([vertices, ones])
# 矩阵乘法:顶点坐标 * 变换矩阵 = 绝对坐标
vertices_ecef_hom = np.dot(vertices_hom, mat.T)
# 还原为三维坐标 (n,4) -> (n,3)
vertices_ecef = vertices_ecef_hom[:, :3]
return vertices_ecef
def parse_b3dm_to_points(b3dm_path, region_filter=None, transform_matrix=None):
"""
解析B3DM文件,提取顶点的经纬度+高程
关键修改区域过滤移到读取顶点后进行
"""
# 获取脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = os.path.join(script_dir, "temp_glb")
# 创建临时目录(如果不存在)
os.makedirs(temp_dir, exist_ok=True)
# 1. 读取B3DM二进制文件
with open(b3dm_path, "rb") as f:
b3dm_data = f.read()
# 跳过头部
header = struct.unpack('<4sIIIIII', b3dm_data[:28])
ft_json_len, ft_bin_len, bt_json_len, bt_bin_len = header[3:7]
offset = 28
offset += ft_json_len # 跳过Feature Table JSON
offset += ft_bin_len # 跳过Feature Table Binary
offset += bt_json_len # 跳过Batch Table JSON
offset += bt_bin_len # 跳过Batch Table Binary
# 提取glb数据
glb_data = b3dm_data[offset:]
if len(glb_data) < 12:
return []
# 2. 使用脚本目录下的临时文件
temp_file_path = None
try:
# 生成唯一临时文件名
temp_filename = f"temp_{uuid.uuid4().hex[:8]}.glb"
temp_file_path = os.path.join(temp_dir, temp_filename)
# 将GLB数据写入临时文件
with open(temp_file_path, "wb") as tmp_glb:
tmp_glb.write(glb_data)
# 使用DracoGLBParser解析
parser = DracoGLBParser(temp_file_path)
# 解析 GLB 结构
parser.parse_glb_structure()
# 分析结构
parser.analyze_structure()
# 解码 Draco 网格
mesh = parser.decode_draco_meshes()
except Exception as e:
print(f"读取GLB数据失败 {b3dm_path}: {e}")
return []
finally:
# 清理临时文件
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception as e:
pass
if mesh is None:
print(f"无法加载模型: {b3dm_path}")
return []
# 获取顶点数据
vertices = parser.get_all_vertices()
if vertices.size == 0 or len(vertices.shape) < 2 or vertices.shape[1] != 3:
print(f"顶点数据格式无效: {b3dm_path}")
return []
# 3. 应用transform矩阵局部坐标 → 绝对ECEF坐标
vertices = apply_transform_matrix(vertices, transform_matrix)
# 4. ECEF坐标转WGS84经纬度+高程
try:
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
transformer = pyproj.Transformer.from_proj(ecef, lla, always_xy=True)
lons, lats, heights = transformer.transform(
vertices[:, 0], vertices[:, 1], vertices[:, 2], radians=False
)
# 组合成点集
points = np.column_stack([lons, lats, heights])
# 5. 基本数据清洗(过滤异常值)
# 过滤nan/inf
valid_mask = np.isfinite(points).all(axis=1)
points = points[valid_mask]
if len(points) == 0:
print(f"B3DM文件 {os.path.basename(b3dm_path)} 转换后无有效点")
return []
# 过滤经纬度超限值
geo_mask = (
(points[:, 0] >= -180) & (points[:, 0] <= 180) &
(points[:, 1] >= -90) & (points[:, 1] <= 90)
)
points = points[geo_mask]
if len(points) == 0:
print(f"B3DM文件 {os.path.basename(b3dm_path)} 经纬度超限")
return []
print(f"{os.path.basename(b3dm_path)} 提取到 {len(points)} 个原始顶点")
# 6. 【关键修改】应用区域过滤器(如果提供)
# 在读取顶点后进行区域过滤,而不是在瓦片级别过滤
if region_filter:
points = region_filter.filter_points(points)
if len(points) == 0:
print(f"B3DM文件 {os.path.basename(b3dm_path)} 的所有点都在区域外,跳过")
return []
print(f"最终保留 {len(points)} 个在区域内的顶点")
return points.tolist()
except Exception as e:
print(f"坐标转换失败 {b3dm_path}: {e}")
return []
def traverse_nested_tiles(tile_obj, base_dir, b3dm_paths, tile_transforms, region_filter=None, parent_transform=None):
"""
深度递归遍历瓦片,自动识别子JSONB3DM
修改移除瓦片级别的区域过滤完全依赖顶点级别的过滤
"""
# 1. 计算当前瓦片的最终变换矩阵
current_transform = parent_transform
if "transform" in tile_obj:
tile_mat = tile_obj["transform"]
if current_transform is None:
current_transform = tile_mat
else:
# 合并变换矩阵
mat1 = np.array(current_transform).reshape(4, 4)
mat2 = np.array(tile_mat).reshape(4, 4)
combined_mat = np.dot(mat1, mat2).flatten().tolist()
current_transform = combined_mat
# 2. 【修改】不再检查瓦片包围体,直接处理内容
# 这样可以避免因粗略的包围体判断而漏掉部分在区域内的顶点
# 3. 处理当前瓦片的内容
if "content" in tile_obj and "uri" in tile_obj["content"]:
tile_uri = tile_obj["content"]["uri"]
tile_abs_path = os.path.join(base_dir, tile_uri)
if tile_uri.lower().endswith(".json"):
# 情况1:uri是子JSON文件 → 递归解析这个子JSON
if os.path.exists(tile_abs_path):
print(f"解析嵌套子JSON文件: {tile_abs_path}")
with open(tile_abs_path, "r", encoding="utf-8") as f:
sub_tileset = json.load(f)
sub_base_dir = os.path.dirname(tile_abs_path)
traverse_nested_tiles(sub_tileset["root"], sub_base_dir, b3dm_paths, tile_transforms, region_filter, current_transform)
else:
print(f"嵌套子JSON文件不存在,跳过: {tile_abs_path}")
elif tile_uri.lower().endswith(".b3dm"):
# 情况2:uri是B3DM文件 → 收集路径+对应transform矩阵
if os.path.exists(tile_abs_path):
b3dm_paths.append(tile_abs_path)
tile_transforms.append(current_transform)
print(f"收集到B3DM文件: {tile_abs_path}")
else:
print(f"B3DM文件不存在,跳过: {tile_abs_path}")
# 4. 递归遍历当前tile的子节点children
if "children" in tile_obj and len(tile_obj["children"]) > 0:
for child_tile in tile_obj["children"]:
traverse_nested_tiles(child_tile, base_dir, b3dm_paths, tile_transforms, region_filter, current_transform)
def parse_tileset(tileset_path, region_coords=None):
"""重构主解析函数,支持区域过滤+矩阵变换"""
if not os.path.exists(tileset_path):
raise FileNotFoundError(f"根tileset.json文件不存在: {tileset_path}")
# 初始化区域过滤器(将在顶点级别使用)
region_filter = RegionFilter(region_coords)
# 读取根tileset.json
with open(tileset_path, "r", encoding="utf-8") as f:
tileset_json = json.load(f)
root_dir = os.path.dirname(tileset_path)
b3dm_paths = []
tile_transforms = []
print(f"开始遍历tileset结构...")
# 调用深度递归函数
traverse_nested_tiles(tileset_json["root"], root_dir, b3dm_paths, tile_transforms, region_filter, None)
print(f"\n遍历完成,共发现 {len(b3dm_paths)} 个B3DM文件:")
for i, b3dm_path in enumerate(b3dm_paths):
print(f" {i+1}. {os.path.basename(b3dm_path)}")
# 批量解析所有B3DM文件,合并点云
all_points = []
if len(b3dm_paths) == 0:
print("未提取到任何有效的B3DM文件")
return all_points
print(f"\n===== 开始解析B3DM文件 =====")
for i, (b3dm_path, transform_mat) in enumerate(zip(b3dm_paths, tile_transforms), 1):
print(f"解析文件 {i}/{len(b3dm_paths)}: {os.path.basename(b3dm_path)}")
points = parse_b3dm_to_points(b3dm_path, region_filter, transform_mat)
if points:
all_points.extend(points)
print(f" 提取到 {len(points)} 个点")
else:
print(f" 未提取到有效点")
# 点云去重+优化
if all_points:
all_points = np.array(all_points)
original_count = len(all_points)
all_points = np.unique(all_points.round(decimals=6), axis=0)
print(f"\n最终提取点云数量: {len(all_points)} 个 (已去重, 去除了 {original_count - len(all_points)} 个重复点)")
# ========== 新增:输出整个地图文件的经纬度高程范围 ==========
print("\n" + "=" * 60)
print("地图文件总体范围统计:")
print("-" * 60)
# 计算经纬度范围
min_lon = np.min(all_points[:, 0])
max_lon = np.max(all_points[:, 0])
min_lat = np.min(all_points[:, 1])
max_lat = np.max(all_points[:, 1])
min_height = np.min(all_points[:, 2])
max_height = np.max(all_points[:, 2])
avg_height = np.mean(all_points[:, 2])
std_height = np.std(all_points[:, 2])
# 计算中心点
center_lon = (min_lon + max_lon) / 2
center_lat = (min_lat + max_lat) / 2
center_height = (min_height + max_height) / 2
print(f"经度范围: {min_lon:.6f}° ~ {max_lon:.6f}° (跨度: {max_lon - min_lon:.6f}°)")
print(f"纬度范围: {min_lat:.6f}° ~ {max_lat:.6f}° (跨度: {max_lat - min_lat:.6f}°)")
print(f"高程范围: {min_height:.2f}m ~ {max_height:.2f}m (总高差: {max_height - min_height:.2f}m)")
print(f"平均高程: {avg_height:.2f}m (±{std_height:.2f}m)")
print(f"\n中心点坐标:")
print(f" 位置: ({center_lon:.6f}°, {center_lat:.6f}°)")
print(f" 高程: {center_height:.2f}m")
# 输出边界坐标(用于复制使用)
print(f"\n边界坐标:")
print(f" 西北角: ({min_lon:.6f}, {max_lat:.6f})")
print(f" 东北角: ({max_lon:.6f}, {max_lat:.6f})")
print(f" 西南角: ({min_lon:.6f}, {min_lat:.6f})")
print(f" 东南角: ({max_lon:.6f}, {min_lat:.6f})")
# 如果有区域过滤,显示过滤效果
if region_coords:
region_min_lon = min(coord[0] for coord in region_coords)
region_max_lon = max(coord[0] for coord in region_coords)
region_min_lat = min(coord[1] for coord in region_coords)
region_max_lat = max(coord[1] for coord in region_coords)
print(f"\n区域过滤效果:")
print(f" 原始地图范围: Lon[{region_min_lon:.6f}, {region_max_lon:.6f}], Lat[{region_min_lat:.6f}, {region_max_lat:.6f}]")
print(f" 提取数据范围: Lon[{min_lon:.6f}, {max_lon:.6f}], Lat[{min_lat:.6f}, {max_lat:.6f}]")
# 计算覆盖率
region_width = region_max_lon - region_min_lon
region_height = region_max_lat - region_min_lat
extracted_width = max_lon - min_lon
extracted_height = max_lat - min_lat
width_coverage = extracted_width / region_width * 100 if region_width > 0 else 0
height_coverage = extracted_height / region_height * 100 if region_height > 0 else 0
print(f" 经度方向覆盖率: {width_coverage:.1f}%")
print(f" 纬度方向覆盖率: {height_coverage:.1f}%")
print("=" * 60)
return all_points
def points_to_dem(points, output_dem_path, pixel_size=0.0001):
"""将离散点云插值为DEMGeoTIFF格式- 使用Scipy插值优化版本"""
if len(points) == 0:
raise ValueError("无有效点云数据,无法生成DEM")
# 转换为numpy数组
points_array = np.array(points)
lons = points_array[:, 0]
lats = points_array[:, 1]
heights = points_array[:, 2]
min_lon, max_lon = lons.min(), lons.max()
min_lat, max_lat = lats.min(), lats.max()
print(f"DEM范围: Lon[{min_lon:.6f}, {max_lon:.6f}], Lat[{min_lat:.6f}, {max_lat:.6f}]")
print(f"点云数量: {len(points)}")
print(f"高程范围: {heights.min():.2f} ~ {heights.max():.2f}")
# 计算网格尺寸
width = int((max_lon - min_lon) / pixel_size) + 1
height = int((max_lat - min_lat) / pixel_size) + 1
# 限制网格大小,避免过大
max_grid_size = 5000 # 最大网格尺寸
if width > max_grid_size or height > max_grid_size:
print(f"警告: 网格尺寸过大 ({width}x{height}),自动调整像素大小...")
# 重新计算像素大小
larger_dim = max(width, height)
pixel_size = pixel_size * (larger_dim / max_grid_size)
width = int((max_lon - min_lon) / pixel_size) + 1
height = int((max_lat - min_lat) / pixel_size) + 1
print(f"调整后像素大小: {pixel_size:.6f}°")
print(f"DEM网格: {width}x{height} (像素大小: {pixel_size:.6f}°)")
# 创建网格坐标
x_grid = np.linspace(min_lon, max_lon, width)
y_grid = np.linspace(max_lat, min_lat, height) # 纬度从上到下递减
xi, yi = np.meshgrid(x_grid, y_grid)
# 使用scipy进行插值
from scipy.interpolate import griddata
print("开始插值计算...")
# 方法1: 先尝试线性插值
try:
zi = griddata((lons, lats), heights, (xi, yi), method='linear')
nan_count = np.isnan(zi).sum()
nan_percent = nan_count / (width * height) * 100
print(f"线性插值完成,空白区域: {nan_count} 像素 ({nan_percent:.1f}%)")
# 如果有空白区域,使用最近邻方法填充
if nan_count > 0:
print("使用最近邻方法填充空白区域...")
zi_nn = griddata((lons, lats), heights, (xi, yi), method='nearest')
mask = np.isnan(zi)
zi[mask] = zi_nn[mask]
# 再次检查
nan_count = np.isnan(zi).sum()
if nan_count > 0:
print(f"仍有 {nan_count} 个空白像素,填充为最低高程值")
min_height = heights.min()
zi[np.isnan(zi)] = min_height
except Exception as e:
print(f"线性插值失败: {e}")
print("尝试使用最近邻插值...")
zi = griddata((lons, lats), heights, (xi, yi), method='nearest')
# 创建GeoTIFF
print("创建GeoTIFF文件...")
driver = gdal.GetDriverByName("GTiff")
dem_ds = driver.Create(
output_dem_path, width, height, 1, gdal.GDT_Float32,
options=["COMPRESS=LZW", "TILED=YES", "PREDICTOR=2", "ZLEVEL=9"]
)
if dem_ds is None:
raise RuntimeError(f"无法创建DEM文件: {output_dem_path}")
# 设置投影和地理变换
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # WGS84
dem_ds.SetProjection(srs.ExportToWkt())
geotransform = [
min_lon, pixel_size, 0,
max_lat, 0, -pixel_size
]
dem_ds.SetGeoTransform(geotransform)
# 写入数据
band = dem_ds.GetRasterBand(1)
band.WriteArray(zi)
band.SetNoDataValue(-9999.0)
band.SetDescription("Elevation")
band.SetUnitType("meters")
# 计算统计信息
print("计算统计信息...")
band.FlushCache()
band.ComputeStatistics(False)
# 设置颜色表(可选)
try:
import matplotlib.pyplot as plt
cmap = plt.cm.terrain
colors = cmap(np.linspace(0, 1, 256))
colors = (colors[:, :3] * 255).astype(np.uint8)
color_table = gdal.ColorTable()
for i in range(256):
color_table.SetColorEntry(i, (colors[i, 0], colors[i, 1], colors[i, 2], 255))
band.SetColorTable(color_table)
band.SetColorInterpretation(gdal.GCI_PaletteIndex)
except:
pass # 如果设置颜色表失败,继续
dem_ds = None # 关闭文件
print(f"DEM生成成功: {output_dem_path}")
print(f" 文件大小: {os.path.getsize(output_dem_path) / (1024*1024):.2f} MB")
# 验证文件
if os.path.exists(output_dem_path):
ds = gdal.Open(output_dem_path, gdal.GA_ReadOnly)
if ds:
band = ds.GetRasterBand(1)
stats = band.GetStatistics(True, True)
print(f" DEM统计: 最小值={stats[0]:.2f}m, 最大值={stats[1]:.2f}m")
print(f" 平均值={stats[2]:.2f}m, 标准差={stats[3]:.2f}m")
ds = None
else:
print("警告: DEM文件可能未成功生成")
def generate_dem(REGION_COORDS=None, tileset_path=None, dem_path=None):
# 配置参数
script_dir = os.path.dirname(os.path.abspath(__file__))
if tileset_path:
TILESET_PATH = tileset_path
else:
TILESET_PATH = os.path.dirname(script_dir) + "/data/3dtiles/tileset.json"
if dem_path :
OUTPUT_DEM_PATH = dem_path
else :
OUTPUT_DEM_PATH = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
PIXEL_SIZE = 0.0001
# 执行流程
print("=" * 60)
print("开始解析3D Tiles...")
if REGION_COORDS:
print(f"启用区域过滤: {REGION_COORDS}")
else:
print("未启用区域过滤,将处理所有数据")
points = parse_tileset(TILESET_PATH, REGION_COORDS)
print(f"解析完成,共提取点云: {len(points)}")
if len(points) > 0:
points_to_dem(points, OUTPUT_DEM_PATH, pixel_size=PIXEL_SIZE)
return OUTPUT_DEM_PATH
else:
print("无点云数据,无法生成DEM")
return None
if __name__ == "__main__":
# 测试示例
# 石棉县核心区域
# SHIMIAN_CORE = [(100.22476304, 29.38340151), (110.32476304, 31.28340151)]
SHIMIAN_CORE = [(100.22476304, 29.18340151), (110.32476304, 31.28340151)]
# 可以根据需要启用或禁用区域过滤
REGION_COORDS = SHIMIAN_CORE # 启用区域过滤
# REGION_COORDS = None # 禁用区域过滤
dem_path = generate_dem(REGION_COORDS)
if dem_path:
print(f"\nDEM文件已生成: {dem_path}")