Skip to content

Commit 77dd440

Browse files
maca88bahusoid
authored andcommitted
Fix decimal equality comparison for Sqlite (nhibernate#2807)
Fixes nhibernate#2792
1 parent 3f7ae00 commit 77dd440

File tree

8 files changed

+46
-3
lines changed

8 files changed

+46
-3
lines changed

src/NHibernate.Test/Async/Linq/OperatorTests.cs

+14
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ public async Task UnaryMinusAsync()
3535
Assert.AreEqual(1, await (session.Query<TimesheetEntry>().CountAsync(a => -a.NumberOfHours == -7)));
3636
}
3737

38+
[Test]
39+
public async Task DecimalAddAsync()
40+
{
41+
decimal offset = 5.5m;
42+
decimal test = 10248 + offset;
43+
var result = await (session.Query<Order>().Where(e => offset + e.OrderId == test).ToListAsync());
44+
Assert.That(result, Has.Count.EqualTo(1));
45+
46+
offset = 5.5m;
47+
test = 32.38m + offset;
48+
result = await (session.Query<Order>().Where(e => offset + e.Freight == test).ToListAsync());
49+
Assert.That(result, Has.Count.EqualTo(1));
50+
}
51+
3852
[Test]
3953
public async Task UnaryPlusAsync()
4054
{

src/NHibernate.Test/Async/Linq/ParameterTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ public async Task CompareFloatingPointParametersAndColumnsAsync()
320320
totalParameters,
321321
sql =>
322322
{
323-
Assert.That(sql, Does.Not.Contain("cast"));
323+
Assert.That(sql, pair.Value == "Decimal" && Dialect.IsDecimalStoredAsFloatingPointNumber ? Does.Contain("cast") : Does.Not.Contain("cast"));
324324
Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters));
325325
}));
326326
}

src/NHibernate.Test/Linq/OperatorTests.cs

+14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ public void UnaryMinus()
2424
Assert.AreEqual(1, session.Query<TimesheetEntry>().Count(a => -a.NumberOfHours == -7));
2525
}
2626

27+
[Test]
28+
public void DecimalAdd()
29+
{
30+
decimal offset = 5.5m;
31+
decimal test = 10248 + offset;
32+
var result = session.Query<Order>().Where(e => offset + e.OrderId == test).ToList();
33+
Assert.That(result, Has.Count.EqualTo(1));
34+
35+
offset = 5.5m;
36+
test = 32.38m + offset;
37+
result = session.Query<Order>().Where(e => offset + e.Freight == test).ToList();
38+
Assert.That(result, Has.Count.EqualTo(1));
39+
}
40+
2741
[Test]
2842
public void UnaryPlus()
2943
{

src/NHibernate.Test/Linq/ParameterTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ public void CompareFloatingPointParametersAndColumns()
308308
totalParameters,
309309
sql =>
310310
{
311-
Assert.That(sql, Does.Not.Contain("cast"));
311+
Assert.That(sql, pair.Value == "Decimal" && Dialect.IsDecimalStoredAsFloatingPointNumber ? Does.Contain("cast") : Does.Not.Contain("cast"));
312312
Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters));
313313
});
314314
}

src/NHibernate.Test/TestCase.cs

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using NHibernate.Dialect;
1818
using NHibernate.Driver;
1919
using NHibernate.Engine.Query;
20+
using NHibernate.SqlTypes;
2021
using NHibernate.Util;
2122
using NSubstitute;
2223

@@ -525,6 +526,8 @@ protected void ClearQueryPlanCache()
525526
var forPartsOfMethod = ReflectHelper.GetMethodDefinition(() => Substitute.ForPartsOf<object>());
526527
var substitute = (Dialect.Dialect) forPartsOfMethod.MakeGenericMethod(origDialect.GetType())
527528
.Invoke(null, new object[] { new object[0] });
529+
substitute.GetCastTypeName(Arg.Any<SqlType>())
530+
.ReturnsForAnyArgs(x => origDialect.GetCastTypeName(x.ArgAt<SqlType>(0)));
528531

529532
dialectProperty.SetValue(Sfi.Settings, substitute);
530533

src/NHibernate/Dialect/Dialect.cs

+5
Original file line numberDiff line numberDiff line change
@@ -2611,6 +2611,11 @@ public virtual bool SupportsSqlBatches
26112611
get { return false; }
26122612
}
26132613

2614+
/// <summary>
2615+
/// Whether <see cref="decimal"/> is stored as a floating point number.
2616+
/// </summary>
2617+
public virtual bool IsDecimalStoredAsFloatingPointNumber => false;
2618+
26142619
public virtual bool IsKnownToken(string currentToken, string nextToken)
26152620
{
26162621
return false;

src/NHibernate/Dialect/SQLiteDialect.cs

+3
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ public override bool GenerateTablePrimaryKeyConstraintForIdentityColumn
362362
get { return false; }
363363
}
364364

365+
/// <inheritdoc />
366+
public override bool IsDecimalStoredAsFloatingPointNumber => true;
367+
365368
public override string Qualify(string catalog, string schema, string table)
366369
{
367370
StringBuilder qualifiedName = new StringBuilder();

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,12 @@ protected HqlTreeNode VisitConstantExpression(ConstantExpression expression)
579579
if (_parameters.ConstantToParameterMap.TryGetValue(expression, out namedParameter))
580580
{
581581
_parameters.RequiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, false));
582+
var parameter = _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression();
582583

583-
return _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression();
584+
// SQLite driver binds decimal parameters to text, which can cause unexpected results in arithmetic operations.
585+
return _parameters.SessionFactory.Dialect.IsDecimalStoredAsFloatingPointNumber && expression.Type.UnwrapIfNullable() == typeof(decimal)
586+
? _hqlTreeBuilder.TransparentCast(parameter, expression.Type)
587+
: parameter;
584588
}
585589

586590
return _hqlTreeBuilder.Constant(expression.Value);

0 commit comments

Comments
 (0)