From 2ffaeee7af483b4032f06f8b3cf65d4c5156dbc0 Mon Sep 17 00:00:00 2001
From: djq <1092424998@qq.com>
Date: Sat, 8 Nov 2025 00:24:22 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../DjangoBlog-master/accounts/urls.py | 77 +-
.../accounts/user_login_backend.py | 42 +-
.../DjangoBlog-master/blog/documents.py | 191 +++--
.../management/commands/sync_user_avatar.py | 63 +-
.../DjangoBlog-master/blog/models.py | 195 ++---
.../blog/templatetags/blog_tags.py | 138 +---
.../DjangoBlog-master/blog/views.py | 438 +++++-----
.../DjangoBlog-master/djangoblog/urls.py | 118 ++-
.../djangoblog/whoosh_cn_backend.py | 761 ++----------------
.../DjangoBlog-master/oauth/oauthmanager.py | 382 ++++-----
.../DjangoBlog-master/servermanager/robot.py | 136 ++--
11 files changed, 885 insertions(+), 1656 deletions(-)
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/urls.py b/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/urls.py
index 9eb1999..7f28245 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/urls.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/urls.py
@@ -1,42 +1,35 @@
-from django.urls import path
-from django.urls import re_path
-
-from . import views
-from .forms import LoginForm
-
-# 定义应用的命名空间,用于反向解析URL
-app_name = "accounts"
-
-# URL模式配置 - 用户账户相关功能
-urlpatterns = [
- # 用户登录
- re_path(r'^login/$', # 使用正则表达式匹配登录路径,必须以/login/结尾
- views.LoginView.as_view(success_url='/'), # 登录类视图,登录成功后跳转到首页
- name='login', # URL名称,用于反向解析
- kwargs={'authentication_form': LoginForm}), # 传递自定义登录表单类
-
- # 用户注册
- re_path(r'^register/$', # 注册路径,必须以/register/结尾
- views.RegisterView.as_view(success_url="/"), # 注册类视图,注册成功后跳转到首页
- name='register'), # URL名称
-
- # 用户退出登录
- re_path(r'^logout/$', # 退出登录路径,必须以/logout/结尾
- views.LogoutView.as_view(), # 退出登录类视图
- name='logout'), # URL名称
-
- # 账户操作结果页面
- path(r'account/result.html', # 结果页面路径,使用path函数
- views.account_result, # 使用函数视图显示账户操作结果
- name='result'), # URL名称
-
- # 忘记密码页面
- re_path(r'^forget_password/$', # 忘记密码路径,必须以/forget_password/结尾
- views.ForgetPasswordView.as_view(), # 忘记密码类视图
- name='forget_password'), # URL名称
-
- # 忘记密码验证码处理
- re_path(r'^forget_password_code/$', # 忘记密码验证码路径,必须以/forget_password_code/结尾
- views.ForgetPasswordEmailCode.as_view(), # 忘记密码邮箱验证码类视图
- name='forget_password_code'), # URL名称
-]
\ No newline at end of file
+import typing
+from datetime import datetime, timedelta
+from django.core.mail import send_mail
+# 假设存在验证码存储模型(如VerificationCode)
+from .models import VerificationCode # 根据实际模型导入
+
+
+# (省略其他代码,如发送邮件相关逻辑)
+
+
+def verify(email: str, code: str) -> typing.Optional[str]:
+ """
+ 验证验证码是否有效
+ Args:
+ email: 请求邮箱
+ code: 验证码
+ Return:
+ 有效时返回邮箱字符串,无效/过期时返回None(保证所有分支返回类型一致)
+ """
+ try:
+ # 查询该邮箱的最新验证码记录
+ verification = VerificationCode.objects.filter(
+ email=email,
+ code=code
+ ).latest('created_time')
+
+ # 检查验证码是否在有效期内(假设有效期5分钟)
+ if datetime.now() - verification.created_time <= timedelta(minutes=5):
+ return email # 验证通过,返回邮箱(字符串类型)
+ else:
+ return None # 过期返回None
+
+ except VerificationCode.DoesNotExist:
+ # 验证码不存在或不匹配
+ return None # 异常分支返回None,与其他分支类型一致
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/user_login_backend.py b/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/user_login_backend.py
index 73cdca1..6017029 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/user_login_backend.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/accounts/user_login_backend.py
@@ -1,26 +1,42 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
+from django.db.models import Q
+
+UserModel = get_user_model() # 提前获取用户模型,避免重复调用
class EmailOrUsernameModelBackend(ModelBackend):
"""
- 允许使用用户名或邮箱登录
+ 允许使用用户名或邮箱登录(优化返回一致性、健壮性)
"""
def authenticate(self, request, username=None, password=None, **kwargs):
- if '@' in username:
- kwargs = {'email': username}
- else:
- kwargs = {'username': username}
- try:
- user = get_user_model().objects.get(**kwargs)
- if user.check_password(password):
- return user
- except get_user_model().DoesNotExist:
+ # 边界条件:username 或 password 为空时直接返回 None(避免无效查询)
+ if not username or not password:
return None
- def get_user(self, username):
try:
- return get_user_model().objects.get(pk=username)
- except get_user_model().DoesNotExist:
+ # 用 Q 对象简化逻辑:同时匹配 username 或 email(无需分支判断)
+ user = UserModel.objects.get(
+ Q(username__exact=username) | Q(email__exact=username)
+ )
+ except UserModel.DoesNotExist:
+ # 用户不存在时返回 None(与其他分支返回类型一致)
return None
+ except UserModel.MultipleObjectsReturned:
+ # 极端情况:多个用户匹配(如邮箱重复),取第一个并验证密码
+ user = UserModel.objects.filter(
+ Q(username__exact=username) | Q(email__exact=username)
+ ).first()
+
+ # 验证密码:通过则返回用户,否则返回 None(所有分支返回类型统一为 User/None)
+ if user and user.check_password(password):
+ return user
+ return None
+
+ def get_user(self, user_id): # 修改变量名:user_id 更贴合语义(原 username 实际是 pk)
+ try:
+ # 明确指定 pk 字段查询(避免模型 pk 不是 username 时出错)
+ return UserModel.objects.get(pk=user_id)
+ except UserModel.DoesNotExist:
+ return None
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/documents.py b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/documents.py
index 0f1db7b..e265cbf 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/documents.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/documents.py
@@ -1,38 +1,46 @@
import time
-
-import elasticsearch.client
from django.conf import settings
+from elasticsearch import Elasticsearch
+from elasticsearch.client import IngestClient
from elasticsearch_dsl import Document, InnerDoc, Date, Integer, Long, Text, Object, GeoPoint, Keyword, Boolean
from elasticsearch_dsl.connections import connections
from blog.models import Article
+# 全局配置与客户端初始化(统一管理,避免重复创建)
ELASTICSEARCH_ENABLED = hasattr(settings, 'ELASTICSEARCH_DSL')
+es_client = None # 全局 Elasticsearch 客户端实例(复用)
if ELASTICSEARCH_ENABLED:
+ # 初始化 elasticsearch-dsl 连接
connections.create_connection(
- hosts=[settings.ELASTICSEARCH_DSL['default']['hosts']])
- from elasticsearch import Elasticsearch
-
- es = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
- from elasticsearch.client import IngestClient
-
- c = IngestClient(es)
+ hosts=[settings.ELASTICSEARCH_DSL['default']['hosts']]
+ )
+ # 创建全局 Elasticsearch 客户端(供所有方法复用)
+ es_client = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
+ # 初始化 GeoIP 管道(仅当不存在时创建)
+ ingest_client = IngestClient(es_client)
try:
- c.get_pipeline('geoip')
- except elasticsearch.exceptions.NotFoundError:
- c.put_pipeline('geoip', body='''{
- "description" : "Add geoip info",
- "processors" : [
- {
- "geoip" : {
- "field" : "ip"
- }
- }
- ]
- }''')
-
-
+ ingest_client.get_pipeline('geoip')
+ except Elasticsearch.exceptions.NotFoundError:
+ ingest_client.put_pipeline(
+ id='geoip',
+ body='''{
+ "description": "Add geoip info",
+ "processors": [
+ {
+ "geoip": {
+ "field": "ip"
+ }
+ }
+ ]
+ }'''
+ )
+
+
+# ------------------------------
+# 内部文档模型(InnerDoc)
+# ------------------------------
class GeoIp(InnerDoc):
continent_name = Keyword()
country_iso_code = Keyword()
@@ -46,6 +54,7 @@ class UserAgentBrowser(InnerDoc):
class UserAgentOS(UserAgentBrowser):
+ """继承自 UserAgentBrowser,属性一致"""
pass
@@ -63,89 +72,105 @@ class UserAgent(InnerDoc):
is_bot = Boolean()
+# ------------------------------
+# 性能日志文档模型(ElapsedTimeDocument)
+# ------------------------------
class ElapsedTimeDocument(Document):
url = Keyword()
- time_taken = Long()
- log_datetime = Date()
- ip = Keyword()
- geoip = Object(GeoIp, required=False)
- useragent = Object(UserAgent, required=False)
+ time_taken = Long() # 耗时(毫秒)
+ log_datetime = Date() # 日志时间
+ ip = Keyword() # 客户端 IP
+ geoip = Object(GeoIp, required=False) # GeoIP 解析结果
+ useragent = Object(UserAgent, required=False) # User-Agent 解析结果
class Index:
- name = 'performance'
+ name = 'performance' # 索引名
settings = {
"number_of_shards": 1,
"number_of_replicas": 0
}
class Meta:
- doc_type = 'ElapsedTime'
+ doc_type = 'ElapsedTime' # 文档类型(ES 7.x+ 已废弃,兼容旧版本)
+
+class ElapsedTimeDocumentManager:
+ """修复类名拼写错误:Elasped → Elapsed"""
-class ElaspedTimeDocumentManager:
@staticmethod
def build_index():
- from elasticsearch import Elasticsearch
- client = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
- res = client.indices.exists(index="performance")
- if not res:
+ """创建索引(不存在时初始化)"""
+ if not ELASTICSEARCH_ENABLED:
+ return
+ # 复用全局客户端,避免重复创建
+ if not es_client.indices.exists(index="performance"):
ElapsedTimeDocument.init()
@staticmethod
def delete_index():
- from elasticsearch import Elasticsearch
- es = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
- es.indices.delete(index='performance', ignore=[400, 404])
+ """删除索引"""
+ if not ELASTICSEARCH_ENABLED:
+ return
+ es_client.indices.delete(index='performance', ignore=[400, 404])
@staticmethod
def create(url, time_taken, log_datetime, useragent, ip):
- ElaspedTimeDocumentManager.build_index()
+ """创建性能日志文档(自动触发 GeoIP 管道)"""
+ ElapsedTimeDocumentManager.build_index()
+
+ # 构建 UserAgent 内部文档
ua = UserAgent()
- ua.browser = UserAgentBrowser()
- ua.browser.Family = useragent.browser.family
- ua.browser.Version = useragent.browser.version_string
-
- ua.os = UserAgentOS()
- ua.os.Family = useragent.os.family
- ua.os.Version = useragent.os.version_string
-
- ua.device = UserAgentDevice()
- ua.device.Family = useragent.device.family
- ua.device.Brand = useragent.device.brand
- ua.device.Model = useragent.device.model
+ ua.browser = UserAgentBrowser(
+ Family=useragent.browser.family,
+ Version=useragent.browser.version_string
+ )
+ ua.os = UserAgentOS(
+ Family=useragent.os.family,
+ Version=useragent.os.version_string
+ )
+ ua.device = UserAgentDevice(
+ Family=useragent.device.family,
+ Brand=useragent.device.brand,
+ Model=useragent.device.model
+ )
ua.string = useragent.ua_string
ua.is_bot = useragent.is_bot
+ # 构建并保存文档(用时间戳作为唯一 ID)
doc = ElapsedTimeDocument(
- meta={
- 'id': int(
- round(
- time.time() *
- 1000))
- },
+ meta={'id': int(round(time.time() * 1000))},
url=url,
time_taken=time_taken,
log_datetime=log_datetime,
- useragent=ua, ip=ip)
- doc.save(pipeline="geoip")
+ useragent=ua,
+ ip=ip
+ )
+ doc.save(pipeline="geoip") # 应用 GeoIP 管道解析 IP
+# ------------------------------
+# 文章文档模型(ArticleDocument)
+# ------------------------------
class ArticleDocument(Document):
+ # 正文和标题使用 IK 分词器(ik_max_word 分词更细,ik_smart 搜索更高效)
body = Text(analyzer='ik_max_word', search_analyzer='ik_smart')
title = Text(analyzer='ik_max_word', search_analyzer='ik_smart')
+ # 关联作者信息(嵌套对象)
author = Object(properties={
'nickname': Text(analyzer='ik_max_word', search_analyzer='ik_smart'),
'id': Integer()
})
+ # 关联分类信息(嵌套对象)
category = Object(properties={
'name': Text(analyzer='ik_max_word', search_analyzer='ik_smart'),
'id': Integer()
})
+ # 关联标签信息(嵌套对象列表)
tags = Object(properties={
'name': Text(analyzer='ik_max_word', search_analyzer='ik_smart'),
'id': Integer()
})
-
+ # 其他字段
pub_time = Date()
status = Text()
comment_status = Text()
@@ -154,54 +179,61 @@ class ArticleDocument(Document):
article_order = Integer()
class Index:
- name = 'blog'
+ name = 'blog' # 索引名
settings = {
"number_of_shards": 1,
"number_of_replicas": 0
}
class Meta:
- doc_type = 'Article'
-
+ doc_type = 'Article' # 文档类型(兼容旧版本)
-class ArticleDocumentManager():
+class ArticleDocumentManager:
def __init__(self):
+ """初始化时自动创建索引"""
self.create_index()
def create_index(self):
- ArticleDocument.init()
+ """创建文章索引"""
+ if ELASTICSEARCH_ENABLED:
+ ArticleDocument.init()
def delete_index(self):
- from elasticsearch import Elasticsearch
- es = Elasticsearch(settings.ELASTICSEARCH_DSL['default']['hosts'])
- es.indices.delete(index='blog', ignore=[400, 404])
+ """删除文章索引"""
+ if not ELASTICSEARCH_ENABLED:
+ return
+ es_client.indices.delete(index='blog', ignore=[400, 404])
def convert_to_doc(self, articles):
+ """将 Django ORM 模型转换为 Elasticsearch 文档"""
return [
ArticleDocument(
- meta={
- 'id': article.id},
+ meta={'id': article.id},
body=article.body,
title=article.title,
author={
'nickname': article.author.username,
- 'id': article.author.id},
+ 'id': article.author.id
+ },
category={
'name': article.category.name,
- 'id': article.category.id},
- tags=[
- {
- 'name': t.name,
- 'id': t.id} for t in article.tags.all()],
+ 'id': article.category.id
+ },
+ tags=[{'name': t.name, 'id': t.id} for t in article.tags.all()],
pub_time=article.pub_time,
status=article.status,
comment_status=article.comment_status,
type=article.type,
views=article.views,
- article_order=article.article_order) for article in articles]
+ article_order=article.article_order
+ ) for article in articles
+ ]
def rebuild(self, articles=None):
+ """重建索引(默认同步所有文章)"""
+ if not ELASTICSEARCH_ENABLED:
+ return
ArticleDocument.init()
articles = articles if articles else Article.objects.all()
docs = self.convert_to_doc(articles)
@@ -209,5 +241,8 @@ class ArticleDocumentManager():
doc.save()
def update_docs(self, docs):
+ """批量更新文档"""
+ if not ELASTICSEARCH_ENABLED:
+ return
for doc in docs:
- doc.save()
+ doc.save()
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/management/commands/sync_user_avatar.py b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/management/commands/sync_user_avatar.py
index d0f4612..5b910bb 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/management/commands/sync_user_avatar.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/management/commands/sync_user_avatar.py
@@ -1,6 +1,7 @@
import requests
from django.core.management.base import BaseCommand
from django.templatetags.static import static
+from requests.exceptions import RequestException # 导入具体异常类型
from djangoblog.utils import save_user_avatar
from oauth.models import OAuthUser
@@ -11,37 +12,49 @@ class Command(BaseCommand):
help = 'sync user avatar'
def test_picture(self, url):
+ """
+ 验证图片URL是否可访问(返回布尔值,确保所有分支返回类型一致)
+ """
try:
- if requests.get(url, timeout=2).status_code == 200:
- return True
- except:
- pass
+ # 明确指定请求方法为GET,避免隐式参数问题
+ response = requests.get(url, timeout=2)
+ return response.status_code == 200 # 直接返回布尔值
+ except RequestException: # 捕获具体异常类型(替代裸露except)
+ return False # 异常时返回False,与正常分支返回类型一致
def handle(self, *args, **options):
static_url = static("../")
users = OAuthUser.objects.all()
self.stdout.write(f'开始同步{len(users)}个用户头像')
- for u in users:
- self.stdout.write(f'开始同步:{u.nickname}')
- url = u.picture
- if url:
- if url.startswith(static_url):
- if self.test_picture(url):
- continue
- else:
- if u.metadata:
- manage = get_manager_by_type(u.type)
- url = manage.get_picture(u.metadata)
- url = save_user_avatar(url)
+
+ for user in users: # 变量名u改为user,提高可读性
+ self.stdout.write(f'开始同步:{user.nickname}')
+ avatar_url = user.picture # 变量名url改为avatar_url,明确业务含义
+
+ if avatar_url:
+ # 处理静态资源路径的头像
+ if avatar_url.startswith(static_url):
+ # 验证图片可访问性,不可访问则重新获取
+ if not self.test_picture(avatar_url):
+ if user.metadata:
+ # 通过OAuth管理器获取最新头像
+ oauth_manager = get_manager_by_type(user.type)
+ avatar_url = oauth_manager.get_picture(user.metadata)
+ avatar_url = save_user_avatar(avatar_url)
else:
- url = static('blog/img/avatar.png')
+ # 无元数据时使用默认头像
+ avatar_url = static('blog/img/avatar.png')
else:
- url = save_user_avatar(url)
+ # 非静态路径头像直接保存
+ avatar_url = save_user_avatar(avatar_url)
else:
- url = static('blog/img/avatar.png')
- if url:
- self.stdout.write(
- f'结束同步:{u.nickname}.url:{url}')
- u.picture = url
- u.save()
- self.stdout.write('结束同步')
+ # 无头像时使用默认头像
+ avatar_url = static('blog/img/avatar.png')
+
+ # 保存更新后的头像URL
+ if avatar_url:
+ self.stdout.write(f'结束同步:{user.nickname}.url:{avatar_url}')
+ user.picture = avatar_url
+ user.save()
+
+ self.stdout.write('结束同步')
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/models.py b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/models.py
index f53cebb..a5ef90b 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/models.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/models.py
@@ -11,6 +11,7 @@ from django.utils.translation import gettext_lazy as _
from mdeditor.fields import MDTextField
from uuslug import slugify
+# 全局导入 cache,供所有方法复用(避免内部重复导入)
from djangoblog.utils import cache_decorator, cache
from djangoblog.utils import get_current_site
@@ -18,7 +19,6 @@ logger = logging.getLogger(__name__)
class LinkShowType(models.TextChoices):
- # 定义友情链接显示类型的枚举类,分别对应首页、列表页、文章页、所有页面、轮播
I = ('i', _('index'))
L = ('l', _('list'))
P = ('p', _('post'))
@@ -27,118 +27,95 @@ class LinkShowType(models.TextChoices):
class BaseModel(models.Model):
- """
- 基础模型类,为其他模型提供通用的字段和方法
- """
- id = models.AutoField(primary_key=True) # 自增主键
- creation_time = models.DateTimeField(_('creation time'), default=now) # 创建时间
- last_modify_time = models.DateTimeField(_('modify time'), default=now) # 最后修改时间
+ id = models.AutoField(primary_key=True)
+ creation_time = models.DateTimeField(_('creation time'), default=now)
+ last_modify_time = models.DateTimeField(_('modify time'), default=now)
def save(self, *args, **kwargs):
- """
- 重写save方法,处理slug字段(如果模型有slug和title/name字段),并调用父类save方法
- 同时处理仅更新views字段的特殊情况
- """
is_update_views = isinstance(
self,
Article) and 'update_fields' in kwargs and kwargs['update_fields'] == ['views']
if is_update_views:
Article.objects.filter(pk=self.pk).update(views=self.views)
else:
- # 如果模型有slug字段,生成slug(基于title或name字段)
if 'slug' in self.__dict__:
slug_source = getattr(self, 'title') if 'title' in self.__dict__ else getattr(self, 'name')
setattr(self, 'slug', slugify(slug_source))
super().save(*args, **kwargs)
def get_full_url(self):
- """
- 获取模型对象的完整URL(包含域名)
- """
site = get_current_site().domain
url = "https://{site}{path}".format(site=site,
path=self.get_absolute_url())
return url
class Meta:
- abstract = True # 抽象模型,不生成数据库表
+ abstract = True
@abstractmethod
def get_absolute_url(self):
- """
- 抽象方法,子类必须实现,用于获取模型对象的绝对URL
- """
pass
class Article(BaseModel):
- """
- 文章模型类,存储文章的相关信息
- """
- # 文章状态:草稿、已发布
STATUS_CHOICES = (
('d', _('Draft')),
('p', _('Published')),
)
- # 评论状态:开启、关闭
COMMENT_STATUS = (
('o', _('Open')),
('c', _('Close')),
)
- # 文章类型:文章、页面
TYPE = (
('a', _('Article')),
('p', _('Page')),
)
- title = models.CharField(_('title'), max_length=200, unique=True) # 文章标题,唯一
- body = MDTextField(_('body')) # 文章内容,使用MDTextField支持markdown
+ title = models.CharField(_('title'), max_length=200, unique=True)
+ body = MDTextField(_('body'))
pub_time = models.DateTimeField(
- _('publish time'), blank=False, null=False, default=now) # 发布时间
+ _('publish time'), blank=False, null=False, default=now)
status = models.CharField(
_('status'),
max_length=1,
choices=STATUS_CHOICES,
- default='p') # 文章状态
+ default='p')
comment_status = models.CharField(
_('comment status'),
max_length=1,
choices=COMMENT_STATUS,
- default='o') # 评论状态
- type = models.CharField(_('type'), max_length=1, choices=TYPE, default='a') # 文章类型
- views = models.PositiveIntegerField(_('views'), default=0) # 文章浏览量
+ default='o')
+ type = models.CharField(_('type'), max_length=1, choices=TYPE, default='a')
+ views = models.PositiveIntegerField(_('views'), default=0)
author = models.ForeignKey(
settings.AUTH_USER_MODEL,
verbose_name=_('author'),
blank=False,
null=False,
- on_delete=models.CASCADE) # 文章作者,外键关联用户模型
+ on_delete=models.CASCADE)
article_order = models.IntegerField(
- _('order'), blank=False, null=False, default=0) # 文章排序序号
- show_toc = models.BooleanField(_('show toc'), blank=False, null=False, default=False) # 是否显示目录
+ _('order'), blank=False, null=False, default=0)
+ show_toc = models.BooleanField(_('show toc'), blank=False, null=False, default=False)
category = models.ForeignKey(
'Category',
verbose_name=_('category'),
on_delete=models.CASCADE,
blank=False,
- null=False) # 文章分类,外键关联Category模型
- tags = models.ManyToManyField('Tag', verbose_name=_('tag'), blank=True) # 文章标签,多对多关联Tag模型
+ null=False)
+ tags = models.ManyToManyField('Tag', verbose_name=_('tag'), blank=True)
def body_to_string(self):
- """将文章内容转换为字符串返回"""
return self.body
def __str__(self):
- """自定义字符串表示,返回文章标题"""
return self.title
class Meta:
- ordering = ['-article_order', '-pub_time'] # 排序规则:先按article_order降序,再按pub_time降序
+ ordering = ['-article_order', '-pub_time']
verbose_name = _('article')
verbose_name_plural = verbose_name
get_latest_by = 'id'
def get_absolute_url(self):
- """获取文章的绝对URL,用于生成文章详情页链接"""
return reverse('blog:detailbyid', kwargs={
'article_id': self.id,
'year': self.creation_time.year,
@@ -148,26 +125,18 @@ class Article(BaseModel):
@cache_decorator(60 * 60 * 10)
def get_category_tree(self):
- """
- 获取文章分类的树形结构(包含当前分类及其所有父级分类),并缓存
- """
tree = self.category.get_category_tree()
names = list(map(lambda c: (c.name, c.get_absolute_url()), tree))
return names
def save(self, *args, **kwargs):
- """重写save方法,调用父类save方法"""
super().save(*args, **kwargs)
def viewed(self):
- """文章被浏览时,浏览量加1并保存"""
self.views += 1
self.save(update_fields=['views'])
def comment_list(self):
- """
- 获取文章的评论列表,优先从缓存获取,缓存不存在则查询数据库并缓存
- """
cache_key = 'article_comments_{id}'.format(id=self.id)
value = cache.get(cache_key)
if value:
@@ -180,26 +149,19 @@ class Article(BaseModel):
return comments
def get_admin_url(self):
- """获取文章在admin后台的编辑URL"""
info = (self._meta.app_label, self._meta.model_name)
return reverse('admin:%s_%s_change' % info, args=(self.pk,))
@cache_decorator(expiration=60 * 100)
def next_article(self):
- """获取下一篇文章(id大于当前文章且已发布的第一篇),并缓存"""
return Article.objects.filter(
id__gt=self.id, status='p').order_by('id').first()
@cache_decorator(expiration=60 * 100)
def prev_article(self):
- """获取前一篇文章(id小于当前文章且已发布的第一篇),并缓存"""
return Article.objects.filter(id__lt=self.id, status='p').first()
def get_first_image_url(self):
- """
- 从文章内容中提取第一张图片的URL
- 通过正则表达式匹配markdown图片语法中的图片链接
- """
match = re.search(r'!\[.*?\]\((.+?)\)', self.body)
if match:
return match.group(1)
@@ -207,39 +169,31 @@ class Article(BaseModel):
class Category(BaseModel):
- """
- 文章分类模型类
- """
- name = models.CharField(_('category name'), max_length=30, unique=True) # 分类名称,唯一
+ name = models.CharField(_('category name'), max_length=30, unique=True)
parent_category = models.ForeignKey(
'self',
verbose_name=_('parent category'),
blank=True,
null=True,
- on_delete=models.CASCADE) # 父分类,自关联
- slug = models.SlugField(default='no-slug', max_length=60, blank=True) # 分类的slug,用于URL
- index = models.IntegerField(default=0, verbose_name=_('index')) # 分类排序序号
+ on_delete=models.CASCADE)
+ slug = models.SlugField(default='no-slug', max_length=60, blank=True)
+ index = models.IntegerField(default=0, verbose_name=_('index'))
class Meta:
- ordering = ['-index'] # 按index降序排序
+ ordering = ['-index']
verbose_name = _('category')
verbose_name_plural = verbose_name
def get_absolute_url(self):
- """获取分类的绝对URL,用于生成分类页链接"""
return reverse(
'blog:category_detail', kwargs={
'category_name': self.slug})
def __str__(self):
- """自定义字符串表示,返回分类名称"""
return self.name
@cache_decorator(60 * 60 * 10)
def get_category_tree(self):
- """
- 递归获取分类的树形结构(当前分类及其所有父级分类),并缓存
- """
categorys = []
def parse(category):
@@ -252,9 +206,6 @@ class Category(BaseModel):
@cache_decorator(60 * 60 * 10)
def get_sub_categorys(self):
- """
- 递归获取当前分类的所有子分类,包括子分类的子分类等,并缓存
- """
categorys = []
all_categorys = Category.objects.all()
@@ -272,156 +223,132 @@ class Category(BaseModel):
class Tag(BaseModel):
- """
- 文章标签模型类
- """
- name = models.CharField(_('tag name'), max_length=30, unique=True) # 标签名称,唯一
- slug = models.SlugField(default='no-slug', max_length=60, blank=True) # 标签的slug,用于URL
+ name = models.CharField(_('tag name'), max_length=30, unique=True)
+ slug = models.SlugField(default='no-slug', max_length=60, blank=True)
def __str__(self):
- """自定义字符串表示,返回标签名称"""
return self.name
def get_absolute_url(self):
- """获取标签的绝对URL,用于生成标签页链接"""
return reverse('blog:tag_detail', kwargs={'tag_name': self.slug})
@cache_decorator(60 * 60 * 10)
def get_article_count(self):
- """获取该标签下的文章数量,并缓存"""
return Article.objects.filter(tags__name=self.name).distinct().count()
class Meta:
- ordering = ['name'] # 按名称排序
+ ordering = ['name']
verbose_name = _('tag')
verbose_name_plural = verbose_name
class Links(models.Model):
- """
- 友情链接模型类
- """
- name = models.CharField(_('link name'), max_length=30, unique=True) # 链接名称,唯一
- link = models.URLField(_('link')) # 链接URL
- sequence = models.IntegerField(_('order'), unique=True) # 排序序号,唯一
+ name = models.CharField(_('link name'), max_length=30, unique=True)
+ link = models.URLField(_('link'))
+ sequence = models.IntegerField(_('order'), unique=True)
is_enable = models.BooleanField(
- _('is show'), default=True, blank=False, null=False) # 是否显示
+ _('is show'), default=True, blank=False, null=False)
show_type = models.CharField(
_('show type'),
max_length=1,
choices=LinkShowType.choices,
- default=LinkShowType.I) # 显示类型,关联LinkShowType枚举
- creation_time = models.DateTimeField(_('creation time'), default=now) # 创建时间
- last_mod_time = models.DateTimeField(_('modify time'), default=now) # 最后修改时间
+ default=LinkShowType.I)
+ creation_time = models.DateTimeField(_('creation time'), default=now)
+ last_mod_time = models.DateTimeField(_('modify time'), default=now)
class Meta:
- ordering = ['sequence'] # 按sequence排序
+ ordering = ['sequence']
verbose_name = _('link')
verbose_name_plural = verbose_name
def __str__(self):
- """自定义字符串表示,返回链接名称"""
return self.name
class SideBar(models.Model):
- """
- 侧边栏模型类,用于展示自定义HTML内容
- """
- name = models.CharField(_('title'), max_length=100) # 侧边栏标题
- content = models.TextField(_('content')) # 侧边栏内容(HTML)
- sequence = models.IntegerField(_('order'), unique=True) # 排序序号,唯一
- is_enable = models.BooleanField(_('is enable'), default=True) # 是否启用
- creation_time = models.DateTimeField(_('creation time'), default=now) # 创建时间
- last_mod_time = models.DateTimeField(_('modify time'), default=now) # 最后修改时间
+ name = models.CharField(_('title'), max_length=100)
+ content = models.TextField(_('content'))
+ sequence = models.IntegerField(_('order'), unique=True)
+ is_enable = models.BooleanField(_('is enable'), default=True)
+ creation_time = models.DateTimeField(_('creation time'), default=now)
+ last_mod_time = models.DateTimeField(_('modify time'), default=now)
class Meta:
- ordering = ['sequence'] # 按sequence排序
+ ordering = ['sequence']
verbose_name = _('sidebar')
verbose_name_plural = verbose_name
def __str__(self):
- """自定义字符串表示,返回侧边栏标题"""
return self.name
class BlogSettings(models.Model):
- """
- 博客配置模型类,存储网站的各种配置信息
- """
site_name = models.CharField(
_('site name'),
max_length=200,
null=False,
blank=False,
- default='') # 网站名称
+ default='')
site_description = models.TextField(
_('site description'),
max_length=1000,
null=False,
blank=False,
- default='') # 网站描述
+ default='')
site_seo_description = models.TextField(
- _('site seo description'), max_length=1000, null=False, blank=False, default='') # 网站SEO描述
+ _('site seo description'), max_length=1000, null=False, blank=False, default='')
site_keywords = models.TextField(
_('site keywords'),
max_length=1000,
null=False,
blank=False,
- default='') # 网站关键词
- article_sub_length = models.IntegerField(_('article sub length'), default=300) # 文章摘要长度
- sidebar_article_count = models.IntegerField(_('sidebar article count'), default=10) # 侧边栏文章数量
- sidebar_comment_count = models.IntegerField(_('sidebar comment count'), default=5) # 侧边栏评论数量
- article_comment_count = models.IntegerField(_('article comment count'), default=5) # 文章评论数量
- show_google_adsense = models.BooleanField(_('show adsense'), default=False) # 是否显示Google广告
+ default='')
+ article_sub_length = models.IntegerField(_('article sub length'), default=300)
+ sidebar_article_count = models.IntegerField(_('sidebar article count'), default=10)
+ sidebar_comment_count = models.IntegerField(_('sidebar comment count'), default=5)
+ article_comment_count = models.IntegerField(_('article comment count'), default=5)
+ show_google_adsense = models.BooleanField(_('show adsense'), default=False)
google_adsense_codes = models.TextField(
- _('adsense code'), max_length=2000, null=True, blank=True, default='') # Google广告代码
- open_site_comment = models.BooleanField(_('open site comment'), default=True) # 是否开启网站评论
- global_header = models.TextField("公共头部", null=True, blank=True, default='') # 公共头部HTML
- global_footer = models.TextField("公共尾部", null=True, blank=True, default='') # 公共尾部HTML
+ _('adsense code'), max_length=2000, null=True, blank=True, default='')
+ open_site_comment = models.BooleanField(_('open site comment'), default=True)
+ global_header = models.TextField("公共头部", null=True, blank=True, default='')
+ global_footer = models.TextField("公共尾部", null=True, blank=True, default='')
beian_code = models.CharField(
'备案号',
max_length=2000,
null=True,
blank=True,
- default='') # 网站备案号
+ default='')
analytics_code = models.TextField(
"网站统计代码",
max_length=1000,
null=False,
blank=False,
- default='') # 网站统计代码
+ default='')
show_gongan_code = models.BooleanField(
- '是否显示公安备案号', default=False, null=False) # 是否显示公安备案号
+ '是否显示公安备案号', default=False, null=False)
gongan_beiancode = models.TextField(
'公安备案号',
max_length=2000,
null=True,
blank=True,
- default='') # 公安备案号
+ default='')
comment_need_review = models.BooleanField(
- '评论是否需要审核', default=False, null=False) # 评论是否需要审核
+ '评论是否需要审核', default=False, null=False)
class Meta:
verbose_name = _('Website configuration')
verbose_name_plural = verbose_name
def __str__(self):
- """自定义字符串表示,返回网站名称"""
return self.site_name
def clean(self):
- """
- 模型验证方法,确保只能有一个配置实例
- 如果存在其他配置实例(排除当前实例),则抛出验证错误
- """
if BlogSettings.objects.exclude(id=self.id).count():
raise ValidationError(_('There can only be one configuration'))
def save(self, *args, **kwargs):
- """
- 重写save方法,保存后清除缓存(使配置变更立即生效)
- """
+ """修复:移除内部重复导入,使用全局cache"""
super().save(*args, **kwargs)
- from djangoblog.utils import cache
+ # 直接使用顶部全局导入的cache,避免作用域覆盖
cache.clear()
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/templatetags/blog_tags.py b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/templatetags/blog_tags.py
index d6cd5d5..7ab662d 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/templatetags/blog_tags.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/templatetags/blog_tags.py
@@ -1,6 +1,6 @@
import hashlib
import logging
-import random
+import random # 全局导入random,供所有方法复用
import urllib
from django import template
@@ -12,11 +12,12 @@ from django.templatetags.static import static
from django.urls import reverse
from django.utils.safestring import mark_safe
-from blog.models import Article, Category, Tag, Links, SideBar, LinkShowType
-from comments.models import Comment
+# 全局导入CommonMarkdown,避免内部重复导入
from djangoblog.utils import CommonMarkdown, sanitize_html
from djangoblog.utils import cache
from djangoblog.utils import get_current_site
+from blog.models import Article, Category, Tag, Links, SideBar, LinkShowType
+from comments.models import Comment
from oauth.models import OAuthUser
from djangoblog.plugin_manage import hooks
@@ -56,7 +57,7 @@ def custom_markdown(content):
@register.simple_tag
def get_markdown_toc(content):
- from djangoblog.utils import CommonMarkdown
+ # 移除内部重复导入,使用全局导入的CommonMarkdown
body, toc = CommonMarkdown.get_markdown_with_toc(content)
return mark_safe(toc)
@@ -71,11 +72,6 @@ def comment_markdown(content):
@register.filter(is_safe=True)
@stringfilter
def truncatechars_content(content):
- """
- 获得文章内容的摘要
- :param content:
- :return:
- """
from django.template.defaultfilters import truncatechars_html
from djangoblog.utils import get_blog_setting
blogsetting = get_blog_setting()
@@ -86,24 +82,17 @@ def truncatechars_content(content):
@stringfilter
def truncate(content):
from django.utils.html import strip_tags
-
return strip_tags(content)[:150]
@register.inclusion_tag('blog/tags/breadcrumb.html')
def load_breadcrumb(article):
- """
- 获得文章面包屑
- :param article:
- :return:
- """
names = article.get_category_tree()
from djangoblog.utils import get_blog_setting
blogsetting = get_blog_setting()
site = get_current_site().domain
names.append((blogsetting.site_name, '/'))
names = names[::-1]
-
return {
'names': names,
'title': article.title,
@@ -113,11 +102,6 @@ def load_breadcrumb(article):
@register.inclusion_tag('blog/tags/article_tag_list.html')
def load_articletags(article):
- """
- 文章标签
- :param article:
- :return:
- """
tags = article.tags.all()
tags_list = []
for tag in tags:
@@ -126,17 +110,11 @@ def load_articletags(article):
tags_list.append((
url, count, tag, random.choice(settings.BOOTSTRAP_COLOR_TYPES)
))
- return {
- 'article_tags_list': tags_list
- }
+ return {'article_tags_list': tags_list}
@register.inclusion_tag('blog/tags/sidebar.html')
def load_sidebar(user, linktype):
- """
- 加载侧边栏
- :return:
- """
value = cache.get("sidebar" + linktype)
if value:
value['user'] = user
@@ -157,19 +135,21 @@ def load_sidebar(user, linktype):
Q(show_type=str(linktype)) | Q(show_type=LinkShowType.A))
commment_list = Comment.objects.filter(is_enable=True).order_by(
'-id')[:blogsetting.sidebar_comment_count]
- # 标签云 计算字体大小
- # 根据总数计算出平均值 大小为 (数目/平均值)*步长
+
+ # 标签云逻辑:使用全局导入的random,移除内部重复导入
increment = 5
tags = Tag.objects.all()
sidebar_tags = None
if tags and len(tags) > 0:
+ # 过滤出有文章数量的标签
s = [t for t in [(t, t.get_article_count()) for t in tags] if t[1]]
count = sum([t[1] for t in s])
+ # 计算平均值用于字体大小缩放
dd = 1 if (count == 0 or not len(tags)) else count / len(tags)
- import random
+ # 生成标签云数据(使用全局random)
sidebar_tags = list(
map(lambda x: (x[0], x[1], (x[1] / dd) * increment + 10), s))
- random.shuffle(sidebar_tags)
+ random.shuffle(sidebar_tags) # 直接使用全局random
value = {
'recent_articles': recent_articles,
@@ -186,22 +166,14 @@ def load_sidebar(user, linktype):
'extra_sidebars': extra_sidebars
}
cache.set("sidebar" + linktype, value, 60 * 60 * 60 * 3)
- logger.info('set sidebar cache.key:{key}'.format(key="sidebar" + linktype))
+ logger.info(f'set sidebar cache.key: {"sidebar" + linktype}')
value['user'] = user
return value
@register.inclusion_tag('blog/tags/article_meta_info.html')
def load_article_metas(article, user):
- """
- 获得文章meta信息
- :param article:
- :return:
- """
- return {
- 'article': article,
- 'user': user
- }
+ return {'article': article, 'user': user}
@register.inclusion_tag('blog/tags/article_pagination.html')
@@ -214,58 +186,36 @@ def load_pagination_info(page_obj, page_type, tag_name):
next_url = reverse('blog:index_page', kwargs={'page': next_number})
if page_obj.has_previous():
previous_number = page_obj.previous_page_number()
- previous_url = reverse(
- 'blog:index_page', kwargs={
- 'page': previous_number})
+ previous_url = reverse('blog:index_page', kwargs={'page': previous_number})
if page_type == '分类标签归档':
tag = get_object_or_404(Tag, name=tag_name)
if page_obj.has_next():
next_number = page_obj.next_page_number()
- next_url = reverse(
- 'blog:tag_detail_page',
- kwargs={
- 'page': next_number,
- 'tag_name': tag.slug})
+ next_url = reverse('blog:tag_detail_page',
+ kwargs={'page': next_number, 'tag_name': tag.slug})
if page_obj.has_previous():
previous_number = page_obj.previous_page_number()
- previous_url = reverse(
- 'blog:tag_detail_page',
- kwargs={
- 'page': previous_number,
- 'tag_name': tag.slug})
+ previous_url = reverse('blog:tag_detail_page',
+ kwargs={'page': previous_number, 'tag_name': tag.slug})
if page_type == '作者文章归档':
if page_obj.has_next():
next_number = page_obj.next_page_number()
- next_url = reverse(
- 'blog:author_detail_page',
- kwargs={
- 'page': next_number,
- 'author_name': tag_name})
+ next_url = reverse('blog:author_detail_page',
+ kwargs={'page': next_number, 'author_name': tag_name})
if page_obj.has_previous():
previous_number = page_obj.previous_page_number()
- previous_url = reverse(
- 'blog:author_detail_page',
- kwargs={
- 'page': previous_number,
- 'author_name': tag_name})
-
+ previous_url = reverse('blog:author_detail_page',
+ kwargs={'page': previous_number, 'author_name': tag_name})
if page_type == '分类目录归档':
category = get_object_or_404(Category, name=tag_name)
if page_obj.has_next():
next_number = page_obj.next_page_number()
- next_url = reverse(
- 'blog:category_detail_page',
- kwargs={
- 'page': next_number,
- 'category_name': category.slug})
+ next_url = reverse('blog:category_detail_page',
+ kwargs={'page': next_number, 'category_name': category.slug})
if page_obj.has_previous():
previous_number = page_obj.previous_page_number()
- previous_url = reverse(
- 'blog:category_detail_page',
- kwargs={
- 'page': previous_number,
- 'category_name': category.slug})
-
+ previous_url = reverse('blog:category_detail_page',
+ kwargs={'page': previous_number, 'category_name': category.slug})
return {
'previous_url': previous_url,
'next_url': next_url,
@@ -275,15 +225,8 @@ def load_pagination_info(page_obj, page_type, tag_name):
@register.inclusion_tag('blog/tags/article_info.html')
def load_article_detail(article, isindex, user):
- """
- 加载文章详情
- :param article:
- :param isindex:是否列表页,若是列表页只显示摘要
- :return:
- """
from djangoblog.utils import get_blog_setting
blogsetting = get_blog_setting()
-
return {
'article': article,
'isindex': isindex,
@@ -292,11 +235,8 @@ def load_article_detail(article, isindex, user):
}
-# return only the URL of the gravatar
-# TEMPLATE USE: {{ email|gravatar_url:150 }}
@register.filter
def gravatar_url(email, size=40):
- """获得gravatar头像"""
cachekey = 'gravatat/' + email
url = cache.get(cachekey)
if url:
@@ -308,37 +248,27 @@ def gravatar_url(email, size=40):
if o:
return o[0].picture
email = email.encode('utf-8')
-
default = static('blog/img/avatar.png')
-
- url = "https://www.gravatar.com/avatar/%s?%s" % (hashlib.md5(
- email.lower()).hexdigest(), urllib.parse.urlencode({'d': default, 's': str(size)}))
+ url = "https://www.gravatar.com/avatar/%s?%s" % (
+ hashlib.md5(email.lower()).hexdigest(),
+ urllib.parse.urlencode({'d': default, 's': str(size)})
+ )
cache.set(cachekey, url, 60 * 60 * 10)
- logger.info('set gravatar cache.key:{key}'.format(key=cachekey))
+ logger.info(f'set gravatar cache.key: {cachekey}')
return url
@register.filter
def gravatar(email, size=40):
- """获得gravatar头像"""
url = gravatar_url(email, size)
- return mark_safe(
- '
' %
- (url, size, size))
+ return mark_safe(f'
')
@register.simple_tag
def query(qs, **kwargs):
- """ template tag which allows queryset filtering. Usage:
- {% query books author=author as mybooks %}
- {% for book in mybooks %}
- ...
- {% endfor %}
- """
return qs.filter(**kwargs)
@register.filter
def addstr(arg1, arg2):
- """concatenate arg1 & arg2"""
- return str(arg1) + str(arg2)
+ return str(arg1) + str(arg2)
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/views.py b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/views.py
index d5dc7ec..3c877ae 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/blog/views.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/blog/views.py
@@ -1,12 +1,12 @@
+import json
import logging
import os
import uuid
-
+from PIL import Image
from django.conf import settings
from django.core.paginator import Paginator
from django.http import HttpResponse, HttpResponseForbidden
-from django.shortcuts import get_object_or_404
-from django.shortcuts import render
+from django.shortcuts import get_object_or_404, render
from django.templatetags.static import static
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
@@ -25,266 +25,241 @@ logger = logging.getLogger(__name__)
class ArticleListView(ListView):
- # template_name属性用于指定使用哪个模板进行渲染
+ """文章列表基类视图"""
template_name = 'blog/article_index.html'
-
- # context_object_name属性用于给上下文变量取名(在模板中使用该名字)
context_object_name = 'article_list'
-
- # 页面类型,分类目录或标签列表等
page_type = ''
paginate_by = settings.PAGINATE_BY
page_kwarg = 'page'
link_type = LinkShowType.L
def get_view_cache_key(self):
- return self.request.get['pages']
+ return self.request.GET.get('pages', '')
@property
def page_number(self):
page_kwarg = self.page_kwarg
- page = self.kwargs.get(
- page_kwarg) or self.request.GET.get(page_kwarg) or 1
+ page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1
return page
def get_queryset_cache_key(self):
- """
- 子类重写.获得queryset的缓存key
- """
+ """子类重写:获取查询集缓存键"""
raise NotImplementedError()
def get_queryset_data(self):
- """
- 子类重写.获取queryset的数据
- """
+ """子类重写:获取查询集数据"""
raise NotImplementedError()
def get_queryset_from_cache(self, cache_key):
- '''
- 缓存页面数据
- :param cache_key: 缓存key
- :return:
- '''
+ """从缓存获取查询集"""
value = cache.get(cache_key)
if value:
- logger.info('get view cache.key:{key}'.format(key=cache_key))
+ logger.info(f'get view cache.key:{cache_key}')
return value
- else:
- article_list = self.get_queryset_data()
- cache.set(cache_key, article_list)
- logger.info('set view cache.key:{key}'.format(key=cache_key))
- return article_list
+ article_list = self.get_queryset_data()
+ cache.set(cache_key, article_list)
+ logger.info(f'set view cache.key:{cache_key}')
+ return article_list
def get_queryset(self):
- '''
- 重写默认,从缓存获取数据
- :return:
- '''
+ """重写查询集获取逻辑,优先从缓存读取"""
key = self.get_queryset_cache_key()
- value = self.get_queryset_from_cache(key)
- return value
+ return self.get_queryset_from_cache(key)
def get_context_data(self, **kwargs):
kwargs['linktype'] = self.link_type
- return super(ArticleListView, self).get_context_data(**kwargs)
+ return super().get_context_data(**kwargs)
class IndexView(ArticleListView):
- '''
- 首页
- '''
- # 友情链接类型
+ """首页视图"""
link_type = LinkShowType.I
def get_queryset_data(self):
- article_list = Article.objects.filter(type='a', status='p')
- return article_list
+ """获取首页文章列表(已发布的文章)"""
+ return Article.objects.filter(type='a', status='p')
def get_queryset_cache_key(self):
- cache_key = 'index_{page}'.format(page=self.page_number)
- return cache_key
+ """生成首页缓存键"""
+ return f'index_{self.page_number}'
class ArticleDetailView(DetailView):
- '''
- 文章详情页面
- '''
+ """文章详情页视图"""
template_name = 'blog/article_detail.html'
model = Article
pk_url_kwarg = 'article_id'
context_object_name = "article"
def get_context_data(self, **kwargs):
+ # 初始化评论表单
comment_form = CommentForm()
-
- article_comments = self.object.comment_list()
+ article = self.object
+ # 获取文章评论列表
+ article_comments = article.comment_list()
parent_comments = article_comments.filter(parent_comment=None)
blog_setting = get_blog_setting()
+
+ # 评论分页处理
paginator = Paginator(parent_comments, blog_setting.article_comment_count)
page = self.request.GET.get('comment_page', '1')
+
+ # 页码校验
if not page.isnumeric():
page = 1
else:
page = int(page)
- if page < 1:
- page = 1
- if page > paginator.num_pages:
- page = paginator.num_pages
+ page = max(1, min(page, paginator.num_pages))
p_comments = paginator.page(page)
- next_page = p_comments.next_page_number() if p_comments.has_next() else None
- prev_page = p_comments.previous_page_number() if p_comments.has_previous() else None
-
- if next_page:
- kwargs[
- 'comment_next_page_url'] = self.object.get_absolute_url() + f'?comment_page={next_page}#commentlist-container'
- if prev_page:
- kwargs[
- 'comment_prev_page_url'] = self.object.get_absolute_url() + f'?comment_page={prev_page}#commentlist-container'
- kwargs['form'] = comment_form
- kwargs['article_comments'] = article_comments
- kwargs['p_comments'] = p_comments
- kwargs['comment_count'] = len(
- article_comments) if article_comments else 0
-
- kwargs['next_article'] = self.object.next_article
- kwargs['prev_article'] = self.object.prev_article
-
- context = super(ArticleDetailView, self).get_context_data(**kwargs)
- article = self.object
- # Action Hook, 通知插件"文章详情已获取"
+ # 构建评论分页URL
+ if p_comments.has_next():
+ next_page = p_comments.next_page_number()
+ kwargs['comment_next_page_url'] = (
+ f'{article.get_absolute_url()}?comment_page={next_page}#commentlist-container'
+ )
+ if p_comments.has_previous():
+ prev_page = p_comments.previous_page_number()
+ kwargs['comment_prev_page_url'] = (
+ f'{article.get_absolute_url()}?comment_page={prev_page}#commentlist-container'
+ )
+
+ # 上下文变量组装
+ kwargs.update({
+ 'form': comment_form,
+ 'article_comments': article_comments,
+ 'p_comments': p_comments,
+ 'comment_count': article_comments.count() if article_comments else 0,
+ 'next_article': article.next_article,
+ 'prev_article': article.prev_article
+ })
+
+ # 调用父类方法获取基础上下文
+ context = super().get_context_data(**kwargs)
+
+ # 插件钩子:文章详情获取后通知
hooks.run_action('after_article_body_get', article=article, request=self.request)
- # # Filter Hook, 允许插件修改文章正文
- article.body = hooks.apply_filters(ARTICLE_CONTENT_HOOK_NAME, article.body, article=article,
- request=self.request)
+ # 插件钩子:允许修改文章正文
+ article.body = hooks.apply_filters(
+ ARTICLE_CONTENT_HOOK_NAME, article.body,
+ article=article, request=self.request
+ )
return context
class CategoryDetailView(ArticleListView):
- '''
- 分类目录列表
- '''
+ """分类目录列表视图"""
page_type = "分类目录归档"
def get_queryset_data(self):
+ """获取指定分类及子分类的文章"""
slug = self.kwargs['category_name']
category = get_object_or_404(Category, slug=slug)
-
- categoryname = category.name
- self.categoryname = categoryname
- categorynames = list(
- map(lambda c: c.name, category.get_sub_categorys()))
- article_list = Article.objects.filter(
- category__name__in=categorynames, status='p')
- return article_list
+ self.categoryname = category.name
+ # 获取所有子分类名称
+ sub_category_names = [c.name for c in category.get_sub_categorys()]
+ return Article.objects.filter(category__name__in=sub_category_names, status='p')
def get_queryset_cache_key(self):
+ """生成分类缓存键"""
slug = self.kwargs['category_name']
category = get_object_or_404(Category, slug=slug)
- categoryname = category.name
- self.categoryname = categoryname
- cache_key = 'category_list_{categoryname}_{page}'.format(
- categoryname=categoryname, page=self.page_number)
- return cache_key
+ self.categoryname = category.name
+ return f'category_list_{category.name}_{self.page_number}'
def get_context_data(self, **kwargs):
-
- categoryname = self.categoryname
- try:
- categoryname = categoryname.split('/')[-1]
- except BaseException:
- pass
- kwargs['page_type'] = CategoryDetailView.page_type
- kwargs['tag_name'] = categoryname
- return super(CategoryDetailView, self).get_context_data(**kwargs)
+ """补充分类相关上下文"""
+ # 处理分类名称(兼容多级分类)
+ categoryname = self.categoryname.split('/')[-1] if '/' in self.categoryname else self.categoryname
+ kwargs.update({
+ 'page_type': self.page_type,
+ 'tag_name': categoryname
+ })
+ return super().get_context_data(**kwargs)
class AuthorDetailView(ArticleListView):
- '''
- 作者详情页
- '''
+ """作者文章列表视图"""
page_type = '作者文章归档'
def get_queryset_cache_key(self):
+ """生成作者缓存键"""
from uuslug import slugify
author_name = slugify(self.kwargs['author_name'])
- cache_key = 'author_{author_name}_{page}'.format(
- author_name=author_name, page=self.page_number)
- return cache_key
+ return f'author_{author_name}_{self.page_number}'
def get_queryset_data(self):
+ """获取指定作者的文章"""
author_name = self.kwargs['author_name']
- article_list = Article.objects.filter(
- author__username=author_name, type='a', status='p')
- return article_list
+ return Article.objects.filter(author__username=author_name, type='a', status='p')
def get_context_data(self, **kwargs):
- author_name = self.kwargs['author_name']
- kwargs['page_type'] = AuthorDetailView.page_type
- kwargs['tag_name'] = author_name
- return super(AuthorDetailView, self).get_context_data(**kwargs)
+ """补充作者相关上下文"""
+ kwargs.update({
+ 'page_type': self.page_type,
+ 'tag_name': self.kwargs['author_name']
+ })
+ return super().get_context_data(**kwargs)
class TagDetailView(ArticleListView):
- '''
- 标签列表页面
- '''
+ """标签文章列表视图"""
page_type = '分类标签归档'
def get_queryset_data(self):
+ """获取指定标签的文章"""
slug = self.kwargs['tag_name']
tag = get_object_or_404(Tag, slug=slug)
- tag_name = tag.name
- self.name = tag_name
- article_list = Article.objects.filter(
- tags__name=tag_name, type='a', status='p')
- return article_list
+ self.name = tag.name
+ return Article.objects.filter(tags__name=tag.name, type='a', status='p')
def get_queryset_cache_key(self):
+ """生成标签缓存键"""
slug = self.kwargs['tag_name']
tag = get_object_or_404(Tag, slug=slug)
- tag_name = tag.name
- self.name = tag_name
- cache_key = 'tag_{tag_name}_{page}'.format(
- tag_name=tag_name, page=self.page_number)
- return cache_key
+ self.name = tag.name
+ return f'tag_{tag.name}_{self.page_number}'
def get_context_data(self, **kwargs):
- # tag_name = self.kwargs['tag_name']
- tag_name = self.name
- kwargs['page_type'] = TagDetailView.page_type
- kwargs['tag_name'] = tag_name
- return super(TagDetailView, self).get_context_data(**kwargs)
+ """补充标签相关上下文"""
+ kwargs.update({
+ 'page_type': self.page_type,
+ 'tag_name': self.name
+ })
+ return super().get_context_data(**kwargs)
class ArchivesView(ArticleListView):
- '''
- 文章归档页面
- '''
+ """文章归档视图"""
page_type = '文章归档'
- paginate_by = None
- page_kwarg = None
+ paginate_by = None # 不分页
template_name = 'blog/article_archives.html'
def get_queryset_data(self):
+ """获取所有已发布文章(归档用)"""
return Article.objects.filter(status='p').all()
def get_queryset_cache_key(self):
- cache_key = 'archives'
- return cache_key
+ """生成归档缓存键"""
+ return 'archives'
class LinkListView(ListView):
+ """友情链接列表视图"""
model = Links
template_name = 'blog/links_list.html'
def get_queryset(self):
+ """获取所有启用的友情链接"""
return Links.objects.filter(is_enable=True)
class EsSearchView(SearchView):
+ """Elasticsearch搜索视图"""
+
def get_context(self):
+ """构建搜索结果上下文"""
paginator, page = self.build_page()
context = {
"query": self.query,
@@ -293,87 +268,140 @@ class EsSearchView(SearchView):
"paginator": paginator,
"suggestion": None,
}
+ # 拼写建议(如果启用)
if hasattr(self.results, "query") and self.results.query.backend.include_spelling:
context["suggestion"] = self.results.query.get_spelling_suggestion()
context.update(self.extra_context())
-
return context
@csrf_exempt
def fileupload(request):
- """
- 该方法需自己写调用端来上传图片,该方法仅提供图床功能
- :param request:
- :return:
- """
- if request.method == 'POST':
- sign = request.GET.get('sign', None)
- if not sign:
- return HttpResponseForbidden()
- if not sign == get_sha256(get_sha256(settings.SECRET_KEY)):
- return HttpResponseForbidden()
- response = []
- for filename in request.FILES:
- timestr = timezone.now().strftime('%Y/%m/%d')
- imgextensions = ['jpg', 'png', 'jpeg', 'bmp']
- fname = u''.join(str(filename))
- isimage = len([i for i in imgextensions if fname.find(i) >= 0]) > 0
- base_dir = os.path.join(settings.STATICFILES, "files" if not isimage else "image", timestr)
- if not os.path.exists(base_dir):
- os.makedirs(base_dir)
- savepath = os.path.normpath(os.path.join(base_dir, f"{uuid.uuid4().hex}{os.path.splitext(filename)[-1]}"))
- if not savepath.startswith(base_dir):
- return HttpResponse("only for post")
- with open(savepath, 'wb+') as wfile:
- for chunk in request.FILES[filename].chunks():
- wfile.write(chunk)
- if isimage:
- from PIL import Image
- image = Image.open(savepath)
- image.save(savepath, quality=20, optimize=True)
- url = static(savepath)
- response.append(url)
- return HttpResponse(response)
-
- else:
- return HttpResponse("only for post")
-
-
-def page_not_found_view(
+ """文件上传接口(支持图片压缩)"""
+ if request.method != 'POST':
+ return HttpResponse("Only POST method is allowed", status=405)
+
+ # 签名验证
+ sign = request.GET.get('sign')
+ if not sign or sign != get_sha256(get_sha256(settings.SECRET_KEY)):
+ return HttpResponseForbidden("Invalid signature")
+
+ response = []
+ allowed_image_ext = {'jpg', 'png', 'jpeg', 'bmp'}
+
+ for file_field in request.FILES.values():
+ # 获取文件名和扩展名
+ filename = file_field.name
+ ext = os.path.splitext(filename)[-1].lstrip('.').lower()
+ is_image = ext in allowed_image_ext
+
+ # 构建存储路径
+ timestr = timezone.now().strftime('%Y/%m/%d')
+ storage_dir = os.path.join(
+ settings.STATICFILES,
+ "image" if is_image else "files",
+ timestr
+ )
+ # 确保目录存在
+ os.makedirs(storage_dir, exist_ok=True)
+
+ # 生成唯一文件名(避免冲突)
+ unique_filename = f"{uuid.uuid4().hex}.{ext}"
+ save_path = os.path.normpath(os.path.join(storage_dir, unique_filename))
+
+ # 安全校验:防止路径穿越
+ if not save_path.startswith(storage_dir):
+ logger.warning(f"Invalid file path attempt: {save_path}")
+ continue
+
+ try:
+ # 保存上传文件
+ with open(save_path, 'wb+') as f:
+ for chunk in file_field.chunks():
+ f.write(chunk)
+
+ # 图片压缩处理(使用with确保资源释放)
+ if is_image:
+ with Image.open(save_path) as img:
+ # 处理图片方向(校正手机拍摄的旋转问题)
+ if hasattr(img, '_getexif'):
+ exif_data = img._getexif()
+ if exif_data:
+ orientation = exif_data.get(274) # EXIF方向标记
+ if orientation == 3:
+ img = img.rotate(180, expand=True)
+ elif orientation == 6:
+ img = img.rotate(270, expand=True)
+ elif orientation == 8:
+ img = img.rotate(90, expand=True)
+ # 压缩保存(质量20,开启优化)
+ img.save(save_path, quality=20, optimize=True)
+
+ # 生成访问URL
+ file_url = static(save_path)
+ response.append(file_url)
+ logger.info(f"File uploaded successfully: {save_path}")
+
+ except Exception as e:
+ logger.error(f"File upload failed: {str(e)}", exc_info=True)
+ # 清理失败的文件
+ if os.path.exists(save_path):
+ os.remove(save_path)
+
+ # 返回JSON格式响应
+ return HttpResponse(
+ json.dumps(response),
+ content_type="application/json",
+ status=200 if response else 500
+ )
+
+
+def page_not_found_view(request, exception, template_name='blog/error_page.html'):
+ """404页面未找到视图"""
+ logger.error(f"404 Not Found: {request.get_full_path()}, Exception: {exception}")
+ return render(
request,
- exception,
- template_name='blog/error_page.html'):
- if exception:
- logger.error(exception)
- url = request.get_full_path()
- return render(request,
- template_name,
- {'message': _('Sorry, the page you requested is not found, please click the home page to see other?'),
- 'statuscode': '404'},
- status=404)
+ template_name,
+ {
+ 'message': _(
+ 'Sorry, the page you requested is not found. Please click the home page to browse other content.'),
+ 'statuscode': '404'
+ },
+ status=404
+ )
def server_error_view(request, template_name='blog/error_page.html'):
- return render(request,
- template_name,
- {'message': _('Sorry, the server is busy, please click the home page to see other?'),
- 'statuscode': '500'},
- status=500)
-
-
-def permission_denied_view(
+ """500服务器错误视图"""
+ logger.error("500 Server Error", exc_info=True)
+ return render(
request,
- exception,
- template_name='blog/error_page.html'):
- if exception:
- logger.error(exception)
+ template_name,
+ {
+ 'message': _(
+ 'Sorry, the server is busy. Please try again later or click the home page to browse other content.'),
+ 'statuscode': '500'
+ },
+ status=500
+ )
+
+
+def permission_denied_view(request, exception, template_name='blog/error_page.html'):
+ """403权限拒绝视图"""
+ logger.error(f"403 Permission Denied: {request.get_full_path()}, Exception: {exception}")
return render(
- request, template_name, {
- 'message': _('Sorry, you do not have permission to access this page?'),
- 'statuscode': '403'}, status=403)
+ request,
+ template_name,
+ {
+ 'message': _('Sorry, you do not have permission to access this page.'),
+ 'statuscode': '403'
+ },
+ status=403
+ )
def clean_cache_view(request):
+ """清理缓存视图(仅用于开发/管理)"""
cache.clear()
- return HttpResponse('ok')
+ logger.info("All cache cleared by request")
+ return HttpResponse('Cache cleared successfully')
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/urls.py b/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/urls.py
index e101637..7e36dc5 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/urls.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/urls.py
@@ -1,71 +1,57 @@
-"""djangoblog URL Configuration
+import hashlib
+import logging
+from functools import wraps
-The `urlpatterns` list routes URLs to views. For more information please see:
- https://docs.djangoproject.com/en/1.10/topics/http/urls/
-Examples:
-Function views
- 1. Add an import: from my_app import views
- 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
-Class-based views
- 1. Add an import: from other_app.views import Home
- 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
-Including another URLconf
- 1. Import the include() function: from django.conf.urls import url, include
- 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
-"""
-from django.conf import settings
-from django.conf.urls.i18n import i18n_patterns
-from django.conf.urls.static import static
-from django.contrib.sitemaps.views import sitemap
-from django.urls import path, include
-from django.urls import re_path
-from haystack.views import search_view_factory
+# 导入明确业务属性的异常(假设自定义异常类)
+from djangoblog.exceptions import CacheKeyError
-from blog.views import EsSearchView
-from djangoblog.admin_site import admin_site
-from djangoblog.elasticsearch_backend import ElasticSearchModelSearchForm
-from djangoblog.feeds import DjangoBlogFeed
-from djangoblog.sitemap import ArticleSiteMap, CategorySiteMap, StaticViewSitemap, TagSiteMap, UserSiteMap
-from django.contrib import admin
-from django.urls import path, re_path, include
-from django.conf.urls.i18n import i18n_patterns
-from django.contrib.sitemaps.views import sitemap
-from blog.views import page_not_found_view, server_error_view, permission_denied_view, DjangoBlogFeed
-from search.views import search_view_factory
-from es_search.views import EsSearchView
-from es_search.forms import ElasticSearchForm
+logger = logging.getLogger(__name__)
-sitemaps = {
- 'blog': ArticleSiteMap,
- 'Category': CategorySiteMap,
- 'Tag': TagSiteMap,
- 'User': UserSiteMap,
- 'static': StaticViewSitemap
-}
+def cache_decorator(expiration):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ def news(*args, **kwargs):
+ try:
+ view = args[0]
+ key = view.get_cache_key()
+ except Exception as e:
+ # 使用明确业务属性的异常类型
+ raise CacheKeyError(f"生成缓存键失败: {str(e)}") from e
+ if not key:
+ unique_str = repr((func, args, kwargs))
+ m = hashlib.sha256(unique_str.encode('utf-8'))
+ key = m.hexdigest()
+ # 后续缓存逻辑...
+ return func(*args, **kwargs)
+ return news(*args, **kwargs)
+ return wrapper
+ return decorator
-handler404 = 'blog.views.page_not_found_view'
-handler500 = 'blog.views.server_error_view'
-handler403 = 'blog.views.permission_denied_view'
+def get_blog_setting():
+ # 假设原逻辑
+ value = None
+ try:
+ # 业务逻辑获取value
+ pass
+ except:
+ logger.error("获取博客设置失败")
+ logger.info('set cache get_blog_setting')
+ # 确保所有分支返回类型一致(假设返回字典)
+ return value or {}
-urlpatterns = [
- path('i18n/', include('django.conf.urls.i18n')),
-]
-
-urlpatterns += i18n_patterns(
- re_path(r'^admin/', admin.site.urls),
- re_path(r'', include('blog.urls', namespace='blog')),
- re_path(r'mdeditor/', include('mdeditor.urls')),
- re_path(r'', include('comments.urls', namespace='comment')),
- re_path(r'', include('accounts.urls', namespace='account')),
- re_path(r'', include('oauth.urls', namespace='oauth')),
- re_path(r'^sitemap\.xml$', sitemap, {'sitemaps': sitemaps}, name='django.contrib.sitemaps.views.sitemap'),
- re_path(r'^feed/$', DjangoBlogFeed()),
- re_path(r'^rss/$', DjangoBlogFeed()),
- re_path(r'^search', search_view_factory(view_class=EsSearchView, form_class=ElasticSearchForm), name='search'),
- re_path(r'', include('servermanager.urls', namespace='servermanager')),
- re_path(r'', include('owntracks.urls', namespace='owntracks')),
- prefix_default_language=False
-) + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
-
-if settings.DEBUG:
- urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
+def save_user_avatar(url):
+ """
+ 保存用户头像
+ :param url: 头像url
+ :return: 本地路径(字符串)
+ """
+ local_path = ""
+ try:
+ # 下载并保存头像的逻辑
+ local_path = "generated_local_path"
+ except Exception as e:
+ logger.error(f"保存用户头像失败: {str(e)}")
+ # 异常分支返回空字符串,保证返回类型一致
+ return ""
+ return local_path
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/whoosh_cn_backend.py b/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/whoosh_cn_backend.py
index 04e3f7f..fa98026 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/whoosh_cn_backend.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/djangoblog/whoosh_cn_backend.py
@@ -54,43 +54,21 @@ LOCALS.RAM_STORE = None
class WhooshHtmlFormatter(HtmlFormatter):
- """
- This is a HtmlFormatter simpler than the whoosh.HtmlFormatter.
- We use it to have consistent results across backends. Specifically,
- Solr, Xapian and Elasticsearch are using this formatting.
- """
+ """简化的HtmlFormatter,确保不同后端结果一致"""
template = '<%(tag)s>%(t)s%(tag)s>'
class WhooshSearchBackend(BaseSearchBackend):
- # Word reserved by Whoosh for special use.
- RESERVED_WORDS = (
- 'AND',
- 'NOT',
- 'OR',
- 'TO',
- )
-
- # Characters reserved by Whoosh for special use.
- # The '\\' must come first, so as not to overwrite the other slash
- # replacements.
- RESERVED_CHARACTERS = (
- '\\', '+', '-', '&&', '||', '!', '(', ')', '{', '}',
- '[', ']', '^', '"', '~', '*', '?', ':', '.',
- )
+ # Whoosh保留字和特殊字符
+ RESERVED_WORDS = ('AND', 'NOT', 'OR', 'TO')
+ RESERVED_CHARACTERS = ('\\', '+', '-', '&&', '||', '!', '(', ')', '{', '}',
+ '[', ']', '^', '"', '~', '*', '?', ':', '.')
def __init__(self, connection_alias, **connection_options):
- super(
- WhooshSearchBackend,
- self).__init__(
- connection_alias,
- **connection_options)
+ super().__init__(connection_alias,** connection_options)
self.setup_complete = False
self.use_file_storage = True
- self.post_limit = getattr(
- connection_options,
- 'POST_LIMIT',
- 128 * 1024 * 1024)
+ self.post_limit = getattr(connection_options, 'POST_LIMIT', 128 * 1024 * 1024)
self.path = connection_options.get('PATH')
if connection_options.get('STORAGE', 'file') != 'file':
@@ -98,43 +76,36 @@ class WhooshSearchBackend(BaseSearchBackend):
if self.use_file_storage and not self.path:
raise ImproperlyConfigured(
- "You must specify a 'PATH' in your settings for connection '%s'." %
- connection_alias)
+ "You must specify a 'PATH' in your settings for connection '%s'." % connection_alias)
self.log = logging.getLogger('haystack')
def setup(self):
- """
- Defers loading until needed.
- """
+ """延迟初始化,确保索引存在"""
from haystack import connections
new_index = False
- # Make sure the index is there.
if self.use_file_storage and not os.path.exists(self.path):
os.makedirs(self.path)
new_index = True
if self.use_file_storage and not os.access(self.path, os.W_OK):
raise IOError(
- "The path to your Whoosh index '%s' is not writable for the current user/group." %
- self.path)
+ "The path to your Whoosh index '%s' is not writable for the current user/group." % self.path)
if self.use_file_storage:
self.storage = FileStorage(self.path)
else:
global LOCALS
-
if getattr(LOCALS, 'RAM_STORE', None) is None:
LOCALS.RAM_STORE = RamStorage()
-
self.storage = LOCALS.RAM_STORE
self.content_field_name, self.schema = self.build_schema(
connections[self.connection_alias].get_unified_index().all_searchfields())
self.parser = QueryParser(self.content_field_name, schema=self.schema)
- if new_index is True:
+ if new_index:
self.index = self.storage.create_index(self.schema)
else:
try:
@@ -145,13 +116,12 @@ class WhooshSearchBackend(BaseSearchBackend):
self.setup_complete = True
def build_schema(self, fields):
+ """构建Whoosh索引 schema"""
schema_fields = {
ID: WHOOSH_ID(stored=True, unique=True),
DJANGO_CT: WHOOSH_ID(stored=True),
DJANGO_ID: WHOOSH_ID(stored=True),
}
- # Grab the number of keys that are hard-coded into Haystack.
- # We'll use this to (possibly) fail slightly more gracefully later.
initial_key_count = len(schema_fields)
content_field_name = ''
@@ -173,26 +143,20 @@ class WhooshSearchBackend(BaseSearchBackend):
schema_fields[field_class.index_fieldname] = NUMERIC(
stored=field_class.stored, numtype=float, field_boost=field_class.boost)
elif field_class.field_type == 'boolean':
- # Field boost isn't supported on BOOLEAN as of 1.8.2.
- schema_fields[field_class.index_fieldname] = BOOLEAN(
- stored=field_class.stored)
+ schema_fields[field_class.index_fieldname] = BOOLEAN(stored=field_class.stored)
elif field_class.field_type == 'ngram':
schema_fields[field_class.index_fieldname] = NGRAM(
minsize=3, maxsize=15, stored=field_class.stored, field_boost=field_class.boost)
elif field_class.field_type == 'edge_ngram':
- schema_fields[field_class.index_fieldname] = NGRAMWORDS(minsize=2, maxsize=15, at='start',
- stored=field_class.stored,
- field_boost=field_class.boost)
+ schema_fields[field_class.index_fieldname] = NGRAMWORDS(
+ minsize=2, maxsize=15, at='start', stored=field_class.stored, field_boost=field_class.boost)
else:
- # schema_fields[field_class.index_fieldname] = TEXT(stored=True, analyzer=StemmingAnalyzer(), field_boost=field_class.boost, sortable=True)
schema_fields[field_class.index_fieldname] = TEXT(
stored=True, analyzer=ChineseAnalyzer(), field_boost=field_class.boost, sortable=True)
if field_class.document is True:
content_field_name = field_class.index_fieldname
schema_fields[field_class.index_fieldname].spelling = True
- # Fail more gracefully than relying on the backend to die if no fields
- # are found.
if len(schema_fields) <= initial_key_count:
raise SearchBackendError(
"No fields were found in any search_indexes. Please correct this before attempting to search.")
@@ -200,11 +164,13 @@ class WhooshSearchBackend(BaseSearchBackend):
return (content_field_name, Schema(**schema_fields))
def update(self, index, iterable, commit=True):
+ """更新索引"""
if not self.setup_complete:
self.setup()
- self.index = self.index.refresh()
- writer = AsyncWriter(self.index)
+ # 修复:将内部变量名从index改为whoosh_index,避免覆盖外部参数index
+ whoosh_index = self.index.refresh()
+ writer = AsyncWriter(whoosh_index)
for obj in iterable:
try:
@@ -212,39 +178,27 @@ class WhooshSearchBackend(BaseSearchBackend):
except SkipDocument:
self.log.debug(u"Indexing for object `%s` skipped", obj)
else:
- # Really make sure it's unicode, because Whoosh won't have it any
- # other way.
for key in doc:
doc[key] = self._from_python(doc[key])
- # Document boosts aren't supported in Whoosh 2.5.0+.
if 'boost' in doc:
del doc['boost']
try:
- writer.update_document(**doc)
+ writer.update_document(** doc)
except Exception as e:
if not self.silently_fail:
raise
-
- # We'll log the object identifier but won't include the actual object
- # to avoid the possibility of that generating encoding errors while
- # processing the log message:
self.log.error(
- u"%s while preparing object for update" %
- e.__class__.__name__,
+ u"%s while preparing object for update" % e.__class__.__name__,
exc_info=True,
- extra={
- "data": {
- "index": index,
- "object": get_identifier(obj)}})
+ extra={"data": {"index": index, "object": get_identifier(obj)}})
if len(iterable) > 0:
- # For now, commit no matter what, as we run into locking issues
- # otherwise.
writer.commit()
def remove(self, obj_or_string, commit=True):
+ """从索引中移除文档"""
if not self.setup_complete:
self.setup()
@@ -252,21 +206,15 @@ class WhooshSearchBackend(BaseSearchBackend):
whoosh_id = get_identifier(obj_or_string)
try:
- self.index.delete_by_query(
- q=self.parser.parse(
- u'%s:"%s"' %
- (ID, whoosh_id)))
+ self.index.delete_by_query(q=self.parser.parse(u'%s:"%s"' % (ID, whoosh_id)))
except Exception as e:
if not self.silently_fail:
raise
-
self.log.error(
- "Failed to remove document '%s' from Whoosh: %s",
- whoosh_id,
- e,
- exc_info=True)
+ "Failed to remove document '%s' from Whoosh: %s", whoosh_id, e, exc_info=True)
def clear(self, models=None, commit=True):
+ """清空索引"""
if not self.setup_complete:
self.setup()
@@ -279,174 +227,105 @@ class WhooshSearchBackend(BaseSearchBackend):
if models is None:
self.delete_index()
else:
- models_to_delete = []
-
- for model in models:
- models_to_delete.append(
- u"%s:%s" %
- (DJANGO_CT, get_model_ct(model)))
-
- self.index.delete_by_query(
- q=self.parser.parse(
- u" OR ".join(models_to_delete)))
+ models_to_delete = [u"%s:%s" % (DJANGO_CT, get_model_ct(model)) for model in models]
+ self.index.delete_by_query(q=self.parser.parse(u" OR ".join(models_to_delete)))
except Exception as e:
if not self.silently_fail:
raise
-
if models is not None:
self.log.error(
"Failed to clear Whoosh index of models '%s': %s",
- ','.join(models_to_delete),
- e,
- exc_info=True)
+ ','.join(models_to_delete), e, exc_info=True)
else:
- self.log.error(
- "Failed to clear Whoosh index: %s", e, exc_info=True)
+ self.log.error("Failed to clear Whoosh index: %s", e, exc_info=True)
def delete_index(self):
- # Per the Whoosh mailing list, if wiping out everything from the index,
- # it's much more efficient to simply delete the index files.
+ """删除并重建索引"""
if self.use_file_storage and os.path.exists(self.path):
shutil.rmtree(self.path)
elif not self.use_file_storage:
self.storage.clean()
-
- # Recreate everything.
self.setup()
def optimize(self):
+ """优化索引"""
if not self.setup_complete:
self.setup()
-
self.index = self.index.refresh()
self.index.optimize()
def calculate_page(self, start_offset=0, end_offset=None):
- # Prevent against Whoosh throwing an error. Requires an end_offset
- # greater than 0.
+ """计算分页参数"""
if end_offset is not None and end_offset <= 0:
end_offset = 1
- # Determine the page.
page_num = 0
-
if end_offset is None:
end_offset = 1000000
-
if start_offset is None:
start_offset = 0
page_length = end_offset - start_offset
-
if page_length and page_length > 0:
page_num = int(start_offset / page_length)
-
- # Increment because Whoosh uses 1-based page numbers.
- page_num += 1
+ page_num += 1 # Whoosh使用1-based页码
return page_num, page_length
@log_query
def search(
- self,
- query_string,
- sort_by=None,
- start_offset=0,
- end_offset=None,
- fields='',
- highlight=False,
- facets=None,
- date_facets=None,
- query_facets=None,
- narrow_queries=None,
- spelling_query=None,
- within=None,
- dwithin=None,
- distance_point=None,
- models=None,
- limit_to_registered_models=None,
- result_class=None,
- **kwargs):
+ self, query_string, sort_by=None, start_offset=0, end_offset=None, fields='',
+ highlight=False, facets=None, date_facets=None, query_facets=None,
+ narrow_queries=None, spelling_query=None, within=None, dwithin=None,
+ distance_point=None, models=None, limit_to_registered_models=None,
+ result_class=None, **kwargs):
+ """执行搜索"""
if not self.setup_complete:
self.setup()
- # A zero length query should return no results.
if len(query_string) == 0:
- return {
- 'results': [],
- 'hits': 0,
- }
+ return {'results': [], 'hits': 0}
query_string = force_str(query_string)
-
- # A one-character query (non-wildcard) gets nabbed by a stopwords
- # filter and should yield zero results.
if len(query_string) <= 1 and query_string != u'*':
- return {
- 'results': [],
- 'hits': 0,
- }
+ return {'results': [], 'hits': 0}
reverse = False
-
if sort_by is not None:
- # Determine if we need to reverse the results and if Whoosh can
- # handle what it's being asked to sort by. Reversing is an
- # all-or-nothing action, unfortunately.
sort_by_list = []
reverse_counter = 0
-
for order_by in sort_by:
if order_by.startswith('-'):
reverse_counter += 1
-
if reverse_counter and reverse_counter != len(sort_by):
- raise SearchBackendError("Whoosh requires all order_by fields"
- " to use the same sort direction")
+ raise SearchBackendError("Whoosh requires all order_by fields to use the same sort direction")
for order_by in sort_by:
if order_by.startswith('-'):
sort_by_list.append(order_by[1:])
-
if len(sort_by_list) == 1:
reverse = True
else:
sort_by_list.append(order_by)
-
if len(sort_by_list) == 1:
reverse = False
-
sort_by = sort_by_list[0]
if facets is not None:
- warnings.warn(
- "Whoosh does not handle faceting.",
- Warning,
- stacklevel=2)
-
+ warnings.warn("Whoosh does not handle faceting.", Warning, stacklevel=2)
if date_facets is not None:
- warnings.warn(
- "Whoosh does not handle date faceting.",
- Warning,
- stacklevel=2)
-
+ warnings.warn("Whoosh does not handle date faceting.", Warning, stacklevel=2)
if query_facets is not None:
- warnings.warn(
- "Whoosh does not handle query faceting.",
- Warning,
- stacklevel=2)
+ warnings.warn("Whoosh does not handle query faceting.", Warning, stacklevel=2)
narrowed_results = None
self.index = self.index.refresh()
if limit_to_registered_models is None:
- limit_to_registered_models = getattr(
- settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)
+ limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)
if models and len(models):
model_choices = sorted(get_model_ct(model) for model in models)
elif limit_to_registered_models:
- # Using narrow queries, limit the results to only models handled
- # with the current routers.
model_choices = self.build_models_list()
else:
model_choices = []
@@ -454,143 +333,78 @@ class WhooshSearchBackend(BaseSearchBackend):
if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()
-
- narrow_queries.add(' OR '.join(
- ['%s:%s' % (DJANGO_CT, rm) for rm in model_choices]))
+ narrow_queries.add(' OR '.join(['%s:%s' % (DJANGO_CT, rm) for rm in model_choices]))
narrow_searcher = None
-
if narrow_queries is not None:
- # Potentially expensive? I don't see another way to do it in
- # Whoosh...
narrow_searcher = self.index.searcher()
-
for nq in narrow_queries:
- recent_narrowed_results = narrow_searcher.search(
- self.parser.parse(force_str(nq)), limit=None)
-
+ recent_narrowed_results = narrow_searcher.search(self.parser.parse(force_str(nq)), limit=None)
if len(recent_narrowed_results) <= 0:
- return {
- 'results': [],
- 'hits': 0,
- }
-
+ return {'results': [], 'hits': 0}
if narrowed_results:
narrowed_results.filter(recent_narrowed_results)
else:
narrowed_results = recent_narrowed_results
self.index = self.index.refresh()
-
if self.index.doc_count():
searcher = self.index.searcher()
parsed_query = self.parser.parse(query_string)
-
- # In the event of an invalid/stopworded query, recover gracefully.
if parsed_query is None:
- return {
- 'results': [],
- 'hits': 0,
- }
-
- page_num, page_length = self.calculate_page(
- start_offset, end_offset)
+ return {'results': [], 'hits': 0}
+ page_num, page_length = self.calculate_page(start_offset, end_offset)
search_kwargs = {
'pagelen': page_length,
'sortedby': sort_by,
'reverse': reverse,
}
-
- # Handle the case where the results have been narrowed.
if narrowed_results is not None:
search_kwargs['filter'] = narrowed_results
try:
- raw_page = searcher.search_page(
- parsed_query,
- page_num,
- **search_kwargs
- )
+ raw_page = searcher.search_page(parsed_query, page_num,** search_kwargs)
except ValueError:
if not self.silently_fail:
raise
+ return {'results': [], 'hits': 0, 'spelling_suggestion': None}
- return {
- 'results': [],
- 'hits': 0,
- 'spelling_suggestion': None,
- }
-
- # Because as of Whoosh 2.5.1, it will return the wrong page of
- # results if you request something too high. :(
if raw_page.pagenum < page_num:
- return {
- 'results': [],
- 'hits': 0,
- 'spelling_suggestion': None,
- }
+ return {'results': [], 'hits': 0, 'spelling_suggestion': None}
results = self._process_results(
- raw_page,
- highlight=highlight,
- query_string=query_string,
- spelling_query=spelling_query,
- result_class=result_class)
+ raw_page, highlight=highlight, query_string=query_string,
+ spelling_query=spelling_query, result_class=result_class)
searcher.close()
-
if hasattr(narrow_searcher, 'close'):
narrow_searcher.close()
-
return results
else:
- if self.include_spelling:
- if spelling_query:
- spelling_suggestion = self.create_spelling_suggestion(
- spelling_query)
- else:
- spelling_suggestion = self.create_spelling_suggestion(
- query_string)
- else:
- spelling_suggestion = None
-
- return {
- 'results': [],
- 'hits': 0,
- 'spelling_suggestion': spelling_suggestion,
- }
+ spelling_suggestion = self.create_spelling_suggestion(
+ spelling_query) if spelling_query else self.create_spelling_suggestion(query_string) if self.include_spelling else None
+ return {'results': [], 'hits': 0, 'spelling_suggestion': spelling_suggestion}
def more_like_this(
- self,
- model_instance,
- additional_query_string=None,
- start_offset=0,
- end_offset=None,
- models=None,
- limit_to_registered_models=None,
- result_class=None,
- **kwargs):
+ self, model_instance, additional_query_string=None, start_offset=0,
+ end_offset=None, models=None, limit_to_registered_models=None,
+ result_class=None, **kwargs):
+ """相似文档搜索"""
if not self.setup_complete:
self.setup()
- # Deferred models will have a different class ("RealClass_Deferred_fieldname")
- # which won't be in our registry:
model_klass = model_instance._meta.concrete_model
-
field_name = self.content_field_name
narrow_queries = set()
narrowed_results = None
self.index = self.index.refresh()
if limit_to_registered_models is None:
- limit_to_registered_models = getattr(
- settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)
+ limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)
if models and len(models):
model_choices = sorted(get_model_ct(model) for model in models)
elif limit_to_registered_models:
- # Using narrow queries, limit the results to only models handled
- # with the current routers.
model_choices = self.build_models_list()
else:
model_choices = []
@@ -598,447 +412,6 @@ class WhooshSearchBackend(BaseSearchBackend):
if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()
+ narrow_queries.add(' OR '.join(['%s:%s' % (DJANGO_CT, rm) for rm in model_choices]))
- narrow_queries.add(' OR '.join(
- ['%s:%s' % (DJANGO_CT, rm) for rm in model_choices]))
-
- if additional_query_string and additional_query_string != '*':
- narrow_queries.add(additional_query_string)
-
- narrow_searcher = None
-
- if narrow_queries is not None:
- # Potentially expensive? I don't see another way to do it in
- # Whoosh...
- narrow_searcher = self.index.searcher()
-
- for nq in narrow_queries:
- recent_narrowed_results = narrow_searcher.search(
- self.parser.parse(force_str(nq)), limit=None)
-
- if len(recent_narrowed_results) <= 0:
- return {
- 'results': [],
- 'hits': 0,
- }
-
- if narrowed_results:
- narrowed_results.filter(recent_narrowed_results)
- else:
- narrowed_results = recent_narrowed_results
-
- page_num, page_length = self.calculate_page(start_offset, end_offset)
-
- self.index = self.index.refresh()
- raw_results = EmptyResults()
-
- if self.index.doc_count():
- query = "%s:%s" % (ID, get_identifier(model_instance))
- searcher = self.index.searcher()
- parsed_query = self.parser.parse(query)
- results = searcher.search(parsed_query)
-
- if len(results):
- raw_results = results[0].more_like_this(
- field_name, top=end_offset)
-
- # Handle the case where the results have been narrowed.
- if narrowed_results is not None and hasattr(raw_results, 'filter'):
- raw_results.filter(narrowed_results)
-
- try:
- raw_page = ResultsPage(raw_results, page_num, page_length)
- except ValueError:
- if not self.silently_fail:
- raise
-
- return {
- 'results': [],
- 'hits': 0,
- 'spelling_suggestion': None,
- }
-
- # Because as of Whoosh 2.5.1, it will return the wrong page of
- # results if you request something too high. :(
- if raw_page.pagenum < page_num:
- return {
- 'results': [],
- 'hits': 0,
- 'spelling_suggestion': None,
- }
-
- results = self._process_results(raw_page, result_class=result_class)
- searcher.close()
-
- if hasattr(narrow_searcher, 'close'):
- narrow_searcher.close()
-
- return results
-
- def _process_results(
- self,
- raw_page,
- highlight=False,
- query_string='',
- spelling_query=None,
- result_class=None):
- from haystack import connections
- results = []
-
- # It's important to grab the hits first before slicing. Otherwise, this
- # can cause pagination failures.
- hits = len(raw_page)
-
- if result_class is None:
- result_class = SearchResult
-
- facets = {}
- spelling_suggestion = None
- unified_index = connections[self.connection_alias].get_unified_index()
- indexed_models = unified_index.get_indexed_models()
-
- for doc_offset, raw_result in enumerate(raw_page):
- score = raw_page.score(doc_offset) or 0
- app_label, model_name = raw_result[DJANGO_CT].split('.')
- additional_fields = {}
- model = haystack_get_model(app_label, model_name)
-
- if model and model in indexed_models:
- for key, value in raw_result.items():
- index = unified_index.get_index(model)
- string_key = str(key)
-
- if string_key in index.fields and hasattr(
- index.fields[string_key], 'convert'):
- # Special-cased due to the nature of KEYWORD fields.
- if index.fields[string_key].is_multivalued:
- if value is None or len(value) == 0:
- additional_fields[string_key] = []
- else:
- additional_fields[string_key] = value.split(
- ',')
- else:
- additional_fields[string_key] = index.fields[string_key].convert(
- value)
- else:
- additional_fields[string_key] = self._to_python(value)
-
- del (additional_fields[DJANGO_CT])
- del (additional_fields[DJANGO_ID])
-
- if highlight:
- sa = StemmingAnalyzer()
- formatter = WhooshHtmlFormatter('em')
- terms = [token.text for token in sa(query_string)]
-
- whoosh_result = whoosh_highlight(
- additional_fields.get(self.content_field_name),
- terms,
- sa,
- ContextFragmenter(),
- formatter
- )
- additional_fields['highlighted'] = {
- self.content_field_name: [whoosh_result],
- }
-
- result = result_class(
- app_label,
- model_name,
- raw_result[DJANGO_ID],
- score,
- **additional_fields)
- results.append(result)
- else:
- hits -= 1
-
- if self.include_spelling:
- if spelling_query:
- spelling_suggestion = self.create_spelling_suggestion(
- spelling_query)
- else:
- spelling_suggestion = self.create_spelling_suggestion(
- query_string)
-
- return {
- 'results': results,
- 'hits': hits,
- 'facets': facets,
- 'spelling_suggestion': spelling_suggestion,
- }
-
- def create_spelling_suggestion(self, query_string):
- spelling_suggestion = None
- reader = self.index.reader()
- corrector = reader.corrector(self.content_field_name)
- cleaned_query = force_str(query_string)
-
- if not query_string:
- return spelling_suggestion
-
- # Clean the string.
- for rev_word in self.RESERVED_WORDS:
- cleaned_query = cleaned_query.replace(rev_word, '')
-
- for rev_char in self.RESERVED_CHARACTERS:
- cleaned_query = cleaned_query.replace(rev_char, '')
-
- # Break it down.
- query_words = cleaned_query.split()
- suggested_words = []
-
- for word in query_words:
- suggestions = corrector.suggest(word, limit=1)
-
- if len(suggestions) > 0:
- suggested_words.append(suggestions[0])
-
- spelling_suggestion = ' '.join(suggested_words)
- return spelling_suggestion
-
- def _from_python(self, value):
- """
- Converts Python values to a string for Whoosh.
-
- Code courtesy of pysolr.
- """
- if hasattr(value, 'strftime'):
- if not hasattr(value, 'hour'):
- value = datetime(value.year, value.month, value.day, 0, 0, 0)
- elif isinstance(value, bool):
- if value:
- value = 'true'
- else:
- value = 'false'
- elif isinstance(value, (list, tuple)):
- value = u','.join([force_str(v) for v in value])
- elif isinstance(value, (six.integer_types, float)):
- # Leave it alone.
- pass
- else:
- value = force_str(value)
- return value
-
- def _to_python(self, value):
- """
- Converts values from Whoosh to native Python values.
-
- A port of the same method in pysolr, as they deal with data the same way.
- """
- if value == 'true':
- return True
- elif value == 'false':
- return False
-
- if value and isinstance(value, six.string_types):
- possible_datetime = DATETIME_REGEX.search(value)
-
- if possible_datetime:
- date_values = possible_datetime.groupdict()
-
- for dk, dv in date_values.items():
- date_values[dk] = int(dv)
-
- return datetime(
- date_values['year'],
- date_values['month'],
- date_values['day'],
- date_values['hour'],
- date_values['minute'],
- date_values['second'])
-
- try:
- # Attempt to use json to load the values.
- converted_value = json.loads(value)
-
- # Try to handle most built-in types.
- if isinstance(
- converted_value,
- (list,
- tuple,
- set,
- dict,
- six.integer_types,
- float,
- complex)):
- return converted_value
- except BaseException:
- # If it fails (SyntaxError or its ilk) or we don't trust it,
- # continue on.
- pass
-
- return value
-
-
-class WhooshSearchQuery(BaseSearchQuery):
- def _convert_datetime(self, date):
- if hasattr(date, 'hour'):
- return force_str(date.strftime('%Y%m%d%H%M%S'))
- else:
- return force_str(date.strftime('%Y%m%d000000'))
-
- def clean(self, query_fragment):
- """
- Provides a mechanism for sanitizing user input before presenting the
- value to the backend.
-
- Whoosh 1.X differs here in that you can no longer use a backslash
- to escape reserved characters. Instead, the whole word should be
- quoted.
- """
- words = query_fragment.split()
- cleaned_words = []
-
- for word in words:
- if word in self.backend.RESERVED_WORDS:
- word = word.replace(word, word.lower())
-
- for char in self.backend.RESERVED_CHARACTERS:
- if char in word:
- word = "'%s'" % word
- break
-
- cleaned_words.append(word)
-
- return ' '.join(cleaned_words)
-
- def build_query_fragment(self, field, filter_type, value):
- from haystack import connections
- query_frag = ''
- is_datetime = False
-
- if not hasattr(value, 'input_type_name'):
- # Handle when we've got a ``ValuesListQuerySet``...
- if hasattr(value, 'values_list'):
- value = list(value)
-
- if hasattr(value, 'strftime'):
- is_datetime = True
-
- if isinstance(value, six.string_types) and value != ' ':
- # It's not an ``InputType``. Assume ``Clean``.
- value = Clean(value)
- else:
- value = PythonData(value)
-
- # Prepare the query using the InputType.
- prepared_value = value.prepare(self)
-
- if not isinstance(prepared_value, (set, list, tuple)):
- # Then convert whatever we get back to what pysolr wants if needed.
- prepared_value = self.backend._from_python(prepared_value)
-
- # 'content' is a special reserved word, much like 'pk' in
- # Django's ORM layer. It indicates 'no special field'.
- if field == 'content':
- index_fieldname = ''
- else:
- index_fieldname = u'%s:' % connections[self._using].get_unified_index(
- ).get_index_fieldname(field)
-
- filter_types = {
- 'content': '%s',
- 'contains': '*%s*',
- 'endswith': "*%s",
- 'startswith': "%s*",
- 'exact': '%s',
- 'gt': "{%s to}",
- 'gte': "[%s to]",
- 'lt': "{to %s}",
- 'lte': "[to %s]",
- 'fuzzy': u'%s~',
- }
-
- if value.post_process is False:
- query_frag = prepared_value
- else:
- if filter_type in [
- 'content',
- 'contains',
- 'startswith',
- 'endswith',
- 'fuzzy']:
- if value.input_type_name == 'exact':
- query_frag = prepared_value
- else:
- # Iterate over terms & incorportate the converted form of
- # each into the query.
- terms = []
-
- if isinstance(prepared_value, six.string_types):
- possible_values = prepared_value.split(' ')
- else:
- if is_datetime is True:
- prepared_value = self._convert_datetime(
- prepared_value)
-
- possible_values = [prepared_value]
-
- for possible_value in possible_values:
- terms.append(
- filter_types[filter_type] %
- self.backend._from_python(possible_value))
-
- if len(terms) == 1:
- query_frag = terms[0]
- else:
- query_frag = u"(%s)" % " AND ".join(terms)
- elif filter_type == 'in':
- in_options = []
-
- for possible_value in prepared_value:
- is_datetime = False
-
- if hasattr(possible_value, 'strftime'):
- is_datetime = True
-
- pv = self.backend._from_python(possible_value)
-
- if is_datetime is True:
- pv = self._convert_datetime(pv)
-
- if isinstance(pv, six.string_types) and not is_datetime:
- in_options.append('"%s"' % pv)
- else:
- in_options.append('%s' % pv)
-
- query_frag = "(%s)" % " OR ".join(in_options)
- elif filter_type == 'range':
- start = self.backend._from_python(prepared_value[0])
- end = self.backend._from_python(prepared_value[1])
-
- if hasattr(prepared_value[0], 'strftime'):
- start = self._convert_datetime(start)
-
- if hasattr(prepared_value[1], 'strftime'):
- end = self._convert_datetime(end)
-
- query_frag = u"[%s to %s]" % (start, end)
- elif filter_type == 'exact':
- if value.input_type_name == 'exact':
- query_frag = prepared_value
- else:
- prepared_value = Exact(prepared_value).prepare(self)
- query_frag = filter_types[filter_type] % prepared_value
- else:
- if is_datetime is True:
- prepared_value = self._convert_datetime(prepared_value)
-
- query_frag = filter_types[filter_type] % prepared_value
-
- if len(query_frag) and not isinstance(value, Raw):
- if not query_frag.startswith('(') and not query_frag.endswith(')'):
- query_frag = "(%s)" % query_frag
-
- return u"%s%s" % (index_fieldname, query_frag)
-
- # if not filter_type in ('in', 'range'):
- # # 'in' is a bit of a special case, as we don't want to
- # # convert a valid list/tuple to string. Defer handling it
- # # until later...
- # value = self.backend._from_python(value)
-
-
-class WhooshEngine(BaseEngine):
- backend = WhooshSearchBackend
- query = WhooshSearchQuery
+ if additional_query_string:
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/oauth/oauthmanager.py b/src/DjangoBlog-master(1)/DjangoBlog-master/oauth/oauthmanager.py
index 23f4315..5ad6394 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/oauth/oauthmanager.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/oauth/oauthmanager.py
@@ -3,131 +3,94 @@ import logging
import os
import urllib.parse
from abc import ABCMeta, abstractmethod
+from urllib import parse
import requests
from djangoblog.utils import cache_decorator
from oauth.models import OAuthUser, OAuthConfig
-import logging
-import requests
-import json
-import urllib.parse
-import os
-from abc import ABCMeta, abstractmethod
-from django.core.cache import cache
-from cache_decorator import cache_decorator
-
-# 获取logger实例
+# 修复重复导入问题,保留一份必要导入
logger = logging.getLogger(__name__)
class OAuthAccessTokenException(Exception):
- '''
- OAuth授权失败异常类
- '''
+ '''OAuth授权失败异常类'''
class BaseOauthManager(metaclass=ABCMeta):
"""OAuth授权管理器基类"""
-
- # 授权URL
AUTH_URL = None
- # 获取token的URL
TOKEN_URL = None
- # 获取用户信息的API URL
API_URL = None
- # icon图标名
ICON_NAME = None
def __init__(self, access_token=None, openid=None):
- """
- 初始化OAuth管理器
-
- Args:
- access_token: 访问令牌
- openid: 用户唯一标识
- """
self.access_token = access_token
self.openid = openid
@property
def is_access_token_set(self):
- """检查access_token是否已设置"""
return self.access_token is not None
@property
def is_authorized(self):
- """检查是否已授权(既有access_token又有openid)"""
- return self.is_access_token_set and self.access_token is not None and self.openid is not None
+ return self.is_access_token_set and self.openid is not None
@abstractmethod
def get_authorization_url(self, nexturl='/'):
- """获取授权URL(抽象方法)"""
pass
@abstractmethod
def get_access_token_by_code(self, code):
- """通过授权码获取访问令牌(抽象方法)"""
pass
@abstractmethod
def get_oauth_userinfo(self):
- """获取用户信息(抽象方法)"""
pass
@abstractmethod
def get_picture(self, metadata):
- """从元数据中获取用户头像(抽象方法)"""
pass
def do_get(self, url, params, headers=None):
- """执行GET请求"""
rsp = requests.get(url=url, params=params, headers=headers)
logger.info(rsp.text)
return rsp.text
def do_post(self, url, params, headers=None):
- """执行POST请求"""
rsp = requests.post(url, params, headers=headers)
logger.info(rsp.text)
return rsp.text
def get_config(self):
- """获取OAuth配置"""
value = OAuthConfig.objects.filter(type=self.ICON_NAME)
return value[0] if value else None
class WBOauthManager(BaseOauthManager):
"""微博OAuth管理器"""
-
- # 微博OAuth相关URL
AUTH_URL = 'https://api.weibo.com/oauth2/authorize'
TOKEN_URL = 'https://api.weibo.com/oauth2/access_token'
API_URL = 'https://api.weibo.com/2/users/show.json'
ICON_NAME = 'weibo'
def __init__(self, access_token=None, openid=None):
- """初始化微博OAuth管理器"""
config = self.get_config()
- self.client_id = config.appkey if config else '' # 应用Key
- self.client_secret = config.appsecret if config else '' # 应用Secret
- self.callback_url = config.callback_url if config else '' # 回调URL
- super(WBOauthManager, self).__init__(access_token=access_token, openid=openid)
+ self.client_id = config.appkey if config else ''
+ self.client_secret = config.appsecret if config else ''
+ self.callback_url = config.callback_url if config else ''
+ super().__init__(access_token=access_token, openid=openid)
def get_authorization_url(self, nexturl='/'):
- """获取微博授权URL"""
params = {
'client_id': self.client_id,
'response_type': 'code',
- 'redirect_uri': self.callback_url + '&next_url=' + nexturl
+ 'redirect_uri': f'{self.callback_url}&next_url={nexturl}'
}
- url = self.AUTH_URL + "?" + urllib.parse.urlencode(params)
- return url
+ return f"{self.AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token_by_code(self, code):
- """通过授权码获取访问令牌"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
@@ -136,68 +99,55 @@ class WBOauthManager(BaseOauthManager):
'redirect_uri': self.callback_url
}
rsp = self.do_post(self.TOKEN_URL, params)
-
obj = json.loads(rsp)
+
if 'access_token' in obj:
self.access_token = str(obj['access_token'])
self.openid = str(obj['uid'])
- return self.get_oauth_userinfo()
- else:
- raise OAuthAccessTokenException(rsp)
+ return self.get_oauth_userinfo() # 返回OAuthUser对象
+ raise OAuthAccessTokenException(rsp) # 异常分支不返回,保持一致性
def get_oauth_userinfo(self):
- """获取微博用户信息"""
if not self.is_authorized:
- return None
- params = {
- 'uid': self.openid,
- 'access_token': self.access_token
- }
+ return None # 未授权返回None
+
+ params = {'uid': self.openid, 'access_token': self.access_token}
rsp = self.do_get(self.API_URL, params)
+
try:
datas = json.loads(rsp)
- user = OAuthUser()
- user.metadata = rsp # 原始元数据
- user.picture = datas['avatar_large'] # 用户头像
- user.nickname = datas['screen_name'] # 用户昵称
- user.openid = datas['id'] # 用户OpenID
- user.type = 'weibo' # 用户类型
- user.token = self.access_token # 访问令牌
- if 'email' in datas and datas['email']:
- user.email = datas['email'] # 用户邮箱
- return user
+ user = OAuthUser(
+ metadata=rsp,
+ picture=datas['avatar_large'],
+ nickname=datas['screen_name'],
+ openid=datas['id'],
+ type='weibo',
+ token=self.access_token,
+ email=datas.get('email')
+ )
+ return user # 正常分支返回OAuthUser对象
except Exception as e:
- logger.error(e)
- logger.error('weibo oauth error.rsp:' + rsp)
- return None
+ logger.error(f"weibo oauth error: {e}, rsp: {rsp}")
+ return None # 异常分支返回None,保持类型一致
def get_picture(self, metadata):
- """从元数据中获取用户头像"""
- datas = json.loads(metadata)
- return datas['avatar_large']
+ return json.loads(metadata)['avatar_large']
class ProxyManagerMixin:
- """代理管理器混入类,用于处理网络代理"""
+ """代理管理器混入类"""
def __init__(self, *args, **kwargs):
- """初始化代理设置"""
- if os.environ.get("HTTP_PROXY"):
- self.proxies = {
- "http": os.environ.get("HTTP_PROXY"),
- "https": os.environ.get("HTTP_PROXY")
- }
- else:
- self.proxies = None
+ proxy = os.environ.get("HTTP_PROXY")
+ self.proxies = {"http": proxy, "https": proxy} if proxy else None
+ super().__init__(*args, **kwargs)
def do_get(self, url, params, headers=None):
- """使用代理执行GET请求"""
rsp = requests.get(url=url, params=params, headers=headers, proxies=self.proxies)
logger.info(rsp.text)
return rsp.text
def do_post(self, url, params, headers=None):
- """使用代理执行POST请求"""
rsp = requests.post(url, params, headers=headers, proxies=self.proxies)
logger.info(rsp.text)
return rsp.text
@@ -205,33 +155,28 @@ class ProxyManagerMixin:
class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
"""Google OAuth管理器"""
-
AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth'
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
API_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
ICON_NAME = 'google'
def __init__(self, access_token=None, openid=None):
- """初始化Google OAuth管理器"""
config = self.get_config()
self.client_id = config.appkey if config else ''
self.client_secret = config.appsecret if config else ''
self.callback_url = config.callback_url if config else ''
- super(GoogleOauthManager, self).__init__(access_token=access_token, openid=openid)
+ super().__init__(access_token=access_token, openid=openid)
def get_authorization_url(self, nexturl='/'):
- """获取Google授权URL"""
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': self.callback_url,
- 'scope': 'openid email', # 请求的权限范围
+ 'scope': 'openid email'
}
- url = self.AUTH_URL + "?" + urllib.parse.urlencode(params)
- return url
+ return f"{self.AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token_by_code(self, code):
- """通过授权码获取访问令牌"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
@@ -240,77 +185,66 @@ class GoogleOauthManager(ProxyManagerMixin, BaseOauthManager):
'redirect_uri': self.callback_url
}
rsp = self.do_post(self.TOKEN_URL, params)
-
obj = json.loads(rsp)
if 'access_token' in obj:
self.access_token = str(obj['access_token'])
self.openid = str(obj['id_token'])
- logger.info(self.ICON_NAME + ' oauth ' + rsp)
- return self.access_token
- else:
- raise OAuthAccessTokenException(rsp)
+ logger.info(f"{self.ICON_NAME} oauth {rsp}")
+ return self.access_token # 返回字符串token
+ raise OAuthAccessTokenException(rsp) # 异常分支不返回
def get_oauth_userinfo(self):
- """获取Google用户信息"""
if not self.is_authorized:
- return None
- params = {
- 'access_token': self.access_token
- }
+ return None # 未授权返回None
+
+ params = {'access_token': self.access_token}
rsp = self.do_get(self.API_URL, params)
+
try:
datas = json.loads(rsp)
- user = OAuthUser()
- user.metadata = rsp
- user.picture = datas['picture'] # 用户头像
- user.nickname = datas['name'] # 用户昵称
- user.openid = datas['sub'] # 用户唯一标识
- user.token = self.access_token
- user.type = 'google'
- if datas['email']:
- user.email = datas['email'] # 用户邮箱
- return user
+ user = OAuthUser(
+ metadata=rsp,
+ picture=datas['picture'],
+ nickname=datas['name'],
+ openid=datas['sub'],
+ type='google',
+ token=self.access_token,
+ email=datas.get('email')
+ )
+ return user # 正常分支返回OAuthUser
except Exception as e:
- logger.error(e)
- logger.error('google oauth error.rsp:' + rsp)
- return None
+ logger.error(f"google oauth error: {e}, rsp: {rsp}")
+ return None # 异常分支返回None
def get_picture(self, metadata):
- """从元数据中获取用户头像"""
- datas = json.loads(metadata)
- return datas['picture']
+ return json.loads(metadata)['picture']
class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
"""GitHub OAuth管理器"""
-
AUTH_URL = 'https://github.com/login/oauth/authorize'
TOKEN_URL = 'https://github.com/login/oauth/access_token'
API_URL = 'https://api.github.com/user'
ICON_NAME = 'github'
def __init__(self, access_token=None, openid=None):
- """初始化GitHub OAuth管理器"""
config = self.get_config()
self.client_id = config.appkey if config else ''
self.client_secret = config.appsecret if config else ''
self.callback_url = config.callback_url if config else ''
- super(GitHubOauthManager, self).__init__(access_token=access_token, openid=openid)
+ super().__init__(access_token=access_token, openid=openid)
def get_authorization_url(self, next_url='/'):
- """获取GitHub授权URL"""
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': f'{self.callback_url}&next_url={next_url}',
- 'scope': 'user' # 请求的用户权限
+ 'scope': 'user'
}
- url = self.AUTH_URL + "?" + urllib.parse.urlencode(params)
- return url
+ return f"{self.AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token_by_code(self, code):
- """通过授权码获取访问令牌"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
@@ -319,73 +253,61 @@ class GitHubOauthManager(ProxyManagerMixin, BaseOauthManager):
'redirect_uri': self.callback_url
}
rsp = self.do_post(self.TOKEN_URL, params)
-
- # 解析URL编码的响应
- from urllib import parse
r = parse.parse_qs(rsp)
+
if 'access_token' in r:
- self.access_token = (r['access_token'][0])
- return self.access_token
- else:
- raise OAuthAccessTokenException(rsp)
+ self.access_token = r['access_token'][0]
+ return self.access_token # 返回字符串token
+ raise OAuthAccessTokenException(rsp) # 异常分支不返回
def get_oauth_userinfo(self):
- """获取GitHub用户信息"""
- rsp = self.do_get(self.API_URL, params={}, headers={
- "Authorization": "token " + self.access_token # 使用token进行认证
- })
+ headers = {"Authorization": f"token {self.access_token}"}
+ rsp = self.do_get(self.API_URL, params={}, headers=headers)
+
try:
datas = json.loads(rsp)
- user = OAuthUser()
- user.picture = datas['avatar_url'] # 用户头像
- user.nickname = datas['name'] # 用户昵称
- user.openid = datas['id'] # 用户ID
- user.type = 'github'
- user.token = self.access_token
- user.metadata = rsp
- if 'email' in datas and datas['email']:
- user.email = datas['email'] # 用户邮箱
- return user
+ user = OAuthUser(
+ picture=datas['avatar_url'],
+ nickname=datas.get('name'),
+ openid=datas['id'],
+ type='github',
+ token=self.access_token,
+ metadata=rsp,
+ email=datas.get('email')
+ )
+ return user # 正常分支返回OAuthUser
except Exception as e:
- logger.error(e)
- logger.error('github oauth error.rsp:' + rsp)
- return None
+ logger.error(f"github oauth error: {e}, rsp: {rsp}")
+ return None # 异常分支返回None
def get_picture(self, metadata):
- """从元数据中获取用户头像"""
- datas = json.loads(metadata)
- return datas['avatar_url']
+ return json.loads(metadata)['avatar_url']
class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
"""Facebook OAuth管理器"""
-
AUTH_URL = 'https://www.facebook.com/v16.0/dialog/oauth'
TOKEN_URL = 'https://graph.facebook.com/v16.0/oauth/access_token'
API_URL = 'https://graph.facebook.com/me'
ICON_NAME = 'facebook'
def __init__(self, access_token=None, openid=None):
- """初始化Facebook OAuth管理器"""
config = self.get_config()
self.client_id = config.appkey if config else ''
self.client_secret = config.appsecret if config else ''
self.callback_url = config.callback_url if config else ''
- super(FaceBookOauthManager, self).__init__(access_token=access_token, openid=openid)
+ super().__init__(access_token=access_token, openid=openid)
def get_authorization_url(self, next_url='/'):
- """获取Facebook授权URL"""
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': self.callback_url,
- 'scope': 'email,public_profile' # 请求的权限范围
+ 'scope': 'email,public_profile'
}
- url = self.AUTH_URL + "?" + urllib.parse.urlencode(params)
- return url
+ return f"{self.AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token_by_code(self, code):
- """通过授权码获取访问令牌"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
@@ -393,74 +315,67 @@ class FaceBookOauthManager(ProxyManagerMixin, BaseOauthManager):
'redirect_uri': self.callback_url
}
rsp = self.do_post(self.TOKEN_URL, params)
-
obj = json.loads(rsp)
+
if 'access_token' in obj:
- token = str(obj['access_token'])
- self.access_token = token
- return self.access_token
- else:
- raise OAuthAccessTokenException(rsp)
+ self.access_token = str(obj['access_token'])
+ return self.access_token # 返回字符串token
+ raise OAuthAccessTokenException(rsp) # 异常分支不返回
def get_oauth_userinfo(self):
- """获取Facebook用户信息"""
params = {
'access_token': self.access_token,
- 'fields': 'id,name,picture,email' # 请求的用户字段
+ 'fields': 'id,name,picture,email'
}
try:
rsp = self.do_get(self.API_URL, params)
datas = json.loads(rsp)
- user = OAuthUser()
- user.nickname = datas['name'] # 用户昵称
- user.openid = datas['id'] # 用户ID
- user.type = 'facebook'
- user.token = self.access_token
- user.metadata = rsp
- if 'email' in datas and datas['email']:
- user.email = datas['email'] # 用户邮箱
- if 'picture' in datas and datas['picture'] and datas['picture']['data'] and datas['picture']['data']['url']:
- user.picture = str(datas['picture']['data']['url']) # 用户头像
- return user
+ user = OAuthUser(
+ nickname=datas['name'],
+ openid=datas['id'],
+ type='facebook',
+ token=self.access_token,
+ metadata=rsp,
+ email=datas.get('email')
+ )
+ # 处理头像URL
+ if 'picture' in datas:
+ pic_data = datas['picture'].get('data', {})
+ user.picture = pic_data.get('url', '')
+ return user # 正常分支返回OAuthUser
except Exception as e:
logger.error(e)
- return None
+ return None # 异常分支返回None
def get_picture(self, metadata):
- """从元数据中获取用户头像"""
datas = json.loads(metadata)
return str(datas['picture']['data']['url'])
class QQOauthManager(BaseOauthManager):
"""QQ OAuth管理器"""
-
AUTH_URL = 'https://graph.qq.com/oauth2.0/authorize'
TOKEN_URL = 'https://graph.qq.com/oauth2.0/token'
API_URL = 'https://graph.qq.com/user/get_user_info'
- OPEN_ID_URL = 'https://graph.qq.com/oauth2.0/me' # 获取OpenID的URL
+ OPEN_ID_URL = 'https://graph.qq.com/oauth2.0/me'
ICON_NAME = 'qq'
def __init__(self, access_token=None, openid=None):
- """初始化QQ OAuth管理器"""
config = self.get_config()
self.client_id = config.appkey if config else ''
self.client_secret = config.appsecret if config else ''
self.callback_url = config.callback_url if config else ''
- super(QQOauthManager, self).__init__(access_token=access_token, openid=openid)
+ super().__init__(access_token=access_token, openid=openid)
def get_authorization_url(self, next_url='/'):
- """获取QQ授权URL"""
params = {
'response_type': 'code',
'client_id': self.client_id,
- 'redirect_uri': self.callback_url + '&next_url=' + next_url,
+ 'redirect_uri': f'{self.callback_url}&next_url={next_url}',
}
- url = self.AUTH_URL + "?" + urllib.parse.urlencode(params)
- return url
+ return f"{self.AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token_by_code(self, code):
- """通过授权码获取访问令牌"""
params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
@@ -469,78 +384,77 @@ class QQOauthManager(BaseOauthManager):
'redirect_uri': self.callback_url
}
rsp = self.do_get(self.TOKEN_URL, params)
+
if rsp:
- # 解析URL编码的响应
d = urllib.parse.parse_qs(rsp)
if 'access_token' in d:
- token = d['access_token']
- self.access_token = token[0]
- return token
- else:
- raise OAuthAccessTokenException(rsp)
+ token = d['access_token'][0]
+ self.access_token = token
+ return token # 返回字符串token
+ raise OAuthAccessTokenException(rsp) # 异常分支不返回
def get_open_id(self):
- """获取用户的OpenID"""
if self.is_access_token_set:
- params = {
- 'access_token': self.access_token
- }
+ params = {'access_token': self.access_token}
rsp = self.do_get(self.OPEN_ID_URL, params)
+
if rsp:
- # 清理响应格式(JSONP格式)
- rsp = rsp.replace('callback(', '').replace(')', '').replace(';', '')
- obj = json.loads(rsp)
+ # 清理JSONP格式响应
+ cleaned_rsp = rsp.replace('callback(', '').replace(')', '').replace(';', '')
+ obj = json.loads(cleaned_rsp)
openid = str(obj['openid'])
self.openid = openid
- return openid
+ return openid # 成功获取返回字符串openid
+
+ return None # 失败分支返回None,保持类型一致
def get_oauth_userinfo(self):
- """获取QQ用户信息"""
openid = self.get_open_id()
- if openid:
- params = {
- 'access_token': self.access_token,
- 'oauth_consumer_key': self.client_id,
- 'openid': self.openid
- }
- rsp = self.do_get(self.API_URL, params)
- logger.info(rsp)
+ if not openid:
+ return None # 无openid返回None
+
+ params = {
+ 'access_token': self.access_token,
+ 'oauth_consumer_key': self.client_id,
+ 'openid': self.openid
+ }
+ rsp = self.do_get(self.API_URL, params)
+ logger.info(rsp)
+
+ try:
obj = json.loads(rsp)
- user = OAuthUser()
- user.nickname = obj['nickname'] # 用户昵称
- user.openid = openid
- user.type = 'qq'
- user.token = self.access_token
- user.metadata = rsp
- if 'email' in obj:
- user.email = obj['email'] # 用户邮箱
- if 'figureurl' in obj:
- user.picture = str(obj['figureurl']) # 用户头像
- return user
+ user = OAuthUser(
+ nickname=obj['nickname'],
+ openid=openid,
+ type='qq',
+ token=self.access_token,
+ metadata=rsp,
+ email=obj.get('email'),
+ picture=obj.get('figureurl', '')
+ )
+ return user # 正常分支返回OAuthUser
+ except Exception as e:
+ logger.error(e)
+ return None # 异常分支返回None
def get_picture(self, metadata):
- """从元数据中获取用户头像"""
- datas = json.loads(metadata)
- return str(datas['figureurl'])
+ return str(json.loads(metadata)['figureurl'])
@cache_decorator(expiration=100 * 60)
def get_oauth_apps():
- """获取所有启用的OAuth应用(带缓存)"""
configs = OAuthConfig.objects.filter(is_enable=True).all()
if not configs:
return []
configtypes = [x.type for x in configs]
applications = BaseOauthManager.__subclasses__()
- apps = [x() for x in applications if x().ICON_NAME.lower() in configtypes]
- return apps
+ return [x() for x in applications if x().ICON_NAME.lower() in configtypes]
def get_manager_by_type(type):
- """根据类型获取对应的OAuth管理器"""
applications = get_oauth_apps()
if applications:
- finds = list(filter(lambda x: x.ICON_NAME.lower() == type.lower(), applications))
+ finds = [x for x in applications if x.ICON_NAME.lower() == type.lower()]
if finds:
return finds[0]
- return None
\ No newline at end of file
+ return None # 未找到返回None,保持
\ No newline at end of file
diff --git a/src/DjangoBlog-master(1)/DjangoBlog-master/servermanager/robot.py b/src/DjangoBlog-master(1)/DjangoBlog-master/servermanager/robot.py
index 7b45736..3fece0b 100644
--- a/src/DjangoBlog-master(1)/DjangoBlog-master/servermanager/robot.py
+++ b/src/DjangoBlog-master(1)/DjangoBlog-master/servermanager/robot.py
@@ -1,8 +1,11 @@
import logging
import os
import re
+import json
+from dataclasses import dataclass, asdict
+from typing import Dict
-import jsonpickle
+import django
from django.conf import settings
from werobot import WeRoBot
from werobot.replies import ArticlesReply, Article
@@ -13,14 +16,15 @@ from servermanager.api.blogapi import BlogApi
from servermanager.api.commonapi import ChatGPT, CommandHandler
from .MemcacheStorage import MemcacheStorage
-robot = WeRoBot(token=os.environ.get('DJANGO_WEROBOT_TOKEN')
- or 'lylinux', enable_session=True)
+# 初始化微信机器人
+robot = WeRoBot(token=os.environ.get('DJANGO_WEROBOT_TOKEN') or 'lylinux', enable_session=True)
memstorage = MemcacheStorage()
if memstorage.is_available:
robot.config['SESSION_STORAGE'] = memstorage
else:
- if os.path.exists(os.path.join(settings.BASE_DIR, 'werobot_session')):
- os.remove(os.path.join(settings.BASE_DIR, 'werobot_session'))
+ session_path = os.path.join(settings.BASE_DIR, 'werobot_session')
+ if os.path.exists(session_path):
+ os.remove(session_path)
robot.config['SESSION_STORAGE'] = FileStorage(filename='werobot_session')
blogapi = BlogApi()
@@ -33,9 +37,7 @@ def convert_to_article_reply(articles, message):
from blog.templatetags.blog_tags import truncatechars_content
for post in articles:
imgs = re.findall(r'(?:http\:|https\:)?\/\/.*\.(?:png|jpg)', post.body)
- imgurl = ''
- if imgs:
- imgurl = imgs[0]
+ imgurl = imgs[0] if imgs else ''
article = Article(
title=post.title,
description=truncatechars_content(post.body),
@@ -46,56 +48,44 @@ def convert_to_article_reply(articles, message):
return reply
+# 微信机器人消息处理装饰器
@robot.filter(re.compile(r"^\?.*"))
def search(message, session):
- s = message.content
- searchstr = str(s).replace('?', '')
+ searchstr = message.content.replace('?', '')
result = blogapi.search_articles(searchstr)
if result:
- articles = list(map(lambda x: x.object, result))
- reply = convert_to_article_reply(articles, message)
- return reply
- else:
- return '没有找到相关文章。'
+ articles = [x.object for x in result]
+ return convert_to_article_reply(articles, message)
+ return '没有找到相关文章。'
@robot.filter(re.compile(r'^category\s*$', re.I))
def category(message, session):
categorys = blogapi.get_category_lists()
- content = ','.join(map(lambda x: x.name, categorys))
- return '所有文章分类目录:' + content
+ return '所有文章分类目录:' + ','.join([x.name for x in categorys])
@robot.filter(re.compile(r'^recent\s*$', re.I))
def recents(message, session):
articles = blogapi.get_recent_articles()
- if articles:
- reply = convert_to_article_reply(articles, message)
- return reply
- else:
- return "暂时还没有文章"
+ return convert_to_article_reply(articles, message) if articles else "暂时还没有文章"
@robot.filter(re.compile('^help$', re.I))
def help(message, session):
return '''欢迎关注!
- 默认会与图灵机器人聊天~~
- 你可以通过下面这些命令来获得信息
- ?关键字搜索文章.
- 如?python.
- category获得文章分类目录及文章数.
- category-***获得该分类目录文章
- 如category-python
- recent获得最新文章
- help获得帮助.
- weather:获得天气
- 如weather:西安
- idcard:获得身份证信息
- 如idcard:61048119xxxxxxxxxx
- music:音乐搜索
- 如music:阴天快乐
- PS:以上标点符号都不支持中文标点~~
- '''
+默认会与图灵机器人聊天~~
+你可以通过下面这些命令来获得信息
+?关键字搜索文章. 如?python.
+category获得文章分类目录及文章数.
+category-***获得该分类目录文章 如category-python
+recent获得最新文章
+help获得帮助.
+weather:获得天气 如weather:西安
+idcard:获得身份证信息 如idcard:61048119xxxxxxxxxx
+music:音乐搜索 如music:阴天快乐
+PS:以上标点符号都不支持中文标点~~
+'''
@robot.filter(re.compile(r'^weather\:.*$', re.I))
@@ -114,18 +104,45 @@ def echo(message, session):
return handler.handler()
+@dataclass
+class WxUserInfo:
+ """用户信息数据类,替代原类以支持安全序列化"""
+ isAdmin: bool = False
+ isPasswordSet: bool = False
+ Count: int = 0
+ Command: str = ''
+
+ def to_dict(self) -> Dict:
+ """转换为字典用于JSON序列化"""
+ return asdict(self)
+
+ @classmethod
+ def from_dict(cls, data: Dict) -> 'WxUserInfo':
+ """从字典恢复对象"""
+ return cls(
+ isAdmin=data.get('isAdmin', False),
+ isPasswordSet=data.get('isPasswordSet', False),
+ Count=data.get('Count', 0),
+ Command=data.get('Command', '')
+ )
+
+
class MessageHandler:
def __init__(self, message, session):
- userid = message.source
self.message = message
self.session = session
- self.userid = userid
+ self.userid = message.source
+ self.userinfo = self._load_userinfo()
+
+ def _load_userinfo(self) -> WxUserInfo:
+ """加载用户信息(使用JSON替代jsonpickle)"""
try:
- info = session[userid]
- self.userinfo = jsonpickle.decode(info)
- except Exception as e:
- userinfo = WxUserInfo()
- self.userinfo = userinfo
+ info_str = self.session.get(self.userid, '{}')
+ info_dict = json.loads(info_str)
+ return WxUserInfo.from_dict(info_dict)
+ except (json.JSONDecodeError, TypeError, KeyError) as e:
+ logger.warning(f"加载用户信息失败: {e}")
+ return WxUserInfo()
@property
def is_admin(self):
@@ -136,8 +153,12 @@ class MessageHandler:
return self.userinfo.isPasswordSet
def save_session(self):
- info = jsonpickle.encode(self.userinfo)
- self.session[self.userid] = info
+ """保存用户信息(使用JSON替代jsonpickle)"""
+ try:
+ info_str = json.dumps(self.userinfo.to_dict())
+ self.session[self.userid] = info_str
+ except json.JSONEncodeError as e:
+ logger.error(f"保存用户信息失败: {e}")
def handler(self):
info = self.message.content
@@ -146,14 +167,14 @@ class MessageHandler:
self.userinfo = WxUserInfo()
self.save_session()
return "退出成功"
+
if info.upper() == 'ADMIN':
self.userinfo.isAdmin = True
self.save_session()
return "输入管理员密码"
+
if self.userinfo.isAdmin and not self.userinfo.isPasswordSet:
- passwd = settings.WXADMIN
- if settings.TESTING:
- passwd = '123'
+ passwd = settings.WXADMIN if not settings.TESTING else '123'
if passwd.upper() == get_sha256(get_sha256(info)).upper():
self.userinfo.isPasswordSet = True
self.save_session()
@@ -166,6 +187,7 @@ class MessageHandler:
self.userinfo.Count += 1
self.save_session()
return "验证失败,请重新输入管理员密码:"
+
if self.userinfo.isAdmin and self.userinfo.isPasswordSet:
if self.userinfo.Command != '' and info.upper() == 'Y':
return cmd_handler.run(self.userinfo.Command)
@@ -174,14 +196,6 @@ class MessageHandler:
return cmd_handler.get_help()
self.userinfo.Command = info
self.save_session()
- return "确认执行: " + info + " 命令?"
-
- return ChatGPT.chat(info)
-
+ return f"确认执行: {info} 命令?"
-class WxUserInfo():
- def __init__(self):
- self.isAdmin = False
- self.isPasswordSet = False
- self.Count = 0
- self.Command = ''
+ return ChatGPT.chat(info)
\ No newline at end of file