Skip to content

Commit e4dd465

Browse files
committed
DeclareDatabaseContextFunction now takes a rolename-based acl because it faces the user
1 parent 6b604f8 commit e4dd465

File tree

7 files changed

+117
-105
lines changed

7 files changed

+117
-105
lines changed

examples/zoo.hs

Lines changed: 27 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,28 @@
1-
{-# LANGUAGE DeriveGeneric, DerivingVia, DeriveAnyClass, OverloadedStrings #-}
2-
import ProjectM36.Client
3-
import ProjectM36.Relation
4-
import ProjectM36.Tupleable
5-
import ProjectM36.TupleSet
6-
import ProjectM36.Relation.Show.Term
7-
8-
import Data.Typeable
9-
import GHC.Generics
10-
import System.Random (initStdGen)
1+
module Zoo where
2+
import ProjectM36.Module
3+
import ProjectM36.AccessControlList
114
import Data.Time.Calendar
12-
import qualified Data.Text.IO as T
13-
14-
data Ticket = Ticket
15-
{ ticketId :: Integer
16-
, visitorAge :: Integer -- years
17-
, basePrice :: Integer -- base price before adjustments
18-
, visitDate :: Day
19-
}
20-
deriving (Generic, Tupleable)
21-
22-
23-
main :: IO ()
24-
main = do
25-
let ticketAttrs = toAttributes (Proxy :: Proxy Ticket)
26-
ticket1 = Ticket { ticketId = 1,
27-
visitorAge = 8,
28-
basePrice = 20,
29-
visitDate = fromGregorian 2025 10 01 }
30-
ticket2 = Ticket { ticketId = 2,
31-
visitorAge = 25,
32-
basePrice = 20,
33-
visitDate = fromGregorian 2025 12 25 }
34-
Right ticketTupleSet = mkTupleSet ticketAttrs [toTuple ticket1,
35-
toTuple ticket2]
36-
Right ticketRel = mkRelation ticketAttrs ticketTupleSet
37-
rando <- initStdGen
38-
-- connect to the database
39-
conn <- failFast $ connectProjectM36 (InProcessConnectionInfo NoPersistence emptyNotificationCallback [] basicDatabaseContext rando "admin")
40-
41-
-- create a session on the master branch
42-
sessionId <- failFast $ createSessionAtHead conn "master"
43-
44-
-- add the ticket discount function
45-
let func_type = [integerTypeCons, integerTypeCons, ADTypeConstructor "Either" [ADTypeConstructor "AtomFunctionError" [], integerTypeCons]]
46-
integerTypeCons = PrimitiveTypeConstructor "Integer" IntegerAtomType
47-
func_body = "(\\[IntegerAtom age,IntegerAtom price] -> pure (IntegerAtom (if age < 10 then price `div` 2 else price))) :: [Atom] -> Either AtomFunctionError Atom"
48-
failFast $ executeDatabaseContextIOExpr sessionId conn (AddAtomFunction "apply_discount" func_type func_body)
49-
50-
-- calculate the proper discount per ticket and add it to the database
51-
let discountedTicketRel = Extend (AttributeExtendTupleExpr "discounted_price" func_apply_discount) (ExistingRelation ticketRel)
52-
func_apply_discount = FunctionAtomExpr "apply_discount" [AttributeAtomExpr "visitorAge", AttributeAtomExpr "basePrice"] ()
53-
failFast $ executeDatabaseContextExpr sessionId conn (Assign "ticket_sales" discountedTicketRel)
54-
55-
-- print out the resultant relation
56-
Right ticketSalesRelation <- executeRelationalExpr sessionId conn (RelationVariable "ticket_sales" ())
57-
T.putStrLn $ showRelation ticketSalesRelation
58-
59-
60-
61-
62-
-- | Apply a 50% discount for kids under 10 years old. Arguments: age, base price
63-
applyDiscount :: Integer -> Integer -> Integer
64-
applyDiscount age base_price =
65-
if age < 10 then base_price `div` 2 else base_price
66-
67-
failFast :: Show a => IO (Either a b) -> IO b
68-
failFast m = do
69-
ret <- m
70-
case ret of
71-
Left err -> error (show err)
72-
Right val -> pure val
73-
74-
5+
import ProjectM36.Base
6+
import qualified Data.Map as M
7+
8+
apply_discount :: Integer -> Integer -> Integer
9+
apply_discount age price =
10+
if age <= 10 then
11+
price `div` 2
12+
else
13+
price
14+
15+
addSale :: Integer -> Integer -> Integer -> Day -> DatabaseContextFunctionMonad ()
16+
addSale ticketId age price purchaseDay = do
17+
let tuples = [TupleExpr (M.fromList [("ticketId", i ticketId),
18+
("visitorAge", i age),
19+
("basePrice", FunctionAtomExpr "applyDiscount" [i age, i price] ()),
20+
("visitDate", NakedAtomExpr (DayAtom purchaseDay))])]
21+
i = NakedAtomExpr . IntegerAtom
22+
executeDatabaseContextExpr (Insert "ticket_sales" (MakeRelationFromExprs Nothing (TupleExprs () tuples)))
23+
24+
25+
projectM36Functions :: EntryPoints ()
26+
projectM36Functions = do
27+
declareAtomFunction "apply_discount"
28+
declareDatabaseContextFunction "addSale" (AccessControlList (M.singleton "ticket_seller" (M.singleton ExecuteDBCFunctionPermission False)))

src/lib/ProjectM36/AccessControlList.hs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import Data.Hashable
88
import qualified Data.Map as M
99
import Data.Maybe (fromMaybe)
1010
import Data.Default
11+
import Control.Monad (foldM)
1112

1213
newtype AccessControlList role' permission =
1314
AccessControlList (M.Map role' (RoleAccess permission))
@@ -120,8 +121,11 @@ normalize (AccessControlList acl) = AccessControlList $ M.filter (not . M.null)
120121
empty :: Ord r => AccessControlList r p
121122
empty = AccessControlList mempty
122123

123-
allPermissionsForRoleId :: (AllPermissions p, Ord p) => r -> AccessControlList r p
124-
allPermissionsForRoleId roleid = AccessControlList (M.singleton roleid (M.fromList (map (,True) (S.toList allPermissions))))
124+
allPermissionsForRole :: (AllPermissions p, Ord p) => r -> AccessControlList r p
125+
allPermissionsForRole roleid = AccessControlList (M.singleton roleid (M.fromList (map (,True) (S.toList allPermissions))))
126+
127+
permissionForRole :: Ord p => p -> r -> AccessControlList r p
128+
permissionForRole perm roleid = AccessControlList (M.singleton roleid (M.singleton perm False))
125129

126130
merge :: (Ord p, Eq r, Ord r) => AccessControlList r p -> AccessControlList r p -> AccessControlList r p
127131
merge (AccessControlList acl1) (AccessControlList acl2) =
@@ -191,4 +195,20 @@ data SomePermission = SomeRelVarPermission RelVarPermission |
191195
SomeDBCFunctionPermission DBCFunctionPermission
192196
deriving (Show, NFData, Generic, Eq, Hashable)
193197

198+
-- | Resolve the first parameter in the AccessContrList, typically a role name resolving to a role id.
199+
resolve :: Ord r2 => (r1 -> IO (Maybe r2)) -> AccessControlList r1 a -> IO (Either r1 (AccessControlList r2 a))
200+
resolve resolver (AccessControlList aclIn) = do
201+
let folder (Right acc) (r, v) = do
202+
eR2 <- resolver r
203+
case eR2 of
204+
Nothing -> pure (Left r)
205+
Just r2 -> pure (Right ((r2,v):acc))
206+
folder (Left acc) _ = pure (Left acc)
207+
eres <- foldM folder (Right []) (M.toList aclIn)
208+
case eres of
209+
Left errR -> pure (Left errR)
210+
Right res ->
211+
pure (Right (AccessControlList (M.fromList res)))
212+
213+
194214

src/lib/ProjectM36/Client.hs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ import ProjectM36.Session
186186
import ProjectM36.ValueMarker
187187
import ProjectM36.AccessControl
188188
import ProjectM36.Sessions
189-
import ProjectM36.AccessControlList
189+
import ProjectM36.AccessControlList as ACL
190190
import ProjectM36.HashSecurely (SecureHash)
191191
import ProjectM36.RegisteredQuery
192192
import qualified ProjectM36.Cache.RelationalExprCache as RelExprCache
@@ -769,7 +769,18 @@ executeDatabaseContextIOExpr sessionId (InProcessConnection conf) expr = do
769769
Left err -> pure (Left err)
770770
Right session -> do
771771
graph <- readTVarIO (ipTransactionGraph conf)
772-
let env = RE.DatabaseContextIOEvalEnv transId graph scriptSession myRoleId objFilesPath dbcFuncUtils
772+
let env = RE.DatabaseContextIOEvalEnv transId graph scriptSession myRoleId resolveRoleNameACL objFilesPath dbcFuncUtils
773+
resolveRoleNameACL :: forall a. AccessControlList RoleName a -> IO (Either RelationalError (AccessControlList RoleId a))
774+
resolveRoleNameACL acl' = do
775+
eRes <- ACL.resolve (\r -> do
776+
eRid <- LoginRoles.roleIdForRoleName r (ipLoginRoles conf)
777+
pure $ case eRid of
778+
Left _err -> Nothing
779+
Right rid -> Just rid
780+
) acl'
781+
case eRes of
782+
Left badR -> pure (Left (NoSuchRoleNameError badR))
783+
Right res -> pure (Right res)
773784
dbcEnv = RE.mkDatabaseContextEvalEnv transId graph dbcFuncUtils
774785
roleNameResolver nam = fst <$> lookup nam roles
775786
dbcFuncUtils = DBC.DatabaseContextFunctionUtils {

src/lib/ProjectM36/Module.hs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ import Control.Monad.RWS.Strict (RWST, get, put, ask, runRWST)
99
import Control.Monad.Except (ExceptT, throwError, runExceptT)
1010
import Data.Functor.Identity
1111

12+
-- | Variant of ACLs which use roles to be later resolved to role ids used by user-facing ProjectM36.Module.
13+
type DBCFunctionRoleNameAccessControlList = AccessControlList RoleName DBCFunctionPermission
14+
1215
declareAtomFunction :: FunctionName -> EntryPoints ()
1316
declareAtomFunction nam = tell [DeclareAtomFunction nam]
1417

15-
declareDatabaseContextFunction :: FunctionName -> DBCFunctionAccessControlList -> EntryPoints ()
18+
declareDatabaseContextFunction :: FunctionName -> DBCFunctionRoleNameAccessControlList -> EntryPoints ()
1619
declareDatabaseContextFunction nam acl' = tell [DeclareDatabaseContextFunction nam acl']
1720

1821
type EntryPoints = Writer [DeclareFunction]
@@ -21,7 +24,7 @@ runEntryPoints :: EntryPoints () -> [DeclareFunction]
2124
runEntryPoints = execWriter
2225

2326
data DeclareFunctionBase a = DeclareAtomFunction a |
24-
DeclareDatabaseContextFunction a DBCFunctionAccessControlList
27+
DeclareDatabaseContextFunction a DBCFunctionRoleNameAccessControlList
2528
deriving (Show)
2629

2730
type DeclareFunction = DeclareFunctionBase FunctionName

src/lib/ProjectM36/RelationalExpression.hs

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
{-# LANGUAGE FlexibleContexts #-}
44
{-# LANGUAGE OverloadedStrings #-}
55
{-# LANGUAGE MultiParamTypeClasses #-}
6+
{-# LANGUAGE RankNTypes #-}
67
module ProjectM36.RelationalExpression where
78
import ProjectM36.Relation
89
import ProjectM36.Tuple
@@ -51,7 +52,7 @@ import Control.Monad.Trans.Except (except)
5152
import ProjectM36.NormalizeExpr
5253
import ProjectM36.WithNameExpr
5354
import ProjectM36.Function
54-
import ProjectM36.AccessControlList as ACL
55+
import ProjectM36.AccessControlList (RoleId, AccessControlList, SomePermission(..), relvarsACL, dbcFunctionsACL, schemaACL, transGraphACL, aclACL, allPermissionsForRole, addAccess, removeAccess)
5556
import Test.QuickCheck
5657
import Data.Functor (void)
5758
import qualified Data.Functor.Foldable as Fold
@@ -65,8 +66,10 @@ import ProjectM36.Module
6566
import GHC hiding (getContext)
6667
import Control.Exception
6768
import GHC.Paths
69+
--import System.FilePath
6870
--import GHC.Unit.State (emptyUnitState)
6971
--import GHC.Driver.Ppr (showSDocForUser)
72+
--import GHC.Utils.Outputable (ppr)
7073
--import GHC.Unit.Finder.Types (FindResult(..))
7174
--import GHC.Unit.Finder (findImportedModule)
7275
import GHC.Types.Name.Occurrence (mkVarOcc, mkTcOcc)
@@ -77,7 +80,7 @@ import GHC.Builtin.Types (unitTy)
7780
import GHC.Core.TyCo.Compare (eqType)
7881
import Unsafe.Coerce
7982
import Control.Monad (forM)
80-
--import GHC.Utils.Outputable (ppr)
83+
8184
#endif
8285

8386
data DatabaseContextExprDetails = CountUpdatedTuples
@@ -557,6 +560,7 @@ data DatabaseContextIOEvalEnv = DatabaseContextIOEvalEnv
557560
dbcio_graph :: TransactionGraph,
558561
dbcio_mScriptSession :: Maybe ScriptSession,
559562
dbcio_roleId :: RoleId,
563+
resolveRoleNameACL :: forall x. AccessControlList RoleName x -> IO (Either RelationalError (AccessControlList RoleId x)),
560564
dbcio_mModulesDirectory :: Maybe FilePath, -- ^ when running in persistent mode, this must be a Just value to a directory containing .o/.so/.dynlib files which the user has placed there for access to compiled functions
561565
dbcio_dbcfunctionUtils :: DatabaseContextFunctionUtils
562566
}
@@ -660,7 +664,7 @@ evalGraphRefDatabaseContextIOExpr (AddDatabaseContextFunction funcName' funcType
660664
funcName = funcName',
661665
funcType = funcAtomType,
662666
funcBody = FunctionScriptBody script compiledFunc,
663-
funcACL = allPermissionsForRoleId myRoleId
667+
funcACL = allPermissionsForRole myRoleId
664668
}
665669
-- check if the name is already in use
666670
if HS.member funcName' (HS.map funcName dbcFuncs) then
@@ -1888,14 +1892,15 @@ importModuleFromPath :: ScriptSession -> ModuleBody -> DatabaseContextIOEvalMona
18881892
importModuleFromPath _scriptSession _moduleSource = throwError (ScriptError ScriptCompilationDisabledError)
18891893
#else
18901894
importModuleFromPath scriptSession moduleSource = do
1895+
resolveRoleNameACLF <- resolveRoleNameACL <$> ask
18911896
res <- liftIO $ try $ do
18921897
withSystemTempFile "pm36module" $ \tempModulePath tempModuleHandle -> do
18931898
hClose tempModuleHandle
18941899
TIO.writeFile tempModulePath moduleSource
18951900
runGhc (Just libdir) $ do
18961901
-- GHC needs to see the module on disk, so we write it to a temporary location
18971902
setSession (hscEnv scriptSession)
1898-
dflags <- getSessionDynFlags
1903+
dflags <- getSessionDynFlags
18991904
let target = Target {
19001905
targetId = TargetFile tempModulePath Nothing,
19011906
targetAllowObjCode = False,
@@ -1907,7 +1912,7 @@ importModuleFromPath scriptSession moduleSource = do
19071912
case loadSuccess of
19081913
Failed -> pure (Left (ScriptError ModuleLoadError))
19091914
Succeeded -> do
1910-
{-modRes <- liftIO $ findImportedModule (hscEnv scriptSession) (mkModuleName "ProjectM36.Base") NoPkgQual
1915+
{- modRes <- liftIO $ findImportedModule (hscEnv scriptSession) (mkModuleName "ProjectM36.Base") NoPkgQual
19111916
liftIO $ case modRes of
19121917
Found modLoc _mod -> do
19131918
let packageLoc = takeDirectory (takeDirectory (takeDirectory (ml_dyn_obj_file modLoc)))
@@ -1958,21 +1963,26 @@ importModuleFromPath scriptSession moduleSource = do
19581963
tyConv <- mkTypeConversions
19591964
mkFunctions <- forM funcDeclarations $ \funcDecl -> do
19601965
case funcDecl of
1961-
DeclareDatabaseContextFunction funcS acl' -> do
1962-
fType <- exprType TM_Default (T.unpack funcS)
1963-
dbcFuncMonadType <- exprType TM_Default "undefined :: DatabaseContextFunctionMonad ()"
1964-
-- extract arguments for dbc function
1965-
let eAtomFuncType = convertGhcTypeToDatabaseContextFunctionAtomType dflags tyConv dbcFuncMonadType fType
1966-
case eAtomFuncType of
1967-
Left err -> throw (OtherScriptCompilationError (show err))
1968-
Right dbcFuncType -> do
1969-
let interpretedFunc = wrapDatabaseContextFunction dbcFuncType funcS
1970-
dbcFunc :: DatabaseContextFunctionBodyType <- unsafeCoerce <$> compileExpr interpretedFunc
1971-
let newDBCFunc = Function { funcName = funcS,
1972-
funcType = dbcFuncType,
1973-
funcBody = FunctionScriptBody (T.pack interpretedFunc) dbcFunc,
1974-
funcACL = acl' }
1975-
pure (MkDatabaseContextFunction newDBCFunc)
1966+
DeclareDatabaseContextFunction funcS roleNameACL -> do
1967+
--resolve role name ACL into role-id-based ACL
1968+
eACL <- liftIO $ resolveRoleNameACLF roleNameACL
1969+
case eACL of
1970+
Left err -> throw err
1971+
Right acl' -> do
1972+
fType <- exprType TM_Default (T.unpack funcS)
1973+
dbcFuncMonadType <- exprType TM_Default "undefined :: DatabaseContextFunctionMonad ()"
1974+
-- extract arguments for dbc function
1975+
let eAtomFuncType = convertGhcTypeToDatabaseContextFunctionAtomType dflags tyConv dbcFuncMonadType fType
1976+
case eAtomFuncType of
1977+
Left err -> throw (OtherScriptCompilationError (show err))
1978+
Right dbcFuncType -> do
1979+
let interpretedFunc = wrapDatabaseContextFunction dbcFuncType funcS
1980+
dbcFunc :: DatabaseContextFunctionBodyType <- unsafeCoerce <$> compileExpr interpretedFunc
1981+
let newDBCFunc = Function { funcName = funcS,
1982+
funcType = dbcFuncType,
1983+
funcBody = FunctionScriptBody (T.pack interpretedFunc) dbcFunc,
1984+
funcACL = acl' }
1985+
pure (MkDatabaseContextFunction newDBCFunc)
19761986
DeclareAtomFunction funcS -> do
19771987
--extract type from function in script
19781988
fType <- exprType TM_Default (T.unpack funcS)

src/lib/ProjectM36/ScriptSession.hs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ import GHC.LanguageExtensions (Extension(OverloadedStrings,ExtendedDefaultRules,
2828
-- GHC 9.4+
2929
import Data.List.NonEmpty(NonEmpty(..))
3030
import GHC.Utils.Panic (handleGhcException)
31-
import GHC.Driver.Session (projectVersion, PackageDBFlag(PackageDB), PkgDbRef(PkgDbPath), TrustFlag(TrustPackage), gopt_set, xopt_set, PackageFlag(ExposePackage), PackageArg(PackageArg), ModRenaming(ModRenaming))
31+
import GHC.Driver.Session (projectVersion, PackageDBFlag(PackageDB), PkgDbRef(PkgDbPath), TrustFlag(TrustPackage), gopt_set, xopt_set, PackageFlag(ExposePackage), PackageArg(PackageArg), ModRenaming(ModRenaming), runCmdLineP, setFlagsFromEnvFile)
32+
import GHC.Driver.CmdLine (runEwM)
3233
import GHC.Types.SourceText (SourceText(NoSourceText))
3334
import GHC.Driver.Ppr (showSDocForUser)
34-
import GHC.Utils.Outputable (ppr, Outputable)
35+
import GHC.Utils.Outputable (Outputable, ppr)
3536
--import GHC.Unit.Module.Graph (showModMsg, mgModSummaries')
3637
import GHC.Core.TyCo.Ppr (pprType)
3738
import GHC.Utils.Encoding (zEncodeString)
@@ -148,8 +149,10 @@ initScriptSession ghcPkgPaths = do
148149
"project-m36",
149150
"bytestring"]
150151
packages = map (\m -> ExposePackage ("-package " ++ m) (PackageArg m) (ModRenaming True [])) required_packages
151-
--liftIO $ traceShowM (showSDoc dflags' (ppr packages))
152-
_ <- setSessionDynFlags dflags'
152+
--liftIO $ traceShowM (showSDoc dflags' (ppr packages))
153+
-- load GHC environment file, if specific in environment variables
154+
dflags'' <- loadGhcEnvFile dflags'
155+
_ <- setSessionDynFlags dflags''
153156
let unqualifiedModules = map (\modn -> IIDecl $ safeImportDecl modn Nothing) [
154157
"Prelude",
155158
"Data.Map",
@@ -429,4 +432,15 @@ convertGhcTypeToDatabaseContextFunctionAtomType dflags tyConv dbcFuncMonadType t
429432
rest <- convertGhcTypeToDatabaseContextFunctionAtomType dflags tyConv dbcFuncMonadType (ft_res fun)
430433
pure (arg <> rest)
431434
other -> Left (UnsupportedTypeConversionError (pprShow other dflags))
435+
436+
loadGhcEnvFile :: GhcMonad m => DynFlags -> m DynFlags
437+
loadGhcEnvFile dflags = do
438+
mEnvPath <- liftIO $ lookupEnv "GHC_ENVIRONMENT"
439+
-- traceShowM ("GHC_ENVIRONMENT"::String, mEnvPath)
440+
case mEnvPath of
441+
Nothing -> pure dflags
442+
Just envPath -> do
443+
content <- liftIO $ readFile envPath
444+
let (_, dflags') = runCmdLineP (runEwM (setFlagsFromEnvFile envPath content)) dflags
445+
pure dflags'
432446
#endif

test/TutorialD/Interpreter/TestModule.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ projectM36Functions :: EntryPoints ()
1616
projectM36Functions = do
1717
declareAtomFunction "multiTypesAtomFunc"
1818
declareAtomFunction "applyDiscount"
19-
declareDatabaseContextFunction "addSale" (allPermissionsForRoleId adminRoleId)
20-
declareDatabaseContextFunction "multiTypesDBCFunc" (allPermissionsForRoleId adminRoleId)
19+
declareDatabaseContextFunction "addSale" (allPermissionsForRole "admin")
20+
declareDatabaseContextFunction "multiTypesDBCFunc" (allPermissionsForRole "admin")
2121

2222
applyDiscount :: Integer -> Integer -> Integer
2323
applyDiscount age price =

0 commit comments

Comments
 (0)