Skip to content

Commit d9cea51

Browse files
committed
Support list (de)serialization with any serde compatible format
1 parent 39bfcbf commit d9cea51

File tree

7 files changed

+225
-75
lines changed

7 files changed

+225
-75
lines changed

Cargo.lock

Lines changed: 32 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cfg-if = "1"
2121
cidr = { version = "0.2", features = ["serde"] }
2222
criterion = "0.5"
2323
dyn-clone = "1.0.20"
24+
erased-serde = "0.4.9"
2425
fnv = "1.0.6"
2526
getrandom = { version = "0.3" }
2627
indoc = "2"

engine/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,20 @@ backtrace.workspace = true
2424
cfg-if.workspace = true
2525
cidr.workspace = true
2626
dyn-clone.workspace = true
27+
erased-serde.workspace = true
2728
fnv.workspace = true
2829
memmem.workspace = true
2930
rand.workspace = true
3031
regex-automata = { workspace = true, optional = true }
3132
serde.workspace = true
32-
serde_json.workspace = true
3333
sliceslice.workspace = true
3434
thiserror.workspace = true
3535
wildcard.workspace = true
3636

3737
[dev-dependencies]
3838
criterion.workspace = true
3939
indoc.workspace = true
40+
serde_json.workspace = true
4041

4142
[features]
4243
default = [ "regex" ]

engine/src/ast/field_expr.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ mod tests {
818818
types::ExpectedType,
819819
};
820820
use cidr::IpCidr;
821+
use serde::Deserialize;
821822
use std::sync::LazyLock;
822823
use std::{convert::TryFrom, iter::once, net::IpAddr};
823824

@@ -950,12 +951,13 @@ mod tests {
950951
pub struct NumMListDefinition {}
951952

952953
impl ListDefinition for NumMListDefinition {
953-
fn matcher_from_json_value(
954+
fn deserialize_matcher<'de>(
954955
&self,
955956
_: Type,
956-
_: serde_json::Value,
957-
) -> Result<Box<dyn ListMatcher>, serde_json::Error> {
958-
Ok(Box::new(NumMatcher {}))
957+
deserializer: &mut dyn erased_serde::Deserializer<'de>,
958+
) -> Result<Box<dyn ListMatcher>, erased_serde::Error> {
959+
let matcher = erased_serde::deserialize::<NumMatcher>(deserializer)?;
960+
Ok(Box::new(matcher))
959961
}
960962

961963
fn new_matcher(&self) -> Box<dyn ListMatcher> {
@@ -2467,7 +2469,7 @@ mod tests {
24672469
);
24682470
}
24692471

2470-
#[derive(Debug, PartialEq, Eq, Serialize, Clone)]
2472+
#[derive(Debug, PartialEq, Eq, Serialize, Clone, Deserialize)]
24712473
pub struct NumMatcher {}
24722474

24732475
impl ListMatcher for NumMatcher {
@@ -2485,10 +2487,6 @@ mod tests {
24852487
}
24862488
}
24872489

2488-
fn to_json_value(&self) -> serde_json::Value {
2489-
serde_json::Value::Null
2490-
}
2491-
24922490
fn clear(&mut self) {}
24932491
}
24942492

@@ -2565,7 +2563,10 @@ mod tests {
25652563
assert_eq!(expr.execute_one(ctx), true);
25662564

25672565
let json = serde_json::to_string(ctx).unwrap();
2568-
assert_eq!(json, "{\"tcp.port\":1001,\"$lists\":[]}");
2566+
assert_eq!(
2567+
json,
2568+
"{\"tcp.port\":1001,\"$lists\":[{\"type\":\"Int\",\"data\":{}}]}"
2569+
);
25692570
}
25702571

25712572
#[test]

engine/src/execution_context.rs

Lines changed: 155 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use crate::{
33
scheme::{Field, List, Scheme, SchemeMismatchError},
44
types::{GetType, LhsValue, LhsValueSeed, Type, TypeMismatchError},
55
};
6-
use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, Visitor};
6+
use serde::Serialize;
7+
use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor};
78
use serde::ser::{SerializeMap, SerializeSeq, Serializer};
8-
use serde::{Deserialize, Serialize};
99
use std::borrow::Cow;
1010
use std::fmt;
1111
use std::fmt::Debug;
@@ -292,11 +292,142 @@ impl<U, T> Drop for ExecutionContextGuard<'_, '_, U, T> {
292292
}
293293
}
294294

295-
#[derive(Serialize, Deserialize)]
296-
struct ListData {
297-
#[serde(rename = "type")]
298-
ty: Type,
299-
data: serde_json::Value,
295+
struct ListMatcherData<'a>(ListRef<'a>);
296+
297+
impl<'de> DeserializeSeed<'de> for ListMatcherData<'_> {
298+
type Value = Box<dyn ListMatcher>;
299+
300+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
301+
where
302+
D: Deserializer<'de>,
303+
{
304+
use serde::de::Error;
305+
306+
let mut erased = <dyn erased_serde::Deserializer<'_>>::erase(deserializer);
307+
self.0
308+
.definition()
309+
.deserialize_matcher(self.0.get_type(), &mut erased)
310+
.map_err(D::Error::custom)
311+
}
312+
}
313+
314+
struct ListMatcherEntry<'a>(&'a Scheme, &'a mut [Box<dyn ListMatcher>]);
315+
316+
impl<'de> DeserializeSeed<'de> for ListMatcherEntry<'_> {
317+
type Value = ();
318+
319+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
320+
where
321+
D: Deserializer<'de>,
322+
{
323+
struct ListMatcherEntryVisitor<'a>(&'a Scheme, &'a mut [Box<dyn ListMatcher>]);
324+
325+
impl<'de> Visitor<'de> for ListMatcherEntryVisitor<'_> {
326+
type Value = ();
327+
328+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
329+
write!(formatter, "list matcher data")
330+
}
331+
332+
fn visit_map<M>(self, mut access: M) -> Result<(), M::Error>
333+
where
334+
M: MapAccess<'de>,
335+
{
336+
use serde::de::Error;
337+
338+
let Some(key) = access.next_key::<Cow<'_, str>>()? else {
339+
return Err(M::Error::missing_field("type"));
340+
};
341+
342+
if key != "type" {
343+
return Err(M::Error::unknown_field(&key, &["type", "data"]));
344+
}
345+
346+
let ty = access.next_value::<Type>()?;
347+
348+
let Some(list) = self.0.get_list(&ty) else {
349+
return Err(M::Error::custom(format!("no list defined for type {ty}")));
350+
};
351+
352+
let Some(key) = access.next_key::<Cow<'_, str>>()? else {
353+
return Err(M::Error::missing_field("data"));
354+
};
355+
356+
if key != "data" {
357+
return Err(M::Error::unknown_field(&key, &["type", "data"]));
358+
}
359+
360+
let matcher = access.next_value_seed(ListMatcherData(list))?;
361+
362+
self.1[list.index()] = matcher;
363+
364+
Ok(())
365+
}
366+
}
367+
368+
const FIELDS: &[&str] = &["type", "data"];
369+
deserializer.deserialize_struct(
370+
"ListMatcher",
371+
FIELDS,
372+
ListMatcherEntryVisitor(self.0, self.1),
373+
)
374+
}
375+
}
376+
377+
struct ListMatcherSlice<'a>(&'a Scheme, &'a mut [Box<dyn ListMatcher>]);
378+
379+
impl<'de> DeserializeSeed<'de> for ListMatcherSlice<'_> {
380+
type Value = ();
381+
382+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
383+
where
384+
D: Deserializer<'de>,
385+
{
386+
struct ListMatcherSliceVisitor<'a>(&'a Scheme, &'a mut [Box<dyn ListMatcher>]);
387+
388+
impl<'de> Visitor<'de> for ListMatcherSliceVisitor<'_> {
389+
type Value = ();
390+
391+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
392+
write!(formatter, "a list of list matcher data")
393+
}
394+
395+
fn visit_seq<S>(self, mut access: S) -> Result<(), S::Error>
396+
where
397+
S: SeqAccess<'de>,
398+
{
399+
while let Some(()) = access.next_element_seed(ListMatcherEntry(self.0, self.1))? {}
400+
401+
Ok(())
402+
}
403+
}
404+
405+
deserializer.deserialize_seq(ListMatcherSliceVisitor(self.0, self.1))
406+
}
407+
}
408+
409+
impl Serialize for ListMatcherSlice<'_> {
410+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
411+
where
412+
S: Serializer,
413+
{
414+
#[derive(Serialize)]
415+
struct TypedListMatcher<'a> {
416+
#[serde(rename = "type")]
417+
ty: Type,
418+
data: &'a dyn erased_serde::Serialize,
419+
}
420+
421+
let mut seq = serializer.serialize_seq(Some(self.1.len()))?;
422+
for list in self.0.lists() {
423+
let matcher = &*self.1[list.index()] as &dyn erased_serde::Serialize;
424+
seq.serialize_element(&TypedListMatcher {
425+
ty: list.get_type(),
426+
data: matcher,
427+
})?;
428+
}
429+
seq.end()
430+
}
300431
}
301432

302433
impl<'de, U> DeserializeSeed<'de> for &mut ExecutionContext<'de, U> {
@@ -312,7 +443,7 @@ impl<'de, U> DeserializeSeed<'de> for &mut ExecutionContext<'de, U> {
312443
type Value = ();
313444

314445
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
315-
write!(formatter, "a map of lhs value")
446+
write!(formatter, "a serialized execution context")
316447
}
317448

318449
fn visit_map<M>(self, mut access: M) -> Result<(), M::Error>
@@ -322,20 +453,10 @@ impl<'de, U> DeserializeSeed<'de> for &mut ExecutionContext<'de, U> {
322453
while let Some(key) = access.next_key::<Cow<'_, str>>()? {
323454
if key == "$lists" {
324455
// Deserialize lists
325-
let vec = access.next_value::<Vec<ListData>>()?;
326-
for ListData { ty, data } in vec.into_iter() {
327-
let list = self.0.scheme.get_list(&ty).ok_or_else(|| {
328-
de::Error::custom(format!("unknown list for type: {ty:?}"))
329-
})?;
330-
self.0.list_matchers[list.index()] = list
331-
.definition()
332-
.matcher_from_json_value(ty, data)
333-
.map_err(|err| {
334-
de::Error::custom(format!(
335-
"failed to deserialize list matcher: {err:?}"
336-
))
337-
})?;
338-
}
456+
access.next_value_seed(ListMatcherSlice(
457+
&self.0.scheme,
458+
&mut self.0.list_matchers,
459+
))?;
339460
} else {
340461
let field = self
341462
.0
@@ -381,20 +502,25 @@ impl<U> Serialize for ExecutionContext<'_, U> {
381502

382503
struct ListMatcherSlice<'a>(&'a Scheme, &'a [Box<dyn ListMatcher>]);
383504

505+
#[derive(Serialize)]
506+
struct TypedListMatcher<'a> {
507+
#[serde(rename = "type")]
508+
ty: Type,
509+
data: &'a dyn erased_serde::Serialize,
510+
}
511+
384512
impl Serialize for ListMatcherSlice<'_> {
385513
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
386514
where
387515
S: Serializer,
388516
{
389517
let mut seq = serializer.serialize_seq(Some(self.1.len()))?;
390518
for list in self.0.lists() {
391-
let data = self.1[list.index()].to_json_value();
392-
if data != serde_json::Value::Null {
393-
seq.serialize_element(&ListData {
394-
ty: list.get_type(),
395-
data,
396-
})?;
397-
}
519+
let matcher = &*self.1[list.index()] as &dyn erased_serde::Serialize;
520+
seq.serialize_element(&TypedListMatcher {
521+
ty: list.get_type(),
522+
data: matcher,
523+
})?;
398524
}
399525
seq.end()
400526
}

0 commit comments

Comments
 (0)