Skip to content

Commit

Permalink
Lower email before checking justice domain (#1439)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljcollinsuk authored Jan 21, 2025
1 parent b93cd0f commit 30137e0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
2 changes: 1 addition & 1 deletion controlpanel/api/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,7 @@ def get_name_from_email(self, user_email):

name, address = user_email.split("@")

if address not in settings.JUSTICE_EMAIL_DOMAINS:
if address.lower() not in settings.JUSTICE_EMAIL_DOMAINS:
raise ValueError("Expecting justice email")

if "." not in name:
Expand Down
14 changes: 10 additions & 4 deletions controlpanel/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ def create_user(self, claims):
"username": claims.get(settings.OIDC_FIELD_USERNAME),
"email": claims.get(settings.OIDC_FIELD_EMAIL),
"name": self.normalise_name(claims.get(settings.OIDC_FIELD_NAME)),
"justice_email": self.get_justice_email(claims.get(settings.OIDC_FIELD_EMAIL)),
}
email_domain = user_details["email"].split("@")[-1]
if email_domain in settings.JUSTICE_EMAIL_DOMAINS:
user_details["justice_email"] = user_details["email"]

return User.objects.create(**user_details)

def normalise_name(self, name):
Expand All @@ -49,6 +46,15 @@ def normalise_name(self, name):
name = " ".join(reversed(parts))
return name

def get_justice_email(self, email):
"""
Check if the email uses a justice domain and return it if it does, otherwise return None
"""
email_domain = email.split("@")[-1].lower()
if email_domain in settings.JUSTICE_EMAIL_DOMAINS:
return email
return None

def update_user(self, user, claims):
# Update the non-key information to sync the user's info
# with user profile from idp when the user's username is not changed.
Expand Down
1 change: 1 addition & 0 deletions tests/api/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ def test_get_embed_url(quicksight_service):
("[email protected]", "Carol", "Vor"),
("[email protected]", "Ronnie", "Hotdogs"),
("[email protected]", "Ci", "Ca"),
("[email protected]", "Ci", "Ca"),
],
)
def test_get_name_from_email(email, expected_forename, expected_surname):
Expand Down
20 changes: 18 additions & 2 deletions tests/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ def test_success_url(users, email, success_url):
@pytest.mark.parametrize(
"email, name, expected_name, expected_justice_email",
[
("email@exmaple.com", "User, Test", "Test User", None),
("email@exmaple.com", "Test User", "Test User", None),
("email@example.com", "User, Test", "Test User", None),
("email@example.com", "Test User", "Test User", None),
("[email protected]", "User, Test", "Test User", "[email protected]"),
("[email protected]", "Test User", "Test User", "[email protected]"),
("[email protected]", "Test User", "Test User", "[email protected]"),
("[email protected]", "Test User", "Test User", "[email protected]"),
],
)
def test_create_user(email, name, expected_name, expected_justice_email):
Expand All @@ -49,3 +51,17 @@ def test_create_user(email, name, expected_name, expected_justice_email):
)
assert user.name == expected_name
assert user.justice_email == expected_justice_email


@pytest.mark.parametrize(
"email, expected",
[
("[email protected]", None),
("[email protected]", "[email protected]"),
("[email protected]", "[email protected]"),
("[email protected]", "[email protected]"),
("[email protected]", "[email protected]"),
],
)
def test_get_justice_email(email, expected):
assert OIDCSubAuthenticationBackend().get_justice_email(email) == expected

0 comments on commit 30137e0

Please sign in to comment.