Skip to content

ntf server: skip duplicates when importing tokens and subscriptions #1526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/Simplex/Messaging/Notifications/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ import UnliftIO.STM
import GHC.Conc (listThreads)
#endif

import qualified Data.ByteString.Base64 as B64

runNtfServer :: NtfServerConfig -> IO ()
runNtfServer cfg = do
started <- newEmptyTMVarIO
Expand Down
161 changes: 95 additions & 66 deletions src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ import Data.Containers.ListUtils (nubOrd)
import Data.Either (fromRight)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (foldl', intercalate)
import Data.List (findIndex, foldl')
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import Data.Maybe (fromMaybe, mapMaybe)
import qualified Data.Set as S
import Data.Text (Text)
import qualified Data.Text as T
Expand Down Expand Up @@ -587,89 +587,96 @@ toLastNtf (qRow :. (ts, nonce, Binary encMeta)) =

importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> IO (Int64, Int64, Int64)
importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do
(tCnt, tIds) <- importTokens
sCnt <- importSubscriptions tIds
nCnt <- importLastNtfs
(tIds, tCnt) <- importTokens
subLookup <- readTVarIO $ subscriptionLookup stmStore
sCnt <- importSubscriptions tIds subLookup
nCnt <- importLastNtfs tIds subLookup
pure (tCnt, sCnt, nCnt)
where
importTokens = do
allTokens <- M.elems <$> readTVarIO (tokens stmStore)
tokens <- filterTokens allTokens
let skipped = length allTokens - length tokens
when (skipped /= 0) $ putStrLn $ "Total skipped tokens " <> show skipped
tCnt <- withConnection s $ \db -> foldM (insertToken db) 0 tokens
void $ checkCount "token" (length tokens) tCnt
-- uncomment this line instead of the next to import tokens one by one.
-- tCnt <- withConnection s $ \db -> foldM (importTkn db) 0 tokens
tRows <- mapM (fmap ntfTknToRow . mkTknRec) tokens
tCnt <- withConnection s $ \db -> DB.executeMany db insertNtfTknQuery tRows
let tokenIds = S.fromList $ map (\NtfTknData {ntfTknId} -> ntfTknId) tokens
pure (tCnt, tokenIds)
(tokenIds,) <$> checkCount "token" (length tokens) tCnt
where
filterTokens tokens = do
let deviceTokens = foldl' (\m t -> M.alter (Just . (t :) . fromMaybe []) (tokenKey t) m) M.empty tokens
tokenSubs <- readTVarIO (tokenSubscriptions stmStore)
filterM (keepTokenRegistration deviceTokens tokenSubs) tokens
tokenKey NtfTknData {token, tknVerifyKey} = strEncode token <> ":" <> C.toPubKey C.pubKeyBytes tknVerifyKey
keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, token, tknStatus} =
keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, tknStatus} =
case M.lookup (tokenKey tkn) deviceTokens of
Just ts
| length ts >= 2 ->
| length ts < 2 -> pure True
| otherwise ->
readTVarIO tknStatus >>= \case
NTConfirmed -> do
anyActive <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> (NTActive ==) <$> readTVarIO tknStatus') ts
noSubs <- S.null <$> maybe (pure S.empty) readTVarIO (M.lookup ntfTknId tokenSubs)
if anyActive
then (
if noSubs
then False <$ putStrLn ("Skipped inactive token " <> enc ntfTknId <> " (no subscriptions)")
else pure True
)
hasSubs <- maybe (pure False) (\v -> not . S.null <$> readTVarIO v) $ M.lookup ntfTknId tokenSubs
if hasSubs
then pure True
else do
let noSubsStr = if noSubs then " no subscriptions" else " has subscriptions"
putStrLn $ "Error: more than one registration for token " <> enc ntfTknId <> " " <> show token <> noSubsStr
pure True
anyActive <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> (NTActive ==) <$> readTVarIO tknStatus') ts
if anyActive
then False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId)
else case findIndex (\NtfTknData {ntfTknId = tId} -> tId == ntfTknId) ts of
Just 0 -> pure True -- keeping the first token
Just _ -> False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId <> " (no active token)")
Nothing -> True <$ putStrLn "Error: no device token in the list"
_ -> pure True
| otherwise -> pure True
Nothing -> True <$ putStrLn "Error: no device token in lookup map"
insertToken db !n tkn@NtfTknData {ntfTknId} = do
tknRow <- ntfTknToRow <$> mkTknRec tkn
(DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) ->
putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n
importSubscriptions tIds = do
allSubs <- M.elems <$> readTVarIO (subscriptions stmStore)
let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs
skipped = length allSubs - length subs
when (skipped /= 0) $ putStrLn $ "Skipped subscriptions (no tokens) " <> show skipped
-- importTkn db !n tkn@NtfTknData {ntfTknId} = do
-- tknRow <- ntfTknToRow <$> mkTknRec tkn
-- (DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) ->
-- putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n
importSubscriptions :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64
importSubscriptions tIds subLookup = do
subs <- filterSubs . M.elems =<< readTVarIO (subscriptions stmStore)
srvIds <- importServers subs
putStrLn $ "Importing " <> show (length subs) <> " subscriptions..."
-- uncomment this line instead of the next 2 lines to import subs one by one.
(sCnt, missingTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs
-- sCnt <- foldM (importSubs srvIds) 0 $ toChunks 100000 subs
-- let missingTkns = M.empty
putStrLn $ "Imported " <> show sCnt <> " subscriptions"
unless (M.null missingTkns) $ do
putStrLn $ show (M.size missingTkns) <> " missing tokens:"
forM_ (M.assocs missingTkns) $ \(tId, sIds) ->
putStrLn $ "Token " <> enc tId <> " " <> show (length sIds) <> " subscriptions: " <> intercalate ", " (map enc sIds)
-- uncomment this line instead of the next to import subs one by one.
-- (sCnt, errTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs
sCnt <- foldM (importSubs srvIds) 0 $ toChunks 500000 subs
checkCount "subscription" (length subs) sCnt
where
filterSubs allSubs = do
let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs
skipped = length allSubs - length subs
when (skipped /= 0) $ putStrLn $ "Skipped " <> show skipped <> " subscriptions of missing tokens"
let (removedSubTokens, removeSubs, dupQueues) = foldl' addSubToken (S.empty, S.empty, S.empty) subs
unless (null removeSubs) $ putStrLn $ "Skipped " <> show (S.size removeSubs) <> " duplicate subscriptions of " <> show (S.size removedSubTokens) <> " tokens for " <> show (S.size dupQueues) <> " queues"
pure $ filter (\NtfSubData {ntfSubId} -> S.notMember ntfSubId removeSubs) subs
where
addSubToken acc@(!stIds, !sIds, !qs) NtfSubData {ntfSubId, smpQueue, tokenId} =
case M.lookup smpQueue subLookup of
Just sId | sId /= ntfSubId ->
(S.insert tokenId stIds, S.insert ntfSubId sIds, S.insert smpQueue qs)
_ -> acc
importSubs srvIds !n subs = do
rows <- mapM (ntfSubRow srvIds) subs
cnt <- withConnection s $ \db -> DB.executeMany db insertNtfSubQuery $ L.toList rows
let n' = n + cnt
putStr $ "Imported " <> show n' <> " subscriptions" <> "\r"
hFlush stdout
pure n'
importSub db srvIds (!n, !missingTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do
subRow <- ntfSubRow srvIds sub
E.try (DB.execute db insertNtfSubQuery subRow) >>= \case
Right i -> do
let n' = n + i
when (n' `mod` 100000 == 0) $ do
putStr $ "Imported " <> show n' <> " subscriptions" <> "\r"
hFlush stdout
pure (n', missingTkns)
Left (e :: E.SomeException) -> do
when (n `mod` 100000 == 0) $ putStrLn ""
putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e
pure (n, M.alter (Just . (sId :) . fromMaybe []) tokenId missingTkns)
-- importSub db srvIds (!n, !errTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do
-- subRow <- ntfSubRow srvIds sub
-- E.try (DB.execute db insertNtfSubQuery subRow) >>= \case
-- Right i -> do
-- let n' = n + i
-- when (n' `mod` 100000 == 0) $ do
-- putStr $ "Imported " <> show n' <> " subscriptions" <> "\r"
-- hFlush stdout
-- pure (n', errTkns)
-- Left (e :: E.SomeException) -> do
-- when (n `mod` 100000 == 0) $ putStrLn ""
-- putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e
-- pure (n, M.alter (Just . maybe [sId] (sId :)) tokenId errTkns)
ntfSubRow srvIds sub = case M.lookup srv srvIds of
Just sId -> ntfSubToRow sId <$> mkSubRec sub
Nothing -> E.throwIO $ userError $ "no matching server ID for server " <> show srv
Expand All @@ -682,19 +689,32 @@ importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do
where
srvQuery = "INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash) VALUES (?, ?, ?) RETURNING smp_server_id"
srvs = nubOrd $ map ntfSubServer subs
importLastNtfs = do
subLookup <- readTVarIO $ subscriptionLookup stmStore
ntfRows <- fmap concat . mapM (lastNtfRows subLookup) . M.assocs =<< readTVarIO (tokenLastNtfs stmStore)
importLastNtfs :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64
importLastNtfs tIds subLookup = do
ntfs <- readTVarIO (tokenLastNtfs stmStore)
ntfRows <- filterLastNtfRows ntfs
nCnt <- withConnection s $ \db -> DB.executeMany db lastNtfQuery ntfRows
checkCount "last notification" (length ntfRows) nCnt
where
lastNtfQuery = "INSERT INTO last_notifications(token_id, subscription_id, sent_at, nmsg_nonce, nmsg_data) VALUES (?,?,?,?,?)"
lastNtfRows :: M.Map SMPQueueNtf NtfSubscriptionId -> (NtfTokenId, TVar (NonEmpty PNMessageData)) -> IO [(NtfTokenId, NtfSubscriptionId, SystemTime, C.CbNonce, Binary ByteString)]
lastNtfRows subLookup (tId, ntfs) = fmap catMaybes . mapM ntfRow . L.toList =<< readTVarIO ntfs
filterLastNtfRows ntfs = do
(skippedTkns, ntfCnt, (skippedQueues, ntfRows)) <- foldM lastNtfRows (S.empty, 0, (S.empty, [])) $ M.assocs ntfs
let skipped = ntfCnt - length ntfRows
when (skipped /= 0) $ putStrLn $ "Skipped last notifications " <> show skipped <> " for " <> show (S.size skippedTkns) <> " missing tokens and " <> show (S.size skippedQueues) <> " missing subscriptions with token present"
pure ntfRows
lastNtfRows (!stIds, !cnt, !acc) (tId, ntfVar) = do
ntfs <- L.toList <$> readTVarIO ntfVar
let cnt' = cnt + length ntfs
pure $
if S.member tId tIds
then (stIds, cnt', foldl' ntfRow acc ntfs)
else (S.insert tId stIds, cnt', acc)
where
ntfRow PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of
Just ntfSubId -> pure $ Just (tId, ntfSubId, ntfTs, nmsgNonce, Binary encNMsgMeta)
Nothing -> Nothing <$ putStrLn ("Error: no subscription " <> show smpQueue <> " for notification of token " <> enc tId)
ntfRow (!qs, !rows) PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of
Just ntfSubId ->
let row = (tId, ntfSubId, ntfTs, nmsgNonce, Binary encNMsgMeta)
in (qs, row : rows)
Nothing -> (S.insert smpQueue qs, rows)
checkCount name expected inserted
| fromIntegral expected == inserted = do
putStrLn $ "Imported " <> show inserted <> " " <> name <> "s."
Expand All @@ -711,12 +731,21 @@ exportNtfDbStore NtfPostgresStore {dbStoreLog = Nothing} _ =
exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFile =
(,,) <$> exportTokens <*> exportSubscriptions <*> exportLastNtfs
where
exportTokens =
withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn ->
exportTokens = do
tCnt <- withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn ->
logCreateToken sl (rowToNtfTkn tkn) $> (i + 1)
exportSubscriptions =
withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub ->
logCreateSubscription sl (toNtfSub sub) $> (i + 1)
putStrLn $ "Exported " <> show tCnt <> " tokens"
pure tCnt
exportSubscriptions = do
sCnt <- withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub -> do
let i' = i + 1
logCreateSubscription sl (toNtfSub sub)
when (i' `mod` 500000 == 0) $ do
putStr $ "Exported " <> show i' <> " subscriptions" <> "\r"
hFlush stdout
pure i'
putStrLn $ "Exported " <> show sCnt <> " subscriptions"
pure sCnt
where
ntfSubQuery =
[sql|
Expand Down
4 changes: 2 additions & 2 deletions tests/AgentTests/NotificationTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,15 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag
threadDelay 500000
suspendAgent alice 0
closeDBStore store
callCommand "sync"
threadDelay 500000 >> callCommand "sync" >> threadDelay 500000
putStrLn "before opening the database from another agent"

-- aliceNtf client doesn't have subscription and is allowed to get notification message
withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> do
(Just SMPMsgMeta {msgFlags = MsgFlags True}) :| _ <- getConnectionMessages aliceNtf [cId]
pure ()

callCommand "sync"
threadDelay 500000 >> callCommand "sync" >> threadDelay 500000
putStrLn "after closing the database in another agent"
reopenDBStore store
foregroundAgent alice
Expand Down
17 changes: 14 additions & 3 deletions tests/CLITests.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NamedFieldPuns #-}

module CLITests where

import AgentTests.FunctionalAPITests (runRight_)
import Control.Logger.Simple
import Control.Monad
import qualified Crypto.PubKey.RSA as RSA
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as HM
import Data.Ini (Ini (..), lookupValue, readIniFile, writeIniFile)
Expand Down Expand Up @@ -41,8 +43,11 @@ import UnliftIO.Concurrent (threadDelay)
import UnliftIO.Exception (bracket)

#if defined(dbServerPostgres)
import NtfClient (ntfTestServerDBConnectInfo)
import qualified Database.PostgreSQL.Simple as PSQL
import Database.PostgreSQL.Simple.Types (Query (..))
import NtfClient (ntfTestServerDBConnectInfo, ntfTestServerDBConnstr, ntfTestStoreDBOpts)
import SMPClient (postgressBracket)
import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..))
import Simplex.Messaging.Notifications.Server.Main
#endif

Expand Down Expand Up @@ -77,7 +82,7 @@ cliTests = do
it "with store log, no password" $ smpServerTest True False
it "static files" smpServerTestStatic
#if defined(dbServerPostgres)
aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $
around_ (postgressBracket ntfTestServerDBConnectInfo) $ before_ (createNtfSchema ntfTestServerDBConnectInfo ntfTestStoreDBOpts) $
describe "Ntf server CLI" $ do
it "should initialize, start and delete the server (no store log)" $ ntfServerTest False
it "should initialize, start and delete the server (with store log)" $ ntfServerTest True
Expand Down Expand Up @@ -192,9 +197,15 @@ smpServerTestStatic = do
in map (X.signedObject . X.getSigned) cc

#if defined(dbServerPostgres)
createNtfSchema :: PSQL.ConnectInfo -> DBOpts -> IO ()
createNtfSchema connInfo DBOpts {schema} = do
db <- PSQL.connect connInfo
void $ PSQL.execute_ db $ Query $ "CREATE SCHEMA " <> schema
PSQL.close db

ntfServerTest :: Bool -> IO ()
ntfServerTest storeLog = do
capture_ (withArgs (["init"] <> ["--disable-store-log" | not storeLog]) $ ntfServerCLI ntfCfgPath ntfLogPath)
capture_ (withArgs (["init", "--database=" <> B.unpack ntfTestServerDBConnstr] <> ["--disable-store-log" | not storeLog]) $ ntfServerCLI ntfCfgPath ntfLogPath)
>>= (`shouldSatisfy` (("Server initialized, you can modify configuration in " <> ntfCfgPath <> "/ntf-server.ini") `isPrefixOf`))
Right ini <- readIniFile $ ntfCfgPath <> "/ntf-server.ini"
lookupValue "STORE_LOG" "enable" ini `shouldBe` Right (if storeLog then "on" else "off")
Expand Down
Loading