trainer 测试

This commit is contained in:
frozenleaves
2025-08-04 12:11:10 +08:00
parent 8871b6bdea
commit 770ac79815
4 changed files with 32 additions and 19 deletions

1
README.md Normal file
View File

@ -0,0 +1 @@
### 太玄:预测分子中原子的三维空间分布</font>

View File

@ -22,12 +22,12 @@ from taixuan.atom_types import PDB_ATOM_TYPES
def main():
project_root = Path(__file__).resolve().parent.parent
train_data_dir = project_root / 'data' / 'train'
train_data_dir = project_root / 'data' / 'val'
assert train_data_dir.exists(), f"训练数据目录不存在: {train_data_dir}"
# 仅取前两个文件实现过拟合
dataset = PdbDataset(str(train_data_dir), max_seq_len=500)
dataset.valid_files = dataset.valid_files[:2]
dataset.valid_files = dataset.valid_files[:1]
# 轻量化模型配置
config = TaixuanConfig(
@ -40,6 +40,7 @@ def main():
dropout=0.1,
)
model = TaixuanForSequenceToSequence(config)
model.load_state_dict(torch.load(project_root / 'example_output' / 'overfit_model' / 'model.pt'))
output_dir = project_root / 'example_output'
output_dir.mkdir(exist_ok=True)

View File

@ -14,7 +14,7 @@ from taixuan.data import PdbReader
def main():
project_root = Path(__file__).resolve().parent.parent
model_dir = project_root / 'example_output' / 'overfit_model'
pdb_file = next((project_root / 'data' / 'train').glob('*.pdb*')) # 取第一个训练样本
pdb_file = next((project_root / 'data' / 'val').glob('*.pdb*')) # 取第一个训练样本
print(f"加载模型: {model_dir}\n使用PDB文件: {pdb_file}")
predictor = MolecularPredictor(pretrained_model_path=str(model_dir))

View File

@ -37,6 +37,14 @@ class MolecularPredictor:
self.config = config
self.atom_types = PDB_ATOM_TYPES
# 设备自动检测
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif hasattr(torch, "mps") and torch.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
# 创建模型并加载预训练权重
if pretrained_model_path:
# 尝试从预训练模型目录加载配置
@ -44,6 +52,9 @@ class MolecularPredictor:
self.model = TaixuanForSequenceToSequence.from_pretrained(pretrained_model_path, config)
else:
self.model = TaixuanForSequenceToSequence(config)
# 将模型移动到目标设备
self.model.to(self.device)
def _load_config_from_pretrained(self, model_path: str, default_config: TaixuanConfig) -> TaixuanConfig:
"""
@ -141,18 +152,18 @@ class MolecularPredictor:
# 输入序列ID这里我们使用原子类型在vocab中的索引
atom_type_to_idx = {atom: idx for idx, atom in enumerate(self.atom_types)}
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in validated_atom_types]],
dtype=torch.long)
input_ids = input_ids.repeat(batch_size, 1) # (batch_size, seq_len)
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in validated_atom_types]], dtype=torch.long)
input_ids = input_ids.repeat(batch_size, 1)
# 注意力掩码
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
# 初始坐标(与训练时保持一致,使用零坐标作为初始值)
# 注意:这里使用零坐标而不是随机坐标,因为训练时使用的是真实坐标
# 在预测时,我们期望模型能够从零坐标预测到合理的分子结构
initial_coords = torch.zeros(batch_size, seq_len, 3, dtype=torch.float32)
# 移动到设备
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
initial_coords = initial_coords.to(self.device)
# 4. 模型推理
self.model.eval()
with torch.no_grad():
@ -164,7 +175,7 @@ class MolecularPredictor:
)
# 5. 提取预测坐标
predicted_coords = outputs['predicted_coords'] # (batch_size, seq_len, 3)
predicted_coords = outputs['predicted_coords'].cpu() # (batch_size, seq_len, 3)
# 6. 构建结果
sequence = "".join(validated_atom_types) # 为了向后兼容
@ -253,18 +264,18 @@ class MolecularPredictor:
# 输入序列ID这里我们使用原子类型在vocab中的索引
atom_type_to_idx = {atom: idx for idx, atom in enumerate(self.atom_types)}
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in atom_types]],
dtype=torch.long)
input_ids = input_ids.repeat(batch_size, 1) # (batch_size, seq_len)
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in atom_types]], dtype=torch.long)
input_ids = input_ids.repeat(batch_size, 1)
# 注意力掩码
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
# 初始坐标(与训练时保持一致,使用零坐标作为初始值)
# 注意:这里使用零坐标而不是随机坐标,因为训练时使用的是真实坐标
# 在预测时,我们期望模型能够从零坐标预测到合理的分子结构
initial_coords = torch.zeros(batch_size, seq_len, 3, dtype=torch.float32)
# 移动到设备
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
initial_coords = initial_coords.to(self.device)
# 4. 模型推理
self.model.eval()
with torch.no_grad():
@ -276,7 +287,7 @@ class MolecularPredictor:
)
# 5. 提取预测坐标
predicted_coords = outputs['predicted_coords'] # (batch_size, seq_len, 3)
predicted_coords = outputs['predicted_coords'].cpu() # (batch_size, seq_len, 3)
# 6. 构建结果
results = {