-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathrun_diffab.py
More file actions
163 lines (136 loc) · 4.35 KB
/
run_diffab.py
File metadata and controls
163 lines (136 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import glob
import yaml
import pprint
import subprocess
import pandas as pd
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1
import sys
sys.path.append("diffab")
import diffab.utils.protein.constants as constants
DIFFAB_NAME_MAP = {
"H1": "H_CDR1",
"H2": "H_CDR2",
"H3": "H_CDR3",
"L1": "L_CDR1",
"L2": "L_CDR2",
"L3": "L_CDR3",
}
DIFFAB_CDRS = {
'H1': constants.ChothiaCDRRange.H1,
'H2': constants.ChothiaCDRRange.H2,
'H3': constants.ChothiaCDRRange.H3,
'L1': constants.ChothiaCDRRange.L1,
'L2': constants.ChothiaCDRRange.L2,
'L3': constants.ChothiaCDRRange.L3,
}
def create_config(
cdr_ids,
results_dir,
tag="",
num_samples=10,
):
config_dir = "./diffab/configs/test"
if len(cdr_ids) == 1:
base_config = "codesign_single.yml"
else:
base_config = "codesign_multicdrs.yml"
config_fn = os.path.join(config_dir, base_config)
with open(config_fn, 'r') as fd:
config = yaml.safe_load(fd)
config["sampling"]["cdrs"] = [DIFFAB_NAME_MAP[cdr_id] for cdr_id in cdr_ids]
config["sampling"]["num_samples"] = num_samples
new_fn = os.path.join(results_dir, f"{tag}.yml")
with open(new_fn, 'w') as fd:
yaml.dump(config, fd, default_flow_style=False)
return new_fn
def run_diffab(
pdb_dir,
config_fn,
results_dir,
):
pdb_fns = glob.glob(os.path.join(pdb_dir, "*.pdb"))
pdb_fns = [x for x in pdb_fns if not "chothia" in x]
for pdb_fn in pdb_fns:
command = (
f"python design_pdb.py "
f"{pdb_fn} "
f"--config {config_fn} "
f"--out_root {results_dir}"
)
print(command)
p = subprocess.Popen(command.split(" "), cwd='./diffab')
p.wait()
def parse_chains(pdb_file):
structure = PDBParser().get_structure("", pdb_file)
chains = {
chain.id:seq1(''.join(residue.resname for residue in chain))
for chain in structure.get_chains()
}
return chains
def parse_diffab_results(
results_dir,
cdr_ids,
tag,
):
base_dir = os.path.join(results_dir, tag)
sample_dirs = glob.glob(os.path.join(base_dir, "*pdb*", "*"))
sample_dirs = [x for x in sample_dirs if os.path.isdir(x)]
out = []
for d in sample_dirs:
pdb_files = glob.glob(os.path.join(d, "*.pdb"))
ref_file = [x for x in pdb_files if "REF" in x][0]
seed_chains = parse_chains(ref_file)
vh_ranges = [
DIFFAB_CDRS[cdr_id] for cdr_id in cdr_ids if "H" in cdr_id
]
vl_ranges = [
DIFFAB_CDRS[cdr_id] for cdr_id in cdr_ids if "L" in cdr_id
]
vh_seed = seed_chains['H']
vl_seed = seed_chains['L']
vh_mask = len(vh_seed) * ["0"]
for vh_range in vh_ranges:
vh_mask[vh_range[0]:vh_range[1]] = (vh_range[1] - vh_range[0]) * ["1"]
vh_mask = "".join(vh_mask)
vl_mask = len(vl_seed) * ["0"]
for vl_range in vl_ranges:
vl_mask[vl_range[0]:vl_range[1]] = (vl_range[1] - vl_range[0]) * ["1"]
vl_mask = "".join(vl_mask)
sample_files = [x for x in pdb_files if not "REF" in x]
for i, pdb_file in enumerate(sample_files):
sample_chains = parse_chains(pdb_file)
out.append({
"vh_seed": seed_chains['H'],
"vl_seed": seed_chains['L'],
"sample_num": i,
"vh_sample": sample_chains['H'],
"vl_sample": sample_chains['L'],
"vh_mask": vh_mask,
"vl_mask": vl_mask,
})
df = pd.DataFrame(out)
return df
def main():
pdb_dir = "/home/nvg7279/src/seq-struct/poas_seed_pdbs"
results_dir = "/home/nvg7279/src/seq-struct/diffab_infill"
if not os.path.exists(results_dir):
os.makedirs(results_dir)
cdr_combos = [
["H1"],
["H2"],
["H3"],
["H1", "H2", "H3"],
["L1"],
["L2"],
["L3"],
]
for cdr_ids in cdr_combos:
tag = "_".join([cdr.lower() for cdr in cdr_ids])
config_fn = create_config(cdr_ids, results_dir, tag)
run_diffab(pdb_dir, config_fn, results_dir)
df = parse_diffab_results(results_dir, cdr_ids, tag)
df.to_csv(os.path.join(results_dir, tag + ".csv"), index=False)
if __name__ == "__main__":
main()