Skip to content

Infrastructure: Change db from mariadb to postgres #711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ctfd/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ schema==0.7.5
bleach==6.1.0
ipython==8.16.1
flask-shell-ipython==0.5.1
psycopg2-binary==2.9.10

# CTFd
Flask==2.2.5
Expand Down
3 changes: 0 additions & 3 deletions db/Dockerfile

This file was deleted.

1 change: 1 addition & 0 deletions db/init/00_pgcrypto.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS pgcrypto;
8 changes: 0 additions & 8 deletions db/start.sh

This file was deleted.

24 changes: 10 additions & 14 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ services:
hard: 1048576
environment:
- UPLOAD_FOLDER=/var/uploads
- DATABASE_URL=mysql+pymysql://${DB_USER}:${DB_PASS}@${DB_HOST}/${DB_NAME}
- DATABASE_URL=postgresql+psycopg2://${DB_USER}:${DB_PASS}@${DB_HOST}/${DB_NAME}
- REDIS_URL=redis://cache:6379
- WORKERS=8
- LOG_FOLDER=/var/log/CTFd
Expand Down Expand Up @@ -140,19 +140,18 @@ services:
container_name: db
profiles:
- main
build: ./db
image: postgres:17.5
restart: always
environment:
- MYSQL_ROOT_PASSWORD=${DB_PASS}
- MYSQL_USER=${DB_USER}
- MYSQL_PASSWORD=${DB_PASS}
- MYSQL_DATABASE=${DB_NAME}
- DB_EXTERNAL=${DB_EXTERNAL}
- POSTGRES_USER=${DB_USER}
- POSTGRES_PASSWORD=${DB_PASS}
- POSTGRES_DB=${DB_NAME}
- PGUSER=${DB_USER}
volumes:
- /data/mysql:/var/lib/mysql
command: [/start.sh]
- /data/postgres:/var/lib/postgresql/data
- ./db/init:/docker-entrypoint-initdb.d:ro
healthcheck:
test: ["CMD", "mysqladmin", "ping", "-p${DB_PASS}", "-u${DB_USER}", "-h${DB_HOST}"]
test: ["CMD", "pg_isready"]
interval: 10s
timeout: 10s
retries: 3
Expand All @@ -177,10 +176,7 @@ services:
- /var/run/docker.sock:/var/run/docker.sock:ro
- /data/mac:/var/data/mac:ro
environment:
- DB_HOST=${DB_HOST}
- DB_NAME=${DB_NAME}
- DB_USER=${DB_USER}
- DB_PASS=${DB_PASS}
- DATABASE_URL=postgresql+psycopg2://${DB_USER}:${DB_PASS}@${DB_HOST}/${DB_NAME}
- REDIS_URL=redis://cache:6379
- MAC_HOSTNAME=${MAC_HOSTNAME}
- MAC_USERNAME=${MAC_USERNAME}
Expand Down
2 changes: 0 additions & 2 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ The DOJO uses mysql.

The DOJO database lives in the `db` container by default.
You can use an external database by setting `DB_HOST` in `config.env`.
To save resources, also set `DB_EXTERNAL` to `yes` so that the `db` container does not actually start mysql.

You can launch a database client session with `dojo db`.

## CTFd and the dojo-plugin
Expand Down
31 changes: 15 additions & 16 deletions dojo/dojo
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ case "$ACTION" in
shift
fi
DOJO_UID="$1"
[ -n "${DOJO_UID//[0-9]}" ] && DOJO_UID=$(
echo "select id from users where name='$DOJO_UID'" |
$0 db -s
)
[ -n "${DOJO_UID//[0-9]}" ] && DOJO_UID=$(echo "select id from users where name='$DOJO_UID'" | dojo db -qAt)
CONTAINER="user_$DOJO_UID"
shift

Expand Down Expand Up @@ -84,13 +81,25 @@ case "$ACTION" in

# HELP: db: launch a mysql client session, connected to the ctfd db
"db")
docker exec $DOCKER_ARGS db mysql -h ${DB_HOST} -p${DB_PASS} -D${DB_NAME} -u${DB_USER} "$@"
docker exec $DOCKER_ARGS db psql "$@"
;;

# HELP: backup: does a dojo db backup into the `/data/backups` directory.
"backup")
mkdir -p /data/backups
docker exec db mysqldump -h ${DB_HOST} -p${DB_PASS} -u${DB_USER} --single-transaction --routines --triggers ${DB_NAME} | gzip > "/data/backups/db-$(date -Iseconds).sql.gz"
BACKUP_PATH="/data/backups/db-$(date -Iseconds).dump"
docker exec db pg_dump -Fc > "$BACKUP_PATH"
echo "Created backup at $BACKUP_PATH"
;;

# HELP: restore PATH: restores a dojo db backup. Path arg is relative to the `/data/backups` directory
"restore")
BACKUP_PATH="/data/backups/$1"
if [ -f "$BACKUP_PATH" ]; then
docker exec -i db pg_restore --clean --if-exists --dbname="$DB_NAME" < "$BACKUP_PATH"
else
echo "Error: missing file to restore from" >&2
fi
;;

# HELP: cloud-backup: upload the last day's worth of cloud backups to S3, but encrypt it at rest
Expand All @@ -112,16 +121,6 @@ case "$ACTION" in
done
;;

# HELP: restore PATH: restores a dojo db backup. Path arg is relative to the `/data/backups` directory
"restore")
BACKUP_PATH="/data/backups/$1"
if [ -f "$BACKUP_PATH" ]; then
gunzip < "$BACKUP_PATH" | docker exec -i db mysql -h ${DB_HOST} -p${DB_PASS} -u${DB_USER} -D${DB_NAME}
else
echo "Error: missing file to restore from" >&2
fi
;;

# HELP: vscode: start vscode tunnel
"vscode")
dojo-vscode "$@"
Expand Down
1 change: 0 additions & 1 deletion dojo/dojo-init
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ define DB_HOST db
define DB_NAME ctfd
define DB_USER ctfd
define DB_PASS ctfd
define DB_EXTERNAL no # change to anything but no and the db container will not start mysql
define BACKUP_AES_KEY_FILE
define S3_BACKUP_BUCKET
define AWS_DEFAULT_REGION
Expand Down
2 changes: 0 additions & 2 deletions dojo_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from .pages.discord import discord
from .pages.course import course
from .pages.canvas import sync_canvas_user, canvas
from .pages.writeups import writeups
from .pages.belts import belts
from .pages.research import research
from .pages.index import static_html_override
Expand Down Expand Up @@ -146,7 +145,6 @@ def load(app):
app.register_blueprint(users)
app.register_blueprint(course)
app.register_blueprint(canvas)
app.register_blueprint(writeups)
app.register_blueprint(belts)
app.register_blueprint(research)
app.register_blueprint(api, url_prefix="/pwncollege_api/v1")
Expand Down
1 change: 1 addition & 0 deletions dojo_plugin/api/v1/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get(self):
.join(Dojos)
.filter(Dojos.official, DojoChallenges.visible())
.distinct()
.with_entities(Challenges.id)
)
rank = db.func.row_number().over(
order_by=(db.func.count(Solves.id).desc(), db.func.max(Solves.id))
Expand Down
12 changes: 7 additions & 5 deletions dojo_plugin/api/v1/scoreboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def get_scoreboard_for(model, duration):
.over(order_by=(solves.desc(), db.func.max(Solves.id)))
.label("rank")
)
user_entities = [Solves.user_id, Users.name, Users.email]
query = (
model.solves()
.filter(duration_filter)
.group_by(Solves.user_id)
.group_by(*user_entities)
.order_by(rank)
.with_entities(rank, solves, Solves.user_id, Users.name, Users.email)
.with_entities(rank, solves, *user_entities)
)

row_results = query.all()
Expand All @@ -63,12 +64,14 @@ def invalidate_scoreboard_cache():

# handle cache invalidation for new solves, dojo creation, dojo challenge creation
@event.listens_for(Dojos, 'after_insert', propagate=True)
@event.listens_for(Dojos, 'after_delete', propagate=True)
@event.listens_for(Solves, 'after_insert', propagate=True)
@event.listens_for(Solves, 'after_delete', propagate=True)
@event.listens_for(Awards, 'after_insert', propagate=True)
@event.listens_for(Belts, 'after_insert', propagate=True)
@event.listens_for(Emojis, 'after_insert', propagate=True)
@event.listens_for(Awards, 'after_delete', propagate=True)
@event.listens_for(Belts, 'after_insert', propagate=True)
@event.listens_for(Belts, 'after_delete', propagate=True)
@event.listens_for(Emojis, 'after_insert', propagate=True)
@event.listens_for(Emojis, 'after_delete', propagate=True)
def hook_object_creation(mapper, connection, target):
invalidate_scoreboard_cache()
Expand All @@ -84,7 +87,6 @@ def hook_object_creation(mapper, connection, target):
@event.listens_for(DojoChallengeVisibilities, 'after_update', propagate=True)
@event.listens_for(Belts, 'after_update', propagate=True)
@event.listens_for(Emojis, 'after_update', propagate=True)
@event.listens_for(Awards, 'after_insert', propagate=True)
def hook_object_update(mapper, connection, target):
# according to the docs, this is a necessary check to see if the
# target actually was modified (and thus an update was made)
Expand Down
80 changes: 47 additions & 33 deletions dojo_plugin/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import yaml
from flask import current_app
from sqlalchemy import String, DateTime, case, cast, Numeric
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import synonym
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.orm.session import object_session
Expand Down Expand Up @@ -70,7 +71,7 @@ class Dojos(db.Model):
official = db.Column(db.Boolean, index=True)
password = db.Column(db.String(128))

data = db.Column(db.JSON)
data = db.Column(JSONB)
data_fields = ["type", "award", "course", "pages", "privileged", "importable", "comparator"]
data_defaults = {
"pages": [],
Expand Down Expand Up @@ -217,7 +218,7 @@ def ordering(cls):
return (
~cls.official,
cls.data["type"],
cast(case([(cls.data["comparator"] == None, 1000)], else_=cls.data["comparator"]), Numeric()),
db.func.coalesce(cast(cls.data["comparator"].astext, Numeric()), 1000),
cls.name,
)

Expand All @@ -226,7 +227,7 @@ def viewable(cls, id=None, user=None):
return (
(cls.from_id(id) if id is not None else cls.query)
.filter(or_(cls.official,
and_(cls.data["type"] == "public", cls.password == None),
and_(cls.data["type"].astext == "public", cls.password == None),
cls.dojo_id.in_(db.session.query(DojoUsers.dojo_id)
.filter_by(user=user)
.subquery())))
Expand All @@ -237,15 +238,22 @@ def solves(self, **kwargs):
return DojoChallenges.solves(dojo=self, **kwargs)

def completions(self):
"""
Returns a list of (User, completion_timestamp) tuples for users, sorted by time in ascending order.
"""
sq = Solves.query.join(DojoChallenges, Solves.challenge_id == DojoChallenges.challenge_id).add_columns(
Solves.user_id.label("solve_user_id"), db.func.count().label("solve_count"), db.func.max(Solves.date).label("last_solve")
).filter(DojoChallenges.dojo == self).group_by(Solves.user_id).subquery()
return Users.query.join(sq).filter_by(
solve_count=len(self.challenges)
).add_column(sq.columns.last_solve).order_by(sq.columns.last_solve).all()
solves_subquery = (
self.solves(ignore_visibility=True, ignore_admins=False)
.with_entities(Solves.user_id,
db.func.count().label("solve_count"),
db.func.max(Solves.date).label("last_solve"))
.group_by(Solves.user_id)
.having(db.func.count() == len(self.challenges))
.subquery()
)
return (
Users.query
.join(solves_subquery, Users.id == solves_subquery.c.user_id)
.add_columns(solves_subquery.c.last_solve)
.order_by(solves_subquery.c.last_solve)
.all()
)

def awards(self):
if not self.award:
Expand Down Expand Up @@ -286,7 +294,7 @@ class DojoUsers(db.Model):
dojo = db.relationship("Dojos", back_populates="users", overlaps="admins,members,students")
user = db.relationship("Users")

survey_responses = db.relationship("SurveyResponses", back_populates="users", overlaps="admins,members,students")
# survey_responses = db.relationship("SurveyResponses", back_populates="users", overlaps="admins,members,students")

def solves(self, **kwargs):
return DojoChallenges.solves(user=self.user, dojo=self.dojo, **kwargs)
Expand Down Expand Up @@ -331,7 +339,7 @@ class DojoModules(db.Model):
name = db.Column(db.String(128))
description = db.Column(db.Text)

data = db.Column(db.JSON)
data = db.Column(JSONB)
data_fields = ["importable"]
data_defaults = {
"importable": True
Expand Down Expand Up @@ -466,7 +474,7 @@ class DojoChallenges(db.Model):
name = db.Column(db.String(128))
description = db.Column(db.Text)

data = db.Column(db.JSON)
data = db.Column(JSONB)
data_fields = ["image", "path_override", "importable", "allow_privileged", "progression_locked", "survey"]
data_defaults = {
"importable": True,
Expand All @@ -485,7 +493,7 @@ class DojoChallenges(db.Model):
cascade="all, delete-orphan",
back_populates="challenge")

survey_responses = db.relationship("SurveyResponses", back_populates="challenge", cascade="all, delete-orphan")
# survey_responses = db.relationship("SurveyResponses", back_populates="challenge", cascade="all, delete-orphan")

def __init__(self, *args, **kwargs):
default = kwargs.pop("default", None)
Expand Down Expand Up @@ -550,7 +558,7 @@ def solves(self, *, user=None, dojo=None, module=None, ignore_visibility=False,
))
.join(Dojos, and_(
Dojos.dojo_id == DojoChallenges.dojo_id,
or_(Dojos.official, Dojos.data["type"] == "public", DojoUsers.user_id != None),
or_(Dojos.official, Dojos.data["type"].astext == "public", DojoUsers.user_id != None),
))
.join(Users, Users.id == Solves.user_id)
)
Expand Down Expand Up @@ -609,19 +617,19 @@ def resolve(self):

class SurveyResponses(db.Model):
__tablename__ = "survey_responses"

id = db.Column(db.Integer, primary_key=True, autoincrement=True)
dojo_id = db.Column(db.Integer, db.ForeignKey("dojo_challenges.dojo_id", ondelete="CASCADE"), nullable=False)
challenge_id = db.Column(db.Integer, db.ForeignKey("challenges.id", ondelete="CASCADE"), index=True, nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey("dojo_users.user_id", ondelete="CASCADE"), nullable=False)
type = db.Column(db.String(64), nullable=False)
prompt = db.Column(db.Text, nullable=False)
response = db.Column(db.Text, nullable=False)
timestamp = db.Column(db.DateTime, default=datetime.datetime.utcnow, nullable=False)
dojo_id = db.Column(db.Integer, db.ForeignKey("dojos.dojo_id", ondelete="CASCADE"))
challenge_id = db.Column(db.Integer, db.ForeignKey("challenges.id", ondelete="CASCADE"))
user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"))

type = db.Column(db.String(64))
prompt = db.Column(db.Text)
response = db.Column(db.Text)
timestamp = db.Column(db.DateTime, default=datetime.datetime.utcnow)

challenge = db.relationship("DojoChallenges", back_populates="survey_responses")
users = db.relationship("DojoUsers", back_populates="survey_responses")
# challenge = db.relationship("DojoChallenges", back_populates="survey_responses")
# users = db.relationship("DojoUsers", back_populates="survey_responses")


class DojoResources(db.Model):
Expand All @@ -640,7 +648,7 @@ class DojoResources(db.Model):
type = db.Column(db.String(80), index=True)
name = db.Column(db.String(128))

data = db.Column(db.JSON)
data = db.Column(JSONB)
data_fields = ["content", "video", "playlist", "slides"]

dojo = db.relationship("Dojos", back_populates="resources", viewonly=True)
Expand Down Expand Up @@ -760,10 +768,16 @@ class DojoModuleVisibilities(db.Model):

class SSHKeys(db.Model):
__tablename__ = "ssh_keys"
user_id = db.Column(
db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"), primary_key=True

id = db.Column(db.Integer, primary_key=True, autoincrement=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"), index=True)
value = db.Column(db.Text)

__table_args__ = (
db.Index("uq_ssh_keys_digest",
db.func.digest(value, "sha256"),
unique=True),
)
value = db.Column(db.String(750), primary_key=True, unique=True)

user = db.relationship("Users")

Expand All @@ -788,7 +802,7 @@ class DiscordUsers(db.Model):
user_id = db.Column(
db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
discord_id = db.Column(db.Integer, unique=True)
discord_id = db.Column(db.BigInteger, unique=True)

user = db.relationship("Users")

Expand Down
1 change: 0 additions & 1 deletion dojo_plugin/pages/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ..utils import is_dojo_admin
from ..utils.dojo import dojo_route
from ..utils.discord import add_role, get_discord_member
from .writeups import WriteupComments, writeup_weeks, all_writeups

course = Blueprint("course", __name__)

Expand Down
Loading