@@ -17,8 +17,8 @@ use std::os::unix::io::RawFd as RawSocket;
17
17
#[ cfg( windows) ]
18
18
use std:: os:: windows:: io:: RawSocket ;
19
19
use std:: ptr:: null_mut;
20
- use std:: sync:: Once ;
21
20
use std:: sync:: { Arc , Mutex , MutexGuard } ;
21
+ use std:: sync:: { Once , Weak } ;
22
22
use std:: time:: Duration ;
23
23
24
24
mod channel;
@@ -72,9 +72,12 @@ fn initialize() -> SshResult<()> {
72
72
}
73
73
74
74
pub ( crate ) struct SessionHolder {
75
+ outer : Weak < Mutex < SessionHolder > > ,
75
76
sess : sys:: ssh_session ,
76
77
callbacks : sys:: ssh_callbacks_struct ,
77
78
auth_callback : Option < Box < dyn FnMut ( & str , bool , bool , Option < String > ) -> SshResult < String > > > ,
79
+ channel_open_request_auth_agent_callback :
80
+ Option < Box < dyn FnMut ( Channel ) -> Result < ( ) , RequestAuthAgentError > > > ,
78
81
}
79
82
unsafe impl Send for SessionHolder { }
80
83
@@ -197,11 +200,16 @@ impl Session {
197
200
channel_open_request_x11_function : None ,
198
201
channel_open_request_auth_agent_function : None ,
199
202
} ;
200
- let sess = Arc :: new ( Mutex :: new ( SessionHolder {
201
- sess,
202
- callbacks,
203
- auth_callback : None ,
204
- } ) ) ;
203
+ let sess = Arc :: new_cyclic ( |outer| {
204
+ let outer = outer. clone ( ) ;
205
+ Mutex :: new ( SessionHolder {
206
+ outer,
207
+ sess,
208
+ callbacks,
209
+ auth_callback : None ,
210
+ channel_open_request_auth_agent_callback : None ,
211
+ } )
212
+ } ) ;
205
213
206
214
{
207
215
let mut sess = sess. lock ( ) . unwrap ( ) ;
@@ -274,6 +282,55 @@ impl Session {
274
282
}
275
283
}
276
284
285
+ unsafe extern "C" fn bridge_channel_open_request_auth_agent_callback (
286
+ session : sys:: ssh_session ,
287
+ userdata : * mut :: std:: os:: raw:: c_void ,
288
+ ) -> sys:: ssh_channel {
289
+ let result = std:: panic:: catch_unwind ( || -> SshResult < sys:: ssh_channel > {
290
+ let sess: & mut SessionHolder = & mut * ( userdata as * mut SessionHolder ) ;
291
+ assert ! (
292
+ std:: ptr:: eq( session, sess. sess) ,
293
+ "invalid callback invocation: session mismatch"
294
+ ) ;
295
+ let cb = sess
296
+ . channel_open_request_auth_agent_callback
297
+ . as_mut ( )
298
+ . unwrap ( ) ;
299
+ let chan = unsafe { sys:: ssh_channel_new ( session) } ;
300
+ if chan. is_null ( ) {
301
+ return Err ( sess
302
+ . last_error ( )
303
+ . unwrap_or_else ( || Error :: fatal ( "ssh_channel_new failed" ) ) ) ;
304
+ }
305
+ match cb ( Channel :: new ( & sess. outer . upgrade ( ) . unwrap ( ) , chan) ) {
306
+ // SAFETY: We steal a *mut sys::ssh_channel_struct here and let libssh
307
+ // temporarily "borrows" it for an unspecified amount of time.
308
+ // libssh is guaranteed to finish using it before returning from the outermost
309
+ // libssh function call that triggered this callback. As such function call
310
+ // always happens with Session locked and dropping Channel needs to lock the
311
+ // session first, we can be sure that this *mut sys::ssh_channel_struct will not
312
+ // be freed while libssh is still using it.
313
+ Ok ( _) => Ok ( chan) ,
314
+ Err ( RequestAuthAgentError ( err, mut chan_obj) ) => {
315
+ unsafe { sys:: ssh_channel_free ( chan_obj. chan_inner ) } ;
316
+ chan_obj. chan_inner = std:: ptr:: null_mut ( ) ;
317
+ Err ( err)
318
+ }
319
+ }
320
+ } ) ;
321
+ match result {
322
+ Err ( err) => {
323
+ eprintln ! ( "Panic in request auth agent callback: {:?}" , err) ;
324
+ std:: ptr:: null_mut ( )
325
+ }
326
+ Ok ( Err ( err) ) => {
327
+ eprintln ! ( "Error in request auth agent callback: {:#}" , err) ;
328
+ std:: ptr:: null_mut ( )
329
+ }
330
+ Ok ( Ok ( chan) ) => chan,
331
+ }
332
+ }
333
+
277
334
/// Sets a callback that is used by libssh when it needs to prompt
278
335
/// for the passphrase during public key authentication.
279
336
/// This is NOT used for password or keyboard interactive authentication.
@@ -326,6 +383,32 @@ impl Session {
326
383
sess. callbacks . auth_function = Some ( Self :: bridge_auth_callback) ;
327
384
}
328
385
386
+ /// Sets a callback that is used by libssh when the remote side requests a new channel
387
+ /// for SSH agent forwarding.
388
+ /// The callback has the signature:
389
+ ///
390
+ /// ```no_run
391
+ /// use libssh_rs::RequestAuthAgentResult;
392
+ /// fn callback(channel: Channel) -> RequestAuthAgentResult {
393
+ /// unimplemented!()
394
+ /// }
395
+ /// ```
396
+ ///
397
+ /// The callback should decide whether to allow the agent forward and if so, take ownership of
398
+ /// the channel (and further move it elsewhere to handle agent protocol within). Otherwise or
399
+ /// in case of an error, the callback should return the channel back as it is not possible to
400
+ /// drop it in the callback.
401
+ pub fn set_channel_open_request_auth_agent_callback < F > ( & self , callback : F )
402
+ where
403
+ F : FnMut ( Channel ) -> Result < ( ) , RequestAuthAgentError > + ' static ,
404
+ {
405
+ let mut sess = self . lock_session ( ) ;
406
+ sess. channel_open_request_auth_agent_callback
407
+ . replace ( Box :: new ( callback) ) ;
408
+ sess. callbacks . channel_open_request_auth_agent_function =
409
+ Some ( Self :: bridge_channel_open_request_auth_agent_callback) ;
410
+ }
411
+
329
412
/// Create a new channel.
330
413
/// Channels are used to handle I/O for commands and forwarded streams.
331
414
pub fn new_channel ( & self ) -> SshResult < Channel > {
0 commit comments