h5ad generation

This commit is contained in:
Flash 2025-07-27 19:18:17 +08:00
parent b39288fb9f
commit c85b6b94f6
3 changed files with 44 additions and 44 deletions

BIN
embryo-backend/Data/CS11.h5ad (Stored with Git LFS)

Binary file not shown.

BIN
embryo-backend/Data/CS12.h5ad (Stored with Git LFS)

Binary file not shown.

View File

@ -3,7 +3,6 @@ import trimesh
import pandas as pd
import numpy as np
import anndata as ad
from collections import defaultdict
# ==== 参数 ====
script_dir = os.path.dirname(os.path.realpath(__file__))
@ -30,32 +29,26 @@ true_markers = {
# ==== 采样函数 ====
def sample_mesh(mesh, label, n_samples):
"""采样点云 + 颜色 + 标签"""
if use_surface_sampling:
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
face_colors = mesh.visual.vertex_colors[mesh.faces]
colors = face_colors.mean(axis=1)[face_idx, :3]
else:
mat_color = np.array([200, 200, 200])
colors = np.tile(mat_color, (n_samples, 1))
if len(mesh.vertices) == 0:
print(f"⚠️ {label} 无顶点,跳过")
return np.zeros((0,3)), np.zeros((0,3)), np.array([])
if mesh.faces is None or len(mesh.faces) == 0 or not use_surface_sampling:
idx = np.random.choice(len(mesh.vertices), min(n_samples, len(mesh.vertices)), replace=False)
points = mesh.vertices[idx]
else:
verts = mesh.vertices
if len(verts) > n_samples:
idx = np.random.choice(len(verts), n_samples, replace=False)
verts = verts[idx]
else:
idx = np.arange(len(verts))
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
colors = mesh.visual.vertex_colors[idx, :3]
else:
mat_color = np.array([200, 200, 200])
colors = np.tile(mat_color, (len(verts), 1))
points = verts
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
colors = mesh.visual.vertex_colors[:len(points), :3]
else:
mat_color = np.array([200,200,200])
colors = np.tile(mat_color, (len(points),1))
labels = np.array([label] * len(points))
return points, colors, labels
# ==== Step 1: 计算全局最大边长 ====
# ==== Step 1: 计算全局边界 ====
all_bounds = []
for glb_name in glb_files:
glb_path = os.path.join(script_dir, glb_name)
@ -74,7 +67,7 @@ global_max = np.max(all_bounds[:,1,:], axis=0)
global_size = np.max(global_max - global_min)
print(f"🌍 统一缩放基准: global_size={global_size}")
# ==== Step 2: 处理单个 GLB(应用全局变换) ====
# ==== Step 2: 处理单个 GLB ====
def process_glb(glb_path, sample_name):
scene = trimesh.load(glb_path)
all_points, all_colors, all_labels = [], [], []
@ -82,23 +75,32 @@ def process_glb(glb_path, sample_name):
if isinstance(scene, trimesh.Scene):
total_points = sum(len(g.vertices) for g in scene.geometry.values())
for name, geom in scene.geometry.items():
ratio = len(geom.vertices) / total_points
n_samples = max(1, int(N_total * ratio))
if len(geom.vertices) == 0:
print(f"⚠️ {sample_name}: 子网格 {name} 无顶点,跳过")
continue
# ✅ 保护性获取变换矩阵
ratio = len(geom.vertices) / total_points if total_points > 0 else 0
n_samples = max(100, int(N_total * ratio)) # ✅ 最少100点
# ✅ 自动检测 & 兼容修复
try:
transform = scene.graph.get(name)[0]
verts_world = trimesh.transform_points(geom.vertices, transform)
except Exception:
print(f"⚠️ {sample_name}: 子网格 {name} 没有变换路径,使用单位矩阵")
transform = np.eye(4)
# 判断顶点是否已经全局化
local_bounds = geom.bounds
local_size = np.max(local_bounds[1] - local_bounds[0])
if local_size > 0.5 * global_size: # 说明已在全局坐标
print(f" {sample_name}: 子网格 {name} 似乎已是全局坐标 → 直接使用")
verts_world = geom.vertices.copy()
else:
print(f"⚠️ {sample_name}: 子网格 {name} 无变换路径且看似局部坐标 → 仍使用原始顶点(可能错位)")
verts_world = geom.vertices.copy()
# ✅ 应用全局变换
verts_world = trimesh.transform_points(geom.vertices, transform)
# ✅ 生成全局 mesh 进行采样
mesh_world = trimesh.Trimesh(vertices=verts_world, faces=geom.faces, process=False)
p, c, l = sample_mesh(mesh_world, name, n_samples)
if len(p) == 0:
continue
all_points.append(p)
all_colors.append(c)
all_labels.append(l)
@ -108,15 +110,14 @@ def process_glb(glb_path, sample_name):
all_colors.append(c)
all_labels.append(l)
if len(all_points) == 0:
raise RuntimeError(f"{sample_name}: 没有有效点被采样!")
points = np.vstack(all_points)
colors = np.vstack(all_colors)
labels = np.concatenate(all_labels)
if len(points) > N_total:
idx = np.random.choice(len(points), N_total, replace=False)
points, colors, labels = points[idx], colors[idx], labels[idx]
# ✅ 一次性中心化 + 统一缩放
# ✅ 中心化 + 统一缩放
center = points.mean(axis=0)
points_centered = points - center
points_scaled = points_centered / global_size
@ -127,10 +128,9 @@ def process_glb(glb_path, sample_name):
csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv")
df.to_csv(csv_file, index=False)
print(f"✅ 已导出 {csv_file} (修正组织错位, 中心化 & 统一比例)")
print(f"✅ 已导出 {csv_file} (自动检测 + 兼容修复)")
return df
# ==== Step 3: 生成 h5ad ====
def create_h5ad(df, sample_name):
points = df[['x','y','z']].to_numpy()