Skip to content

Commit e7fe4ce

Browse files
authored
[AutoModel] Fix bug with subfolders and local model paths when loading custom code (#13197)
* update * update
1 parent 3d90855 commit e7fe4ce

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,10 @@ def get_cached_module_file(
299299
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
300300
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
301301

302-
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
302+
if subfolder is not None:
303+
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
304+
else:
305+
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
303306

304307
if os.path.isfile(module_file_or_url):
305308
resolved_module_file = module_file_or_url
@@ -384,7 +387,11 @@ def get_cached_module_file(
384387
if not os.path.exists(submodule_path / module_folder):
385388
os.makedirs(submodule_path / module_folder)
386389
module_needed = f"{module_needed}.py"
387-
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
390+
if subfolder is not None:
391+
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
392+
else:
393+
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
394+
shutil.copyfile(source_path, submodule_path / module_needed)
388395
else:
389396
# Get the commit hash
390397
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.

tests/models/test_models_auto.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import json
2+
import os
3+
import tempfile
14
import unittest
25
from unittest.mock import MagicMock, patch
36

7+
import torch
48
from transformers import CLIPTextModel, LongformerModel
59

610
from diffusers.models import AutoModel, UNet2DConditionModel
@@ -35,6 +39,45 @@ def test_load_from_model_index(self):
3539
)
3640
assert isinstance(model, CLIPTextModel)
3741

42+
def test_load_dynamic_module_from_local_path_with_subfolder(self):
43+
CUSTOM_MODEL_CODE = (
44+
"import torch\n"
45+
"from diffusers import ModelMixin, ConfigMixin\n"
46+
"from diffusers.configuration_utils import register_to_config\n"
47+
"\n"
48+
"class CustomModel(ModelMixin, ConfigMixin):\n"
49+
" @register_to_config\n"
50+
" def __init__(self, hidden_size=8):\n"
51+
" super().__init__()\n"
52+
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
53+
"\n"
54+
" def forward(self, x):\n"
55+
" return self.linear(x)\n"
56+
)
57+
58+
with tempfile.TemporaryDirectory() as tmpdir:
59+
subfolder = "custom_model"
60+
model_dir = os.path.join(tmpdir, subfolder)
61+
os.makedirs(model_dir)
62+
63+
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
64+
f.write(CUSTOM_MODEL_CODE)
65+
66+
config = {
67+
"_class_name": "CustomModel",
68+
"_diffusers_version": "0.0.0",
69+
"auto_map": {"AutoModel": "modeling.CustomModel"},
70+
"hidden_size": 8,
71+
}
72+
with open(os.path.join(model_dir, "config.json"), "w") as f:
73+
json.dump(config, f)
74+
75+
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
76+
77+
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
78+
assert model.__class__.__name__ == "CustomModel"
79+
assert model.config["hidden_size"] == 8
80+
3881

3982
class TestAutoModelFromConfig(unittest.TestCase):
4083
@patch(

0 commit comments

Comments
 (0)