ai_project_v1/b3dm/glb_with_draco.py
2026-01-19 10:42:21 +08:00

469 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import struct
import numpy as np
import DracoPy
class DracoGLBParser:
"""使用 DracoPy 解析包含 Draco 压缩的 GLB 文件"""
def __init__(self, glb_file_path):
self.glb_file_path = glb_file_path
self.json_data = None
self.binary_data = None
self.decoded_meshes = [] # 缓存解码后的网格数据
def parse_glb_structure(self):
"""解析 GLB 文件结构"""
with open(self.glb_file_path, 'rb') as f:
# 读取 GLB 头部
magic = f.read(4)
version = struct.unpack('<I', f.read(4))[0]
total_length = struct.unpack('<I', f.read(4))[0]
print("=" * 60)
print(f"GLB 文件分析:")
print(f" 文件类型: {magic.decode('utf-8')}")
print(f" 版本: {version}")
print(f" 总大小: {total_length:,} bytes")
# 读取 JSON chunk
json_length = struct.unpack('<I', f.read(4))[0]
json_type = f.read(4)
if json_type != b'JSON':
raise ValueError(f"期望 JSON chunk但得到: {json_type}")
self.json_data = json.loads(f.read(json_length).decode('utf-8'))
print(f" JSON 大小: {json_length:,} bytes")
# 读取 Binary chunk
if f.tell() < total_length:
bin_length = struct.unpack('<I', f.read(4))[0]
bin_type = f.read(4)
if bin_type != b'BIN\x00':
raise ValueError(f"期望 BIN chunk但得到: {bin_type}")
self.binary_data = f.read(bin_length)
print(f" 二进制数据大小: {bin_length:,} bytes")
print("=" * 60)
return self
def analyze_structure(self):
"""分析 GLB 结构"""
if not self.json_data:
self.parse_glb_structure()
print("\nGLB 结构分析:")
print("-" * 40)
# 基本信息
asset = self.json_data.get('asset', {})
print(f"生成器: {asset.get('generator', '未知')}")
print(f"glTF 版本: {asset.get('version', '未知')}")
# Draco 扩展
extensions_used = self.json_data.get('extensionsUsed', [])
extensions_required = self.json_data.get('extensionsRequired', [])
if 'KHR_draco_mesh_compression' in extensions_used:
print("使用 Draco 压缩")
if 'KHR_draco_mesh_compression' in extensions_required:
print("Draco 压缩是必需的")
# 网格信息
meshes = self.json_data.get('meshes', [])
print(f"\n网格数量: {len(meshes)}")
for i, mesh in enumerate(meshes):
print(f" 网格 {i}: {mesh.get('name', '未命名')}")
primitives = mesh.get('primitives', [])
print(f" 图元数量: {len(primitives)}")
for j, primitive in enumerate(primitives):
print(f" 图元 {j}:")
if 'extensions' in primitive:
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
if draco_info:
print(f" 使用 Draco 压缩")
print(f" 属性: {draco_info.get('attributes', {})}")
# 缓冲区信息
buffers = self.json_data.get('buffers', [])
buffer_views = self.json_data.get('bufferViews', [])
accessors = self.json_data.get('accessors', [])
print(f"\n缓冲区: {len(buffers)}")
print(f"BufferViews: {len(buffer_views)}")
print(f"访问器: {len(accessors)}")
return self
def decode_draco_meshes(self):
"""解码所有 Draco 压缩的网格"""
if not self.json_data:
self.parse_glb_structure()
meshes = []
buffer_views = self.json_data.get('bufferViews', [])
print("\n" + "=" * 60)
print("开始解码 Draco 压缩数据...")
print("=" * 60)
for mesh_idx, mesh in enumerate(self.json_data.get('meshes', [])):
mesh_name = mesh.get('name', f'mesh_{mesh_idx}')
for primitive_idx, primitive in enumerate(mesh.get('primitives', [])):
if 'extensions' in primitive:
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
if draco_info:
print(f"\n解码: {mesh_name} - 图元 {primitive_idx}")
# 解码 Draco 数据
mesh_data = self._decode_primitive(draco_info, buffer_views)
if mesh_data:
meshes.append({
'mesh_idx': mesh_idx,
'primitive_idx': primitive_idx,
'name': mesh_name,
**mesh_data
})
self.decoded_meshes = meshes # 缓存解码结果
print("\n" + "=" * 60)
print(f"解码完成!共解码 {len(meshes)} 个网格")
print("=" * 60)
return meshes
def get_vertices(self, mesh_idx=0, primitive_idx=0):
"""
获取指定网格的顶点集合
参数:
mesh_idx: 网格索引默认0
primitive_idx: 图元索引默认0
返回:
np.array: 顶点数组,形状为 (n, 3) 或 None
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
# 查找指定网格
for mesh in self.decoded_meshes:
if mesh['mesh_idx'] == mesh_idx and mesh['primitive_idx'] == primitive_idx:
return mesh.get('vertices')
print(f"未找到网格 {mesh_idx} 的图元 {primitive_idx}")
return None
def get_all_vertices(self):
"""
获取所有网格的所有顶点,合并成一个数组
返回:
np.array: 所有顶点的合并数组,形状为 (n, 3) 或 None
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
if not self.decoded_meshes:
print("没有解码的网格数据")
return None
# 收集所有顶点
all_vertices = []
for mesh in self.decoded_meshes:
if mesh.get('vertices') is not None:
all_vertices.append(mesh['vertices'])
if not all_vertices:
return None
# 合并所有顶点
return np.vstack(all_vertices)
def get_vertices_by_mesh_name(self, mesh_name):
"""
根据网格名称获取顶点集合
参数:
mesh_name: 网格名称
返回:
list: 包含所有匹配网格的顶点数组列表
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
vertices_list = []
for mesh in self.decoded_meshes:
if mesh['name'] == mesh_name and mesh.get('vertices') is not None:
vertices_list.append(mesh['vertices'])
return vertices_list
def get_vertex_count(self):
"""
获取总顶点数
返回:
int: 所有网格的总顶点数
"""
vertices = self.get_all_vertices()
return len(vertices) if vertices is not None else 0
def _decode_primitive(self, draco_info, buffer_views):
"""解码单个图元的 Draco 数据"""
try:
# 获取 bufferView 信息
buffer_view_idx = draco_info['bufferView']
attributes = draco_info['attributes']
buffer_view = buffer_views[buffer_view_idx]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
print(f" BufferView: {buffer_view_idx}")
print(f" 属性映射: {attributes}")
print(f" 数据位置: offset={byte_offset}, length={byte_length}")
# 提取 Draco 压缩数据
draco_data = self.binary_data[byte_offset:byte_offset + byte_length]
print(f" Draco 数据大小: {len(draco_data):,} bytes")
# 使用 DracoPy 解码
print(" 正在使用 DracoPy 解码...")
draco_decoder = DracoPy.decode(draco_data)
# 解析解码结果
mesh_data = self._parse_draco_result(draco_decoder, attributes)
return mesh_data
except Exception as e:
print(f" 解码失败: {e}")
import traceback
traceback.print_exc()
return None
def _parse_draco_result(self, draco_decoder, attributes):
"""解析 DracoPy 解码结果"""
result = {
'vertices': None,
'faces': None,
'texcoords': None,
'batch_ids': None,
'normals': None,
'colors': None
}
# 获取顶点
if hasattr(draco_decoder, 'points'):
result['vertices'] = np.array(draco_decoder.points, dtype=np.float32)
print(f" 顶点数量: {len(result['vertices'])}")
# 获取面/三角形
if hasattr(draco_decoder, 'faces'):
faces_data = draco_decoder.faces
# 确保是三角形每面3个顶点
if len(faces_data) > 0:
if isinstance(faces_data[0], list) or isinstance(faces_data[0], tuple):
# 如果是列表的列表
result['faces'] = np.array(faces_data, dtype=np.uint32)
else:
# 如果是扁平化的数组
result['faces'] = np.array(faces_data, dtype=np.uint32).reshape(-1, 3)
print(f" 面数量: {len(result['faces']) if result['faces'] is not None else 0}")
# 获取属性数据
if hasattr(draco_decoder, 'attributes'):
attrs = draco_decoder.attributes
# 根据属性映射查找数据
for gltf_attr_name, draco_attr_id in attributes.items():
if draco_attr_id in attrs:
attr_data = attrs[draco_attr_id]
if gltf_attr_name == 'POSITION':
result['vertices'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == 'TEXCOORD_0':
result['texcoords'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == '_BATCHID':
result['batch_ids'] = np.array(attr_data, dtype=np.uint32)
elif gltf_attr_name == 'NORMAL':
result['normals'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == 'COLOR_0':
result['colors'] = np.array(attr_data, dtype=np.float32)
print(f" 已提取属性: {gltf_attr_name} (ID: {draco_attr_id})")
# 如果没有通过attributes获取到顶点尝试其他方式
if result['vertices'] is None and hasattr(draco_decoder, 'get_points'):
try:
result['vertices'] = np.array(draco_decoder.get_points(), dtype=np.float32)
except:
pass
return result
def save_decoded_meshes(self, meshes, output_format='obj'):
"""保存解码后的网格"""
import os
base_name = os.path.splitext(os.path.basename(self.glb_file_path))[0]
output_dir = f"{base_name}_decoded"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for mesh in meshes:
filename = f"{output_dir}/{mesh['name']}_p{mesh['primitive_idx']}.{output_format}"
if output_format == 'obj':
self._save_as_obj(mesh, filename)
elif output_format == 'ply':
self._save_as_ply(mesh, filename)
elif output_format == 'npz':
self._save_as_npz(mesh, filename)
else:
print(f"不支持的格式: {output_format}")
continue
print(f"已保存: {filename}")
def _save_as_obj(self, mesh, filename):
"""保存为 OBJ 格式"""
with open(filename, 'w') as f:
# 写入顶点
if mesh['vertices'] is not None:
for v in mesh['vertices']:
f.write(f"v {v[0]} {v[1]} {v[2]}\n")
# 写入纹理坐标
if mesh['texcoords'] is not None:
for uv in mesh['texcoords']:
f.write(f"vt {uv[0]} {uv[1]}\n")
# 写入法线
if mesh['normals'] is not None:
for n in mesh['normals']:
f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
# 写入面
if mesh['faces'] is not None:
for face in mesh['faces']:
# OBJ 索引从1开始
face_indices = [str(idx + 1) for idx in face]
f.write(f"f {' '.join(face_indices)}\n")
def _save_as_ply(self, mesh, filename):
"""保存为 PLY 格式"""
from plyfile import PlyData, PlyElement
import numpy as np
vertices = mesh['vertices']
faces = mesh['faces']
if vertices is None:
return
# 创建顶点数据
vertex_data = np.zeros(len(vertices), dtype=[
('x', 'f4'), ('y', 'f4'), ('z', 'f4')
])
vertex_data['x'] = vertices[:, 0]
vertex_data['y'] = vertices[:, 1]
vertex_data['z'] = vertices[:, 2]
# 创建面数据
if faces is not None:
face_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (3,))])
face_data['vertex_indices'] = faces
# 写入文件
vertex_element = PlyElement.describe(vertex_data, 'vertex')
if faces is not None:
face_element = PlyElement.describe(face_data, 'face')
PlyData([vertex_element, face_element], text=False).write(filename)
else:
PlyData([vertex_element], text=False).write(filename)
def _save_as_npz(self, mesh, filename):
"""保存为 NPZ 格式"""
np.savez(
filename,
vertices=mesh['vertices'],
faces=mesh['faces'],
texcoords=mesh['texcoords'],
batch_ids=mesh['batch_ids'],
normals=mesh['normals'],
colors=mesh['colors']
)
# 使用示例
def main():
# 初始化解析器
parser = DracoGLBParser(r"D:\devForBdzlWork\ai_project_v1\b3dm\test\temp_glb\temp_6e895637.glb")
# 解析 GLB 结构
parser.parse_glb_structure()
# 分析结构
parser.analyze_structure()
# 解码 Draco 网格
meshes = parser.decode_draco_meshes()
# 使用新增的顶点获取方法
print("\n" + "=" * 60)
print("顶点获取方法演示:")
print("=" * 60)
# 1. 获取第一个网格的第一个图元的顶点
vertices = parser.get_vertices(mesh_idx=0, primitive_idx=0)
if vertices is not None:
print(f"1. 获取第一个网格顶点:")
print(f" 形状: {vertices.shape}")
print(f" 数据类型: {vertices.dtype}")
print(f" 前5个顶点: \n{vertices[:5]}")
# 2. 获取所有顶点(合并)
all_vertices = parser.get_all_vertices()
if all_vertices is not None:
print(f"\n2. 获取所有网格顶点(合并):")
print(f" 总顶点数: {len(all_vertices)}")
print(f" 形状: {all_vertices.shape}")
# 3. 获取总顶点数
total_vertices = parser.get_vertex_count()
print(f"\n3. 总顶点数: {total_vertices}")
# 4. 根据网格名称获取顶点
if meshes:
mesh_name = meshes[0]['name']
vertices_list = parser.get_vertices_by_mesh_name(mesh_name)
print(f"\n4. 根据名称 '{mesh_name}' 获取的顶点:")
for i, verts in enumerate(vertices_list):
print(f" 图元 {i}: {verts.shape if verts is not None else 'None'}")
# 保存解码后的网格
parser.save_decoded_meshes(meshes, output_format='obj')
if __name__ == "__main__":
main()