@@ -13,6 +13,8 @@ use git2::{
13
13
} ;
14
14
use scopetime:: scope_time;
15
15
16
+ pub const DEFAULT_REMOTE_NAME : & str = "origin" ;
17
+
16
18
///
17
19
#[ derive( Debug , Clone ) ]
18
20
pub enum ProgressNotification {
@@ -66,28 +68,45 @@ pub fn get_remotes(repo_path: &str) -> Result<Vec<String>> {
66
68
Ok ( remotes)
67
69
}
68
70
69
- ///
70
- pub fn get_first_remote ( repo_path : & str ) -> Result < String > {
71
+ /// tries to find origin or the only remote that is defined if any
72
+ /// in case of multiple remotes and none named *origin* we fail
73
+ pub fn get_default_remote ( repo_path : & str ) -> Result < String > {
71
74
let repo = utils:: repo ( repo_path) ?;
72
- get_first_remote_in_repo ( & repo)
75
+ get_default_remote_in_repo ( & repo)
73
76
}
74
77
75
- ///
76
- pub ( crate ) fn get_first_remote_in_repo (
78
+ /// see `get_default_remote`
79
+ pub ( crate ) fn get_default_remote_in_repo (
77
80
repo : & Repository ,
78
81
) -> Result < String > {
79
- scope_time ! ( "get_remotes " ) ;
82
+ scope_time ! ( "get_default_remote_in_repo " ) ;
80
83
81
84
let remotes = repo. remotes ( ) ?;
82
85
83
- let first_remote = remotes
84
- . iter ( )
85
- . next ( )
86
- . flatten ( )
87
- . map ( String :: from)
88
- . ok_or_else ( || Error :: Generic ( "no remote found" . into ( ) ) ) ?;
86
+ // if `origin` exists return that
87
+ let found_origin = remotes. iter ( ) . any ( |r| {
88
+ r. map ( |r| r == DEFAULT_REMOTE_NAME ) . unwrap_or_default ( )
89
+ } ) ;
90
+ if found_origin {
91
+ return Ok ( DEFAULT_REMOTE_NAME . into ( ) ) ;
92
+ }
89
93
90
- Ok ( first_remote)
94
+ //if only one remote exists pick that
95
+ if remotes. len ( ) == 1 {
96
+ let first_remote = remotes
97
+ . iter ( )
98
+ . next ( )
99
+ . flatten ( )
100
+ . map ( String :: from)
101
+ . ok_or_else ( || {
102
+ Error :: Generic ( "no remote found" . into ( ) )
103
+ } ) ?;
104
+
105
+ return Ok ( first_remote) ;
106
+ }
107
+
108
+ //inconclusive
109
+ Err ( Error :: NoDefaultRemoteFound )
91
110
}
92
111
93
112
///
@@ -96,7 +115,7 @@ pub fn fetch_origin(repo_path: &str, branch: &str) -> Result<usize> {
96
115
97
116
let repo = utils:: repo ( repo_path) ?;
98
117
let mut remote =
99
- repo. find_remote ( & get_first_remote_in_repo ( & repo) ?) ?;
118
+ repo. find_remote ( & get_default_remote_in_repo ( & repo) ?) ?;
100
119
101
120
let mut options = FetchOptions :: new ( ) ;
102
121
options. remote_callbacks ( remote_callbacks ( None , None ) ) ;
@@ -288,7 +307,7 @@ mod tests {
288
307
}
289
308
290
309
#[ test]
291
- fn test_first_remote ( ) {
310
+ fn test_default_remote ( ) {
292
311
let td = TempDir :: new ( ) . unwrap ( ) ;
293
312
294
313
debug_cmd_print (
@@ -311,7 +330,44 @@ mod tests {
311
330
vec![ String :: from( "origin" ) , String :: from( "second" ) ]
312
331
) ;
313
332
314
- let first = get_first_remote_in_repo (
333
+ let first = get_default_remote_in_repo (
334
+ & utils:: repo ( repo_path) . unwrap ( ) ,
335
+ )
336
+ . unwrap ( ) ;
337
+ assert_eq ! ( first, String :: from( "origin" ) ) ;
338
+ }
339
+
340
+ #[ test]
341
+ fn test_default_remote_out_of_order ( ) {
342
+ let td = TempDir :: new ( ) . unwrap ( ) ;
343
+
344
+ debug_cmd_print (
345
+ td. path ( ) . as_os_str ( ) . to_str ( ) . unwrap ( ) ,
346
+ "git clone https://github.com/extrawurst/brewdump.git" ,
347
+ ) ;
348
+
349
+ debug_cmd_print (
350
+ td. path ( ) . as_os_str ( ) . to_str ( ) . unwrap ( ) ,
351
+ "cd brewdump && git remote rename origin alternate" ,
352
+ ) ;
353
+
354
+ debug_cmd_print (
355
+ td. path ( ) . as_os_str ( ) . to_str ( ) . unwrap ( ) ,
356
+ "cd brewdump && git remote add origin https://github.com/extrawurst/brewdump.git" ,
357
+ ) ;
358
+
359
+ let repo_path = td. path ( ) . join ( "brewdump" ) ;
360
+ let repo_path = repo_path. as_os_str ( ) . to_str ( ) . unwrap ( ) ;
361
+
362
+ //NOTE: aparently remotes are not chronolically sorted but alphabetically
363
+ let remotes = get_remotes ( repo_path) . unwrap ( ) ;
364
+
365
+ assert_eq ! (
366
+ remotes,
367
+ vec![ String :: from( "alternate" ) , String :: from( "origin" ) ]
368
+ ) ;
369
+
370
+ let first = get_default_remote_in_repo (
315
371
& utils:: repo ( repo_path) . unwrap ( ) ,
316
372
)
317
373
. unwrap ( ) ;
0 commit comments