ai_project_v1/b3dm/data_3dtiles_to_dem.py
2026-01-29 11:51:20 +08:00

1367 lines
50 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 文件过滤+顶点过滤+点云补足
#!/usr/bin/env python3
"""
3D Tiles 到 DEM 转换器
功能:
1. 多层过滤:快速过滤明显不在区域内的B3DM文件
2. 精确过滤:顶点级别的区域过滤
3. 点云补足:稀疏点云的智能增强
4. DEM生成:高质量DEM输出
"""
import os
import json
import numpy as np
import struct
import uuid
import re
import time
from math import radians, sin, cos, sqrt, atan2
from osgeo import gdal, osr
import pyproj
from b3dm.glb_with_draco import DracoGLBParser
# 解决GDAL中文路径/警告问题
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("CPL_ZIP_ENCODING", "UTF-8")
gdal.PushErrorHandler('CPLQuietErrorHandler')
# ========== 区域过滤器类 ==========
class RegionFilter:
"""智能区域过滤器,支持多层过滤策略"""
def __init__(self, region_coords=None, enable_tile_filter=True, debug=False):
"""
初始化区域过滤器
:param region_coords: 区域坐标 [(min_lon, min_lat), (max_lon, max_lat)]
:param enable_tile_filter: 是否启用瓦片级别过滤
:param debug: 调试模式
"""
self.region_coords = region_coords
self.enable_tile_filter = enable_tile_filter
self.debug = debug
self._transformer = None
if region_coords:
# 提取区域边界
lons = [coord[0] for coord in region_coords]
lats = [coord[1] for coord in region_coords]
self.min_lon = min(lons)
self.max_lon = max(lons)
self.min_lat = min(lats)
self.max_lat = max(lats)
# 扩展边界(避免边缘误差)
self.expand_factor = 0.1
lon_expand = (self.max_lon - self.min_lon) * self.expand_factor
lat_expand = (self.max_lat - self.min_lat) * self.expand_factor
self.filter_min_lon = self.min_lon - lon_expand
self.filter_max_lon = self.max_lon + lon_expand
self.filter_min_lat = self.min_lat - lat_expand
self.filter_max_lat = self.max_lat + lat_expand
if self.debug:
print(f"[区域过滤] 初始化:")
print(f" 目标区域: Lon[{self.min_lon:.6f}, {self.max_lon:.6f}], "
f"Lat[{self.min_lat:.6f}, {self.max_lat:.6f}]")
print(f" 过滤区域: Lon[{self.filter_min_lon:.6f}, {self.filter_max_lon:.6f}], "
f"Lat[{self.filter_min_lat:.6f}, {self.filter_max_lat:.6f}]")
else:
self.min_lon = self.min_lat = -180
self.max_lon = self.max_lat = 180
self.filter_min_lon = self.filter_min_lat = -180
self.filter_max_lon = self.filter_max_lat = 180
if self.debug:
print("[区域过滤] 未指定区域,处理所有数据")
def apply_transform_to_box(self, center_point, transform_matrix):
"""将box中心点应用transform矩阵"""
# 原始box中心点
x, y, z = center_point
# 转换为齐次坐标
point_hom = np.array([x, y, z, 1.0], dtype=np.float64)
# transform矩阵4x4
transform_mat = np.array(transform_matrix).reshape(4, 4).astype(np.float64).T
# 矩阵乘法
point_transformed_hom = np.dot(point_hom, transform_mat.T)
# 转回3D坐标
point_transformed = point_transformed_hom[:3] / point_transformed_hom[3]
return point_transformed
def _get_transformer(self):
"""获取坐标转换器"""
if self._transformer is None:
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
self._transformer = pyproj.Transformer.from_proj(ecef, lla, always_xy=True)
return self._transformer
def check_tile_bounding_volume(self, current_transform, bounding_volume):
"""检查瓦片包围体是否在区域内"""
if not self.region_coords or not self.enable_tile_filter:
return True
try:
if 'region' in bounding_volume:
return self._check_region(bounding_volume['region'])
elif 'box' in bounding_volume:
return self._check_box(current_transform, bounding_volume['box'])
elif 'sphere' in bounding_volume:
return self._check_sphere(bounding_volume['sphere'])
return True
except Exception as e:
if self.debug:
print(f"[包围体检查] 错误: {e}")
return True
def _check_region(self, region):
"""检查region格式包围体"""
if len(region) != 6:
return True
west, south, east, north, min_h, max_h = region
# 检查是否完全在过滤区域外
if (east < self.filter_min_lon or west > self.filter_max_lon or
north < self.filter_min_lat or south > self.filter_max_lat):
if self.debug:
print(f"[region过滤] 区域外: [{west:.3f},{south:.3f},{east:.3f},{north:.3f}]")
return False
return True
def _check_box(self, current_transform, box):
"""检查box包围体支持12值和15值格式"""
if len(box) == 12:
return self._check_box_12(current_transform, box)
elif len(box) == 15:
return self._check_box_15(current_transform, box)
else:
if self.debug:
print(f"[box检查] 异常长度 {len(box)},默认通过")
return True
def _check_box_12(self, current_transform, box):
"""检查12值box格式 [cx, cy, cz, halfX, 0, 0, 0, halfY, 0, 0, 0, halfZ]"""
cx, cy, cz = box[0], box[1], box[2]
halfX, halfY, halfZ = box[3], box[7], box[11]
if self.debug:
print(f"[box12检查] center=({cx:.1f},{cy:.1f},{cz:.1f}), halfs=({halfX},{halfY},{halfZ})")
# 应用转换矩阵
transform_center = self.apply_transform_to_box([cx, cy, cz], current_transform)
cx, cy, cz = transform_center[0], transform_center[1], transform_center[2]
try:
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
# 计算最大偏移
max_half = max(halfX, halfY, halfZ)
earth_radius = 6378137.0
lon_offset = np.degrees(max_half / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(max_half / earth_radius)
# 检查相交
box_min_lon = center_lon - lon_offset
box_max_lon = center_lon + lon_offset
box_min_lat = center_lat - lat_offset
box_max_lat = center_lat + lat_offset
if (box_max_lon < self.filter_min_lon or box_min_lon > self.filter_max_lon or
box_max_lat < self.filter_min_lat or box_min_lat > self.filter_max_lat):
if self.debug:
print(f"[box12过滤] 区域外: 中心({center_lon:.3f},{center_lat:.3f})")
return False
return True
except Exception as e:
if self.debug:
print(f"[box12检查] 错误: {e}")
return True
def _check_box_15(self, current_transform, box):
"""检查标准15值box格式"""
if len(box) < 15:
return True
cx, cy, cz = box[0], box[1], box[2]
halfX, halfY, halfZ = box[12], box[13], box[14]
try:
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
max_half = max(halfX, halfY, halfZ)
earth_radius = 6378137.0
lon_offset = np.degrees(max_half / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(max_half / earth_radius)
box_min_lon = center_lon - lon_offset
box_max_lon = center_lon + lon_offset
box_min_lat = center_lat - lat_offset
box_max_lat = center_lat + lat_offset
if (box_max_lon < self.filter_min_lon or box_min_lon > self.filter_max_lon or
box_max_lat < self.filter_min_lat or box_min_lat > self.filter_max_lat):
if self.debug:
print(f"[box15过滤] 区域外: 中心({center_lon:.3f},{center_lat:.3f})")
return False
return True
except Exception as e:
if self.debug:
print(f"[box15检查] 错误: {e}")
return True
def _check_sphere(self, sphere):
"""检查sphere包围体"""
if len(sphere) < 4:
return True
cx, cy, cz, radius = sphere[0], sphere[1], sphere[2], sphere[3]
try:
transformer = self._get_transformer()
center_lon, center_lat, _ = transformer.transform(cx, cy, cz, radians=False)
earth_radius = 6378137.0
lon_offset = np.degrees(radius / (earth_radius * np.cos(np.radians(center_lat))))
lat_offset = np.degrees(radius / earth_radius)
sphere_min_lon = center_lon - lon_offset
sphere_max_lon = center_lon + lon_offset
sphere_min_lat = center_lat - lat_offset
sphere_max_lat = center_lat + lat_offset
if (sphere_max_lon < self.filter_min_lon or sphere_min_lon > self.filter_max_lon or
sphere_max_lat < self.filter_min_lat or sphere_min_lat > self.filter_max_lat):
if self.debug:
print(f"[sphere过滤] 区域外: 中心({center_lon:.3f},{center_lat:.3f})")
return False
return True
except Exception as e:
if self.debug:
print(f"[sphere检查] 错误: {e}")
return True
def check_b3dm_file_quick(self, b3dm_path, current_transform, bounding_volume=None):
"""快速检查B3DM文件是否可能在区域内"""
if not self.region_coords or not self.enable_tile_filter:
return True
filename = os.path.basename(b3dm_path)
if self.debug:
print(f"[快速检查] B3DM文件: {filename}")
# 方法1: 使用包围体信息
if bounding_volume:
return self.check_tile_bounding_volume(current_transform, bounding_volume)
# # 方法2: 检查文件名中的坐标信息
# coord_patterns = [
# r'tile[_\-](\d+\.?\d*)[_\-](\d+\.?\d*)',
# r'(\d+\.?\d*)[_\-](\d+\.?\d*)\.b3dm$',
# r'(\d{3})[_\-](\d{2})\.b3dm$',
# ]
# for pattern in coord_patterns:
# match = re.search(pattern, filename, re.IGNORECASE)
# if match:
# try:
# lon = float(match.group(1))
# lat = float(match.group(2))
# # 调整坐标范围
# if lon > 180:
# lon = lon - 360
# if lat > 90:
# lat = 90 - (lat - 90)
# if self.debug:
# print(f"[快速检查] 提取坐标: ({lon:.6f}, {lat:.6f})")
# if (lon < self.filter_min_lon or lon > self.filter_max_lon or
# lat < self.filter_min_lat or lat > self.filter_max_lat):
# if self.debug:
# print(f"[快速检查] 坐标({lon:.2f},{lat:.2f})在区域外")
# return False
# return True
# except:
# continue
# 方法3: 检查文件大小
try:
file_size = os.path.getsize(b3dm_path)
if file_size < 1024: # 小于1KB跳过
if self.debug:
print(f"[快速检查] 文件过小({file_size} bytes),跳过")
return False
except:
pass
# 默认处理
return True
def filter_points(self, points):
"""精确过滤点集,只保留区域内的点"""
if not self.region_coords or len(points) == 0:
return points
points_array = np.array(points)
if len(points_array) == 0:
return []
# 检查每个点是否在区域内
in_region_mask = (
(points_array[:, 0] >= self.min_lon) &
(points_array[:, 0] <= self.max_lon) &
(points_array[:, 1] >= self.min_lat) &
(points_array[:, 1] <= self.max_lat)
)
filtered_points = points_array[in_region_mask]
if len(points_array) > 0 and self.debug:
filtered_percent = (len(filtered_points) / len(points_array)) * 100
print(f"[点过滤] {len(points_array)}{len(filtered_points)} 点 ({filtered_percent:.1f}%保留)")
return filtered_points.tolist()
# ========== 点云增强器类 ==========
class PointCloudEnhancer:
"""点云智能增强器,用于稀疏点云的补足"""
def __init__(self, strategy='balanced'):
"""
初始化点云增强器
:param strategy: 增强策略 'minimal'|'balanced'|'aggressive'
"""
self.strategy = strategy
def enhance_points(self, base_points, original_vertices=None,
target_density=0.5, pixel_size=0.0001):
"""
增强点云密度
:param base_points: 基础点云
:param original_vertices: 原始顶点ECEF坐标
:param target_density: 目标点密度(点/像素)
:param pixel_size: DEM像素大小
:return: 增强后的点云
"""
if len(base_points) < 100: # 点数太少,需要增强
print(f"[点云增强] 点数过少({len(base_points)}),启动增强")
if original_vertices is not None and len(original_vertices) > 0:
enhanced = self._generate_enhanced_points(
original_vertices,
num_points=1000 - len(base_points) # 补足到1000点
)
# 转换到经纬度
if enhanced.size > 0:
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
transformer = pyproj.Transformer.from_proj(ecef, lla, always_xy=True)
lons, lats, heights = transformer.transform(
enhanced[:, 0], enhanced[:, 1], enhanced[:, 2], radians=False
)
enhanced_points = np.column_stack([lons, lats, heights])
base_array = np.array(base_points)
combined = np.vstack([base_array, enhanced_points])
print(f"[点云增强] 增强后点数: {len(combined)}")
return combined.tolist()
return base_points
def _generate_enhanced_points(self, vertices, num_points):
"""生成增强点"""
if len(vertices) == 0:
return np.array([])
min_coords = np.min(vertices, axis=0)
max_coords = np.max(vertices, axis=0)
ranges = max_coords - min_coords
if self.strategy == 'minimal':
# 最小增强:随机扰动
noise = np.random.randn(num_points, 3) * 0.1
indices = np.random.randint(0, len(vertices), num_points)
enhanced = vertices[indices] + noise
elif self.strategy == 'balanced':
# 平衡增强:均匀采样 + 扰动
uniform_num = int(num_points * 0.7)
noise_num = num_points - uniform_num
# 均匀采样
uniform_points = min_coords + np.random.rand(uniform_num, 3) * ranges
# 顶点扰动
indices = np.random.randint(0, len(vertices), noise_num)
noise = np.random.randn(noise_num, 3) * 0.05
noise_points = vertices[indices] + noise
enhanced = np.vstack([uniform_points, noise_points])
else: # aggressive
# 积极增强:多种方法组合
points_per_dim = int(np.ceil(num_points ** (1/3)))
x = np.linspace(min_coords[0], max_coords[0], points_per_dim)
y = np.linspace(min_coords[1], max_coords[1], points_per_dim)
z = np.linspace(min_coords[2], max_coords[2], points_per_dim)
xx, yy, zz = np.meshgrid(x, y, z)
grid_points = np.column_stack([xx.ravel(), yy.ravel(), zz.ravel()])
if len(grid_points) > num_points:
indices = np.random.choice(len(grid_points), num_points, replace=False)
enhanced = grid_points[indices]
else:
enhanced = grid_points
return enhanced[:num_points]
# ========== 核心工具函数 ==========
def apply_transform_matrix(vertices, transform_matrix):
"""应用变换矩阵到顶点"""
if transform_matrix is None or len(transform_matrix) != 16:
return vertices
# 转换矩阵
mat = np.array(transform_matrix).reshape(4, 4).astype(np.float64).T
# 顶点齐次坐标化
ones = np.ones((vertices.shape[0], 1), dtype=np.float64)
vertices_hom = np.hstack([vertices, ones])
# 矩阵乘法
vertices_ecef_hom = np.dot(vertices_hom, mat.T)
vertices_ecef = vertices_ecef_hom[:, :3] / vertices_ecef_hom[:, 3:4]
return vertices_ecef
def parse_b3dm_to_points(b3dm_path, region_filter=None, transform_matrix=None,
enhancer=None, min_points_threshold=500):
"""
解析B3DM文件提取顶点
:param enhancer: 点云增强器
:param min_points_threshold: 最小点阈值
"""
# 创建临时目录
script_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = os.path.join(script_dir, "temp_glb")
os.makedirs(temp_dir, exist_ok=True)
# 读取B3DM文件
try:
with open(b3dm_path, "rb") as f:
b3dm_data = f.read()
except Exception as e:
print(f"[B3DM解析] 读取失败 {b3dm_path}: {e}")
return []
# 解析头部
try:
header = struct.unpack('<4sIIIIII', b3dm_data[:28])
ft_json_len, ft_bin_len, bt_json_len, bt_bin_len = header[3:7]
offset = 28 + ft_json_len + ft_bin_len + bt_json_len + bt_bin_len
glb_data = b3dm_data[offset:]
if len(glb_data) < 12:
return []
except Exception as e:
print(f"[B3DM解析] 头部解析失败 {b3dm_path}: {e}")
return []
# 写入临时文件
temp_file_path = os.path.join(temp_dir, f"temp_{uuid.uuid4().hex[:8]}.glb")
try:
with open(temp_file_path, "wb") as tmp_glb:
tmp_glb.write(glb_data)
# 解析GLB
parser = DracoGLBParser(temp_file_path)
parser.parse_glb_structure()
parser.analyze_structure()
mesh = parser.decode_draco_meshes()
if mesh is None:
print(f"[B3DM解析] 无法加载模型: {b3dm_path}")
return []
# 获取顶点
vertices = parser.get_all_vertices()
if vertices.size == 0 or len(vertices.shape) < 2 or vertices.shape[1] != 3:
print(f"[B3DM解析] 顶点数据无效: {b3dm_path}")
return []
# 保存原始顶点(用于增强)
original_vertices = vertices.copy()
# 应用变换矩阵
vertices = apply_transform_matrix(vertices, transform_matrix)
# ECEF转WGS84
ecef = pyproj.Proj(proj='geocent', ellps='WGS84', datum='WGS84')
lla = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
transformer = pyproj.Transformer.from_proj(ecef, lla, always_xy=True)
lons, lats, heights = transformer.transform(
vertices[:, 0], vertices[:, 1], vertices[:, 2], radians=False
)
# 组合点云
points = np.column_stack([lons, lats, heights])
# 基本数据清洗
valid_mask = np.isfinite(points).all(axis=1)
points = points[valid_mask]
if len(points) == 0:
return []
# 区域过滤
if region_filter:
points = region_filter.filter_points(points)
if len(points) == 0:
print(f"[B3DM解析] 所有点都在区域外: {os.path.basename(b3dm_path)}")
return []
print(f"[B3DM解析] {os.path.basename(b3dm_path)}: {len(points)}")
# 点云增强(如果需要)
if enhancer and len(points) < min_points_threshold:
points = enhancer.enhance_points(
points.tolist(),
original_vertices,
target_density=0.3,
pixel_size=0.0001
)
return points.tolist() if isinstance(points, np.ndarray) else points
except Exception as e:
print(f"[B3DM解析] 解析失败 {b3dm_path}: {e}")
return []
finally:
# 清理临时文件
if os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except:
pass
# ========== 遍历函数 ==========
def traverse_nested_tiles(tile_obj, base_dir, b3dm_paths, tile_transforms,
tile_bounding_volumes, region_filter=None,
parent_transform=None, stats=None, depth=0):
"""
递归遍历tileset结构
"""
if stats is None:
stats = {
'tiles_checked': 0,
'b3dm_collected': 0,
'b3dm_filtered': 0,
'json_files': 0
}
stats['tiles_checked'] += 1
# 计算变换矩阵
current_transform = parent_transform
if "transform" in tile_obj:
tile_mat = tile_obj["transform"]
if current_transform is None:
current_transform = tile_mat
else:
mat1 = np.array(current_transform).reshape(4, 4)
mat2 = np.array(tile_mat).reshape(4, 4)
combined_mat = np.dot(mat1, mat2).flatten().tolist()
current_transform = combined_mat
# 处理当前瓦片内容
if "content" in tile_obj and "uri" in tile_obj["content"]:
tile_uri = tile_obj["content"]["uri"]
tile_abs_path = os.path.join(base_dir, tile_uri)
if tile_uri.lower().endswith(".json"):
# 子JSON文件
stats['json_files'] += 1
if os.path.exists(tile_abs_path):
try:
with open(tile_abs_path, "r", encoding="utf-8") as f:
sub_tileset = json.load(f)
sub_base_dir = os.path.dirname(tile_abs_path)
traverse_nested_tiles(
sub_tileset["root"], sub_base_dir, b3dm_paths, tile_transforms,
tile_bounding_volumes, region_filter, current_transform,
stats, depth + 1
)
except Exception as e:
print(f"{' '*depth}[遍历] JSON解析失败: {tile_abs_path}: {e}")
else:
print(f"{' '*depth}[遍历] JSON文件不存在: {tile_abs_path}")
elif tile_uri.lower().endswith(".b3dm"):
# B3DM文件
if os.path.exists(tile_abs_path):
bounding_volume = tile_obj.get("boundingVolume", {})
# 快速检查
should_process = True
if region_filter:
should_process = region_filter.check_b3dm_file_quick(
tile_abs_path, current_transform, bounding_volume
)
if should_process:
b3dm_paths.append(tile_abs_path)
tile_transforms.append(current_transform)
tile_bounding_volumes.append(bounding_volume)
stats['b3dm_collected'] += 1
if region_filter and region_filter.debug:
print(f"{' '*depth}[遍历] 收集: {os.path.basename(tile_abs_path)}")
else:
stats['b3dm_filtered'] += 1
if region_filter and region_filter.debug:
print(f"{' '*depth}[遍历] 过滤: {os.path.basename(tile_abs_path)}")
else:
print(f"{' '*depth}[遍历] B3DM文件不存在: {tile_abs_path}")
# 递归处理子节点
if "children" in tile_obj and len(tile_obj["children"]) > 0:
for child_tile in tile_obj["children"]:
traverse_nested_tiles(
child_tile, base_dir, b3dm_paths, tile_transforms,
tile_bounding_volumes, region_filter, current_transform,
stats, depth + 1
)
return stats
# ========== 主解析函数 ==========
def parse_tileset(tileset_path, region_coords=None, enable_enhancement=True, debug=False):
"""
主解析函数
:param enable_enhancement: 是否启用点云增强
:param debug: 调试模式
:return: 点云数据
"""
start_time = time.time()
if not os.path.exists(tileset_path):
raise FileNotFoundError(f"tileset.json不存在: {tileset_path}")
print("=" * 60)
print("开始解析3D Tiles数据")
print("=" * 60)
# 初始化过滤器
region_filter = RegionFilter(region_coords, enable_tile_filter=True, debug=debug)
# 初始化增强器
enhancer = PointCloudEnhancer(strategy='balanced') if enable_enhancement else None
# 读取根tileset
with open(tileset_path, "r", encoding="utf-8") as f:
tileset_json = json.load(f)
root_dir = os.path.dirname(tileset_path)
b3dm_paths = []
tile_transforms = []
tile_bounding_volumes = []
print("[主解析] 开始遍历tileset结构...")
# 遍历收集B3DM文件
stats = traverse_nested_tiles(
tileset_json["root"], root_dir, b3dm_paths, tile_transforms,
tile_bounding_volumes, region_filter, None,
{'tiles_checked': 0, 'b3dm_collected': 0,
'b3dm_filtered': 0, 'json_files': 0}
)
print(f"\n[主解析] 遍历完成:")
print(f" 检查瓦片: {stats['tiles_checked']}")
print(f" JSON入口: 1")
print(f" JSON文件: {stats['json_files']}")
print(f" 收集B3DM: {stats['b3dm_collected']}")
print(f" 过滤B3DM: {stats['b3dm_filtered']}")
# 如果没有收集到文件,尝试放宽过滤
if len(b3dm_paths) == 0 and region_coords:
print("\n[主解析] 警告: 未收集到B3DM文件尝试禁用快速过滤...")
region_filter.enable_tile_filter = False
b3dm_paths.clear()
tile_transforms.clear()
tile_bounding_volumes.clear()
stats = traverse_nested_tiles(
tileset_json["root"], root_dir, b3dm_paths, tile_transforms,
tile_bounding_volumes, region_filter, None,
{'tiles_checked': 0, 'b3dm_collected': 0,
'b3dm_filtered': 0, 'json_files': 0}
)
print(f"[主解析] 重新扫描收集到 {len(b3dm_paths)} 个B3DM文件")
# 解析B3DM文件
all_points = []
if len(b3dm_paths) == 0:
print("[主解析] 错误: 未找到有效的B3DM文件")
return all_points
print(f"\n[主解析] 开始解析 {len(b3dm_paths)} 个B3DM文件...")
total_points = 0
for i, (b3dm_path, transform_mat) in enumerate(zip(b3dm_paths, tile_transforms), 1):
filename = os.path.basename(b3dm_path)
print(f" [{i}/{len(b3dm_paths)}] 解析: {filename}")
points = parse_b3dm_to_points(
b3dm_path,
region_filter,
transform_mat,
enhancer,
min_points_threshold=300
)
if points:
all_points.extend(points)
total_points += len(points)
print(f" 提取: {len(points)}")
else:
print(f" 无有效点")
# 点云处理
if all_points:
all_points_array = np.array(all_points)
original_count = len(all_points_array)
# 去重
all_points_array = np.unique(all_points_array.round(decimals=6), axis=0)
# 最终区域过滤(确保精确)
if region_filter:
filtered_points = region_filter.filter_points(all_points_array)
all_points_array = np.array(filtered_points)
elapsed_time = time.time() - start_time
print(f"\n[主解析] 解析完成:")
print(f" 原始点数: {total_points}")
print(f" 去重后数: {len(all_points_array)}")
print(f" 处理时间: {elapsed_time:.1f}")
print(f" 点云密度: {len(all_points_array)/max(1, total_points)*100:.1f}%保留")
return all_points_array.tolist()
print("[主解析] 错误: 未提取到任何点云数据")
return []
# ========== DEM生成函数 ==========
def points_to_dem(points, output_dem_path, pixel_size=None, quality='high', min_resolution=10):
"""
将点云转换为DEM确保足够的网格分辨率用于坡度坡向计算
:param points: 点云数据 [(lon, lat, height), ...]
:param output_dem_path: 输出DEM文件路径
:param pixel_size: 像素大小如果为None则自动计算
:param quality: 质量等级 'low'|'medium'|'high'|'terrain'(地形分析专用)
:param min_resolution: 最小网格分辨率(像素数),用于保证坡度计算
"""
import time
import os
import numpy as np
from osgeo import gdal, osr
from scipy.interpolate import griddata
import warnings
if len(points) == 0:
raise ValueError("无点云数据无法生成DEM")
start_time = time.time()
# 转换为numpy数组
points_array = np.array(points)
lons = points_array[:, 0]
lats = points_array[:, 1]
heights = points_array[:, 2]
# 计算范围
min_lon, max_lon = lons.min(), lons.max()
min_lat, max_lat = lats.min(), lats.max()
lon_range = max_lon - min_lon
lat_range = max_lat - min_lat
print(f"[DEM生成] 点云范围:")
print(f" 经度: {min_lon:.6f}° ~ {max_lon:.6f}° (范围: {lon_range:.6f}°)")
print(f" 纬度: {min_lat:.6f}° ~ {max_lat:.6f}° (范围: {lat_range:.6f}°)")
print(f" 高程: {heights.min():.2f}m ~ {heights.max():.2f}m")
print(f" 点数: {len(points):,}")
# 自动确定像素大小(优化版本)
if pixel_size is None:
# 根据数据量自动确定
if quality == 'terrain': # 地形分析专用
# 确保足够的网格分辨率用于坡度计算
target_pixels = max(min_resolution, int(np.sqrt(len(points)) / 2))
# 根据数据范围计算像素大小
lon_pixel = lon_range / target_pixels
lat_pixel = lat_range / target_pixels
# 取较小的像素大小以保证分辨率
pixel_size = min(lon_pixel, lat_pixel)
# 设置像素大小范围限制
pixel_size = max(pixel_size, 0.000001) # 最小约0.1米
pixel_size = min(pixel_size, 0.0005) # 最大约55米
elif quality == 'high':
pixel_size = max(0.00001, lon_range / 100) # 至少1.1米最多100像素
elif quality == 'medium':
pixel_size = max(0.00002, lon_range / 50) # 至少2.2米最多50像素
else: # low
pixel_size = max(0.00005, lon_range / 20) # 至少5.6米最多20像素
# 计算网格尺寸
width = max(10, int(lon_range / pixel_size) + 1)
height = max(10, int(lat_range / pixel_size) + 1)
# 确保网格大小符合最小分辨率要求
if width * height < min_resolution * min_resolution:
print(f"[DEM生成] 警告:网格分辨率不足 ({width}x{height}),自动调整...")
# 重新计算像素大小以满足最小分辨率
target_cells = min_resolution * min_resolution
target_pixel_size = np.sqrt(lon_range * lat_range / target_cells)
pixel_size = max(0.000001, min(pixel_size, target_pixel_size))
width = max(min_resolution, int(lon_range / pixel_size) + 1)
height = max(min_resolution, int(lat_range / pixel_size) + 1)
# 扩展数据范围以增加边缘像素(有助于坡度计算)
expand_factor = 1.1 # 扩展10%
min_lon_exp = min_lon - lon_range * (expand_factor - 1) / 2
max_lon_exp = max_lon + lon_range * (expand_factor - 1) / 2
min_lat_exp = min_lat - lat_range * (expand_factor - 1) / 2
max_lat_exp = max_lat + lat_range * (expand_factor - 1) / 2
print(f"[DEM生成] 网格设置:")
print(f" 像素大小: {pixel_size:.6f}° (~{pixel_size*111320:.1f}米)")
print(f" 网格尺寸: {width} × {height}")
print(f" 总像素数: {width * height:,}")
print(f" 点云密度: {len(points)/(width*height):.2f} 点/像素")
# 创建网格
x_grid = np.linspace(min_lon_exp, max_lon_exp, width)
y_grid = np.linspace(max_lat_exp, min_lat_exp, height) # 纬度从上到下
xi, yi = np.meshgrid(x_grid, y_grid)
# 插值 - 优化版本
print("[DEM生成] 开始插值计算...")
# 检查数据分布
grid_points = np.column_stack([xi.flatten(), yi.flatten()])
data_points = np.column_stack([lons, lats])
# 使用KD树加速最近邻查询
try:
from scipy.spatial import cKDTree
tree = cKDTree(data_points)
dists, _ = tree.query(grid_points, k=1)
max_dist = np.max(dists)
print(f"[DEM生成] 最大最近邻距离: {max_dist*111320:.1f}")
except:
pass
# 尝试IDW插值适用于地形
try:
if len(points) > 100:
print("[DEM生成] 使用IDW插值...")
zi = idw_interpolation(lons, lats, heights, xi, yi, power=2)
else:
# 数据太少,使用线性+最近邻
zi = griddata((lons, lats), heights, (xi, yi), method='linear', fill_value=np.nan)
except Exception as e:
print(f"[DEM生成] IDW插值失败: {e},使用线性插值")
zi = griddata((lons, lats), heights, (xi, yi), method='linear', fill_value=np.nan)
# 处理空白区域
nan_mask = np.isnan(zi)
if np.any(nan_mask):
nan_count = np.sum(nan_mask)
nan_percent = nan_count / (width * height) * 100
print(f"[DEM生成] 插值空白: {nan_count:,} 像素 ({nan_percent:.1f}%)")
if nan_percent < 50: # 空白区域少于50%
# 使用最近邻填充空白
zi_nn = griddata((lons, lats), heights, (xi, yi), method='nearest')
zi[nan_mask] = zi_nn[nan_mask]
else:
print("[DEM生成] 警告:空白区域过多,使用最近邻插值")
zi = griddata((lons, lats), heights, (xi, yi), method='nearest')
# 平滑处理(可选,有助于坡度计算)
if quality in ['terrain', 'high']:
try:
from scipy.ndimage import gaussian_filter
sigma = 0.5 # 高斯滤波参数
zi = gaussian_filter(zi, sigma=sigma, mode='nearest')
print(f"[DEM生成] 应用高斯平滑 (sigma={sigma})")
except:
pass
# 创建GeoTIFF
print("[DEM生成] 创建GeoTIFF文件...")
driver = gdal.GetDriverByName("GTiff")
# 压缩选项
if quality in ['terrain', 'high']:
options = ["COMPRESS=DEFLATE", "PREDICTOR=3", "ZLEVEL=6", "TILED=YES", "BLOCKXSIZE=256", "BLOCKYSIZE=256"]
else:
options = ["COMPRESS=LZW", "TILED=YES"]
dem_ds = driver.Create(output_dem_path, width, height, 1, gdal.GDT_Float32, options)
if dem_ds is None:
raise RuntimeError(f"无法创建DEM文件: {output_dem_path}")
# 设置投影和地理变换
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # WGS84
dem_ds.SetProjection(srs.ExportToWkt())
# 注意:使用扩展后的范围
geotransform = [
min_lon_exp, pixel_size, 0,
max_lat_exp, 0, -pixel_size
]
dem_ds.SetGeoTransform(geotransform)
# 写入数据
band = dem_ds.GetRasterBand(1)
band.WriteArray(zi)
band.SetNoDataValue(-9999.0)
band.SetDescription("Elevation (meters)")
band.SetUnitType("meters")
# 计算统计信息(处理可能的统计问题)
print("[DEM生成] 计算统计信息...")
band.FlushCache()
try:
band.ComputeStatistics(False)
print("[DEM生成] 统计信息计算完成")
except RuntimeError as e:
print(f"[DEM生成] 警告:统计信息计算失败: {e}")
# 手动计算统计信息
valid_mask = zi != -9999.0
if np.any(valid_mask):
valid_data = zi[valid_mask]
stats = [
float(np.min(valid_data)),
float(np.max(valid_data)),
float(np.mean(valid_data)),
float(np.std(valid_data))
]
band.SetStatistics(*stats)
print(f"[DEM生成] 手动设置统计信息:")
print(f" 最小值: {stats[0]:.2f}m")
print(f" 最大值: {stats[1]:.2f}m")
print(f" 平均值: {stats[2]:.2f}m")
print(f" 标准差: {stats[3]:.2f}m")
# 构建金字塔
if quality in ['terrain', 'high']:
print("[DEM生成] 构建金字塔...")
gdal.SetConfigOption('COMPRESS_OVERVIEW', 'DEFLATE')
dem_ds.BuildOverviews("AVERAGE", [2, 4, 8, 16])
dem_ds = None # 关闭文件
# 验证结果
elapsed_time = time.time() - start_time
file_size_mb = os.path.getsize(output_dem_path) / (1024 * 1024) if os.path.exists(output_dem_path) else 0
# 重新打开验证
try:
ds = gdal.Open(output_dem_path, gdal.GA_ReadOnly)
if ds:
band = ds.GetRasterBand(1)
actual_width = ds.RasterXSize
actual_height = ds.RasterYSize
print(f"[DEM生成] 验证结果:")
print(f" 实际尺寸: {actual_width} × {actual_height}")
print(f" 文件大小: {file_size_mb:.2f} MB")
print(f" 处理时间: {elapsed_time:.1f}")
# 检查是否适合坡度计算
if actual_width >= 10 and actual_height >= 10:
print(f"[DEM生成] ✓ DEM分辨率适合坡度坡向计算")
else:
print(f"[DEM生成] ⚠ DEM分辨率较低坡度计算可能不准确")
ds = None
except:
print(f"[DEM生成] 完成! 文件: {output_dem_path}")
return output_dem_path
def idw_interpolation(x, y, z, xi, yi, power=2, radius=None):
"""
反距离权重插值
"""
import numpy as np
from scipy.spatial import cKDTree
x = np.asarray(x)
y = np.asarray(y)
z = np.asarray(z)
xi = np.asarray(xi)
yi = np.asarray(yi)
# 展平网格
xi_flat = xi.flatten()
yi_flat = yi.flatten()
# 创建KD树
tree = cKDTree(np.column_stack([x, y]))
# 设置搜索半径
if radius is None:
# 自动确定半径平均点间距的3倍
from scipy.spatial.distance import pdist
if len(x) > 1:
distances = pdist(np.column_stack([x, y]))
radius = np.mean(distances) * 3
else:
radius = 0.01 # 默认值
# 查询最近邻点
if len(z) > 50: # 数据较多时,限制搜索点数
k = min(12, len(z))
dists, idxs = tree.query(np.column_stack([xi_flat, yi_flat]), k=k, distance_upper_bound=radius)
else:
dists, idxs = tree.query(np.column_stack([xi_flat, yi_flat]), k=len(z))
# IDW插值
zi_flat = np.zeros(len(xi_flat))
for i in range(len(xi_flat)):
valid_mask = idxs[i] < len(z)
if np.any(valid_mask):
valid_dists = dists[i][valid_mask]
valid_idx = idxs[i][valid_mask]
# 避免除零
valid_dists = np.maximum(valid_dists, 1e-10)
# 计算权重
weights = 1.0 / (valid_dists ** power)
weights = weights / np.sum(weights)
# 加权平均
zi_flat[i] = np.sum(z[valid_idx] * weights)
else:
zi_flat[i] = np.nan
# 恢复形状
zi = zi_flat.reshape(xi.shape)
return zi
def points_to_dem1(points, output_dem_path, pixel_size=None, quality='medium'):
"""
将点云转换为DEM
:param quality: 质量等级 'low'|'medium'|'high'
"""
if len(points) == 0:
raise ValueError("无点云数据无法生成DEM")
start_time = time.time()
# 转换为numpy数组
points_array = np.array(points)
lons = points_array[:, 0]
lats = points_array[:, 1]
heights = points_array[:, 2]
# 计算范围
min_lon, max_lon = lons.min(), lons.max()
min_lat, max_lat = lats.min(), lats.max()
print(f"[DEM生成] 点云范围:")
print(f" 经度: {min_lon:.6f}° ~ {max_lon:.6f}°")
print(f" 纬度: {min_lat:.6f}° ~ {max_lat:.6f}°")
print(f" 高程: {heights.min():.2f}m ~ {heights.max():.2f}m")
print(f" 点数: {len(points)}")
# 自动确定像素大小
if pixel_size is None:
area_km2 = ((max_lon - min_lon) * 111.32) * ((max_lat - min_lat) * 111.32)
if quality == 'high':
pixel_size = 0.00005 # ~5米
elif quality == 'medium':
pixel_size = 0.0001 # ~10米
else: # low
pixel_size = 0.0002 # ~20米
# 根据点数调整
if len(points) < 1000:
pixel_size = max(pixel_size, 0.0005) # 至少50米
elif len(points) < 5000:
pixel_size = max(pixel_size, 0.0002) # 至少20米
# 计算网格尺寸
width = int((max_lon - min_lon) / pixel_size) + 1
height = int((max_lat - min_lat) / pixel_size) + 1
# 限制最大网格大小
max_grid_size = 10000
if width > max_grid_size or height > max_grid_size:
print(f"[DEM生成] 警告: 网格过大({width}x{height}),调整像素大小")
larger_dim = max(width, height)
pixel_size = pixel_size * (larger_dim / max_grid_size)
width = int((max_lon - min_lon) / pixel_size) + 1
height = int((max_lat - min_lat) / pixel_size) + 1
print(f"[DEM生成] 网格设置:")
print(f" 像素大小: {pixel_size:.6f}° (~{pixel_size*111320:.1f}米)")
print(f" 网格尺寸: {width} × {height}")
print(f" 总像素数: {width * height:,}")
# 创建网格
x_grid = np.linspace(min_lon, max_lon, width)
y_grid = np.linspace(max_lat, min_lat, height) # 纬度从上到下
xi, yi = np.meshgrid(x_grid, y_grid)
# 插值
print("[DEM生成] 开始插值计算...")
from scipy.interpolate import griddata
try:
# 先尝试线性插值
zi = griddata((lons, lats), heights, (xi, yi), method='linear')
nan_count = np.isnan(zi).sum()
if nan_count > 0:
print(f"[DEM生成] 线性插值空白: {nan_count} 像素 ({nan_count/(width*height)*100:.1f}%)")
# 使用最近邻填充空白
zi_nn = griddata((lons, lats), heights, (xi, yi), method='nearest')
mask = np.isnan(zi)
zi[mask] = zi_nn[mask]
# 检查剩余空白
nan_count = np.isnan(zi).sum()
if nan_count > 0:
print(f"[DEM生成] 填充剩余空白: {nan_count} 像素")
min_height = heights.min()
zi[np.isnan(zi)] = min_height
except Exception as e:
print(f"[DEM生成] 插值失败: {e},使用最近邻插值")
zi = griddata((lons, lats), heights, (xi, yi), method='nearest')
# 创建GeoTIFF
print("[DEM生成] 创建GeoTIFF文件...")
driver = gdal.GetDriverByName("GTiff")
# 压缩选项
if quality == 'high':
options = ["COMPRESS=DEFLATE", "PREDICTOR=2", "ZLEVEL=9", "TILED=YES", "BIGTIFF=IF_SAFER"]
else:
options = ["COMPRESS=LZW", "TILED=YES"]
dem_ds = driver.Create(output_dem_path, width, height, 1, gdal.GDT_Float32, options)
if dem_ds is None:
raise RuntimeError(f"无法创建DEM文件: {output_dem_path}")
# 设置投影和地理变换
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # WGS84
dem_ds.SetProjection(srs.ExportToWkt())
geotransform = [
min_lon, pixel_size, 0,
max_lat, 0, -pixel_size
]
dem_ds.SetGeoTransform(geotransform)
# 写入数据
band = dem_ds.GetRasterBand(1)
band.WriteArray(zi)
band.SetNoDataValue(-9999.0)
band.SetDescription("Elevation (meters)")
band.SetUnitType("meters")
# 计算统计信息
print("[DEM生成] 计算统计信息...")
band.FlushCache()
band.ComputeStatistics(False)
# 设置颜色表(可选)
try:
from matplotlib.cm import terrain
colors = terrain(np.linspace(0, 1, 256))
colors = (colors[:, :3] * 255).astype(np.uint8)
color_table = gdal.ColorTable()
for i in range(256):
color_table.SetColorEntry(i, tuple(colors[i]))
band.SetColorTable(color_table)
band.SetColorInterpretation(gdal.GCI_PaletteIndex)
except:
pass
dem_ds = None # 关闭文件
# 验证结果
if os.path.exists(output_dem_path):
ds = gdal.Open(output_dem_path, gdal.GA_ReadOnly)
if ds:
band = ds.GetRasterBand(1)
stats = band.GetStatistics(True, True)
ds = None
print(f"[DEM生成] DEM统计:")
print(f" 最小值: {stats[0]:.2f}m")
print(f" 最大值: {stats[1]:.2f}m")
print(f" 平均值: {stats[2]:.2f}m")
print(f" 标准差: {stats[3]:.2f}m")
elapsed_time = time.time() - start_time
file_size_mb = os.path.getsize(output_dem_path) / (1024 * 1024)
print(f"[DEM生成] 完成!")
print(f" 输出文件: {output_dem_path}")
print(f" 文件大小: {file_size_mb:.2f} MB")
print(f" 处理时间: {elapsed_time:.1f}")
return output_dem_path
# ========== 主函数 ==========
def generate_dem(tileset_path, dem_path=None, region_coords=None,
pixel_size=None, quality='medium', enable_enhancement=True,
debug=False):
"""
生成DEM的主函数
:param tileset_path: tileset.json路径
:param output_path: 输出目录
:param region_coords: 区域坐标 [(min_lon, min_lat), (max_lon, max_lat)]
:param pixel_size: DEM像素大小None则自动确定
:param quality: 质量等级 'low'|'medium'|'high'
:param enable_enhancement: 是否启用点云增强
:param debug: 调试模式
:return: DEM文件路径
"""
print("=" * 60)
print("3D Tiles 到 DEM 转换器")
print("=" * 60)
# 检查输入文件
if not os.path.exists(tileset_path):
raise FileNotFoundError(f"输入文件不存在: {tileset_path}")
# 设置输出路径
if dem_path is None:
script_dir = os.path.dirname(os.path.abspath(__file__))
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
# 解析点云
print(f"[主函数] 开始解析3D Tiles数据...")
if region_coords:
print(f"[主函数] 区域过滤: {region_coords}")
points = parse_tileset(
tileset_path,
region_coords=region_coords,
enable_enhancement=enable_enhancement,
debug=debug
)
if len(points) == 0:
print("[主函数] 错误: 未提取到点云数据")
return None
# 生成DEM
print(f"\n[主函数] 开始生成DEM...")
dem_path = points_to_dem(
points,
dem_path,
pixel_size=pixel_size,
quality=quality
)
print(f"\n[主函数] 转换完成!")
print(f" DEM文件: {dem_path}")
print(f" 点云数量: {len(points)}")
return dem_path
# ========== 示例用法 ==========
if __name__ == "__main__":
# 配置参数
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# 1. 指定tileset.json路径
# TILESET_PATH = os.path.join(SCRIPT_DIR, "../data/3dtiles/tileset.json")
TILESET_PATH = os.path.dirname(SCRIPT_DIR) + "/data/3dtiles/tileset.json"
# 2. 可选:指定区域(经纬度)
# BEIJING_REGION = [(116.3, 39.9), (116.5, 40.1)]
# SHANGHAI_REGION = [(121.4, 31.1), (121.6, 31.3)]
SHIMIAN_CORE = [(100.22476304, 29.18340151), (110.32476304, 31.28340151)]
SHIMIAN_CORE1 = [(102.216344, 29.376723), (102.218344, 29.378723)]
SHIMIAN_CORE2 = [(102.216344, 29.376723), (102.220344, 29.380723)]
SHIMIAN_CORE3 = [(102.216344, 29.376723), (102.222344, 29.382723)]
SHIMIAN_CORE4 = [(102.216344, 29.376723), (102.224344, 29.384723)]
REGION_COORDS = SHIMIAN_CORE3 # 不指定区域则处理全部数据
# 3. 可选:指定输出目录
OUTPUT_DEM_PATH = os.path.join(SCRIPT_DIR, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
# 4. 生成DEM
try:
dem_file = generate_dem(
tileset_path=TILESET_PATH,
dem_path=OUTPUT_DEM_PATH,
region_coords=REGION_COORDS,
quality='medium', # 质量等级
enable_enhancement=True, # 启用点云增强
debug=False # 调试模式
)
if dem_file:
print(f"\nDEM生成成功!")
print(f" 文件位置: {dem_file}")
print(f" 文件大小: {os.path.getsize(dem_file) / (1024*1024):.2f} MB")
else:
print("\nDEM生成失败!")
except Exception as e:
print(f"\n处理失败: {e}")
import traceback
traceback.print_exc()