Skip to content

Commit 00c0e07

Browse files
Detect identical expressions in extractAggregateFields to allow distinctAggregator to be used with orderAggregate
1 parent b7dbbeb commit 00c0e07

File tree

4 files changed

+71
-44
lines changed

4 files changed

+71
-44
lines changed

opaleye.cabal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ library
2929
aeson >= 0.6 && < 2.3
3030
, base >= 4.9 && < 4.19
3131
, base16-bytestring >= 0.1.1.6 && < 1.1
32-
, case-insensitive >= 1.2 && < 1.3
3332
, bytestring >= 0.10 && < 0.12
33+
, case-insensitive >= 1.2 && < 1.3
34+
, containers >= 0.5 && < 0.8
3435
, contravariant >= 1.2 && < 1.6
3536
, postgresql-simple >= 0.6 && < 0.8
3637
, pretty >= 1.1.1.0 && < 1.2

src/Opaleye/Internal/Aggregate.hs

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@
22
module Opaleye.Internal.Aggregate where
33

44
import Control.Applicative (Applicative, liftA2, pure, (<*>))
5+
import Control.Arrow ((***))
56
import Data.Foldable (toList)
67
import Data.Traversable (for)
78

9+
import Data.Map.Strict (Map)
10+
import qualified Data.Map.Strict as Map
11+
812
import qualified Data.Profunctor as P
913
import qualified Data.Profunctor.Product as PP
1014

15+
import Control.Monad.Trans.Class (lift)
16+
import Control.Monad.Trans.State.Strict (StateT, gets, modify, runStateT)
17+
1118
import qualified Opaleye.Field as F
1219
import qualified Opaleye.Internal.Column as C
1320
import qualified Opaleye.Internal.Order as O
@@ -130,42 +137,61 @@ aggregatorApply = Aggregator $ PM.PackMap $ \f (agg, a) ->
130137
--
131138
-- Instead of detecting when we are aggregating over a field from a
132139
-- previous query we just create new names for all field before we
133-
-- aggregate. On the other hand, referring to a field from a previous
134-
-- query in an ORDER BY expression is totally fine!
140+
-- aggregate.
141+
--
142+
-- Additionally, PostgreSQL imposes a limitation on aggregations using ORDER
143+
-- BY in combination with DISTINCT - essentially the expression you pass to
144+
-- ORDER BY must also be present in the argument list to the aggregation
145+
-- function. This means that not only do we also have to also create new
146+
-- names for the ORDER BY expressions (if we only rewrite the function
147+
-- arguments then they can't match and therefore ORDER BY can never be used
148+
-- with DISTINCT), but that these names actually have to match the names
149+
-- created for the aggregation function arguments. To accomplish this, when
150+
-- traversing over the aggregations, we keep track of all the expressions
151+
-- we've encountered so far, and only create new names for new expressions,
152+
-- reusing old names where possible.
135153
aggregateU :: Aggregator a b
136154
-> (a, PQ.PrimQuery, T.Tag) -> (b, PQ.PrimQuery)
137-
aggregateU agg (c0, primQ, t0) = (c1, primQ')
138-
where (c1, projPEs_inners) =
139-
PM.run (runAggregator agg (extractAggregateFields t0) c0)
140-
141-
projPEs = map fst projPEs_inners
142-
inners = concatMap snd projPEs_inners
155+
aggregateU agg (a, primQ, tag) = (b, primQ')
156+
where
157+
(inners, outers, b) =
158+
runSymbols (runAggregator agg (extractAggregateFields tag) a)
143159

144-
primQ' = PQ.Aggregate projPEs (PQ.Rebind True inners primQ)
160+
primQ' = PQ.Aggregate outers (PQ.Rebind True inners primQ)
145161

146162
extractAggregateFields
147163
:: Traversable t
148164
=> T.Tag
149-
-> (t HPQ.PrimExpr)
150-
-> PM.PM [((HPQ.Symbol,
151-
t HPQ.Symbol),
152-
PQ.Bindings HPQ.PrimExpr)]
153-
HPQ.PrimExpr
165+
-> t HPQ.PrimExpr
166+
-> Symbols HPQ.PrimExpr (PQ.Bindings (t HPQ.Symbol)) HPQ.PrimExpr
154167
extractAggregateFields tag agg = do
155-
i <- PM.new
156-
157-
let souter = HPQ.Symbol ("result" ++ i) tag
158-
159-
bindings <- for agg $ \pe -> do
160-
j <- PM.new
161-
let sinner = HPQ.Symbol ("inner" ++ j) tag
162-
pure (sinner, pe)
163-
164-
let agg' = fmap fst bindings
168+
result <- mkSymbol "result" <$> lift PM.new
169+
agg' <- traverse (symbolize (mkSymbol "inner")) agg
170+
lift $ PM.write (result, agg')
171+
pure $ HPQ.AttrExpr result
172+
where
173+
mkSymbol name i = HPQ.Symbol (name ++ i) tag
165174

166-
PM.write ((souter, agg'), toList bindings)
175+
type Symbols e s =
176+
StateT
177+
(Map e HPQ.Symbol, PQ.Bindings e -> PQ.Bindings e)
178+
(PM.PM s)
167179

168-
pure (HPQ.AttrExpr souter)
180+
runSymbols :: Symbols e [s] a -> (PQ.Bindings e, [s], a)
181+
runSymbols m = (dlist [], outers, a)
182+
where
183+
((a, (_, dlist)), outers) = PM.run $ runStateT m (Map.empty, id)
184+
185+
symbolize :: Ord e =>
186+
(String -> HPQ.Symbol) -> e -> Symbols e s HPQ.Symbol
187+
symbolize f expr = do
188+
msymbol <- gets (Map.lookup expr . fst)
189+
case msymbol of
190+
Just symbol -> pure symbol
191+
Nothing -> do
192+
symbol <- f <$> lift PM.new
193+
modify (Map.insert expr symbol *** (. ((symbol, expr) :)))
194+
pure symbol
169195

170196
unsafeMax :: Aggregator (C.Field a) (C.Field a)
171197
unsafeMax = makeAggr HPQ.AggrMax

src/Opaleye/Internal/HaskellDB/PrimQuery.hs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type Name = String
1717
type Scheme = [Attribute]
1818
type Assoc = [(Attribute,PrimExpr)]
1919

20-
data Symbol = Symbol String T.Tag deriving (Read, Show)
20+
data Symbol = Symbol String T.Tag deriving (Eq, Ord, Read, Show)
2121

2222
data PrimExpr = AttrExpr Symbol
2323
| BaseTableAttrExpr Attribute
@@ -40,7 +40,7 @@ data PrimExpr = AttrExpr Symbol
4040
| ArrayExpr [PrimExpr] -- ^ ARRAY[..]
4141
| RangeExpr String BoundExpr BoundExpr
4242
| ArrayIndex PrimExpr PrimExpr
43-
deriving (Read,Show)
43+
deriving (Eq, Ord, Read, Show)
4444

4545
data Literal = NullLit
4646
| DefaultLit -- ^ represents a default value
@@ -51,7 +51,7 @@ data Literal = NullLit
5151
| DoubleLit Double
5252
| NumericLit Sci.Scientific
5353
| OtherLit String -- ^ used for hacking in custom SQL
54-
deriving (Read,Show)
54+
deriving (Eq, Ord, Read, Show)
5555

5656
data BinOp = (:==) | (:<) | (:<=) | (:>) | (:>=) | (:<>)
5757
| OpAnd | OpOr
@@ -66,7 +66,7 @@ data BinOp = (:==) | (:<) | (:<=) | (:>) | (:>=) | (:<>)
6666
| (:->) | (:->>) | (:#>) | (:#>>)
6767
| (:@>) | (:<@) | (:?) | (:?|) | (:?&)
6868
| (:&&) | (:<<) | (:>>) | (:&<) | (:&>) | (:-|-)
69-
deriving (Show,Read)
69+
deriving (Eq, Ord, Read, Show)
7070

7171
data UnOp = OpNot
7272
| OpIsNull
@@ -77,22 +77,22 @@ data UnOp = OpNot
7777
| OpLower
7878
| OpUpper
7979
| UnOpOther String
80-
deriving (Show,Read)
80+
deriving (Eq, Ord, Read, Show)
8181

8282
data AggrOp = AggrCount | AggrSum | AggrAvg | AggrMin | AggrMax
8383
| AggrStdDev | AggrStdDevP | AggrVar | AggrVarP
8484
| AggrBoolOr | AggrBoolAnd | AggrArr | JsonArr
8585
| AggrStringAggr
8686
| AggrOther String
87-
deriving (Show,Read)
87+
deriving (Eq, Ord, Read, Show)
8888

8989
data AggrDistinct = AggrDistinct | AggrAll
90-
deriving (Eq,Show,Read)
90+
deriving (Eq, Ord, Read, Show)
9191

9292
type Aggregate = Aggregate' PrimExpr
9393

9494
data Aggregate' a = GroupBy a | Aggregate (Aggr' a)
95-
deriving (Functor, Foldable, Traversable, Show, Read)
95+
deriving (Functor, Foldable, Traversable, Eq, Ord, Read, Show)
9696

9797
data Aggr' a = Aggr
9898
{ aggrOp :: !AggrOp
@@ -102,25 +102,25 @@ data Aggr' a = Aggr
102102
, aggrGroup :: ![OrderExpr' a]
103103
, aggrFilter :: !(Maybe PrimExpr)
104104
}
105-
deriving (Functor, Foldable, Traversable, Show, Read)
105+
deriving (Functor, Foldable, Traversable, Eq, Ord, Read, Show)
106106

107107
type OrderExpr = OrderExpr' PrimExpr
108108

109109
data OrderExpr' a = OrderExpr OrderOp a
110-
deriving (Functor, Foldable, Traversable, Show, Read)
110+
deriving (Functor, Foldable, Traversable, Eq, Ord, Read, Show)
111111

112112
data OrderNulls = NullsFirst | NullsLast
113-
deriving (Show,Read)
113+
deriving (Eq, Ord, Read, Show)
114114

115115
data OrderDirection = OpAsc | OpDesc
116-
deriving (Show,Read)
116+
deriving (Eq, Ord, Read, Show)
117117

118118
data OrderOp = OrderOp { orderDirection :: OrderDirection
119119
, orderNulls :: OrderNulls }
120-
deriving (Show,Read)
120+
deriving (Eq, Ord, Read, Show)
121121

122122
data BoundExpr = Inclusive PrimExpr | Exclusive PrimExpr | PosInfinity | NegInfinity
123-
deriving (Show,Read)
123+
deriving (Eq, Ord, Read, Show)
124124

125125
data WndwOp
126126
= WndwRowNumber
@@ -135,10 +135,10 @@ data WndwOp
135135
| WndwLastValue PrimExpr
136136
| WndwNthValue PrimExpr PrimExpr
137137
| WndwAggregate AggrOp [PrimExpr]
138-
deriving (Show,Read)
138+
deriving (Eq, Ord, Read, Show)
139139

140140
data Partition = Partition
141141
{ partitionBy :: [PrimExpr]
142142
, orderBy :: [OrderExpr]
143143
}
144-
deriving (Read, Show)
144+
deriving (Eq, Ord, Read, Show)

src/Opaleye/Internal/Tag.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module Opaleye.Internal.Tag where
33
import Control.Monad.Trans.State.Strict ( get, modify', State )
44

55
-- | Tag is for use as a source of unique IDs in QueryArr
6-
newtype Tag = UnsafeTag Int deriving (Read, Show)
6+
newtype Tag = UnsafeTag Int deriving (Eq, Ord, Read, Show)
77

88
start :: Tag
99
start = UnsafeTag 1

0 commit comments

Comments
 (0)