Compare commits

..

No commits in common. 'master' and 'master' have entirely different histories.

Binary file not shown.

@ -1,80 +1,52 @@
# admin.py - Django后台管理配置文件
# 导入Django表单模块
from django import forms
# 导入Django用户管理类
from django.contrib.auth.admin import UserAdmin
# 导入用户修改表单
from django.contrib.auth.forms import UserChangeForm
# 导入用户名字段
from django.contrib.auth.forms import UsernameField
# 导入国际化翻译函数
from django.utils.translation import gettext_lazy as _
# 在这里注册模型
# 导入自定义的用户模型
# Register your models here.
from .models import BlogUser
# 自定义用户创建表单
class BlogUserCreationForm(forms.ModelForm):
# 密码字段1 - 用于输入密码
password1 = forms.CharField(label=_('password'), widget=forms.PasswordInput)
# 密码字段2 - 用于确认密码
password2 = forms.CharField(label=_('Enter password again'), widget=forms.PasswordInput)
class Meta:
# 指定关联的模型
model = BlogUser
# 表单包含的字段
fields = ('email',)
# 清理密码确认字段的方法
def clean_password2(self):
# 从已清理的数据中获取两个密码字段的值
# Check that the two password entries match
password1 = self.cleaned_data.get("password1")
password2 = self.cleaned_data.get("password2")
# 检查两个密码是否匹配
if password1 and password2 and password1 != password2:
raise forms.ValidationError(_("passwords do not match"))
return password2
# 保存用户的方法
def save(self, commit=True):
# 调用父类的save方法但不立即提交到数据库
# Save the provided password in hashed format
user = super().save(commit=False)
# 设置哈希后的密码
user.set_password(self.cleaned_data["password1"])
if commit:
# 设置用户来源为管理员站点
user.source = 'adminsite'
# 保存用户到数据库
user.save()
return user
# 自定义用户修改表单
class BlogUserChangeForm(UserChangeForm):
class Meta:
# 指定关联的模型
model = BlogUser
# 包含所有字段
fields = '__all__'
# 字段类型映射
field_classes = {'username': UsernameField}
# 初始化方法
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 自定义用户管理类
class BlogUserAdmin(UserAdmin):
# 指定修改表单
form = BlogUserChangeForm
# 指定创建表单
add_form = BlogUserCreationForm
# 列表页面显示的字段
list_display = (
'id',
'nickname',
@ -83,7 +55,5 @@ class BlogUserAdmin(UserAdmin):
'last_login',
'date_joined',
'source')
# 列表中可点击链接的字段
list_display_links = ('id', 'username')
# 默认排序字段按ID降序
ordering = ('-id',)
ordering = ('-id',)

@ -1,10 +1,5 @@
# apps.py - Django应用程序配置文件
# 导入Django应用配置基类
from django.apps import AppConfig
# 定义账户应用的配置类
class AccountsConfig(AppConfig):
# 指定应用程序的完整Python路径
name = 'accounts'
name = 'accounts'

@ -1,71 +1,47 @@
# forms.py - Django表单定义文件
# 导入Django表单模块
from django import forms
# 导入用户模型相关函数和表单类
from django.contrib.auth import get_user_model, password_validation
from django.contrib.auth.forms import AuthenticationForm, UserCreationForm
# 导入验证错误异常
from django.core.exceptions import ValidationError
# 导入表单小部件
from django.forms import widgets
# 导入国际化翻译函数
from django.utils.translation import gettext_lazy as _
# 导入工具模块
from . import utils
# 导入用户模型
from .models import BlogUser
# 登录表单类
class LoginForm(AuthenticationForm):
# 初始化方法,设置表单字段的样式和属性
def __init__(self, *args, **kwargs):
super(LoginForm, self).__init__(*args, **kwargs)
# 设置用户名字段的输入框样式
self.fields['username'].widget = widgets.TextInput(
attrs={'placeholder': "username", "class": "form-control"})
# 设置密码字段的输入框样式
self.fields['password'].widget = widgets.PasswordInput(
attrs={'placeholder': "password", "class": "form-control"})
# 注册表单类
class RegisterForm(UserCreationForm):
# 初始化方法,设置表单字段的样式和属性
def __init__(self, *args, **kwargs):
super(RegisterForm, self).__init__(*args, **kwargs)
# 设置用户名字段的输入框样式
self.fields['username'].widget = widgets.TextInput(
attrs={'placeholder': "username", "class": "form-control"})
# 设置邮箱字段的输入框样式
self.fields['email'].widget = widgets.EmailInput(
attrs={'placeholder': "email", "class": "form-control"})
# 设置密码字段的输入框样式
self.fields['password1'].widget = widgets.PasswordInput(
attrs={'placeholder': "password", "class": "form-control"})
# 设置确认密码字段的输入框样式
self.fields['password2'].widget = widgets.PasswordInput(
attrs={'placeholder': "repeat password", "class": "form-control"})
# 清理邮箱字段的方法,检查邮箱是否已存在
def clean_email(self):
email = self.cleaned_data['email']
# 检查邮箱是否已被注册
if get_user_model().objects.filter(email=email).exists():
raise ValidationError(_("email already exists"))
return email
class Meta:
# 指定关联的用户模型
model = get_user_model()
# 表单包含的字段
fields = ("username", "email")
# 忘记密码表单类
class ForgetPasswordForm(forms.Form):
# 新密码字段
new_password1 = forms.CharField(
label=_("New password"),
widget=forms.PasswordInput(
@ -76,7 +52,6 @@ class ForgetPasswordForm(forms.Form):
),
)
# 确认新密码字段
new_password2 = forms.CharField(
label="确认密码",
widget=forms.PasswordInput(
@ -87,7 +62,6 @@ class ForgetPasswordForm(forms.Form):
),
)
# 邮箱字段
email = forms.EmailField(
label='邮箱',
widget=forms.TextInput(
@ -98,7 +72,6 @@ class ForgetPasswordForm(forms.Form):
),
)
# 验证码字段
code = forms.CharField(
label=_('Code'),
widget=forms.TextInput(
@ -109,22 +82,17 @@ class ForgetPasswordForm(forms.Form):
),
)
# 清理确认密码字段的方法
def clean_new_password2(self):
password1 = self.data.get("new_password1")
password2 = self.data.get("new_password2")
# 检查两个密码是否匹配
if password1 and password2 and password1 != password2:
raise ValidationError(_("passwords do not match"))
# 验证密码强度
password_validation.validate_password(password2)
return password2
# 清理邮箱字段的方法,检查邮箱是否存在
def clean_email(self):
user_email = self.cleaned_data.get("email")
# 检查邮箱是否已注册
if not BlogUser.objects.filter(
email=user_email
).exists():
@ -132,10 +100,8 @@ class ForgetPasswordForm(forms.Form):
raise ValidationError(_("email does not exist"))
return user_email
# 清理验证码字段的方法
def clean_code(self):
code = self.cleaned_data.get("code")
# 验证邮箱和验证码是否匹配
error = utils.verify(
email=self.cleaned_data.get("email"),
code=code,
@ -145,9 +111,7 @@ class ForgetPasswordForm(forms.Form):
return code
# 获取忘记密码验证码表单类
class ForgetPasswordCodeForm(forms.Form):
# 邮箱字段
email = forms.EmailField(
label=_('Email'),
)
)

@ -1,58 +1,35 @@
# models.py - Django数据模型定义文件
# 导入Django内置的用户抽象基类
from django.contrib.auth.models import AbstractUser
# 导入Django数据库模型
from django.db import models
# 导入URL反向解析函数
from django.urls import reverse
# 导入时间相关函数
from django.utils.timezone import now
# 导入国际化翻译函数
from django.utils.translation import gettext_lazy as _
# 导入工具函数获取当前站点
from djangoblog.utils import get_current_site
# 在这里创建模型
# Create your models here.
# 博客用户模型继承自Django的AbstractUser
class BlogUser(AbstractUser):
# 昵称字段最大长度100字符允许为空
nickname = models.CharField(_('nick name'), max_length=100, blank=True)
# 创建时间字段,默认值为当前时间
creation_time = models.DateTimeField(_('creation time'), default=now)
# 最后修改时间字段,默认值为当前时间
last_modify_time = models.DateTimeField(_('last modify time'), default=now)
# 用户来源字段记录创建来源最大长度100字符允许为空
source = models.CharField(_('create source'), max_length=100, blank=True)
# 获取用户绝对URL的方法
def get_absolute_url(self):
return reverse(
'blog:author_detail', kwargs={
'author_name': self.username})
# 对象的字符串表示方法,返回邮箱地址
def __str__(self):
return self.email
# 获取完整URL的方法包含域名
def get_full_url(self):
# 获取当前站点域名
site = get_current_site().domain
# 构建完整的URL
url = "https://{site}{path}".format(site=site,
path=self.get_absolute_url())
return url
# 模型的元数据配置
class Meta:
# 默认按ID降序排列
ordering = ['-id']
# 单数形式的显示名称
verbose_name = _('user')
# 复数形式的显示名称(与单数相同)
verbose_name_plural = verbose_name
# 指定获取最新对象的字段
get_latest_by = 'id'
get_latest_by = 'id'

@ -1,186 +1,135 @@
# tests.py - Django测试用例文件
# 导入Django测试相关模块
from django.test import Client, RequestFactory, TestCase
# 导入URL反向解析
from django.urls import reverse
# 导入时区相关功能
from django.utils import timezone
# 导入国际化翻译函数
from django.utils.translation import gettext_lazy as _
# 导入账户相关模型
from accounts.models import BlogUser
# 导入博客相关模型
from blog.models import Article, Category
# 导入工具函数
from djangoblog.utils import *
# 导入当前应用的工具模块
from . import utils
# 在这里创建测试用例
# Create your tests here.
# 账户测试类
class AccountTest(TestCase):
# 测试前置设置方法
def setUp(self):
# 创建测试客户端
self.client = Client()
# 创建请求工厂
self.factory = RequestFactory()
# 创建测试用户
self.blog_user = BlogUser.objects.create_user(
username="test",
email="admin@admin.com",
password="12345678"
)
# 设置测试用的新密码
self.new_test = "xxx123--="
# 测试账户验证功能
def test_validate_account(self):
# 获取当前站点域名
site = get_current_site().domain
# 创建超级用户
user = BlogUser.objects.create_superuser(
email="liangliangyy1@gmail.com",
username="liangliangyy1",
password="qwer!@#$ggg")
# 获取刚创建的用户
testuser = BlogUser.objects.get(username='liangliangyy1')
# 测试登录功能
loginresult = self.client.login(
username='liangliangyy1',
password='qwer!@#$ggg')
# 断言登录成功
self.assertEqual(loginresult, True)
# 测试访问管理员页面
response = self.client.get('/admin/')
self.assertEqual(response.status_code, 200)
# 创建测试分类
category = Category()
category.name = "categoryaaa"
category.creation_time = timezone.now()
category.last_modify_time = timezone.now()
category.save()
# 创建测试文章
article = Article()
article.title = "nicetitleaaa"
article.body = "nicecontentaaa"
article.author = user
article.category = category
article.type = 'a' # 文章类型
article.status = 'p' # 发布状态
article.type = 'a'
article.status = 'p'
article.save()
# 测试访问文章管理页面
response = self.client.get(article.get_admin_url())
self.assertEqual(response.status_code, 200)
# 测试用户注册功能
def test_validate_register(self):
# 断言邮箱不存在
self.assertEquals(
0, len(
BlogUser.objects.filter(
email='user123@user.com')))
# 发送注册请求
response = self.client.post(reverse('account:register'), {
'username': 'user1233',
'email': 'user123@user.com',
'password1': 'password123!q@wE#R$T',
'password2': 'password123!q@wE#R$T',
})
# 断言用户创建成功
self.assertEquals(
1, len(
BlogUser.objects.filter(
email='user123@user.com')))
# 获取新创建的用户
user = BlogUser.objects.filter(email='user123@user.com')[0]
# 生成验证签名
sign = get_sha256(get_sha256(settings.SECRET_KEY + str(user.id)))
path = reverse('accounts:result')
# 构建验证URL
url = '{path}?type=validation&id={id}&sign={sign}'.format(
path=path, id=user.id, sign=sign)
# 测试验证页面
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
# 测试用户登录
self.client.login(username='user1233', password='password123!q@wE#R$T')
user = BlogUser.objects.filter(email='user123@user.com')[0]
# 设置用户为超级用户和管理员
user.is_superuser = True
user.is_staff = True
user.save()
# 清理侧边栏缓存
delete_sidebar_cache()
# 创建分类
category = Category()
category.name = "categoryaaa"
category.creation_time = timezone.now()
category.last_modify_time = timezone.now()
category.save()
# 创建文章
article = Article()
article.category = category
article.title = "nicetitle333"
article.body = "nicecontentttt"
article.author = user
article.type = 'a'
article.status = 'p'
article.save()
# 测试访问文章管理页面
response = self.client.get(article.get_admin_url())
self.assertEqual(response.status_code, 200)
# 测试退出登录
response = self.client.get(reverse('account:logout'))
self.assertIn(response.status_code, [301, 302, 200])
# 测试退出后访问管理页面(应该重定向)
response = self.client.get(article.get_admin_url())
self.assertIn(response.status_code, [301, 302, 200])
# 测试错误密码登录
response = self.client.post(reverse('account:login'), {
'username': 'user1233',
'password': 'password123'
})
self.assertIn(response.status_code, [301, 302, 200])
# 测试登录后访问管理页面
response = self.client.get(article.get_admin_url())
self.assertIn(response.status_code, [301, 302, 200])
# 测试邮箱验证码功能
def test_verify_email_code(self):
to_email = "admin@admin.com"
# 生成验证码
code = generate_code()
# 设置验证码
utils.set_code(to_email, code)
# 发送验证邮件
utils.send_verify_email(to_email, code)
# 测试正确验证码
err = utils.verify("admin@admin.com", code)
self.assertEqual(err, None)
# 测试错误邮箱
err = utils.verify("admin@123.com", code)
self.assertEqual(type(err), str)
# 测试成功发送忘记密码验证码
def test_forget_password_email_code_success(self):
resp = self.client.post(
path=reverse("account:forget_password_code"),
@ -190,40 +139,33 @@ class AccountTest(TestCase):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.content.decode("utf-8"), "ok")
# 测试发送忘记密码验证码失败情况
def test_forget_password_email_code_fail(self):
# 测试空数据
resp = self.client.post(
path=reverse("account:forget_password_code"),
data=dict()
)
self.assertEqual(resp.content.decode("utf-8"), "错误的邮箱")
# 测试错误邮箱格式
resp = self.client.post(
path=reverse("account:forget_password_code"),
data=dict(email="admin@com")
)
self.assertEqual(resp.content.decode("utf-8"), "错误的邮箱")
# 测试成功重置密码
def test_forget_password_email_success(self):
# 生成并设置验证码
code = generate_code()
utils.set_code(self.blog_user.email, code)
# 准备重置密码数据
data = dict(
new_password1=self.new_test,
new_password2=self.new_test,
email=self.blog_user.email,
code=code,
)
# 发送重置密码请求
resp = self.client.post(
path=reverse("account:forget_password"),
data=data
)
self.assertEqual(resp.status_code, 302) # 重定向响应
self.assertEqual(resp.status_code, 302)
# 验证用户密码是否修改成功
blog_user = BlogUser.objects.filter(
@ -232,7 +174,6 @@ class AccountTest(TestCase):
self.assertNotEqual(blog_user, None)
self.assertEqual(blog_user.check_password(data["new_password1"]), True)
# 测试重置密码时邮箱不存在的情况
def test_forget_password_email_not_user(self):
data = dict(
new_password1=self.new_test,
@ -247,7 +188,7 @@ class AccountTest(TestCase):
self.assertEqual(resp.status_code, 200)
# 测试重置密码时验证码错误的情况
def test_forget_password_email_code_error(self):
code = generate_code()
utils.set_code(self.blog_user.email, code)
@ -255,11 +196,12 @@ class AccountTest(TestCase):
new_password1=self.new_test,
new_password2=self.new_test,
email=self.blog_user.email,
code="111111", # 错误的验证码
code="111111",
)
resp = self.client.post(
path=reverse("account:forget_password"),
data=data
)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.status_code, 200)

@ -1,54 +1,28 @@
# urls.py - Django URL路由配置文件
# 导入Django URL路由相关函数
from django.urls import path
from django.urls import re_path
# 导入当前应用的视图模块
from . import views
# 导入登录表单类
from .forms import LoginForm
# 定义应用命名空间用于URL反向解析
app_name = "accounts"
# URL模式列表 - 定义URL路径与视图的映射关系
urlpatterns = [
# 登录URL - 使用正则表达式匹配路径
re_path(r'^login/$',
# 使用类视图,设置登录成功后跳转到首页
views.LoginView.as_view(success_url='/'),
name='login', # URL名称用于反向解析
# 传递额外参数,指定认证表单类
kwargs={'authentication_form': LoginForm}),
# 注册URL - 使用正则表达式匹配路径
re_path(r'^register/$',
# 使用类视图,设置注册成功后跳转到首页
views.RegisterView.as_view(success_url="/"),
name='register'), # URL名称用于反向解析
# 退出登录URL - 使用正则表达式匹配路径
re_path(r'^logout/$',
# 使用类视图
views.LogoutView.as_view(),
name='logout'), # URL名称用于反向解析
# 账户操作结果页面URL - 使用path匹配精确路径
path(r'account/result.html',
# 使用函数视图
views.account_result,
name='result'), # URL名称用于反向解析
# 忘记密码URL - 使用正则表达式匹配路径
re_path(r'^forget_password/$',
# 使用类视图
views.ForgetPasswordView.as_view(),
name='forget_password'), # URL名称用于反向解析
# 获取忘记密码验证码URL - 使用正则表达式匹配路径
re_path(r'^forget_password_code/$',
# 使用类视图
views.ForgetPasswordEmailCode.as_view(),
name='forget_password_code'), # URL名称用于反向解析
]
urlpatterns = [re_path(r'^login/$',
views.LoginView.as_view(success_url='/'),
name='login',
kwargs={'authentication_form': LoginForm}),
re_path(r'^register/$',
views.RegisterView.as_view(success_url="/"),
name='register'),
re_path(r'^logout/$',
views.LogoutView.as_view(),
name='logout'),
path(r'account/result.html',
views.account_result,
name='result'),
re_path(r'^forget_password/$',
views.ForgetPasswordView.as_view(),
name='forget_password'),
re_path(r'^forget_password_code/$',
views.ForgetPasswordEmailCode.as_view(),
name='forget_password_code'),
]

@ -1,42 +1,26 @@
# user_login_backend.py - 自定义用户认证后端
# 导入获取用户模型的函数
from django.contrib.auth import get_user_model
# 导入Django模型认证后端基类
from django.contrib.auth.backends import ModelBackend
# 自定义认证后端类,支持邮箱或用户名登录
class EmailOrUsernameModelBackend(ModelBackend):
"""
允许使用用户名或邮箱登录的自定义认证后端
允许使用用户名或邮箱登录
"""
# 用户认证方法
def authenticate(self, request, username=None, password=None, **kwargs):
# 判断输入的是邮箱还是用户名
if '@' in username:
# 如果包含@符号,按邮箱处理
kwargs = {'email': username}
else:
# 否则按用户名处理
kwargs = {'username': username}
try:
# 根据用户名或邮箱获取用户
user = get_user_model().objects.get(**kwargs)
# 验证密码是否正确
if user.check_password(password):
return user # 认证成功,返回用户对象
return user
except get_user_model().DoesNotExist:
# 用户不存在返回None
return None
# 根据用户ID获取用户的方法
def get_user(self, username):
try:
# 根据主键用户ID获取用户
return get_user_model().objects.get(pk=username)
except get_user_model().DoesNotExist:
# 用户不存在返回None
return None
return None

@ -1,71 +1,49 @@
# utils.py - 工具函数模块,处理验证码相关功能
# 导入类型提示模块
import typing
# 导入时间间隔类
from datetime import timedelta
# 导入Django缓存模块
from django.core.cache import cache
# 导入国际化翻译函数
from django.utils.translation import gettext
from django.utils.translation import gettext_lazy as _
# 导入发送邮件工具函数
from djangoblog.utils import send_email
# 验证码有效期设置为5分钟
_code_ttl = timedelta(minutes=5)
def send_verify_email(to_mail: str, code: str, subject: str = _("Verify Email")):
"""发送验证邮件
"""发送重设密码验证码
Args:
to_mail: 收邮箱地址
subject: 邮件主题默认为"验证邮件"
to_mail: 受邮箱
subject: 邮件主题
code: 验证码
"""
# 构建邮件HTML内容包含验证码信息
html_content = _(
"You are resetting the password, the verification code is%(code)s, valid within 5 minutes, please keep it "
"properly") % {'code': code}
# 调用发送邮件函数
send_email([to_mail], subject, html_content)
def verify(email: str, code: str) -> typing.Optional[str]:
"""验证验证码是否有效
"""验证code是否有效
Args:
email: 请求验证的邮箱地址
code: 用户输入的验证码
email: 请求邮箱
code: 验证码
Return:
如果验证失败返回错误信息字符串成功返回None
Note:
这里的错误处理不太合理应该采用raise抛出异常
调用方也需要对error进行处理
如果有错误就返回错误str
Node:
这里的错误处理不太合理应该采用raise抛出
调用方也需要对error进行处理
"""
# 从缓存中获取该邮箱对应的验证码
cache_code = get_code(email)
# 比较用户输入的验证码和缓存中的验证码
if cache_code != code:
return gettext("Verification code error")
def set_code(email: str, code: str):
"""设置验证码到缓存中
Args:
email: 邮箱地址作为缓存的key
code: 验证码作为缓存的value
"""
# 将验证码存入缓存设置过期时间为5分钟
"""设置code"""
cache.set(email, code, _code_ttl.seconds)
def get_code(email: str) -> typing.Optional[str]:
"""从缓存中获取验证码
Args:
email: 邮箱地址作为缓存的key
Return:
返回验证码字符串如果不存在则返回None
"""
return cache.get(email)
"""获取code"""
return cache.get(email)

@ -1,89 +1,59 @@
# views.py - Django视图文件处理用户账户相关请求
# 导入日志模块
import logging
# 导入国际化翻译函数
from django.utils.translation import gettext_lazy as _
# 导入Django设置
from django.conf import settings
# 导入Django认证相关模块
from django.contrib import auth
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.contrib.auth import get_user_model
from django.contrib.auth import logout
from django.contrib.auth.forms import AuthenticationForm
from django.contrib.auth.hashers import make_password
# 导入HTTP响应类
from django.http import HttpResponseRedirect, HttpResponseForbidden
from django.http.request import HttpRequest
from django.http.response import HttpResponse
# 导入快捷函数
from django.shortcuts import get_object_or_404
from django.shortcuts import render
# 导入URL反向解析
from django.urls import reverse
# 导入方法装饰器
from django.utils.decorators import method_decorator
from django.utils.http import url_has_allowed_host_and_scheme
# 导入基于类的视图
from django.views import View
from django.views.decorators.cache import never_cache
from django.views.decorators.csrf import csrf_protect
from django.views.decorators.debug import sensitive_post_parameters
from django.views.generic import FormView, RedirectView
# 导入工具函数
from djangoblog.utils import send_email, get_sha256, get_current_site, generate_code, delete_sidebar_cache
# 导入当前应用的工具模块
from . import utils
# 导入表单类
from .forms import RegisterForm, LoginForm, ForgetPasswordForm, ForgetPasswordCodeForm
# 导入用户模型
from .models import BlogUser
# 获取日志记录器
logger = logging.getLogger(__name__)
# 在这里创建视图
# Create your views here.
# 用户注册视图
class RegisterView(FormView):
# 指定使用的表单类
form_class = RegisterForm
# 指定模板文件
template_name = 'account/registration_form.html'
# 使用CSRF保护装饰器
@method_decorator(csrf_protect)
def dispatch(self, *args, **kwargs):
return super(RegisterView, self).dispatch(*args, **kwargs)
# 表单验证通过后的处理
def form_valid(self, form):
if form.is_valid():
# 保存用户但不提交到数据库
user = form.save(False)
# 设置用户为非激活状态(需要邮箱验证)
user.is_active = False
# 设置用户来源
user.source = 'Register'
# 保存用户到数据库
user.save(True)
# 获取当前站点域名
site = get_current_site().domain
# 生成验证签名
sign = get_sha256(get_sha256(settings.SECRET_KEY + str(user.id)))
# 调试模式下使用本地地址
if settings.DEBUG:
site = '127.0.0.1:8000'
# 构建验证URL
path = reverse('account:result')
url = "http://{site}{path}?type=validation&id={id}&sign={sign}".format(
site=site, path=path, id=user.id, sign=sign)
# 构建邮件内容
content = """
<p>请点击下面链接验证您的邮箱</p>
@ -94,7 +64,6 @@ class RegisterView(FormView):
如果上面链接无法打开请将此链接复制至浏览器
{url}
""".format(url=url)
# 发送验证邮件
send_email(
emailto=[
user.email,
@ -102,59 +71,43 @@ class RegisterView(FormView):
title='验证您的电子邮箱',
content=content)
# 重定向到结果页面
url = reverse('accounts:result') + \
'?type=register&id=' + str(user.id)
return HttpResponseRedirect(url)
else:
# 表单验证失败,重新渲染表单
return self.render_to_response({
'form': form
})
# 用户退出登录视图
class LogoutView(RedirectView):
# 退出后重定向的URL
url = '/login/'
# 禁用缓存
@method_decorator(never_cache)
def dispatch(self, request, *args, **kwargs):
return super(LogoutView, self).dispatch(request, *args, **kwargs)
# 处理GET请求
def get(self, request, *args, **kwargs):
# 执行退出登录操作
logout(request)
# 清理侧边栏缓存
delete_sidebar_cache()
return super(LogoutView, self).get(request, *args, **kwargs)
# 用户登录视图
class LoginView(FormView):
# 指定使用的表单类
form_class = LoginForm
# 指定模板文件
template_name = 'account/login.html'
# 登录成功后的默认重定向URL
success_url = '/'
# 重定向字段名
redirect_field_name = REDIRECT_FIELD_NAME
# 登录会话有效期(一个月)
login_ttl = 2626560
login_ttl = 2626560 # 一个月的时间
# 使用多个装饰器保护敏感操作
@method_decorator(sensitive_post_parameters('password'))
@method_decorator(csrf_protect)
@method_decorator(never_cache)
def dispatch(self, request, *args, **kwargs):
return super(LoginView, self).dispatch(request, *args, **kwargs)
# 获取上下文数据
def get_context_data(self, **kwargs):
# 获取重定向URL
redirect_to = self.request.GET.get(self.redirect_field_name)
if redirect_to is None:
redirect_to = '/'
@ -162,33 +115,26 @@ class LoginView(FormView):
return super(LoginView, self).get_context_data(**kwargs)
# 表单验证通过后的处理
def form_valid(self, form):
form = AuthenticationForm(data=self.request.POST, request=self.request)
if form.is_valid():
# 清理侧边栏缓存
delete_sidebar_cache()
# 记录日志
logger.info(self.redirect_field_name)
# 执行登录操作
auth.login(self.request, form.get_user())
# 如果用户选择"记住我",设置会话有效期
if self.request.POST.get("remember"):
self.request.session.set_expiry(self.login_ttl)
return super(LoginView, self).form_valid(form)
# return HttpResponseRedirect('/')
else:
# 表单验证失败,重新渲染表单
return self.render_to_response({
'form': form
})
# 获取登录成功后的重定向URL
def get_success_url(self):
# 从POST数据中获取重定向URL
redirect_to = self.request.POST.get(self.redirect_field_name)
# 检查URL是否安全
if not url_has_allowed_host_and_scheme(
url=redirect_to, allowed_hosts=[
self.request.get_host()]):
@ -196,91 +142,63 @@ class LoginView(FormView):
return redirect_to
# 账户操作结果页面视图函数
def account_result(request):
# 获取操作类型和用户ID
type = request.GET.get('type')
id = request.GET.get('id')
# 获取用户对象如果不存在返回404
user = get_object_or_404(get_user_model(), id=id)
logger.info(type)
# 如果用户已激活,重定向到首页
if user.is_active:
return HttpResponseRedirect('/')
# 处理注册和验证类型
if type and type in ['register', 'validation']:
if type == 'register':
# 注册成功页面内容
content = '''
恭喜您注册成功一封验证邮件已经发送到您的邮箱请验证您的邮箱后登录本站
'''
title = '注册成功'
else:
# 验证邮箱签名
c_sign = get_sha256(get_sha256(settings.SECRET_KEY + str(user.id)))
sign = request.GET.get('sign')
# 签名不匹配返回403禁止访问
if sign != c_sign:
return HttpResponseForbidden()
# 激活用户账户
user.is_active = True
user.save()
content = '''
恭喜您已经成功的完成邮箱验证您现在可以使用您的账号来登录本站
'''
title = '验证成功'
# 渲染结果页面
return render(request, 'account/result.html', {
'title': title,
'content': content
})
else:
# 无效类型,重定向到首页
return HttpResponseRedirect('/')
# 忘记密码视图
class ForgetPasswordView(FormView):
# 指定使用的表单类
form_class = ForgetPasswordForm
# 指定模板文件
template_name = 'account/forget_password.html'
# 表单验证通过后的处理
def form_valid(self, form):
if form.is_valid():
# 根据邮箱获取用户
blog_user = BlogUser.objects.filter(email=form.cleaned_data.get("email")).get()
# 设置新密码(自动哈希)
blog_user.password = make_password(form.cleaned_data["new_password2"])
# 保存用户
blog_user.save()
# 重定向到登录页面
return HttpResponseRedirect('/login/')
else:
# 表单验证失败,重新渲染表单
return self.render_to_response({'form': form})
# 忘记密码验证码发送视图
class ForgetPasswordEmailCode(View):
# 处理POST请求
def post(self, request: HttpRequest):
# 初始化表单
form = ForgetPasswordCodeForm(request.POST)
# 表单验证
if not form.is_valid():
return HttpResponse("错误的邮箱")
# 获取邮箱地址
to_email = form.cleaned_data["email"]
# 生成验证码
code = generate_code()
# 发送验证邮件
utils.send_verify_email(to_email, code)
# 保存验证码到缓存
utils.set_code(to_email, code)
return HttpResponse("ok")
return HttpResponse("ok")

@ -17,22 +17,18 @@ class ArticleForm(forms.ModelForm):
fields = '__all__'
#xjh管理员动作函数 - 发布选中的文章
def makr_article_publish(modeladmin, request, queryset):
queryset.update(status='p')
#xjh管理员动作函数 - 将选中的文章设为草稿
def draft_article(modeladmin, request, queryset):
queryset.update(status='d')
#xjh管理员动作函数 - 关闭文章评论
def close_article_commentstatus(modeladmin, request, queryset):
queryset.update(comment_status='c')
#xjh管理员动作函数 - 打开文章评论
def open_article_commentstatus(modeladmin, request, queryset):
queryset.update(comment_status='o')
@ -44,7 +40,6 @@ open_article_commentstatus.short_description = _('Open article comments')
class ArticlelAdmin(admin.ModelAdmin):
"""xjh文章模型的后台管理配置"""
list_per_page = 20
search_fields = ('body', 'title')
form = ArticleForm
@ -70,7 +65,6 @@ class ArticlelAdmin(admin.ModelAdmin):
open_article_commentstatus]
def link_to_category(self, obj):
"""xjh在文章列表显示分类链接"""
info = (obj.category._meta.app_label, obj.category._meta.model_name)
link = reverse('admin:%s_%s_change' % info, args=(obj.category.id,))
return format_html(u'<a href="%s">%s</a>' % (link, obj.category.name))
@ -78,18 +72,15 @@ class ArticlelAdmin(admin.ModelAdmin):
link_to_category.short_description = _('category')
def get_form(self, request, obj=None, **kwargs):
"""xjh限制作者字段只能选择超级用户"""
form = super(ArticlelAdmin, self).get_form(request, obj, **kwargs)
form.base_fields['author'].queryset = get_user_model(
).objects.filter(is_superuser=True)
return form
def save_model(self, request, obj, form, change):
"""xjh保存文章模型"""
super(ArticlelAdmin, self).save_model(request, obj, form, change)
def get_view_on_site_url(self, obj=None):
"""xjh获取文章在前台的URL"""
if obj:
url = obj.get_full_url()
return url
@ -100,27 +91,22 @@ class ArticlelAdmin(admin.ModelAdmin):
class TagAdmin(admin.ModelAdmin):
"""xjh标签模型的后台管理配置"""
exclude = ('slug', 'last_mod_time', 'creation_time')
class CategoryAdmin(admin.ModelAdmin):
"""xjh分类模型的后台管理配置"""
list_display = ('name', 'parent_category', 'index')
exclude = ('slug', 'last_mod_time', 'creation_time')
class LinksAdmin(admin.ModelAdmin):
"""xjh友情链接模型的后台管理配置"""
exclude = ('last_mod_time', 'creation_time')
class SideBarAdmin(admin.ModelAdmin):
"""xjh侧边栏模型的后台管理配置"""
list_display = ('name', 'content', 'is_enable', 'sequence')
exclude = ('last_mod_time', 'creation_time')
class BlogSettingsAdmin(admin.ModelAdmin):
"""xjh博客设置模型的后台管理配置"""
pass
pass

@ -2,5 +2,4 @@ from django.apps import AppConfig
class BlogConfig(AppConfig):
"""xjh博客应用的配置类"""
name = 'blog'
name = 'blog'

@ -9,7 +9,6 @@ logger = logging.getLogger(__name__)
def seo_processor(requests):
"""xjhSEO上下文处理器为模板提供SEO相关变量"""
key = 'seo_processor'
value = cache.get(key)
if value:
@ -17,7 +16,6 @@ def seo_processor(requests):
else:
logger.info('set processor cache.')
setting = get_blog_setting()
#xjh构建包含SEO和网站配置的上下文字典
value = {
'SITE_NAME': setting.site_name,
'SHOW_GOOGLE_ADSENSE': setting.show_google_adsense,
@ -41,6 +39,5 @@ def seo_processor(requests):
"GLOBAL_FOOTER": setting.global_footer,
"COMMENT_NEED_REVIEW": setting.comment_need_review,
}
#xjh缓存上下文数据10小时
cache.set(key, value, 60 * 60 * 10)
return value
return value

@ -7,11 +7,9 @@ from elasticsearch_dsl.connections import connections
from blog.models import Article
#xjh检查是否启用了Elasticsearch
ELASTICSEARCH_ENABLED = hasattr(settings, 'ELASTICSEARCH_DSL')
if ELASTICSEARCH_ENABLED:
#xjh创建Elasticsearch连接
connections.create_connection(
hosts=[settings.ELASTICSEARCH_DSL['default']['hosts']])
from elasticsearch import Elasticsearch
@ -23,7 +21,6 @@ if ELASTICSEARCH_ENABLED:
try:
c.get_pipeline('geoip')
except elasticsearch.exceptions.NotFoundError:
#xjh创建geoip处理管道用于IP地址地理位置解析
c.put_pipeline('geoip', body='''{
"description" : "Add geoip info",
"processors" : [
@ -37,7 +34,6 @@ if ELASTICSEARCH_ENABLED:
class GeoIp(InnerDoc):
"""xjhIP地理位置信息文档结构"""
continent_name = Keyword()
country_iso_code = Keyword()
country_name = Keyword()
@ -45,25 +41,21 @@ class GeoIp(InnerDoc):
class UserAgentBrowser(InnerDoc):
"""xjh用户代理浏览器信息文档结构"""
Family = Keyword()
Version = Keyword()
class UserAgentOS(UserAgentBrowser):
"""xjh用户代理操作系统信息文档结构"""
pass
class UserAgentDevice(InnerDoc):
"""xjh用户代理设备信息文档结构"""
Family = Keyword()
Brand = Keyword()
Model = Keyword()
class UserAgent(InnerDoc):
"""xjh完整的用户代理信息文档结构"""
browser = Object(UserAgentBrowser, required=False)
os = Object(UserAgentOS, required=False)
device = Object(UserAgentDevice, required=False)
@ -72,7 +64,6 @@ class UserAgent(InnerDoc):
class ElapsedTimeDocument(Document):
"""xjh页面加载耗时记录文档"""
url = Keyword()
time_taken = Long()
log_datetime = Date()
@ -92,10 +83,8 @@ class ElapsedTimeDocument(Document):
class ElaspedTimeDocumentManager:
"""xjh耗时文档管理器"""
@staticmethod
def build_index():
"""xjh构建性能索引"""
from elasticsearch import Elasticsearch
client = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
res = client.indices.exists(index="performance")
@ -104,14 +93,12 @@ class ElaspedTimeDocumentManager:
@staticmethod
def delete_index():
"""xjh删除性能索引"""
from elasticsearch import Elasticsearch
es = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
es.indices.delete(index='performance', ignore=[400, 404])
@staticmethod
def create(url, time_taken, log_datetime, useragent, ip):
"""xjh创建耗时记录"""
ElaspedTimeDocumentManager.build_index()
ua = UserAgent()
ua.browser = UserAgentBrowser()
@ -129,7 +116,6 @@ class ElaspedTimeDocumentManager:
ua.string = useragent.ua_string
ua.is_bot = useragent.is_bot
#xjh使用当前时间戳作为文档ID
doc = ElapsedTimeDocument(
meta={
'id': int(
@ -141,11 +127,10 @@ class ElaspedTimeDocumentManager:
time_taken=time_taken,
log_datetime=log_datetime,
useragent=ua, ip=ip)
doc.save(pipeline="geoip") #xjh保存文档并通过geoip管道处理
doc.save(pipeline="geoip")
class ArticleDocument(Document):
"""xjh文章搜索文档结构"""
body = Text(analyzer='ik_max_word', search_analyzer='ik_smart')
title = Text(analyzer='ik_max_word', search_analyzer='ik_smart')
author = Object(properties={
@ -180,23 +165,19 @@ class ArticleDocument(Document):
class ArticleDocumentManager():
"""xjh文章文档管理器"""
def __init__(self):
self.create_index()
def create_index(self):
"""xjh创建文章索引"""
ArticleDocument.init()
def delete_index(self):
"""xjh删除文章索引"""
from elasticsearch import Elasticsearch
es = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
es.indices.delete(index='blog', ignore=[400, 404])
def convert_to_doc(self, articles):
"""xjh将文章对象转换为搜索文档"""
return [
ArticleDocument(
meta={
@ -221,7 +202,6 @@ class ArticleDocumentManager():
article_order=article.article_order) for article in articles]
def rebuild(self, articles=None):
"""xjh重建文章索引"""
ArticleDocument.init()
articles = articles if articles else Article.objects.all()
docs = self.convert_to_doc(articles)
@ -229,6 +209,5 @@ class ArticleDocumentManager():
doc.save()
def update_docs(self, docs):
"""xjh更新文档"""
for doc in docs:
doc.save()
doc.save()

@ -7,15 +7,13 @@ logger = logging.getLogger(__name__)
class BlogSearchForm(SearchForm):
"""xjh博客搜索表单"""
querydata = forms.CharField(required=True)
def search(self):
"""xjh执行搜索"""
datas = super(BlogSearchForm, self).search()
if not self.is_valid():
return self.no_query_found()
if self.cleaned_data['querydata']:
logger.info(self.cleaned_data['querydata'])
return datas
return datas

@ -10,16 +10,14 @@ logger = logging.getLogger(__name__)
class OnlineMiddleware(object):
"""xjh在线中间件用于记录页面加载时间和用户访问信息"""
def __init__(self, get_response=None):
self.get_response = get_response
super().__init__()
def __call__(self, request):
''' xjh页面渲染时间统计 '''
''' page render time '''
start_time = time.time()
response = self.get_response(request)
#xjh获取用户IP和User-Agent信息
http_user_agent = request.META.get('HTTP_USER_AGENT', '')
ip, _ = get_client_ip(request)
user_agent = parse(http_user_agent)
@ -27,7 +25,6 @@ class OnlineMiddleware(object):
try:
cast_time = time.time() - start_time
if ELASTICSEARCH_ENABLED:
#xjh记录页面加载耗时到Elasticsearch
time_taken = round((cast_time) * 1000, 2)
url = request.path
from django.utils import timezone
@ -37,10 +34,9 @@ class OnlineMiddleware(object):
log_datetime=timezone.now(),
useragent=user_agent,
ip=ip)
#xjh在页面内容中替换加载时间占位符
response.content = response.content.replace(
b'<!!LOAD_TIMES!!>', str.encode(str(cast_time)[:5]))
except Exception as e:
logger.error("Error OnlineMiddleware: %s" % e)
return response
return response

@ -18,7 +18,6 @@ logger = logging.getLogger(__name__)
class LinkShowType(models.TextChoices):
"""xjh链接显示类型选择"""
I = ('i', _('index'))
L = ('l', _('list'))
P = ('p', _('post'))
@ -27,20 +26,17 @@ class LinkShowType(models.TextChoices):
class BaseModel(models.Model):
"""xjh基础模型类包含公共字段"""
id = models.AutoField(primary_key=True)
creation_time = models.DateTimeField(_('creation time'), default=now)
last_modify_time = models.DateTimeField(_('modify time'), default=now)
def save(self, *args, **kwargs):
#xjh特殊处理文章浏览量更新
is_update_views = isinstance(
self,
Article) and 'update_fields' in kwargs and kwargs['update_fields'] == ['views']
if is_update_views:
Article.objects.filter(pk=self.pk).update(views=self.views)
else:
#xjh自动生成slug字段
if 'slug' in self.__dict__:
slug = getattr(
self, 'title') if 'title' in self.__dict__ else getattr(
@ -49,7 +45,6 @@ class BaseModel(models.Model):
super().save(*args, **kwargs)
def get_full_url(self):
"""xjh获取完整URL"""
site = get_current_site().domain
url = "https://{site}{path}".format(site=site,
path=self.get_absolute_url())
@ -60,12 +55,11 @@ class BaseModel(models.Model):
@abstractmethod
def get_absolute_url(self):
"""xjh抽象方法获取绝对URL"""
pass
class Article(BaseModel):
"""xjh文章模型"""
"""文章"""
STATUS_CHOICES = (
('d', _('Draft')),
('p', _('Published')),
@ -124,7 +118,6 @@ class Article(BaseModel):
get_latest_by = 'id'
def get_absolute_url(self):
"""xjh获取文章绝对URL"""
return reverse('blog:detailbyid', kwargs={
'article_id': self.id,
'year': self.creation_time.year,
@ -134,7 +127,6 @@ class Article(BaseModel):
@cache_decorator(60 * 60 * 10)
def get_category_tree(self):
"""xjh获取分类树"""
tree = self.category.get_category_tree()
names = list(map(lambda c: (c.name, c.get_absolute_url()), tree))
@ -144,12 +136,10 @@ class Article(BaseModel):
super().save(*args, **kwargs)
def viewed(self):
"""xjh增加文章浏览量"""
self.views += 1
self.save(update_fields=['views'])
def comment_list(self):
"""xjh获取文章评论列表"""
cache_key = 'article_comments_{id}'.format(id=self.id)
value = cache.get(cache_key)
if value:
@ -162,24 +152,23 @@ class Article(BaseModel):
return comments
def get_admin_url(self):
"""xjh获取文章管理后台URL"""
info = (self._meta.app_label, self._meta.model_name)
return reverse('admin:%s_%s_change' % info, args=(self.pk,))
@cache_decorator(expiration=60 * 100)
def next_article(self):
"""xjh获取下一篇文章"""
# 下一篇
return Article.objects.filter(
id__gt=self.id, status='p').order_by('id').first()
@cache_decorator(expiration=60 * 100)
def prev_article(self):
"""xjh获取上一篇文章"""
# 前一篇
return Article.objects.filter(id__lt=self.id, status='p').first()
def get_first_image_url(self):
"""
xjh从文章内容中提取第一张图片URL
Get the first image url from article.body.
:return:
"""
match = re.search(r'!\[.*?\]\((.+?)\)', self.body)
@ -189,7 +178,7 @@ class Article(BaseModel):
class Category(BaseModel):
"""xjh文章分类模型"""
"""文章分类"""
name = models.CharField(_('category name'), max_length=30, unique=True)
parent_category = models.ForeignKey(
'self',
@ -216,7 +205,7 @@ class Category(BaseModel):
@cache_decorator(60 * 60 * 10)
def get_category_tree(self):
"""
xjh递归获得分类目录的父级
递归获得分类目录的父级
:return:
"""
categorys = []
@ -232,7 +221,7 @@ class Category(BaseModel):
@cache_decorator(60 * 60 * 10)
def get_sub_categorys(self):
"""
xjh获得当前分类目录所有子集
获得当前分类目录所有子集
:return:
"""
categorys = []
@ -252,7 +241,7 @@ class Category(BaseModel):
class Tag(BaseModel):
"""xjh文章标签模型"""
"""文章标签"""
name = models.CharField(_('tag name'), max_length=30, unique=True)
slug = models.SlugField(default='no-slug', max_length=60, blank=True)
@ -264,7 +253,6 @@ class Tag(BaseModel):
@cache_decorator(60 * 60 * 10)
def get_article_count(self):
"""xjh获取标签下的文章数量"""
return Article.objects.filter(tags__name=self.name).distinct().count()
class Meta:
@ -274,7 +262,7 @@ class Tag(BaseModel):
class Links(models.Model):
"""xjh友情链接模型"""
"""友情链接"""
name = models.CharField(_('link name'), max_length=30, unique=True)
link = models.URLField(_('link'))
@ -299,7 +287,7 @@ class Links(models.Model):
class SideBar(models.Model):
"""xjh侧边栏模型可以展示一些html内容"""
"""侧边栏,可以展示一些html内容"""
name = models.CharField(_('title'), max_length=100)
content = models.TextField(_('content'))
sequence = models.IntegerField(_('order'), unique=True)
@ -317,7 +305,7 @@ class SideBar(models.Model):
class BlogSettings(models.Model):
"""xjh博客设置模型"""
"""blog的配置"""
site_name = models.CharField(
_('site name'),
max_length=200,
@ -379,12 +367,10 @@ class BlogSettings(models.Model):
return self.site_name
def clean(self):
"""xjh验证只能有一个博客配置实例"""
if BlogSettings.objects.exclude(id=self.id).count():
raise ValidationError(_('There can only be one configuration'))
def save(self, *args, **kwargs):
"""xjh保存时清除缓存"""
super().save(*args, **kwargs)
from djangoblog.utils import cache
cache.clear()
cache.clear()

@ -4,13 +4,10 @@ from blog.models import Article
class ArticleIndex(indexes.SearchIndex, indexes.Indexable):
"""xjh文章搜索索引配置"""
text = indexes.CharField(document=True, use_template=True)
def get_model(self):
"""xjh指定搜索模型"""
return Article
def index_queryset(self, using=None):
"""xjh指定索引查询集只索引已发布的文章"""
return self.get_model().objects.filter(status='p')
return self.get_model().objects.filter(status='p')

@ -20,13 +20,11 @@ from oauth.models import OAuthUser, OAuthConfig
# Create your tests here.
class ArticleTest(TestCase):
"""xjh文章相关测试用例"""
def setUp(self):
self.client = Client()
self.factory = RequestFactory()
def test_validate_article(self):
"""xjh测试文章相关功能"""
site = get_current_site().domain
user = BlogUser.objects.get_or_create(
email="liangliangyy@gmail.com",
@ -151,7 +149,6 @@ class ArticleTest(TestCase):
self.client.get('/admin/admin/logentry/1/change/')
def check_pagination(self, p, type, value):
"""xjh检查分页功能"""
for page in range(1, p.num_pages + 1):
s = load_pagination_info(p.page(page), type, value)
self.assertIsNotNone(s)
@ -163,7 +160,6 @@ class ArticleTest(TestCase):
self.assertEqual(response.status_code, 200)
def test_image(self):
"""xjh测试图片上传和处理功能"""
import requests
rsp = requests.get(
'https://www.python.org/static/img/python-logo.png')
@ -187,12 +183,10 @@ class ArticleTest(TestCase):
'https://www.python.org/static/img/python-logo.png')
def test_errorpage(self):
"""xjh测试错误页面"""
rsp = self.client.get('/eee')
self.assertEqual(rsp.status_code, 404)
def test_commands(self):
"""xjh测试管理命令"""
user = BlogUser.objects.get_or_create(
email="liangliangyy@gmail.com",
username="liangliangyy")[0]
@ -235,4 +229,4 @@ class ArticleTest(TestCase):
call_command("create_testdata")
call_command("clear_cache")
call_command("sync_user_avatar")
call_command("build_search_words")
call_command("build_search_words")

@ -5,71 +5,58 @@ from . import views
app_name = "blog"
urlpatterns = [
#xjh首页
path(
r'',
views.IndexView.as_view(),
name='index'),
#xjh首页分页
path(
r'page/<int:page>/',
views.IndexView.as_view(),
name='index_page'),
#xjh文章详情页
path(
r'article/<int:year>/<int:month>/<int:day>/<int:article_id>.html',
views.ArticleDetailView.as_view(),
name='detailbyid'),
#xjh分类目录页
path(
r'category/<slug:category_name>.html',
views.CategoryDetailView.as_view(),
name='category_detail'),
#xjh分类目录分页
path(
r'category/<slug:category_name>/<int:page>.html',
views.CategoryDetailView.as_view(),
name='category_detail_page'),
#xjh作者文章页
path(
r'author/<author_name>.html',
views.AuthorDetailView.as_view(),
name='author_detail'),
#xjh作者文章分页
path(
r'author/<author_name>/<int:page>.html',
views.AuthorDetailView.as_view(),
name='author_detail_page'),
#xjh标签文章页
path(
r'tag/<slug:tag_name>.html',
views.TagDetailView.as_view(),
name='tag_detail'),
#xjh标签文章分页
path(
r'tag/<slug:tag_name>/<int:page>.html',
views.TagDetailView.as_view(),
name='tag_detail_page'),
#xjh文章归档页缓存1小时
path(
'archives.html',
cache_page(
60 * 60)(
views.ArchivesView.as_view()),
name='archives'),
#xjh友情链接页
path(
'links.html',
views.LinkListView.as_view(),
name='links'),
#xjh文件上传接口
path(
r'upload',
views.fileupload,
name='upload'),
#xjh清理缓存接口
path(
r'clean',
views.clean_cache_view,
name='clean'),
]
]

@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
class ArticleListView(ListView):
"""xjh文章列表视图基类"""
# template_name属性用于指定使用哪个模板进行渲染
template_name = 'blog/article_index.html'
@ -62,7 +61,7 @@ class ArticleListView(ListView):
def get_queryset_from_cache(self, cache_key):
'''
xjh缓存页面数据
缓存页面数据
:param cache_key: 缓存key
:return:
'''
@ -78,7 +77,7 @@ class ArticleListView(ListView):
def get_queryset(self):
'''
xjh重写默认从缓存获取数据
重写默认从缓存获取数据
:return:
'''
key = self.get_queryset_cache_key()
@ -92,7 +91,7 @@ class ArticleListView(ListView):
class IndexView(ArticleListView):
'''
xjh首页视图
首页
'''
# 友情链接类型
link_type = LinkShowType.I
@ -108,7 +107,7 @@ class IndexView(ArticleListView):
class ArticleDetailView(DetailView):
'''
xjh文章详情页面视图
文章详情页面
'''
template_name = 'blog/article_detail.html'
model = Article
@ -164,7 +163,7 @@ class ArticleDetailView(DetailView):
class CategoryDetailView(ArticleListView):
'''
xjh分类目录列表视图
分类目录列表
'''
page_type = "分类目录归档"
@ -203,7 +202,7 @@ class CategoryDetailView(ArticleListView):
class AuthorDetailView(ArticleListView):
'''
xjh作者详情页视图
作者详情页
'''
page_type = '作者文章归档'
@ -229,7 +228,7 @@ class AuthorDetailView(ArticleListView):
class TagDetailView(ArticleListView):
'''
xjh标签列表页面视图
标签列表页面
'''
page_type = '分类标签归档'
@ -261,7 +260,7 @@ class TagDetailView(ArticleListView):
class ArchivesView(ArticleListView):
'''
xjh文章归档页面视图
文章归档页面
'''
page_type = '文章归档'
paginate_by = None
@ -277,7 +276,6 @@ class ArchivesView(ArticleListView):
class LinkListView(ListView):
"""xjh友情链接列表视图"""
model = Links
template_name = 'blog/links_list.html'
@ -286,7 +284,6 @@ class LinkListView(ListView):
class EsSearchView(SearchView):
"""xjhElasticsearch搜索视图"""
def get_context(self):
paginator, page = self.build_page()
context = {
@ -306,7 +303,7 @@ class EsSearchView(SearchView):
@csrf_exempt
def fileupload(request):
"""
xjh文件上传接口该方法需自己写调用端来上传图片该方法仅提供图床功能
该方法需自己写调用端来上传图片该方法仅提供图床功能
:param request:
:return:
"""
@ -347,7 +344,6 @@ def page_not_found_view(
request,
exception,
template_name='blog/error_page.html'):
"""xjh404页面处理视图"""
if exception:
logger.error(exception)
url = request.get_full_path()
@ -359,7 +355,6 @@ def page_not_found_view(
def server_error_view(request, template_name='blog/error_page.html'):
"""xjh500页面处理视图"""
return render(request,
template_name,
{'message': _('Sorry, the server is busy, please click the home page to see other?'),
@ -371,7 +366,6 @@ def permission_denied_view(
request,
exception,
template_name='blog/error_page.html'):
"""xjh403页面处理视图"""
if exception:
logger.error(exception)
return render(
@ -381,6 +375,5 @@ def permission_denied_view(
def clean_cache_view(request):
"""xjh清理缓存视图"""
cache.clear()
return HttpResponse('ok')
return HttpResponse('ok')

@ -5,15 +5,15 @@ from django.utils.translation import gettext_lazy as _
def disable_commentstatus(modeladmin, request, queryset):
queryset.update(is_enable=False) # 杨智鑫:批量设置评论为禁用状态
queryset.update(is_enable=False)
def enable_commentstatus(modeladmin, request, queryset):
queryset.update(is_enable=True) # 杨智鑫:批量设置评论为启用状态
queryset.update(is_enable=True)
disable_commentstatus.short_description = _('Disable comments') # 杨智鑫:批量禁用评论
enable_commentstatus.short_description = _('Enable comments') # 杨智鑫:批量启用评论
disable_commentstatus.short_description = _('Disable comments')
enable_commentstatus.short_description = _('Enable comments')
class CommentAdmin(admin.ModelAdmin):
@ -24,24 +24,24 @@ class CommentAdmin(admin.ModelAdmin):
'link_to_userinfo',
'link_to_article',
'is_enable',
'creation_time') # 杨智鑫:显示
list_display_links = ('id', 'body', 'is_enable') # 杨智鑫:可点击
list_filter = ('is_enable',) # 杨智鑫:过滤
exclude = ('creation_time', 'last_modify_time') # 杨智鑫:不显示创建时间
actions = [disable_commentstatus, enable_commentstatus] # 杨智鑫:批量操作
'creation_time')
list_display_links = ('id', 'body', 'is_enable')
list_filter = ('is_enable',)
exclude = ('creation_time', 'last_modify_time')
actions = [disable_commentstatus, enable_commentstatus]
def link_to_userinfo(self, obj):
info = (obj.author._meta.app_label, obj.author._meta.model_name) # 杨智鑫:获取用户信息
link = reverse('admin:%s_%s_change' % info, args=(obj.author.id,)) # 杨智鑫:获取用户信息
info = (obj.author._meta.app_label, obj.author._meta.model_name)
link = reverse('admin:%s_%s_change' % info, args=(obj.author.id,))
return format_html(
u'<a href="%s">%s</a>' %
(link, obj.author.nickname if obj.author.nickname else obj.author.email)) # 杨智鑫:获取用户信息
(link, obj.author.nickname if obj.author.nickname else obj.author.email))
def link_to_article(self, obj):
info = (obj.article._meta.app_label, obj.article._meta.model_name)
link = reverse('admin:%s_%s_change' % info, args=(obj.article.id,)) # 杨智鑫:获取文章信息
link = reverse('admin:%s_%s_change' % info, args=(obj.article.id,))
return format_html(
u'<a href="%s">%s</a>' % (link, obj.article.title)) # 杨智鑫:获取文章信息
u'<a href="%s">%s</a>' % (link, obj.article.title))
link_to_userinfo.short_description = _('User') # 杨智鑫:用户
link_to_article.short_description = _('Article') # 杨智鑫:文章
link_to_userinfo.short_description = _('User')
link_to_article.short_description = _('Article')

@ -2,4 +2,4 @@ from django.apps import AppConfig
class CommentsConfig(AppConfig):
name = 'comments' # 杨智鑫:应用名称
name = 'comments'

@ -6,8 +6,8 @@ from .models import Comment
class CommentForm(ModelForm):
parent_comment_id = forms.IntegerField(
widget=forms.HiddenInput, required=False) # 杨智鑫隐藏字段用于处理回复评论的父评论ID
widget=forms.HiddenInput, required=False)
class Meta:
model = Comment # 杨智鑫:指定表单关联的模型
fields = ['body'] # 杨智鑫:表单只包含评论内容字段
model = Comment
fields = ['body']

@ -9,31 +9,31 @@ from blog.models import Article
# Create your models here.
class Comment(models.Model):
body = models.TextField('正文', max_length=300) # 杨智鑫评论正文最大长度300字符
creation_time = models.DateTimeField(_('creation time'), default=now) # 杨智鑫:评论创建时间
last_modify_time = models.DateTimeField(_('last modify time'), default=now) # 杨智鑫:最后修改时间
body = models.TextField('正文', max_length=300)
creation_time = models.DateTimeField(_('creation time'), default=now)
last_modify_time = models.DateTimeField(_('last modify time'), default=now)
author = models.ForeignKey(
settings.AUTH_USER_MODEL,
verbose_name=_('author'),
on_delete=models.CASCADE) # 杨智鑫:关联用户模型,删除用户时级联删除评论
on_delete=models.CASCADE)
article = models.ForeignKey(
Article,
verbose_name=_('article'),
on_delete=models.CASCADE) # 杨智鑫:关联文章模型,删除文章时级联删除评论
on_delete=models.CASCADE)
parent_comment = models.ForeignKey(
'self',
verbose_name=_('parent comment'),
blank=True,
null=True,
on_delete=models.CASCADE) # 杨智鑫:关联父级评论,默认为空
on_delete=models.CASCADE)
is_enable = models.BooleanField(_('enable'),
default=False, blank=False, null=False) # 杨智鑫:是否启用
default=False, blank=False, null=False)
class Meta:
ordering = ['-id'] # 杨智鑫:排序
verbose_name = _('comment') # 杨智鑫:模型名称
verbose_name_plural = verbose_name # 杨智鑫:模型名称复数
get_latest_by = 'id' # 杨智鑫:获取最新
ordering = ['-id']
verbose_name = _('comment')
verbose_name_plural = verbose_name
get_latest_by = 'id'
def __str__(self):
return self.body # 杨智鑫:返回评论内容
return self.body

@ -4,7 +4,7 @@ from django.urls import reverse
from accounts.models import BlogUser
from blog.models import Category, Article
from comments.models import Comment
from comments.templatetags.comments_tags import * # 杨智鑫:导入所有模板标签函数
from comments.templatetags.comments_tags import *
from djangoblog.utils import get_max_articleid_commentid
@ -12,39 +12,39 @@ from djangoblog.utils import get_max_articleid_commentid
class CommentsTest(TransactionTestCase):
def setUp(self):
self.client = Client() # 杨智鑫创建一个client
self.factory = RequestFactory() # 杨智鑫创建一个factory
self.client = Client()
self.factory = RequestFactory()
from blog.models import BlogSettings
value = BlogSettings() # 杨智鑫创建一个BlogSettings对象
value.comment_need_review = True # 杨智鑫:设置评论是否需要审核
value.save() # 杨智鑫:保存
value = BlogSettings()
value.comment_need_review = True
value.save()
self.user = BlogUser.objects.create_superuser(
email="liangliangyy1@gmail.com",
username="liangliangyy1",
password="liangliangyy1") # 杨智鑫:创建一个用户
password="liangliangyy1")
def update_article_comment_status(self, article):
comments = article.comment_set.all()
for comment in comments: # 杨智鑫:遍历所有评论
comment.is_enable = True # 杨智鑫:设置评论是否可用
comment.save() # 杨智鑫:保存
for comment in comments:
comment.is_enable = True
comment.save()
def test_validate_comment(self):
self.client.login(username='liangliangyy1', password='liangliangyy1') # 杨智鑫:登录
self.client.login(username='liangliangyy1', password='liangliangyy1')
category = Category() # 杨智鑫:创建一个分类
category.name = "categoryccc" # 杨智鑫:创建一个分类
category = Category()
category.name = "categoryccc"
category.save()
article = Article() # 杨智鑫:创建一个文章
article.title = "nicetitleccc" # 杨智鑫:创建一个文章
article.body = "nicecontentccc" # 杨智鑫:创建一个文章
article = Article()
article.title = "nicetitleccc"
article.body = "nicecontentccc"
article.author = self.user
article.category = category
article.type = 'a'
article.status = 'p'
article.save() # 杨智鑫:保存
article.save()
comment_url = reverse(
'comments:postcomment', kwargs={
@ -55,25 +55,25 @@ class CommentsTest(TransactionTestCase):
'body': '123ffffffffff'
})
self.assertEqual(response.status_code, 302) # 杨智鑫:判断返回状态码
self.assertEqual(response.status_code, 302)
article = Article.objects.get(pk=article.pk)
self.assertEqual(len(article.comment_list()), 0) #杨智鑫:判断评论数量
self.assertEqual(len(article.comment_list()), 0)
self.update_article_comment_status(article)
self.assertEqual(len(article.comment_list()), 1) #杨智鑫:判断评论数量
self.assertEqual(len(article.comment_list()), 1)
response = self.client.post(comment_url,
{
'body': '123ffffffffff',
}) # 杨智鑫:提交数据
})
self.assertEqual(response.status_code, 302) # 杨智鑫:判断返回状态码
self.assertEqual(response.status_code, 302)
article = Article.objects.get(pk=article.pk) # 杨智鑫:获取文章
self.update_article_comment_status(article) # 杨智鑫:更新文章评论状态
self.assertEqual(len(article.comment_list()), 2) #杨智鑫:判断评论数量
parent_comment_id = article.comment_list()[0].id #杨智鑫获取父评论id
article = Article.objects.get(pk=article.pk)
self.update_article_comment_status(article)
self.assertEqual(len(article.comment_list()), 2)
parent_comment_id = article.comment_list()[0].id
response = self.client.post(comment_url,
{
@ -91,19 +91,19 @@ class CommentsTest(TransactionTestCase):
''',
'parent_comment_id': parent_comment_id
}) # 杨智鑫:提交数据
self.assertEqual(response.status_code, 302) # 杨智鑫:判断返回状态码
self.update_article_comment_status(article) # 杨智鑫:更新文章评论状态
article = Article.objects.get(pk=article.pk) # 杨智鑫:获取文章
self.assertEqual(len(article.comment_list()), 3) # 杨智鑫:判断评论数量
comment = Comment.objects.get(id=parent_comment_id) # 杨智鑫:获取父评论
tree = parse_commenttree(article.comment_list(), comment) # 杨智鑫:获取子评论
self.assertEqual(len(tree), 1) # 杨智鑫:判断子评论数量
data = show_comment_item(comment, True) # 杨智鑫:获取评论项
self.assertIsNotNone(data) # 杨智鑫:判断数据是否为空
s = get_max_articleid_commentid() # 杨智鑫获取最大文章id和评论id
self.assertIsNotNone(s) # 杨智鑫:判断数据是否为空
})
self.assertEqual(response.status_code, 302)
self.update_article_comment_status(article)
article = Article.objects.get(pk=article.pk)
self.assertEqual(len(article.comment_list()), 3)
comment = Comment.objects.get(id=parent_comment_id)
tree = parse_commenttree(article.comment_list(), comment)
self.assertEqual(len(tree), 1)
data = show_comment_item(comment, True)
self.assertIsNotNone(data)
s = get_max_articleid_commentid()
self.assertIsNotNone(s)
from comments.utils import send_comment_email
send_comment_email(comment) # 杨智鑫:发送邮件
send_comment_email(comment)

@ -2,10 +2,10 @@ from django.urls import path
from . import views
app_name = "comments" # 杨智鑫:定义应用命名空间
app_name = "comments"
urlpatterns = [
path(
'article/<int:article_id>/postcomment',
views.CommentPostView.as_view(), # 杨智鑫:定义路由
name='postcomment'), # 杨智鑫:定义路由名称
views.CommentPostView.as_view(),
name='postcomment'),
]

@ -5,12 +5,12 @@ from django.utils.translation import gettext_lazy as _
from djangoblog.utils import get_current_site
from djangoblog.utils import send_email
logger = logging.getLogger(__name__) # 杨智鑫:获取当前模块的日志器
logger = logging.getLogger(__name__)
def send_comment_email(comment):
site = get_current_site().domain # 杨智鑫:获取当前站点
subject = _('Thanks for your comment') # 杨智鑫获取当前语言的Thanks for your comment
site = get_current_site().domain
subject = _('Thanks for your comment')
article_url = f"https://{site}{comment.article.get_absolute_url()}"
html_content = _("""<p>Thank you very much for your comments on this site</p>
You can visit <a href="%(article_url)s" rel="bookmark">%(article_title)s</a>
@ -18,9 +18,9 @@ def send_comment_email(comment):
Thank you again!
<br />
If the link above cannot be opened, please copy this link to your browser.
%(article_url)s""") % {'article_url': article_url, 'article_title': comment.article.title} # 杨智鑫:获取当前语言
tomail = comment.author.email # 杨智鑫:获取评论者的邮箱
send_email([tomail], subject, html_content) # 杨智鑫:发送邮件
%(article_url)s""") % {'article_url': article_url, 'article_title': comment.article.title}
tomail = comment.author.email
send_email([tomail], subject, html_content)
try:
if comment.parent_comment:
html_content = _("""Your comment on <a href="%(article_url)s" rel="bookmark">%(article_title)s</a><br/> has
@ -33,6 +33,6 @@ def send_comment_email(comment):
""") % {'article_url': article_url, 'article_title': comment.article.title,
'comment_body': comment.parent_comment.body}
tomail = comment.parent_comment.author.email
send_email([tomail], subject, html_content) # 杨智鑫:发送邮件
send_email([tomail], subject, html_content)
except Exception as e:
logger.error(e) # 杨智鑫:记录错误
logger.error(e)

@ -13,51 +13,51 @@ from .models import Comment
class CommentPostView(FormView):
form_class = CommentForm # 杨智鑫:指定使用的表单类
template_name = 'blog/article_detail.html' # 杨智鑫:指定使用的模板
form_class = CommentForm
template_name = 'blog/article_detail.html'
@method_decorator(csrf_protect)
def dispatch(self, *args, **kwargs): # 杨智鑫添加csrf_protect装饰器
return super(CommentPostView, self).dispatch(*args, **kwargs) # 杨智鑫调用父类的dispatch方法
def dispatch(self, *args, **kwargs):
return super(CommentPostView, self).dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs):
article_id = self.kwargs['article_id'] # 杨智鑫获取文章id
article = get_object_or_404(Article, pk=article_id) # 杨智鑫:获取文章对象
url = article.get_absolute_url() # 杨智鑫获取文章的url
return HttpResponseRedirect(url + "#comments") # 杨智鑫:跳转到文章的评论区
article_id = self.kwargs['article_id']
article = get_object_or_404(Article, pk=article_id)
url = article.get_absolute_url()
return HttpResponseRedirect(url + "#comments")
def form_invalid(self, form):
article_id = self.kwargs['article_id'] # 杨智鑫获取文章id
article = get_object_or_404(Article, pk=article_id) # 杨智鑫:获取文章对象
article_id = self.kwargs['article_id']
article = get_object_or_404(Article, pk=article_id)
return self.render_to_response({
'form': form,
'article': article
}) # 杨智鑫:返回错误信息
})
def form_valid(self, form):
"""提交的数据验证合法后的逻辑"""
user = self.request.user # 杨智鑫:获取用户
author = BlogUser.objects.get(pk=user.pk) # 杨智鑫:获取用户对象
article_id = self.kwargs['article_id'] # 杨智鑫获取文章id
article = get_object_or_404(Article, pk=article_id) # 杨智鑫:获取文章对象
user = self.request.user
author = BlogUser.objects.get(pk=user.pk)
article_id = self.kwargs['article_id']
article = get_object_or_404(Article, pk=article_id)
if article.comment_status == 'c' or article.status == 'c':
raise ValidationError("该文章评论已关闭.") # 杨智鑫:抛出异常
comment = form.save(False) # 杨智鑫:保存评论
comment.article = article # 杨智鑫:设置评论所属文章
raise ValidationError("该文章评论已关闭.")
comment = form.save(False)
comment.article = article
from djangoblog.utils import get_blog_setting
settings = get_blog_setting() # 杨智鑫:获取博客设置
settings = get_blog_setting()
if not settings.comment_need_review:
comment.is_enable = True
comment.author = author # 杨智鑫:设置评论作者
comment.author = author
if form.cleaned_data['parent_comment_id']: # 杨智鑫:判断是否有父级评论
if form.cleaned_data['parent_comment_id']:
parent_comment = Comment.objects.get(
pk=form.cleaned_data['parent_comment_id']) # 杨智鑫:获取父级评论对象
comment.parent_comment = parent_comment # 杨智鑫:设置父级评论
pk=form.cleaned_data['parent_comment_id'])
comment.parent_comment = parent_comment
comment.save(True) # 杨智鑫:保存评论
comment.save(True)
return HttpResponseRedirect(
"%s#div-comment-%d" %
(article.get_absolute_url(), comment.pk)) # 杨智鑫:跳转到评论区
(article.get_absolute_url(), comment.pk))

@ -18,7 +18,6 @@ from servermanager.models import *
class DjangoBlogAdminSite(AdminSite):
#mj 自定义Django管理站点
site_header = 'djangoblog administration'
site_title = 'djangoblog site admin'
@ -26,7 +25,6 @@ class DjangoBlogAdminSite(AdminSite):
super().__init__(name)
def has_permission(self, request):
#mj 只有超级用户才能访问管理站点
return request.user.is_superuser
# def get_urls(self):
@ -42,7 +40,6 @@ class DjangoBlogAdminSite(AdminSite):
admin_site = DjangoBlogAdminSite(name='admin')
#mj 注册所有模型到自定义管理站点
admin_site.register(Article, ArticlelAdmin)
admin_site.register(Category, CategoryAdmin)
admin_site.register(Tag, TagAdmin)
@ -65,4 +62,3 @@ admin_site.register(OwnTrackLog, OwnTrackLogsAdmin)
admin_site.register(Site, SiteAdmin)
admin_site.register(LogEntry, LogEntryAdmin)
# [file content end]

@ -6,8 +6,6 @@ class DjangoblogAppConfig(AppConfig):
def ready(self):
super().ready()
#mj 应用启动时加载插件
# Import and load plugins here
from .plugin_manage.loader import load_plugins
load_plugins()
# [file content end]
load_plugins()

@ -18,7 +18,6 @@ from oauth.models import OAuthUser
logger = logging.getLogger(__name__)
#mj 定义自定义信号
oauth_user_login_signal = django.dispatch.Signal(['id'])
send_email_signal = django.dispatch.Signal(
['emailto', 'title', 'content'])
@ -26,7 +25,6 @@ send_email_signal = django.dispatch.Signal(
@receiver(send_email_signal)
def send_email_signal_handler(sender, **kwargs):
#mj 发送邮件信号处理器
emailto = kwargs['emailto']
title = kwargs['title']
content = kwargs['content']
@ -55,7 +53,6 @@ def send_email_signal_handler(sender, **kwargs):
@receiver(oauth_user_login_signal)
def oauth_user_login_signal_handler(sender, **kwargs):
#mj OAuth用户登录信号处理器
id = kwargs['id']
oauthuser = OAuthUser.objects.get(id=id)
site = get_current_site().domain
@ -76,7 +73,6 @@ def model_post_save_callback(
using,
update_fields,
**kwargs):
#mj 模型保存后的回调函数
clearcache = False
if isinstance(instance, LogEntry):
return
@ -92,7 +88,6 @@ def model_post_save_callback(
clearcache = True
if isinstance(instance, Comment):
#mj 处理评论保存逻辑
if instance.is_enable:
path = instance.article.get_absolute_url()
site = get_current_site().domain
@ -121,9 +116,7 @@ def model_post_save_callback(
@receiver(user_logged_in)
@receiver(user_logged_out)
def user_auth_callback(sender, request, user, **kwargs):
#mj 用户登录/登出回调函数
if user and user.username:
logger.info(user)
delete_sidebar_cache()
# cache.clear()
# [file content end]

@ -12,7 +12,6 @@ logger = logging.getLogger(__name__)
class ElasticSearchBackend(BaseSearchBackend):
#mj ElasticSearch搜索后端实现
def __init__(self, connection_alias, **connection_options):
super(
ElasticSearchBackend,
@ -43,6 +42,7 @@ class ElasticSearchBackend(BaseSearchBackend):
self.manager.update_docs(docs)
def update(self, index, iterable, commit=True):
models = self._get_models(iterable)
self.manager.update_docs(models)
@ -56,6 +56,7 @@ class ElasticSearchBackend(BaseSearchBackend):
@staticmethod
def get_suggestion(query: str) -> str:
"""获取推荐词, 如果没有找到添加原搜索词"""
search = ArticleDocument.search() \
.query("match", body=query) \
.suggest('suggest_search', query, term={'field': 'body'}) \
@ -72,7 +73,6 @@ class ElasticSearchBackend(BaseSearchBackend):
@log_query
def search(self, query_string, **kwargs):
#mj 执行搜索
logger.info('search query_string:' + query_string)
start_offset = kwargs.get('start_offset')
@ -123,7 +123,6 @@ class ElasticSearchBackend(BaseSearchBackend):
class ElasticSearchQuery(BaseSearchQuery):
#mj ElasticSearch查询类
def _convert_datetime(self, date):
if hasattr(date, 'hour'):
return force_str(date.strftime('%Y%m%d%H%M%S'))
@ -171,7 +170,7 @@ class ElasticSearchQuery(BaseSearchQuery):
class ElasticSearchModelSearchForm(ModelSearchForm):
#mj ElasticSearch模型搜索表单
def search(self):
# 是否建议搜索
self.searchqueryset.query.backend.is_suggest = self.data.get("is_suggest") != "no"
@ -180,7 +179,5 @@ class ElasticSearchModelSearchForm(ModelSearchForm):
class ElasticSearchEngine(BaseEngine):
#mj ElasticSearch引擎
backend = ElasticSearchBackend
query = ElasticSearchQuery
# [file content end]

@ -8,7 +8,6 @@ from djangoblog.utils import CommonMarkdown
class DjangoBlogFeed(Feed):
#mj RSS订阅源
feed_type = Rss201rev2Feed
description = '大巧无工,重剑无锋.'
@ -22,14 +21,12 @@ class DjangoBlogFeed(Feed):
return get_user_model().objects.first().get_absolute_url()
def items(self):
#mj 返回最新的5篇文章
return Article.objects.filter(type='a', status='p').order_by('-pub_time')[:5]
def item_title(self, item):
return item.title
def item_description(self, item):
#mj 将Markdown内容转换为HTML
return CommonMarkdown.get_markdown(item.body)
def feed_copyright(self):
@ -41,4 +38,3 @@ class DjangoBlogFeed(Feed):
def item_guid(self, item):
return
# [file content end]

@ -9,7 +9,6 @@ from django.utils.translation import gettext_lazy as _
class LogEntryAdmin(admin.ModelAdmin):
#mj 日志条目管理
list_filter = [
'content_type'
]
@ -44,7 +43,6 @@ class LogEntryAdmin(admin.ModelAdmin):
return False
def object_link(self, obj):
#mj 生成对象链接
object_link = escape(obj.object_repr)
content_type = obj.content_type
@ -65,7 +63,6 @@ class LogEntryAdmin(admin.ModelAdmin):
object_link.short_description = _('object')
def user_link(self, obj):
#mj 生成用户链接
content_type = ContentType.objects.get_for_model(type(obj.user))
user_link = escape(force_str(obj.user))
try:
@ -92,4 +89,3 @@ class LogEntryAdmin(admin.ModelAdmin):
if 'delete_selected' in actions:
del actions['delete_selected']
return actions
# [file content end]

@ -1,3 +1,14 @@
"""
Django settings for djangoblog project.
Generated by 'django-admin startproject' using Django 1.10.2.
For more information on this file, see
https://docs.djangoproject.com/en/1.10/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.10/ref/settings/
"""
import os
import sys
from pathlib import Path
@ -6,7 +17,6 @@ from django.utils.translation import gettext_lazy as _
def env_to_bool(env, default):
#mj 环境变量转换为布尔值
str_val = os.environ.get(env)
return default if str_val is None else str_val == 'True'
@ -329,4 +339,3 @@ ACTIVE_PLUGINS = [
'view_count',
'seo_optimizer'
]
# [file content end]

@ -1,5 +1,3 @@
f# [file name]: sitemap.py
# [file content begin]
from django.contrib.sitemaps import Sitemap
from django.urls import reverse
@ -7,7 +5,6 @@ from blog.models import Article, Category, Tag
class StaticViewSitemap(Sitemap):
#mj 静态视图站点地图
priority = 0.5
changefreq = 'daily'
@ -19,7 +16,6 @@ class StaticViewSitemap(Sitemap):
class ArticleSiteMap(Sitemap):
#mj 文章站点地图
changefreq = "monthly"
priority = "0.6"
@ -31,7 +27,6 @@ class ArticleSiteMap(Sitemap):
class CategorySiteMap(Sitemap):
#mj 分类站点地图
changefreq = "Weekly"
priority = "0.6"
@ -43,7 +38,6 @@ class CategorySiteMap(Sitemap):
class TagSiteMap(Sitemap):
#mj 标签站点地图
changefreq = "Weekly"
priority = "0.3"
@ -55,7 +49,6 @@ class TagSiteMap(Sitemap):
class UserSiteMap(Sitemap):
#mj 用户站点地图
changefreq = "Weekly"
priority = "0.3"
@ -64,4 +57,3 @@ class UserSiteMap(Sitemap):
def lastmod(self, obj):
return obj.date_joined
# [file content end]

@ -1,5 +1,3 @@
# [file name]: spider_notify.py
# [file content begin]
import logging
import requests
@ -9,11 +7,15 @@ logger = logging.getLogger(__name__)
class SpiderNotify():
#mj 搜索引擎爬虫通
@staticmethod
def baidu_notify(urls):
try:
data = '\n'.join(urls)
result = requests.post(settings.BAIDU_NOTIFY_URL, data=data)
logger.info(result.text)
except Exception as e:
except Exception as e:
logger.error(e)
@staticmethod
def notify(url):
SpiderNotify.baidu_notify(url)

@ -1,16 +1,13 @@
from django.test import TestCase
from djangoblog.utils import *
class DjangoBlogTest(TestCase):
#mj Django博客测试类
def setUp(self):
pass
def test_utils(self):
#mj 测试工具函数
md5 = get_sha256('test')
self.assertIsNotNone(md5)
c = CommonMarkdown.get_markdown('''
@ -33,4 +30,3 @@ class DjangoBlogTest(TestCase):
}
data = parse_dict_to_url(d)
self.assertIsNotNone(data)
# [file content end]

@ -1,3 +1,18 @@
"""djangoblog URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/1.10/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.conf.urls import url, include
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
"""
from django.conf import settings
from django.conf.urls.i18n import i18n_patterns
from django.conf.urls.static import static
@ -12,7 +27,6 @@ from djangoblog.elasticsearch_backend import ElasticSearchModelSearchForm
from djangoblog.feeds import DjangoBlogFeed
from djangoblog.sitemap import ArticleSiteMap, CategorySiteMap, StaticViewSitemap, TagSiteMap, UserSiteMap
#mj 站点地图配置
sitemaps = {
'blog': ArticleSiteMap,
@ -22,7 +36,6 @@ sitemaps = {
'static': StaticViewSitemap
}
#mj 错误处理视图
handler404 = 'blog.views.page_not_found_view'
handler500 = 'blog.views.server_error_view'
handle403 = 'blog.views.permission_denied_view'
@ -30,7 +43,6 @@ handle403 = 'blog.views.permission_denied_view'
urlpatterns = [
path('i18n/', include('django.conf.urls.i18n')),
]
#mj 国际化URL模式
urlpatterns += i18n_patterns(
re_path(r'^admin/', admin_site.urls),
re_path(r'', include('blog.urls', namespace='blog')),
@ -47,8 +59,6 @@ urlpatterns += i18n_patterns(
re_path(r'', include('servermanager.urls', namespace='servermanager')),
re_path(r'', include('owntracks.urls', namespace='owntracks'))
, prefix_default_language=False) + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
#mj 调试模式下添加媒体文件服务
if settings.DEBUG:
urlpatterns += static(settings.MEDIA_URL,
document_root=settings.MEDIA_ROOT)
# [file content end]

@ -1,3 +1,7 @@
#!/usr/bin/env python
# encoding: utf-8
import logging
import os
import random
@ -17,20 +21,17 @@ logger = logging.getLogger(__name__)
def get_max_articleid_commentid():
#mj 获取最大文章ID和评论ID
from blog.models import Article
from comments.models import Comment
return (Article.objects.latest().pk, Comment.objects.latest().pk)
def get_sha256(str):
#mj 计算字符串的SHA256哈希值
m = sha256(str.encode('utf-8'))
return m.hexdigest()
def cache_decorator(expiration=3 * 60):
#mj 缓存装饰器
def wrapper(func):
def news(*args, **kwargs):
try:
@ -93,13 +94,11 @@ def expire_view_cache(path, servername, serverport, key_prefix=None):
@cache_decorator()
def get_current_site():
#mj 获取当前站点(带缓存)
site = Site.objects.get_current()
return site
class CommonMarkdown:
#mj Markdown处理工具类
@staticmethod
def _convert_markdown(value):
md = markdown.Markdown(
@ -116,19 +115,16 @@ class CommonMarkdown:
@staticmethod
def get_markdown_with_toc(value):
#mj 获取带目录的Markdown
body, toc = CommonMarkdown._convert_markdown(value)
return body, toc
@staticmethod
def get_markdown(value):
#mj 获取Markdown内容
body, toc = CommonMarkdown._convert_markdown(value)
return body
def send_email(emailto, title, content):
#mj 发送邮件
from djangoblog.blog_signals import send_email_signal
send_email_signal.send(
send_email.__class__,
@ -143,7 +139,6 @@ def generate_code() -> str:
def parse_dict_to_url(dict):
#mj 将字典转换为URL查询字符串
from urllib.parse import quote
url = '&'.join(['{}={}'.format(quote(k, safe='/'), quote(v, safe='/'))
for k, v in dict.items()])
@ -151,14 +146,12 @@ def parse_dict_to_url(dict):
def get_blog_setting():
#mj 获取博客设置(带缓存)
value = cache.get('get_blog_setting')
if value:
return value
else:
from blog.models import BlogSettings
if not BlogSettings.objects.count():
#mj 如果不存在设置,创建默认设置
setting = BlogSettings()
setting.site_name = 'djangoblog'
setting.site_description = '基于Django的博客系统'
@ -209,7 +202,6 @@ def save_user_avatar(url):
def delete_sidebar_cache():
#mj 删除侧边栏缓存
from blog.models import LinkShowType
keys = ["sidebar" + x for x in LinkShowType.values]
for k in keys:
@ -218,14 +210,12 @@ def delete_sidebar_cache():
def delete_view_cache(prefix, keys):
#mj 删除视图缓存
from django.core.cache.utils import make_template_fragment_key
key = make_template_fragment_key(prefix, keys)
cache.delete(key)
def get_resource_url():
#mj 获取资源URL
if settings.STATIC_URL:
return settings.STATIC_URL
else:
@ -233,13 +223,10 @@ def get_resource_url():
return 'http://' + site.domain + '/static/'
#mj 允许的HTML标签和属性用于HTML清理
ALLOWED_TAGS = ['a', 'abbr', 'acronym', 'b', 'blockquote', 'code', 'em', 'i', 'li', 'ol', 'pre', 'strong', 'ul', 'h1',
'h2', 'p']
ALLOWED_ATTRIBUTES = {'a': ['href', 'title'], 'abbr': ['title'], 'acronym': ['title']}
def sanitize_html(html):
#mj 清理HTML防止XSS攻击
return bleach.clean(html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRIBUTES)
# [file content end]

@ -1,3 +1,5 @@
# encoding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals
import json
@ -61,7 +63,6 @@ class WhooshHtmlFormatter(HtmlFormatter):
class WhooshSearchBackend(BaseSearchBackend):
#mj Whoosh搜索后端
# Word reserved by Whoosh for special use.
RESERVED_WORDS = (
'AND',
@ -144,7 +145,6 @@ class WhooshSearchBackend(BaseSearchBackend):
self.setup_complete = True
def build_schema(self, fields):
#mj 构建Whoosh索引模式
schema_fields = {
ID: WHOOSH_ID(stored=True, unique=True),
DJANGO_CT: WHOOSH_ID(stored=True),
@ -185,7 +185,6 @@ class WhooshSearchBackend(BaseSearchBackend):
field_boost=field_class.boost)
else:
# schema_fields[field_class.index_fieldname] = TEXT(stored=True, analyzer=StemmingAnalyzer(), field_boost=field_class.boost, sortable=True)
#mj 使用中文分析器处理文本字段
schema_fields[field_class.index_fieldname] = TEXT(
stored=True, analyzer=ChineseAnalyzer(), field_boost=field_class.boost, sortable=True)
if field_class.document is True:
@ -201,7 +200,6 @@ class WhooshSearchBackend(BaseSearchBackend):
return (content_field_name, Schema(**schema_fields))
def update(self, index, iterable, commit=True):
#mj 更新索引
if not self.setup_complete:
self.setup()
@ -247,7 +245,6 @@ class WhooshSearchBackend(BaseSearchBackend):
writer.commit()
def remove(self, obj_or_string, commit=True):
#mj 从索引中移除文档
if not self.setup_complete:
self.setup()
@ -270,7 +267,6 @@ class WhooshSearchBackend(BaseSearchBackend):
exc_info=True)
def clear(self, models=None, commit=True):
#mj 清空索引
if not self.setup_complete:
self.setup()
@ -308,7 +304,6 @@ class WhooshSearchBackend(BaseSearchBackend):
"Failed to clear Whoosh index: %s", e, exc_info=True)
def delete_index(self):
#mj 删除索引文件
# Per the Whoosh mailing list, if wiping out everything from the index,
# it's much more efficient to simply delete the index files.
if self.use_file_storage and os.path.exists(self.path):
@ -320,7 +315,6 @@ class WhooshSearchBackend(BaseSearchBackend):
self.setup()
def optimize(self):
#mj 优化索引
if not self.setup_complete:
self.setup()
@ -328,7 +322,6 @@ class WhooshSearchBackend(BaseSearchBackend):
self.index.optimize()
def calculate_page(self, start_offset=0, end_offset=None):
#mj 计算分页参数
# Prevent against Whoosh throwing an error. Requires an end_offset
# greater than 0.
if end_offset is not None and end_offset <= 0:
@ -373,7 +366,6 @@ class WhooshSearchBackend(BaseSearchBackend):
limit_to_registered_models=None,
result_class=None,
**kwargs):
#mj 执行搜索
if not self.setup_complete:
self.setup()
@ -578,7 +570,6 @@ class WhooshSearchBackend(BaseSearchBackend):
limit_to_registered_models=None,
result_class=None,
**kwargs):
#mj 查找相似文档
if not self.setup_complete:
self.setup()
@ -691,7 +682,6 @@ class WhooshSearchBackend(BaseSearchBackend):
query_string='',
spelling_query=None,
result_class=None):
#mj 处理搜索结果
from haystack import connections
results = []
@ -709,4 +699,346 @@ class WhooshSearchBackend(BaseSearchBackend):
for doc_offset, raw_result in enumerate(raw_page):
score = raw_page.score(doc_offset) or 0
app_label, model_name = raw_result[DJANGO_CT].
app_label, model_name = raw_result[DJANGO_CT].split('.')
additional_fields = {}
model = haystack_get_model(app_label, model_name)
if model and model in indexed_models:
for key, value in raw_result.items():
index = unified_index.get_index(model)
string_key = str(key)
if string_key in index.fields and hasattr(
index.fields[string_key], 'convert'):
# Special-cased due to the nature of KEYWORD fields.
if index.fields[string_key].is_multivalued:
if value is None or len(value) == 0:
additional_fields[string_key] = []
else:
additional_fields[string_key] = value.split(
',')
else:
additional_fields[string_key] = index.fields[string_key].convert(
value)
else:
additional_fields[string_key] = self._to_python(value)
del (additional_fields[DJANGO_CT])
del (additional_fields[DJANGO_ID])
if highlight:
sa = StemmingAnalyzer()
formatter = WhooshHtmlFormatter('em')
terms = [token.text for token in sa(query_string)]
whoosh_result = whoosh_highlight(
additional_fields.get(self.content_field_name),
terms,
sa,
ContextFragmenter(),
formatter
)
additional_fields['highlighted'] = {
self.content_field_name: [whoosh_result],
}
result = result_class(
app_label,
model_name,
raw_result[DJANGO_ID],
score,
**additional_fields)
results.append(result)
else:
hits -= 1
if self.include_spelling:
if spelling_query:
spelling_suggestion = self.create_spelling_suggestion(
spelling_query)
else:
spelling_suggestion = self.create_spelling_suggestion(
query_string)
return {
'results': results,
'hits': hits,
'facets': facets,
'spelling_suggestion': spelling_suggestion,
}
def create_spelling_suggestion(self, query_string):
spelling_suggestion = None
reader = self.index.reader()
corrector = reader.corrector(self.content_field_name)
cleaned_query = force_str(query_string)
if not query_string:
return spelling_suggestion
# Clean the string.
for rev_word in self.RESERVED_WORDS:
cleaned_query = cleaned_query.replace(rev_word, '')
for rev_char in self.RESERVED_CHARACTERS:
cleaned_query = cleaned_query.replace(rev_char, '')
# Break it down.
query_words = cleaned_query.split()
suggested_words = []
for word in query_words:
suggestions = corrector.suggest(word, limit=1)
if len(suggestions) > 0:
suggested_words.append(suggestions[0])
spelling_suggestion = ' '.join(suggested_words)
return spelling_suggestion
def _from_python(self, value):
"""
Converts Python values to a string for Whoosh.
Code courtesy of pysolr.
"""
if hasattr(value, 'strftime'):
if not hasattr(value, 'hour'):
value = datetime(value.year, value.month, value.day, 0, 0, 0)
elif isinstance(value, bool):
if value:
value = 'true'
else:
value = 'false'
elif isinstance(value, (list, tuple)):
value = u','.join([force_str(v) for v in value])
elif isinstance(value, (six.integer_types, float)):
# Leave it alone.
pass
else:
value = force_str(value)
return value
def _to_python(self, value):
"""
Converts values from Whoosh to native Python values.
A port of the same method in pysolr, as they deal with data the same way.
"""
if value == 'true':
return True
elif value == 'false':
return False
if value and isinstance(value, six.string_types):
possible_datetime = DATETIME_REGEX.search(value)
if possible_datetime:
date_values = possible_datetime.groupdict()
for dk, dv in date_values.items():
date_values[dk] = int(dv)
return datetime(
date_values['year'],
date_values['month'],
date_values['day'],
date_values['hour'],
date_values['minute'],
date_values['second'])
try:
# Attempt to use json to load the values.
converted_value = json.loads(value)
# Try to handle most built-in types.
if isinstance(
converted_value,
(list,
tuple,
set,
dict,
six.integer_types,
float,
complex)):
return converted_value
except BaseException:
# If it fails (SyntaxError or its ilk) or we don't trust it,
# continue on.
pass
return value
class WhooshSearchQuery(BaseSearchQuery):
def _convert_datetime(self, date):
if hasattr(date, 'hour'):
return force_str(date.strftime('%Y%m%d%H%M%S'))
else:
return force_str(date.strftime('%Y%m%d000000'))
def clean(self, query_fragment):
"""
Provides a mechanism for sanitizing user input before presenting the
value to the backend.
Whoosh 1.X differs here in that you can no longer use a backslash
to escape reserved characters. Instead, the whole word should be
quoted.
"""
words = query_fragment.split()
cleaned_words = []
for word in words:
if word in self.backend.RESERVED_WORDS:
word = word.replace(word, word.lower())
for char in self.backend.RESERVED_CHARACTERS:
if char in word:
word = "'%s'" % word
break
cleaned_words.append(word)
return ' '.join(cleaned_words)
def build_query_fragment(self, field, filter_type, value):
from haystack import connections
query_frag = ''
is_datetime = False
if not hasattr(value, 'input_type_name'):
# Handle when we've got a ``ValuesListQuerySet``...
if hasattr(value, 'values_list'):
value = list(value)
if hasattr(value, 'strftime'):
is_datetime = True
if isinstance(value, six.string_types) and value != ' ':
# It's not an ``InputType``. Assume ``Clean``.
value = Clean(value)
else:
value = PythonData(value)
# Prepare the query using the InputType.
prepared_value = value.prepare(self)
if not isinstance(prepared_value, (set, list, tuple)):
# Then convert whatever we get back to what pysolr wants if needed.
prepared_value = self.backend._from_python(prepared_value)
# 'content' is a special reserved word, much like 'pk' in
# Django's ORM layer. It indicates 'no special field'.
if field == 'content':
index_fieldname = ''
else:
index_fieldname = u'%s:' % connections[self._using].get_unified_index(
).get_index_fieldname(field)
filter_types = {
'content': '%s',
'contains': '*%s*',
'endswith': "*%s",
'startswith': "%s*",
'exact': '%s',
'gt': "{%s to}",
'gte': "[%s to]",
'lt': "{to %s}",
'lte': "[to %s]",
'fuzzy': u'%s~',
}
if value.post_process is False:
query_frag = prepared_value
else:
if filter_type in [
'content',
'contains',
'startswith',
'endswith',
'fuzzy']:
if value.input_type_name == 'exact':
query_frag = prepared_value
else:
# Iterate over terms & incorportate the converted form of
# each into the query.
terms = []
if isinstance(prepared_value, six.string_types):
possible_values = prepared_value.split(' ')
else:
if is_datetime is True:
prepared_value = self._convert_datetime(
prepared_value)
possible_values = [prepared_value]
for possible_value in possible_values:
terms.append(
filter_types[filter_type] %
self.backend._from_python(possible_value))
if len(terms) == 1:
query_frag = terms[0]
else:
query_frag = u"(%s)" % " AND ".join(terms)
elif filter_type == 'in':
in_options = []
for possible_value in prepared_value:
is_datetime = False
if hasattr(possible_value, 'strftime'):
is_datetime = True
pv = self.backend._from_python(possible_value)
if is_datetime is True:
pv = self._convert_datetime(pv)
if isinstance(pv, six.string_types) and not is_datetime:
in_options.append('"%s"' % pv)
else:
in_options.append('%s' % pv)
query_frag = "(%s)" % " OR ".join(in_options)
elif filter_type == 'range':
start = self.backend._from_python(prepared_value[0])
end = self.backend._from_python(prepared_value[1])
if hasattr(prepared_value[0], 'strftime'):
start = self._convert_datetime(start)
if hasattr(prepared_value[1], 'strftime'):
end = self._convert_datetime(end)
query_frag = u"[%s to %s]" % (start, end)
elif filter_type == 'exact':
if value.input_type_name == 'exact':
query_frag = prepared_value
else:
prepared_value = Exact(prepared_value).prepare(self)
query_frag = filter_types[filter_type] % prepared_value
else:
if is_datetime is True:
prepared_value = self._convert_datetime(prepared_value)
query_frag = filter_types[filter_type] % prepared_value
if len(query_frag) and not isinstance(value, Raw):
if not query_frag.startswith('(') and not query_frag.endswith(')'):
query_frag = "(%s)" % query_frag
return u"%s%s" % (index_fieldname, query_frag)
# if not filter_type in ('in', 'range'):
# # 'in' is a bit of a special case, as we don't want to
# # convert a valid list/tuple to string. Defer handling it
# # until later...
# value = self.backend._from_python(value)
class WhooshEngine(BaseEngine):
backend = WhooshSearchBackend
query = WhooshSearchQuery

@ -9,32 +9,29 @@ logger = logging.getLogger(__name__)
class OAuthUserAdmin(admin.ModelAdmin):
search_fields = ('nickname', 'email') # zzh: 配置搜索字段,支持按昵称和邮箱搜索
list_per_page = 20 # zzh: 设置列表页每页显示20条记录
search_fields = ('nickname', 'email')
list_per_page = 20
list_display = (
'id',
'nickname',
'link_to_usermodel', # zzh: 自定义字段,显示关联用户模型的链接
'show_user_image', # zzh: 自定义字段,显示用户头像
'link_to_usermodel',
'show_user_image',
'type',
'email',
)
list_display_links = ('id', 'nickname') # zzh: 设置哪些字段可以作为链接点击进入编辑页
list_filter = ('author', 'type',) # zzh: 配置右侧筛选器,支持按作者和类型筛选
readonly_fields = [] # zzh: 初始化只读字段列表
list_display_links = ('id', 'nickname')
list_filter = ('author', 'type',)
readonly_fields = []
def get_readonly_fields(self, request, obj=None):
# zzh: 动态设置所有字段为只读防止在admin中修改OAuth用户数据
return list(self.readonly_fields) + \
[field.name for field in obj._meta.fields] + \
[field.name for field in obj._meta.many_to_many]
def has_add_permission(self, request):
# zzh: 禁用添加权限OAuth用户应该通过认证流程自动创建
return False
def link_to_usermodel(self, obj):
# zzh: 自定义方法,生成指向关联用户编辑页面的链接
if obj.author:
info = (obj.author._meta.app_label, obj.author._meta.model_name)
link = reverse('admin:%s_%s_change' % info, args=(obj.author.id,))
@ -43,16 +40,15 @@ class OAuthUserAdmin(admin.ModelAdmin):
(link, obj.author.nickname if obj.author.nickname else obj.author.email))
def show_user_image(self, obj):
# zzh: 自定义方法在admin中显示用户头像图片
img = obj.picture
return format_html(
u'<img src="%s" style="width:50px;height:50px"></img>' %
(img))
link_to_usermodel.short_description = '用户' # zzh: 设置自定义字段在admin中的显示名称
show_user_image.short_description = '用户头像' # zzh: 设置自定义字段在admin中的显示名称
link_to_usermodel.short_description = '用户'
show_user_image.short_description = '用户头像'
class OAuthConfigAdmin(admin.ModelAdmin):
list_display = ('type', 'appkey', 'appsecret', 'is_enable') # zzh: 配置OAuth配置项的显示字段
list_filter = ('type',) # zzh: 支持按OAuth类型进行筛选
list_display = ('type', 'appkey', 'appsecret', 'is_enable')
list_filter = ('type',)

@ -2,4 +2,4 @@ from django.apps import AppConfig
class OauthConfig(AppConfig):
name = 'oauth' # zzh: 定义Django应用的完整Python路径这是AppConfig必须设置的属性
name = 'oauth'

@ -3,14 +3,10 @@ from django.forms import widgets
class RequireEmailForm(forms.Form):
# zzh: 定义邮箱字段,标签为'电子邮箱',且为必填字段
email = forms.EmailField(label='电子邮箱', required=True)
# zzh: 定义隐藏的oauthid字段用于在表单中传递但不显示给用户的OAuth身份ID
oauthid = forms.IntegerField(widget=forms.HiddenInput, required=False)
def __init__(self, *args, **kwargs):
# zzh: 初始化表单,调用父类的构造方法
super(RequireEmailForm, self).__init__(*args, **kwargs)
# zzh: 自定义邮箱字段的widget添加placeholder提示和CSS类名
self.fields['email'].widget = widgets.EmailInput(
attrs={'placeholder': "email", "class": "form-control"})
attrs={'placeholder': "email", "class": "form-control"})

@ -7,47 +7,32 @@ from django.utils.translation import gettext_lazy as _
class OAuthUser(models.Model):
# zzh: 关联到系统用户模型允许为空表示OAuth用户可能未绑定本地用户
author = models.ForeignKey(
settings.AUTH_USER_MODEL,
verbose_name=_('author'),
blank=True,
null=True,
on_delete=models.CASCADE)
# zzh: OAuth服务提供商返回的用户唯一标识
openid = models.CharField(max_length=50)
# zzh: 用户在第三方平台的昵称
nickname = models.CharField(max_length=50, verbose_name=_('nick name'))
# zzh: OAuth访问令牌用于调用第三方API
token = models.CharField(max_length=150, null=True, blank=True)
# zzh: 用户头像的URL地址
picture = models.CharField(max_length=350, blank=True, null=True)
# zzh: OAuth类型如weibo、github等
type = models.CharField(blank=False, null=False, max_length=50)
# zzh: 用户邮箱,可能为空
email = models.CharField(max_length=50, null=True, blank=True)
# zzh: 存储额外的OAuth用户信息通常为JSON格式
metadata = models.TextField(null=True, blank=True)
# zzh: 记录创建时间,自动设置为当前时间
creation_time = models.DateTimeField(_('creation time'), default=now)
# zzh: 记录最后修改时间,自动设置为当前时间
last_modify_time = models.DateTimeField(_('last modify time'), default=now)
def __str__(self):
# zzh: 定义对象的字符串表示,返回用户昵称
return self.nickname
class Meta:
# zzh: 在Django admin中显示的单数名称
verbose_name = _('oauth user')
# zzh: 在Django admin中显示的复数名称
verbose_name_plural = verbose_name
# zzh: 默认按创建时间降序排列
ordering = ['-creation_time']
class OAuthConfig(models.Model):
# zzh: 定义支持的OAuth类型选项
TYPE = (
('weibo', _('weibo')),
('google', _('google')),
@ -55,39 +40,28 @@ class OAuthConfig(models.Model):
('facebook', 'FaceBook'),
('qq', 'QQ'),
)
# zzh: OAuth类型选择字段从预定义类型中选择
type = models.CharField(_('type'), max_length=10, choices=TYPE, default='a')
# zzh: OAuth应用的AppKey/Client ID
appkey = models.CharField(max_length=200, verbose_name='AppKey')
# zzh: OAuth应用的AppSecret/Client Secret
appsecret = models.CharField(max_length=200, verbose_name='AppSecret')
# zzh: OAuth回调地址认证成功后跳转的URL
callback_url = models.CharField(
max_length=200,
verbose_name=_('callback url'),
blank=False,
default='')
# zzh: 标记该OAuth配置是否启用
is_enable = models.BooleanField(
_('is enable'), default=True, blank=False, null=False)
# zzh: 配置创建时间
creation_time = models.DateTimeField(_('creation time'), default=now)
# zzh: 配置最后修改时间
last_modify_time = models.DateTimeField(_('last modify time'), default=now)
def clean(self):
# zzh: 数据验证方法确保同类型的OAuth配置只能有一个
if OAuthConfig.objects.filter(
type=self.type).exclude(id=self.id).count():
raise ValidationError(_(self.type + _('already exists')))
def __str__(self):
# zzh: 返回OAuth类型作为对象的字符串表示
return self.type
class Meta:
# zzh: 在Django admin中显示的中文名称
verbose_name = 'oauth配置'
verbose_name_plural = verbose_name
# zzh: 默认按创建时间降序排列
ordering = ['-creation_time']
ordering = ['-creation_time']

@ -14,72 +14,64 @@ logger = logging.getLogger(__name__)
class OAuthAccessTokenException(Exception):
'''
oauth授权失败异常 zzh: 自定义异常类用于处理OAuth授权过程中的错误
oauth授权失败异常
'''
class BaseOauthManager(metaclass=ABCMeta):
"""获取用户授权""" zzh: OAuth管理器的抽象基类定义了OAuth流程的标准接口
AUTH_URL = None # zzh: 授权页面URL用户在此页面进行OAuth授权
TOKEN_URL = None # zzh: 获取access_token的API地址
API_URL = None # zzh: 获取用户信息的API地址
ICON_NAME = None # zzh: OAuth平台标识名用于区分不同平台
"""获取用户授权"""
AUTH_URL = None
"""获取token"""
TOKEN_URL = None
"""获取用户信息"""
API_URL = None
'''icon图标名'''
ICON_NAME = None
def __init__(self, access_token=None, openid=None):
# zzh: 初始化OAuth管理器可传入已获取的access_token和openid
self.access_token = access_token
self.openid = openid
@property
def is_access_token_set(self):
# zzh: 检查access_token是否已设置
return self.access_token is not None
@property
def is_authorized(self):
# zzh: 检查是否已完成授权既有access_token又有openid
return self.is_access_token_set and self.access_token is not None and self.openid is not None
@abstractmethod
def get_authorization_url(self, nexturl='/'):
# zzh: 抽象方法生成授权页面URL
pass
@abstractmethod
def get_access_token_by_code(self, code):
# zzh: 抽象方法通过授权码获取access_token
pass
@abstractmethod
def get_oauth_userinfo(self):
# zzh: 抽象方法,获取用户信息
pass
@abstractmethod
def get_picture(self, metadata):
# zzh: 抽象方法,从元数据中提取用户头像
pass
def do_get(self, url, params, headers=None):
# zzh: 执行GET请求的通用方法
rsp = requests.get(url=url, params=params, headers=headers)
logger.info(rsp.text)
return rsp.text
def do_post(self, url, params, headers=None):
# zzh: 执行POST请求的通用方法
rsp = requests.post(url, params, headers=headers)
logger.info(rsp.text)
return rsp.text
def get_config(self):
# zzh: 从数据库获取对应平台的OAuth配置
value = OAuthConfig.objects.filter(type=self.ICON_NAME)
return value[0] if value else None
class WBOauthManager(BaseOauthManager):
# zzh: 微博OAuth管理器实现
AUTH_URL = 'https://api.weibo.com/oauth2/authorize'
TOKEN_URL = 'https://api.weibo.com/oauth2/access_token'
API_URL = 'https://api.weibo.com/2/users/show.json'
@ -87,9 +79,9 @@ class WBOauthManager(BaseOauthManager):
def __init__(self, access_token=None, openid=None):
config = self.get_config()
self.client_id = config.appkey if config else '' # zzh: 微博应用AppKey
self.client_secret = config.appsecret if config else '' # zzh: 微博应用AppSecret
self.callback_url = config.callback_url if config else '' # zzh: 回调地址
self.client_id = config.appkey if config else ''
self.client_secret = config.appsecret if config else ''
self.callback_url = config.callback_url if config else ''
super(
WBOauthManager,
self).__init__(
@ -97,7 +89,6 @@ class WBOauthManager(BaseOauthManager):
openid=openid)
def get_authorization_url(self, nexturl='/'):
# zzh: 生成微博授权URL包含next_url参数用于授权后跳转
params = {
'client_id': self.client_id,
'response_type': 'code',
@ -107,7 +98,7 @@ class WBOauthManager(BaseOauthManager):
return url
def get_access_token_by_code(self, code):
# zzh: 使用授权码获取微博access_token
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
@ -120,7 +111,7 @@ class WBOauthManager(BaseOauthManager):
obj = json.loads(rsp)
if 'access_token' in obj:
self.access_token = str(obj['access_token'])
self.openid = str(obj['uid']) # zzh: 微博返回的是uid字段
self.openid = str(obj['uid'])
return self.get_oauth_userinfo()
else:
raise OAuthAccessTokenException(rsp)
@ -136,9 +127,9 @@ class WBOauthManager(BaseOauthManager):
try:
datas = json.loads(rsp)
user = OAuthUser()
user.metadata = rsp # zzh: 保存原始API响应数据
user.picture = datas['avatar_large'] # zzh: 微博大头像
user.nickname = datas['screen_name'] # zzh: 微博昵称
user.metadata = rsp
user.picture = datas['avatar_large']
user.nickname = datas['screen_name']
user.openid = datas['id']
user.type = 'weibo'
user.token = self.access_token
@ -151,13 +142,11 @@ class WBOauthManager(BaseOauthManager):
return None
def get_picture(self, metadata):
# zzh: 从元数据中提取微博用户头像
datas = json.loads(metadata)
return datas['avatar_large']
class ProxyManagerMixin:
# zzh: 代理混入类为需要代理访问的OAuth平台提供支持
def __init__(self, *args, **kwargs):
if os.environ.get("HTTP_PROXY"):
self.proxies = {
@ -168,20 +157,17 @@ class ProxyManagerMixin:
self.proxies = None
def do_get(self, url, params, headers=None):
# zzh: 重写GET方法支持代理
rsp = requests.get(url=url, params=params, headers=headers, proxies=self.proxies)
logger.info(rsp.text)
return rsp.text
def do_post(self, url, params, headers=None):
# zzh: 重写POST方法支持代理
rsp = requests.post(url, params, headers=headers, proxies=self.proxies)
logger.info(rsp.text)
return rsp.text
class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
# zzh: Google OAuth管理器实现继承代理混入类
AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth'
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
API_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
@ -199,7 +185,6 @@ class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
openid=openid)
def get_authorization_url(self, nexturl='/'):
# zzh: 生成Google授权URLscope包含openid和email
params = {
'client_id': self.client_id,
'response_type': 'code',
@ -224,7 +209,7 @@ class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
if 'access_token' in obj:
self.access_token = str(obj['access_token'])
self.openid = str(obj['id_token']) # zzh: Google使用id_token作为openid
self.openid = str(obj['id_token'])
logger.info(self.ICON_NAME + ' oauth ' + rsp)
return self.access_token
else:
@ -242,9 +227,9 @@ class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
datas = json.loads(rsp)
user = OAuthUser()
user.metadata = rsp
user.picture = datas['picture'] # zzh: Google用户头像
user.picture = datas['picture']
user.nickname = datas['name']
user.openid = datas['sub'] # zzh: Google用户唯一标识
user.openid = datas['sub']
user.token = self.access_token
user.type = 'google'
if datas['email']:
@ -261,7 +246,6 @@ class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
# zzh: GitHub OAuth管理器实现
AUTH_URL = 'https://github.com/login/oauth/authorize'
TOKEN_URL = 'https://github.com/login/oauth/access_token'
API_URL = 'https://api.github.com/user'
@ -283,7 +267,7 @@ class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': f'{self.callback_url}&next_url={next_url}',
'scope': 'user' # zzh: GitHub授权范围获取用户基本信息
'scope': 'user'
}
url = self.AUTH_URL + "?" + urllib.parse.urlencode(params)
return url
@ -300,7 +284,7 @@ class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
rsp = self.do_post(self.TOKEN_URL, params)
from urllib import parse
r = parse.parse_qs(rsp) # zzh: GitHub返回的是查询字符串格式需要解析
r = parse.parse_qs(rsp)
if 'access_token' in r:
self.access_token = (r['access_token'][0])
return self.access_token
@ -308,14 +292,14 @@ class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
raise OAuthAccessTokenException(rsp)
def get_oauth_userinfo(self):
# zzh: GitHub获取用户信息需要在header中传递token
rsp = self.do_get(self.API_URL, params={}, headers={
"Authorization": "token " + self.access_token
})
try:
datas = json.loads(rsp)
user = OAuthUser()
user.picture = datas['avatar_url'] # zzh: GitHub头像URL
user.picture = datas['avatar_url']
user.nickname = datas['name']
user.openid = datas['id']
user.type = 'github'
@ -335,7 +319,6 @@ class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
# zzh: Facebook OAuth管理器实现
AUTH_URL = 'https://www.facebook.com/v16.0/dialog/oauth'
TOKEN_URL = 'https://graph.facebook.com/v16.0/oauth/access_token'
API_URL = 'https://graph.facebook.com/me'
@ -353,7 +336,6 @@ class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
openid=openid)
def get_authorization_url(self, next_url='/'):
# zzh: Facebook授权范围包含email和public_profile
params = {
'client_id': self.client_id,
'response_type': 'code',
@ -367,7 +349,7 @@ class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
# 'grant_type': 'authorization_code', # zzh: Facebook不需要显式指定grant_type
# 'grant_type': 'authorization_code',
'code': code,
'redirect_uri': self.callback_url
@ -383,7 +365,6 @@ class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
raise OAuthAccessTokenException(rsp)
def get_oauth_userinfo(self):
# zzh: Facebook需要指定fields参数来获取特定字段
params = {
'access_token': self.access_token,
'fields': 'id,name,picture,email'
@ -400,7 +381,6 @@ class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
if 'email' in datas and datas['email']:
user.email = datas['email']
if 'picture' in datas and datas['picture'] and datas['picture']['data'] and datas['picture']['data']['url']:
# zzh: Facebook头像URL嵌套在多层结构中
user.picture = str(datas['picture']['data']['url'])
return user
except Exception as e:
@ -413,11 +393,10 @@ class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
class QQOauthManager(BaseOauthManager):
# zzh: QQ OAuth管理器实现
AUTH_URL = 'https://graph.qq.com/oauth2.0/authorize'
TOKEN_URL = 'https://graph.qq.com/oauth2.0/token'
API_URL = 'https://graph.qq.com/user/get_user_info'
OPEN_ID_URL = 'https://graph.qq.com/oauth2.0/me' # zzh: QQ需要单独获取openid
OPEN_ID_URL = 'https://graph.qq.com/oauth2.0/me'
ICON_NAME = 'qq'
def __init__(self, access_token=None, openid=None):
@ -450,7 +429,7 @@ class QQOauthManager(BaseOauthManager):
}
rsp = self.do_get(self.TOKEN_URL, params)
if rsp:
d = urllib.parse.parse_qs(rsp) # zzh: QQ返回查询字符串格式
d = urllib.parse.parse_qs(rsp)
if 'access_token' in d:
token = d['access_token']
self.access_token = token[0]
@ -459,14 +438,12 @@ class QQOauthManager(BaseOauthManager):
raise OAuthAccessTokenException(rsp)
def get_open_id(self):
# zzh: QQ需要单独调用接口获取openid
if self.is_access_token_set:
params = {
'access_token': self.access_token
}
rsp = self.do_get(self.OPEN_ID_URL, params)
if rsp:
# zzh: QQ返回的是JSONP格式需要清理
rsp = rsp.replace(
'callback(', '').replace(
')', '').replace(
@ -481,7 +458,7 @@ class QQOauthManager(BaseOauthManager):
if openid:
params = {
'access_token': self.access_token,
'oauth_consumer_key': self.client_id, # zzh: QQ需要传递appkey
'oauth_consumer_key': self.client_id,
'openid': self.openid
}
rsp = self.do_get(self.API_URL, params)
@ -496,7 +473,7 @@ class QQOauthManager(BaseOauthManager):
if 'email' in obj:
user.email = obj['email']
if 'figureurl' in obj:
user.picture = str(obj['figureurl']) # zzh: QQ标准头像
user.picture = str(obj['figureurl'])
return user
def get_picture(self, metadata):
@ -506,7 +483,6 @@ class QQOauthManager(BaseOauthManager):
@cache_decorator(expiration=100 * 60)
def get_oauth_apps():
# zzh: 获取所有启用的OAuth应用使用缓存提高性能
configs = OAuthConfig.objects.filter(is_enable=True).all()
if not configs:
return []
@ -517,7 +493,6 @@ def get_oauth_apps():
def get_manager_by_type(type):
# zzh: 根据类型获取对应的OAuth管理器实例
applications = get_oauth_apps()
if applications:
finds = list(
@ -526,4 +501,4 @@ def get_manager_by_type(type):
applications))
if finds:
return finds[0]
return None
return None

@ -14,12 +14,10 @@ from oauth.oauthmanager import BaseOauthManager
# Create your tests here.
class OAuthConfigTest(TestCase):
def setUp(self):
# zzh: 测试类初始化,创建测试客户端和请求工厂
self.client = Client()
self.factory = RequestFactory()
def test_oauth_login_test(self):
# zzh: 测试OAuth配置和登录流程
c = OAuthConfig()
c.type = 'weibo'
c.appkey = 'appkey'
@ -27,23 +25,21 @@ class OAuthConfigTest(TestCase):
c.save()
response = self.client.get('/oauth/oauthlogin?type=weibo')
self.assertEqual(response.status_code, 302) # zzh: 验证重定向状态码
self.assertTrue("api.weibo.com" in response.url) # zzh: 验证重定向到微博授权页面
self.assertEqual(response.status_code, 302)
self.assertTrue("api.weibo.com" in response.url)
response = self.client.get('/oauth/authorize?type=weibo&code=code')
self.assertEqual(response.status_code, 302) # zzh: 验证授权后重定向
self.assertEqual(response.url, '/') # zzh: 验证重定向到首页
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, '/')
class OauthLoginTest(TestCase):
def setUp(self) -> None:
# zzh: 测试类初始化创建测试环境和初始化OAuth应用
self.client = Client()
self.factory = RequestFactory()
self.apps = self.init_apps()
def init_apps(self):
# zzh: 初始化所有OAuth应用配置为每个平台创建测试配置
applications = [p() for p in BaseOauthManager.__subclasses__()]
for application in applications:
c = OAuthConfig()
@ -54,7 +50,6 @@ class OauthLoginTest(TestCase):
return applications
def get_app_by_type(self, type):
# zzh: 根据类型获取对应的OAuth应用实例
for app in self.apps:
if app.ICON_NAME.lower() == type:
return app
@ -62,15 +57,12 @@ class OauthLoginTest(TestCase):
@patch("oauth.oauthmanager.WBOauthManager.do_post")
@patch("oauth.oauthmanager.WBOauthManager.do_get")
def test_weibo_login(self, mock_do_get, mock_do_post):
# zzh: 测试微博OAuth登录流程使用mock模拟API调用
weibo_app = self.get_app_by_type('weibo')
assert weibo_app
url = weibo_app.get_authorization_url()
# zzh: 模拟微博API返回的token响应
mock_do_post.return_value = json.dumps({"access_token": "access_token",
"uid": "uid"
})
# zzh: 模拟微博API返回的用户信息
mock_do_get.return_value = json.dumps({
"avatar_large": "avatar_large",
"screen_name": "screen_name",
@ -78,22 +70,19 @@ class OauthLoginTest(TestCase):
"email": "email",
})
userinfo = weibo_app.get_access_token_by_code('code')
self.assertEqual(userinfo.token, 'access_token') # zzh: 验证token正确性
self.assertEqual(userinfo.openid, 'id') # zzh: 验证openid正确性
self.assertEqual(userinfo.token, 'access_token')
self.assertEqual(userinfo.openid, 'id')
@patch("oauth.oauthmanager.GoogleOauthManager.do_post")
@patch("oauth.oauthmanager.GoogleOauthManager.do_get")
def test_google_login(self, mock_do_get, mock_do_post):
# zzh: 测试Google OAuth登录流程
google_app = self.get_app_by_type('google')
assert google_app
url = google_app.get_authorization_url()
# zzh: 模拟Google API返回的token响应
mock_do_post.return_value = json.dumps({
"access_token": "access_token",
"id_token": "id_token",
})
# zzh: 模拟Google API返回的用户信息
mock_do_get.return_value = json.dumps({
"picture": "picture",
"name": "name",
@ -102,21 +91,18 @@ class OauthLoginTest(TestCase):
})
token = google_app.get_access_token_by_code('code')
userinfo = google_app.get_oauth_userinfo()
self.assertEqual(userinfo.token, 'access_token') # zzh: 验证token正确性
self.assertEqual(userinfo.openid, 'sub') # zzh: 验证openid正确性Google使用sub字段
self.assertEqual(userinfo.token, 'access_token')
self.assertEqual(userinfo.openid, 'sub')
@patch("oauth.oauthmanager.GitHubOauthManager.do_post")
@patch("oauth.oauthmanager.GitHubOauthManager.do_get")
def test_github_login(self, mock_do_get, mock_do_post):
# zzh: 测试GitHub OAuth登录流程
github_app = self.get_app_by_type('github')
assert github_app
url = github_app.get_authorization_url()
self.assertTrue("github.com" in url) # zzh: 验证授权URL包含GitHub域名
self.assertTrue("client_id" in url) # zzh: 验证授权URL包含client_id参数
# zzh: 模拟GitHub API返回的token响应查询字符串格式
self.assertTrue("github.com" in url)
self.assertTrue("client_id" in url)
mock_do_post.return_value = "access_token=gho_16C7e42F292c6912E7710c838347Ae178B4a&scope=repo%2Cgist&token_type=bearer"
# zzh: 模拟GitHub API返回的用户信息
mock_do_get.return_value = json.dumps({
"avatar_url": "avatar_url",
"name": "name",
@ -125,22 +111,19 @@ class OauthLoginTest(TestCase):
})
token = github_app.get_access_token_by_code('code')
userinfo = github_app.get_oauth_userinfo()
self.assertEqual(userinfo.token, 'gho_16C7e42F292c6912E7710c838347Ae178B4a') # zzh: 验证token正确性
self.assertEqual(userinfo.openid, 'id') # zzh: 验证openid正确性
self.assertEqual(userinfo.token, 'gho_16C7e42F292c6912E7710c838347Ae178B4a')
self.assertEqual(userinfo.openid, 'id')
@patch("oauth.oauthmanager.FaceBookOauthManager.do_post")
@patch("oauth.oauthmanager.FaceBookOauthManager.do_get")
def test_facebook_login(self, mock_do_get, mock_do_post):
# zzh: 测试Facebook OAuth登录流程
facebook_app = self.get_app_by_type('facebook')
assert facebook_app
url = facebook_app.get_authorization_url()
self.assertTrue("facebook.com" in url) # zzh: 验证授权URL包含Facebook域名
# zzh: 模拟Facebook API返回的token响应
self.assertTrue("facebook.com" in url)
mock_do_post.return_value = json.dumps({
"access_token": "access_token",
})
# zzh: 模拟Facebook API返回的用户信息包含嵌套的头像数据结构
mock_do_get.return_value = json.dumps({
"name": "name",
"id": "id",
@ -153,13 +136,12 @@ class OauthLoginTest(TestCase):
})
token = facebook_app.get_access_token_by_code('code')
userinfo = facebook_app.get_oauth_userinfo()
self.assertEqual(userinfo.token, 'access_token') # zzh: 验证token正确性
self.assertEqual(userinfo.token, 'access_token')
@patch("oauth.oauthmanager.QQOauthManager.do_get", side_effect=[
# zzh: 使用side_effect模拟QQ OAuth三个连续的API调用
'access_token=access_token&expires_in=3600', # zzh: 模拟获取token的响应
'callback({"client_id":"appid","openid":"openid"} );', # zzh: 模拟获取openid的响应JSONP格式
json.dumps({ # zzh: 模拟获取用户信息的响应
'access_token=access_token&expires_in=3600',
'callback({"client_id":"appid","openid":"openid"} );',
json.dumps({
"nickname": "nickname",
"email": "email",
"figureurl": "figureurl",
@ -167,19 +149,18 @@ class OauthLoginTest(TestCase):
})
])
def test_qq_login(self, mock_do_get):
# zzh: 测试QQ OAuth登录流程需要三个连续的API调用
qq_app = self.get_app_by_type('qq')
assert qq_app
url = qq_app.get_authorization_url()
self.assertTrue("qq.com" in url) # zzh: 验证授权URL包含QQ域名
self.assertTrue("qq.com" in url)
token = qq_app.get_access_token_by_code('code')
userinfo = qq_app.get_oauth_userinfo()
self.assertEqual(userinfo.token, 'access_token') # zzh: 验证token正确性
self.assertEqual(userinfo.token, 'access_token')
@patch("oauth.oauthmanager.WBOauthManager.do_post")
@patch("oauth.oauthmanager.WBOauthManager.do_get")
def test_weibo_authoriz_login_with_email(self, mock_do_get, mock_do_post):
# zzh: 测试包含邮箱的微博授权登录完整流程
mock_do_post.return_value = json.dumps({"access_token": "access_token",
"uid": "uid"
})
@ -197,16 +178,15 @@ class OauthLoginTest(TestCase):
response = self.client.get('/oauth/authorize?type=weibo&code=code')
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, '/') # zzh: 有邮箱时直接登录成功并跳转首页
self.assertEqual(response.url, '/')
user = auth.get_user(self.client)
assert user.is_authenticated # zzh: 验证用户已认证
assert user.is_authenticated
self.assertTrue(user.is_authenticated)
self.assertEqual(user.username, mock_user_info['screen_name']) # zzh: 验证用户名正确
self.assertEqual(user.email, mock_user_info['email']) # zzh: 验证邮箱正确
self.client.logout() # zzh: 登出用户
self.assertEqual(user.username, mock_user_info['screen_name'])
self.assertEqual(user.email, mock_user_info['email'])
self.client.logout()
# zzh: 测试重复登录场景
response = self.client.get('/oauth/authorize?type=weibo&code=code')
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, '/')
@ -220,7 +200,7 @@ class OauthLoginTest(TestCase):
@patch("oauth.oauthmanager.WBOauthManager.do_post")
@patch("oauth.oauthmanager.WBOauthManager.do_get")
def test_weibo_authoriz_login_without_email(self, mock_do_get, mock_do_post):
# zzh: 测试不包含邮箱的微博授权登录流程(需要补充邮箱)
mock_do_post.return_value = json.dumps({"access_token": "access_token",
"uid": "uid"
})
@ -239,34 +219,31 @@ class OauthLoginTest(TestCase):
self.assertEqual(response.status_code, 302)
oauth_user_id = int(response.url.split('/')[-1].split('.')[0]) # zzh: 从URL中提取oauth_user_id
self.assertEqual(response.url, f'/oauth/requireemail/{oauth_user_id}.html') # zzh: 无邮箱时跳转到补充邮箱页面
oauth_user_id = int(response.url.split('/')[-1].split('.')[0])
self.assertEqual(response.url, f'/oauth/requireemail/{oauth_user_id}.html')
# zzh: 提交邮箱表单
response = self.client.post(response.url, {'email': 'test@gmail.com', 'oauthid': oauth_user_id})
self.assertEqual(response.status_code, 302)
# zzh: 生成邮箱确认签名
sign = get_sha256(settings.SECRET_KEY +
str(oauth_user_id) + settings.SECRET_KEY)
url = reverse('oauth:bindsuccess', kwargs={
'oauthid': oauth_user_id,
})
self.assertEqual(response.url, f'{url}?type=email') # zzh: 跳转到绑定成功页面
self.assertEqual(response.url, f'{url}?type=email')
# zzh: 模拟邮箱确认链接点击
path = reverse('oauth:email_confirm', kwargs={
'id': oauth_user_id,
'sign': sign
})
response = self.client.get(path)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, f'/oauth/bindsuccess/{oauth_user_id}.html?type=success') # zzh: 绑定成功跳转
self.assertEqual(response.url, f'/oauth/bindsuccess/{oauth_user_id}.html?type=success')
user = auth.get_user(self.client)
from oauth.models import OAuthUser
oauth_user = OAuthUser.objects.get(author=user) # zzh: 验证OAuth用户关联关系
self.assertTrue(user.is_authenticated) # zzh: 验证用户已认证
self.assertEqual(user.username, mock_user_info['screen_name']) # zzh: 验证用户名正确
self.assertEqual(user.email, 'test@gmail.com') # zzh: 验证补充的邮箱正确
self.assertEqual(oauth_user.pk, oauth_user_id) # zzh: 验证OAuth用户ID正确
oauth_user = OAuthUser.objects.get(author=user)
self.assertTrue(user.is_authenticated)
self.assertEqual(user.username, mock_user_info['screen_name'])
self.assertEqual(user.email, 'test@gmail.com')
self.assertEqual(oauth_user.pk, oauth_user_id)

@ -2,24 +2,24 @@ from django.urls import path
from . import views
app_name = "oauth" # zzh: 定义应用的命名空间用于反向解析URL时避免冲突
app_name = "oauth"
urlpatterns = [
path(
r'oauth/authorize', # zzh: OAuth授权回调URL第三方平台授权后跳转至此
views.authorize), # zzh: 处理授权回调获取access_token和用户信息
r'oauth/authorize',
views.authorize),
path(
r'oauth/requireemail/<int:oauthid>.html', # zzh: 需要补充邮箱的页面URL包含oauthid参数
views.RequireEmailView.as_view(), # zzh: 类视图,处理需要补充邮箱的情况
name='require_email'), # zzh: URL名称用于反向解析
r'oauth/requireemail/<int:oauthid>.html',
views.RequireEmailView.as_view(),
name='require_email'),
path(
r'oauth/emailconfirm/<int:id>/<sign>.html', # zzh: 邮箱确认URL包含用户ID和签名验证
views.emailconfirm, # zzh: 验证邮箱确认链接的签名并完成绑定
name='email_confirm'), # zzh: URL名称用于生成邮箱确认链接
r'oauth/emailconfirm/<int:id>/<sign>.html',
views.emailconfirm,
name='email_confirm'),
path(
r'oauth/bindsuccess/<int:oauthid>.html', # zzh: 绑定成功页面URL
views.bindsuccess, # zzh: 显示绑定成功信息
name='bindsuccess'), # zzh: URL名称用于跳转到绑定成功页面
r'oauth/bindsuccess/<int:oauthid>.html',
views.bindsuccess,
name='bindsuccess'),
path(
r'oauth/oauthlogin', # zzh: OAuth登录入口URL
views.oauthlogin, # zzh: 跳转到第三方平台授权页面
name='oauthlogin')] # zzh: URL名称用于生成OAuth登录链接
r'oauth/oauthlogin',
views.oauthlogin,
name='oauthlogin')]

@ -27,7 +27,6 @@ logger = logging.getLogger(__name__)
def get_redirecturl(request):
# zzh: 获取重定向URL进行安全性检查
nexturl = request.GET.get('next_url', None)
if not nexturl or nexturl == '/login/' or nexturl == '/login':
nexturl = '/'
@ -42,7 +41,6 @@ def get_redirecturl(request):
def oauthlogin(request):
# zzh: OAuth登录入口跳转到第三方平台授权页面
type = request.GET.get('type', None)
if not type:
return HttpResponseRedirect('/')
@ -55,7 +53,6 @@ def oauthlogin(request):
def authorize(request):
# zzh: OAuth授权回调处理获取用户信息并登录
type = request.GET.get('type', None)
if not type:
return HttpResponseRedirect('/')
@ -76,11 +73,9 @@ def authorize(request):
return HttpResponseRedirect(manager.get_authorization_url(nexturl))
user = manager.get_oauth_userinfo()
if user:
# zzh: 处理用户昵称为空的情况,生成默认昵称
if not user.nickname or not user.nickname.strip():
user.nickname = "djangoblog" + timezone.now().strftime('%y%m%d%I%M%S')
try:
# zzh: 检查是否已存在该OAuth用户更新信息
temp = OAuthUser.objects.get(type=type, openid=user.openid)
temp.picture = user.picture
temp.metadata = user.metadata
@ -88,11 +83,10 @@ def authorize(request):
user = temp
except ObjectDoesNotExist:
pass
# zzh: facebook的token过长,清空存储
# facebook的token过长
if type == 'facebook':
user.token = ''
if user.email:
# zzh: 有邮箱的情况,直接创建或关联用户并登录
with transaction.atomic():
author = None
try:
@ -100,17 +94,14 @@ def authorize(request):
except ObjectDoesNotExist:
pass
if not author:
# zzh: 根据邮箱获取或创建用户
result = get_user_model().objects.get_or_create(email=user.email)
author = result[0]
if result[1]:
# zzh: 新创建用户,设置用户名
try:
get_user_model().objects.get(username=user.nickname)
except ObjectDoesNotExist:
author.username = user.nickname
else:
# zzh: 用户名冲突时生成唯一用户名
author.username = "djangoblog" + timezone.now().strftime('%y%m%d%I%M%S')
author.source = 'authorize'
author.save()
@ -118,13 +109,11 @@ def authorize(request):
user.author = author
user.save()
# zzh: 发送OAuth用户登录信号
oauth_user_login_signal.send(
sender=authorize.__class__, id=user.id)
login(request, author)
return HttpResponseRedirect(nexturl)
else:
# zzh: 没有邮箱的情况保存OAuth用户并跳转到补充邮箱页面
user.save()
url = reverse('oauth:require_email', kwargs={
'oauthid': user.id
@ -136,10 +125,8 @@ def authorize(request):
def emailconfirm(request, id, sign):
# zzh: 邮箱确认处理,验证签名并完成用户绑定
if not sign:
return HttpResponseForbidden()
# zzh: 验证签名防止篡改
if not get_sha256(settings.SECRET_KEY +
str(id) +
settings.SECRET_KEY).upper() == sign.upper():
@ -149,7 +136,6 @@ def emailconfirm(request, id, sign):
if oauthuser.author:
author = get_user_model().objects.get(pk=oauthuser.author_id)
else:
# zzh: 创建新用户并关联OAuth账户
result = get_user_model().objects.get_or_create(email=oauthuser.email)
author = result[0]
if result[1]:
@ -159,7 +145,6 @@ def emailconfirm(request, id, sign):
author.save()
oauthuser.author = author
oauthuser.save()
# zzh: 发送登录信号并登录用户
oauth_user_login_signal.send(
sender=emailconfirm.__class__,
id=oauthuser.id)
@ -177,7 +162,6 @@ def emailconfirm(request, id, sign):
%(site)s
''') % {'oauthuser_type': oauthuser.type, 'site': site}
# zzh: 发送绑定成功邮件
send_email(emailto=[oauthuser.email, ], title=_('Congratulations on your successful binding!'), content=content)
url = reverse('oauth:bindsuccess', kwargs={
'oauthid': id
@ -187,7 +171,6 @@ def emailconfirm(request, id, sign):
class RequireEmailView(FormView):
# zzh: 需要补充邮箱的表单视图
form_class = RequireEmailForm
template_name = 'oauth/require_email.html'
@ -201,7 +184,6 @@ class RequireEmailView(FormView):
return super(RequireEmailView, self).get(request, *args, **kwargs)
def get_initial(self):
# zzh: 设置表单初始数据
oauthid = self.kwargs['oauthid']
return {
'email': '',
@ -209,7 +191,6 @@ class RequireEmailView(FormView):
}
def get_context_data(self, **kwargs):
# zzh: 添加上下文数据,用于在模板中显示用户头像
oauthid = self.kwargs['oauthid']
oauthuser = get_object_or_404(OAuthUser, pk=oauthid)
if oauthuser.picture:
@ -217,13 +198,11 @@ class RequireEmailView(FormView):
return super(RequireEmailView, self).get_context_data(**kwargs)
def form_valid(self, form):
# zzh: 表单验证通过,保存邮箱并发送确认邮件
email = form.cleaned_data['email']
oauthid = form.cleaned_data['oauthid']
oauthuser = get_object_or_404(OAuthUser, pk=oauthid)
oauthuser.email = email
oauthuser.save()
# zzh: 生成邮箱确认签名
sign = get_sha256(settings.SECRET_KEY +
str(oauthuser.id) + settings.SECRET_KEY)
site = get_current_site().domain
@ -246,7 +225,6 @@ class RequireEmailView(FormView):
<br />
%(url)s
""") % {'url': url}
# zzh: 发送邮箱确认邮件
send_email(emailto=[email, ], title=_('Bind your email'), content=content)
url = reverse('oauth:bindsuccess', kwargs={
'oauthid': oauthid
@ -256,17 +234,14 @@ class RequireEmailView(FormView):
def bindsuccess(request, oauthid):
# zzh: 绑定成功页面,根据类型显示不同内容
type = request.GET.get('type', None)
oauthuser = get_object_or_404(OAuthUser, pk=oauthid)
if type == 'email':
# zzh: 邮箱已提交,等待确认
title = _('Bind your email')
content = _(
'Congratulations, the binding is just one step away. '
'Please log in to your email to check the email to complete the binding. Thank you.')
else:
# zzh: 绑定已完成
title = _('Binding successful')
content = _(
"Congratulations, you have successfully bound your email address. You can use %(oauthuser_type)s"
@ -275,4 +250,4 @@ def bindsuccess(request, oauthid):
return render(request, 'oauth/bindsuccess.html', {
'title': title,
'content': content
})
})

Loading…
Cancel
Save