You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
7.2 KiB

package gorm_v2
import (
"errors"
"fmt"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
gormLog "gorm.io/gorm/logger"
"gorm.io/plugin/dbresolver"
"goskeleton/app/global/my_errors"
"goskeleton/app/global/variable"
"strings"
"time"
)
// 获取一个 mysql 客户端
func GetOneMysqlClient() (*gorm.DB, error) {
sqlType := "Mysql"
readDbIsOpen := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".IsOpenReadDb")
return GetSqlDriver(sqlType, readDbIsOpen)
}
// 获取一个 sqlserver 客户端
func GetOneSqlserverClient() (*gorm.DB, error) {
sqlType := "SqlServer"
readDbIsOpen := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".IsOpenReadDb")
return GetSqlDriver(sqlType, readDbIsOpen)
}
// 获取一个 postgresql 客户端
func GetOnePostgreSqlClient() (*gorm.DB, error) {
sqlType := "Postgresql"
readDbIsOpen := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".IsOpenReadDb")
return GetSqlDriver(sqlType, readDbIsOpen)
}
// 获取数据库驱动, 可以通过options 动态参数连接任意多个数据库
func GetSqlDriver(sqlType string, readDbIsOpen int, dbConf ...ConfigParams) (*gorm.DB, error) {
var dbDialector gorm.Dialector
if val, err := getDbDialector(sqlType, "Write", dbConf...); err != nil {
variable.ZapLog.Error(my_errors.ErrorsDialectorDbInitFail+sqlType, zap.Error(err))
} else {
dbDialector = val
}
gormDb, err := gorm.Open(dbDialector, &gorm.Config{
SkipDefaultTransaction: true,
PrepareStmt: true,
Logger: redefineLog(sqlType), //拦截、接管 gorm v2 自带日志
})
if err != nil {
//gorm 数据库驱动初始化失败
return nil, err
}
// 如果开启了读写分离配置读数据库resource、read、replicas
// 读写分离配置只
if readDbIsOpen == 1 {
if val, err := getDbDialector(sqlType, "Read", dbConf...); err != nil {
variable.ZapLog.Error(my_errors.ErrorsDialectorDbInitFail+sqlType, zap.Error(err))
} else {
dbDialector = val
}
resolverConf := dbresolver.Config{
Replicas: []gorm.Dialector{dbDialector}, // 读 操作库,查询类
Policy: dbresolver.RandomPolicy{}, // sources/replicas 负载均衡策略适用于
}
err = gormDb.Use(dbresolver.Register(resolverConf).SetConnMaxIdleTime(time.Second * 30).
SetConnMaxLifetime(variable.ConfigGormv2Yml.GetDuration("Gormv2."+sqlType+".Read.SetConnMaxLifetime") * time.Second).
SetMaxIdleConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Read.SetMaxIdleConns")).
SetMaxOpenConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Read.SetMaxOpenConns")))
if err != nil {
return nil, err
}
}
// 查询没有数据,屏蔽 gorm v2 包中会爆出的错误
// https://github.com/go-gorm/gorm/issues/3789 此 issue 所反映的问题就是我们本次解决掉的
_ = gormDb.Callback().Query().Before("gorm:query").Register("disable_raise_record_not_found", MaskNotDataError)
// https://github.com/go-gorm/gorm/issues/4838
_ = gormDb.Callback().Create().Before("gorm:before_create").Register("CreateBeforeHook", CreateBeforeHook)
// 为了完美支持gorm的一系列回调函数
_ = gormDb.Callback().Update().Before("gorm:before_update").Register("UpdateBeforeHook", UpdateBeforeHook)
// 为主连接设置连接池(43行返回的数据库驱动指针)
if rawDb, err := gormDb.DB(); err != nil {
return nil, err
} else {
rawDb.SetConnMaxIdleTime(time.Second * 30)
rawDb.SetConnMaxLifetime(variable.ConfigGormv2Yml.GetDuration("Gormv2."+sqlType+".Write.SetConnMaxLifetime") * time.Second)
rawDb.SetMaxIdleConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Write.SetMaxIdleConns"))
rawDb.SetMaxOpenConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Write.SetMaxOpenConns"))
// 全局sql的debug配置
if variable.ConfigGormv2Yml.GetBool("Gormv2.SqlDebug") {
return gormDb.Debug(), nil
} else {
return gormDb, nil
}
}
}
// 获取一个数据库方言(Dialector),通俗的说就是根据不同的连接参数,获取具体的一类数据库的连接指针
func getDbDialector(sqlType, readWrite string, dbConf ...ConfigParams) (gorm.Dialector, error) {
var dbDialector gorm.Dialector
dsn := getDsn(sqlType, readWrite, dbConf...)
switch strings.ToLower(sqlType) {
case "mysql":
dbDialector = mysql.Open(dsn)
case "sqlserver", "mssql":
dbDialector = sqlserver.Open(dsn)
case "postgres", "postgresql", "postgre":
dbDialector = postgres.Open(dsn)
default:
return nil, errors.New(my_errors.ErrorsDbDriverNotExists + sqlType)
}
return dbDialector, nil
}
// 根据配置参数生成数据库驱动 dsn
func getDsn(sqlType, readWrite string, dbConf ...ConfigParams) string {
Host := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".Host")
DataBase := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".DataBase")
Port := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + "." + readWrite + ".Port")
User := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".User")
Pass := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".Pass")
Charset := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".Charset")
if len(dbConf) > 0 {
if strings.ToLower(readWrite) == "write" {
if len(dbConf[0].Write.Host) > 0 {
Host = dbConf[0].Write.Host
}
if len(dbConf[0].Write.DataBase) > 0 {
DataBase = dbConf[0].Write.DataBase
}
if dbConf[0].Write.Port > 0 {
Port = dbConf[0].Write.Port
}
if len(dbConf[0].Write.User) > 0 {
User = dbConf[0].Write.User
}
if len(dbConf[0].Write.Pass) > 0 {
Pass = dbConf[0].Write.Pass
}
if len(dbConf[0].Write.Charset) > 0 {
Charset = dbConf[0].Write.Charset
}
} else {
if len(dbConf[0].Read.Host) > 0 {
Host = dbConf[0].Read.Host
}
if len(dbConf[0].Read.DataBase) > 0 {
DataBase = dbConf[0].Read.DataBase
}
if dbConf[0].Read.Port > 0 {
Port = dbConf[0].Read.Port
}
if len(dbConf[0].Read.User) > 0 {
User = dbConf[0].Read.User
}
if len(dbConf[0].Read.Pass) > 0 {
Pass = dbConf[0].Read.Pass
}
if len(dbConf[0].Read.Charset) > 0 {
Charset = dbConf[0].Read.Charset
}
}
}
switch strings.ToLower(sqlType) {
case "mysql":
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=false&loc=Local", User, Pass, Host, Port, DataBase, Charset)
case "sqlserver", "mssql":
return fmt.Sprintf("server=%s;port=%d;database=%s;user id=%s;password=%s;encrypt=disable", Host, Port, DataBase, User, Pass)
case "postgresql", "postgre", "postgres":
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable TimeZone=Asia/Shanghai", Host, Port, DataBase, User, Pass)
}
return ""
}
// 创建自定义日志模块,对 gorm 日志进行拦截、
func redefineLog(sqlType string) gormLog.Interface {
return createCustomGormLog(sqlType,
SetInfoStrFormat("[info] %s\n"), SetWarnStrFormat("[warn] %s\n"), SetErrStrFormat("[error] %s\n"),
SetTraceStrFormat("[traceStr] %s [%.3fms] [rows:%v] %s\n"), SetTracWarnStrFormat("[traceWarn] %s %s [%.3fms] [rows:%v] %s\n"), SetTracErrStrFormat("[traceErr] %s %s [%.3fms] [rows:%v] %s\n"))
}