trainer 测试
This commit is contained in:
@ -22,12 +22,12 @@ from taixuan.atom_types import PDB_ATOM_TYPES
|
||||
|
||||
def main():
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
train_data_dir = project_root / 'data' / 'train'
|
||||
train_data_dir = project_root / 'data' / 'val'
|
||||
assert train_data_dir.exists(), f"训练数据目录不存在: {train_data_dir}"
|
||||
|
||||
# 仅取前两个文件实现过拟合
|
||||
dataset = PdbDataset(str(train_data_dir), max_seq_len=500)
|
||||
dataset.valid_files = dataset.valid_files[:2]
|
||||
dataset.valid_files = dataset.valid_files[:1]
|
||||
|
||||
# 轻量化模型配置
|
||||
config = TaixuanConfig(
|
||||
@ -40,6 +40,7 @@ def main():
|
||||
dropout=0.1,
|
||||
)
|
||||
model = TaixuanForSequenceToSequence(config)
|
||||
model.load_state_dict(torch.load(project_root / 'example_output' / 'overfit_model' / 'model.pt'))
|
||||
|
||||
output_dir = project_root / 'example_output'
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
@ -14,7 +14,7 @@ from taixuan.data import PdbReader
|
||||
def main():
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
model_dir = project_root / 'example_output' / 'overfit_model'
|
||||
pdb_file = next((project_root / 'data' / 'train').glob('*.pdb*')) # 取第一个训练样本
|
||||
pdb_file = next((project_root / 'data' / 'val').glob('*.pdb*')) # 取第一个训练样本
|
||||
|
||||
print(f"加载模型: {model_dir}\n使用PDB文件: {pdb_file}")
|
||||
predictor = MolecularPredictor(pretrained_model_path=str(model_dir))
|
||||
|
@ -37,6 +37,14 @@ class MolecularPredictor:
|
||||
self.config = config
|
||||
self.atom_types = PDB_ATOM_TYPES
|
||||
|
||||
# 设备自动检测
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
elif hasattr(torch, "mps") and torch.mps.is_available():
|
||||
self.device = torch.device("mps")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# 创建模型并加载预训练权重
|
||||
if pretrained_model_path:
|
||||
# 尝试从预训练模型目录加载配置
|
||||
@ -44,6 +52,9 @@ class MolecularPredictor:
|
||||
self.model = TaixuanForSequenceToSequence.from_pretrained(pretrained_model_path, config)
|
||||
else:
|
||||
self.model = TaixuanForSequenceToSequence(config)
|
||||
|
||||
# 将模型移动到目标设备
|
||||
self.model.to(self.device)
|
||||
|
||||
def _load_config_from_pretrained(self, model_path: str, default_config: TaixuanConfig) -> TaixuanConfig:
|
||||
"""
|
||||
@ -141,18 +152,18 @@ class MolecularPredictor:
|
||||
|
||||
# 输入序列ID(这里我们使用原子类型在vocab中的索引)
|
||||
atom_type_to_idx = {atom: idx for idx, atom in enumerate(self.atom_types)}
|
||||
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in validated_atom_types]],
|
||||
dtype=torch.long)
|
||||
input_ids = input_ids.repeat(batch_size, 1) # (batch_size, seq_len)
|
||||
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in validated_atom_types]], dtype=torch.long)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
# 注意力掩码
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||
|
||||
# 初始坐标(与训练时保持一致,使用零坐标作为初始值)
|
||||
# 注意:这里使用零坐标而不是随机坐标,因为训练时使用的是真实坐标
|
||||
# 在预测时,我们期望模型能够从零坐标预测到合理的分子结构
|
||||
initial_coords = torch.zeros(batch_size, seq_len, 3, dtype=torch.float32)
|
||||
|
||||
# 移动到设备
|
||||
input_ids = input_ids.to(self.device)
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
initial_coords = initial_coords.to(self.device)
|
||||
|
||||
# 4. 模型推理
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
@ -164,7 +175,7 @@ class MolecularPredictor:
|
||||
)
|
||||
|
||||
# 5. 提取预测坐标
|
||||
predicted_coords = outputs['predicted_coords'] # (batch_size, seq_len, 3)
|
||||
predicted_coords = outputs['predicted_coords'].cpu() # (batch_size, seq_len, 3)
|
||||
|
||||
# 6. 构建结果
|
||||
sequence = "".join(validated_atom_types) # 为了向后兼容
|
||||
@ -253,18 +264,18 @@ class MolecularPredictor:
|
||||
|
||||
# 输入序列ID(这里我们使用原子类型在vocab中的索引)
|
||||
atom_type_to_idx = {atom: idx for idx, atom in enumerate(self.atom_types)}
|
||||
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in atom_types]],
|
||||
dtype=torch.long)
|
||||
input_ids = input_ids.repeat(batch_size, 1) # (batch_size, seq_len)
|
||||
input_ids = torch.tensor([[atom_type_to_idx.get(atom, 0) for atom in atom_types]], dtype=torch.long)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
# 注意力掩码
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||
|
||||
# 初始坐标(与训练时保持一致,使用零坐标作为初始值)
|
||||
# 注意:这里使用零坐标而不是随机坐标,因为训练时使用的是真实坐标
|
||||
# 在预测时,我们期望模型能够从零坐标预测到合理的分子结构
|
||||
initial_coords = torch.zeros(batch_size, seq_len, 3, dtype=torch.float32)
|
||||
|
||||
# 移动到设备
|
||||
input_ids = input_ids.to(self.device)
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
initial_coords = initial_coords.to(self.device)
|
||||
|
||||
# 4. 模型推理
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
@ -276,7 +287,7 @@ class MolecularPredictor:
|
||||
)
|
||||
|
||||
# 5. 提取预测坐标
|
||||
predicted_coords = outputs['predicted_coords'] # (batch_size, seq_len, 3)
|
||||
predicted_coords = outputs['predicted_coords'].cpu() # (batch_size, seq_len, 3)
|
||||
|
||||
# 6. 构建结果
|
||||
results = {
|
||||
|
Reference in New Issue
Block a user