Skip to content

Commit

Permalink
cas: add login_attrs()
Browse files Browse the repository at this point in the history
  • Loading branch information
taoky committed Oct 14, 2023
1 parent 60ed8dd commit 126bd1a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 17 deletions.
4 changes: 4 additions & 0 deletions frontend/auth_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def on_get_account(self, account):
def normalize_identity(self):
return self.identity

# kwargs 会在首次登录(注册)时传入 User.create
# 用户可以修改这些信息。
# 对于 CAS 等场合(identity 不是学号,或者返回了除了学号以外更多的信息)
# 可以将有关信息记录在 AccountLog 中,以备查阅。
def login(self, **kwargs):
account, created = Account.objects.get_or_create(
provider=self.provider,
Expand Down
7 changes: 5 additions & 2 deletions frontend/auth_providers/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.contrib import messages
from django.shortcuts import redirect

from typing import Optional
from typing import Optional, Any

from .base import BaseLoginView

Expand All @@ -19,6 +19,9 @@ class CASBaseLoginView(BaseLoginView):

YALE_CAS_URL = "{http://www.yale.edu/tp/cas}"

def login_attrs(self) -> dict[str, Any]:
raise NotImplementedError("CAS 登录需要实现 login_attrs()")

def get(self, request: HttpRequest):
self.service = request.build_absolute_uri(request.path)
self.ticket = request.GET.get("ticket")
Expand All @@ -27,7 +30,7 @@ def get(self, request: HttpRequest):
self.cas_login_url + "?" + urlencode({"service": self.service})
)
if self.check_ticket():
self.login(sno=self.sno)
self.login(**self.login_attrs())
return redirect("hub")

def check_ticket(self) -> Optional[ElementTree.Element]:
Expand Down
13 changes: 10 additions & 3 deletions frontend/auth_providers/sustech.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.urls import path

from typing import Optional
from typing import Any, Optional

from ..models import AccountLog
from .cas import CASBaseLoginView
Expand All @@ -19,12 +19,19 @@ class LoginView(CASBaseLoginView):
cas_login_url = 'https://sso.cra.ac.cn/realms/cra-service-realm/protocol/cas/login'
cas_service_validate_url = 'https://sso.cra.ac.cn/realms/cra-service-realm/protocol/cas/serviceValidate'

def login_attrs(self) -> dict[str, Any]:
return {
"sno": self.identity,
"email": self.email,
"name": self.name,
}

def check_ticket(self) -> Optional[ElementTree.Element]:
tree = super().check_ticket()
if not tree:
return None
self.identity = tree.find(self.YALE_CAS_URL + 'user').text.strip()
self.mail = tree.find(self.YALE_CAS_URL + 'attributes').find(self.YALE_CAS_URL + 'mail').text.strip()
self.email = tree.find(self.YALE_CAS_URL + 'attributes').find(self.YALE_CAS_URL + 'mail').text.strip()
self.name = tree.find(self.YALE_CAS_URL + 'attributes').find(self.YALE_CAS_URL + 'cn').text.strip()
return tree

Expand All @@ -34,7 +41,7 @@ def to_set(s):
def from_set(vs):
return ','.join(sorted(vs))
custom_attrs: list[tuple[str, str]] = [
('邮箱', self.mail),
('邮箱', self.email),
('姓名', self.name)
]
for display_name, self_value in custom_attrs:
Expand Down
7 changes: 6 additions & 1 deletion frontend/auth_providers/ustc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.urls import path

from typing import Optional
from typing import Optional, Any

from ..models import AccountLog
from .cas import CASBaseLoginView
Expand All @@ -19,6 +19,11 @@ class LoginView(CASBaseLoginView):
cas_login_url = 'https://passport.ustc.edu.cn/login'
cas_service_validate_url = 'https://passport.ustc.edu.cn/serviceValidate'

def login_attrs(self) -> dict[str, Any]:
return {
"sno": self.sno,
}

def check_ticket(self) -> Optional[ElementTree.Element]:
tree = super().check_ticket()
if not tree:
Expand Down
39 changes: 28 additions & 11 deletions frontend/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,41 @@ def mock_urlopen(url, timeout=None):
raise ValueError("Unknown URL")


@mock.patch("frontend.auth_providers.cas.urlopen", new=mock_urlopen)
def ustc_check_ticket():
v = USTCLoginView()
v.service = "http://example.com/accounts/ustc/login/"
v.ticket = "ST-1234567890"
return v, v.check_ticket()

@mock.patch("frontend.auth_providers.cas.urlopen", new=mock_urlopen)
def sustech_check_ticket():
v = SUSTECHLoginView()
v.service = "http://example.com/accounts/sustech/login/"
v.ticket = "ST-1234567890"
return v, v.check_ticket()


class AuthProviderCASServiceValidateTest(TestCase):
@mock.patch("frontend.auth_providers.cas.urlopen", new=mock_urlopen)
def test_ustc(self):
v = USTCLoginView()
v.service = "http://example.com/accounts/ustc/login/"
v.ticket = "ST-1234567890"
tree = v.check_ticket()
v, tree = ustc_check_ticket()
self.assertEqual(tree.tag, "{http://www.yale.edu/tp/cas}authenticationSuccess")
self.assertEqual(v.identity, "2201234567")
self.assertEqual(v.sno, "SA21011000")

@mock.patch("frontend.auth_providers.cas.urlopen", new=mock_urlopen)
def test_sustech(self):
v = SUSTECHLoginView()
v.service = "http://example.com/accounts/sustech/login/"
v.ticket = "ST-1234567890"
tree = v.check_ticket()
v, tree = sustech_check_ticket()
self.assertEqual(tree.tag, "{http://www.yale.edu/tp/cas}authenticationSuccess")
self.assertEqual(v.identity, "11899999")
self.assertEqual(v.mail, "[email protected]")
self.assertEqual(v.email, "[email protected]")
self.assertEqual(v.name, "ZHANG San")


class AuthProviderCASHasImplementedLoginAttrs(TestCase):
def test_ustc(self):
v, _ = ustc_check_ticket()
v.login_attrs()

def test_sustech(self):
v, _ = sustech_check_ticket()
v.login_attrs()

0 comments on commit 126bd1a

Please sign in to comment.