diff --git a/src/NHibernate.DomainModel/FooComponent.cs b/src/NHibernate.DomainModel/FooComponent.cs index 4bd536eed96..e1c88c7b449 100644 --- a/src/NHibernate.DomainModel/FooComponent.cs +++ b/src/NHibernate.DomainModel/FooComponent.cs @@ -92,6 +92,8 @@ public Int32 Count set { _count = value; } } + public int NotMapped { get; set; } + public DateTime[] ImportantDates { get { return _importantDates; } diff --git a/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs b/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs index 89d599183bf..95abf028f98 100644 --- a/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs +++ b/src/NHibernate.DomainModel/NHSpecific/NullableInt32.cs @@ -7,7 +7,7 @@ namespace NHibernate.DomainModel.NHSpecific /// A nullable type that wraps an value. /// [TypeConverter(typeof(NullableInt32Converter)), Serializable()] - public struct NullableInt32 : IFormattable, IComparable + public struct NullableInt32 : IFormattable, IComparable, IConvertible { public static readonly NullableInt32 Default = new NullableInt32(); @@ -234,5 +234,94 @@ public static NullableInt32 Parse(string s) // TODO: implement the rest of the Parse overloads found in Int32 #endregion + + #region IConvertible + + public TypeCode GetTypeCode() + { + return _value.GetTypeCode(); + } + + public bool ToBoolean(IFormatProvider provider) + { + return ((IConvertible) _value).ToBoolean(provider); + } + + public char ToChar(IFormatProvider provider) + { + return ((IConvertible) _value).ToChar(provider); + } + + public sbyte ToSByte(IFormatProvider provider) + { + return ((IConvertible) _value).ToSByte(provider); + } + + public byte ToByte(IFormatProvider provider) + { + return ((IConvertible) _value).ToByte(provider); + } + + public short ToInt16(IFormatProvider provider) + { + return ((IConvertible) _value).ToInt16(provider); + } + + public ushort ToUInt16(IFormatProvider provider) + { + return ((IConvertible) _value).ToUInt16(provider); + } + + public int ToInt32(IFormatProvider provider) + { + return ((IConvertible) _value).ToInt32(provider); + } + + public uint ToUInt32(IFormatProvider provider) + { + return ((IConvertible) _value).ToUInt32(provider); + } + + public long ToInt64(IFormatProvider provider) + { + return ((IConvertible) _value).ToInt64(provider); + } + + public ulong ToUInt64(IFormatProvider provider) + { + return ((IConvertible) _value).ToUInt64(provider); + } + + public float ToSingle(IFormatProvider provider) + { + return ((IConvertible) _value).ToSingle(provider); + } + + public double ToDouble(IFormatProvider provider) + { + return ((IConvertible) _value).ToDouble(provider); + } + + public decimal ToDecimal(IFormatProvider provider) + { + return ((IConvertible) _value).ToDecimal(provider); + } + + public DateTime ToDateTime(IFormatProvider provider) + { + return ((IConvertible) _value).ToDateTime(provider); + } + + public string ToString(IFormatProvider provider) + { + return _value.ToString(provider); + } + + public object ToType(System.Type conversionType, IFormatProvider provider) + { + return ((IConvertible) _value).ToType(conversionType, provider); + } + + #endregion } } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Address.cs b/src/NHibernate.DomainModel/Northwind/Entities/Address.cs index d224bc50cf7..d2d56fd6823 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Address.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Address.cs @@ -61,6 +61,8 @@ public string Fax get { return _fax; } } + public int NotMapped => 1; + public static bool operator ==(Address address1, Address address2) { if (!ReferenceEquals(address1, null) && @@ -114,4 +116,4 @@ public override int GetHashCode() (_fax ?? string.Empty).GetHashCode(); } } -} \ No newline at end of file +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs b/src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs new file mode 100644 index 00000000000..52c661c9ec1 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/IEntity.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public interface IEntity + { + TId Id { get; set; } + } + + public interface IEntity : IEntity + { + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Product.cs b/src/NHibernate.DomainModel/Northwind/Entities/Product.cs index 5af739cc0d2..8c72b3895d0 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Product.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Product.cs @@ -103,9 +103,11 @@ public virtual float ShippingWeight set { _shippingWeight = value; } } + public virtual int NotMapped => 1; + public virtual ReadOnlyCollection OrderLines { get { return new ReadOnlyCollection(_orderLines); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index a2bde32af30..c3f220ffda5 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -28,7 +28,7 @@ public interface IUser EnumStoredAsInt32 Enum2 { get; set; } } - public class User : IUser + public class User : IUser, IEntity { public virtual int Id { get; set; } @@ -50,6 +50,10 @@ public class User : IUser public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual int NotMapped { get; set; } + + public virtual Role NotMappedRole { get; set; } + public User() { } public User(string name, DateTime registeredAt) diff --git a/src/NHibernate.Test/Async/Linq/SelectionTests.cs b/src/NHibernate.Test/Async/Linq/SelectionTests.cs index b4ac7e372e4..cf065e6bf5d 100644 --- a/src/NHibernate.Test/Async/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Async/Linq/SelectionTests.cs @@ -11,7 +11,9 @@ using System; using System.Collections.Generic; using System.Linq; +using NHibernate.DomainModel.NHSpecific; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Type; using NUnit.Framework; using NHibernate.Linq; @@ -307,6 +309,10 @@ public async Task CanProjectWithCastAsync() var names5 = await (db.Users.Select(p => new { p1 = (p as IUser).Name }).ToListAsync()); Assert.AreEqual(3, names5.Count); + + var names6 = await (db.Users.Select(p => new { p1 = (long) p.Id }).ToListAsync()); + Assert.AreEqual(3, names6.Count); + // ReSharper restore RedundantCast } @@ -453,6 +459,23 @@ public async Task CanSelectConditionalObjectAsync() Assert.That(fatherIsKnown, Has.Exactly(1).With.Property("FatherIsKnown").True); } + [Test] + public async Task CanCastToDerivedTypeAsync() + { + var dogs = await (db.Animals + .Where(a => ((Dog) a).Pregnant) + .Select(a => new {a.SerialNumber}) + .ToListAsync()); + Assert.That(dogs, Has.Exactly(1).With.Property("SerialNumber").Not.Null); + } + + [Test] + public async Task CanCastToCustomRegisteredTypeAsync() + { + TypeFactory.RegisterType(typeof(NullableInt32), new NullableInt32Type(), Enumerable.Empty()); + Assert.That(await (db.Users.Where(o => (NullableInt32) o.Id == 1).ToListAsync()), Has.Count.EqualTo(1)); + } + public class Wrapper { public T item; diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs new file mode 100644 index 00000000000..7b148a864dc --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2029/Fixture.cs @@ -0,0 +1,190 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Dialect; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH2029 +{ + using System.Threading.Tasks; + [TestFixture] + public class FixtureAsync : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.Native)); + rc.Property(x => x.NullableInt32Prop); + rc.Property(x => x.Int32Prop); + rc.Property(x => x.NullableInt64Prop); + rc.Property(x => x.Int64Prop); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return !(dialect is SQLiteDialect); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = null, + Int64Prop = int.MaxValue, + NullableInt64Prop = null + }); + + tx.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.CreateQuery("delete from TestClass").ExecuteUpdate(); + + tx.Commit(); + } + } + + [Test] + public async Task NullableIntOverflowAsync() + { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.NullableInt32Prop) + }) + .ToListAsync()); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public async Task IntOverflowAsync() + { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.Int32Prop) + }) + .ToListAsync()); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + [Test] + public async Task NullableInt64NoCastAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.NullableInt64Prop) + }) + .ToListAsync()); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public async Task Int64NoCastAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = await (session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.Int64Prop) + }) + .ToListAsync()); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + private int FindAllOccurrences(string source, string substring) + { + if (source == null) + { + return 0; + } + int n = 0, count = 0; + while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) + { + n += substring.Length; + count++; + } + return count; + } + } +} diff --git a/src/NHibernate.Test/Linq/SelectionTests.cs b/src/NHibernate.Test/Linq/SelectionTests.cs index 3873558badf..7aac7edc2da 100644 --- a/src/NHibernate.Test/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Linq/SelectionTests.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using NHibernate.DomainModel.NHSpecific; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Type; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -346,6 +348,10 @@ public void CanProjectWithCast() var names5 = db.Users.Select(p => new { p1 = (p as IUser).Name }).ToList(); Assert.AreEqual(3, names5.Count); + + var names6 = db.Users.Select(p => new { p1 = (long) p.Id }).ToList(); + Assert.AreEqual(3, names6.Count); + // ReSharper restore RedundantCast } @@ -492,6 +498,23 @@ public void CanSelectConditionalObject() Assert.That(fatherIsKnown, Has.Exactly(1).With.Property("FatherIsKnown").True); } + [Test] + public void CanCastToDerivedType() + { + var dogs = db.Animals + .Where(a => ((Dog) a).Pregnant) + .Select(a => new {a.SerialNumber}) + .ToList(); + Assert.That(dogs, Has.Exactly(1).With.Property("SerialNumber").Not.Null); + } + + [Test] + public void CanCastToCustomRegisteredType() + { + TypeFactory.RegisterType(typeof(NullableInt32), new NullableInt32Type(), Enumerable.Empty()); + Assert.That(db.Users.Where(o => (NullableInt32) o.Id == 1).ToList(), Has.Count.EqualTo(1)); + } + public class Wrapper { public T item; diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs new file mode 100644 index 00000000000..11724e1ac9b --- /dev/null +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -0,0 +1,816 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.DomainModel; +using NHibernate.DomainModel.NHSpecific; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Persister.Entity; +using NHibernate.Type; +using NHibernate.Util; +using NUnit.Framework; +using IQueryable = System.Linq.IQueryable; + +namespace NHibernate.Test.Linq +{ + /// + /// Tests form ExpressionsHelper.TryGetMappedType and ExpressionsHelper.TryGetMappedNullability + /// + public class TryGetMappedTests : LinqTestCase + { + private static readonly TryGetMappedType _tryGetMappedType; + private static readonly TryGetMappedNullability _tryGetMappedNullability; + + private delegate bool TryGetMappedType( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath); + + private delegate bool TryGetMappedNullability( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out bool nullability); + + static TryGetMappedTests() + { + var method = typeof(ExpressionsHelper).GetMethod( + nameof(TryGetMappedType), + BindingFlags.NonPublic | BindingFlags.Static); + var sessionFactoryParam = Expression.Parameter(typeof(ISessionFactoryImplementor), "sessionFactory"); + var expressionParam = Expression.Parameter(typeof(Expression), "expression"); + var mappedTypeParam = Expression.Parameter(typeof(IType).MakeByRefType(), "mappedType"); + var entityPersisterParam = Expression.Parameter(typeof(IEntityPersister).MakeByRefType(), "entityPersister"); + var componentParam = Expression.Parameter(typeof(IAbstractComponentType).MakeByRefType(), "component"); + var memberPathParam = Expression.Parameter(typeof(string).MakeByRefType(), "memberPath"); + var methodCall = Expression.Call( + method, + sessionFactoryParam, + expressionParam, + mappedTypeParam, + entityPersisterParam, + componentParam, + memberPathParam); + _tryGetMappedType = Expression.Lambda( + methodCall, + sessionFactoryParam, + expressionParam, + mappedTypeParam, + entityPersisterParam, + componentParam, + memberPathParam).Compile(); + + method = typeof(ExpressionsHelper).GetMethod( + nameof(TryGetMappedNullability), + BindingFlags.NonPublic | BindingFlags.Static); + var nullabilityParam = Expression.Parameter(typeof(bool).MakeByRefType(), "nullability"); + methodCall = Expression.Call( + method, + sessionFactoryParam, + expressionParam, + nullabilityParam); + _tryGetMappedNullability = Expression.Lambda( + methodCall, + sessionFactoryParam, + expressionParam, + nullabilityParam).Compile(); + } + + protected override string[] Mappings + { + get + { + return + new[] + { + "ABC.hbm.xml", + "Baz.hbm.xml", + "FooBar.hbm.xml", + "Glarch.hbm.xml", + "Fee.hbm.xml", + "Qux.hbm.xml", + "Fum.hbm.xml", + "Holder.hbm.xml", + "One.hbm.xml", + "Many.hbm.xml" + }.Concat(base.Mappings).ToArray(); + } + } + + [Test] + public void SelfTest() + { + var query = db.OrderLines.Select(o => o); + AssertSupportedAndResultNotNullable( + query, + typeof(OrderLine).FullName, + null, + o => o is EntityType entityType && entityType.ReturnedClass == typeof(OrderLine)); + } + + [Test] + public void SelfCastNotMappedTest() + { + var query = session.Query().Select(o => (object) o); + AssertSupportedAndResultNotNullable( + query, + false, + typeof(A).FullName, + null, + o => o is SerializableType serializableType && serializableType.ReturnedClass == typeof(object)); + } + + [Test] + public void PropertyTest() + { + var query = db.OrderLines.Select(o => o.Quantity); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Quantity", o => o is Int32Type); + } + + [Test] + public void NotMappedPropertyTest() + { + var query = db.Users.Select(o => o.NotMapped); + AssertUnsupported(query, typeof(User).FullName, "NotMapped", o => o is null); + } + + [Test] + public void NestedNotMappedPropertyTest() + { + var query = db.Users.Select(o => o.Name.Length); + AssertUnsupported(query, false, null, null, o => o is null); + } + + [Test] + public void PropertyCastTest() + { + var query = db.OrderLines.Select(o => (long) o.Quantity); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Quantity", o => o is Int64Type); + } + + [Test] + public void PropertyIndexer() + { + var query = db.Products.Select(o => o.Name[0]); + AssertUnsupported(query, null, null, o => o == null); + } + + [Test] + public void EnumInt32Test() + { + var query = db.Users.Select(o => o.Enum2); + AssertSupportedAndResultNotNullable( + query, + typeof(User).FullName, + "Enum2", + o => o.GetType().GetGenericArguments().FirstOrDefault() == typeof(EnumStoredAsInt32)); + } + + [Test] + public void EnumInt32CastTest() + { + var query = db.Users.Select(o => (int) o.Enum2); + AssertSupportedAndResultNotNullable(query, typeof(User).FullName, "Enum2", o => o is Int32Type); + } + + [Test] + public void EnumAsStringTest() + { + var query = db.Users.Select(o => o.Enum1); + AssertSupported(query, typeof(User).FullName, "Enum1", o => o is EnumStoredAsStringType); + } + + [Test] + public void IdentifierTest() + { + var query = db.OrderLines.Select(o => o.Id); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Id", o => o is Int64Type); + } + + [Test] + public void CompositeIdentifierTest() + { + var query = session.Query().Select(o => o.Id.Date); + AssertSupportedAndResultNotNullable( + query, + typeof(Fum).FullName, + "Id.Date", + o => o is DateTimeType, + o => o?.Name == "component[String,Short,Date]"); + } + + [Test] + public void ComponentTest() + { + var query = db.Customers.Select(o => o.Address); + AssertSupported( + query, + typeof(Customer).FullName, + "Address", + o => o is ComponentType && o.Name == "component[Street,City,Region,PostalCode,Country,PhoneNumber,Fax]"); + } + + [Test] + public void ComponentPropertyTest() + { + var query = db.Customers.Select(o => o.Address.City); + AssertSupported( + query, + typeof(Customer).FullName, + "Address.City", + o => o is StringType, + o => o?.Name == "component[Street,City,Region,PostalCode,Country,PhoneNumber,Fax]"); + } + + [Test] + public void ComponentNotMappedPropertyTest() + { + var query = db.Customers.Select(o => o.Address.NotMapped); + AssertUnsupported( + query, + typeof(Customer).FullName, + "Address.NotMapped", + o => o == null, + o => o?.Name == "component[Street,City,Region,PostalCode,Country,PhoneNumber,Fax]"); + } + + [Test] + public void ComponentNestedNotMappedPropertyTest() + { + var query = db.Customers.Select(o => o.Address.City.Length); + AssertUnsupported(query, false, null, null, o => o == null); + } + + [Test] + public void NestedComponentPropertyTest() + { + var query = db.Users.Select(o => o.Component.OtherComponent.OtherProperty1); + AssertSupported( + query, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is AnsiStringType, + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void NestedComponentPropertyCastTest() + { + var query = db.Users.Select(o => (object) o.Component.OtherComponent.OtherProperty1); + AssertSupported( + query, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is SerializableType serializableType && serializableType.ReturnedClass == typeof(object), + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void ManyToOneTest() + { + var query = db.OrderLines.Select(o => o.Order); + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "Order", + o => o is ManyToOneType manyToOne && manyToOne.PropertyName == "Order"); + } + + [Test] + public void ManyToOnePropertyTest() + { + var query = db.OrderLines.Select(o => o.Order.Freight); + AssertSupported(query, typeof(Order).FullName, "Freight", o => o is DecimalType); + } + + [Test] + public void ManyToOneNotMappedPropertyTest() + { + var query = db.OrderLines.Select(o => o.Product.NotMapped); + AssertUnsupported(query, typeof(Product).FullName, "NotMapped", o => o == null); + } + + [Test] + public void NotMappedManyToOnePropertyTest() + { + var query = db.Users.Select(o => o.NotMappedRole.Name); + AssertUnsupported(query, false, null, null, o => o is null); + } + + [Test] + public void NestedManyToOneTest() + { + var query = db.OrderLines.Select(o => o.Order.Employee); + AssertSupported(query, false, typeof(Order).FullName, "Employee", + o => o is ManyToOneType manyToOne && manyToOne.PropertyName == "Employee"); + } + + [Test] + public void NestedManyToOnePropertyTest() + { + var query = db.OrderLines.Select(o => o.Order.Employee.BirthDate); + AssertSupported(query, typeof(Employee).FullName, "BirthDate", o => o is DateTimeType); + } + + [Test] + public void OneToManyTest() + { + var query = db.Customers.SelectMany(o => o.Orders); + AssertSupported( + query, + typeof(Customer).FullName, + "Orders", + o => o is CollectionType collectionType && collectionType.Role == $"{typeof(Customer).FullName}.Orders"); + } + + [Test] + public void OneToManyElementIndexerTest() + { + var query = session.Query().Select(o => o.StringList[0]); + AssertSupported(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); + } + + [Test] + public void OneToManyElementIndexerNotMappedPropertyTest() + { + var query = session.Query().Select(o => o.StringList[0].Length); + AssertUnsupported(query, false, null, null, o => o == null); + } + + [Test] + public void OneToManyCustomElementIndexerTest() + { + var query = session.Query().Select(o => o.Customs[0]); + AssertSupported( + query, + false, + typeof(Baz).FullName, + "Customs", + o => o is CompositeCustomType customType && customType.UserType is DoubleStringType); + } + + [Test] + public void OneToManyIndexerCastTest() + { + var query = session.Query().Select(o => (long) o.IntArray[0]); + AssertSupported(query, false, typeof(Baz).FullName, "IntArray", o => o is Int64Type); + } + + [Test] + public void OneToManyIndexerPropertyTest() + { + var query = session.Query().Select(o => o.Fees[0].Count); + AssertSupported(query, false, typeof(Fee).FullName, "Count", o => o is Int32Type); + } + + [Test] + public void OneToManyElementAtTest() + { + var query = session.Query().Select(o => o.StringList.ElementAt(0)); + AssertSupported(query, false, typeof(Baz).FullName, "StringList", o => o is StringType); + } + + [Test] + public void NestedOneToManyManyToOneComponentPropertyTest() + { + var query = session.Query().SelectMany(o => o.Fees).Select(o => o.TheFee.Compon.Name); + AssertSupported( + query, + typeof(Fee).FullName, + "Compon.Name", + o => o is StringType, + o => o?.Name == "component[Name,NullString]"); + } + + [Test] + public void OneToManyCompositeElementPropertyTest() + { + var query = session.Query().Select(o => o.Components[0].Count); + AssertSupported( + query, + false, + null, + "Count", + o => o is Int32Type, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void OneToManyCompositeElementPropertyIndexerTest() + { + var query = session.Query().Select(o => o.Components[0].Name[0]); + AssertUnsupported(query, false, null, null, o => o == null); + } + + [Test] + public void OneToManyCompositeElementNotMappedPropertyTest() + { + var query = session.Query().Select(o => o.Components[0].NotMapped); + AssertUnsupported( + query, + false, + null, + "NotMapped", + o => o == null, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void OneToManyCompositeElementCastPropertyTest() + { + var query = session.Query().Select(o => (long) o.Components[0].Count); + AssertSupported( + query, + false, + null, + "Count", + o => o is Int64Type, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void OneToManyCompositeElementCollectionNotMappedPropertyTest() + { + var query = session.Query().SelectMany(o => o.Components[0].ImportantDates); + AssertUnsupported( + query, + false, + null, + "ImportantDates", + o => o == null, + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void NestedOneToManyCompositeElementTest() + { + var query = session.Query().Select(o => o.Components[0].Subcomponent); + AssertSupported( + query, + false, + null, + "Subcomponent", + o => o is IAbstractComponentType componentType && componentType.ReturnedClass == typeof(FooComponent), + o => o?.Name == "component[Name,Count,Subcomponent]"); + } + + [Test] + public void NestedOneToManyCompositeElementPropertyTest() + { + var query = session.Query().Select(o => o.Components[0].Subcomponent.Name); + AssertSupported(query, false, null, "Name", o => o is StringType, o => o?.Name == "component[Name,Count]"); + } + + [Test] + public void NestedOneToManyCompositeElementPropertyIndexerTest() + { + var query = session.Query().Select(o => o.Components[0].Subcomponent.Name[0]); + AssertUnsupported(query, false, null, null, o => o == null); + } + + [Test] + public void ManyToManyTest() + { + var query = session.Query().Select(o => o.FooArray); + AssertSupported( + query, + false, + typeof(Baz).FullName, + "FooArray", + o => o is ArrayType arrayType && arrayType.Role == $"{typeof(Baz).FullName}.FooArray"); + } + + [Test] + public void ManyToManyIndexerTest() + { + var query = session.Query().Select(o => o.FooArray[0].Null); + AssertSupported(query, false, typeof(Foo).FullName, "Null", o => o is NullableInt32Type); + } + + [Test] + public void SubclassCastTest() + { + var query = session.Query().Select(o => (B) o); + AssertSupportedAndResultNotNullable( + query, + typeof(A).FullName, + null, + o => o is EntityType entityType && entityType.ReturnedClass == typeof(B)); + } + + [Test] + public void NestedSubclassCastTest() + { + var query = session.Query().Select(o => (C1) ((B) o)); + AssertSupportedAndResultNotNullable( + query, + false, + typeof(A).FullName, + null, + o => o is EntityType entityType && entityType.ReturnedClass == typeof(C1)); + } + + [Test] + public void SubclassPropertyTest() + { + var query = session.Query().Select(o => ((C1) o).Count); + AssertSupported(query, typeof(C1).FullName, "Count", o => o is Int32Type); + } + + [Test] + public void NestedSubclassCastPropertyTest() + { + var query = session.Query().Select(o => ((C1) ((B) o)).Id); + AssertSupportedAndResultNotNullable(query, typeof(C1).FullName, "Id", o => o is Int64Type); + } + + [Test] + public void AnyTest() + { + var query = session.Query().Select(o => o.Object); + AssertSupported(query, typeof(Bar).FullName, "Object", o => o.IsAnyType); + } + + [Test] + public void CastAnyTest() + { + var query = session.Query().Select(o => (Foo) o.Object); + AssertSupported( + query, + typeof(Bar).FullName, + "Object", + o => o is EntityType entityType && entityType.ReturnedClass == typeof(Foo)); + } + + [Test] + public void NestedCastAnyTest() + { + var query = session.Query().Select(o => (Foo) ((Bar) o.Object).Object); + AssertSupported( + query, + false, + typeof(Bar).FullName, + "Object", + o => o is EntityType entityType && entityType.ReturnedClass == typeof(Foo)); + } + + [Test] + public void CastAnyManyToOneTest() + { + var query = session.Query().Select(o => ((Foo) o.Object).Dependent); + AssertSupportedAndResultNotNullable( + query, + typeof(Foo).FullName, + "Dependent", + o => o is EntityType entityType && entityType.ReturnedClass == typeof(Fee)); + } + + [Test] + public void CastAnyPropertyTest() + { + var query = session.Query().Select(o => ((Foo) o.Object).String); + AssertSupported(query, false, typeof(Foo).FullName, "String", o => o is StringType); + } + + [Test] + public void QueryUnmappedEntityTest() + { + var query = session.Query>().Select(o => o.Id); + AssertSupportedAndResultNotNullable(query, typeof(User).FullName, "Id", o => o is Int32Type); + } + + [Test] + public void ConditionalExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.RegisteredAt : o.LastLoginDate)); + AssertSupported(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); + } + + [Test] + public void ConditionalIfFalseExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? DateTime.Today : o.LastLoginDate)); + AssertSupported(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); + } + + [Test] + public void ConditionalMemberExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.NotMappedRole : o.Role).IsActive); + AssertSupported(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + } + + [Test] + public void ConditionalNestedExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.Component.OtherComponent.OtherProperty1 : o.Component.Property1)); + AssertSupported( + query, + false, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is AnsiStringType, + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void CoalesceExpressionTest() + { + var query = db.Users.Select(o => o.LastLoginDate ?? o.RegisteredAt); + AssertSupported(query, false, typeof(User).FullName, "LastLoginDate", o => o is DateTimeType); + } + + [Test] + public void CoalesceRightExpressionTest() + { + var query = db.Users.Select(o => ((DateTime?) DateTime.Now) ?? o.RegisteredAt); + AssertSupported(query, false, typeof(User).FullName, "RegisteredAt", o => o is DateTimeType); + } + + [Test] + public void CoalesceMemberExpressionTest() + { + var query = db.Users.Select(o => (o.NotMappedRole ?? o.Role).IsActive); + AssertSupported(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + } + + [Test] + public void CoalesceNestedExpressionTest() + { + var query = db.Users.Select(o => o.Component.OtherComponent.OtherProperty1 ?? o.Component.Property1); + AssertSupported( + query, + false, + typeof(User).FullName, + "Component.OtherComponent.OtherProperty1", + o => o is AnsiStringType, + o => o?.Name == "component[OtherProperty1]"); + } + + [Test] + public void CoalesceConditionalMemberExpressionTest() + { + var query = db.Users.Select(o => (o.Name == "Test" ? o.NotMappedRole : (o.NotMappedRole ?? new Role() ?? o.Role)).IsActive); + AssertSupported(query, false, typeof(Role).FullName, "IsActive", o => o is BooleanType); + } + + [Test] + public void JoinTest() + { + var query = from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + into details + from d in details + select d.UnitPrice; + AssertSupportedAndResultNotNullable(query, typeof(OrderLine).FullName, "UnitPrice", o => o is DecimalType); + } + + [Test] + public void NotNullComponentPropertyTest() + { + var query = session.Query().SelectMany(o => o.PatientRecords.Select(r => r.Name.FirstName)); + AssertSupportedAndResultNotNullable( + query, + typeof(PatientRecord).FullName, + "Name.FirstName", + o => o is StringType, + o => o?.Name == "component[FirstName,LastName]"); + } + + [Test] + public void NotRelatedTypeTest() + { + var query = session.Query().Select(o => o.CanReduce); + AssertUnsupported(query, null, null, o => o == null); + } + + [Test] + public void NotNhQueryableTest() + { + var query = new List().AsQueryable().Select(o => o.Name); + AssertUnsupported(query, false, null, null, o => o == null); + } + + private void AssertUnsupported( + IQueryable query, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, true, false, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertUnsupported( + IQueryable query, + bool rewriteQuery, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, rewriteQuery, false, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertSupported( + IQueryable query, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, true, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertSupported( + IQueryable query, + bool rewriteQuery, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, rewriteQuery, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType); + } + + private void AssertSupportedAndResultNotNullable( + IQueryable query, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, true, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType, false); + } + + private void AssertSupportedAndResultNotNullable( + IQueryable query, + bool rewriteQuery, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null) + { + AssertResult(query, rewriteQuery, true, expectedEntityName, expectedMemberPath, expectedMemberType, expectedComponentType, false); + } + + private void AssertResult( + IQueryable query, + bool rewriteQuery, + bool supported, + string expectedEntityName, + string expectedMemberPath, + Predicate expectedMemberType, + Predicate expectedComponentType = null, + bool nullability = true) + { + expectedComponentType = expectedComponentType ?? (o => o == null); + + var expression = query.Expression; + NhRelinqQueryParser.PreTransform(expression); + var constantToParameterMap = ExpressionParameterVisitor.Visit(expression, Sfi); + var queryModel = NhRelinqQueryParser.Parse(expression); + var requiredHqlParameters = new List(); + var visitorParameters = new VisitorParameters( + Sfi, + constantToParameterMap, + requiredHqlParameters, + new QuerySourceNamer(), + expression.Type, + QueryMode.Select); + if (rewriteQuery) + { + QueryModelVisitor.GenerateHqlQuery( + queryModel, + visitorParameters, + true, + NhLinqExpressionReturnType.Scalar); + } + + var found = _tryGetMappedType( + Sfi, + queryModel.SelectClause.Selector, + out var memberType, + out var entityPersister, + out var componentType, + out var memberPath); + Assert.That(found, Is.EqualTo(supported), $"Expression should be {(supported ? "supported" : "unsupported")}"); + Assert.That(entityPersister?.EntityName, Is.EqualTo(expectedEntityName), "Invalid entity name"); + Assert.That(memberPath, Is.EqualTo(expectedMemberPath), "Invalid member path"); + Assert.That(() => expectedMemberType(memberType), $"Invalid member type: {memberType?.Name ?? "null"}"); + Assert.That(() => expectedComponentType(componentType), $"Invalid component type: {componentType?.Name ?? "null"}"); + + if (found) + { + Assert.That(_tryGetMappedNullability(Sfi, queryModel.SelectClause.Selector, out var isNullable), Is.True, "Expression should be supported"); + Assert.That(nullability, Is.EqualTo(isNullable), "Nullability is not correct"); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs new file mode 100644 index 00000000000..544034db0ea --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/Fixture.cs @@ -0,0 +1,178 @@ +using System; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Dialect; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2029 +{ + [TestFixture] + public class Fixture : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.Native)); + rc.Property(x => x.NullableInt32Prop); + rc.Property(x => x.Int32Prop); + rc.Property(x => x.NullableInt64Prop); + rc.Property(x => x.Int64Prop); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return !(dialect is SQLiteDialect); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = int.MaxValue, + Int64Prop = int.MaxValue, + NullableInt64Prop = int.MaxValue + }); + session.Save(new TestClass + { + Int32Prop = int.MaxValue, + NullableInt32Prop = null, + Int64Prop = int.MaxValue, + NullableInt64Prop = null + }); + + tx.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var tx = session.BeginTransaction()) + { + session.CreateQuery("delete from TestClass").ExecuteUpdate(); + + tx.Commit(); + } + } + + [Test] + public void NullableIntOverflow() + { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.NullableInt32Prop) + }) + .ToList(); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public void IntOverflow() + { + var hasCast = Dialect.GetCastTypeName(NHibernateUtil.Int32.SqlType) != + Dialect.GetCastTypeName(NHibernateUtil.Int64.SqlType); + + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new + { + s = g.Sum(i => (long) i.Int32Prop) + }) + .ToList(); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(hasCast ? 1 : 0)); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + [Test] + public void NullableInt64NoCast() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.NullableInt64Prop) + }) + .ToList(); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2)); + } + } + + [Test] + public void Int64NoCast() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + using (var sqlLog = new SqlLogSpy()) + { + var groups = session.Query() + .GroupBy(i => 1) + .Select(g => new { + s = g.Sum(i => i.Int64Prop) + }) + .ToList(); + + Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast")); + Assert.That(groups, Has.Count.EqualTo(1)); + Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3)); + } + } + + private int FindAllOccurrences(string source, string substring) + { + if (source == null) + { + return 0; + } + int n = 0, count = 0; + while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) + { + n += substring.Length; + count++; + } + return count; + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs b/src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs new file mode 100644 index 00000000000..c15c60dfee3 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2029/TestClass.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NHibernate.Test.NHSpecificTest.GH2029 +{ + public class TestClass + { + public virtual int Id { get; set; } + public virtual int? NullableInt32Prop { get; set; } + public virtual int Int32Prop { get; set; } + public virtual long? NullableInt64Prop { get; set; } + public virtual long Int64Prop { get; set; } + } +} diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 7447f9494e7..e3d0b6cacd9 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -266,6 +266,18 @@ public virtual string GetLongestTypeName(DbType dbType) public virtual string GetCastTypeName(SqlType sqlType) => GetCastTypeName(sqlType, _typeNames); + /// + /// Get the name of the database type appropriate for casting operations + /// (via the CAST() SQL function) for the given typecode. + /// + /// The typecode. + /// The database type name that will be set in case it was found. + /// Whether the type name was found. + public virtual bool TryGetCastTypeName(SqlType sqlType, out string typeName) + { + return TryGetCastTypeName(sqlType, _typeNames, out typeName); + } + /// /// Get the name of the database type appropriate for casting operations /// (via the CAST() SQL function) for the given typecode. @@ -274,9 +286,27 @@ public virtual string GetCastTypeName(SqlType sqlType) => /// The source for type names. /// The database type name. protected virtual string GetCastTypeName(SqlType sqlType, TypeNames castTypeNames) + { + if (!TryGetCastTypeName(sqlType, castTypeNames, out var result)) + { + throw new ArgumentException("Dialect does not support DbType." + sqlType.DbType, nameof(sqlType)); + } + + return result; + } + + /// + /// Get the name of the database type appropriate for casting operations + /// (via the CAST() SQL function) for the given typecode. + /// + /// The typecode. + /// The source for type names. + /// The database type name that will be set in case it was found. + /// Whether the type name was found. + protected virtual bool TryGetCastTypeName(SqlType sqlType, TypeNames castTypeNames, out string typeName) { if (sqlType.LengthDefined || sqlType.PrecisionDefined || sqlType.ScaleDefined) - return castTypeNames.Get(sqlType.DbType, sqlType.Length, sqlType.Precision, sqlType.Scale); + return castTypeNames.TryGet(sqlType.DbType, sqlType.Length, sqlType.Precision, sqlType.Scale, out typeName); switch (sqlType.DbType) { case DbType.Decimal: @@ -284,18 +314,18 @@ protected virtual string GetCastTypeName(SqlType sqlType, TypeNames castTypeName case DbType.Double: // We cannot know if the user needs its digit after or before the dot, so use a configurable // default. - return castTypeNames.Get(sqlType.DbType, 0, DefaultCastPrecision, DefaultCastScale); + return castTypeNames.TryGet(sqlType.DbType, 0, DefaultCastPrecision, DefaultCastScale, out typeName); case DbType.DateTime: case DbType.DateTime2: case DbType.DateTimeOffset: case DbType.Time: case DbType.Currency: // Use default for these, dialects are supposed to map them to max capacity - return castTypeNames.Get(sqlType.DbType); + return castTypeNames.TryGet(sqlType.DbType, out typeName); default: // Other types are either length bound or not length/precision/scale bound. Otherwise they need to be // handled previously. - return castTypeNames.Get(sqlType.DbType, DefaultCastLength, 0, 0); + return castTypeNames.TryGet(sqlType.DbType, DefaultCastLength, 0, 0, out typeName); } } diff --git a/src/NHibernate/Dialect/Function/CastFunction.cs b/src/NHibernate/Dialect/Function/CastFunction.cs index 747dd90440a..8580da1997f 100644 --- a/src/NHibernate/Dialect/Function/CastFunction.cs +++ b/src/NHibernate/Dialect/Function/CastFunction.cs @@ -50,13 +50,8 @@ public SqlString Render(IList args, ISessionFactoryImplementor factory) { throw new QueryException("invalid NHibernate type for cast(), was:" + typeName); } + sqlType = factory.Dialect.GetCastTypeName(sqlTypeCodes[0]); - if (sqlType == null) - { - //TODO: never reached, since GetTypeName() actually throws an exception! - sqlType = typeName; - } - //else //{ // //trim off the length/precision/scale // int loc = sqlType.IndexOf('('); diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index 475c1d1e5c2..9a83de26eb3 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -497,6 +497,10 @@ protected void RegisterCastType(DbType code, int capacity, string name) public override string GetCastTypeName(SqlType sqlType) => GetCastTypeName(sqlType, castTypeNames); + /// + public override bool TryGetCastTypeName(SqlType sqlType, out string typeName) => + TryGetCastTypeName(sqlType, castTypeNames, out typeName); + public override long TimestampResolutionInTicks { get diff --git a/src/NHibernate/Dialect/TypeNames.cs b/src/NHibernate/Dialect/TypeNames.cs index 7c25a211461..b2fbd4529fc 100644 --- a/src/NHibernate/Dialect/TypeNames.cs +++ b/src/NHibernate/Dialect/TypeNames.cs @@ -57,13 +57,24 @@ public class TypeNames /// the default type name associated with the specified key public string Get(DbType typecode) { - if (!defaults.TryGetValue(typecode, out var result)) + if (!TryGet(typecode, out var result)) { throw new ArgumentException("Dialect does not support DbType." + typecode, nameof(typecode)); } return result; } + /// + /// Get default type name for specified type. + /// + /// The type key. + /// The default type name that will be set in case it was found. + /// Whether the default type name was found. + public bool TryGet(DbType typecode, out string typeName) + { + return defaults.TryGetValue(typecode, out typeName); + } + /// /// Get the type name specified type and size /// @@ -76,6 +87,28 @@ public string Get(DbType typecode) /// if available, otherwise the default type name. /// public string Get(DbType typecode, int size, int precision, int scale) + { + if (!TryGet(typecode, size, precision, scale, out var result)) + { + throw new ArgumentException("Dialect does not support DbType." + typecode, nameof(typecode)); + } + + return result; + } + + /// + /// Get the type name specified type and size. + /// + /// The type key. + /// The SQL length. + /// The SQL scale. + /// The SQL precision. + /// + /// The associated name with smallest capacity >= size (or precision for decimal, or scale for date time types) + /// if available, otherwise the default type name. + /// + /// Whether the type name was found. + public bool TryGet(DbType typecode, int size, int precision, int scale, out string typeName) { weighted.TryGetValue(typecode, out var map); if (map != null && map.Count > 0) @@ -88,7 +121,8 @@ public string Get(DbType typecode, int size, int precision, int scale) { if (requiredCapacity <= entry.Key) { - return Replace(entry.Value, size, precision, scale); + typeName = Replace(entry.Value, size, precision, scale); + return true; } } if (isPrecisionType && precision != 0) @@ -102,11 +136,12 @@ public string Get(DbType typecode, int size, int precision, int scale) // But if the type is used for storing amounts, this may cause losing the ability to store cents... // So better just reduce as few as possible. var adjustedScale = Math.Min(scale, adjustedPrecision); - return Replace(maxEntry.Value, size, adjustedPrecision, adjustedScale); + typeName = Replace(maxEntry.Value, size, adjustedPrecision, adjustedScale); + return true; } } //Could not find a specific type for the capacity, using the default - return Get(typecode); + return TryGet(typecode, out typeName); } /// diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 7289d5acbc5..5964b99db90 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -257,6 +257,28 @@ internal HqlIdent(IASTFactory factory, System.Type type) throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name)); } } + + internal static bool SupportsType(System.Type type) + { + type = type.UnwrapIfNullable(); + switch (System.Type.GetTypeCode(type)) + { + case TypeCode.Boolean: + case TypeCode.Int16: + case TypeCode.Int32: + case TypeCode.Int64: + case TypeCode.Decimal: + case TypeCode.Single: + case TypeCode.DateTime: + case TypeCode.String: + case TypeCode.Double: + return true; + default: + return + type == typeof(Guid) || + type == typeof(DateTimeOffset); + } + } } public class HqlRange : HqlStatement diff --git a/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs b/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs index 6435b88a476..f03f6700230 100644 --- a/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs +++ b/src/NHibernate/Linq/Functions/ListIndexerGenerator.cs @@ -12,20 +12,30 @@ namespace NHibernate.Linq.Functions { internal class ListIndexerGenerator : BaseHqlGeneratorForMethod,IRuntimeMethodHqlGenerator { + private static readonly HashSet _supportedMethods = new HashSet + { + ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt(null, 0)), + ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt(null, 0)) + }; + public ListIndexerGenerator() { - SupportedMethods = new[] - { - ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt(null, 0)), - ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt(null, 0)) - }; + SupportedMethods = _supportedMethods; } public bool SupportsMethod(MethodInfo method) { - return method != null && - method.Name == "get_Item" && - (method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>))); + return IsRuntimeMethodSupported(method); + } + + public static bool IsMethodSupported(MethodInfo method) + { + if (method.IsGenericMethod) + { + method = method.GetGenericMethodDefinition(); + } + + return _supportedMethods.Contains(method) || IsRuntimeMethodSupported(method); } public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) @@ -40,5 +50,12 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, return treeBuilder.Index(collection, index); } + + private static bool IsRuntimeMethodSupported(MethodInfo method) + { + return method != null && + method.Name == "get_Item" && + (method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>))); + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 3d2e9be5f47..4e6df61afba 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,13 +1,16 @@ using System; +using System.Data; using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; +using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; using NHibernate.Param; +using NHibernate.Type; using NHibernate.Util; using Remotion.Linq.Clauses.Expressions; @@ -237,10 +240,16 @@ constant.Value is CallSite site && protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) { + // We need to cast the argument when its type is different from Average method return type, + // otherwise the result may be incorrect. In SQL Server avg always returns int + // when the argument is int. var hqlExpression = VisitExpression(expression.Expression).AsExpression(); - if (expression.Type != expression.Expression.Type) - hqlExpression = _hqlTreeBuilder.Cast(hqlExpression, expression.Type); + hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _) + ? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type) + : _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type); + // In Oracle the avg function can return a number with up to 40 digits which cannot be retrieved from the data reader due to the lack of such + // numeric type in .NET. In order to avoid that we have to add a cast to trim the number so that it can be converted into a .NET numeric type. return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type); } @@ -261,7 +270,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression) protected HqlTreeNode VisitNhSum(NhSumExpression expression) { - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); + return IsCastRequired("sum", expression.Expression, expression.Type) + ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type) + : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type); } protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression) @@ -475,13 +486,12 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - if ((expression.Operand.Type.IsPrimitive || expression.Operand.Type == typeof(Decimal)) && - (expression.Type.IsPrimitive || expression.Type == typeof(Decimal))) - { - return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type); - } - - return VisitExpression(expression.Operand); + return IsCastRequired(expression.Operand, expression.Type, out var existType) + ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader + : existType && HqlIdent.SupportsType(expression.Type) + ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), expression.Type) + : VisitExpression(expression.Operand); } throw new NotSupportedException(expression.ToString()); @@ -582,5 +592,85 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression) var expressionSubTree = expression.Expressions.ToArray(exp => VisitExpression(exp)); return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree); } + + private bool IsCastRequired(Expression expression, System.Type toType, out bool existType) + { + existType = false; + return toType != typeof(object) && + IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType); + } + + private bool IsCastRequired(IType type, IType toType, out bool existType) + { + // A type can be null when casting an entity into a base class, in that case we should not cast + if (type == null || toType == null || Equals(type, toType)) + { + existType = false; + return false; + } + + var sqlTypes = type.SqlTypes(_parameters.SessionFactory); + var toSqlTypes = toType.SqlTypes(_parameters.SessionFactory); + if (sqlTypes.Length != 1 || toSqlTypes.Length != 1) + { + existType = false; + return false; // Casting a multi-column type is not possible + } + + existType = true; + if (sqlTypes[0].DbType == toSqlTypes[0].DbType) + { + return false; + } + + if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String) + { + existType = false; + return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value + } + + // Some dialects can map several sql types into one, cast only if the dialect types are different + if (!_parameters.SessionFactory.Dialect.TryGetCastTypeName(sqlTypes[0], out var castTypeName) || + !_parameters.SessionFactory.Dialect.TryGetCastTypeName(toSqlTypes[0], out var toCastTypeName)) + { + return false; // The dialect does not support such cast + } + + return castTypeName != toCastTypeName; + } + + private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType) + { + var argumentType = GetType(argumentExpression); + if (argumentType == null || returnType == typeof(object)) + { + return false; + } + + var returnNhType = TypeFactory.GetDefaultTypeFor(returnType); + if (returnNhType == null) + { + return true; // Fallback to the old behavior + } + + var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName); + if (sqlFunction == null) + { + return true; // Fallback to the old behavior + } + + var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory); + return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _); + } + + private IType GetType(Expression expression) + { + // Try to get the mapped type for the member as it may be a non default one + return expression.Type == typeof(object) + ? null + : (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _) + ? type + : TypeFactory.GetDefaultTypeFor(expression.Type)); + } } } diff --git a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs index b12c49975c4..3b793d348a8 100644 --- a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs +++ b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs @@ -50,6 +50,8 @@ public class EntityMetamodel private readonly CascadeStyle[] cascadeStyles; private readonly Dictionary propertyIndexes = new Dictionary(); + private readonly IDictionary _identifierPropertyTypes = new Dictionary(); + private readonly IDictionary _propertyTypes = new Dictionary(); private readonly bool hasCollections; private readonly bool hasMutableProperties; private readonly bool hasLazyProperties; @@ -91,6 +93,7 @@ public EntityMetamodel(PersistentClass persistentClass, ISessionFactoryImplement identifierProperty = PropertyFactory.BuildIdentifierProperty(persistentClass, sessionFactory.GetIdentifierGenerator(rootName)); + MapIdentifierPropertyTypes(identifierProperty); versioned = persistentClass.IsVersioned; @@ -409,13 +412,44 @@ private bool HasPartialUpdateComponentGeneration(Mapping.Component component) private void MapPropertyToIndex(Mapping.Property prop, int i) { - propertyIndexes[prop.Name] = i; - Mapping.Component comp = prop.Value as Mapping.Component; - if (comp != null) + MapPropertyToIndex(null, prop, i); + } + + private void MapPropertyToIndex(string path, Mapping.Property prop, int i) + { + var propPath = !string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name; + propertyIndexes[propPath] = i; + _propertyTypes[propPath] = prop.Type; + if (!(prop.Value is Mapping.Component comp)) + { + return; + } + + foreach (var subprop in comp.PropertyIterator) { - foreach (Mapping.Property subprop in comp.PropertyIterator) + MapPropertyToIndex(propPath, subprop, i); + } + } + + private void MapIdentifierPropertyTypes(IdentifierProperty identifier) + { + MapIdentifierPropertyTypes(identifier.Name, identifier.Type); + } + + private void MapIdentifierPropertyTypes(string path, IType propertyType) + { + if (!string.IsNullOrEmpty(path)) + { + _identifierPropertyTypes[path] = propertyType; + } + + if (propertyType is IAbstractComponentType componentType) + { + for (var i = 0; i < componentType.PropertyNames.Length; i++) { - propertyIndexes[prop.Name + '.' + subprop.Name] = i; + MapIdentifierPropertyTypes( + !string.IsNullOrEmpty(path) ? $"{path}.{componentType.PropertyNames[i]}" : componentType.PropertyNames[i], + componentType.Subtypes[i]); } } } @@ -534,6 +568,18 @@ public int GetPropertyIndex(string propertyName) return null; } + internal IType GetIdentifierPropertyType(string memberPath) + { + return _identifierPropertyTypes.TryGetValue(memberPath, out var propertyType) ? propertyType : null; + } + + internal IType GetPropertyType(string memberPath) + { + return _propertyTypes.TryGetValue(memberPath, out var propertyType) + ? propertyType + : GetIdentifierPropertyType(memberPath); + } + public bool HasCollections { get { return hasCollections; } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 6fa3a44615f..86fb8d04bd3 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,6 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; -using System; +using NHibernate.Engine; +using NHibernate.Linq; +using NHibernate.Linq.Expressions; +using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; +using NHibernate.Persister.Collection; +using NHibernate.Persister.Entity; +using NHibernate.Type; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; namespace NHibernate.Util { @@ -15,5 +27,734 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi } return ((MemberExpression)expression.Body).Member; } + + /// + /// Try to get the mapped nullability from the given expression. + /// + /// The session factory. + /// The expression to evaluate. + /// Output parameter that represents whether the is nullable. + /// Whether the mapped nullability was found. + internal static bool TryGetMappedNullability( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out bool nullable) + { + if (!TryGetMappedType( + sessionFactory, + expression, + out _, + out var entityPersister, + out var componentType, + out var memberPath)) + { + nullable = false; + return false; + } + + // The source entity is always not null, as it gets translated to the entity identifier + if (memberPath == null) + { + nullable = false; + return true; + } + + int index; + if (componentType != null) + { + index = Array.IndexOf( + componentType.PropertyNames, + memberPath.Substring(memberPath.LastIndexOf('.') + 1)); + nullable = componentType.PropertyNullability[index]; + return true; + } + + if (entityPersister.EntityMetamodel.GetIdentifierPropertyType(memberPath) != null) + { + nullable = false; // Identifier is always not-null + return true; + } + + index = entityPersister.EntityMetamodel.GetPropertyIndex(memberPath); + nullable = entityPersister.PropertyNullability[index]; + return true; + } + + /// + /// Try to get the mapped type from the given expression. When the type is + /// , the will be set based on the expression type + /// only when the mapping for was found, otherwise + /// will be returned. + /// + /// The session factory to retrieve types. + /// The expression to evaluate. + /// Output parameter that represents the mapped type of . + /// + /// Output parameter that represents the entity persister of the entity where is defined. + /// This parameter will not be set when represents a property in a collection composite element. + /// + /// + /// Output parameter that represents the component type where is defined. + /// This parameter will not be set when does not represent a property in a component. + /// + /// + /// Output parameter that represents the path of the mapped member, which in most cases is the member name. In case + /// when the mapped member is defined inside a component the path will be prefixed with the name of the component member and a dot. + /// (e.g. Component.Property). + /// Whether the mapped type was found. + /// + /// When the contains an expression of type , the + /// result may not be correct when casting to an entity that is mapped with multiple entity names. + /// When the is polymorphic, the first implementor will be returned. + /// When the contains a , the first found entity name + /// will be returned from or . + /// When the contains a expression, the first found entity name + /// will be returned from or . + /// + internal static bool TryGetMappedType( + ISessionFactoryImplementor sessionFactory, + Expression expression, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath) + { + // In order to get the correct entity name from the expression we first have to find the constant expression that contains the + // IEntityNameProvider instance, from which we can retrieve the starting entity name. Once we have it, we have to traverse all + // expressions that we had to traverse in order to find the IEntityNameProvider instance, but in reverse order (bottom to top) + // and keep tracking the entity name until we reach to top. + + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + // Try to retrieve the starting entity name with all members that were traversed in that process. + if (!MemberMetadataExtractor.TryGetAllMemberMetadata(expression, out var metadataResults)) + { + // Failed to find the starting entity name, due to: + // - Unsupported expression + // - The expression didn't contain the IEntityNameProvider instance + return false; + } + + // Due to coalesce and conditional expressions we can have multiple paths to traverse, in that case find the first path + // for which we are able to determine the mapped type. + foreach (var metadataResult in metadataResults) + { + if (ProcessMembersMetadataResult( + metadataResult, + sessionFactory, + out mappedType, + out entityPersister, + out component, + out memberPath)) + { + return true; + } + } + + return false; + } + + private static bool ProcessMembersMetadataResult( + MemberMetadataResult metadataResult, + ISessionFactoryImplementor sessionFactory, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath) + { + if (!TryGetEntityPersister(metadataResult.EntityName, null, sessionFactory, out var currentEntityPersister)) + { + // Failed to find the starting entity name, due to: + // - Querying a type that is not related to any entity e.g. s.Query().Where(a => a.Type == "A") + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + return false; + } + + if (metadataResult.MemberPaths.Count == 0) // The expression do not contain any member expressions + { + if (metadataResult.ConvertType != null) + { + mappedType = TryGetEntityPersister( + currentEntityPersister, + metadataResult.ConvertType, + sessionFactory, + out var convertPersister) + ? convertPersister.EntityMetamodel.EntityType // ((Subclass)q) + : TypeFactory.GetDefaultTypeFor(metadataResult.ConvertType); // ((NotMapped)q) + } + else + { + mappedType = currentEntityPersister.EntityMetamodel.EntityType; // q + } + + memberPath = null; + component = null; + entityPersister = currentEntityPersister; + return mappedType != null; + } + + // If there was a cast right after the constant expression that contains the IEntityNameProvider instance, we have + // to update the entity persister according to it, otherwise use the value returned by TryGetAllMemberMetadata method. + if (metadataResult.ConvertType != null) + { + if (!TryGetEntityPersister( + currentEntityPersister, + metadataResult.ConvertType, + sessionFactory, + out var convertPersister)) // ((NotMapped)q).Id + { + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + return false; + } + + currentEntityPersister = convertPersister; // ((Subclass)q).Id + } + + return TraverseMembers( + sessionFactory, + metadataResult.MemberPaths, + currentEntityPersister, + out mappedType, + out entityPersister, + out component, + out memberPath); + } + + private static bool TraverseMembers( + ISessionFactoryImplementor sessionFactory, + Stack memberPaths, + IEntityPersister currentEntityPersister, + out IType mappedType, + out IEntityPersister entityPersister, + out IAbstractComponentType component, + out string memberPath) + { + // Traverse the members that were traversed by the TryGetAllMemberMetadata method in the reverse order and try to keep + // tracking the entity persister until all members are traversed. + var member = memberPaths.Pop(); + var currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); + IAbstractComponentType currentComponentType = null; + while (memberPaths.Count > 0 && currentType != null) + { + memberPath = member.Path; + var convertType = member.ConvertType; + member = memberPaths.Pop(); + + switch (currentType) + { + case IAssociationType associationType: + ProcessAssociationType( + associationType, + sessionFactory, + member, + convertType, + out currentType, + out currentEntityPersister, + out currentComponentType); + break; + case IAbstractComponentType componentType: + currentComponentType = componentType; + if (currentEntityPersister == null) + { + // When persister is not available (q.OneToManyCompositeElement[0].Prop), try to get the type from the component + currentType = TryGetComponentPropertyType(componentType, member.Path); + } + else + { + // Concatenate the component property path in order to be able to use EntityMetamodel.GetPropertyType to retrieve the type. + // As GetPropertyType supports only components, do not concatenate when dealing with collection composite elements or elements. + // q.Component.Prop + member = new MemberMetadata( + $"{memberPath}.{member.Path}", + member.ConvertType, + member.HasIndexer); + + // q.Component.Prop + currentType = currentEntityPersister.EntityMetamodel.GetPropertyType(member.Path); + } + + break; + default: + // q.Prop.NotMappedProp + currentType = null; + currentEntityPersister = null; + currentComponentType = null; + break; + } + } + + // When traversed to the top of the expression, return the current tracking values + if (memberPaths.Count == 0) + { + memberPath = currentEntityPersister != null || currentComponentType != null ? member.Path : null; + mappedType = GetType(currentEntityPersister, currentType, member, sessionFactory); + entityPersister = currentEntityPersister; + component = currentComponentType; + return mappedType != null; + } + + // Member not mapped + memberPath = null; + mappedType = null; + entityPersister = null; + component = null; + return false; + } + + private static IType TryGetComponentPropertyType(IAbstractComponentType componentType, string memberPath) + { + var index = Array.IndexOf(componentType.PropertyNames, memberPath); + return index < 0 + ? null // q.OneToManyCompositeElement[0].NotMappedProp + : componentType.Subtypes[index]; // q.OneToManyCompositeElement[0].Prop + } + + private static void ProcessAssociationType( + IAssociationType associationType, + ISessionFactoryImplementor sessionFactory, + MemberMetadata member, + System.Type convertType, + out IType memberType, + out IEntityPersister memberPersister, + out IAbstractComponentType memberComponent) + { + if (associationType.IsCollectionType) + { + // Check manually for entity association as GetAssociatedEntityName throws when there is none. + var queryableCollection = + (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + if (!queryableCollection.ElementType.IsEntityType) // q.OneToManyCompositeElement[0].Member, q.OneToManyElement[0].Member + { + memberPersister = null; + // Can be or + switch (queryableCollection.ElementType) + { + case IAbstractComponentType componentType: // q.OneToManyCompositeElement[0].Member + memberComponent = componentType; + memberType = TryGetComponentPropertyType(componentType, member.Path); + return; + default: // q.OneToManyElement[0].Member + memberType = null; + memberComponent = null; + return; + } + } + + // q.OneToMany[0].Member + TryGetEntityPersister( + associationType.GetAssociatedEntityName(sessionFactory), + convertType, + sessionFactory, + out memberPersister); + } + else if (associationType.IsAnyType) + { + // ((Address)q.AnyType).Member, q.AnyType.Member + // Unfortunately we cannot detect the exact entity name as cast does not provide it, + // so the only option is to guess it. + TryGetEntityPersister(convertType, sessionFactory, out memberPersister); + } + else // q.ManyToOne.Member + { + TryGetEntityPersister( + associationType.GetAssociatedEntityName(sessionFactory), + convertType, + sessionFactory, + out memberPersister); + } + + memberComponent = null; + memberType = memberPersister != null + ? memberPersister.EntityMetamodel.GetPropertyType(member.Path) + : null; // q.AnyType.Member, ((NotMappedClass)q.ManyToOne) + } + + private static bool TryGetEntityPersister( + string currentEntityName, + System.Type convertedType, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) + { + var currentEntityPersister = sessionFactory.TryGetEntityPersister(currentEntityName); + if (currentEntityPersister == null) + { + // When dealing with a polymorphic query it is not important which entity name we pick + // as they all need to have the same mapped types for members of the type that is queried. + // If one of the entites has a different type mapped (e.g. enum mapped as string instead of numeric), + // the query will fail to execute as currently the ParameterMetadata is bound to IQueryPlan and not to IQueryTranslator + // (e.g. s.Query().Where(a => a.MyEnum == MyEnum.Option)). + currentEntityName = sessionFactory.GetImplementors(currentEntityName).FirstOrDefault(); + if (currentEntityName == null) + { + persister = null; + return false; + } + + currentEntityPersister = sessionFactory.GetEntityPersister(currentEntityName); + } + + return TryGetEntityPersister(currentEntityPersister, convertedType, sessionFactory, out persister); + } + + private static bool TryGetEntityPersister( + IEntityPersister currentEntityPersister, + System.Type convertedType, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) + { + if (convertedType == null) + { + persister = currentEntityPersister; + return true; + } + + if (currentEntityPersister.EntityMetamodel.HasSubclasses) + { + // When a class is casted to a subclass e.g. ((PizzaOrder)c.Order).PizzaName, we + // can only guess the entity name of it, as there can be many entity names mapped + // to the same subclass. + persister = currentEntityPersister.EntityMetamodel.SubclassEntityNames + .Select(sessionFactory.GetEntityPersister) + .FirstOrDefault(p => p.MappedClass == convertedType); + + return persister != null; + } + + return TryGetEntityPersister(convertedType, sessionFactory, out persister); + } + + private static bool TryGetEntityPersister( + System.Type convertedType, + ISessionFactoryImplementor sessionFactory, + out IEntityPersister persister) + { + if (convertedType == null) + { + persister = null; + return false; + } + + var entityName = sessionFactory.TryGetGuessEntityName(convertedType); + if (entityName == null) + { + persister = null; + return false; + } + + persister = sessionFactory.GetEntityPersister(entityName); + return true; + } + + private static IType GetType( + IEntityPersister currentEntityPersister, + IType currentType, + MemberMetadata member, + ISessionFactoryImplementor sessionFactory) + { + // Not mapped + if (currentType == null) + { + return null; + } + + IEntityPersister persister; + if (!member.HasIndexer || currentEntityPersister == null) + { + if (member.ConvertType == null) + { + return currentType; // q.Prop, q.OneToManyCompositeElement[0].Prop + } + + return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) + ? persister.EntityMetamodel.EntityType // (Entity)q.Prop, (Entity)q.OneToManyCompositeElement[0].Prop + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.Prop, (long)q.OneToManyCompositeElement[0].Prop + } + + if (!(currentType is IAssociationType associationType)) + { + // q.Prop[0] + return null; + } + + var queryableCollection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + if (member.ConvertType == null) + { + // q.OneToMany[0] + return queryableCollection.ElementType; + } + + return TryGetEntityPersister(member.ConvertType, sessionFactory, out persister) + ? persister.EntityMetamodel.EntityType // (Entity)q.OneToMany[0] + : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.OneToMany[0] + } + + private class MemberMetadataExtractor : NhExpressionVisitor + { + private readonly List _childrenResults = new List(); + private readonly Stack _memberPaths; + private System.Type _convertType; + private bool _hasIndexer; + private string _entityName; + + private MemberMetadataExtractor(Stack memberPaths, System.Type convertType, bool hasIndexer) + { + _memberPaths = memberPaths; + _convertType = convertType; + _hasIndexer = hasIndexer; + } + + /// + /// Traverses the expression from top to bottom until the first containing an IEntityNameProvider + /// instance is found. + /// + /// The expression to traverse. + /// Output parameter that represents a collection, where each item contains information about all + /// that were traversed until the first containing an + /// instance is found. The number of items depends on how many different paths exist + /// in the that contains a instance. When + /// is not found or one of the expressions is not supported the parameter will be set to . + /// Whether was populated. + public static bool TryGetAllMemberMetadata(Expression expression, out List results) + { + if (TryGetAllMemberMetadata(expression, new Stack(), null, false, out var result)) + { + results = result.GetAllResults().ToList(); + return true; + } + + results = null; + return false; + } + + private static bool TryGetAllMemberMetadata( + Expression expression, + Stack memberPaths, + System.Type convertType, + bool hasIndexer, + out MemberMetadataResult results) + { + var extractor = new MemberMetadataExtractor(memberPaths, convertType, hasIndexer); + extractor.Accept(expression); + results = extractor._entityName != null || extractor._childrenResults.Count > 0 + ? new MemberMetadataResult( + extractor._childrenResults, + extractor._memberPaths, + extractor._entityName, + extractor._convertType) + : null; + + return results != null; + } + + private void Accept(Expression expression) + { + base.Visit(expression); + } + + protected override Expression VisitMember(MemberExpression node) + { + _memberPaths.Push(new MemberMetadata(node.Member.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Expression); + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) + { + if (node.ReferencedQuerySource is IFromClause fromClause) + { + return base.Visit(fromClause.FromExpression); + } + + if (node.ReferencedQuerySource is JoinClause joinClause) + { + return base.Visit(joinClause.InnerSequence); + } + + // Not supported expression + _entityName = null; + return node; + } + + protected override Expression VisitUnary(UnaryExpression node) + { + // Store only the outermost cast, when there are multiple casts for the same member + if (_convertType == null) + { + _convertType = node.Type; + } + + return base.Visit(node.Operand); + } + + protected internal override Expression VisitNhNominated(NhNominatedExpression node) + { + return base.Visit(node.Expression); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + _entityName = node.Value is IEntityNameProvider entityNameProvider + ? entityNameProvider.EntityName + : null; // Not a NhQueryable + + return node; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + if (node.NodeType == ExpressionType.ArrayIndex) + { + _hasIndexer = true; + return base.Visit(node.Left); + } + + if (node.NodeType == ExpressionType.Coalesce && + (TryGetMembersMetadata(node.Left) | TryGetMembersMetadata(node.Right))) + { + return node; + } + + return Visit(node); + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + if (TryGetMembersMetadata(node.IfTrue) | TryGetMembersMetadata(node.IfFalse)) + { + return node; + } + + return Visit(node); + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (ListIndexerGenerator.IsMethodSupported(node.Method)) + { + _hasIndexer = true; + return base.Visit( + node.Object == null + ? Enumerable.First(node.Arguments) // q.Children.ElementAt(0) + : node.Object // q.Children[0] + ); + } + + return Visit(node); + } + + public override Expression Visit(Expression node) + { + // Not supported expression + _entityName = null; + return node; + } + + private bool TryGetMembersMetadata(Expression expression) + { + if (TryGetAllMemberMetadata(expression, Clone(_memberPaths), _convertType, _hasIndexer, out var result)) + { + _childrenResults.Add(result); + return true; + } + + return false; + } + + private static Stack Clone(Stack original) + { + var arr = new T[original.Count]; + original.CopyTo(arr, 0); + Array.Reverse(arr); + return new Stack(arr); + } + } + + private struct MemberMetadata + { + public MemberMetadata(string path, System.Type convertType, bool hasIndexer) + { + Path = path; + ConvertType = convertType; + HasIndexer = hasIndexer; + } + + public string Path { get; } + + public System.Type ConvertType { get; } + + public bool HasIndexer { get; } + } + + private class MemberMetadataResult + { + public MemberMetadataResult( + List childrenResults, + Stack memberPaths, + string entityName, + System.Type convertType) + { + ChildrenResults = childrenResults; + MemberPaths = memberPaths; + EntityName = entityName; + ConvertType = convertType; + } + + /// + /// Metadata about all that were traversed. + /// + public Stack MemberPaths { get; } + + /// + /// type that was used on a containing + /// an . + /// + public System.Type ConvertType { get; } + + /// + /// The entity name from . + /// + public string EntityName { get; } + + /// + /// Direct children of the current metadata result. + /// + public List ChildrenResults { get; } + + /// + /// Gets all leaf (bottom) children that have the entity name set. + /// + /// + public IEnumerable GetAllResults() + { + return GetAllResults(this); + } + + private static IEnumerable GetAllResults(MemberMetadataResult result) + { + if (result.ChildrenResults.Count == 0) + { + yield return result; + } + else + { + foreach (var childResult in result.ChildrenResults) + { + foreach (var childChildrenResult in GetAllResults(childResult)) + { + yield return childChildrenResult; + } + } + } + } + } } -} \ No newline at end of file +}