h5ad generation
This commit is contained in:
parent
b39288fb9f
commit
c85b6b94f6
BIN
embryo-backend/Data/CS11.h5ad
(Stored with Git LFS)
BIN
embryo-backend/Data/CS11.h5ad
(Stored with Git LFS)
Binary file not shown.
BIN
embryo-backend/Data/CS12.h5ad
(Stored with Git LFS)
BIN
embryo-backend/Data/CS12.h5ad
(Stored with Git LFS)
Binary file not shown.
@ -3,7 +3,6 @@ import trimesh
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import anndata as ad
|
||||
from collections import defaultdict
|
||||
|
||||
# ==== 参数 ====
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
@ -30,32 +29,26 @@ true_markers = {
|
||||
|
||||
# ==== 采样函数 ====
|
||||
def sample_mesh(mesh, label, n_samples):
|
||||
"""采样点云 + 颜色 + 标签"""
|
||||
if use_surface_sampling:
|
||||
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
|
||||
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
|
||||
face_colors = mesh.visual.vertex_colors[mesh.faces]
|
||||
colors = face_colors.mean(axis=1)[face_idx, :3]
|
||||
else:
|
||||
mat_color = np.array([200, 200, 200])
|
||||
colors = np.tile(mat_color, (n_samples, 1))
|
||||
if len(mesh.vertices) == 0:
|
||||
print(f"⚠️ {label} 无顶点,跳过")
|
||||
return np.zeros((0,3)), np.zeros((0,3)), np.array([])
|
||||
|
||||
if mesh.faces is None or len(mesh.faces) == 0 or not use_surface_sampling:
|
||||
idx = np.random.choice(len(mesh.vertices), min(n_samples, len(mesh.vertices)), replace=False)
|
||||
points = mesh.vertices[idx]
|
||||
else:
|
||||
verts = mesh.vertices
|
||||
if len(verts) > n_samples:
|
||||
idx = np.random.choice(len(verts), n_samples, replace=False)
|
||||
verts = verts[idx]
|
||||
else:
|
||||
idx = np.arange(len(verts))
|
||||
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
|
||||
colors = mesh.visual.vertex_colors[idx, :3]
|
||||
else:
|
||||
mat_color = np.array([200, 200, 200])
|
||||
colors = np.tile(mat_color, (len(verts), 1))
|
||||
points = verts
|
||||
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
|
||||
|
||||
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
|
||||
colors = mesh.visual.vertex_colors[:len(points), :3]
|
||||
else:
|
||||
mat_color = np.array([200,200,200])
|
||||
colors = np.tile(mat_color, (len(points),1))
|
||||
|
||||
labels = np.array([label] * len(points))
|
||||
return points, colors, labels
|
||||
|
||||
# ==== Step 1: 计算全局最大边长 ====
|
||||
# ==== Step 1: 计算全局边界 ====
|
||||
all_bounds = []
|
||||
for glb_name in glb_files:
|
||||
glb_path = os.path.join(script_dir, glb_name)
|
||||
@ -74,7 +67,7 @@ global_max = np.max(all_bounds[:,1,:], axis=0)
|
||||
global_size = np.max(global_max - global_min)
|
||||
print(f"🌍 统一缩放基准: global_size={global_size}")
|
||||
|
||||
# ==== Step 2: 处理单个 GLB(应用全局变换) ====
|
||||
# ==== Step 2: 处理单个 GLB ====
|
||||
def process_glb(glb_path, sample_name):
|
||||
scene = trimesh.load(glb_path)
|
||||
all_points, all_colors, all_labels = [], [], []
|
||||
@ -82,23 +75,32 @@ def process_glb(glb_path, sample_name):
|
||||
if isinstance(scene, trimesh.Scene):
|
||||
total_points = sum(len(g.vertices) for g in scene.geometry.values())
|
||||
for name, geom in scene.geometry.items():
|
||||
ratio = len(geom.vertices) / total_points
|
||||
n_samples = max(1, int(N_total * ratio))
|
||||
if len(geom.vertices) == 0:
|
||||
print(f"⚠️ {sample_name}: 子网格 {name} 无顶点,跳过")
|
||||
continue
|
||||
|
||||
# ✅ 保护性获取变换矩阵
|
||||
ratio = len(geom.vertices) / total_points if total_points > 0 else 0
|
||||
n_samples = max(100, int(N_total * ratio)) # ✅ 最少100点
|
||||
|
||||
# ✅ 自动检测 & 兼容修复
|
||||
try:
|
||||
transform = scene.graph.get(name)[0]
|
||||
verts_world = trimesh.transform_points(geom.vertices, transform)
|
||||
except Exception:
|
||||
print(f"⚠️ {sample_name}: 子网格 {name} 没有变换路径,使用单位矩阵")
|
||||
transform = np.eye(4)
|
||||
# 判断顶点是否已经全局化
|
||||
local_bounds = geom.bounds
|
||||
local_size = np.max(local_bounds[1] - local_bounds[0])
|
||||
if local_size > 0.5 * global_size: # 说明已在全局坐标
|
||||
print(f"ℹ️ {sample_name}: 子网格 {name} 似乎已是全局坐标 → 直接使用")
|
||||
verts_world = geom.vertices.copy()
|
||||
else:
|
||||
print(f"⚠️ {sample_name}: 子网格 {name} 无变换路径且看似局部坐标 → 仍使用原始顶点(可能错位)")
|
||||
verts_world = geom.vertices.copy()
|
||||
|
||||
# ✅ 应用全局变换
|
||||
verts_world = trimesh.transform_points(geom.vertices, transform)
|
||||
|
||||
# ✅ 生成全局 mesh 进行采样
|
||||
mesh_world = trimesh.Trimesh(vertices=verts_world, faces=geom.faces, process=False)
|
||||
p, c, l = sample_mesh(mesh_world, name, n_samples)
|
||||
|
||||
if len(p) == 0:
|
||||
continue
|
||||
all_points.append(p)
|
||||
all_colors.append(c)
|
||||
all_labels.append(l)
|
||||
@ -108,15 +110,14 @@ def process_glb(glb_path, sample_name):
|
||||
all_colors.append(c)
|
||||
all_labels.append(l)
|
||||
|
||||
if len(all_points) == 0:
|
||||
raise RuntimeError(f"❌ {sample_name}: 没有有效点被采样!")
|
||||
|
||||
points = np.vstack(all_points)
|
||||
colors = np.vstack(all_colors)
|
||||
labels = np.concatenate(all_labels)
|
||||
|
||||
if len(points) > N_total:
|
||||
idx = np.random.choice(len(points), N_total, replace=False)
|
||||
points, colors, labels = points[idx], colors[idx], labels[idx]
|
||||
|
||||
# ✅ 一次性中心化 + 统一缩放
|
||||
# ✅ 中心化 + 统一缩放
|
||||
center = points.mean(axis=0)
|
||||
points_centered = points - center
|
||||
points_scaled = points_centered / global_size
|
||||
@ -127,10 +128,9 @@ def process_glb(glb_path, sample_name):
|
||||
|
||||
csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv")
|
||||
df.to_csv(csv_file, index=False)
|
||||
print(f"✅ 已导出 {csv_file} (修正组织错位, 中心化 & 统一比例)")
|
||||
print(f"✅ 已导出 {csv_file} (自动检测 + 兼容修复)")
|
||||
return df
|
||||
|
||||
|
||||
# ==== Step 3: 生成 h5ad ====
|
||||
def create_h5ad(df, sample_name):
|
||||
points = df[['x','y','z']].to_numpy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user