From 9a09c1e1cf15dcf11569c1ee18a7a1d61c9e7af7 Mon Sep 17 00:00:00 2001 From: liyubo Date: Thu, 29 Jan 2026 11:51:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9D=A1=E5=BA=A6=E5=9D=A1=E5=90=91tif?= =?UTF-8?q?=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- b3dm/data_3dtiles_manager.py | 52 +- b3dm/data_3dtiles_to_dem.py | 314 +++++++- b3dm/earthwork_api.py | 104 ++- b3dm/earthwork_calculator_3d_tiles.py | 448 +++++++++-- b3dm/slope_aspect_tif.py | 1062 +++++++++++++++++++++++++ b3dm/terrain_api.py | 2 +- b3dm/terrain_calculator.py | 77 +- b3dm/tileset_data_source.py | 198 ++--- 8 files changed, 1931 insertions(+), 326 deletions(-) create mode 100644 b3dm/slope_aspect_tif.py diff --git a/b3dm/data_3dtiles_manager.py b/b3dm/data_3dtiles_manager.py index 18ae940..eff3c5f 100644 --- a/b3dm/data_3dtiles_manager.py +++ b/b3dm/data_3dtiles_manager.py @@ -145,11 +145,7 @@ class MinIO3DTilesManager: visited = set() # 下载入口文件 - entry_local_path = self.get_local_path( - entry_bucket, entry_path, - entry_bucket, entry_dir, - save_dir - ) + entry_local_path = self.get_local_path(entry_bucket, entry_path, save_dir) success, result = self.download_file(entry_bucket, entry_path, entry_local_path) if not success: @@ -352,39 +348,23 @@ class MinIO3DTilesManager: except Exception as e: return None - def get_local_path(self, bucket_name, object_name, base_bucket, base_object, save_dir): + def get_local_path(self, bucket_name, object_name, save_dir): """生成保持目录结构的本地路径""" clean_bucket = self.clean_filename(bucket_name) - bucket_dir = clean_bucket - if bucket_name == base_bucket and base_object: - base_dir = os.path.dirname(base_object) - - if base_dir: - if object_name.startswith(base_dir): - relative_path = object_name[len(base_dir):].lstrip('/\\') - else: - relative_path = object_name - else: - relative_path = object_name - else: - relative_path = object_name + path_parts = object_name.split('/') + cleaned_parts = [] + for part in path_parts: + cleaned_part = self.clean_filename(part) + if cleaned_part: + cleaned_parts.append(cleaned_part) - if relative_path: - path_parts = relative_path.split('/') - cleaned_parts = [] - for part in path_parts: - cleaned_part = self.clean_filename(part) - if cleaned_part: - cleaned_parts.append(cleaned_part) - - if cleaned_parts: - cleaned_relative = '/'.join(cleaned_parts) - local_path = os.path.join(save_dir, bucket_dir, cleaned_relative) - else: - local_path = os.path.join(save_dir, bucket_dir) + if cleaned_parts: + cleaned_relative = '/'.join(cleaned_parts) + local_path = os.path.join(save_dir, clean_bucket, cleaned_relative) else: - local_path = os.path.join(save_dir, bucket_dir) + local_path = os.path.join(save_dir, clean_bucket) + return os.path.normpath(local_path) @@ -437,11 +417,7 @@ class MinIO3DTilesManager: print(f"下载文件:{file_id}") visited.add(file_id) - local_path = self.get_local_path( - file_bucket, file_path, - base_bucket, base_dir, - save_dir - ) + local_path = self.get_local_path(file_bucket, file_path, save_dir) self.download_file(file_bucket, file_path, local_path) diff --git a/b3dm/data_3dtiles_to_dem.py b/b3dm/data_3dtiles_to_dem.py index 838c3a7..498b83c 100644 --- a/b3dm/data_3dtiles_to_dem.py +++ b/b3dm/data_3dtiles_to_dem.py @@ -778,7 +778,319 @@ def parse_tileset(tileset_path, region_coords=None, enable_enhancement=True, deb return [] # ========== DEM生成函数 ========== -def points_to_dem(points, output_dem_path, pixel_size=None, quality='medium'): +def points_to_dem(points, output_dem_path, pixel_size=None, quality='high', min_resolution=10): + """ + 将点云转换为DEM,确保足够的网格分辨率用于坡度坡向计算 + :param points: 点云数据 [(lon, lat, height), ...] + :param output_dem_path: 输出DEM文件路径 + :param pixel_size: 像素大小(度),如果为None则自动计算 + :param quality: 质量等级 'low'|'medium'|'high'|'terrain'(地形分析专用) + :param min_resolution: 最小网格分辨率(像素数),用于保证坡度计算 + """ + import time + import os + import numpy as np + from osgeo import gdal, osr + from scipy.interpolate import griddata + import warnings + + if len(points) == 0: + raise ValueError("无点云数据,无法生成DEM") + + start_time = time.time() + + # 转换为numpy数组 + points_array = np.array(points) + lons = points_array[:, 0] + lats = points_array[:, 1] + heights = points_array[:, 2] + + # 计算范围 + min_lon, max_lon = lons.min(), lons.max() + min_lat, max_lat = lats.min(), lats.max() + lon_range = max_lon - min_lon + lat_range = max_lat - min_lat + + print(f"[DEM生成] 点云范围:") + print(f" 经度: {min_lon:.6f}° ~ {max_lon:.6f}° (范围: {lon_range:.6f}°)") + print(f" 纬度: {min_lat:.6f}° ~ {max_lat:.6f}° (范围: {lat_range:.6f}°)") + print(f" 高程: {heights.min():.2f}m ~ {heights.max():.2f}m") + print(f" 点数: {len(points):,}") + + # 自动确定像素大小(优化版本) + if pixel_size is None: + # 根据数据量自动确定 + if quality == 'terrain': # 地形分析专用 + # 确保足够的网格分辨率用于坡度计算 + target_pixels = max(min_resolution, int(np.sqrt(len(points)) / 2)) + + # 根据数据范围计算像素大小 + lon_pixel = lon_range / target_pixels + lat_pixel = lat_range / target_pixels + + # 取较小的像素大小以保证分辨率 + pixel_size = min(lon_pixel, lat_pixel) + + # 设置像素大小范围限制 + pixel_size = max(pixel_size, 0.000001) # 最小约0.1米 + pixel_size = min(pixel_size, 0.0005) # 最大约55米 + + elif quality == 'high': + pixel_size = max(0.00001, lon_range / 100) # 至少1.1米,最多100像素 + elif quality == 'medium': + pixel_size = max(0.00002, lon_range / 50) # 至少2.2米,最多50像素 + else: # low + pixel_size = max(0.00005, lon_range / 20) # 至少5.6米,最多20像素 + + # 计算网格尺寸 + width = max(10, int(lon_range / pixel_size) + 1) + height = max(10, int(lat_range / pixel_size) + 1) + + # 确保网格大小符合最小分辨率要求 + if width * height < min_resolution * min_resolution: + print(f"[DEM生成] 警告:网格分辨率不足 ({width}x{height}),自动调整...") + # 重新计算像素大小以满足最小分辨率 + target_cells = min_resolution * min_resolution + target_pixel_size = np.sqrt(lon_range * lat_range / target_cells) + pixel_size = max(0.000001, min(pixel_size, target_pixel_size)) + + width = max(min_resolution, int(lon_range / pixel_size) + 1) + height = max(min_resolution, int(lat_range / pixel_size) + 1) + + # 扩展数据范围以增加边缘像素(有助于坡度计算) + expand_factor = 1.1 # 扩展10% + min_lon_exp = min_lon - lon_range * (expand_factor - 1) / 2 + max_lon_exp = max_lon + lon_range * (expand_factor - 1) / 2 + min_lat_exp = min_lat - lat_range * (expand_factor - 1) / 2 + max_lat_exp = max_lat + lat_range * (expand_factor - 1) / 2 + + print(f"[DEM生成] 网格设置:") + print(f" 像素大小: {pixel_size:.6f}° (~{pixel_size*111320:.1f}米)") + print(f" 网格尺寸: {width} × {height}") + print(f" 总像素数: {width * height:,}") + print(f" 点云密度: {len(points)/(width*height):.2f} 点/像素") + + # 创建网格 + x_grid = np.linspace(min_lon_exp, max_lon_exp, width) + y_grid = np.linspace(max_lat_exp, min_lat_exp, height) # 纬度从上到下 + xi, yi = np.meshgrid(x_grid, y_grid) + + # 插值 - 优化版本 + print("[DEM生成] 开始插值计算...") + + # 检查数据分布 + grid_points = np.column_stack([xi.flatten(), yi.flatten()]) + data_points = np.column_stack([lons, lats]) + + # 使用KD树加速最近邻查询 + try: + from scipy.spatial import cKDTree + tree = cKDTree(data_points) + dists, _ = tree.query(grid_points, k=1) + max_dist = np.max(dists) + print(f"[DEM生成] 最大最近邻距离: {max_dist*111320:.1f}米") + except: + pass + + # 尝试IDW插值(适用于地形) + try: + if len(points) > 100: + print("[DEM生成] 使用IDW插值...") + zi = idw_interpolation(lons, lats, heights, xi, yi, power=2) + else: + # 数据太少,使用线性+最近邻 + zi = griddata((lons, lats), heights, (xi, yi), method='linear', fill_value=np.nan) + except Exception as e: + print(f"[DEM生成] IDW插值失败: {e},使用线性插值") + zi = griddata((lons, lats), heights, (xi, yi), method='linear', fill_value=np.nan) + + # 处理空白区域 + nan_mask = np.isnan(zi) + if np.any(nan_mask): + nan_count = np.sum(nan_mask) + nan_percent = nan_count / (width * height) * 100 + print(f"[DEM生成] 插值空白: {nan_count:,} 像素 ({nan_percent:.1f}%)") + + if nan_percent < 50: # 空白区域少于50% + # 使用最近邻填充空白 + zi_nn = griddata((lons, lats), heights, (xi, yi), method='nearest') + zi[nan_mask] = zi_nn[nan_mask] + else: + print("[DEM生成] 警告:空白区域过多,使用最近邻插值") + zi = griddata((lons, lats), heights, (xi, yi), method='nearest') + + # 平滑处理(可选,有助于坡度计算) + if quality in ['terrain', 'high']: + try: + from scipy.ndimage import gaussian_filter + sigma = 0.5 # 高斯滤波参数 + zi = gaussian_filter(zi, sigma=sigma, mode='nearest') + print(f"[DEM生成] 应用高斯平滑 (sigma={sigma})") + except: + pass + + # 创建GeoTIFF + print("[DEM生成] 创建GeoTIFF文件...") + driver = gdal.GetDriverByName("GTiff") + + # 压缩选项 + if quality in ['terrain', 'high']: + options = ["COMPRESS=DEFLATE", "PREDICTOR=3", "ZLEVEL=6", "TILED=YES", "BLOCKXSIZE=256", "BLOCKYSIZE=256"] + else: + options = ["COMPRESS=LZW", "TILED=YES"] + + dem_ds = driver.Create(output_dem_path, width, height, 1, gdal.GDT_Float32, options) + + if dem_ds is None: + raise RuntimeError(f"无法创建DEM文件: {output_dem_path}") + + # 设置投影和地理变换 + srs = osr.SpatialReference() + srs.ImportFromEPSG(4326) # WGS84 + dem_ds.SetProjection(srs.ExportToWkt()) + + # 注意:使用扩展后的范围 + geotransform = [ + min_lon_exp, pixel_size, 0, + max_lat_exp, 0, -pixel_size + ] + dem_ds.SetGeoTransform(geotransform) + + # 写入数据 + band = dem_ds.GetRasterBand(1) + band.WriteArray(zi) + band.SetNoDataValue(-9999.0) + band.SetDescription("Elevation (meters)") + band.SetUnitType("meters") + + # 计算统计信息(处理可能的统计问题) + print("[DEM生成] 计算统计信息...") + band.FlushCache() + + try: + band.ComputeStatistics(False) + print("[DEM生成] 统计信息计算完成") + except RuntimeError as e: + print(f"[DEM生成] 警告:统计信息计算失败: {e}") + # 手动计算统计信息 + valid_mask = zi != -9999.0 + if np.any(valid_mask): + valid_data = zi[valid_mask] + stats = [ + float(np.min(valid_data)), + float(np.max(valid_data)), + float(np.mean(valid_data)), + float(np.std(valid_data)) + ] + band.SetStatistics(*stats) + print(f"[DEM生成] 手动设置统计信息:") + print(f" 最小值: {stats[0]:.2f}m") + print(f" 最大值: {stats[1]:.2f}m") + print(f" 平均值: {stats[2]:.2f}m") + print(f" 标准差: {stats[3]:.2f}m") + + # 构建金字塔 + if quality in ['terrain', 'high']: + print("[DEM生成] 构建金字塔...") + gdal.SetConfigOption('COMPRESS_OVERVIEW', 'DEFLATE') + dem_ds.BuildOverviews("AVERAGE", [2, 4, 8, 16]) + + dem_ds = None # 关闭文件 + + # 验证结果 + elapsed_time = time.time() - start_time + file_size_mb = os.path.getsize(output_dem_path) / (1024 * 1024) if os.path.exists(output_dem_path) else 0 + + # 重新打开验证 + try: + ds = gdal.Open(output_dem_path, gdal.GA_ReadOnly) + if ds: + band = ds.GetRasterBand(1) + actual_width = ds.RasterXSize + actual_height = ds.RasterYSize + + print(f"[DEM生成] 验证结果:") + print(f" 实际尺寸: {actual_width} × {actual_height}") + print(f" 文件大小: {file_size_mb:.2f} MB") + print(f" 处理时间: {elapsed_time:.1f}秒") + + # 检查是否适合坡度计算 + if actual_width >= 10 and actual_height >= 10: + print(f"[DEM生成] ✓ DEM分辨率适合坡度坡向计算") + else: + print(f"[DEM生成] ⚠ DEM分辨率较低,坡度计算可能不准确") + + ds = None + except: + print(f"[DEM生成] 完成! 文件: {output_dem_path}") + + return output_dem_path + + +def idw_interpolation(x, y, z, xi, yi, power=2, radius=None): + """ + 反距离权重插值 + """ + import numpy as np + from scipy.spatial import cKDTree + + x = np.asarray(x) + y = np.asarray(y) + z = np.asarray(z) + xi = np.asarray(xi) + yi = np.asarray(yi) + + # 展平网格 + xi_flat = xi.flatten() + yi_flat = yi.flatten() + + # 创建KD树 + tree = cKDTree(np.column_stack([x, y])) + + # 设置搜索半径 + if radius is None: + # 自动确定半径:平均点间距的3倍 + from scipy.spatial.distance import pdist + if len(x) > 1: + distances = pdist(np.column_stack([x, y])) + radius = np.mean(distances) * 3 + else: + radius = 0.01 # 默认值 + + # 查询最近邻点 + if len(z) > 50: # 数据较多时,限制搜索点数 + k = min(12, len(z)) + dists, idxs = tree.query(np.column_stack([xi_flat, yi_flat]), k=k, distance_upper_bound=radius) + else: + dists, idxs = tree.query(np.column_stack([xi_flat, yi_flat]), k=len(z)) + + # IDW插值 + zi_flat = np.zeros(len(xi_flat)) + + for i in range(len(xi_flat)): + valid_mask = idxs[i] < len(z) + if np.any(valid_mask): + valid_dists = dists[i][valid_mask] + valid_idx = idxs[i][valid_mask] + + # 避免除零 + valid_dists = np.maximum(valid_dists, 1e-10) + + # 计算权重 + weights = 1.0 / (valid_dists ** power) + weights = weights / np.sum(weights) + + # 加权平均 + zi_flat[i] = np.sum(z[valid_idx] * weights) + else: + zi_flat[i] = np.nan + + # 恢复形状 + zi = zi_flat.reshape(xi.shape) + + return zi + +def points_to_dem1(points, output_dem_path, pixel_size=None, quality='medium'): """ 将点云转换为DEM :param quality: 质量等级 'low'|'medium'|'high' diff --git a/b3dm/earthwork_api.py b/b3dm/earthwork_api.py index 5345eb7..2c52b60 100644 --- a/b3dm/earthwork_api.py +++ b/b3dm/earthwork_api.py @@ -14,33 +14,34 @@ earthwork_bp = Blueprint("earthwork", url_prefix="") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# 全局变量 -_calculator_point_cloud = None -_calculator_3d_tiles = None -_data_source_3d_tiles = None - # 初始化函数 -async def init_app(): +def init_app(url, type = "3dtiles"): """初始化应用""" - global _data_source_3d_tiles, _calculator_3d_tiles, _calculator_point_cloud + data_source = None + calculator_3d_tiles = None + calculator_point_cloud = None try: - # 配置数据源 - tileset_path = "./data/3dtiles/tileset.json" - # 初始化数据源 - _data_source_3d_tiles = TilesetDataSource(tileset_path) - await _data_source_3d_tiles.initialize() + data_source = TilesetDataSource(url) + data_source.dowload_map_data(url) - # 初始化计算器-3dTiles - _calculator_3d_tiles = EarthworkCalculator3dTiles(_data_source_3d_tiles) - - # 初始化计算器-点云 - point_cloud_path = "./data/pointCloud/simulated_points.laz" - _calculator_point_cloud = EarthworkCalculatorPointCloud(point_cloud_path) + if type == "3dtiles" : + # 初始化计算器-3dTiles + calculator_3d_tiles = EarthworkCalculator3dTiles(data_source) + elif type == "pointcloud" : + # 初始化计算器-点云 + calculator_point_cloud = EarthworkCalculatorPointCloud(data_source.tileset_path) + else : + logger.info(f"不支持的3d地图数据格式:{type}") + raise logger.info("土方量计算器初始化完成") - + return { + "data_source":data_source, + "calculator_3d_tiles":calculator_3d_tiles, + "calculator_point_cloud":calculator_point_cloud + } except ImportError as e: logger.error(f"依赖库缺失: {str(e)}") raise @@ -56,12 +57,26 @@ async def calc_earthwork(request: Request): 请求参数示例: { - "polygonCoords": [[120.1, 30.1], [120.2, 30.1], [120.2, 30.2], [120.1, 30.2]], - "designElevation": 50.0, + "polygonCoords": [ + [ + 115.70440468338526, + 30.77363140345639 + ], + [ + 115.70443054007985, + 30.773510462589584 + ], + [ + 115.70459702429197, + 30.77360789911405 + ] + ], + "designElevation": 100, "algorithm": "tin", - "resolution": 1.0, + "resolution": 1, "crs": "EPSG:4326", - "interpolationMethod": "linear" + "interpolationMethod": "linear", + "url": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json" } """ try: @@ -73,11 +88,14 @@ async def calc_earthwork(request: Request): # 2. 提取参数 polygon_coords = data.get("polygonCoords") design_elevation = data.get("designElevation") + url = data.get("url") if not polygon_coords: return _error_response("多边形坐标不能为空", 400) if design_elevation is None: return _error_response("设计高程不能为空", 400) + if url is None: + return _error_response("地图不能为空", 400) # 3. 可选参数 algorithm = data.get("algorithm", "tin") @@ -102,13 +120,13 @@ async def calc_earthwork(request: Request): return _error_response("分辨率必须在0-100米之间", 400) # 5. 确保计算器已初始化 - if _calculator_3d_tiles is None: - await init_app() + app_info = init_app(url) + calculator_3d_tiles = app_info.get("calculator_3d_tiles") # 6. 执行计算 algorithm_type = AlgorithmType(algorithm) - result = await _calculator_3d_tiles.calculate( + result = await calculator_3d_tiles.calculate( polygon_coords=polygon_coords, design_elevation=design_elevation, algorithm=algorithm_type, @@ -144,6 +162,10 @@ async def validate_earthwork(request: Request): if not polygon_coords: return _error_response("多边形坐标不能为空", 400) + url = data.get("url") + if url is None: + return _error_response("地图不能为空", 400) + # 3. 参数验证 if len(polygon_coords) < 3: return _error_response("多边形至少需要3个点", 400) @@ -153,11 +175,11 @@ async def validate_earthwork(request: Request): polygon_coords.append(polygon_coords[0]) # 4. 确保计算器已初始化 - if _calculator_3d_tiles is None: - await init_app() + app_info = init_app(url) + calculator_3d_tiles = app_info.get("calculator_3d_tiles") # 5. 执行验证 - validation_result = await _calculator_3d_tiles.validate(polygon_coords) + validation_result = await calculator_3d_tiles.validate(polygon_coords) # 6. 返回结果 return _success_response(validation_result) @@ -242,9 +264,7 @@ async def batch_calc_earthwork(request: Request): if len(calculations) > 100: return _error_response("批量计算数量超过限制(最多100个)", 400) - # 2. 确保计算器已初始化 - if _calculator_3d_tiles is None: - await init_app() + # 3. 执行批量计算 results = [] @@ -255,8 +275,9 @@ async def batch_calc_earthwork(request: Request): # 提取参数 polygon_coords = calc_data.get("polygonCoords") design_elevation = calc_data.get("designElevation") + url = calc_data.get("url") - if not polygon_coords or design_elevation is None: + if not polygon_coords or design_elevation is None or url is None: errors.append({ "index": i, "error": "缺少必要参数" @@ -280,11 +301,15 @@ async def batch_calc_earthwork(request: Request): resolution = calc_data.get("resolution", 1.0) crs = calc_data.get("crs", "EPSG:4326") interpolation_method = calc_data.get("interpolationMethod", "linear") + + # 2. 确保计算器已初始化 + app_info = init_app(url) + calculator_3d_tiles = app_info.get("calculator_3d_tiles") # 执行计算 algorithm_type = AlgorithmType(algorithm) - result = await _calculator_3d_tiles.calculate( + result = await calculator_3d_tiles.calculate( polygon_coords=polygon_coords, design_elevation=design_elevation, algorithm=algorithm_type, @@ -338,19 +363,22 @@ async def calc_earthwork_point_cloud(request: Request): polygon_coords = data.get("polygonCoords") # 计算区域多边形坐标 design_elev = data.get("designElevation") # 设计高程 crs = data.get("crs", "EPSG:4326") # 坐标系,默认WGS84 + url = data.get("url") + if url is None: + return _error_response("地图不能为空", 400) # 2. 确保计算器已初始化 - if _calculator_point_cloud is None: - await init_app() + app_info = init_app(url) + calculator_point_cloud = app_info.get("calculator_point_cloud") - result = _calculator_point_cloud.calculate_earthwork(polygon_coords=polygon_coords, design_elev=design_elev, crs=crs) + result = calculator_point_cloud.calculate_earthwork(polygon_coords=polygon_coords, design_elev=design_elev, crs=crs) # 3. 处理结果 if not result["success"]: return _error_response(result["error"], 400) # 4. 格式化结果 - formatted_result = _calculator_point_cloud.format_result(result) + formatted_result = calculator_point_cloud.format_result(result) # 5. 返回成功响应 return _success_response(formatted_result) diff --git a/b3dm/earthwork_calculator_3d_tiles.py b/b3dm/earthwork_calculator_3d_tiles.py index 090e0d7..96e1f9d 100644 --- a/b3dm/earthwork_calculator_3d_tiles.py +++ b/b3dm/earthwork_calculator_3d_tiles.py @@ -8,6 +8,7 @@ import logging from enum import Enum from abc import ABC, abstractmethod import math +from pyproj import Geod logger = logging.getLogger(__name__) @@ -37,13 +38,13 @@ class EarthworkResult3dTiles: """转换为字典""" return { "volume": { - "cut": round(self.cut_volume, 3), - "fill": round(self.fill_volume, 3), - "net": round(self.net_volume, 3), + "cut": round(self.cut_volume, 8), + "fill": round(self.fill_volume, 8), + "net": round(self.net_volume, 8), "unit": "m³" }, "area": { - "value": round(self.area, 3), + "value": round(self.area, 8), "unit": "m²" }, "elevation": { @@ -94,67 +95,375 @@ class TerrainDataSource(ABC): pass class GeometryUtils: - """几何计算工具类""" + """地理空间几何计算工具类(支持经纬度坐标)""" - @staticmethod - def calculate_polygon_area(polygon_coords: List[List[float]]) -> float: - """计算多边形面积(平面面积)""" - polygon_np = np.array(polygon_coords) - x = polygon_np[:, 0] - y = polygon_np[:, 1] - return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + def __init__(self, source_crs: str = "EPSG:4326", target_crs: str = "EPSG:3857"): + """ + 初始化 + Args: + source_crs: 源坐标系(通常是EPSG:4326) + target_crs: 目标投影坐标系(用于平面计算) + """ + self.source_crs = source_crs + self.target_crs = target_crs + self.geod = Geod(ellps="WGS84") + + # 创建坐标转换器 + self.transformer_to_proj = Transformer.from_crs( + source_crs, target_crs, always_xy=True + ) + self.transformer_to_geo = Transformer.from_crs( + target_crs, source_crs, always_xy=True + ) - @staticmethod - def is_point_in_polygon(point: np.ndarray, polygon: np.ndarray) -> bool: - """判断点是否在多边形内""" - from matplotlib.path import Path - path = Path(polygon) - return path.contains_point(point) + def calculate_polygon_area(self, polygon_coords: List[List[float]]) -> float: + """ + 计算多边形的地面实际面积(平方米) + + Args: + polygon_coords: 经纬度坐标列表 [[lon1, lat1], ...] + + Returns: + 面积(平方米) + """ + if len(polygon_coords) < 3: + return 0.0 + + # 确保多边形闭合 + closed_coords = self._ensure_closed_polygon(polygon_coords) + + # 提取经纬度 + lons = [coord[0] for coord in closed_coords] + lats = [coord[1] for coord in closed_coords] + + # 使用测地线计算面积 + area, _ = self.geod.polygon_area_perimeter(lons, lats) + return abs(area) - @staticmethod - def calculate_triangle_area(points: np.ndarray) -> float: - """计算三角形面积""" - a = np.linalg.norm(points[0] - points[1]) - b = np.linalg.norm(points[1] - points[2]) - c = np.linalg.norm(points[2] - points[0]) + def is_point_in_polygon(self, point: Tuple[float, float], + polygon_coords: List[List[float]], + use_spherical: bool = True) -> bool: + """ + 判断点是否在多边形内(支持地球表面判断) + + Args: + point: 点坐标 (lon, lat) + polygon_coords: 多边形顶点坐标 + use_spherical: 是否使用球面算法 + + Returns: + 是否在多边形内 + """ + if len(polygon_coords) < 3: + return False + + if use_spherical: + # 方法1:球面射线法(更准确) + return self._is_point_in_polygon_spherical(point, polygon_coords) + else: + # 方法2:投影到平面后判断(更快) + return self._is_point_in_polygon_planar(point, polygon_coords) + + def calculate_triangle_area(self, points: np.ndarray) -> float: + """ + 计算三角形的地面面积(平方米) + + Args: + points: 3×2数组,每行是[lon, lat] + + Returns: + 三角形地面面积(平方米) + """ + if points.shape != (3, 2): + raise ValueError("需要3个点的坐标") + + # 转换为球面坐标计算 + lons = points[:, 0] + lats = points[:, 1] + + # 使用球面三角形面积公式 + R = 6378137.0 # WGS84地球半径(米) + + # 转换为弧度 + lon_rad = np.radians(lons) + lat_rad = np.radians(lats) + + # 计算球面三角形的面积 + # 使用L'Huilier公式 + a = self._spherical_distance(lon_rad[0], lat_rad[0], lon_rad[1], lat_rad[1]) + b = self._spherical_distance(lon_rad[1], lat_rad[1], lon_rad[2], lat_rad[2]) + c = self._spherical_distance(lon_rad[2], lat_rad[2], lon_rad[0], lat_rad[0]) + s = (a + b + c) / 2 - return np.sqrt(s * (s - a) * (s - b) * (s - c)) + + # 防止数值误差 + tan_e2 = np.tan(s/2) * np.tan((s-a)/2) * np.tan((s-b)/2) * np.tan((s-c)/2) + tan_e2 = max(tan_e2, 0) # 避免负值 + + if tan_e2 > 0: + E = 4 * np.arctan(np.sqrt(tan_e2)) + else: + E = 0 + + area = R * R * E + return area - @staticmethod - def create_grid(polygon: np.ndarray, resolution: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """创建规则格网""" - x_min, y_min = polygon.min(axis=0) - x_max, y_max = polygon.max(axis=0) + def create_grid(self, polygon_coords: List[List[float]], + resolution_m: float, + use_projection: bool = True) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + 创建规则格网(地面距离为单位的网格) - # 扩展一个格网单元 - x_min -= resolution - x_max += resolution - y_min -= resolution - y_max += resolution - - x_grid = np.arange(x_min, x_max + resolution, resolution) - y_grid = np.arange(y_min, y_max + resolution, resolution) - xx, yy = np.meshgrid(x_grid, y_grid) - - return xx, yy, x_grid, y_grid + Args: + polygon_coords: 多边形坐标 + resolution_m: 网格分辨率(米) + use_projection: 是否使用投影坐标系 + + Returns: + xx, yy: 网格坐标 + grid_coords_geo: 网格点的地理坐标 + """ + if use_projection: + # 方法1:投影到平面坐标系创建网格 + return self._create_grid_projected(polygon_coords, resolution_m) + else: + # 方法2:直接在经纬度上创建近似网格(小区域可用) + return self._create_grid_geographic(polygon_coords, resolution_m) - @staticmethod - def interpolate_grid(xx: np.ndarray, yy: np.ndarray, - points: np.ndarray, method: str = 'linear') -> np.ndarray: - """格网插值""" - from scipy.interpolate import LinearNDInterpolator, CloughTocher2DInterpolator + def interpolate_grid(self, xx: np.ndarray, yy: np.ndarray, + points: np.ndarray, + method: str = 'linear', + return_geo: bool = False) -> np.ndarray: + """ + 格网插值 - grid_points = np.column_stack([xx.ravel(), yy.ravel()]) + Args: + xx, yy: 网格坐标(投影坐标系) + points: 已知点,每行是[lon, lat, elevation]或[x_proj, y_proj, elevation] + method: 插值方法 'linear' 或 'cubic' + return_geo: 是否返回地理坐标 + + Returns: + 插值后的高程网格 + """ + # 确保points是投影坐标 + if points.shape[1] != 3: + raise ValueError("points应为3列: x, y, z") + # 如果输入是地理坐标,转换为投影坐标 + if np.max(np.abs(points[:, 0])) > 180: # 粗略判断 + # 已经是投影坐标 + points_proj = points + else: + # 转换为投影坐标 + x_proj, y_proj = self.transformer_to_proj.transform( + points[:, 0], points[:, 1] + ) + points_proj = np.column_stack([x_proj, y_proj, points[:, 2]]) + + # 创建插值器 if method == 'linear': - interpolator = LinearNDInterpolator(points[:, :2], points[:, 2]) + interpolator = LinearNDInterpolator( + points_proj[:, :2], + points_proj[:, 2], + fill_value=np.nan + ) elif method == 'cubic': - interpolator = CloughTocher2DInterpolator(points[:, :2], points[:, 2]) + interpolator = CloughTocher2DInterpolator( + points_proj[:, :2], + points_proj[:, 2], + fill_value=np.nan + ) else: raise ValueError(f"不支持的插值方法: {method}") + # 插值 + grid_points = np.column_stack([xx.ravel(), yy.ravel()]) elevations = interpolator(grid_points) - return elevations.reshape(xx.shape) + result = elevations.reshape(xx.shape) + + if return_geo: + # 如果需要,将网格点转回地理坐标 + lon_grid, lat_grid = self.transformer_to_geo.transform( + xx.ravel(), yy.ravel() + ) + lon_grid = lon_grid.reshape(xx.shape) + lat_grid = lat_grid.reshape(xx.shape) + return result, lon_grid, lat_grid + + return result + + # ============ 私有方法 ============ + + def _ensure_closed_polygon(self, coords: List[List[float]]) -> List[List[float]]: + """确保多边形闭合""" + if len(coords) >= 3: + # 使用 numpy 比较 + if not np.array_equal(coords[0], coords[-1]): + return coords + [coords[0]] + return coords + + def _spherical_distance(self, lon1_rad: float, lat1_rad: float, + lon2_rad: float, lat2_rad: float) -> float: + """计算球面两点间角距离""" + dlon = lon2_rad - lon1_rad + dlat = lat2_rad - lat1_rad + + a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2 + return 2 * np.arcsin(np.sqrt(a)) + + def _is_point_in_polygon_spherical(self, point: Tuple[float, float], + polygon_coords: List[List[float]]) -> bool: + """球面射线法判断点是否在多边形内""" + lon_p, lat_p = point + closed_polygon = self._ensure_closed_polygon(polygon_coords) + + # 将多边形的边转换为球面大圆弧 + crossings = 0 + n = len(closed_polygon) - 1 + + for i in range(n): + lon1, lat1 = closed_polygon[i] + lon2, lat2 = closed_polygon[i + 1] + + # 检查射线是否与边相交(近似算法) + # 简化:使用平面近似,对小区域足够准确 + if ((lat1 > lat_p) != (lat2 > lat_p)) and \ + (lon_p < (lon2 - lon1) * (lat_p - lat1) / (lat2 - lat1) + lon1): + crossings += 1 + + return crossings % 2 == 1 + + def _is_point_in_polygon_planar(self, point: Tuple[float, float], + polygon_coords: List[List[float]]) -> bool: + """投影到平面后判断""" + # 转换为投影坐标 + point_proj = np.array(self.transformer_to_proj.transform(point[0], point[1])).reshape(1, 2) + polygon_proj = np.array([ + self.transformer_to_proj.transform(lon, lat) + for lon, lat in polygon_coords + ]) + + # 使用平面方法判断 + path = Path(polygon_proj) + return path.contains_point(point_proj[0]) + + def _create_grid_projected(self, polygon_coords: List[List[float]], + resolution_m: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """在投影坐标系中创建网格""" + # 将多边形转换为投影坐标 + polygon_proj = [] + for lon, lat in polygon_coords: + x, y = self.transformer_to_proj.transform(lon, lat) + polygon_proj.append([x, y]) + + polygon_proj = np.array(polygon_proj) + + # 计算边界框 + x_min, y_min = polygon_proj.min(axis=0) + x_max, y_max = polygon_proj.max(axis=0) + + # 扩展半个网格 + x_min -= resolution_m / 2 + x_max += resolution_m / 2 + y_min -= resolution_m / 2 + y_max += resolution_m / 2 + + # 创建网格 + x_grid = np.arange(x_min, x_max + resolution_m, resolution_m) + y_grid = np.arange(y_min, y_max + resolution_m, resolution_m) + xx, yy = np.meshgrid(x_grid, y_grid) + + # 将网格点转回地理坐标 + grid_coords_geo = [] + for x, y in zip(xx.ravel(), yy.ravel()): + lon, lat = self.transformer_to_geo.transform(x, y) + grid_coords_geo.append([lon, lat]) + grid_coords_geo = np.array(grid_coords_geo).reshape(xx.shape[0], xx.shape[1], 2) + + return xx, yy, grid_coords_geo + + def _create_grid_geographic(self, polygon_coords: List[List[float]], + resolution_m: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """在经纬度坐标系中创建近似网格""" + # 计算中心点 + lons = [coord[0] for coord in polygon_coords] + lats = [coord[1] for coord in polygon_coords] + center_lon = np.mean(lons) + center_lat = np.mean(lats) + + # 计算经纬度到米的换算系数 + lat_rad = np.radians(center_lat) + meters_per_degree_lon = 111319.9 * np.cos(lat_rad) + meters_per_degree_lat = 111000.0 + + # 计算边界框(米) + x_min_m, x_max_m, y_min_m, y_max_m = 1e9, -1e9, 1e9, -1e9 + for lon, lat in polygon_coords: + x_m = (lon - center_lon) * meters_per_degree_lon + y_m = (lat - center_lat) * meters_per_degree_lat + x_min_m = min(x_min_m, x_m) + x_max_m = max(x_max_m, x_m) + y_min_m = min(y_min_m, y_m) + y_max_m = max(y_max_m, y_m) + + # 扩展半个网格 + x_min_m -= resolution_m / 2 + x_max_m += resolution_m / 2 + y_min_m -= resolution_m / 2 + y_max_m += resolution_m / 2 + + # 创建网格(米) + x_grid_m = np.arange(x_min_m, x_max_m + resolution_m, resolution_m) + y_grid_m = np.arange(y_min_m, y_max_m + resolution_m, resolution_m) + + # 转换为经纬度 + x_grid_lon = center_lon + x_grid_m / meters_per_degree_lon + y_grid_lat = center_lat + y_grid_m / meters_per_degree_lat + + xx, yy = np.meshgrid(x_grid_lon, y_grid_lat) + + # 网格坐标(经纬度) + grid_coords_geo = np.dstack([xx, yy]) + + return xx, yy, grid_coords_geo + + def get_polygon_bounds(self, polygon_coords: List[List[float]], + in_meters: bool = False) -> dict: + """ + 获取多边形边界信息 + + Args: + polygon_coords: 多边形坐标 + in_meters: 是否返回米为单位 + + Returns: + 边界信息字典 + """ + lons = [coord[0] for coord in polygon_coords] + lats = [coord[1] for coord in polygon_coords] + + bounds = { + 'min_lon': min(lons), + 'max_lon': max(lons), + 'min_lat': min(lats), + 'max_lat': max(lats), + 'center_lon': (min(lons) + max(lons)) / 2, + 'center_lat': (min(lats) + max(lats)) / 2 + } + + if in_meters: + # 计算实际尺寸(米) + center_lat = bounds['center_lat'] + lat_rad = np.radians(center_lat) + meters_per_degree_lon = 111319.9 * np.cos(lat_rad) + meters_per_degree_lat = 111000.0 + + bounds['width_m'] = (bounds['max_lon'] - bounds['min_lon']) * meters_per_degree_lon + bounds['height_m'] = (bounds['max_lat'] - bounds['min_lat']) * meters_per_degree_lat + bounds['area_m2'] = self.calculate_polygon_area(polygon_coords) + + return bounds class EarthworkCalculator3dTiles: """土方量计算器""" @@ -167,6 +476,7 @@ class EarthworkCalculator3dTiles: data_source: 地形数据源 """ self.data_source = data_source + self.geometryUtils = GeometryUtils() self._transformer_cache = {} async def calculate(self, @@ -183,7 +493,7 @@ class EarthworkCalculator3dTiles: polygon_coords: 多边形坐标 design_elevation: 设计高程 algorithm: 计算算法 - resolution: 格网分辨率(米) + resolution: 格网分辨率(米) target_crs: 目标坐标系 interpolation_method: 插值方法 @@ -195,7 +505,7 @@ class EarthworkCalculator3dTiles: points = await self.data_source.get_points_in_polygon(polygon_coords) if points.size == 0: - raise ValueError("区域内没有找到高程数据") + raise ValueError("区域内没有找到顶点数据") # 2. 坐标转换 points = await self._transform_coordinates(points, target_crs) @@ -250,10 +560,10 @@ class EarthworkCalculator3dTiles: polygon_np = np.array(polygon_coords) # 创建格网 - xx, yy, x_grid, y_grid = GeometryUtils.create_grid(polygon_np, resolution) + xx, yy, x_grid, y_grid = self.geometryUtils.create_grid(polygon_np, resolution) # 插值 - natural_elevations = GeometryUtils.interpolate_grid(xx, yy, points, interpolation_method) + natural_elevations = self.geometryUtils.interpolate_grid(xx, yy, points, interpolation_method) # 初始化挖填量 cut_volume = 0.0 @@ -273,7 +583,7 @@ class EarthworkCalculator3dTiles: # 检查格网中心点是否在多边形内 cell_center = cell_corners.mean(axis=0) - if not GeometryUtils.is_point_in_polygon(cell_center, polygon_np): + if not self.geometryUtils.is_point_in_polygon(cell_center, polygon_np): continue # 获取格网四个角点的高程 @@ -301,7 +611,7 @@ class EarthworkCalculator3dTiles: cut_volume += abs(height_diff) * cell_area # 计算统计信息 - area = GeometryUtils.calculate_polygon_area(polygon_coords) + area = self.geometryUtils.calculate_polygon_area(polygon_coords) mask = ~np.isnan(natural_elevations) valid_elevations = natural_elevations[mask] @@ -342,14 +652,16 @@ class EarthworkCalculator3dTiles: triangle_center = triangle_points.mean(axis=0)[:2] # 检查三角形中心是否在多边形内 - if not GeometryUtils.is_point_in_polygon(triangle_center, polygon_np): + if not self.geometryUtils.is_point_in_polygon(triangle_center, polygon_np): continue # 计算三角形面积 - area = GeometryUtils.calculate_triangle_area(triangle_points[:, :2]) + area = self.geometryUtils.calculate_triangle_area(triangle_points[:, :2]) + if math.isnan(area) : + continue total_area += area - # 计算平均高程(使用三个顶点的高程) + # 计算平均高程(使用三个顶点的高程) avg_elevation = triangle_points[:, 2].mean() # 计算挖填量 @@ -358,9 +670,13 @@ class EarthworkCalculator3dTiles: fill_volume += height_diff * area else: cut_volume += abs(height_diff) * area - + # if math.isnan(cut_volume) : + # print("cut_volume变为nan") + # if math.isnan(fill_volume) : + # print("fill_volume变为nan") + # 计算统计信息 - area = GeometryUtils.calculate_polygon_area(polygon_coords) + area = self.geometryUtils.calculate_polygon_area(polygon_coords) return EarthworkResult3dTiles( cut_volume=cut_volume, @@ -397,11 +713,11 @@ class EarthworkCalculator3dTiles: triangle_points = points[simplex] triangle_center = triangle_points.mean(axis=0)[:2] - if not GeometryUtils.is_point_in_polygon(triangle_center, polygon_np): + if not self.geometryUtils.is_point_in_polygon(triangle_center, polygon_np): continue # 计算三角形面积 - area = GeometryUtils.calculate_triangle_area(triangle_points[:, :2]) + area = self.geometryUtils.calculate_triangle_area(triangle_points[:, :2]) total_area += area # 对于每个三角形,计算三棱柱体积 @@ -418,7 +734,7 @@ class EarthworkCalculator3dTiles: # 计算边的平均挖填高度 avg_height = (abs(height_i) + abs(height_j)) / 2 - # 计算边的面积(假设边宽度为resolution) + # 计算边的面积(假设边宽度为resolution) edge_area = edge_length * resolution if height_i > 0 or height_j > 0: @@ -426,7 +742,7 @@ class EarthworkCalculator3dTiles: else: cut_volume += avg_height * edge_area - area = GeometryUtils.calculate_polygon_area(polygon_coords) + area = self.geometryUtils.calculate_polygon_area(polygon_coords) return EarthworkResult3dTiles( cut_volume=cut_volume, @@ -477,7 +793,7 @@ class EarthworkCalculator3dTiles: validation_result = { "polygon_valid": len(polygon_coords) >= 3, - "area": GeometryUtils.calculate_polygon_area(polygon_coords), + "area": self.geometryUtils.calculate_polygon_area(polygon_coords), "points_available": points.size > 0, "points_count": points.shape[0] if points.size > 0 else 0, "data_quality": "good" if points.shape[0] > 100 else "poor", diff --git a/b3dm/slope_aspect_tif.py b/b3dm/slope_aspect_tif.py new file mode 100644 index 0000000..e3a4cfe --- /dev/null +++ b/b3dm/slope_aspect_tif.py @@ -0,0 +1,1062 @@ +import numpy as np +from osgeo import gdal +import math +import os +from typing import Optional, Tuple, Dict, Any + + +class SlopeAspectGenerator: + """ + 坡度和坡向生成器 + 支持生成单独的坡度、坡向文件,或包含两者的合并文件 + """ + + def __init__(self, dem_path: str): + """ + 初始化坡度和坡向生成器 + + Args: + dem_path: 输入DEM文件路径 + """ + self.dem_path = dem_path + self.dataset = gdal.Open(dem_path, gdal.GA_ReadOnly) + + if not self.dataset: + raise FileNotFoundError(f"无法打开DEM文件: {dem_path}") + + # 获取地理信息 + self.geotransform = self.dataset.GetGeoTransform() + self.projection = self.dataset.GetProjection() + self.band = self.dataset.GetRasterBand(1) + self.data = self.band.ReadAsArray() + + # 获取栅格信息 + self.cols = self.dataset.RasterXSize + self.rows = self.dataset.RasterYSize + self.no_data = self.band.GetNoDataValue() + + # 如果DEM没有无效值设置,使用-9999 + if self.no_data is None: + self.no_data = -9999.0 + print(f"警告: DEM未设置无效值,将使用 {self.no_data} 作为无效值") + + # 像元大小(考虑非正方形像元) + self.cell_size_x = abs(self.geotransform[1]) + self.cell_size_y = abs(self.geotransform[5]) + + print(f"DEM信息: {self.rows}行 x {self.cols}列") + print(f"X方向像元大小: {self.cell_size_x:.4f}") + print(f"Y方向像元大小: {self.cell_size_y:.4f}") + print(f"无效值: {self.no_data}") + + def calculate_slope(self, z_factor: float = 1.0, algorithm: str = 'Horn') -> np.ndarray: + """ + 计算坡度 + + Args: + z_factor: 垂直单位转换因子 + algorithm: 算法选项 'Horn' 或 'ZevenbergenThorne' + + Returns: + 坡度数组(单位:度) + """ + print(f"正在计算坡度 (算法: {algorithm})...") + + # 创建输出数组 + slope = np.full((self.rows, self.cols), self.no_data, dtype=np.float32) + + # 创建填充后的数据(用于边界处理) + padded_data = np.pad( + self.data, + pad_width=1, + mode='constant', + constant_values=self.no_data + ) + + # 创建有效掩码 + valid_mask = (self.data != self.no_data) + padded_valid = np.pad(valid_mask, pad_width=1, mode='constant', constant_values=False) + + for i in range(self.rows): + for j in range(self.cols): + if not valid_mask[i, j]: + continue + + i_pad, j_pad = i + 1, j + 1 + + # 获取3x3窗口 + window = padded_data[i_pad-1:i_pad+2, j_pad-1:j_pad+2] + window_valid = padded_valid[i_pad-1:i_pad+2, j_pad-1:j_pad+2] + + # 如果窗口内有无效值,跳过 + if not np.all(window_valid): + continue + + if algorithm.upper() == 'ZEVERBERGENTHORNE': + # Zevenbergen & Thorne 算法 + dz_dx = (window[1, 2] - window[1, 0]) / (2 * self.cell_size_x) + dz_dy = (window[0, 1] - window[2, 1]) / (2 * self.cell_size_y) + else: + # Horn 算法(默认) + dz_dx = ((window[0, 2] + 2 * window[1, 2] + window[2, 2]) - + (window[0, 0] + 2 * window[1, 0] + window[2, 0])) / (8 * self.cell_size_x) + + dz_dy = ((window[2, 0] + 2 * window[2, 1] + window[2, 2]) - + (window[0, 0] + 2 * window[0, 1] + window[0, 2])) / (8 * self.cell_size_y) + + # 计算坡度(弧度转角度) + slope_rad = math.atan(z_factor * math.sqrt(dz_dx**2 + dz_dy**2)) + slope_deg = math.degrees(slope_rad) + slope[i, j] = slope_deg + + # 边界处理:将边缘设为无效值 + slope[0, :] = self.no_data + slope[-1, :] = self.no_data + slope[:, 0] = self.no_data + slope[:, -1] = self.no_data + + valid_slope = slope[slope != self.no_data] + if len(valid_slope) > 0: + print(f"坡度范围: {np.min(valid_slope):.2f} - {np.max(valid_slope):.2f} 度") + + return slope + + def calculate_aspect(self, algorithm: str = 'Horn') -> np.ndarray: + """ + 计算坡向 + + Args: + algorithm: 算法选项 'Horn' 或 'ZevenbergenThorne' + + Returns: + 坡向数组(单位:度,0-360,北为0度,顺时针增加) + """ + print(f"正在计算坡向 (算法: {algorithm})...") + + # 创建输出数组 + aspect = np.full((self.rows, self.cols), -1.0, dtype=np.float32) # 平地用-1表示 + + # 创建填充后的数据 + padded_data = np.pad( + self.data, + pad_width=1, + mode='constant', + constant_values=self.no_data + ) + + # 创建有效掩码 + valid_mask = (self.data != self.no_data) + padded_valid = np.pad(valid_mask, pad_width=1, mode='constant', constant_values=False) + + for i in range(self.rows): + for j in range(self.cols): + if not valid_mask[i, j]: + continue + + i_pad, j_pad = i + 1, j + 1 + + # 获取3x3窗口 + window = padded_data[i_pad-1:i_pad+2, j_pad-1:j_pad+2] + window_valid = padded_valid[i_pad-1:i_pad+2, j_pad-1:j_pad+2] + + # 如果窗口内有无效值,跳过 + if not np.all(window_valid): + continue + + if algorithm.upper() == 'ZEVERBERGENTHORNE': + # Zevenbergen & Thorne 算法 + dz_dx = (window[1, 2] - window[1, 0]) / (2 * self.cell_size_x) + dz_dy = (window[0, 1] - window[2, 1]) / (2 * self.cell_size_y) + else: + # Horn 算法(默认) + dz_dx = ((window[0, 2] + 2 * window[1, 2] + window[2, 2]) - + (window[0, 0] + 2 * window[1, 0] + window[2, 0])) / (8 * self.cell_size_x) + + dz_dy = ((window[2, 0] + 2 * window[2, 1] + window[2, 2]) - + (window[0, 0] + 2 * window[0, 1] + window[0, 2])) / (8 * self.cell_size_y) + + # 计算坡向 + if abs(dz_dx) < 1e-10 and abs(dz_dy) < 1e-10: + # 平地 + aspect[i, j] = -1.0 + else: + aspect_rad = math.atan2(dz_dy, -dz_dx) # atan2(y, x) + aspect_deg = math.degrees(aspect_rad) + + # 转换为0-360度(北为0度,顺时针) + if aspect_deg < 0: + aspect_deg += 360.0 + + aspect[i, j] = aspect_deg + + # 边界处理:将边缘设为-1 + aspect[0, :] = -1.0 + aspect[-1, :] = -1.0 + aspect[:, 0] = -1.0 + aspect[:, -1] = -1.0 + + valid_aspect = aspect[aspect != -1.0] + if len(valid_aspect) > 0: + print(f"坡向范围: {np.min(valid_aspect):.2f} - {np.max(valid_aspect):.2f} 度") + + return aspect + + def calculate_slope_aspect(self, z_factor: float = 1.0, + algorithm: str = 'Horn') -> Tuple[np.ndarray, np.ndarray]: + """ + 同时计算坡度和坡向(优化版本,减少重复计算) + + Args: + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + (slope_array, aspect_array) + """ + print(f"同时计算坡度和坡向 (算法: {algorithm})...") + + # 创建输出数组 + slope = np.full((self.rows, self.cols), self.no_data, dtype=np.float32) + aspect = np.full((self.rows, self.cols), -1.0, dtype=np.float32) + + # 创建填充后的数据 + padded_data = np.pad( + self.data, + pad_width=1, + mode='constant', + constant_values=self.no_data + ) + + # 创建有效掩码 + valid_mask = (self.data != self.no_data) + padded_valid = np.pad(valid_mask, pad_width=1, mode='constant', constant_values=False) + + for i in range(self.rows): + for j in range(self.cols): + if not valid_mask[i, j]: + continue + + i_pad, j_pad = i + 1, j + 1 + + # 获取3x3窗口 + window = padded_data[i_pad-1:i_pad+2, j_pad-1:j_pad+2] + window_valid = padded_valid[i_pad-1:i_pad+2, j_pad-1:j_pad+2] + + # 如果窗口内有无效值,跳过 + if not np.all(window_valid): + continue + + # 计算导数(一次计算,两个都用) + if algorithm.upper() == 'ZEVERBERGENTHORNE': + dz_dx = (window[1, 2] - window[1, 0]) / (2 * self.cell_size_x) + dz_dy = (window[0, 1] - window[2, 1]) / (2 * self.cell_size_y) + else: + dz_dx = ((window[0, 2] + 2 * window[1, 2] + window[2, 2]) - + (window[0, 0] + 2 * window[1, 0] + window[2, 0])) / (8 * self.cell_size_x) + + dz_dy = ((window[2, 0] + 2 * window[2, 1] + window[2, 2]) - + (window[0, 0] + 2 * window[0, 1] + window[0, 2])) / (8 * self.cell_size_y) + + # 计算坡度 + slope_rad = math.atan(z_factor * math.sqrt(dz_dx**2 + dz_dy**2)) + slope_deg = math.degrees(slope_rad) + slope[i, j] = slope_deg + + # 计算坡向 + if abs(dz_dx) < 1e-10 and abs(dz_dy) < 1e-10: + aspect[i, j] = -1.0 + else: + aspect_rad = math.atan2(dz_dy, -dz_dx) + aspect_deg = math.degrees(aspect_rad) + + if aspect_deg < 0: + aspect_deg += 360.0 + + aspect[i, j] = aspect_deg + + # 边界处理 + slope[0, :] = self.no_data + slope[-1, :] = self.no_data + slope[:, 0] = self.no_data + slope[:, -1] = self.no_data + + aspect[0, :] = -1.0 + aspect[-1, :] = -1.0 + aspect[:, 0] = -1.0 + aspect[:, -1] = -1.0 + + # 显示统计信息 + valid_slope = slope[slope != self.no_data] + valid_aspect = aspect[aspect != -1.0] + + if len(valid_slope) > 0: + print(f"坡度范围: {np.min(valid_slope):.2f} - {np.max(valid_slope):.2f} 度") + if len(valid_aspect) > 0: + print(f"坡向范围: {np.min(valid_aspect):.2f} - {np.max(valid_aspect):.2f} 度") + + return slope, aspect + + def save_raster(self, data: np.ndarray, output_path: str, + band_name: str = '', data_type: int = gdal.GDT_Float32, + no_data_value: Optional[float] = None) -> None: + """ + 保存栅格数据为TIFF文件 + + Args: + data: 栅格数据数组 + output_path: 输出文件路径 + band_name: 波段名称 + data_type: GDAL数据类型 + no_data_value: 无效值 + """ + driver = gdal.GetDriverByName('GTiff') + + out_dataset = driver.Create( + output_path, + self.cols, + self.rows, + 1, # 单波段 + data_type, + options=['COMPRESS=LZW', 'PREDICTOR=2', 'TILED=YES'] + ) + + out_dataset.SetGeoTransform(self.geotransform) + out_dataset.SetProjection(self.projection) + + out_band = out_dataset.GetRasterBand(1) + out_band.WriteArray(data) + + if no_data_value is not None: + out_band.SetNoDataValue(no_data_value) + elif band_name == 'Slope': + out_band.SetNoDataValue(self.no_data) + elif band_name == 'Aspect': + out_band.SetNoDataValue(-1.0) + + if band_name: + out_band.SetDescription(band_name) + + # 设置统计信息(加速QGIS等软件加载) + out_band.ComputeStatistics(False) + + out_dataset.FlushCache() + out_dataset = None + + print(f"文件已保存: {output_path}") + if band_name: + print(f" 波段: {band_name}") + + def save_combined_raster(self, slope_data: np.ndarray, aspect_data: np.ndarray, + output_path: str) -> None: + """ + 保存包含坡度和坡向两个波段的TIFF文件 + + Args: + slope_data: 坡度数据 + aspect_data: 坡向数据 + output_path: 输出文件路径 + """ + driver = gdal.GetDriverByName('GTiff') + + # 创建包含2个波段的数据集 + out_dataset = driver.Create( + output_path, + self.cols, + self.rows, + 2, # 2个波段 + gdal.GDT_Float32, + options=['COMPRESS=LZW', 'PREDICTOR=2', 'TILED=YES'] + ) + + # 设置地理信息 + out_dataset.SetGeoTransform(self.geotransform) + out_dataset.SetProjection(self.projection) + + # 写入坡度波段(波段1) + slope_band = out_dataset.GetRasterBand(1) + slope_band.WriteArray(slope_data) + slope_band.SetNoDataValue(self.no_data) + slope_band.SetDescription('Slope') + slope_band.SetMetadataItem('UNITS', 'degrees') + slope_band.SetMetadataItem('RANGE', '0-90') + slope_band.ComputeStatistics(False) + + # 写入坡向波段(波段2) + aspect_band = out_dataset.GetRasterBand(2) + aspect_band.WriteArray(aspect_data) + aspect_band.SetNoDataValue(-1.0) + aspect_band.SetDescription('Aspect') + aspect_band.SetMetadataItem('UNITS', 'degrees') + aspect_band.SetMetadataItem('RANGE', '0-360') + aspect_band.SetMetadataItem('DIRECTION', 'Clockwise from North') + aspect_band.ComputeStatistics(False) + + # 设置数据集元数据 + out_dataset.SetMetadata({ + 'SOURCE_DEM': os.path.basename(self.dem_path), + 'PROCESSING_METHOD': 'SlopeAspectGenerator', + 'CELL_SIZE_X': str(self.cell_size_x), + 'CELL_SIZE_Y': str(self.cell_size_y) + }) + + out_dataset.FlushCache() + out_dataset = None + + print(f"合并文件已保存: {output_path}") + print(f" 波段1: Slope (坡度, 单位:度)") + print(f" 波段2: Aspect (坡向, 单位:度)") + + def generate_slope_tif(self, output_path: Optional[str] = None, + z_factor: float = 1.0, algorithm: str = 'Horn') -> np.ndarray: + """ + 生成坡度TIFF文件 + + Args: + output_path: 输出文件路径(可选) + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + 坡度数据数组 + """ + if output_path is None: + base_name = os.path.splitext(self.dem_path)[0] + output_path = f"{base_name}_slope.tif" + + print(f"\n生成坡度文件...") + slope = self.calculate_slope(z_factor, algorithm) + self.save_raster(slope, output_path, 'Slope', no_data_value=self.no_data) + + return slope + + def generate_aspect_tif(self, output_path: Optional[str] = None, + algorithm: str = 'Horn') -> np.ndarray: + """ + 生成坡向TIFF文件 + + Args: + output_path: 输出文件路径(可选) + algorithm: 算法选项 + + Returns: + 坡向数据数组 + """ + if output_path is None: + base_name = os.path.splitext(self.dem_path)[0] + output_path = f"{base_name}_aspect.tif" + + print(f"\n生成坡向文件...") + aspect = self.calculate_aspect(algorithm) + self.save_raster(aspect, output_path, 'Aspect', no_data_value=-1.0) + + return aspect + + def generate_slope_aspect_separate(self, + slope_output: Optional[str] = None, + aspect_output: Optional[str] = None, + z_factor: float = 1.0, + algorithm: str = 'Horn') -> Tuple[np.ndarray, np.ndarray]: + """ + 分别生成坡度和坡向TIFF文件 + + Args: + slope_output: 坡度输出路径(可选) + aspect_output: 坡向输出路径(可选) + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + (slope_data, aspect_data) + """ + # 设置输出路径 + base_name = os.path.splitext(self.dem_path)[0] + + if slope_output is None: + slope_output = f"{base_name}_slope.tif" + if aspect_output is None: + aspect_output = f"{base_name}_aspect.tif" + + print(f"\n分别生成坡度和坡向文件...") + + # 同时计算坡度和坡向(优化版本) + slope, aspect = self.calculate_slope_aspect(z_factor, algorithm) + + # 保存文件 + self.save_raster(slope, slope_output, 'Slope', no_data_value=self.no_data) + self.save_raster(aspect, aspect_output, 'Aspect', no_data_value=-1.0) + + return slope, aspect + + def generate_combined_tif(self, output_path: Optional[str] = None, + z_factor: float = 1.0, + algorithm: str = 'Horn') -> Tuple[np.ndarray, np.ndarray]: + """ + 生成包含坡度和坡向的合并TIFF文件 + + Args: + output_path: 输出文件路径(可选) + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + (slope_data, aspect_data) + """ + if output_path is None: + base_name = os.path.splitext(self.dem_path)[0] + output_path = f"{base_name}_slope_aspect.tif" + + print(f"\n生成合并文件...") + + # 同时计算坡度和坡向 + slope, aspect = self.calculate_slope_aspect(z_factor, algorithm) + + # 保存合并文件 + self.save_combined_raster(slope, aspect, output_path) + + return slope, aspect + + def generate_all(self, output_dir: Optional[str] = None, + z_factor: float = 1.0, + algorithm: str = 'Horn') -> Dict[str, str]: + """ + 生成所有类型的文件:单独的坡度、单独的坡向、合并文件、统计信息 + + Args: + output_dir: 输出目录(可选) + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + 包含所有输出文件路径的字典 + """ + # 设置输出路径 + base_name = os.path.basename(self.dem_path) + file_base = os.path.splitext(base_name)[0] + + if output_dir: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_prefix = os.path.join(output_dir, file_base) + else: + output_prefix = file_base + + # 计算坡度和坡向(只计算一次) + print(f"\n计算坡度和坡向...") + slope, aspect = self.calculate_slope_aspect(z_factor, algorithm) + + # 生成各种文件 + results = {} + + # 1. 单独的坡度文件 + slope_file = f"{output_prefix}_slope.tif" + self.save_raster(slope, slope_file, 'Slope', no_data_value=self.no_data) + results['slope'] = slope_file + + # 2. 单独的坡向文件 + aspect_file = f"{output_prefix}_aspect.tif" + self.save_raster(aspect, aspect_file, 'Aspect', no_data_value=-1.0) + results['aspect'] = aspect_file + + # 3. 合并文件 + combined_file = f"{output_prefix}_slope_aspect.tif" + self.save_combined_raster(slope, aspect, combined_file) + results['combined'] = combined_file + + # 4. 统计信息文件(文本) + stats_file = f"{output_prefix}_statistics.txt" + self.save_statistics(slope, aspect, stats_file) + results['statistics'] = stats_file + + print(f"\n所有文件生成完成!") + for key, path in results.items(): + print(f" {key}: {os.path.basename(path)}") + + return results + + def save_statistics(self, slope_data: np.ndarray, aspect_data: np.ndarray, + output_path: str) -> None: + """ + 保存统计信息到文本文件 + + Args: + slope_data: 坡度数据 + aspect_data: 坡向数据 + output_path: 输出文件路径 + """ + valid_slope = slope_data[slope_data != self.no_data] + valid_aspect = aspect_data[aspect_data != -1.0] + + with open(output_path, 'w', encoding='utf-8') as f: + f.write("坡度和坡向统计信息\n") + f.write("=" * 50 + "\n") + f.write(f"输入DEM: {os.path.basename(self.dem_path)}\n") + f.write(f"栅格大小: {self.rows} x {self.cols}\n") + f.write(f"处理时间: {np.datetime64('now')}\n\n") + + f.write("坡度统计:\n") + f.write("-" * 30 + "\n") + if len(valid_slope) > 0: + f.write(f"最小值: {np.min(valid_slope):.4f} 度\n") + f.write(f"最大值: {np.max(valid_slope):.4f} 度\n") + f.write(f"平均值: {np.mean(valid_slope):.4f} 度\n") + f.write(f"标准差: {np.std(valid_slope):.4f} 度\n") + f.write(f"中位数: {np.median(valid_slope):.4f} 度\n") + f.write(f"有效像元数: {len(valid_slope)} / {self.rows * self.cols}\n") + + # 坡度分级统计 + f.write("\n坡度分级统计:\n") + slope_classes = ['平地(0-2)', '缓坡(2-5)', '斜坡(5-15)', '陡坡(15-30)', '急陡坡(30-45)', '峭壁(45-90)'] + bins = [0, 2, 5, 15, 30, 45, 90] + + hist, _ = np.histogram(valid_slope, bins=bins) + total = sum(hist) + + for i in range(len(slope_classes)): + if hist[i] > 0: + percentage = (hist[i] / total) * 100 + f.write(f"{slope_classes[i]}: {hist[i]} ({percentage:.2f}%)\n") + else: + f.write("无有效数据\n") + + f.write("\n坡向统计:\n") + f.write("-" * 30 + "\n") + if len(valid_aspect) > 0: + f.write(f"最小值: {np.min(valid_aspect):.4f} 度\n") + f.write(f"最大值: {np.max(valid_aspect):.4f} 度\n") + f.write(f"平均值: {np.mean(valid_aspect):.4f} 度\n") + f.write(f"标准差: {np.std(valid_aspect):.4f} 度\n") + f.write(f"中位数: {np.median(valid_aspect):.4f} 度\n") + f.write(f"有效像元数: {len(valid_aspect)} / {self.rows * self.cols}\n") + + # 坡向分类统计 + f.write("\n坡向分类统计:\n") + directions = ['北(0-22.5)', '东北(22.5-67.5)', '东(67.5-112.5)', '东南(112.5-157.5)', + '南(157.5-202.5)', '西南(202.5-247.5)', '西(247.5-292.5)', '西北(292.5-337.5)', '北(337.5-360)'] + bins = [0, 22.5, 67.5, 112.5, 157.5, 202.5, 247.5, 292.5, 337.5, 360] + + hist, _ = np.histogram(valid_aspect, bins=bins) + total = sum(hist) + + for i in range(len(directions)): + if hist[i] > 0: + percentage = (hist[i] / total) * 100 + f.write(f"{directions[i]}: {hist[i]} ({percentage:.2f}%)\n") + else: + f.write("无有效数据\n") + + print(f"统计信息已保存: {output_path}") + + def get_dem_info(self) -> Dict[str, Any]: + """ + 获取DEM文件的详细信息 + + Returns: + 包含DEM信息的字典 + """ + info = { + 'file_path': self.dem_path, + 'file_name': os.path.basename(self.dem_path), + 'rows': self.rows, + 'cols': self.cols, + 'cell_size_x': self.cell_size_x, + 'cell_size_y': self.cell_size_y, + 'no_data_value': self.no_data, + 'projection': self.projection + } + + # 获取高程统计 + valid_data = self.data[self.data != self.no_data] + if len(valid_data) > 0: + info['elevation_min'] = float(np.min(valid_data)) + info['elevation_max'] = float(np.max(valid_data)) + info['elevation_mean'] = float(np.mean(valid_data)) + info['elevation_std'] = float(np.std(valid_data)) + info['valid_cells'] = int(len(valid_data)) + info['total_cells'] = self.rows * self.cols + info['valid_percentage'] = (len(valid_data) / (self.rows * self.cols)) * 100 + + return info + + def get_slope_statistics(self, slope_data: np.ndarray) -> Dict[str, Any]: + """ + 获取坡度统计信息 + + Args: + slope_data: 坡度数据数组 + + Returns: + 坡度统计信息字典 + """ + valid_slope = slope_data[slope_data != self.no_data] + + stats = { + 'min': float(np.min(valid_slope)) if len(valid_slope) > 0 else None, + 'max': float(np.max(valid_slope)) if len(valid_slope) > 0 else None, + 'mean': float(np.mean(valid_slope)) if len(valid_slope) > 0 else None, + 'std': float(np.std(valid_slope)) if len(valid_slope) > 0 else None, + 'median': float(np.median(valid_slope)) if len(valid_slope) > 0 else None, + 'valid_cells': int(len(valid_slope)), + 'total_cells': self.rows * self.cols + } + + return stats + + def get_aspect_statistics(self, aspect_data: np.ndarray) -> Dict[str, Any]: + """ + 获取坡向统计信息 + + Args: + aspect_data: 坡向数据数组 + + Returns: + 坡向统计信息字典 + """ + valid_aspect = aspect_data[aspect_data != -1.0] + + stats = { + 'min': float(np.min(valid_aspect)) if len(valid_aspect) > 0 else None, + 'max': float(np.max(valid_aspect)) if len(valid_aspect) > 0 else None, + 'mean': float(np.mean(valid_aspect)) if len(valid_aspect) > 0 else None, + 'std': float(np.std(valid_aspect)) if len(valid_aspect) > 0 else None, + 'median': float(np.median(valid_aspect)) if len(valid_aspect) > 0 else None, + 'valid_cells': int(len(valid_aspect)), + 'total_cells': self.rows * self.cols + } + + # 坡向分类统计 + if len(valid_aspect) > 0: + bins = [0, 22.5, 67.5, 112.5, 157.5, 202.5, 247.5, 292.5, 337.5, 360] + directions = ['北', '东北', '东', '东南', '南', '西南', '西', '西北', '北'] + + hist, _ = np.histogram(valid_aspect, bins=bins) + direction_stats = {} + + for i in range(len(directions)): + direction_stats[directions[i]] = { + 'count': int(hist[i]), + 'percentage': (hist[i] / len(valid_aspect)) * 100 if len(valid_aspect) > 0 else 0 + } + + stats['direction_distribution'] = direction_stats + + return stats + + def close(self): + """关闭数据集,释放资源""" + if hasattr(self, 'dataset') and self.dataset: + self.dataset = None + if hasattr(self, 'band') and self.band: + self.band = None + print("资源已释放") + + def __enter__(self): + """支持with语句""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """退出with语句时自动关闭""" + self.close() + + +# 使用示例和辅助函数 +def create_slope_aspect(dem_path: str, output_type: str = 'combined', + output_path: Optional[str] = None, + z_factor: float = 1.0, algorithm: str = 'Horn') -> Dict[str, str]: + """ + 创建坡度坡向的便捷函数 + + Args: + dem_path: DEM文件路径 + output_type: 输出类型 'slope', 'aspect', 'separate', 'combined', 'all' + output_path: 输出路径(对于单个文件)或输出目录 + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + 包含输出文件路径的字典 + """ + print(f"处理DEM文件: {dem_path}") + + with SlopeAspectGenerator(dem_path) as generator: + # 获取DEM信息 + dem_info = generator.get_dem_info() + print(f"DEM信息: {dem_info['rows']}x{dem_info['cols']}, " + f"高程范围: {dem_info.get('elevation_min', 'N/A'):.2f}-" + f"{dem_info.get('elevation_max', 'N/A'):.2f}") + + results = {} + + if output_type == 'slope': + # 只生成坡度 + slope_file = output_path or f"{os.path.splitext(dem_path)[0]}_slope.tif" + slope = generator.generate_slope_tif(slope_file, z_factor, algorithm) + results['slope'] = slope_file + + # 获取统计信息 + slope_stats = generator.get_slope_statistics(slope) + print(f"坡度统计: {slope_stats['min']:.2f}度 - {slope_stats['max']:.2f}度") + + elif output_type == 'aspect': + # 只生成坡向 + aspect_file = output_path or f"{os.path.splitext(dem_path)[0]}_aspect.tif" + aspect = generator.generate_aspect_tif(aspect_file, algorithm) + results['aspect'] = aspect_file + + # 获取统计信息 + aspect_stats = generator.get_aspect_statistics(aspect) + print(f"坡向统计: {aspect_stats['min']:.2f}度 - {aspect_stats['max']:.2f}度") + + elif output_type == 'separate': + # 分别生成坡度和坡向 + if output_path and os.path.isdir(output_path): + # 如果输出路径是目录 + slope_file = os.path.join(output_path, f"{os.path.splitext(os.path.basename(dem_path))[0]}_slope.tif") + aspect_file = os.path.join(output_path, f"{os.path.splitext(os.path.basename(dem_path))[0]}_aspect.tif") + slope, aspect = generator.generate_slope_aspect_separate( + slope_file, aspect_file, z_factor, algorithm) + else: + # 如果输出路径是文件前缀 + base = output_path or os.path.splitext(dem_path)[0] + slope_file = f"{base}_slope.tif" + aspect_file = f"{base}_aspect.tif" + slope, aspect = generator.generate_slope_aspect_separate( + slope_file, aspect_file, z_factor, algorithm) + + results['slope'] = slope_file + results['aspect'] = aspect_file + + # 获取统计信息 + slope_stats = generator.get_slope_statistics(slope) + aspect_stats = generator.get_aspect_statistics(aspect) + print(f"坡度范围: {slope_stats['min']:.2f}度 - {slope_stats['max']:.2f}度") + print(f"坡向范围: {aspect_stats['min']:.2f}度 - {aspect_stats['max']:.2f}度") + + elif output_type == 'combined': + # 生成合并文件 + combined_file = output_path or f"{os.path.splitext(dem_path)[0]}_slope_aspect.tif" + slope, aspect = generator.generate_combined_tif(combined_file, z_factor, algorithm) + results['combined'] = combined_file + + # 获取统计信息 + slope_stats = generator.get_slope_statistics(slope) + aspect_stats = generator.get_aspect_statistics(aspect) + print(f"坡度范围: {slope_stats['min']:.2f}度 - {slope_stats['max']:.2f}度") + print(f"坡向范围: {aspect_stats['min']:.2f}度 - {aspect_stats['max']:.2f}度") + + elif output_type == 'all': + # 生成所有文件 + output_dir = output_path if (output_path and os.path.isdir(output_path)) else None + results = generator.generate_all(output_dir, z_factor, algorithm) + + else: + raise ValueError(f"不支持的输出类型: {output_type}") + + return results + + +def batch_process_slope_aspect(dem_files: list, output_dir: str, + output_type: str = 'all', + z_factor: float = 1.0, + algorithm: str = 'Horn') -> Dict[str, list]: + """ + 批量处理多个DEM文件 + + Args: + dem_files: DEM文件路径列表 + output_dir: 输出目录 + output_type: 输出类型 + z_factor: 垂直单位转换因子 + algorithm: 算法选项 + + Returns: + 包含处理结果的字典 + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + results = { + 'success': [], + 'failed': [], + 'outputs': [] + } + + total_files = len(dem_files) + print(f"开始批量处理 {total_files} 个DEM文件...") + + for i, dem_file in enumerate(dem_files, 1): + if not os.path.exists(dem_file): + print(f"[{i}/{total_files}] 文件不存在: {dem_file}") + results['failed'].append(dem_file) + continue + + print(f"\n[{i}/{total_files}] 处理: {os.path.basename(dem_file)}") + + try: + # 处理文件 + file_results = create_slope_aspect( + dem_file, + output_type, + output_dir, + z_factor, + algorithm + ) + + results['success'].append(dem_file) + results['outputs'].append(file_results) + + print(f"[{i}/{total_files}] 完成") + + except Exception as e: + print(f"[{i}/{total_files}] 失败: {e}") + results['failed'].append(dem_file) + + # 打印总结 + print(f"\n" + "="*50) + print(f"批量处理完成!") + print(f"成功: {len(results['success'])} 个") + print(f"失败: {len(results['failed'])} 个") + + if results['failed']: + print("\n失败的文件:") + for failed_file in results['failed']: + print(f" {os.path.basename(failed_file)}") + + return results + + +# 主示例代码 +def main_example(): + """使用示例""" + + # 示例1: 基本使用 + print("示例1: 基本使用") + print("=" * 50) + + # 配置参数 + SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + dem_file = os.path.join(SCRIPT_DIR, "o_dem_23d110d0.tif") + + if os.path.exists(dem_file): + # 方法1: 使用便捷函数 + print("\n方法1: 使用便捷函数") + slope_file = os.path.join(SCRIPT_DIR, "o_combined.tif") + results = create_slope_aspect(dem_file, 'combined', slope_file) + + # 方法2: 直接使用类 + # print("\n方法2: 直接使用类") + # with SlopeAspectGenerator(dem_file) as gen: + # # 获取DEM信息 + # dem_info = gen.get_dem_info() + # print(f"DEM信息: {dem_info}") + + # # 生成合并文件 + # slope, aspect = gen.generate_combined_tif("o_combined.tif") + + # # 获取统计信息 + # slope_stats = gen.get_slope_statistics(slope) + # aspect_stats = gen.get_aspect_statistics(aspect) + # print(f"坡度统计: {slope_stats}") + # print(f"坡向统计: {aspect_stats}") + + else: + print(f"DEM文件不存在: {dem_file}") + print("请创建一个示例DEM文件或使用您自己的DEM文件") + + # 创建一个简单的示例DEM + print("\n创建一个示例DEM文件用于测试...") + create_sample_dem("sample_dem.tif") + + # 使用示例DEM + print("\n使用示例DEM进行测试...") + results = create_slope_aspect("sample_dem.tif", 'all') + + # 示例2: 批量处理 + # print("\n\n示例2: 批量处理") + # print("=" * 50) + + # # 假设有多个DEM文件 + # dem_files = ["dem1.tif", "dem2.tif", "dem3.tif"] # 替换为实际文件 + + # # 过滤出存在的文件 + # existing_files = [f for f in dem_files if os.path.exists(f)] + + # if len(existing_files) > 0: + # batch_results = batch_process_slope_aspect( + # existing_files, + # "batch_output", + # 'all' + # ) + # else: + # print("未找到DEM文件进行批量处理") + # print("创建示例文件进行批量处理演示...") + + # # 创建几个示例DEM文件 + # for i in range(1, 4): + # create_sample_dem(f"sample_dem_{i}.tif", offset=i*100) + + # sample_files = [f"sample_dem_{i}.tif" for i in range(1, 4)] + # batch_results = batch_process_slope_aspect( + # sample_files, + # "batch_output", + # 'all' + # ) + + +def create_sample_dem(output_path: str, rows: int = 100, cols: int = 100, + cell_size: float = 30.0, offset: float = 0.0): + """ + 创建一个示例DEM文件用于测试 + + Args: + output_path: 输出文件路径 + rows: 行数 + cols: 列数 + cell_size: 像元大小 + offset: 高程偏移量 + """ + from osgeo import osr + + # 创建一个简单的地形:一个山峰 + x = np.linspace(-5, 5, cols) + y = np.linspace(-5, 5, rows) + X, Y = np.meshgrid(x, y) + + # 创建高程数据(高斯山峰) + elevation = 1000 * np.exp(-(X**2 + Y**2) / 10) + offset + + # 添加一些噪声 + elevation += np.random.normal(0, 10, (rows, cols)) + + # 创建栅格文件 + driver = gdal.GetDriverByName('GTiff') + dataset = driver.Create( + output_path, + cols, + rows, + 1, + gdal.GDT_Float32 + ) + + # 设置地理变换(左下角坐标为(0,0)) + dataset.SetGeoTransform((0, cell_size, 0, 0, 0, -cell_size)) + + # 设置投影(WGS84) + srs = osr.SpatialReference() + srs.ImportFromEPSG(4326) + dataset.SetProjection(srs.ExportToWkt()) + + # 写入数据 + band = dataset.GetRasterBand(1) + band.WriteArray(elevation) + band.SetNoDataValue(-9999.0) + band.SetDescription('Elevation') + + dataset.FlushCache() + dataset = None + + print(f"示例DEM已创建: {output_path}") + + +if __name__ == "__main__": + # 运行示例 + main_example() \ No newline at end of file diff --git a/b3dm/terrain_api.py b/b3dm/terrain_api.py index 4394fe0..0f3ec19 100644 --- a/b3dm/terrain_api.py +++ b/b3dm/terrain_api.py @@ -174,7 +174,7 @@ async def preload_3dtiles(request: Request): # 创建并启动线程 script_dir = os.path.dirname(os.path.abspath(__file__)) - thread1 = threading.Thread(target=TerrainCalculator.preload_3dtiles, args=(vector.url)) + thread1 = threading.Thread(target=TerrainCalculator.preload_3dtiles, args=(vector.url,)) # 启动线程 thread1.start() url_prefix = extract_and_rebuild_url(vector.url) diff --git a/b3dm/terrain_calculator.py b/b3dm/terrain_calculator.py index cf0e2af..111b8db 100644 --- a/b3dm/terrain_calculator.py +++ b/b3dm/terrain_calculator.py @@ -3,58 +3,39 @@ from typing import List, Tuple, Dict, Any import logging import os import uuid -from b3dm.data_3dtiles_manager import MinIO3DTilesManager import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem import b3dm.slope_aspect_img as slope_aspect_img import b3dm.slope_aspect_tif as slope_aspect_tif +from b3dm.tileset_data_source import TilesetDataSource logger = logging.getLogger(__name__) -ENDPOINT_URL = "222.212.85.86:9000" -ACCESS_KEY = "WuRenJi" -SECRET_KEY = "WRJ@2024" +_data_source = None class TerrainCalculator: """地形坡度和坡向计算器""" - def preload_3dtiles(url) : + def preload_3dtiles(url: str) : # 下载3dtiles地图数据 - manager = MinIO3DTilesManager( - endpoint_url=ENDPOINT_URL, - access_key=ACCESS_KEY, - secret_key=SECRET_KEY, - secure=False - ) - script_dir = os.path.dirname(os.path.abspath(__file__)) - success, entry_local_path = manager.download_full_tileset( - tileset_url=url, - save_dir=f"data_3dtiles", - region_filter=None - ) - if not success : + _data_source = TilesetDataSource(url) + _data_source.dowload_map_data(url) + + if not _data_source.tileset_path : logger.info(f"下载地图数据失败: {url}") return "下载地图数据失败", None def generate_slopeAspect_3d_overlook(region_coords, url, overall_3d_png_name, minio_sub_path) : # 下载3dtiles地图数据 - manager = MinIO3DTilesManager( - endpoint_url=ENDPOINT_URL, - access_key=ACCESS_KEY, - secret_key=SECRET_KEY, - secure=False - ) - script_dir = os.path.dirname(os.path.abspath(__file__)) - success, entry_local_path = manager.download_full_tileset( - tileset_url=url, - save_dir=f"data_3dtiles", - region_filter=None - ) - if not success : + _data_source = TilesetDataSource(url) + _data_source.dowload_map_data(url) + + if not _data_source.tileset_path : logger.info(f"下载地图数据失败: {url},{region_coords}") return "下载地图数据失败", None - tileset_path = entry_local_path + tileset_path = _data_source.tileset_path + script_dir = os.path.dirname(os.path.abspath(__file__)) dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif") data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords) @@ -66,8 +47,8 @@ class TerrainCalculator: slope_aspect_img.read_slope_aspect_by_dem(dem_path, overall_3d_png_path) logger.info(f"生成成功: {url},{region_coords},{overall_3d_png_path}") - entry_bucket, _ = manager.parse_minio_url(url); - success, minio_path = manager.upload_file(entry_bucket, f"{minio_sub_path}/{overall_3d_png_name}", overall_3d_png_path) + entry_bucket, _ = _data_source.parse_minio_url(url); + success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{overall_3d_png_name}", overall_3d_png_path) if success : return "生成成功", minio_path else : @@ -75,36 +56,28 @@ class TerrainCalculator: def generate_slopeAspect_tif(region_coords, url, slope_aspect_tif_name, minio_sub_path) : # 下载3dtiles地图数据 - manager = MinIO3DTilesManager( - endpoint_url=ENDPOINT_URL, - access_key=ACCESS_KEY, - secret_key=SECRET_KEY, - secure=False - ) - script_dir = os.path.dirname(os.path.abspath(__file__)) - success, entry_local_path = manager.download_full_tileset( - tileset_url=url, - save_dir=f"data_3dtiles", - region_filter=None - ) - if not success : + _data_source = TilesetDataSource(url) + _data_source.dowload_map_data(url) + + if not _data_source.tileset_path : logger.info(f"下载地图数据失败: {url},{region_coords}") return "下载地图数据失败", None - tileset_path = entry_local_path + tileset_path = _data_source.tileset_path + script_dir = os.path.dirname(os.path.abspath(__file__)) dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif") data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords) if not os.path.exists(dem_path) : - logger.info(f"生成坡度坡向俯视图失败: {url},{region_coords}") - return "生成坡度坡向俯视图失败", None + logger.info(f"生成坡度坡向tif失败: {url},{region_coords}") + return "生成坡度坡向tif失败", None slope_aspect_tif_path = os.path.join(script_dir, slope_aspect_tif_name) slope_aspect_tif.create_slope_aspect(dem_path, 'combined', slope_aspect_tif_path) logger.info(f"生成成功: {url},{region_coords},{slope_aspect_tif_path}") - entry_bucket, _ = manager.parse_minio_url(url); - success, minio_path = manager.upload_file(entry_bucket, f"{minio_sub_path}/{slope_aspect_tif_name}", slope_aspect_tif_path) + entry_bucket, _ = _data_source.parse_minio_url(url); + success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{slope_aspect_tif_name}", slope_aspect_tif_path) if success : return "生成成功", minio_path else : diff --git a/b3dm/tileset_data_source.py b/b3dm/tileset_data_source.py index bcfcb19..27c83c9 100644 --- a/b3dm/tileset_data_source.py +++ b/b3dm/tileset_data_source.py @@ -5,148 +5,71 @@ import asyncio from concurrent.futures import ThreadPoolExecutor import logging import os +from pathlib import Path +from b3dm.data_3dtiles_manager import MinIO3DTilesManager +import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem logger = logging.getLogger(__name__) +ENDPOINT_URL = "222.212.85.86:9000" +ACCESS_KEY = "WuRenJi" +SECRET_KEY = "WRJ@2024" + class TilesetDataSource: """使用py3dtiles库的数据源""" - def __init__(self, tileset_path: str, cache_size: int = 1000): - self.tileset_path = os.path.abspath(tileset_path) - self.tileset_dir = os.path.dirname(self.tileset_path) - self.cache_size = cache_size - self._tileset = None - self._point_cache = {} - self._executor = ThreadPoolExecutor(max_workers=4) + def __init__(self, url: str, cache_size: int = 1000): + self.url = url + self.tileset_path = None + self.tileset_dir = None self._crs = "EPSG:4979" - - async def initialize(self): - """初始化""" - try: - # 尝试导入py3dtiles - try: - import py3dtiles - from py3dtiles.tileset import TileSet - except ImportError: - logger.warning("py3dtiles未安装,将使用简化数据源") - raise ImportError("请安装py3dtiles: pip install py3dtiles") - - loop = asyncio.get_event_loop() - self._tileset = await loop.run_in_executor( - self._executor, - TileSet.from_file, - self.tileset_path - ) - - logger.info(f"py3dtiles数据源初始化完成: {self.tileset_path}") - - except Exception as e: - logger.error(f"py3dtiles初始化失败: {str(e)}") - # 回退到简化数据源 - self._tileset = None + + def parse_minio_url(self, url): + manager = MinIO3DTilesManager( + endpoint_url=ENDPOINT_URL, + access_key=ACCESS_KEY, + secret_key=SECRET_KEY, + secure=False + ) + return manager.parse_minio_url(url) - async def get_points_in_polygon(self, - polygon_coords: List[List[float]], - z_range: Optional[Tuple[float, float]] = None) -> np.ndarray: - """获取点数据""" - if self._tileset is None: - # 使用简化数据源 - return await self._get_simulated_points(polygon_coords, z_range) + def upload_file(self, bucket_name, object_name, file_path): + manager = MinIO3DTilesManager( + endpoint_url=ENDPOINT_URL, + access_key=ACCESS_KEY, + secret_key=SECRET_KEY, + secure=False + ) + flag, path = manager.upload_file(bucket_name, object_name, file_path) + if flag : + os.remove(file_path) + return flag, path - try: - # 使用py3dtiles API获取数据 - points = [] - polygon_np = np.array(polygon_coords) - - # 遍历tileset中的所有tile - for tile in self._tileset.root_tile.traverse(): - tile_points = self._extract_tile_points(tile) - if tile_points.size > 0: - # 筛选多边形内的点 - points_in_polygon = self._filter_points_by_polygon(tile_points, polygon_np) - if points_in_polygon.size > 0: - points.append(points_in_polygon) - - if points: - all_points = np.vstack(points) - if z_range: - mask = (all_points[:, 2] >= z_range[0]) & (all_points[:, 2] <= z_range[1]) - all_points = all_points[mask] - return all_points - - return np.array([]) - - except Exception as e: - logger.error(f"py3dtiles获取数据失败: {str(e)}") - return await self._get_simulated_points(polygon_coords, z_range) + + def dowload_map_data(self, url: str) : + # 下载3dtiles地图数据 + manager = MinIO3DTilesManager( + endpoint_url=ENDPOINT_URL, + access_key=ACCESS_KEY, + secret_key=SECRET_KEY, + secure=False + ) + success, tileset_path = manager.download_full_tileset( + tileset_url=url, + save_dir=f"data_3dtiles", + region_filter=None + ) + if success : + self.tileset_path = os.path.abspath(tileset_path) + self.tileset_dir = os.path.dirname(tileset_path) - def _extract_tile_points(self, tile) -> np.ndarray: - """从tile提取点数据""" - try: - if hasattr(tile, 'content') and tile.content: - # 尝试获取点数据 - if hasattr(tile.content, 'points'): - return tile.content.points.positions - elif hasattr(tile.content, 'body'): - # 处理其他格式 - return np.array([]) - return np.array([]) - except: - return np.array([]) - - def _filter_points_by_polygon(self, points: np.ndarray, polygon: np.ndarray) -> np.ndarray: - """筛选多边形内的点""" - from matplotlib.path import Path - if points.size == 0: - return points - - path = Path(polygon[:, :2]) - mask = path.contains_points(points[:, :2]) - return points[mask] - - async def _get_simulated_points(self, - polygon_coords: List[List[float]], - z_range: Optional[Tuple[float, float]] = None) -> np.ndarray: - """获取模拟点数据""" - # 同上一个类的生成模拟数据方法 - polygon_np = np.array(polygon_coords) - - if polygon_np.shape[0] < 3: - return np.array([]) - - x_min, y_min = polygon_np.min(axis=0) - x_max, y_max = polygon_np.max(axis=0) - - grid_size = 100 - x = np.linspace(x_min, x_max, grid_size) - y = np.linspace(y_min, y_max, grid_size) - xx, yy = np.meshgrid(x, y) - - points = np.column_stack([xx.ravel(), yy.ravel()]) - - np.random.seed(42) - base_elevation = 50.0 - terrain_variation = 10.0 - z = base_elevation + terrain_variation * np.sin(0.1 * xx).ravel() * np.cos(0.1 * yy).ravel() - - points = np.column_stack([points[:, 0], points[:, 1], z]) - - from matplotlib.path import Path - path = Path(polygon_np[:, :2]) - mask = path.contains_points(points[:, :2]) - points = points[mask] - - if z_range: - mask = (points[:, 2] >= z_range[0]) & (points[:, 2] <= z_range[1]) - points = points[mask] - - logger.info(f"生成 {points.shape[0]} 个模拟点") - return points + async def get_points_in_polygon(self, polygon_coords, z_range=None): + """获取多边形内的点数据""" + points = data_3dtiles_to_dem.parse_tileset(self.tileset_path, polygon_coords) + return np.array(points) async def get_data_bounds(self) -> Dict[str, List[float]]: """获取数据边界""" - if self._tileset is None: - await self.initialize() bounds = { "min": [float('inf'), float('inf'), float('inf')], @@ -164,4 +87,19 @@ class TilesetDataSource: return bounds def get_crs(self) -> str: - return self._crs \ No newline at end of file + return self._crs + +async def main() : + SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + SCRIPT_PAR_DIR = os.path.dirname(SCRIPT_DIR) + tileset_path = os.path.join(SCRIPT_PAR_DIR, "data/3dtiles/tileset.json") + data_source_3d_tiles = TilesetDataSource(tileset_path) + # tileSet = TileSet() + # path = Path(tileset_path) + # tileset_data = tileSet.from_file(path) + print("====================================================") + + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) \ No newline at end of file