digital-embryo/embryo-backend/Data/glbTrans.py
2025-07-27 19:18:17 +08:00

184 lines
7.1 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import trimesh
import pandas as pd
import numpy as np
import anndata as ad
# ==== 参数 ====
script_dir = os.path.dirname(os.path.realpath(__file__))
glb_files = [f"CS{i}.glb" for i in range(11,24,1)]
N_total = 30000
use_surface_sampling = False
n_genes = 1000
n_markers_per_label = 10
# ==== 读取真实基因名 ====
gene_list_file = os.path.join(script_dir, "human_protein_coding_genes.txt")
assert os.path.exists(gene_list_file), f"❌ 缺少基因名文件 {gene_list_file}"
all_gene_names = pd.read_csv(gene_list_file, header=None)[0].tolist()
gene_names = all_gene_names[:n_genes]
# ==== marker 基因 ====
true_markers = {
"Ectoderm": ["SOX2", "PAX6", "NES", "TUBB3", "OTX2"],
"Mesoderm": ["TBXT", "MESP1", "HAND1", "GATA4", "PDGFRA"],
"Endoderm": ["SOX17", "FOXA2", "GATA6", "CXCR4", "HNF1B"],
"Notochord": ["SHH", "NOG", "CHRD", "FOXA2", "Brachyury"],
"NeuralTube": ["OLIG2", "NKX6-1", "HOXB4", "PAX3", "SNAI2"]
}
# ==== 采样函数 ====
def sample_mesh(mesh, label, n_samples):
if len(mesh.vertices) == 0:
print(f"⚠️ {label} 无顶点,跳过")
return np.zeros((0,3)), np.zeros((0,3)), np.array([])
if mesh.faces is None or len(mesh.faces) == 0 or not use_surface_sampling:
idx = np.random.choice(len(mesh.vertices), min(n_samples, len(mesh.vertices)), replace=False)
points = mesh.vertices[idx]
else:
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
colors = mesh.visual.vertex_colors[:len(points), :3]
else:
mat_color = np.array([200,200,200])
colors = np.tile(mat_color, (len(points),1))
labels = np.array([label] * len(points))
return points, colors, labels
# ==== Step 1: 计算全局边界 ====
all_bounds = []
for glb_name in glb_files:
glb_path = os.path.join(script_dir, glb_name)
if not os.path.exists(glb_path):
continue
scene = trimesh.load(glb_path)
if isinstance(scene, trimesh.Scene):
for geom in scene.geometry.values():
all_bounds.append(geom.bounds)
else:
all_bounds.append(scene.bounds)
all_bounds = np.array(all_bounds)
global_min = np.min(all_bounds[:,0,:], axis=0)
global_max = np.max(all_bounds[:,1,:], axis=0)
global_size = np.max(global_max - global_min)
print(f"🌍 统一缩放基准: global_size={global_size}")
# ==== Step 2: 处理单个 GLB ====
def process_glb(glb_path, sample_name):
scene = trimesh.load(glb_path)
all_points, all_colors, all_labels = [], [], []
if isinstance(scene, trimesh.Scene):
total_points = sum(len(g.vertices) for g in scene.geometry.values())
for name, geom in scene.geometry.items():
if len(geom.vertices) == 0:
print(f"⚠️ {sample_name}: 子网格 {name} 无顶点,跳过")
continue
ratio = len(geom.vertices) / total_points if total_points > 0 else 0
n_samples = max(100, int(N_total * ratio)) # ✅ 最少100点
# ✅ 自动检测 & 兼容修复
try:
transform = scene.graph.get(name)[0]
verts_world = trimesh.transform_points(geom.vertices, transform)
except Exception:
# 判断顶点是否已经全局化
local_bounds = geom.bounds
local_size = np.max(local_bounds[1] - local_bounds[0])
if local_size > 0.5 * global_size: # 说明已在全局坐标
print(f" {sample_name}: 子网格 {name} 似乎已是全局坐标 → 直接使用")
verts_world = geom.vertices.copy()
else:
print(f"⚠️ {sample_name}: 子网格 {name} 无变换路径且看似局部坐标 → 仍使用原始顶点(可能错位)")
verts_world = geom.vertices.copy()
mesh_world = trimesh.Trimesh(vertices=verts_world, faces=geom.faces, process=False)
p, c, l = sample_mesh(mesh_world, name, n_samples)
if len(p) == 0:
continue
all_points.append(p)
all_colors.append(c)
all_labels.append(l)
else:
p, c, l = sample_mesh(scene, "mesh", N_total)
all_points.append(p)
all_colors.append(c)
all_labels.append(l)
if len(all_points) == 0:
raise RuntimeError(f"{sample_name}: 没有有效点被采样!")
points = np.vstack(all_points)
colors = np.vstack(all_colors)
labels = np.concatenate(all_labels)
# ✅ 中心化 + 统一缩放
center = points.mean(axis=0)
points_centered = points - center
points_scaled = points_centered / global_size
df = pd.DataFrame(points_scaled, columns=["x","y","z"])
df["r"], df["g"], df["b"] = colors[:,0], colors[:,1], colors[:,2]
df["label"] = labels
csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv")
df.to_csv(csv_file, index=False)
print(f"✅ 已导出 {csv_file} (自动检测 + 兼容修复)")
return df
# ==== Step 3: 生成 h5ad ====
def create_h5ad(df, sample_name):
points = df[['x','y','z']].to_numpy()
cell_types = df['label'].to_numpy()
unique_labels = sorted(set(cell_types))
marker_genes = {}
all_marker_set = set()
for ct in unique_labels:
valid_markers = [g for g in true_markers.get(ct, []) if g in gene_names]
while len(valid_markers) < n_markers_per_label:
g = np.random.choice(gene_names)
if g not in valid_markers:
valid_markers.append(g)
marker_genes[ct] = valid_markers
all_marker_set.update(valid_markers)
np.random.seed(42)
expr_matrix = np.random.poisson(lam=1.0, size=(len(df), len(gene_names))).astype(np.float32)
for ct in unique_labels:
cell_idx = np.where(cell_types == ct)[0]
marker_idx = [gene_names.index(g) for g in marker_genes[ct]]
expr_matrix[np.ix_(cell_idx, marker_idx)] += np.random.poisson(lam=5.0, size=(len(cell_idx), len(marker_idx)))
obs = pd.DataFrame(index=[f"{sample_name}_cell{i}" for i in range(len(df))])
obs["cell_type"] = cell_types
var = pd.DataFrame(index=gene_names)
var["is_marker"] = ["yes" if g in all_marker_set else "no" for g in gene_names]
adata = ad.AnnData(X=expr_matrix, obs=obs, var=var)
adata.obsm["spatial"] = points
output_h5ad = os.path.join(script_dir, f"{sample_name}.h5ad")
adata.write(output_h5ad)
with open(os.path.join(script_dir, f"{sample_name}_marker_genes.txt"), "w") as f:
for ct, genes in marker_genes.items():
f.write(f"{ct} : {', '.join(genes)}\n")
print(f"✅ 已生成 {output_h5ad} ({len(df)} cells × {len(gene_names)} genes)")
# ==== Step 4: 批量运行 ====
for glb_name in glb_files:
glb_path = os.path.join(script_dir, glb_name)
if os.path.exists(glb_path):
sample_name = os.path.splitext(glb_name)[0]
df = process_glb(glb_path, sample_name)
create_h5ad(df, sample_name)
else:
print(f"⚠️ 未找到 {glb_name},跳过")