digital-embryo/embryo-backend/Data/glbTrans.py
2025-07-27 17:57:58 +08:00

174 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import trimesh
import pandas as pd
import numpy as np
import anndata as ad
from collections import defaultdict
# ==== 参数 ====
script_dir = os.path.dirname(os.path.realpath(__file__))
glb_files = [f"CS{i}.glb" for i in range(11,24,1)]
N_total = 30000
use_surface_sampling = False
n_genes = 1000
n_markers_per_label = 10
# ==== 读取真实基因名 ====
gene_list_file = os.path.join(script_dir, "human_protein_coding_genes.txt")
assert os.path.exists(gene_list_file), f"❌ 缺少基因名文件 {gene_list_file}"
all_gene_names = pd.read_csv(gene_list_file, header=None)[0].tolist()
gene_names = all_gene_names[:n_genes]
# ==== 真实 marker 基因 ====
true_markers = {
"Ectoderm": ["SOX2", "PAX6", "NES", "TUBB3", "OTX2"],
"Mesoderm": ["TBXT", "MESP1", "HAND1", "GATA4", "PDGFRA"],
"Endoderm": ["SOX17", "FOXA2", "GATA6", "CXCR4", "HNF1B"],
"Notochord": ["SHH", "NOG", "CHRD", "FOXA2", "Brachyury"],
"NeuralTube": ["OLIG2", "NKX6-1", "HOXB4", "PAX3", "SNAI2"]
}
# ==== 采样函数 ====
def sample_mesh(mesh, label, n_samples):
if use_surface_sampling:
points, face_idx = trimesh.sample.sample_surface(mesh, n_samples)
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
face_colors = mesh.visual.vertex_colors[mesh.faces]
colors = face_colors.mean(axis=1)[face_idx, :3]
else:
mat_color = np.array([200, 200, 200])
colors = np.tile(mat_color, (n_samples, 1))
else:
verts = mesh.vertices
if len(verts) > n_samples:
idx = np.random.choice(len(verts), n_samples, replace=False)
verts = verts[idx]
else:
idx = np.arange(len(verts))
if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices):
colors = mesh.visual.vertex_colors[idx, :3]
else:
mat_color = np.array([200, 200, 200])
colors = np.tile(mat_color, (len(verts), 1))
points = verts
labels = np.array([label] * len(points))
return points, colors, labels
# ==== Step 1: 计算全局最大边长 ====
all_bounds = []
for glb_name in glb_files:
glb_path = os.path.join(script_dir, glb_name)
if not os.path.exists(glb_path):
continue
scene = trimesh.load(glb_path)
if isinstance(scene, trimesh.Scene):
for geom in scene.geometry.values():
all_bounds.append(geom.bounds) # (min,max)
else:
all_bounds.append(scene.bounds)
all_bounds = np.array(all_bounds)
global_min = np.min(all_bounds[:,0,:], axis=0)
global_max = np.max(all_bounds[:,1,:], axis=0)
global_size = np.max(global_max - global_min)
print(f"🌍 统一缩放基准: global_size={global_size}")
# ==== Step 2: 处理单个 GLB ====
def process_glb(glb_path, sample_name):
scene = trimesh.load(glb_path)
all_points, all_colors, all_labels = [], [], []
if isinstance(scene, trimesh.Scene):
total_points = sum(len(g.vertices) for g in scene.geometry.values())
for name, geom in scene.geometry.items():
ratio = len(geom.vertices) / total_points
n_samples = max(1, int(N_total * ratio))
p, c, l = sample_mesh(geom, name, n_samples)
all_points.append(p)
all_colors.append(c)
all_labels.append(l)
else:
p, c, l = sample_mesh(scene, "mesh", N_total)
all_points.append(p)
all_colors.append(c)
all_labels.append(l)
points = np.vstack(all_points)
colors = np.vstack(all_colors)
labels = np.concatenate(all_labels)
if len(points) > N_total:
idx = np.random.choice(len(points), N_total, replace=False)
points, colors, labels = points[idx], colors[idx], labels[idx]
# ✅ 平移到中心点
center = points.mean(axis=0)
points_centered = points - center
# ✅ 统一缩放(保持原始比例)
points_scaled = points_centered / global_size
# ✅ 如需修正坐标系方向可启用以下行示例交换Y和Z
# points_scaled = points_scaled[:, [0, 2, 1]]
# points_scaled[:, 1] *= -1
df = pd.DataFrame(points_scaled, columns=["x","y","z"])
df["r"], df["g"], df["b"] = colors[:,0], colors[:,1], colors[:,2]
df["label"] = labels
csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv")
df.to_csv(csv_file, index=False)
print(f"✅ 已导出 {csv_file} (中心化 & 统一大小, 保持比例)")
return df
# ==== Step 3: 生成 h5ad ====
def create_h5ad(df, sample_name):
points = df[['x','y','z']].to_numpy()
cell_types = df['label'].to_numpy()
unique_labels = sorted(set(cell_types))
marker_genes = {}
all_marker_set = set()
for ct in unique_labels:
valid_markers = [g for g in true_markers.get(ct, []) if g in gene_names]
while len(valid_markers) < n_markers_per_label:
g = np.random.choice(gene_names)
if g not in valid_markers:
valid_markers.append(g)
marker_genes[ct] = valid_markers
all_marker_set.update(valid_markers)
np.random.seed(42)
expr_matrix = np.random.poisson(lam=1.0, size=(len(df), len(gene_names))).astype(np.float32)
for ct in unique_labels:
cell_idx = np.where(cell_types == ct)[0]
marker_idx = [gene_names.index(g) for g in marker_genes[ct]]
expr_matrix[np.ix_(cell_idx, marker_idx)] += np.random.poisson(lam=5.0, size=(len(cell_idx), len(marker_idx)))
obs = pd.DataFrame(index=[f"{sample_name}_cell{i}" for i in range(len(df))])
obs["cell_type"] = cell_types
var = pd.DataFrame(index=gene_names)
var["is_marker"] = ["yes" if g in all_marker_set else "no" for g in gene_names]
adata = ad.AnnData(X=expr_matrix, obs=obs, var=var)
adata.obsm["spatial"] = points
output_h5ad = os.path.join(script_dir, f"{sample_name}.h5ad")
adata.write(output_h5ad)
with open(os.path.join(script_dir, f"{sample_name}_marker_genes.txt"), "w") as f:
for ct, genes in marker_genes.items():
f.write(f"{ct} : {', '.join(genes)}\n")
print(f"✅ 已生成 {output_h5ad} ({len(df)} cells × {len(gene_names)} genes)")
# ==== Step 4: 批量运行 ====
for glb_name in glb_files:
glb_path = os.path.join(script_dir, glb_name)
if os.path.exists(glb_path):
sample_name = os.path.splitext(glb_name)[0]
df = process_glb(glb_path, sample_name)
create_h5ad(df, sample_name)
else:
print(f"⚠️ 未找到 {glb_name},跳过")