-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathextract_m3.py
More file actions
144 lines (122 loc) · 5.53 KB
/
extract_m3.py
File metadata and controls
144 lines (122 loc) · 5.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import torch
import numpy as np
from tqdm import tqdm
from config import *
from utils import *
from samplings import *
from accelerate import Accelerator
from transformers import BertConfig, GPT2Config
import argparse
import requests
# Parse command-line arguments for input_dir and output_dir
parser = argparse.ArgumentParser(description="Process files to extract features.")
parser.add_argument("input_dir", type=str, help="Directory with input files")
parser.add_argument("output_dir", type=str, help="Directory to save extracted features")
args = parser.parse_args()
# Use args for input and output directories
input_dir = args.input_dir
output_dir = args.output_dir
# Collect input files
files = []
for root, dirs, fs in os.walk(input_dir):
for f in fs:
if f.endswith(".abc") or f.endswith(".mtf"):
files.append(os.path.join(root, f))
print(f"Found {len(files)} files in total")
# Initialize accelerator and device
accelerator = Accelerator()
device = accelerator.device
print("Using device:", device)
# Model and configuration setup
patchilizer = M3Patchilizer()
encoder_config = BertConfig(
vocab_size=1,
hidden_size=M3_HIDDEN_SIZE,
num_hidden_layers=PATCH_NUM_LAYERS,
num_attention_heads=M3_HIDDEN_SIZE // 64,
intermediate_size=M3_HIDDEN_SIZE * 4,
max_position_embeddings=PATCH_LENGTH,
)
decoder_config = GPT2Config(
vocab_size=128,
n_positions=PATCH_SIZE,
n_embd=M3_HIDDEN_SIZE,
n_layer=TOKEN_NUM_LAYERS,
n_head=M3_HIDDEN_SIZE // 64,
n_inner=M3_HIDDEN_SIZE * 4,
)
model = M3Model(encoder_config, decoder_config).to(device)
# print parameter number
print("Total Parameter Number: "+str(sum(p.numel() for p in model.parameters())))
# Load model weights
model.eval()
checkpoint_path = M3_WEIGHTS_PATH
if not os.path.exists(checkpoint_path):
print("No M3 weights found. Downloading from Hugging Face...")
checkpoint_url = "https://huggingface.co/sander-wood/clamp2/resolve/main/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth"
checkpoint_path = "weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth"
response = requests.get(checkpoint_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(checkpoint_path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
bar.update(len(chunk))
print("Weights file downloaded successfully.")
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
model.load_state_dict(checkpoint['model'])
def extract_feature(item):
"""Extracts features from input data."""
target_patches = patchilizer.encode(item, add_special_patches=True)
target_patches_list = [target_patches[i:i + PATCH_LENGTH] for i in range(0, len(target_patches), PATCH_LENGTH)]
target_patches_list[-1] = target_patches[-PATCH_LENGTH:]
last_hidden_states_list = []
for input_patches in target_patches_list:
input_masks = torch.tensor([1] * len(input_patches))
input_patches = torch.tensor(input_patches)
last_hidden_states = model.encoder(
input_patches.unsqueeze(0).to(device), input_masks.unsqueeze(0).to(device)
)["last_hidden_state"][0]
last_hidden_states_list.append(last_hidden_states)
# Handle the last segment padding correctly
last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(target_patches) % PATCH_LENGTH):]
return torch.concat(last_hidden_states_list, 0)
def process_directory(input_dir, output_dir, files):
"""Processes files in the input directory and saves features to the output directory."""
# Distribute files across processes for parallel processing
num_files_per_gpu = len(files) // accelerator.num_processes
start_idx = accelerator.process_index * num_files_per_gpu
end_idx = start_idx + num_files_per_gpu if accelerator.process_index < accelerator.num_processes - 1 else len(files)
files_to_process = files[start_idx:end_idx]
# Process each file
for file in tqdm(files_to_process):
output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
try:
os.makedirs(output_subdir, exist_ok=True)
except Exception as e:
print(f"{output_subdir} cannot be created\n{e}")
output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
if os.path.exists(output_file):
print(f"Skipping {file}, output already exists")
continue
try:
with open(file, "r", encoding="utf-8") as f:
item = f.read()
if not item.startswith("ticks_per_beat"):
item = item.replace("L:1/8\n", "")
with torch.no_grad():
features = extract_feature(item).unsqueeze(0)
np.save(output_file, features.detach().cpu().numpy())
except Exception as e:
print(f"Failed to process {file}: {e}")
# Process the directory
process_directory(input_dir, output_dir, files)