diff --git a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs index fd5998c93b3..aab7c3bce7f 100644 --- a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs +++ b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs @@ -50,6 +50,38 @@ public static TValue Deserialize(this IBsonSerializer serializer return serializer.Deserialize(context, args); } + /// + /// Gets the serializer for a base type starting from a serializer for a derived type. + /// + /// The serializer for the derived type. + /// The base type. + /// The serializer for the base type. + public static IBsonSerializer GetBaseTypeSerializer(this IBsonSerializer serializer, Type baseType) + { + if (!baseType.IsAssignableFrom(serializer.ValueType)) + { + throw new ArgumentException($"{baseType} is not assignable from {serializer.ValueType}."); + } + + return BsonSerializer.LookupSerializer(baseType); // TODO: should be able to navigate from serializer + } + + /// + /// Gets the serializer for a derived type starting from a serializer for a base type. + /// + /// The serializer for the base type. + /// The derived type. + /// The serializer for the derived type. + public static IBsonSerializer GetDerivedTypeSerializer(this IBsonSerializer serializer, Type derivedType) + { + if (!serializer.ValueType.IsAssignableFrom(derivedType)) + { + throw new ArgumentException($"{serializer.ValueType} is not assignable from {derivedType}."); + } + + return BsonSerializer.LookupSerializer(derivedType); // TODO: should be able to navigate from serializer + } + /// /// Gets the discriminator convention for a serializer. /// diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index db7618ce677..3879d911cc9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -60,5 +60,17 @@ public static TValue GetConstantValue(this Expression expression, Expres var message = $"Expression must be a constant: {expression} in {containingExpression}."; throw new ExpressionNotSupportedException(message); } + + public static bool IsConvert(this Expression expression, out Expression operand) + { + if (expression is UnaryExpression unaryExpression && unaryExpression.NodeType == ExpressionType.Convert) + { + operand = unaryExpression.Operand; + return true; + } + + operand = null; + return false; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 4d62eaea95c..38e93428c37 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,8 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedOutput.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedOutput, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +107,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +152,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +191,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/KnownSerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/KnownSerializerFinder.cs new file mode 100644 index 00000000000..5f30e573726 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/KnownSerializerFinder.cs @@ -0,0 +1,368 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal class KnownSerializerFinder : ExpressionVisitor +{ + public static KnownSerializerMap FindKnownSerializers( + Expression expression) + { + var knownSerializers = new KnownSerializerMap(); + return FindKnownSerializers(expression, knownSerializers); + } + + public static KnownSerializerMap FindKnownSerializers( + Expression expression, + Expression initialNode, + IBsonSerializer initialSerializer) + { + var knownSerializers = new KnownSerializerMap(); + knownSerializers.AddSerializer(initialNode, initialSerializer); + return FindKnownSerializers(expression, knownSerializers); + } + + public static KnownSerializerMap FindKnownSerializers( + Expression expression, + KnownSerializerMap knownSerializers) + { + var finder = new KnownSerializerFinder(knownSerializers); + + int oldSerializerCount; + int newSerializerCount; + do + { + oldSerializerCount = finder._knownSerializers.Count; + finder.Visit(expression); + newSerializerCount = finder._knownSerializers.Count; + } + while (newSerializerCount > oldSerializerCount); // I don't know yet if this can be done in a single pass + + #if DEBUG + var expressionWithUnknownSerializer = UnknownSerializerFinder.FindExpressionWithUnknownSerializer(expression, knownSerializers); + if (expressionWithUnknownSerializer != null) + { + throw new ExpressionNotSupportedException(expressionWithUnknownSerializer, because: "unable to determine which serializer to use"); + } + #endif + + return knownSerializers; + } + + private readonly KnownSerializerMap _knownSerializers; + + public KnownSerializerFinder(KnownSerializerMap knownSerializers) + { + _knownSerializers = knownSerializers; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + base.VisitBinary(node); + + var @operator = node.NodeType; + var left = node.Left; + var right = node.Right; + + if (IsSymmetricalBinaryOperator(@operator) && + CanDeduceSerializer(left, right, out var unknownNode, out var knownSerializer)) + { + // expr1 op expr2 => expr1: expr2Serializer or expr2: expr1Serializer + _knownSerializers.AddSerializer(unknownNode, knownSerializer); + return node; + } + + if (@operator == ExpressionType.ArrayIndex) + { + if (_knownSerializers.IsNotKnown(node) && + _knownSerializers.IsKnown(left, out var leftSerializer)) + { + if (leftSerializer is not IBsonSerializer arraySerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {leftSerializer.GetType()} does not implement IBsonArraySerializer"); + } + + var itemSerializer = ArraySerializerHelper.GetItemSerializer(arraySerializer); + + // expr[index] => node: itemSerializer + _knownSerializers.AddSerializer(node, itemSerializer); + } + + return node; + } + + if (left.IsConvert(out var leftConvertOperand) && right.IsConvert(out var rightConvertOperand)) + { + if (CanDeduceSerializer(leftConvertOperand, rightConvertOperand, out unknownNode, out knownSerializer)) + { + // Convert(expr1, T) op Convert(expr2, T) => expr1: expr2Serializer or expr2: expr1Serializer + _knownSerializers.AddSerializer(unknownNode, knownSerializer); + } + + return node; + } + + return node; + + static bool IsSymmetricalBinaryOperator(ExpressionType @operator) + => @operator is + ExpressionType.Add or + ExpressionType.AddChecked or + ExpressionType.And or + ExpressionType.AndAlso or + ExpressionType.Divide or + ExpressionType.Equal or + ExpressionType.GreaterThan or + ExpressionType.GreaterThanOrEqual or + ExpressionType.Modulo or + ExpressionType.Multiply or + ExpressionType.MultiplyChecked or + ExpressionType.Or or + ExpressionType.OrElse or + ExpressionType.Subtract or + ExpressionType.SubtractChecked; + } + + protected override Expression VisitConstant(ConstantExpression node) + { + var value = node.Value; + + if (_knownSerializers.IsNotKnown(node) && + value is IQueryable queryable && + queryable.Provider is IMongoQueryProviderInternal provider && + queryable.Expression is ConstantExpression constantExpression && + constantExpression.Value == value) + { + var documentSerializer = provider.PipelineInputSerializer; + var queryableSerializer = QueryableSerializer.Create(itemSerializer: documentSerializer); + + // originalSource => node: new QueryableSerializer(documentSerializer) + _knownSerializers.AddSerializer(node, queryableSerializer); + } + + return node; + } + + protected override Expression VisitMember(MemberExpression node) + { + base.VisitMember(node); + + var containingExpression = node.Expression; + if (_knownSerializers.IsKnown(containingExpression, out var containingSerializer) && + _knownSerializers.IsNotKnown(node)) + { + // TODO: handle special cases like DateTime.Year etc. + + if (containingSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonDocumentSerializer)} interface"); + } + + var memberName = node.Member.Name; + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not support a member named: {memberName}"); + } + var memberSerializer = memberSerializationInfo.Serializer; + + // expr.Member => node: memberSerializer + _knownSerializers.AddSerializer(node, memberSerializer); + } + + return node; + } + + protected override Expression VisitMemberInit(MemberInitExpression node) + { + base.VisitMemberInit(node); + + if (_knownSerializers.IsKnown(node, out var newSerializer)) + { + if (newSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {newSerializer.GetType()} does not implement IBsonDocumentSerializer interface"); + } + + foreach (var binding in node.Bindings) + { + if (binding is MemberAssignment memberAssignment) + { + if (_knownSerializers.IsNotKnown(memberAssignment.Expression)) + { + var member = memberAssignment.Member; + var memberName = member.Name; + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(node, because: $"type {member.DeclaringType} does not have a member named: {memberName}"); + } + var expressionSerializer = memberSerializationInfo.Serializer; + + if (expressionSerializer.ValueType != memberAssignment.Expression.Type && + expressionSerializer.ValueType.IsAssignableFrom(memberAssignment.Expression.Type)) + { + expressionSerializer = expressionSerializer.GetDerivedTypeSerializer(memberAssignment.Expression.Type); + } + + // member = expression => expression: memberSerializer (or derivedTypeSerializer) + _knownSerializers.AddSerializer(memberAssignment.Expression, expressionSerializer); + } + } + } + } + + return node; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + base.VisitMethodCall(node); + + var method = node.Method; + var arguments = node.Arguments; + + if (method.IsStatic && + arguments.Count >= 2 && + arguments[0] is var sourceExpression && + arguments[1] is LambdaExpression lambdaExpression && + sourceExpression.Type.ImplementsIEnumerable(out var sourceItemType) && + lambdaExpression.Parameters.Count == 1 && + lambdaExpression.Parameters[0] is var itemExpression && + itemExpression.Type == sourceItemType) + { + IBsonSerializer itemSerializer; + + if (_knownSerializers.IsKnown(sourceExpression, out var sourceSerializer) && + _knownSerializers.IsNotKnown(itemExpression)) + { + itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + + // source.method(item => ...) => item: itemSerializer + _knownSerializers.AddSerializer(itemExpression, itemSerializer); + } + + if (_knownSerializers.IsNotKnown(sourceExpression) && + _knownSerializers.IsKnown(itemExpression, out itemSerializer)) + { + sourceSerializer = BsonSerializer.LookupSerializer(sourceExpression.Type); // TODO: is it OK to use BsonSerializer registry? + if (sourceSerializer is IChildSerializerConfigurable childSerializerConfigurable) + { + sourceSerializer = childSerializerConfigurable.WithChildSerializer(itemSerializer); + + // source.method(item => ...) => source: sourceSerializer + _knownSerializers.AddSerializer(sourceExpression, sourceSerializer); + } + } + } + + return node; + } + + protected override Expression VisitNew(NewExpression node) + { + base.VisitNew(node); + + if (_knownSerializers.IsKnown(node, out var nodeSerializer) && + node.Arguments.Any(_knownSerializers.IsNotKnown)) + { + var matchingMemberSerializationInfos = nodeSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(node, node.Constructor); + for (var i = 0; i < matchingMemberSerializationInfos.Count; i++) + { + var argumentExpression = node.Arguments[i]; + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; + + if (_knownSerializers.IsNotKnown(argumentExpression)) + { + // arg => arg: matchingMemberSerializer + _knownSerializers.AddSerializer(argumentExpression, matchingMemberSerializationInfo.Serializer); + } + } + } + + return node; + } + + protected override Expression VisitNewArray(NewArrayExpression node) + { + base.VisitNewArray(node); + + if (_knownSerializers.IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is not IBsonArraySerializer arraySerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {nodeSerializer.GetType()} does not implement IBsonArraySerializer"); + } + var itemSerializer = ArraySerializerHelper.GetItemSerializer(arraySerializer); + + foreach (var expression in node.Expressions) + { + if (_knownSerializers.IsNotKnown(expression)) + { + // new T[] { ..., expr, ... } => expr: itemSerializer + _knownSerializers.AddSerializer(expression, itemSerializer); + } + } + } + + return node; + } + + private bool CanDeduceSerializer(Expression node1, Expression node2, out Expression unknownNode, out IBsonSerializer knownSerializer) + { + if (_knownSerializers.IsKnown(node1, out var node1Serializer) && + _knownSerializers.IsNotKnown(node2)) + { + unknownNode = node2; + knownSerializer = node1Serializer; + return true; + } + + if (_knownSerializers.IsNotKnown(node1) && + _knownSerializers.IsKnown(node2, out var node2Serializer)) + { + unknownNode = node1; + knownSerializer = node2Serializer; + return true; + } + + unknownNode = null; + knownSerializer = null; + return false; + } +} + +internal class UnknownSerializerFinder : ExpressionVisitor +{ + public static Expression FindExpressionWithUnknownSerializer(Expression expression, KnownSerializerMap knownSerializers) + { + return null; // TODO: implement + } + + private Expression _expressionWithUnknownSerializer = null; + private readonly KnownSerializerMap _knownSerializers; + + public UnknownSerializerFinder(KnownSerializerMap knownSerializers) + { + _knownSerializers = knownSerializers; + } + + public Expression ExpressionWithUnknownSerialier => _expressionWithUnknownSerializer; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/KnownSerializerMap.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/KnownSerializerMap.cs new file mode 100644 index 00000000000..3f8be92dd80 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/KnownSerializerMap.cs @@ -0,0 +1,84 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal class KnownSerializerMap +{ + private readonly Dictionary _map = new(); + + public int Count => _map.Count; + + public void AddSerializer(Expression node, IBsonSerializer serializer) + { + if (serializer.ValueType != node.Type) + { + if (node.Type.IsAssignableFrom(serializer.ValueType)) + { + serializer = DowncastingSerializer.Create(baseType: node.Type, derivedType: serializer.ValueType, derivedTypeSerializer: serializer); + } + else if (serializer.ValueType.IsAssignableFrom(node.Type)) + { + serializer = UpcastingSerializer.Create(baseType: serializer.ValueType, derivedType: node.Type, baseTypeSerializer: serializer); + } + else + { + throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expression value type {node.Type}", nameof(serializer)); + } + } + + if (_map.TryGetValue(node, out var existingSerializer)) + { + throw new ExpressionNotSupportedException( + node, + because: $"there are duplicate known serializers for expression '{node}': {serializer.GetType()} and {existingSerializer.GetType()}"); + } + + _map.Add(node, serializer); + } + + public IBsonSerializer GetSerializer(Expression node) + { + if (_map.TryGetValue(node, out var knownSerializer)) + { + return knownSerializer; + } + + throw new ExpressionNotSupportedException(node, because: "unable to determine which serializer to use"); + } + + public bool IsNotKnown(Expression node) + { + return !IsKnown(node); + } + + public bool IsKnown(Expression node) + { + return _map.ContainsKey(node); + } + + public bool IsKnown(Expression node, out IBsonSerializer serializer) + { + return _map.TryGetValue(node, out serializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs index fe96bacae36..a868f3a8508 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs @@ -41,7 +41,7 @@ internal class MongoQuery : MongoQuery, IOrderedQue public MongoQuery(MongoQueryProvider provider) { _provider = provider; - _expression = Expression.Constant(this); + _expression = Expression.Constant(this, typeof(IQueryable<>).MakeGenericType(typeof(TDocument))); } public MongoQuery(MongoQueryProvider provider, Expression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IBsonSerializerExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IBsonSerializerExtensions.cs new file mode 100644 index 00000000000..9d8be120004 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IBsonSerializerExtensions.cs @@ -0,0 +1,70 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IBsonSerializerExtensions +{ + public static IReadOnlyList GetMatchingMemberSerializationInfosForConstructorParameters( + this IBsonSerializer serializer, + Expression expression, + ConstructorInfo constructorInfo) + { + if (serializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer type {serializer.GetType().Name} does not implement IBsonDocumentSerializer"); + } + + var matchingMemberSerializationInfos = new List(); + foreach (var constructorParameter in constructorInfo.GetParameters()) + { + var matchingMemberSerializationInfo = GetMatchingMemberSerializationInfo(expression, documentSerializer, constructorParameter.Name); + matchingMemberSerializationInfos.Add(matchingMemberSerializationInfo); + } + + return matchingMemberSerializationInfos; + + static BsonSerializationInfo GetMatchingMemberSerializationInfo( + Expression expression, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) + { + var possibleMatchingMembers = documentSerializer.ValueType.GetMembers().Where(m => m.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase)).ToArray(); + if (possibleMatchingMembers.Length == 0) + { + throw new ExpressionNotSupportedException(expression, because: $"no matching member found for constructor parameter: {constructorParameterName}"); + } + if (possibleMatchingMembers.Length > 1) + { + throw new ExpressionNotSupportedException(expression, because: $"multiple possible matching members found for constructor parameter: {constructorParameterName}"); + } + var matchingMemberName = possibleMatchingMembers[0].Name; + + if (!documentSerializer.TryGetMemberSerializationInfo(matchingMemberName, out var matchingMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer of type {documentSerializer.GetType().Name} did not provide serialization info for member {matchingMemberName}"); + } + + return matchingMemberSerializationInfo; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/QueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/QueryableSerializer.cs new file mode 100644 index 00000000000..2acef5c2e7e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/QueryableSerializer.cs @@ -0,0 +1,46 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class QueryableSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(QueryableSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class QueryableSerializer : SerializerBase>, IBsonArraySerializer +{ + private readonly IBsonSerializer _itemSerializer; + + public QueryableSerializer(IBsonSerializer itemSerializer) + { + _itemSerializer = itemSerializer; + } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) + { + serializationInfo = new(null, _itemSerializer, typeof(T)); + return true; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs new file mode 100644 index 00000000000..e2843cb8602 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal static class UpcastingSerializer + { + public static IBsonSerializer Create( + Type baseType, + Type derivedType, + IBsonSerializer baseTypeSerializer) + { + var upcastingSerializerType = typeof(UpcastingSerializer<,>).MakeGenericType(baseType, derivedType); + return (IBsonSerializer)Activator.CreateInstance(upcastingSerializerType, baseTypeSerializer); + } + } + + internal sealed class UpcastingSerializer : SerializerBase, IBsonArraySerializer, IBsonDocumentSerializer + where TDerived : TBase + { + private readonly IBsonSerializer _baseTypeSerializer; + + public UpcastingSerializer(IBsonSerializer baseTypeSerializer) + { + _baseTypeSerializer = baseTypeSerializer ?? throw new ArgumentNullException(nameof(baseTypeSerializer)); + } + + public Type BaseType => typeof(TBase); + + public IBsonSerializer BaseTypeSerializer => _baseTypeSerializer; + + public Type DerivedType => typeof(TDerived); + + public override TDerived Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + return (TDerived)_baseTypeSerializer.Deserialize(context); + } + + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is UpcastingSerializer other && + object.Equals(_baseTypeSerializer, other._baseTypeSerializer); + } + + public override int GetHashCode() => 0; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TDerived value) + { + _baseTypeSerializer.Serialize(context, value); + } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonArraySerializer arraySerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonArraySerializer."); + } + + return arraySerializer.TryGetItemSerializationInfo(out serializationInfo); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonDocumentSerializer."); + } + + return documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs index f3bb40aaf3a..c66f84b213e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs @@ -98,6 +98,20 @@ public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationI public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) { + if (_valueSerializer is IBsonDocumentSerializer documentSerializer) + { + if (documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo)) + { + serializationInfo = BsonSerializationInfo.CreateWithPath( + [_fieldName, serializationInfo.ElementName], + serializationInfo.Serializer, + serializationInfo.NominalType); + return true; + } + + return false; + } + throw new InvalidOperationException(); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index 7487627213d..cf6698e46ec 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,10 +23,14 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static TranslatedExpression Translate(ConstantExpression constantExpression) + public static TranslatedExpression Translate(TranslationContext context, ConstantExpression constantExpression) { var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + // TODO: throw if serializer is not known? + if (!context.KnownSerializers.IsKnown(constantExpression, out var constantSerializer)) + { + constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + } return Translate(constantExpression, constantSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index c2d8e0010e9..5eeb2857f9a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -67,7 +67,7 @@ public static TranslatedExpression Translate(TranslationContext context, Express case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate(context, (ConstantExpression)expression); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 20f7e81312c..03bf3afeb9f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -13,16 +13,9 @@ * limitations under the License. */ -using System; using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; -using MongoDB.Bson.Serialization; -using MongoDB.Driver.Linq.Linq3Implementation.Ast; -using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; -using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -44,168 +37,10 @@ public static TranslatedExpression Translate( NewExpression newExpression, IReadOnlyList bindings) { - var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct - var constructorArguments = newExpression.Arguments; - var computedFields = new List(); - var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); - - if (constructorInfo != null && creatorMap != null) - { - var constructorParameters = constructorInfo.GetParameters(); - var creatorMapParameters = creatorMap.Arguments?.ToArray(); - if (constructorParameters.Length > 0) - { - if (creatorMapParameters == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); - } - if (creatorMapParameters.Length != constructorParameters.Length) - { - throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); - } - - for (var i = 0; i < creatorMapParameters.Length; i++) - { - var creatorMapParameter = creatorMapParameters[i]; - var constructorArgumentExpression = constructorArguments[i]; - var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); - var constructorArgumentType = constructorArgumentExpression.Type; - var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); - var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); - EnsureDefaultValue(memberMap); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); - } - } - } - - foreach (var binding in bindings) - { - var memberAssignment = (MemberAssignment)binding; - var member = memberAssignment.Member; - var memberMap = FindMemberMap(expression, classMap, member.Name); - var valueExpression = memberAssignment.Expression; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); - } - - var ast = AstExpression.ComputedDocument(computedFields); - classMap.Freeze(); - var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); - var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); - - return new TranslatedExpression(expression, ast, serializer); - } - - private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) - { - BsonClassMap baseClassMap = null; - if (classType.BaseType != null) - { - baseClassMap = CreateClassMap(classType.BaseType, null, out _); - } - - var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); - var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); - var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); - if (constructorInfo != null) - { - creatorMap = classMap.MapConstructor(constructorInfo); - } - else - { - creatorMap = null; - } - - classMap.AutoMap(); - classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here - - return classMap; - } - - private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) - { - var memberType = memberMap.MemberType; - var memberSerializer = memberMap.GetSerializer(); - var sourceType = sourceSerializer.ValueType; - - if (memberType != sourceType && - memberType.ImplementsIEnumerable(out var memberItemType) && - sourceType.ImplementsIEnumerable(out var sourceItemType) && - sourceItemType == memberItemType && - sourceSerializer is IBsonArraySerializer sourceArraySerializer && - sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && - memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) - { - var sourceItemSerializer = sourceItemSerializationInfo.Serializer; - return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); - } - - return sourceSerializer; - } - - private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) - { - var declaringClassMap = classMap; - while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) - { - declaringClassMap = declaringClassMap.BaseClassMap; - - if (declaringClassMap == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); - } - } - - foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) - { - if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) - { - return memberMap; - } - } - - return declaringClassMap.MapMember(creatorMapParameter); - - static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) - { - var memberInfo = memberMap.MemberInfo; - return - memberInfo.MemberType == creatorMapParameter.MemberType && - memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); - } - } - - private static void EnsureDefaultValue(BsonMemberMap memberMap) - { - if (memberMap.IsDefaultValueSpecified) - { - return; - } - - var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; - memberMap.SetDefaultValue(defaultValue); - } - - private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) - { - foreach (var memberMap in classMap.DeclaredMemberMaps) - { - if (memberMap.MemberName == memberName) - { - return memberMap; - } - } - - if (classMap.BaseClassMap != null) - { - return FindMemberMap(expression, classMap.BaseClassMap, memberName); - } - - throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + return + context.KnownSerializers.IsKnown(expression, out var knownSerializer) ? + MemberInitExpressionWithKnownSerializerToAggregationExpressionTranslator.TranslateWithKnownSerializer(context, expression, newExpression, bindings, knownSerializer) : + MemberInitExpressionWithoutKnownSerializerToAggregationExpressionTranslator.TranslateWithoutKnownSerializer(context, expression, newExpression, bindings); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionWithKnownSerializerToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionWithKnownSerializerToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..60528e7e894 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionWithKnownSerializerToAggregationExpressionTranslator.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Ast; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + internal static class MemberInitExpressionWithKnownSerializerToAggregationExpressionTranslator + { + public static TranslatedExpression TranslateWithKnownSerializer( + TranslationContext context, + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings, + IBsonSerializer knownSerializer) + { + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + + var computedFields = new List(); + if (constructorInfo != null && constructorArguments.Count > 0) + { + var matchingMemberSerializationInfos = knownSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(expression, constructorInfo); + + for (var i = 0; i < constructorArguments.Count; i++) + { + var argument = constructorArguments[i]; + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argument); + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; + + if (!argumentTranslation.Serializer.Equals(matchingMemberSerializationInfo.Serializer)) + { + throw new ExpressionNotSupportedException(argument, expression, because: "argument serializer is not equal to member serializer"); + } + + var computedField = AstExpression.ComputedField(matchingMemberSerializationInfo.ElementName, argumentTranslation.Ast); + computedFields.Add(computedField); + } + } + + if (bindings.Count > 0) + { + if (knownSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer type {knownSerializer.GetType()} does not implement IBsonDocumentSerializer"); + } + + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + + if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"member {member.Name} was not found"); + } + + var valueExpression = memberAssignment.Expression; + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + + if (!valueTranslation.Serializer.Equals(memberSerializationInfo.Serializer)) + { + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"value serializer is not equal to serializer for member {member.Name}"); + } + + var computedField = AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast); + computedFields.Add(computedField); + } + } + + var ast = AstExpression.ComputedDocument(computedFields); + return new TranslatedExpression(expression, ast, knownSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionWithoutKnownSerializerToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionWithoutKnownSerializerToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..609a5b22680 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionWithoutKnownSerializerToAggregationExpressionTranslator.cs @@ -0,0 +1,201 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Ast; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + internal static class MemberInitExpressionWithoutKnownSerializerToAggregationExpressionTranslator + { + public static TranslatedExpression TranslateWithoutKnownSerializer( + TranslationContext context, + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings) + { + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var computedFields = new List(); + var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); + + if (constructorInfo != null && creatorMap != null) + { + var constructorParameters = constructorInfo.GetParameters(); + var creatorMapParameters = creatorMap.Arguments?.ToArray(); + if (constructorParameters.Length > 0) + { + if (creatorMapParameters == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); + } + if (creatorMapParameters.Length != constructorParameters.Length) + { + throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); + } + + for (var i = 0; i < creatorMapParameters.Length; i++) + { + var creatorMapParameter = creatorMapParameters[i]; + var constructorArgumentExpression = constructorArguments[i]; + var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); + var constructorArgumentType = constructorArgumentExpression.Type; + var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); + var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); + EnsureDefaultValue(memberMap); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); + memberMap.SetSerializer(memberSerializer); + computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); + } + } + } + + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var memberMap = FindMemberMap(expression, classMap, member.Name); + var valueExpression = memberAssignment.Expression; + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); + memberMap.SetSerializer(memberSerializer); + computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); + } + + var ast = AstExpression.ComputedDocument(computedFields); + classMap.Freeze(); + var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); + var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); + + return new TranslatedExpression(expression, ast, serializer); + } + + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) + { + BsonClassMap baseClassMap = null; + if (classType.BaseType != null) + { + baseClassMap = CreateClassMap(classType.BaseType, null, out _); + } + + var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); + var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); + var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); + if (constructorInfo != null) + { + creatorMap = classMap.MapConstructor(constructorInfo); + } + else + { + creatorMap = null; + } + + classMap.AutoMap(); + classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here + + return classMap; + } + + private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) + { + var memberType = memberMap.MemberType; + var memberSerializer = memberMap.GetSerializer(); + var sourceType = sourceSerializer.ValueType; + + if (memberType != sourceType && + memberType.ImplementsIEnumerable(out var memberItemType) && + sourceType.ImplementsIEnumerable(out var sourceItemType) && + sourceItemType == memberItemType && + sourceSerializer is IBsonArraySerializer sourceArraySerializer && + sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && + memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) + { + var sourceItemSerializer = sourceItemSerializationInfo.Serializer; + return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); + } + + return sourceSerializer; + } + + private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) + { + var declaringClassMap = classMap; + while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) + { + declaringClassMap = declaringClassMap.BaseClassMap; + + if (declaringClassMap == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + } + } + + foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + { + if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + { + return memberMap; + } + } + + return declaringClassMap.MapMember(creatorMapParameter); + + static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) + { + var memberInfo = memberMap.MemberInfo; + return + memberInfo.MemberType == creatorMapParameter.MemberType && + memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); + } + } + + private static void EnsureDefaultValue(BsonMemberMap memberMap) + { + if (memberMap.IsDefaultValueSpecified) + { + return; + } + + var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; + memberMap.SetDefaultValue(defaultValue); + } + + private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + { + foreach (var memberMap in classMap.DeclaredMemberMaps) + { + if (memberMap.MemberName == memberName) + { + return memberMap; + } + } + + if (classMap.BaseClassMap != null) + { + return FindMemberMap(expression, classMap.BaseClassMap, memberName); + } + + throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index b96a193e323..29f929085d2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -31,7 +31,7 @@ public static ExecutableQuery> Translate TranslateScalar _data; + public KnownSerializerMap KnownSerializers => _knownSerializers; public NameGenerator NameGenerator => _nameGenerator; public SymbolTable SymbolTable => _symbolTable; public ExpressionTranslationOptions TranslationOptions => _translationOptions; @@ -124,7 +149,7 @@ public TranslationContext WithSymbols(params Symbol[] newSymbols) public TranslationContext WithSymbolTable(SymbolTable symbolTable) { - return new TranslationContext(_translationOptions, _data, symbolTable, _nameGenerator); + return new TranslationContext(_translationOptions, _knownSerializers, _data, symbolTable, _nameGenerator); } } } diff --git a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs index 67ca25b4261..c1abab9bfdd 100644 --- a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs +++ b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs @@ -61,7 +61,8 @@ internal static BsonValue TranslateExpressionToAggregateExpression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions, contextData); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: sourceSerializer, translationOptions: translationOptions, data: contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -76,7 +77,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -106,7 +107,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -125,8 +126,8 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: elementSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); @@ -142,7 +143,8 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +178,8 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); try @@ -215,8 +218,19 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(translationOptions); // do not partially evaluate expression var parameter = expression.Parameters.Single(); + var body = expression.Body; + + var knownSerializers = new KnownSerializerMap(); + knownSerializers.AddSerializer(parameter, documentSerializer); + if (body is MemberInitExpression memberInitExpression && + memberInitExpression.Type == typeof(TDocument)) + { + knownSerializers.AddSerializer(memberInitExpression, documentSerializer); + } + KnownSerializerFinder.FindKnownSerializers(expression, knownSerializers); + + var context = TranslationContext.Create(translationOptions, knownSerializers); // do not partially evaluate expression var symbol = context.CreateRootSymbol(parameter, documentSerializer); context = context.WithSymbol(symbol); var setStage = ExpressionToSetStageTranslator.Translate(context, documentSerializer, expression); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs new file mode 100644 index 00000000000..c04aa5c0116 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs @@ -0,0 +1,147 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Linq; +using System.Linq.Expressions; +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4593Tests : LinqIntegrationTest +{ + public CSharp4593Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void First_example_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(r => r.Id); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result.Should().Be("a"); + } + + [Fact] + public void First_example_workaround_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(Builders.Projection.Include(o => o.Id)); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result["_id"].AsString.Should().Be("a"); + } + + [Fact] + public void Second_example_should_work() + { + var collection = Fixture.Entities; + var idsFilter = Builders.Filter.Eq(x => x.Id, 1); + + var aggregate = collection.Aggregate() + .Match(idsFilter) + .Project(e => new + { + _id = e.Id, + CampaignId = e.CampaignId, + Accepted = e.Status.Key == "Accepted" ? 1 : 0, + Rejected = e.Status.Key == "Rejected" ? 1 : 0, + }); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $match : { _id : 1 } }", + """ + { $project : + { + _id : "$_id", + CampaignId : "$CampaignId", + Accepted : { $cond : { if : { $eq : ["$Status.Key", "Accepted"] }, then : 1, else : 0 } }, + Rejected : { $cond : { if : { $eq : ["$Status.Key", "Rejected"] }, then : 1, else : 0 } } + } + } + """); + + var results = aggregate.ToList(); + results.Count.Should().Be(1); + results[0]._id.Should().Be(1); + results[0].CampaignId.Should().Be(11); + results[0].Accepted.Should().Be(1); + results[0].Rejected.Should().Be(0); + } + + public class Order + { + public string Id { get; set; } + public string RateBasisHistoryId { get; set; } + } + + public class Entity + { + public int Id { get; set; } + public int CampaignId { get; set; } + public Status Status { get; set; } + } + + public class Status + { + public string Key { get; set; } + } + + public sealed class ClassFixture : MongoDatabaseFixture + { + public IMongoCollection Orders { get; private set; } + public IMongoCollection Entities { get; private set; } + + protected override void InitializeFixture() + { + Orders = CreateCollection("orders"); + Orders.InsertMany( + [ + new Order { Id = "a", RateBasisHistoryId = "abc" } + ]); + + Entities = CreateCollection("entities"); + Entities.InsertMany( + [ + new Entity { Id = 1, CampaignId = 11, Status = new Status { Key = "Accepted" } }, + new Entity { Id = 2, CampaignId = 22, Status = new Status { Key = "Rejected" } } + ]); + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs new file mode 100644 index 00000000000..9f8f49eff4e --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs @@ -0,0 +1,68 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4819Tests : LinqIntegrationTest +{ + public CSharp4819Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void ReplaceWith_should_use_configured_element_name() + { + var collection = Fixture.Collection; + var stage = PipelineStageDefinitionBuilder + .ReplaceWith((User u) => new User { UserId = u.UserId }); + + var aggregate = collection.Aggregate() + .AppendStage(stage); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $replaceWith : { uuid : '$uuid' } }"); + + var result = aggregate.Single(); + result.Id.Should().Be(0); + result.UserId.Should().Be(Guid.Parse("00112233-4455-6677-8899-aabbccddeeff")); + } + + public class User + { + public int Id { get; set; } + [BsonElement("uuid")] + [BsonGuidRepresentation(GuidRepresentation.Standard)] + public Guid UserId { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new User { Id = 1, UserId = Guid.Parse("00112233-4455-6677-8899-aabbccddeeff") } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs new file mode 100644 index 00000000000..ba690ef3aa6 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs @@ -0,0 +1,116 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4820Tests : LinqIntegrationTest +{ + public CSharp4820Tests(ClassFixture fixture) + : base(fixture) + { + } + + static CSharp4820Tests() + { + BsonClassMap.RegisterClassMap(cm => + { + cm.AutoMap(); + var readonlyCollectionMemberMap = cm.GetMemberMap(x => x.ReadOnlyCollection); + var readOnlyCollectionSerializer = readonlyCollectionMemberMap.GetSerializer(); + var bracketingCollectionSerializer = ((IChildSerializerConfigurable)readOnlyCollectionSerializer).WithChildSerializer(new StringBracketingSerializer()); + readonlyCollectionMemberMap.SetSerializer(bracketingCollectionSerializer); + }); + } + + [Fact] + public void Update_Set_with_List_should_work() + { + var values = new List() { "abc", "def" }; + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_should_throw() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_ToList_should_work() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values.ToList()); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + public class C + { + public int Id { get; set; } + public IReadOnlyCollection ReadOnlyCollection { get; set; } + } + + + private class StringBracketingSerializer : SerializerBase + { + public override string Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var bracketedValue = StringSerializer.Instance.Deserialize(context, args); + return bracketedValue.Substring(1, bracketedValue.Length - 2); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, string value) + { + var bracketedValue = "[" + value + "]"; + StringSerializer.Instance.Serialize(context, bracketedValue); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => null; + // [ + // new C { } + // ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs new file mode 100644 index 00000000000..527bc03f27e --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs @@ -0,0 +1,78 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4967Tests : LinqIntegrationTest +{ + public CSharp4967Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Set_Nested_should_work() + { + var collection = Fixture.Collection; + var update = Builders.Update + .Pipeline(new EmptyPipelineDefinition() + .Set(c => new MyDocument + { + Nested = new MyNestedDocument + { + ValueCopy = c.Value, + }, + })); + + var renderedUpdate = update.Render(new(collection.DocumentSerializer, BsonSerializer.SerializerRegistry)).AsBsonArray; + renderedUpdate.Count.Should().Be(1); + renderedUpdate[0].Should().Be("{ $set : { Nested : { ValueCopy : '$Value' } } }"); + + collection.UpdateMany("{ }", update); + + var updatedDocument = collection.FindSync("{}").Single(); + updatedDocument.Nested.ValueCopy.Should().Be("Value"); + } + + public class MyDocument + { + public int Id { get; set; } + public string Value { get; set; } + public string AnotherValue { get; set; } + public MyNestedDocument Nested { get; set; } + } + + public class MyNestedDocument + { + public string ValueCopy { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new MyDocument { Id = 1, Value = "Value" } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..dd4df74ece6 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,226 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test_set_ValueObject_Value_using_creator_map() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_to_derived_value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyDerivedValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + B = 42 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_X_using_constructor() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } + } + + class MyValue + { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } + public int Value { get; set; } + } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs new file mode 100644 index 00000000000..30f3a73072a --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs @@ -0,0 +1,66 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5519Tests : LinqIntegrationTest +{ + public CSharp5519Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Array_constant_Any_should_serialize_array_correctly() + { + var collection = Fixture.Collection; + var array = new[] { E.A, E.B }; + + var find = collection.Find(x => array.Any(e => x.E == e)); + + var filter = TranslateFindFilter(collection, find); + filter.Should().Be("{ E : { $in : ['A', 'B'] } }"); + + var results = find.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + public class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.String)] public E E { get; set; } + } + + public enum E { A, B, C } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C { Id = 1, E = E.A }, + new C { Id = 2, E = E.B }, + new C { Id = 3, E = E.C } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs new file mode 100644 index 00000000000..6a4a6df3b14 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs @@ -0,0 +1,189 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5532Tests : LinqIntegrationTest +{ + private static readonly ObjectId id1 = ObjectId.Parse("111111111111111111111111"); + private static readonly ObjectId id2 = ObjectId.Parse("222222222222222222222222"); + private static readonly ObjectId id3 = ObjectId.Parse("333333333333333333333333"); + + public CSharp5532Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Filter_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find(x => x.Parts.Any(a => a.Refs.Any(b => jobIds.Contains(b.id)))); + + var filter = TranslateFindFilter(collection, find); + + filter.Should().Be("{ Parts : { $elemMatch : { Refs : { $elemMatch : { _id : { $in : [ObjectId('222222222222222222222222')] } } } } } }"); + } + + [Fact] + public void Projection_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find("{}") + .Project(chain => + new + { + chain.Parts + .First(p => p.Refs.Any(j => jobIds.Contains(j.id))) + .Refs.First(j => jobIds.Contains(j.id)).id + });; + + var projectionTranslation = TranslateFindProjection(collection, find); + + projectionTranslation.Should().Be( + """ + { + _id : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : "$Parts", + as : "p", + cond : + { + $anyElementTrue : + { + $map : + { + input : "$$p.Refs", + as : "j", + in : { $in : ["$$j._id", ["222222222222222222222222"]] } + } + } + } + } + }, + 0 + ] + } + }, + in : "$$this.Refs" + } + }, + as : "j", + cond : { $in : ['$$j._id', ["222222222222222222222222"]] } + } + }, + 0 + ] + } + }, + in : "$$this._id" + } + } + } + """); + } + + public class Document + { + [BsonId] + [BsonRepresentation(BsonType.ObjectId)] + public string id { get; set; } + } + + public class Chain : Document + { + public ICollection Parts { get; set; } = new List(); + } + + public class Unit + { + public ICollection Refs { get; set; } + + public Unit() + { + Refs = new List(); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new Chain + { + id = "0102030405060708090a0b0c", + Parts = new List() + { + new() + { + Refs = new List() + { + new() + { + id = id1.ToString(), + }, + new() + { + id = id2.ToString(), + }, + new() + { + id = id3.ToString(), + }, + } + } + } + } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index 10f3f2a5d14..700bbcbf7ee 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -183,7 +183,7 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue private TranslationContext CreateContext(ParameterExpression parameter) { var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(parameter, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index 45bbf7067af..82bc41fd5ec 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -546,7 +546,7 @@ private ProjectedResult Group(Expression(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(expression, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 0869d70822e..4d878ca70b9 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1154,7 +1154,7 @@ public List Assert(IMongoCollection collection, var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(filter, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body);