Skip to content

Commit

Permalink
Try to catch the race condidion when creating the user with a partner (
Browse files Browse the repository at this point in the history
  • Loading branch information
acasajus authored Nov 14, 2024
1 parent 8dd3771 commit dd2cfae
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 22 deletions.
67 changes: 49 additions & 18 deletions app/account_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import arrow
from arrow import Arrow
from newrelic import agent
from psycopg2.errors import UniqueViolation
from sqlalchemy import or_

from app.db import Session
Expand Down Expand Up @@ -160,15 +161,55 @@ def process(self) -> LinkResult:

class NewUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult:
# Will create a new SL User with a random password
canonical_email = canonicalize_email(self.link_request.email)
new_user = User.create(
email=canonical_email,
name=self.link_request.name,
password=random_string(20),
activated=True,
from_partner=self.link_request.from_partner,
try:
# Will create a new SL User with a random password
new_user = User.create(
email=canonical_email,
name=self.link_request.name,
password=random_string(20),
activated=True,
from_partner=self.link_request.from_partner,
)
self.create_partner_user(new_user)
Session.commit()

if not new_user.created_by_partner:
send_welcome_email(new_user)

agent.record_custom_event(
"PartnerUserCreation", {"partner": self.partner.name}
)

return LinkResult(
user=new_user,
strategy=self.__class__.__name__,
)
except UniqueViolation:
return self.create_missing_link(canonical_email)

def create_missing_link(self, canonical_email: str):
# If there's a unique key violation due to race conditions try to create only the partner if needed
partner_user = PartnerUser.get_by(
external_user_id=self.link_request.external_user_id,
partner_id=self.partner.id,
)
if partner_user is None:
# Get the user by canonical email and if not by normal email
user = User.get_by(email=canonical_email) or User.get_by(
email=self.link_request.email
)
if not user:
raise RuntimeError(
"Tried to create only partner on UniqueViolation but cannot find the user"
)
partner_user = self.create_partner_user(user)
Session.commit()
return LinkResult(
user=partner_user.user, strategy=ExistingUnlinkedUserStrategy.__name__
)

def create_partner_user(self, new_user: User):
partner_user = create_partner_user(
user=new_user,
partner_id=self.partner.id,
Expand All @@ -182,17 +223,7 @@ def process(self) -> LinkResult:
partner_user,
self.link_request.plan,
)
Session.commit()

if not new_user.created_by_partner:
send_welcome_email(new_user)

agent.record_custom_event("PartnerUserCreation", {"partner": self.partner.name})

return LinkResult(
user=new_user,
strategy=self.__class__.__name__,
)
return partner_user


class ExistingUnlinkedUserStrategy(ClientMergeStrategy):
Expand Down
8 changes: 4 additions & 4 deletions app/alias_suffix.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def verify_prefix_suffix(

# alias_domain must be either one of user custom domains or built-in domains
if alias_domain not in user.available_alias_domains(alias_options=alias_options):
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
LOG.i("wrong alias suffix %s, user %s", alias_suffix, user)
return False

# SimpleLogin domain case:
Expand All @@ -75,17 +75,17 @@ def verify_prefix_suffix(
and not config.DISABLE_ALIAS_SUFFIX
):
if not alias_domain_prefix.startswith("."):
LOG.e("User %s submits a wrong alias suffix %s", user, alias_suffix)
LOG.i("User %s submits a wrong alias suffix %s", user, alias_suffix)
return False

else:
if alias_domain not in user_custom_domains:
if not config.DISABLE_ALIAS_SUFFIX:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
LOG.i("wrong alias suffix %s, user %s", alias_suffix, user)
return False

if alias_domain not in available_sl_domains:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
LOG.i("wrong alias suffix %s, user %s", alias_suffix, user)
return False

return True
Expand Down
15 changes: 15 additions & 0 deletions tests/test_account_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ def test_login_case_from_web():
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value


def test_new_user_strategy_create_missing_link():
email = random_email()
user = User.create(email, commit=True)
nus = NewUserStrategy(
link_request=random_link_request(
email=user.email, external_user_id=random_string(), from_partner=False
),
user=None,
partner=get_proton_partner(),
)
result = nus.create_missing_link(user.email)
assert result.user.id == user.id
assert result.strategy == ExistingUnlinkedUserStrategy.__name__


def test_get_strategy_existing_sl_user():
email = random_email()
user = User.create(email, commit=True)
Expand Down

0 comments on commit dd2cfae

Please sign in to comment.