You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

443 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from datetime import datetime, timedelta, timezone
import os
from backend.config import Config
def get_utc_now():
"""获取当前UTC时间确保时间同步"""
# 返回naive UTC时间因为cryptography库的证书时间字段需要naive datetime
# 使用now(timezone.utc)确保获取的是UTC时间不受系统时区影响
return datetime.now(timezone.utc).replace(tzinfo=None)
def generate_key_pair():
"""生成RSA密钥对"""
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=Config.CA_KEY_SIZE,
backend=default_backend()
)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
public_key = private_key.public_key()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return private_pem.decode('utf-8'), public_pem.decode('utf-8')
def init_ca():
"""初始化CA根证书如果不存在"""
if os.path.exists(Config.CA_PRIVATE_KEY_PATH) and os.path.exists(Config.CA_CERTIFICATE_PATH):
return
# 生成CA密钥对
ca_private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=Config.CA_KEY_SIZE,
backend=default_backend()
)
# 创建CA证书
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Beijing"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Beijing"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple CA System"),
x509.NameAttribute(NameOID.COMMON_NAME, "Simple CA Root"),
])
now = get_utc_now()
ca_cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
ca_private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
now
).not_valid_after(
now + timedelta(days=3650) # 10年有效期
).add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True,
).sign(ca_private_key, hashes.SHA256(), default_backend())
# 保存CA私钥
ca_private_pem = ca_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
# 保存CA证书
ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM)
# 确保目录存在
os.makedirs('ca', exist_ok=True)
with open(Config.CA_PRIVATE_KEY_PATH, 'wb') as f:
f.write(ca_private_pem)
with open(Config.CA_CERTIFICATE_PATH, 'wb') as f:
f.write(ca_cert_pem)
def load_ca_private_key():
"""加载CA私钥"""
with open(Config.CA_PRIVATE_KEY_PATH, 'rb') as f:
return serialization.load_pem_private_key(f.read(), password=None, backend=default_backend())
def load_ca_certificate():
"""加载CA证书"""
with open(Config.CA_CERTIFICATE_PATH, 'rb') as f:
return x509.load_pem_x509_certificate(f.read(), default_backend())
def validate_country_code(country):
"""验证并转换国家代码为2个字符"""
if not country:
return None
country = country.strip().upper()
# 如果是2个字符直接返回
if len(country) == 2:
return country
# 常见国家名称到代码的映射
country_map = {
'CHINA': 'CN',
'USA': 'US',
'UNITED STATES': 'US',
'UNITED KINGDOM': 'GB',
'UK': 'GB',
'JAPAN': 'JP',
'GERMANY': 'DE',
'FRANCE': 'FR',
'CANADA': 'CA',
'AUSTRALIA': 'AU',
'SOUTH KOREA': 'KR',
'KOREA': 'KR',
'INDIA': 'IN',
'BRAZIL': 'BR',
'RUSSIA': 'RU',
'ITALY': 'IT',
'SPAIN': 'ES',
'NETHERLANDS': 'NL',
'SWEDEN': 'SE',
'NORWAY': 'NO',
'DENMARK': 'DK',
'FINLAND': 'FI',
'POLAND': 'PL',
'SWITZERLAND': 'CH',
'AUSTRIA': 'AT',
'BELGIUM': 'BE',
'IRELAND': 'IE',
'PORTUGAL': 'PT',
'GREECE': 'GR',
'TURKEY': 'TR',
'MEXICO': 'MX',
'ARGENTINA': 'AR',
'SOUTH AFRICA': 'ZA',
'SINGAPORE': 'SG',
'HONG KONG': 'HK',
'TAIWAN': 'TW',
'THAILAND': 'TH',
'VIETNAM': 'VN',
'INDONESIA': 'ID',
'MALAYSIA': 'MY',
'PHILIPPINES': 'PH',
'NEW ZEALAND': 'NZ'
}
# 尝试映射
if country in country_map:
return country_map[country]
# 如果长度超过2取前2个字符可能不准确但至少能通过验证
if len(country) > 2:
raise ValueError(f'国家代码必须是2个字符如CN、US当前输入: {country}。请输入ISO 3166-1 alpha-2标准的2字符国家代码')
return country
def create_csr_from_data(data):
"""从数据创建CSR"""
private_key = serialization.load_pem_private_key(
data['private_key'].encode('utf-8'),
password=None,
backend=default_backend()
)
name_attributes = []
if data.get('country'):
# 验证并转换国家代码
country_code = validate_country_code(data['country'])
if country_code:
name_attributes.append(x509.NameAttribute(NameOID.COUNTRY_NAME, country_code))
if data.get('province'):
name_attributes.append(x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, data['province']))
if data.get('locality'):
name_attributes.append(x509.NameAttribute(NameOID.LOCALITY_NAME, data['locality']))
if data.get('organization'):
name_attributes.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, data['organization']))
if data.get('organization_unit_name'):
name_attributes.append(x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, data['organization_unit_name']))
name_attributes.append(x509.NameAttribute(NameOID.COMMON_NAME, data['common_name']))
if data.get('email_address'):
name_attributes.append(x509.NameAttribute(NameOID.EMAIL_ADDRESS, data['email_address']))
builder = x509.CertificateSigningRequestBuilder()
builder = builder.subject_name(x509.Name(name_attributes))
csr = builder.sign(private_key, hashes.SHA256(), default_backend())
return csr, private_key
def parse_csr(csr_pem):
"""解析CSR文件"""
csr = x509.load_pem_x509_csr(csr_pem.encode('utf-8'), default_backend())
data = {}
for attr in csr.subject:
if attr.oid == NameOID.COUNTRY_NAME:
data['country'] = attr.value
elif attr.oid == NameOID.STATE_OR_PROVINCE_NAME:
data['province'] = attr.value
elif attr.oid == NameOID.LOCALITY_NAME:
data['locality'] = attr.value
elif attr.oid == NameOID.ORGANIZATION_NAME:
data['organization'] = attr.value
elif attr.oid == NameOID.ORGANIZATIONAL_UNIT_NAME:
data['organization_unit_name'] = attr.value
elif attr.oid == NameOID.COMMON_NAME:
data['common_name'] = attr.value
elif attr.oid == NameOID.EMAIL_ADDRESS:
data['email_address'] = attr.value
public_key = csr.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf-8')
return data, public_key
def sign_certificate_from_request(cert_request):
"""根据证书请求直接生成证书不需要CSR"""
ca_private_key = load_ca_private_key()
ca_cert = load_ca_certificate()
# 构建主题名称
name_attributes = []
if cert_request.country:
# 验证并转换国家代码
country_code = validate_country_code(cert_request.country)
if country_code:
name_attributes.append(x509.NameAttribute(NameOID.COUNTRY_NAME, country_code))
if cert_request.province:
name_attributes.append(x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, cert_request.province))
if cert_request.locality:
name_attributes.append(x509.NameAttribute(NameOID.LOCALITY_NAME, cert_request.locality))
if cert_request.organization:
name_attributes.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, cert_request.organization))
if cert_request.organization_unit_name:
name_attributes.append(x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, cert_request.organization_unit_name))
name_attributes.append(x509.NameAttribute(NameOID.COMMON_NAME, cert_request.common_name))
if cert_request.email_address:
name_attributes.append(x509.NameAttribute(NameOID.EMAIL_ADDRESS, cert_request.email_address))
subject_name = x509.Name(name_attributes)
# 加载公钥
# 尝试加载PEM格式的公钥
public_key_data = cert_request.public_key.encode('utf-8')
try:
public_key = serialization.load_pem_public_key(public_key_data, backend=default_backend())
except Exception:
# 如果不是标准PEM格式尝试添加PEM头尾
if not public_key_data.startswith(b'-----BEGIN'):
# 尝试添加RSA公钥头尾
if b'BEGIN PUBLIC KEY' not in public_key_data:
public_key_data = b'-----BEGIN PUBLIC KEY-----\n' + public_key_data + b'\n-----END PUBLIC KEY-----'
public_key = serialization.load_pem_public_key(public_key_data, backend=default_backend())
else:
raise ValueError('无法解析公钥格式')
# 生成证书
now = get_utc_now()
cert = x509.CertificateBuilder().subject_name(
subject_name
).issuer_name(
ca_cert.subject
).public_key(
public_key
).serial_number(
x509.random_serial_number()
).not_valid_before(
now
).not_valid_after(
now + timedelta(days=Config.CERT_VALIDITY_DAYS)
).sign(ca_private_key, hashes.SHA256(), default_backend())
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
serial_number = str(cert.serial_number)
# 确保证书过期时间转换为naive datetimeMySQL DATETIME不支持时区
expire_time = cert.not_valid_after
if expire_time.tzinfo is not None:
# 如果证书中的时间带时区转换为UTC的naive datetime
expire_time = expire_time.astimezone(timezone.utc).replace(tzinfo=None)
return cert_pem, serial_number, expire_time
def sign_certificate(csr_pem, request_id):
"""使用CA私钥签署证书从CSR"""
csr = x509.load_pem_x509_csr(csr_pem.encode('utf-8'), default_backend())
ca_private_key = load_ca_private_key()
ca_cert = load_ca_certificate()
# 生成证书
now = get_utc_now()
cert = x509.CertificateBuilder().subject_name(
csr.subject
).issuer_name(
ca_cert.subject
).public_key(
csr.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
now
).not_valid_after(
now + timedelta(days=Config.CERT_VALIDITY_DAYS)
).sign(ca_private_key, hashes.SHA256(), default_backend())
cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode('utf-8')
serial_number = str(cert.serial_number)
# 确保证书过期时间转换为naive datetimeMySQL DATETIME不支持时区
expire_time = cert.not_valid_after
if expire_time.tzinfo is not None:
# 如果证书中的时间带时区转换为UTC的naive datetime
expire_time = expire_time.astimezone(timezone.utc).replace(tzinfo=None)
return cert_pem, serial_number, expire_time
def verify_certificate(cert_pem, check_crl=True):
"""验证证书是否由CA签发
Args:
cert_pem: 证书PEM格式字符串
check_crl: 是否检查证书吊销列表CRL默认为True
"""
try:
cert = x509.load_pem_x509_certificate(cert_pem.encode('utf-8'), default_backend())
ca_cert = load_ca_certificate()
# 验证证书是否过期
# 处理证书时间可能是naive或aware的情况
cert_expire_time = cert.not_valid_after
cert_start_time = cert.not_valid_before
# 统一转换为naive UTC datetime
if cert_expire_time.tzinfo is not None:
cert_expire_time = cert_expire_time.astimezone(timezone.utc).replace(tzinfo=None)
if cert_start_time.tzinfo is not None:
cert_start_time = cert_start_time.astimezone(timezone.utc).replace(tzinfo=None)
now = get_utc_now() # 返回naive UTC时间
# 检查证书是否已生效
if cert_start_time > now:
return False, f"证书尚未生效(生效时间: {cert_start_time.isoformat()}"
# 检查证书是否过期
if cert_expire_time < now:
return False, f"证书已过期(过期时间: {cert_expire_time.isoformat()}"
# 验证证书是否由CA签发简单验证实际应该验证签名
# 这里简化处理,检查发行者是否匹配
if cert.issuer != ca_cert.subject:
return False, "证书不是由本CA签发的"
# 检查证书是否在CRL证书吊销列表
if check_crl:
from models import Certificate, CRL, db
cert_serial_number = str(cert.serial_number)
# 通过序列号查找证书
certificate = Certificate.query.filter_by(
serial_number=cert_serial_number,
deleted_at=None
).first()
if certificate:
# 检查证书状态
if certificate.state == 2:
return False, "证书已被吊销"
# 检查是否在CRL中
crl_entry = CRL.query.filter_by(
certificate_id=certificate.id,
deleted_at=None
).first()
if crl_entry:
return False, "证书已被吊销在CRL中"
return True, "证书有效"
except Exception as e:
return False, f"证书验证失败: {str(e)}"
def parse_certificate(cert_pem):
"""解析证书信息"""
cert = x509.load_pem_x509_certificate(cert_pem.encode('utf-8'), default_backend())
# 处理证书时间统一转换为naive UTC datetime再格式化为ISO字符串
not_valid_before = cert.not_valid_before
not_valid_after = cert.not_valid_after
if not_valid_before.tzinfo is not None:
not_valid_before = not_valid_before.astimezone(timezone.utc).replace(tzinfo=None)
if not_valid_after.tzinfo is not None:
not_valid_after = not_valid_after.astimezone(timezone.utc).replace(tzinfo=None)
data = {
'serial_number': str(cert.serial_number),
'subject': {},
'issuer': {},
'not_valid_before': not_valid_before.isoformat(),
'not_valid_after': not_valid_after.isoformat(),
}
for attr in cert.subject:
if attr.oid == NameOID.COUNTRY_NAME:
data['subject']['country'] = attr.value
elif attr.oid == NameOID.STATE_OR_PROVINCE_NAME:
data['subject']['province'] = attr.value
elif attr.oid == NameOID.LOCALITY_NAME:
data['subject']['locality'] = attr.value
elif attr.oid == NameOID.ORGANIZATION_NAME:
data['subject']['organization'] = attr.value
elif attr.oid == NameOID.ORGANIZATIONAL_UNIT_NAME:
data['subject']['organization_unit_name'] = attr.value
elif attr.oid == NameOID.COMMON_NAME:
data['subject']['common_name'] = attr.value
elif attr.oid == NameOID.EMAIL_ADDRESS:
data['subject']['email_address'] = attr.value
for attr in cert.issuer:
if attr.oid == NameOID.COMMON_NAME:
data['issuer']['common_name'] = attr.value
elif attr.oid == NameOID.ORGANIZATION_NAME:
data['issuer']['organization'] = attr.value
return data