-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathrun_diffusion.py
More file actions
71 lines (56 loc) · 1.79 KB
/
run_diffusion.py
File metadata and controls
71 lines (56 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
import hydra
from seq_models.sample import sample_outer_loop
@hydra.main(config_path="../../configs", config_name="sample")
def main(config):
if config.ckpt_path is None:
raise ValueError("Must specify a checkpoint path")
if config.seeds_fn is None:
raise ValueError("Must specify a seeds file")
if config.results_dir is None:
raise ValueError("Must specify a results directory")
if not os.path.exists(config.results_dir):
os.makedirs(config.results_dir)
model = hydra.utils.instantiate(config.model, _recursive_=False)
if config.ckpt_path is not None:
state_dict = torch.load(config.ckpt_path)['state_dict']
result = model.load_state_dict(state_dict, strict=False)
if len(result.missing_keys) > 0:
raise ValueError(f"Missing keys: {result.missing_keys}")
elif len(result.unexpected_keys) > 0:
print(f"Unexpected keys: {result.unexpected_keys}")
if torch.cuda.is_available():
model.cuda()
model.eval()
model_tag = ''
if 'mlm' in config.model['_target_']:
model_tag += 'mlm'
elif 'gaussian' in config.model['_target_']:
model_tag += 'gaussian'
numbering_schemes = ["chothia", "aho"]
cdr_combos = [
["hcdr1"],
["hcdr2"],
["hcdr3"],
["hcdr1", "hcdr2", "hcdr3"],
["lcdr1"],
["lcdr2"],
["lcdr3"],
]
sampling_kwargs_list = [
{"fixed_length": True},
{"fixed_length": False},
]
sample_outer_loop(
model,
model_tag,
config.results_dir,
config.seeds_fn,
config.vocab_file,
numbering_schemes,
cdr_combos,
sampling_kwargs_list,
)
if __name__ == "__main__":
main()