Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 70 additions & 23 deletions src/country_workspace/management/commands/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any

from typing import TYPE_CHECKING, Any, Final
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.core.management import BaseCommand, call_command
from django.core.validators import validate_email
Expand All @@ -20,6 +20,9 @@
logger = logging.getLogger(__name__)


FALLBACK_EMAIL_DOMAIN: Final[str] = "example.org"


class Command(BaseCommand):
requires_migrations_checks = False
requires_system_checks = []
Expand Down Expand Up @@ -83,6 +86,13 @@ def add_arguments(self, parser: "ArgumentParser") -> None:
default="",
help="Admin password",
)
parser.add_argument(
"--superusers",
nargs="+",
dest="superusers",
default=None,
help="Emails/usernames to grant superuser privileges (space-separated)",
)

def get_options(self, options: dict[str, Any]) -> None:
self.verbosity = options["verbosity"]
Expand All @@ -95,6 +105,7 @@ def get_options(self, options: dict[str, Any]) -> None:

self.admin_email = str(options["admin_email"] or env("ADMIN_EMAIL", ""))
self.admin_password = str(options["admin_password"] or env("ADMIN_PASSWORD", ""))
self.superusers = options["superusers"] if options["superusers"] is not None else env("SUPERUSERS", [])

def halt(self, e: Exception) -> None: # pragma: no cover
self.stdout.write(str(e), style_func=self.style.ERROR)
Expand All @@ -106,8 +117,62 @@ def halt(self, e: Exception) -> None: # pragma: no cover

sys.exit(1)

def handle(self, *args: Any, **options: Any) -> None: # noqa: C901
from country_workspace.models import Office, User
def _ensure_superuser(self, user: Any) -> bool:
if user.is_staff and user.is_superuser:
return False
user.is_staff = True
user.is_superuser = True
user.save(update_fields=["is_staff", "is_superuser"])
return True

def _run_createsuperuser(self, username: str, email: str) -> bool:
os.environ["DJANGO_SUPERUSER_USERNAME"] = username
os.environ["DJANGO_SUPERUSER_EMAIL"] = email

if password := self.admin_password if username == self.admin_email else "":
os.environ["DJANGO_SUPERUSER_PASSWORD"] = password
else:
os.environ.pop("DJANGO_SUPERUSER_PASSWORD", None)

call_command(
"createsuperuser",
email=email,
username=username,
verbosity=max(self.verbosity - 1, 0),
interactive=False,
)
return bool(password)

def _superuser_logins(self) -> list[str]:
raw = [self.admin_email, *self.superusers]
return list(dict.fromkeys(s for x in raw if x and (s := x.strip())))

def _create_superusers(self, echo: Any) -> None:
users = get_user_model().objects

for login in self._superuser_logins():
email = login if "@" in login else f"{login}@{FALLBACK_EMAIL_DOMAIN}"

if user := (
(users.filter(email=email).first() if "@" in login else None) or users.filter(username=login).first()
):
changed = self._ensure_superuser(user)
echo(
f"{'Granted superuser privileges' if changed else 'User found, skip'}: {login}",
style_func=self.style.WARNING,
)
continue

validate_email(email)
password_provided = self._run_createsuperuser(login, email)

echo(
f"Created superuser: {email}{'' if password_provided else ' with unusable password'}",
style_func=self.style.WARNING,
)

def handle(self, *args: Any, **options: Any) -> None:
from country_workspace.models import Office

self.get_options(options)
if self.verbosity >= 1:
Expand Down Expand Up @@ -143,25 +208,7 @@ def handle(self, *args: Any, **options: Any) -> None: # noqa: C901
echo("Run HOPE synchronisation")
call_command("sync", **extra)

if self.admin_email:
if User.objects.filter(email=self.admin_email).exists():
echo(
f"User {self.admin_email} found, skip creation",
style_func=self.style.WARNING,
)
else:
echo("Creating superuser")
validate_email(self.admin_email)
os.environ["DJANGO_SUPERUSER_USERNAME"] = self.admin_email
os.environ["DJANGO_SUPERUSER_EMAIL"] = self.admin_email
os.environ["DJANGO_SUPERUSER_PASSWORD"] = self.admin_password
call_command(
"createsuperuser",
email=self.admin_email,
username=self.admin_email,
verbosity=self.verbosity - 1,
interactive=False,
)
self._create_superusers(echo)

echo("Setup base security")
setup_workspace_group()
Expand Down
65 changes: 59 additions & 6 deletions tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from country_workspace.management.commands.sync import Command as SyncCommand, run_program_sync, run_geo_sync
import country_workspace.management.commands.gen_rdi as gen_rdi_cmd
from country_workspace.utils.gen_rdi import GenerationMode, GeneratorConfig
import country_workspace.management.commands.upgrade as upgrade_cmd

from testutils.factories import SuperUserFactory, UserFactory


if TYPE_CHECKING:
Expand Down Expand Up @@ -87,8 +90,6 @@ def test_upgrade_init(
@pytest.mark.parametrize("migrate", [1, 0], ids=["migrate", ""])
@override_config(HOPE_API_URL="https://dev-hope.unitst.org/api/rest/")
def test_upgrade(verbosity: int, migrate: int, mocker: MockerFixture, environment: dict[str, str]) -> None:
from testutils.factories import SuperUserFactory

out = StringIO()
SuperUserFactory()
mocker.patch.dict(os.environ, environment, clear=True)
Expand All @@ -98,8 +99,6 @@ def test_upgrade(verbosity: int, migrate: int, mocker: MockerFixture, environmen

@override_config(HOPE_API_URL="https://dev-hope.unitst.org/api/rest/")
def test_upgrade_next(mocked_responses: RequestsMock) -> None:
from testutils.factories import SuperUserFactory

SuperUserFactory()
out = StringIO()
call_command("upgrade", stdout=out, checks=False, sync_with_hope=False)
Expand All @@ -120,8 +119,6 @@ def test_upgrade_check(
def test_upgrade_admin(
mocker: MockerFixture, mocked_responses: RequestsMock, environment: dict[str, str], admin: str
) -> None:
from testutils.factories import SuperUserFactory

if admin:
email = SuperUserFactory().email
else:
Expand Down Expand Up @@ -287,3 +284,59 @@ def test_gen_rdi_validation_errors(cli_args: list[str], err: str) -> None:

with pytest.raises(CommandError, match=re.escape(err)):
call_command("gen_rdi", *cli_args)


@pytest.mark.parametrize(
("factory_key", "kwargs", "expect_changed"),
[
("user", {"is_staff": False, "is_superuser": False}, True),
("superuser", {}, False),
],
ids=["promotes", "noop"],
)
def test_upgrade_ensure_superuser(
mocker: MockerFixture,
factory_key: str,
kwargs: dict[str, object],
expect_changed: bool,
) -> None:
factory = {"user": UserFactory, "superuser": SuperUserFactory}[factory_key]
user = factory(**kwargs)
save_spy = mocker.spy(user, "save")

assert upgrade_cmd.Command()._ensure_superuser(user) is expect_changed

if expect_changed:
assert user.is_staff is True
assert user.is_superuser is True
save_spy.assert_called_once_with(update_fields=["is_staff", "is_superuser"])
else:
save_spy.assert_not_called()


def test_upgrade_run_createsuperuser_pops_password_env_when_missing(mocker: MockerFixture) -> None:
cmd = upgrade_cmd.Command()
cmd.admin_email = "[email protected]"
cmd.admin_password = ""
cmd.verbosity = 1

call = mocker.patch.object(upgrade_cmd, "call_command")
mocker.patch.dict(os.environ, {"DJANGO_SUPERUSER_PASSWORD": "stale"}, clear=True)

assert cmd._run_createsuperuser("[email protected]", "[email protected]") is False
assert "DJANGO_SUPERUSER_PASSWORD" not in os.environ

call.assert_called_once_with(
"createsuperuser",
email="[email protected]",
username="[email protected]",
verbosity=0,
interactive=False,
)


def test_upgrade_superuser_logins_drops_whitespace_only() -> None:
cmd = upgrade_cmd.Command()
cmd.admin_email = " [email protected] "
cmd.superusers = [" ", "u1", " u1 ", "", " "]
assert cmd._superuser_logins() == ["[email protected]", "u1"]