Skip to content

Fix named parameter leaking to query cache #3144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/NHibernate.Test/NHSpecificTest/GH3030/ByCodeFixture.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using System;
using System.Linq;
using NHibernate.Cfg.MappingSchema;
using NHibernate.Mapping.ByCode;
using NUnit.Framework;

namespace NHibernate.Test.NHSpecificTest.GH3030
{
[TestFixture]
public class ByCodeFixture : TestCaseMappingByCode
{
protected override HbmMapping GetMappings()
{
var mapper = new ModelMapper();
mapper.Class<Entity>(
rc =>
{
rc.Table("Entity");
rc.Id(x => x.Id, m => m.Generator(Generators.Assigned));
});

return mapper.CompileMappingForAllExplicitlyAddedEntities();
}

protected override void OnTearDown()
{
using (var session = OpenSession())
using (var transaction = session.BeginTransaction())
{
session.CreateQuery("delete from System.Object").ExecuteUpdate();
transaction.Commit();
}
}

[Test]
public void LinqShouldNotLeakEntityParameters()
{
WeakReference sessionReference = null;
WeakReference firstReference = null;
WeakReference secondReference = null;

new System.Action(
() =>
{
using (var session = ((DebugSessionFactory) Sfi).ActualFactory.OpenSession())
{
var first = new Entity { Id = 1 };
var second = new Entity { Id = 2 };

_ = session.Query<Entity>().FirstOrDefault(f => f == first);
_ = session.Query<Entity>().FirstOrDefault(f => f == second);

sessionReference = new WeakReference(session, true);
firstReference = new WeakReference(first, true);
secondReference = new WeakReference(second, true);
}
})();

GC.Collect();
GC.WaitForPendingFinalizers();

Assert.That(sessionReference.Target, Is.Null);
Assert.That(firstReference.Target, Is.Null);
Assert.That(secondReference.Target, Is.Null);
}

public class Entity
{
public virtual int Id { get; set; }
}
}
}
1 change: 0 additions & 1 deletion src/NHibernate/Engine/Query/QueryExpressionPlan.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using NHibernate.Hql;
using NHibernate.Linq;

namespace NHibernate.Engine.Query
{
Expand Down
24 changes: 16 additions & 8 deletions src/NHibernate/Engine/Query/QueryPlanCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ public IQueryExpressionPlan GetHQLQueryPlan(IQueryExpression queryExpression, bo
{
log.Debug("unable to locate HQL query plan in cache; generating ({0})", queryExpression.Key);
}

plan = new QueryExpressionPlan(queryExpression, shallow, enabledFilters, factory);
// 6.0 TODO: add "CanCachePlan { get; }" to IQueryExpression interface
if (!(queryExpression is ICacheableQueryExpression linqExpression) || linqExpression.CanCachePlan)
planCache.Put(key, plan);
planCache.Put(key, PreparePlanToCache(plan));
else
log.Debug("Query plan not cacheable");
}
Expand All @@ -79,23 +80,30 @@ public IQueryExpressionPlan GetHQLQueryPlan(IQueryExpression queryExpression, bo
return plan;
}

private QueryExpressionPlan PreparePlanToCache(QueryExpressionPlan plan)
{
if (plan.QueryExpression is NhLinqExpression planExpression)
{
return plan.Copy(new NhLinqExpressionCache(planExpression));
}

return plan;
}

private static QueryExpressionPlan CopyIfRequired(QueryExpressionPlan plan, IQueryExpression queryExpression)
{
var planExpression = plan.QueryExpression as NhLinqExpression;
var expression = queryExpression as NhLinqExpression;
if (planExpression != null && expression != null)
if (plan.QueryExpression is NhLinqExpressionCache cache && queryExpression is NhLinqExpression expression)
{
//NH-3413
//Here we have to use original expression.
//In most cases NH do not translate expression in second time, but
// for cases when we have list parameters in query, like @p1.Contains(...),
// it does, and then it uses parameters from first try.
//TODO: cache only required parts of QueryExpression

//NH-3436
// We have to return new instance plan with it's own query expression
// because other treads can override queryexpression of current plan during execution of query if we will use cached instance of plan
expression.CopyExpressionTranslation(planExpression);
// because other treads can override query expression of current plan during execution of query if we will use cached instance of plan
expression.CopyExpressionTranslation(cache);
plan = plan.Copy(expression);
}

Expand All @@ -118,7 +126,7 @@ public IQueryExpressionPlan GetFilterQueryPlan(IQueryExpression queryExpression,
plan = new FilterQueryPlan(queryExpression, collectionRole, shallow, enabledFilters, factory);
// 6.0 TODO: add "CanCachePlan { get; }" to IQueryExpression interface
if (!(queryExpression is ICacheableQueryExpression linqExpression) || linqExpression.CanCachePlan)
planCache.Put(key, plan);
planCache.Put(key, PreparePlanToCache(plan));
else
log.Debug("Query plan not cacheable");
}
Expand Down
2 changes: 1 addition & 1 deletion src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static IQueryTranslator[] CreateQueryTranslators(

var translators = polymorphicParsers
.ToArray(hql => queryExpression is NhLinqExpression linqExpression
? new QueryTranslatorImpl(queryIdentifier, hql, filters, factory, linqExpression.NamedParameters)
? new QueryTranslatorImpl(queryIdentifier, hql, filters, factory, linqExpression.GetNamedParameterTypes())
: new QueryTranslatorImpl(queryIdentifier, hql, filters, factory));

foreach (var translator in translators)
Expand Down
28 changes: 7 additions & 21 deletions src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public partial class HqlSqlWalker
private readonly LiteralProcessor _literalProcessor;

private readonly IDictionary<string, string> _tokenReplacements;
private readonly IDictionary<string, NamedParameter> _namedParameters;
private readonly IDictionary<IParameterSpecification, IType> _guessedParameterTypes = new Dictionary<IParameterSpecification, IType>();

private JoinType _impliedJoinType;
Expand All @@ -65,32 +64,20 @@ public partial class HqlSqlWalker
private IASTFactory _nodeFactory;
private readonly List<AssignmentSpecification> assignmentSpecifications = new List<AssignmentSpecification>();
private int numberOfParametersInSetClause;
private Stack<int> clauseStack=new Stack<int>();
private Stack<int> clauseStack = new Stack<int>();

public HqlSqlWalker(
QueryTranslatorImpl qti,
ISessionFactoryImplementor sfi,
ITreeNodeStream input,
IDictionary<string, string> tokenReplacements,
string collectionRole)
: this(qti, sfi, input, tokenReplacements, null, collectionRole)
{
}

internal HqlSqlWalker(
QueryTranslatorImpl qti,
ISessionFactoryImplementor sfi,
ITreeNodeStream input,
IDictionary<string, string> tokenReplacements,
IDictionary<string, NamedParameter> namedParameters,
string collectionRole)
: this(input)
{
_sessionFactoryHelper = new SessionFactoryHelperExtensions(sfi);
_qti = qti;
_literalProcessor = new LiteralProcessor(this);
_tokenReplacements = tokenReplacements;
_namedParameters = namedParameters;
_collectionFilterRole = collectionRole;
}

Expand Down Expand Up @@ -313,8 +300,7 @@ void PostProcessInsert(IASTNode insert)

IASTNode idSelectExprNode = null;

var seqGenerator = generator as SequenceGenerator;
if (seqGenerator != null)
if (generator is SequenceGenerator seqGenerator)
{
string seqName = seqGenerator.GeneratorKey();
string nextval = SessionFactoryHelper.Factory.Dialect.GetSelectSequenceNextValString(seqName);
Expand Down Expand Up @@ -1083,17 +1069,17 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode)
);

parameter.HqlParameterSpecification = paramSpec;
if (_namedParameters != null && _namedParameters.TryGetValue(name, out var namedParameter))
if (_qti.TryGetNamedParameterType(name, out var type, out var isGuessedType))
{
// Add the parameter type information so that we are able to calculate functions return types
// when the parameter is used as an argument.
if (namedParameter.IsGuessedType)
if (isGuessedType)
{
_guessedParameterTypes[paramSpec] = namedParameter.Type;
parameter.GuessedType = namedParameter.Type;
_guessedParameterTypes[paramSpec] = type;
parameter.GuessedType = type;
}
else
parameter.ExpectedType = namedParameter.Type;
parameter.ExpectedType = type;
}

_parameters.Add(paramSpec);
Expand Down
27 changes: 19 additions & 8 deletions src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public partial class QueryTranslatorImpl : IFilterTranslator
private readonly string _queryIdentifier;
private readonly IASTNode _stageOneAst;
private readonly ISessionFactoryImplementor _factory;
private readonly IDictionary<string, NamedParameter> _namedParameters;
private readonly IDictionary<string, Tuple<IType, bool>> _namedParameters;

private bool _shallowQuery;
private bool _compiled;
Expand Down Expand Up @@ -69,7 +69,7 @@ internal QueryTranslatorImpl(
IASTNode parsedQuery,
IDictionary<string, IFilter> enabledFilters,
ISessionFactoryImplementor factory,
IDictionary<string, NamedParameter> namedParameters)
IDictionary<string, Tuple<IType, bool>> namedParameters)
{
_queryIdentifier = queryIdentifier;
_stageOneAst = parsedQuery;
Expand Down Expand Up @@ -454,7 +454,7 @@ private static IStatementExecutor BuildAppropriateStatementExecutor(IStatement s

private HqlSqlTranslator Analyze(string collectionRole)
{
var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, _namedParameters, collectionRole);
var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, collectionRole);

translator.Translate();

Expand All @@ -468,6 +468,20 @@ private void ErrorIfDML()
throw new QueryExecutionRequestException("Not supported for DML operations", _queryIdentifier);
}
}

public bool TryGetNamedParameterType(string name, out IType type, out bool isGuessedType)
{
if (_namedParameters == null || !_namedParameters.TryGetValue(name, out var p))
{
type = null;
isGuessedType = false;
return false;
}

type = p.Item1;
isGuessedType = p.Item2;
return true;
}
}

public class HqlParseEngine
Expand Down Expand Up @@ -568,23 +582,20 @@ internal class HqlSqlTranslator
private readonly QueryTranslatorImpl _qti;
private readonly ISessionFactoryImplementor _sfi;
private readonly IDictionary<string, string> _tokenReplacements;
private readonly IDictionary<string, NamedParameter> _namedParameters;
private readonly string _collectionRole;
private IStatement _resultAst;

public HqlSqlTranslator(
internal HqlSqlTranslator(
IASTNode ast,
QueryTranslatorImpl qti,
ISessionFactoryImplementor sfi,
IDictionary<string, string> tokenReplacements,
IDictionary<string, NamedParameter> namedParameters,
string collectionRole)
{
_inputAst = ast;
_qti = qti;
_sfi = sfi;
_tokenReplacements = tokenReplacements;
_namedParameters = namedParameters;
_collectionRole = collectionRole;
}

Expand All @@ -604,7 +615,7 @@ public IStatement Translate()

var nodes = new BufferedTreeNodeStream(_inputAst);

var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _namedParameters, _collectionRole);
var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _collectionRole);
hqlSqlWalker.TreeAdaptor = new HqlSqlWalkerTreeAdaptor(hqlSqlWalker);

try
Expand Down
14 changes: 10 additions & 4 deletions src/NHibernate/Linq/NhLinqExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,18 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter
return DuplicateTree(ExpressionToHqlTranslationResults.Statement.AstNode);
}

internal void CopyExpressionTranslation(NhLinqExpression other)
internal void CopyExpressionTranslation(NhLinqExpressionCache cache)
{
ExpressionToHqlTranslationResults = other.ExpressionToHqlTranslationResults;
ParameterDescriptors = other.ParameterDescriptors;
ExpressionToHqlTranslationResults = cache.ExpressionToHqlTranslationResults;
ParameterDescriptors = cache.ParameterDescriptors;
// Type could have been overridden by translation.
Type = other.Type;
Type = cache.Type;
}

internal IDictionary<string, Tuple<IType, bool>> GetNamedParameterTypes()
{
return _constantToParameterMap.Values.Distinct()
.ToDictionary(p => p.Name, p => System.Tuple.Create(p.Type, p.IsGuessedType));
}

private static IASTNode DuplicateTree(IASTNode ast)
Expand Down
29 changes: 29 additions & 0 deletions src/NHibernate/Linq/NhLinqExpressionCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using NHibernate.Engine;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast.ANTLR.Tree;

namespace NHibernate.Linq
{
internal class NhLinqExpressionCache : IQueryExpression
{
internal NhLinqExpressionCache(NhLinqExpression expression)
{
ExpressionToHqlTranslationResults = expression.ExpressionToHqlTranslationResults ?? throw new ArgumentException("NhLinqExpression is not translated");
Key = expression.Key;
Type = expression.Type;
ParameterDescriptors = expression.ParameterDescriptors;
}

public ExpressionToHqlTranslationResults ExpressionToHqlTranslationResults { get; }
public string Key { get; }
public System.Type Type { get; }
public IList<NamedParameterDescriptor> ParameterDescriptors { get; }

public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter)
{
return ExpressionToHqlTranslationResults.Statement.AstNode;
}
}
}