import os import trimesh import pandas as pd import numpy as np import anndata as ad # ==== 参数 ==== script_dir = os.path.dirname(os.path.realpath(__file__)) glb_files = [f"CS{i}.glb" for i in range(19,20,1)] N_total = 30000 use_surface_sampling = False n_genes = 1000 n_markers_per_label = 10 # ==== 读取真实基因名 ==== gene_list_file = os.path.join(script_dir, "human_protein_coding_genes.txt") assert os.path.exists(gene_list_file), f"❌ 缺少基因名文件 {gene_list_file}" all_gene_names = pd.read_csv(gene_list_file, header=None)[0].tolist() gene_names = all_gene_names[:n_genes] # ==== marker 基因 ==== true_markers = { "Ectoderm": ["SOX2", "PAX6", "NES", "TUBB3", "OTX2"], "Mesoderm": ["TBXT", "MESP1", "HAND1", "GATA4", "PDGFRA"], "Endoderm": ["SOX17", "FOXA2", "GATA6", "CXCR4", "HNF1B"], "Notochord": ["SHH", "NOG", "CHRD", "FOXA2", "Brachyury"], "NeuralTube": ["OLIG2", "NKX6-1", "HOXB4", "PAX3", "SNAI2"] } # ==== 采样函数 ==== def sample_mesh(mesh, label, n_samples): if len(mesh.vertices) == 0: print(f"⚠️ {label} 无顶点,跳过") return np.zeros((0,3)), np.zeros((0,3)), np.array([]) if mesh.faces is None or len(mesh.faces) == 0 or not use_surface_sampling: idx = np.random.choice(len(mesh.vertices), min(n_samples, len(mesh.vertices)), replace=False) points = mesh.vertices[idx] else: points, face_idx = trimesh.sample.sample_surface(mesh, n_samples) if hasattr(mesh.visual, "vertex_colors") and len(mesh.visual.vertex_colors) == len(mesh.vertices): colors = mesh.visual.vertex_colors[:len(points), :3] else: mat_color = np.array([200,200,200]) colors = np.tile(mat_color, (len(points),1)) labels = np.array([label] * len(points)) return points, colors, labels # ==== Step 1: 计算全局边界 ==== all_bounds = [] for glb_name in glb_files: glb_path = os.path.join(script_dir, glb_name) if not os.path.exists(glb_path): continue scene = trimesh.load(glb_path) if isinstance(scene, trimesh.Scene): for geom in scene.geometry.values(): all_bounds.append(geom.bounds) else: all_bounds.append(scene.bounds) all_bounds = np.array(all_bounds) global_min = np.min(all_bounds[:,0,:], axis=0) global_max = np.max(all_bounds[:,1,:], axis=0) global_size = np.max(global_max - global_min) print(f"🌍 统一缩放基准: global_size={global_size}") # ==== Step 2: 处理单个 GLB ==== def process_glb(glb_path, sample_name): scene = trimesh.load(glb_path) all_points, all_colors, all_labels = [], [], [] if isinstance(scene, trimesh.Scene): total_points = sum(len(g.vertices) for g in scene.geometry.values()) for name, geom in scene.geometry.items(): if len(geom.vertices) == 0: print(f"⚠️ {sample_name}: 子网格 {name} 无顶点,跳过") continue ratio = len(geom.vertices) / total_points if total_points > 0 else 0 n_samples = max(100, int(N_total * ratio)) # ✅ 最少100点 # ✅ 自动检测 & 兼容修复 try: transform = scene.graph.get(name)[0] verts_world = trimesh.transform_points(geom.vertices, transform) except Exception: # 判断顶点是否已经全局化 local_bounds = geom.bounds local_size = np.max(local_bounds[1] - local_bounds[0]) if local_size > 0.5 * global_size: # 说明已在全局坐标 print(f"ℹ️ {sample_name}: 子网格 {name} 似乎已是全局坐标 → 直接使用") verts_world = geom.vertices.copy() else: print(f"⚠️ {sample_name}: 子网格 {name} 无变换路径且看似局部坐标 → 仍使用原始顶点(可能错位)") verts_world = geom.vertices.copy() mesh_world = trimesh.Trimesh(vertices=verts_world, faces=geom.faces, process=False) p, c, l = sample_mesh(mesh_world, name, n_samples) if len(p) == 0: continue all_points.append(p) all_colors.append(c) all_labels.append(l) else: p, c, l = sample_mesh(scene, "mesh", N_total) all_points.append(p) all_colors.append(c) all_labels.append(l) if len(all_points) == 0: raise RuntimeError(f"❌ {sample_name}: 没有有效点被采样!") points = np.vstack(all_points) colors = np.vstack(all_colors) labels = np.concatenate(all_labels) # ✅ 中心化 + 统一缩放 center = points.mean(axis=0) points_centered = points - center points_scaled = points_centered / global_size df = pd.DataFrame(points_scaled, columns=["x","y","z"]) df["r"], df["g"], df["b"] = colors[:,0], colors[:,1], colors[:,2] df["label"] = labels csv_file = os.path.join(script_dir, f"{sample_name}_point_cloud_30000_centered_scaled.csv") df.to_csv(csv_file, index=False) print(f"✅ 已导出 {csv_file} (自动检测 + 兼容修复)") return df # ==== Step 3: 生成 h5ad ==== def create_h5ad(df, sample_name): points = df[['x','y','z']].to_numpy() cell_types = df['label'].to_numpy() unique_labels = sorted(set(cell_types)) marker_genes = {} all_marker_set = set() for ct in unique_labels: valid_markers = [g for g in true_markers.get(ct, []) if g in gene_names] while len(valid_markers) < n_markers_per_label: g = np.random.choice(gene_names) if g not in valid_markers: valid_markers.append(g) marker_genes[ct] = valid_markers all_marker_set.update(valid_markers) np.random.seed(42) expr_matrix = np.random.poisson(lam=1.0, size=(len(df), len(gene_names))).astype(np.float32) for ct in unique_labels: cell_idx = np.where(cell_types == ct)[0] marker_idx = [gene_names.index(g) for g in marker_genes[ct]] expr_matrix[np.ix_(cell_idx, marker_idx)] += np.random.poisson(lam=5.0, size=(len(cell_idx), len(marker_idx))) obs = pd.DataFrame(index=[f"{sample_name}_cell{i}" for i in range(len(df))]) obs["cell_type"] = cell_types var = pd.DataFrame(index=gene_names) var["is_marker"] = ["yes" if g in all_marker_set else "no" for g in gene_names] adata = ad.AnnData(X=expr_matrix, obs=obs, var=var) adata.obsm["spatial"] = points output_h5ad = os.path.join(script_dir, f"{sample_name}.h5ad") adata.write(output_h5ad) with open(os.path.join(script_dir, f"{sample_name}_marker_genes.txt"), "w") as f: for ct, genes in marker_genes.items(): f.write(f"{ct} : {', '.join(genes)}\n") print(f"✅ 已生成 {output_h5ad} ({len(df)} cells × {len(gene_names)} genes)") # ==== Step 4: 批量运行 ==== for glb_name in glb_files: glb_path = os.path.join(script_dir, glb_name) if os.path.exists(glb_path): sample_name = os.path.splitext(glb_name)[0] df = process_glb(glb_path, sample_name) create_h5ad(df, sample_name) else: print(f"⚠️ 未找到 {glb_name},跳过")