Skip to content

Commit bcc2fa1

Browse files
authored
Improve UX with private repos (#3065)
* Fetch and check credentials stored on the server side if no credentials provided via the command line (otherwise, check the provided credentials as usual) * Detect the default branch using the provided or stored credentials Closes: #3061
1 parent 13c67fd commit bcc2fa1

File tree

6 files changed

+193
-121
lines changed

6 files changed

+193
-121
lines changed

src/dstack/_internal/cli/commands/init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from dstack._internal.cli.commands import BaseCommand
77
from dstack._internal.cli.services.repos import (
88
get_repo_from_dir,
9-
get_repo_from_url,
109
is_git_repo_url,
1110
register_init_repo_args,
1211
)
1312
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
1413
from dstack._internal.core.errors import ConfigurationError
14+
from dstack._internal.core.models.repos.remote import RemoteRepo
1515
from dstack._internal.core.services.configs import ConfigManager
1616
from dstack.api import Client
1717

@@ -101,7 +101,7 @@ def _command(self, args: argparse.Namespace):
101101
if repo_url is not None:
102102
# Dummy repo branch to avoid autodetection that fails on private repos.
103103
# We don't need branch/hash for repo_id anyway.
104-
repo = get_repo_from_url(repo_url, repo_branch="master")
104+
repo = RemoteRepo.from_url(repo_url, repo_branch="master")
105105
elif repo_path is not None:
106106
repo = get_repo_from_dir(repo_path, local=local)
107107
else:

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args
1818
from dstack._internal.cli.services.repos import (
1919
get_repo_from_dir,
20-
get_repo_from_url,
2120
init_default_virtual_repo,
2221
is_git_repo_url,
2322
register_init_repo_args,
@@ -43,13 +42,19 @@
4342
ServiceConfiguration,
4443
TaskConfiguration,
4544
)
45+
from dstack._internal.core.models.repos import RepoHeadWithCreds
4646
from dstack._internal.core.models.repos.base import Repo
4747
from dstack._internal.core.models.repos.local import LocalRepo
48+
from dstack._internal.core.models.repos.remote import RemoteRepo, RemoteRepoCreds
4849
from dstack._internal.core.models.resources import CPUSpec
4950
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
5051
from dstack._internal.core.services.configs import ConfigManager
5152
from dstack._internal.core.services.diff import diff_models
52-
from dstack._internal.core.services.repos import load_repo
53+
from dstack._internal.core.services.repos import (
54+
InvalidRepoCredentialsError,
55+
get_repo_creds_and_default_branch,
56+
load_repo,
57+
)
5358
from dstack._internal.utils.common import local_time
5459
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
5560
from dstack._internal.utils.logging import get_logger
@@ -535,15 +540,17 @@ def get_repo(
535540
return init_default_virtual_repo(api=self.api)
536541

537542
repo: Optional[Repo] = None
543+
repo_head: Optional[RepoHeadWithCreds] = None
538544
repo_branch: Optional[str] = configurator_args.repo_branch
539545
repo_hash: Optional[str] = configurator_args.repo_hash
546+
repo_creds: Optional[RemoteRepoCreds] = None
547+
git_identity_file: Optional[str] = configurator_args.git_identity_file
548+
git_private_key: Optional[str] = None
549+
oauth_token: Optional[str] = configurator_args.gh_token
540550
# Should we (re)initialize the repo?
541551
# If any Git credentials provided, we reinitialize the repo, as the user may have provided
542552
# updated credentials.
543-
init = (
544-
configurator_args.git_identity_file is not None
545-
or configurator_args.gh_token is not None
546-
)
553+
init = git_identity_file is not None or oauth_token is not None
547554

548555
url: Optional[str] = None
549556
local_path: Optional[Path] = None
@@ -576,15 +583,15 @@ def get_repo(
576583
local_path = Path.cwd()
577584
legacy_local_path = True
578585
if url:
579-
repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash)
580-
if not self.api.repos.is_initialized(repo, by_user=True):
581-
init = True
586+
# "master" is a dummy value, we'll fetch the actual default branch later
587+
repo = RemoteRepo.from_url(repo_url=url, repo_branch="master")
588+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
582589
elif local_path:
583590
if legacy_local_path:
584591
if repo_config := config_manager.get_repo_config(local_path):
585592
repo = load_repo(repo_config)
586-
# allow users with legacy configurations use shared repo creds
587-
if self.api.repos.is_initialized(repo, by_user=False):
593+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
594+
if repo_head is not None:
588595
warn(
589596
"The repo is not specified but found and will be used in the run\n"
590597
"Future versions will not load repos automatically\n"
@@ -611,20 +618,55 @@ def get_repo(
611618
)
612619
local: bool = configurator_args.local
613620
repo = get_repo_from_dir(local_path, local=local)
614-
if not self.api.repos.is_initialized(repo, by_user=True):
615-
init = True
621+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
622+
if isinstance(repo, RemoteRepo):
623+
repo_branch = repo.run_repo_data.repo_branch
624+
repo_hash = repo.run_repo_data.repo_hash
616625
else:
617626
assert False, "should not reach here"
618627

619628
if repo is None:
620629
return init_default_virtual_repo(api=self.api)
621630

631+
if isinstance(repo, RemoteRepo):
632+
assert repo.repo_url is not None
633+
634+
if repo_head is not None and repo_head.repo_creds is not None:
635+
if git_identity_file is None and oauth_token is None:
636+
git_private_key = repo_head.repo_creds.private_key
637+
oauth_token = repo_head.repo_creds.oauth_token
638+
else:
639+
init = True
640+
641+
try:
642+
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
643+
repo_url=repo.repo_url,
644+
identity_file=git_identity_file,
645+
private_key=git_private_key,
646+
oauth_token=oauth_token,
647+
)
648+
except InvalidRepoCredentialsError as e:
649+
raise CLIError(*e.args) from e
650+
651+
if repo_branch is None and repo_hash is None:
652+
repo_branch = default_repo_branch
653+
if repo_branch is None:
654+
raise CLIError(
655+
"Failed to automatically detect remote repo branch."
656+
" Specify branch or hash."
657+
)
658+
repo = RemoteRepo.from_url(
659+
repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash
660+
)
661+
622662
if init:
623663
self.api.repos.init(
624664
repo=repo,
625-
git_identity_file=configurator_args.git_identity_file,
626-
oauth_token=configurator_args.gh_token,
665+
git_identity_file=git_identity_file,
666+
oauth_token=oauth_token,
667+
creds=repo_creds,
627668
)
669+
628670
if isinstance(repo, LocalRepo):
629671
warn(
630672
f"{repo.repo_dir} is a local repo\n"

src/dstack/_internal/cli/services/repos.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
from typing import Literal, Optional, Union, overload
2+
from typing import Literal, Union, overload
33

44
import git
55

@@ -8,7 +8,6 @@
88
from dstack._internal.core.models.repos.local import LocalRepo
99
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
1010
from dstack._internal.core.models.repos.virtual import VirtualRepo
11-
from dstack._internal.core.services.repos import get_default_branch
1211
from dstack._internal.utils.path import PathLike
1312
from dstack.api._public import Client
1413

@@ -43,22 +42,6 @@ def init_default_virtual_repo(api: Client) -> VirtualRepo:
4342
return repo
4443

4544

46-
def get_repo_from_url(
47-
repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None
48-
) -> RemoteRepo:
49-
if repo_branch is None and repo_hash is None:
50-
repo_branch = get_default_branch(repo_url)
51-
if repo_branch is None:
52-
raise CLIError(
53-
"Failed to automatically detect remote repo branch. Specify branch or hash."
54-
)
55-
return RemoteRepo.from_url(
56-
repo_url=repo_url,
57-
repo_branch=repo_branch,
58-
repo_hash=repo_hash,
59-
)
60-
61-
6245
@overload
6346
def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ...
6447

0 commit comments

Comments
 (0)