Compare commits

...

2 Commits

Author SHA1 Message Date
bc60273eca Merge branch 'main' of http://10.80.10.11/wjsjwr/digital-embryo 2025-07-27 19:20:08 +08:00
c85b6b94f6 h5ad generation 2025-07-27 19:18:17 +08:00
3 changed files with 44 additions and 44 deletions

BIN
embryo-backend/Data/CS11.h5ad (Stored with Git LFS)

Binary file not shown.

BIN
embryo-backend/Data/CS12.h5ad (Stored with Git LFS)

Binary file not shown.

View File

@ -3,7 +3,6 @@ import trimesh
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import anndata as ad import anndata as ad
from collections import defaultdict
# ==== 参数 ==== # ==== 参数 ====
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
@ -30,32 +29,26 @@ true_markers = {
# ==== 采样函数 ==== # ==== 采样函数 ====
def sample_mesh(mesh, label, n_samples): def sample_mesh(mesh, label, n_samples):
"""采样点云 + 颜色 + 标签""" if len(mesh.vertices) == 0:
if use_surface_sampling: print(f"⚠️ {label} 无顶点,跳过")
return np.zeros((0,3)), np.zeros((0,3)), np.array([])
if mesh.faces is None or len(mesh.faces) == 0 or not use_surface_sampling:
idx = np.random.choice(len(mesh.vertices), min(n_samples, len(mesh.vertices)), replace=False)
points = mesh.vertices[idx]
else:
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples) points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices): if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
face_colors = mesh.visual.vertex_colors[mesh.faces] colors = mesh.visual.vertex_colors[:len(points), :3]
colors = face_colors.mean(axis=1)[face_idx, :3]
else: else:
mat_color = np.array([200,200,200]) mat_color = np.array([200,200,200])
colors = np.tile(mat_color, (n_samples, 1)) colors = np.tile(mat_color, (len(points),1))
else:
verts = mesh.vertices
if len(verts) > n_samples:
idx = np.random.choice(len(verts), n_samples, replace=False)
verts = verts[idx]
else:
idx = np.arange(len(verts))
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
colors = mesh.visual.vertex_colors[idx, :3]
else:
mat_color = np.array([200, 200, 200])
colors = np.tile(mat_color, (len(verts), 1))
points = verts
labels = np.array([label] * len(points)) labels = np.array([label] * len(points))
return points, colors, labels return points, colors, labels
# ==== Step 1: 计算全局最大边长 ==== # ==== Step 1: 计算全局边界 ====
all_bounds = [] all_bounds = []
for glb_name in glb_files: for glb_name in glb_files:
glb_path = os.path.join(script_dir, glb_name) glb_path = os.path.join(script_dir, glb_name)
@ -74,7 +67,7 @@ global_max = np.max(all_bounds[:,1,:], axis=0)
global_size = np.max(global_max - global_min) global_size = np.max(global_max - global_min)
print(f"🌍 统一缩放基准: global_size={global_size}") print(f"🌍 统一缩放基准: global_size={global_size}")
# ==== Step 2: 处理单个 GLB(应用全局变换) ==== # ==== Step 2: 处理单个 GLB ====
def process_glb(glb_path, sample_name): def process_glb(glb_path, sample_name):
scene = trimesh.load(glb_path) scene = trimesh.load(glb_path)
all_points, all_colors, all_labels = [], [], [] all_points, all_colors, all_labels = [], [], []
@ -82,23 +75,32 @@ def process_glb(glb_path, sample_name):
if isinstance(scene, trimesh.Scene): if isinstance(scene, trimesh.Scene):
total_points = sum(len(g.vertices) for g in scene.geometry.values()) total_points = sum(len(g.vertices) for g in scene.geometry.values())
for name, geom in scene.geometry.items(): for name, geom in scene.geometry.items():
ratio = len(geom.vertices) / total_points if len(geom.vertices) == 0:
n_samples = max(1, int(N_total * ratio)) print(f"⚠️ {sample_name}: 子网格 {name} 无顶点,跳过")
continue
# ✅ 保护性获取变换矩阵 ratio = len(geom.vertices) / total_points if total_points > 0 else 0
n_samples = max(100, int(N_total * ratio)) # ✅ 最少100点
# ✅ 自动检测 & 兼容修复
try: try:
transform = scene.graph.get(name)[0] transform = scene.graph.get(name)[0]
except Exception:
print(f"⚠️ {sample_name}: 子网格 {name} 没有变换路径,使用单位矩阵")
transform = np.eye(4)
# ✅ 应用全局变换
verts_world = trimesh.transform_points(geom.vertices, transform) verts_world = trimesh.transform_points(geom.vertices, transform)
except Exception:
# 判断顶点是否已经全局化
local_bounds = geom.bounds
local_size = np.max(local_bounds[1] - local_bounds[0])
if local_size > 0.5 * global_size: # 说明已在全局坐标
print(f" {sample_name}: 子网格 {name} 似乎已是全局坐标 → 直接使用")
verts_world = geom.vertices.copy()
else:
print(f"⚠️ {sample_name}: 子网格 {name} 无变换路径且看似局部坐标 → 仍使用原始顶点(可能错位)")
verts_world = geom.vertices.copy()
# ✅ 生成全局 mesh 进行采样
mesh_world = trimesh.Trimesh(vertices=verts_world, faces=geom.faces, process=False) mesh_world = trimesh.Trimesh(vertices=verts_world, faces=geom.faces, process=False)
p, c, l = sample_mesh(mesh_world, name, n_samples) p, c, l = sample_mesh(mesh_world, name, n_samples)
if len(p) == 0:
continue
all_points.append(p) all_points.append(p)
all_colors.append(c) all_colors.append(c)
all_labels.append(l) all_labels.append(l)
@ -108,15 +110,14 @@ def process_glb(glb_path, sample_name):
all_colors.append(c) all_colors.append(c)
all_labels.append(l) all_labels.append(l)
if len(all_points) == 0:
raise RuntimeError(f"{sample_name}: 没有有效点被采样!")
points = np.vstack(all_points) points = np.vstack(all_points)
colors = np.vstack(all_colors) colors = np.vstack(all_colors)
labels = np.concatenate(all_labels) labels = np.concatenate(all_labels)
if len(points) > N_total: # ✅ 中心化 + 统一缩放
idx = np.random.choice(len(points), N_total, replace=False)
points, colors, labels = points[idx], colors[idx], labels[idx]
# ✅ 一次性中心化 + 统一缩放
center = points.mean(axis=0) center = points.mean(axis=0)
points_centered = points - center points_centered = points - center
points_scaled = points_centered / global_size points_scaled = points_centered / global_size
@ -127,10 +128,9 @@ def process_glb(glb_path, sample_name):
csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv") csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv")
df.to_csv(csv_file, index=False) df.to_csv(csv_file, index=False)
print(f"✅ 已导出 {csv_file} (修正组织错位, 中心化 & 统一比例)") print(f"✅ 已导出 {csv_file} (自动检测 + 兼容修复)")
return df return df
# ==== Step 3: 生成 h5ad ==== # ==== Step 3: 生成 h5ad ====
def create_h5ad(df, sample_name): def create_h5ad(df, sample_name):
points = df[['x','y','z']].to_numpy() points = df[['x','y','z']].to_numpy()