Skip to content

Commit 80205d4

Browse files
svyatonikascjones
authored andcommitted
Support multiple trailing arguments (#365)
* support multiple trailing arguments * Try to get the Windows build to work
1 parent ec5249e commit 80205d4

File tree

3 files changed

+102
-42
lines changed

3 files changed

+102
-42
lines changed

derive/src/to_delegate.rs

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,17 @@ impl RpcMethod {
173173
.map(|x| ident(&((x + 'a' as u8) as char).to_string()))
174174
.collect();
175175
let param_types = &param_types;
176-
let parse_params =
177-
// if the last argument is an `Option` then it can be made an optional 'trailing' argument
178-
if let Some(ref trailing) = param_types.iter().last().and_then(try_get_option) {
179-
self.params_with_trailing(trailing, param_types, tuple_fields)
176+
let parse_params = {
177+
// last arguments that are `Option`-s are optional 'trailing' arguments
178+
let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count();
179+
if trailing_args_num != 0 {
180+
self.params_with_trailing(trailing_args_num, param_types, tuple_fields)
180181
} else if param_types.is_empty() {
181182
quote! { let params = params.expect_no_params(); }
182183
} else {
183184
quote! { let params = params.parse::<(#(#param_types, )*)>(); }
184-
};
185+
}
186+
};
185187

186188
let method_ident = self.ident();
187189
let result = &self.trait_item.sig.decl.output;
@@ -257,48 +259,52 @@ impl RpcMethod {
257259

258260
fn params_with_trailing(
259261
&self,
260-
trailing: &syn::Type,
262+
trailing_args_num: usize,
261263
param_types: &[syn::Type],
262264
tuple_fields: &[syn::Ident],
263265
) -> proc_macro2::TokenStream {
264-
let param_types_no_trailing: Vec<_> =
265-
param_types.iter().cloned().filter(|arg| arg != trailing).collect();
266-
let tuple_fields_no_trailing: &Vec<_> =
267-
&tuple_fields.iter().take(tuple_fields.len() - 1).collect();
268-
let num = param_types_no_trailing.len();
269-
let all_params_len = param_types.len();
270-
let no_trailing_branch =
271-
if all_params_len > 1 {
272-
quote! {
273-
params.parse::<(#(#param_types_no_trailing, )*)>()
274-
.map( |(#(#tuple_fields_no_trailing, )*)|
275-
(#(#tuple_fields_no_trailing, )* None))
276-
.map_err(Into::into)
266+
let total_args_num = param_types.len();
267+
let required_args_num = total_args_num - trailing_args_num;
268+
269+
let switch_branches = (0..trailing_args_num+1)
270+
.map(|passed_trailing_args_num| {
271+
let passed_args_num = required_args_num + passed_trailing_args_num;
272+
let passed_param_types = &param_types[..passed_args_num];
273+
let passed_tuple_fields = &tuple_fields[..passed_args_num];
274+
let missed_args_num = total_args_num - passed_args_num;
275+
let missed_params_values = ::std::iter::repeat(quote! { None }).take(missed_args_num).collect::<Vec<_>>();
276+
277+
if passed_args_num == 0 {
278+
quote! {
279+
#passed_args_num => params.expect_no_params()
280+
.map(|_| (#(#missed_params_values, ) *))
281+
.map_err(Into::into)
282+
}
283+
} else {
284+
quote! {
285+
#passed_args_num => params.parse::<(#(#passed_param_types, )*)>()
286+
.map(|(#(#passed_tuple_fields,)*)|
287+
(#(#passed_tuple_fields, )* #(#missed_params_values, )*))
288+
.map_err(Into::into)
289+
}
277290
}
278-
} else if all_params_len == 1 {
279-
quote! ( Ok((None,)) )
280-
} else {
281-
panic!("Should be at least one trailing param; qed")
282-
};
291+
}).collect::<Vec<_>>();
292+
283293
quote! {
284-
let params_len = match params {
294+
let passed_args_num = match params {
285295
_jsonrpc_core::Params::Array(ref v) => Ok(v.len()),
286296
_jsonrpc_core::Params::None => Ok(0),
287297
_ => Err(_jsonrpc_core::Error::invalid_params("`params` should be an array"))
288298
};
289299

290-
let params = params_len.and_then(|len| {
291-
match len.checked_sub(#num) {
292-
Some(0) => #no_trailing_branch,
293-
Some(1) => params.parse::<(#(#param_types, )*) > ()
294-
.map( |(#(#tuple_fields_no_trailing, )* id,)|
295-
(#(#tuple_fields_no_trailing, )* id,))
296-
.map_err(Into::into),
297-
None => Err(_jsonrpc_core::Error::invalid_params(
298-
format!("`params` should have at least {} argument(s)", #num))),
300+
let params = passed_args_num.and_then(|passed_args_num| {
301+
match passed_args_num {
302+
_ if passed_args_num < #required_args_num => Err(_jsonrpc_core::Error::invalid_params(
303+
format!("`params` should have at least {} argument(s)", #required_args_num))),
304+
#(#switch_branches),*,
299305
_ => Err(_jsonrpc_core::Error::invalid_params_with_details(
300-
format!("Expected {} or {} parameters.", #num, #num + 1),
301-
format!("Got: {}", len))),
306+
format!("Expected from {} to {} parameters.", #required_args_num, #total_args_num),
307+
format!("Got: {}", passed_args_num))),
302308
}
303309
});
304310
}
@@ -318,15 +324,14 @@ fn ident(s: &str) -> syn::Ident {
318324
syn::Ident::new(s, proc_macro2::Span::call_site())
319325
}
320326

321-
fn try_get_option(ty: &syn::Type) -> Option<syn::Type> {
327+
fn is_option_type(ty: &syn::Type) -> bool {
322328
if let syn::Type::Path(path) = ty {
323329
path.path.segments
324330
.first()
325-
.and_then(|t| {
326-
if t.value().ident == "Option" { Some(ty.clone()) } else { None }
327-
})
331+
.map(|t| t.value().ident == "Option")
332+
.unwrap_or(false)
328333
} else {
329-
None
334+
false
330335
}
331336
}
332337

derive/tests/compiletests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn run_mode(mode: &'static str) {
77

88
config.mode = mode.parse().expect("Invalid mode");
99
config.src_base = PathBuf::from(format!("tests/{}", mode));
10-
config.link_deps(); // Populate config.target_rustcflags with dependencies on the path
10+
config.target_rustcflags = Some("-L ../target/debug/ -L ../target/debug/deps/".to_owned());
1111
config.clean_rmeta(); // If your tests import the parent crate, this helps with E0464
1212

1313
compiletest::run_tests(&config);

derive/tests/trailing.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ pub trait Rpc {
1111
/// Echos back the message, example of a single param trailing
1212
#[rpc(name = "echo")]
1313
fn echo(&self, _: Option<String>) -> Result<String>;
14+
15+
/// Adds up to three numbers and returns a result
16+
#[rpc(name = "add_multi")]
17+
fn add_multi(&self, _: Option<u64>, _: Option<u64>, _: Option<u64>) -> Result<u64>;
1418
}
1519

1620
#[derive(Default)]
@@ -24,6 +28,10 @@ impl Rpc for RpcImpl {
2428
fn echo(&self, x: Option<String>) -> Result<String> {
2529
Ok(x.unwrap_or("".into()))
2630
}
31+
32+
fn add_multi(&self, a: Option<u64>, b: Option<u64>, c: Option<u64>) -> Result<u64> {
33+
Ok(a.unwrap_or_default() + b.unwrap_or_default() + c.unwrap_or_default())
34+
}
2735
}
2836

2937
#[test]
@@ -92,3 +100,50 @@ fn should_accept_single_trailing_param() {
92100
"id": 1
93101
}"#).unwrap());
94102
}
103+
104+
#[test]
105+
fn should_accept_multiple_trailing_params() {
106+
let mut io = IoHandler::new();
107+
let rpc = RpcImpl::default();
108+
io.extend_with(rpc.to_delegate());
109+
110+
// when
111+
let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[]}"#;
112+
let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[1]}"#;
113+
let req3 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[1, 2]}"#;
114+
let req4 = r#"{"jsonrpc":"2.0","id":1,"method":"add_multi","params":[1, 2, 3]}"#;
115+
116+
let res1 = io.handle_request_sync(req1);
117+
let res2 = io.handle_request_sync(req2);
118+
let res3 = io.handle_request_sync(req3);
119+
let res4 = io.handle_request_sync(req4);
120+
121+
// then
122+
let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap();
123+
assert_eq!(result1, serde_json::from_str(r#"{
124+
"jsonrpc": "2.0",
125+
"result": 0,
126+
"id": 1
127+
}"#).unwrap());
128+
129+
let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap();
130+
assert_eq!(result2, serde_json::from_str(r#"{
131+
"jsonrpc": "2.0",
132+
"result": 1,
133+
"id": 1
134+
}"#).unwrap());
135+
136+
let result3: Response = serde_json::from_str(&res3.unwrap()).unwrap();
137+
assert_eq!(result3, serde_json::from_str(r#"{
138+
"jsonrpc": "2.0",
139+
"result": 3,
140+
"id": 1
141+
}"#).unwrap());
142+
143+
let result4: Response = serde_json::from_str(&res4.unwrap()).unwrap();
144+
assert_eq!(result4, serde_json::from_str(r#"{
145+
"jsonrpc": "2.0",
146+
"result": 6,
147+
"id": 1
148+
}"#).unwrap());
149+
}

0 commit comments

Comments
 (0)