@@ -17,8 +17,8 @@ use std::os::unix::io::RawFd as RawSocket;
17
17
#[ cfg( windows) ]
18
18
use std:: os:: windows:: io:: RawSocket ;
19
19
use std:: ptr:: null_mut;
20
- use std:: sync:: Once ;
21
20
use std:: sync:: { Arc , Mutex , MutexGuard } ;
21
+ use std:: sync:: { Once , Weak } ;
22
22
use std:: time:: Duration ;
23
23
24
24
mod channel;
@@ -72,9 +72,12 @@ fn initialize() -> SshResult<()> {
72
72
}
73
73
74
74
pub ( crate ) struct SessionHolder {
75
+ outer : Weak < Mutex < SessionHolder > > ,
75
76
sess : sys:: ssh_session ,
76
77
callbacks : sys:: ssh_callbacks_struct ,
77
78
auth_callback : Option < Box < dyn FnMut ( & str , bool , bool , Option < String > ) -> SshResult < String > > > ,
79
+ channel_open_request_auth_agent_callback :
80
+ Option < Box < dyn FnMut ( Channel ) -> RequestAuthAgentResult > > ,
78
81
}
79
82
unsafe impl Send for SessionHolder { }
80
83
@@ -197,11 +200,16 @@ impl Session {
197
200
channel_open_request_x11_function : None ,
198
201
channel_open_request_auth_agent_function : None ,
199
202
} ;
200
- let sess = Arc :: new ( Mutex :: new ( SessionHolder {
201
- sess,
202
- callbacks,
203
- auth_callback : None ,
204
- } ) ) ;
203
+ let sess = Arc :: new_cyclic ( |outer| {
204
+ let outer = outer. clone ( ) ;
205
+ Mutex :: new ( SessionHolder {
206
+ outer,
207
+ sess,
208
+ callbacks,
209
+ auth_callback : None ,
210
+ channel_open_request_auth_agent_callback : None ,
211
+ } )
212
+ } ) ;
205
213
206
214
{
207
215
let mut sess = sess. lock ( ) . unwrap ( ) ;
@@ -274,6 +282,66 @@ impl Session {
274
282
}
275
283
}
276
284
285
+ unsafe extern "C" fn bridge_channel_open_request_auth_agent_callback (
286
+ session : sys:: ssh_session ,
287
+ userdata : * mut :: std:: os:: raw:: c_void ,
288
+ ) -> sys:: ssh_channel {
289
+ let result = std:: panic:: catch_unwind ( || -> SshResult < sys:: ssh_channel > {
290
+ let sess: & mut SessionHolder = & mut * ( userdata as * mut SessionHolder ) ;
291
+ assert ! (
292
+ std:: ptr:: eq( session, sess. sess) ,
293
+ "invalid callback invocation: session mismatch"
294
+ ) ;
295
+ let cb = sess
296
+ . channel_open_request_auth_agent_callback
297
+ . as_mut ( )
298
+ . unwrap ( ) ;
299
+ let chan = unsafe { sys:: ssh_channel_new ( session) } ;
300
+ if chan. is_null ( ) {
301
+ return Err ( sess
302
+ . last_error ( )
303
+ . unwrap_or_else ( || Error :: fatal ( "ssh_channel_new failed" ) ) ) ;
304
+ }
305
+ match cb ( Channel :: new ( & sess. outer . upgrade ( ) . unwrap ( ) , chan) ) {
306
+ // SAFETY: We steal a *mut sys::ssh_channel_struct here and let libssh
307
+ // temporarily "borrows" it for an unspecified amount of time.
308
+ // libssh is guaranteed to finish using it before returning from the outermost
309
+ // libssh function call that triggered this callback. As such function call
310
+ // always happens with Session locked and dropping Channel needs to lock the
311
+ // session first, we can be sure that this *mut sys::ssh_channel_struct will not
312
+ // be freed while libssh is still using it.
313
+ RequestAuthAgentResult :: Accept => Ok ( chan) ,
314
+ RequestAuthAgentResult :: Reject ( mut chan_obj) => {
315
+ unsafe { sys:: ssh_channel_free ( chan_obj. chan_inner ) } ;
316
+ chan_obj. chan_inner = std:: ptr:: null_mut ( ) ;
317
+ Err ( Error :: RequestDenied ( "request auth agent" . to_string ( ) ) )
318
+ }
319
+ RequestAuthAgentResult :: Err ( mut chan_obj, err) => {
320
+ unsafe { sys:: ssh_channel_free ( chan_obj. chan_inner ) } ;
321
+ chan_obj. chan_inner = std:: ptr:: null_mut ( ) ;
322
+ Err ( err)
323
+ }
324
+ }
325
+ } ) ;
326
+ match result {
327
+ Err ( err) => {
328
+ eprintln ! (
329
+ "Panic in channel open request auth agent callback: {:?}" ,
330
+ err
331
+ ) ;
332
+ std:: ptr:: null_mut ( )
333
+ }
334
+ Ok ( Err ( err) ) => {
335
+ eprintln ! (
336
+ "Error in channel open request auth agent callback: {:#}" ,
337
+ err
338
+ ) ;
339
+ std:: ptr:: null_mut ( )
340
+ }
341
+ Ok ( Ok ( chan) ) => chan,
342
+ }
343
+ }
344
+
277
345
/// Sets a callback that is used by libssh when it needs to prompt
278
346
/// for the passphrase during public key authentication.
279
347
/// This is NOT used for password or keyboard interactive authentication.
@@ -326,6 +394,32 @@ impl Session {
326
394
sess. callbacks . auth_function = Some ( Self :: bridge_auth_callback) ;
327
395
}
328
396
397
+ /// Sets a callback that is used by libssh when the remote side requests a new channel
398
+ /// for SSH agent forwarding.
399
+ /// The callback has the signature:
400
+ ///
401
+ /// ```no_run
402
+ /// use libssh_rs::RequestAuthAgentResult;
403
+ /// fn callback(channel: Channel) -> RequestAuthAgentResult {
404
+ /// unimplemented!()
405
+ /// }
406
+ /// ```
407
+ ///
408
+ /// The callback should decide whether to allow the agent forward and if so, take ownership of
409
+ /// the channel (and further move it elsewhere to handle agent protocol within). Otherwise or
410
+ /// in case of an error, the callback should return the channel back as it is not possible to
411
+ /// drop it in the callback.
412
+ pub fn set_channel_open_request_auth_agent_callback < F > ( & self , callback : F )
413
+ where
414
+ F : FnMut ( Channel ) -> RequestAuthAgentResult + ' static ,
415
+ {
416
+ let mut sess = self . lock_session ( ) ;
417
+ sess. channel_open_request_auth_agent_callback
418
+ . replace ( Box :: new ( callback) ) ;
419
+ sess. callbacks . channel_open_request_auth_agent_function =
420
+ Some ( Self :: bridge_channel_open_request_auth_agent_callback) ;
421
+ }
422
+
329
423
/// Create a new channel.
330
424
/// Channels are used to handle I/O for commands and forwarded streams.
331
425
pub fn new_channel ( & self ) -> SshResult < Channel > {
@@ -1421,6 +1515,12 @@ pub struct InteractiveAuthInfo {
1421
1515
pub prompts : Vec < InteractiveAuthPrompt > ,
1422
1516
}
1423
1517
1518
+ pub enum RequestAuthAgentResult {
1519
+ Accept ,
1520
+ Reject ( Channel ) ,
1521
+ Err ( Channel , Error ) ,
1522
+ }
1523
+
1424
1524
/// A utility function that will prompt the user for input
1425
1525
/// via the console/tty.
1426
1526
///
0 commit comments