Skip to content

Commit ca00c14

Browse files
committed
CSHARP-4337: Use correct serializer for conditional result.
1 parent 3d67e80 commit ca00c14

File tree

6 files changed

+293
-46
lines changed

6 files changed

+293
-46
lines changed

src/MongoDB.Bson/Serialization/Serializers/EnumSerializer.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ public override TEnum Deserialize(BsonDeserializationContext context, BsonDeseri
102102
}
103103
}
104104

105+
/// <inheritdoc/>
106+
public override bool Equals(object obj)
107+
{
108+
return
109+
obj is EnumSerializer<TEnum> other &&
110+
_representation == other.Representation;
111+
}
112+
113+
/// <inheritdoc/>
114+
public override int GetHashCode()
115+
{
116+
return _representation.GetHashCode();
117+
}
118+
105119
/// <summary>
106120
/// Serializes a value.
107121
/// </summary>

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializerFinder.cs

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,42 @@ public override Expression Visit(Expression node)
6868

6969
var result = base.Visit(node);
7070
_registry.Add(node, _currentKnownSerializersNode);
71-
_currentKnownSerializersNode = _currentKnownSerializersNode.Parent;
71+
72+
var parent = _currentKnownSerializersNode.Parent;
73+
if (ShouldPropagateKnownSerializersToParent(parent))
74+
{
75+
parent.AddKnownSerializersFromChild(_currentKnownSerializersNode);
76+
}
77+
_currentKnownSerializersNode = parent;
78+
79+
return result;
80+
}
81+
82+
protected override Expression VisitConditional(ConditionalExpression node)
83+
{
84+
var result = base.VisitConditional(node);
85+
86+
if (_currentKnownSerializersNode.KnownSerializers.TryGetValue(node.Type, out var resultSerializers) &&
87+
resultSerializers.Count > 1)
88+
{
89+
var ifTrueSerializer = _registry.GetSerializerAtThisLevel(node.IfTrue);
90+
var ifFalseSerializer = _registry.GetSerializerAtThisLevel(node.IfFalse);
91+
92+
if (ifTrueSerializer != null && ifFalseSerializer != null && !ifTrueSerializer.Equals(ifFalseSerializer))
93+
{
94+
throw new ExpressionNotSupportedException(node, because: "IfTrue and IfFalse expressions have different serializers");
95+
}
96+
97+
if (ifTrueSerializer != null)
98+
{
99+
_currentKnownSerializersNode.SetKnownSerializerForType(node.Type, ifTrueSerializer);
100+
}
101+
else if (ifFalseSerializer != null)
102+
{
103+
_currentKnownSerializersNode.SetKnownSerializerForType(node.Type, ifFalseSerializer);
104+
}
105+
}
106+
72107
return result;
73108
}
74109

@@ -87,14 +122,14 @@ protected override Expression VisitBinary(BinaryExpression node)
87122
{
88123
var rightExpressionSerializer = _registry.GetSerializer(rightExpression);
89124
var leftExpressionSerializer = EnumUnderlyingTypeSerializer.Create(rightExpressionSerializer);
90-
_registry.AddKnownSerializer(leftExpression, leftExpressionSerializer, allowPropagation: false);
125+
_registry.SetNodeSerializer(leftExpression, leftExpressionSerializer);
91126
}
92127

93128
if (rightExpression is ConstantExpression rightConstantExpression)
94129
{
95130
var leftExpressionSerializer = _registry.GetSerializer(leftExpression);
96131
var rightExpressionSerializer = EnumUnderlyingTypeSerializer.Create(leftExpressionSerializer);
97-
_registry.AddKnownSerializer(rightExpression, rightExpressionSerializer, allowPropagation: false);
132+
_registry.SetNodeSerializer(rightExpression, rightExpressionSerializer);
98133
}
99134
}
100135
}
@@ -202,5 +237,20 @@ protected override Expression VisitParameter(ParameterExpression node)
202237

203238
return result;
204239
}
240+
241+
private bool ShouldPropagateKnownSerializersToParent(KnownSerializersNode parent)
242+
{
243+
if (parent == null)
244+
{
245+
return false;
246+
}
247+
248+
return parent.Expression.NodeType switch
249+
{
250+
ExpressionType.MemberInit => false,
251+
ExpressionType.New => false,
252+
_ => true
253+
};
254+
}
205255
}
206256
}

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializersNode.cs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ internal class KnownSerializersNode
2727
// private fields
2828
private readonly Expression _expression;
2929
private readonly Dictionary<Type, HashSet<IBsonSerializer>> _knownSerializers = new Dictionary<Type, HashSet<IBsonSerializer>>();
30+
private IBsonSerializer _nodeSerializer; // a serializer used only for this node (not propagated upwards)
3031
private readonly KnownSerializersNode _parent;
3132

3233
// constructors
@@ -42,7 +43,16 @@ public KnownSerializersNode(Expression expression, KnownSerializersNode parent)
4243
public KnownSerializersNode Parent => _parent;
4344

4445
// public methods
45-
public void AddKnownSerializer(Type type, IBsonSerializer serializer, bool allowPropagation = true)
46+
public void AddKnownSerializersFromChild(KnownSerializersNode child)
47+
{
48+
foreach (var type in child.KnownSerializers.Keys)
49+
foreach (var serializer in child.KnownSerializers[type])
50+
{
51+
AddKnownSerializer(type, serializer);
52+
}
53+
}
54+
55+
public void AddKnownSerializer(Type type, IBsonSerializer serializer)
4656
{
4757
if (!_knownSerializers.TryGetValue(type, out var set))
4858
{
@@ -51,15 +61,35 @@ public void AddKnownSerializer(Type type, IBsonSerializer serializer, bool allow
5161
}
5262

5363
set.Add(serializer);
64+
}
65+
66+
public void SetKnownSerializerForType(Type type, IBsonSerializer serializer)
67+
{
68+
if (serializer.ValueType != type)
69+
{
70+
throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expected type {type}.");
71+
}
72+
73+
_knownSerializers[type] = new HashSet<IBsonSerializer> { serializer };
74+
}
5475

55-
if (allowPropagation && ShouldPropagateKnownSerializerToParent())
76+
public void SetNodeSerializer(IBsonSerializer serializer)
77+
{
78+
if (serializer.ValueType != _expression.Type)
5679
{
57-
_parent.AddKnownSerializer(type, serializer);
80+
throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expression type {_expression.Type}.");
5881
}
82+
83+
_nodeSerializer = serializer;
5984
}
6085

6186
public HashSet<IBsonSerializer> GetPossibleSerializers(Type type)
6287
{
88+
if (_nodeSerializer != null && _nodeSerializer.ValueType == type)
89+
{
90+
return new HashSet<IBsonSerializer> { _nodeSerializer };
91+
}
92+
6393
var possibleSerializers = GetPossibleSerializersAtThisLevel(type);
6494
if (possibleSerializers.Count > 0)
6595
{
@@ -115,20 +145,5 @@ private HashSet<IBsonSerializer> GetPossibleSerializersAtThisLevel(Type type)
115145

116146
return possibleSerializers;
117147
}
118-
119-
private bool ShouldPropagateKnownSerializerToParent()
120-
{
121-
if (_parent == null)
122-
{
123-
return false;
124-
}
125-
126-
return _parent.Expression.NodeType switch
127-
{
128-
ExpressionType.MemberInit => false,
129-
ExpressionType.New => false,
130-
_ => true
131-
};
132-
}
133148
}
134149
}

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializersRegistry.cs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ public void Add(Expression expression, KnownSerializersNode knownSerializers)
4242
_registry.Add(expression, knownSerializers);
4343
}
4444

45-
public void AddKnownSerializer(Expression expression, IBsonSerializer knownSerializer, bool allowPropagation = true)
45+
public void SetNodeSerializer(Expression expression, IBsonSerializer nodeSerializer)
4646
{
47-
if (knownSerializer.ValueType != expression.Type)
47+
if (nodeSerializer.ValueType != expression.Type)
4848
{
49-
throw new ArgumentException($"Serializer value type {knownSerializer.ValueType} does not match expresion type {expression.Type}.", nameof(knownSerializer));
49+
throw new ArgumentException($"Serializer value type {nodeSerializer.ValueType} does not match expresion type {expression.Type}.", nameof(nodeSerializer));
5050
}
5151

5252
if (!_registry.TryGetValue(expression, out var knownSerializers))
5353
{
5454
throw new InvalidOperationException("KnownSerializersNode does not exist yet for expression: {expression}.");
5555
}
5656

57-
knownSerializers.AddKnownSerializer(expression.Type, knownSerializer, allowPropagation);
57+
knownSerializers.SetNodeSerializer(nodeSerializer);
5858
}
5959

6060
public IBsonSerializer GetSerializer(Expression expression, IBsonSerializer defaultSerializer = null)
@@ -74,6 +74,18 @@ public IBsonSerializer GetSerializer(Expression expression, Type type, IBsonSeri
7474
};
7575
}
7676

77+
public IBsonSerializer GetSerializerAtThisLevel(Expression expression)
78+
{
79+
var expressionType = expression is LambdaExpression lambdaExpression ? lambdaExpression.ReturnType : expression.Type;
80+
return GetSerializerAtThisLevel(expression, expressionType);
81+
}
82+
83+
public IBsonSerializer GetSerializerAtThisLevel(Expression expression, Type type)
84+
{
85+
var possibleSerializers = _registry.TryGetValue(expression, out var knownSerializers) ? knownSerializers.GetPossibleSerializers(type) : new HashSet<IBsonSerializer>();
86+
return possibleSerializers.Count == 1 ? possibleSerializers.Single() : null;
87+
}
88+
7789
private IBsonSerializer LookupSerializer(Expression expression, Type type)
7890
{
7991
if (type.IsConstructedGenericType &&

0 commit comments

Comments
 (0)