|
|
package gorm_v2
|
|
|
|
|
|
import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"go.uber.org/zap"
|
|
|
"gorm.io/driver/mysql"
|
|
|
"gorm.io/driver/postgres"
|
|
|
"gorm.io/driver/sqlserver"
|
|
|
"gorm.io/gorm"
|
|
|
gormLog "gorm.io/gorm/logger"
|
|
|
"gorm.io/plugin/dbresolver"
|
|
|
"goskeleton/app/global/my_errors"
|
|
|
"goskeleton/app/global/variable"
|
|
|
"strings"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
// 获取一个 mysql 客户端
|
|
|
func GetOneMysqlClient() (*gorm.DB, error) {
|
|
|
sqlType := "Mysql"
|
|
|
readDbIsOpen := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".IsOpenReadDb")
|
|
|
return GetSqlDriver(sqlType, readDbIsOpen)
|
|
|
}
|
|
|
|
|
|
// 获取一个 sqlserver 客户端
|
|
|
func GetOneSqlserverClient() (*gorm.DB, error) {
|
|
|
sqlType := "SqlServer"
|
|
|
readDbIsOpen := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".IsOpenReadDb")
|
|
|
return GetSqlDriver(sqlType, readDbIsOpen)
|
|
|
}
|
|
|
|
|
|
// 获取一个 postgresql 客户端
|
|
|
func GetOnePostgreSqlClient() (*gorm.DB, error) {
|
|
|
sqlType := "Postgresql"
|
|
|
readDbIsOpen := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".IsOpenReadDb")
|
|
|
return GetSqlDriver(sqlType, readDbIsOpen)
|
|
|
}
|
|
|
|
|
|
// 获取数据库驱动, 可以通过options 动态参数连接任意多个数据库
|
|
|
func GetSqlDriver(sqlType string, readDbIsOpen int, dbConf ...ConfigParams) (*gorm.DB, error) {
|
|
|
|
|
|
var dbDialector gorm.Dialector
|
|
|
if val, err := getDbDialector(sqlType, "Write", dbConf...); err != nil {
|
|
|
variable.ZapLog.Error(my_errors.ErrorsDialectorDbInitFail+sqlType, zap.Error(err))
|
|
|
} else {
|
|
|
dbDialector = val
|
|
|
}
|
|
|
gormDb, err := gorm.Open(dbDialector, &gorm.Config{
|
|
|
SkipDefaultTransaction: true,
|
|
|
PrepareStmt: true,
|
|
|
Logger: redefineLog(sqlType), //拦截、接管 gorm v2 自带日志
|
|
|
})
|
|
|
if err != nil {
|
|
|
//gorm 数据库驱动初始化失败
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 如果开启了读写分离,配置读数据库(resource、read、replicas)
|
|
|
// 读写分离配置只
|
|
|
if readDbIsOpen == 1 {
|
|
|
if val, err := getDbDialector(sqlType, "Read", dbConf...); err != nil {
|
|
|
variable.ZapLog.Error(my_errors.ErrorsDialectorDbInitFail+sqlType, zap.Error(err))
|
|
|
} else {
|
|
|
dbDialector = val
|
|
|
}
|
|
|
resolverConf := dbresolver.Config{
|
|
|
Replicas: []gorm.Dialector{dbDialector}, // 读 操作库,查询类
|
|
|
Policy: dbresolver.RandomPolicy{}, // sources/replicas 负载均衡策略适用于
|
|
|
}
|
|
|
err = gormDb.Use(dbresolver.Register(resolverConf).SetConnMaxIdleTime(time.Second * 30).
|
|
|
SetConnMaxLifetime(variable.ConfigGormv2Yml.GetDuration("Gormv2."+sqlType+".Read.SetConnMaxLifetime") * time.Second).
|
|
|
SetMaxIdleConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Read.SetMaxIdleConns")).
|
|
|
SetMaxOpenConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Read.SetMaxOpenConns")))
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 查询没有数据,屏蔽 gorm v2 包中会爆出的错误
|
|
|
// https://github.com/go-gorm/gorm/issues/3789 此 issue 所反映的问题就是我们本次解决掉的
|
|
|
_ = gormDb.Callback().Query().Before("gorm:query").Register("disable_raise_record_not_found", MaskNotDataError)
|
|
|
|
|
|
// https://github.com/go-gorm/gorm/issues/4838
|
|
|
_ = gormDb.Callback().Create().Before("gorm:before_create").Register("CreateBeforeHook", CreateBeforeHook)
|
|
|
// 为了完美支持gorm的一系列回调函数
|
|
|
_ = gormDb.Callback().Update().Before("gorm:before_update").Register("UpdateBeforeHook", UpdateBeforeHook)
|
|
|
|
|
|
// 为主连接设置连接池(43行返回的数据库驱动指针)
|
|
|
if rawDb, err := gormDb.DB(); err != nil {
|
|
|
return nil, err
|
|
|
} else {
|
|
|
rawDb.SetConnMaxIdleTime(time.Second * 30)
|
|
|
rawDb.SetConnMaxLifetime(variable.ConfigGormv2Yml.GetDuration("Gormv2."+sqlType+".Write.SetConnMaxLifetime") * time.Second)
|
|
|
rawDb.SetMaxIdleConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Write.SetMaxIdleConns"))
|
|
|
rawDb.SetMaxOpenConns(variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + ".Write.SetMaxOpenConns"))
|
|
|
// 全局sql的debug配置
|
|
|
if variable.ConfigGormv2Yml.GetBool("Gormv2.SqlDebug") {
|
|
|
return gormDb.Debug(), nil
|
|
|
} else {
|
|
|
return gormDb, nil
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 获取一个数据库方言(Dialector),通俗的说就是根据不同的连接参数,获取具体的一类数据库的连接指针
|
|
|
func getDbDialector(sqlType, readWrite string, dbConf ...ConfigParams) (gorm.Dialector, error) {
|
|
|
var dbDialector gorm.Dialector
|
|
|
dsn := getDsn(sqlType, readWrite, dbConf...)
|
|
|
switch strings.ToLower(sqlType) {
|
|
|
case "mysql":
|
|
|
dbDialector = mysql.Open(dsn)
|
|
|
case "sqlserver", "mssql":
|
|
|
dbDialector = sqlserver.Open(dsn)
|
|
|
case "postgres", "postgresql", "postgre":
|
|
|
dbDialector = postgres.Open(dsn)
|
|
|
default:
|
|
|
return nil, errors.New(my_errors.ErrorsDbDriverNotExists + sqlType)
|
|
|
}
|
|
|
return dbDialector, nil
|
|
|
}
|
|
|
|
|
|
// 根据配置参数生成数据库驱动 dsn
|
|
|
func getDsn(sqlType, readWrite string, dbConf ...ConfigParams) string {
|
|
|
Host := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".Host")
|
|
|
DataBase := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".DataBase")
|
|
|
Port := variable.ConfigGormv2Yml.GetInt("Gormv2." + sqlType + "." + readWrite + ".Port")
|
|
|
User := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".User")
|
|
|
Pass := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".Pass")
|
|
|
Charset := variable.ConfigGormv2Yml.GetString("Gormv2." + sqlType + "." + readWrite + ".Charset")
|
|
|
|
|
|
if len(dbConf) > 0 {
|
|
|
if strings.ToLower(readWrite) == "write" {
|
|
|
if len(dbConf[0].Write.Host) > 0 {
|
|
|
Host = dbConf[0].Write.Host
|
|
|
}
|
|
|
if len(dbConf[0].Write.DataBase) > 0 {
|
|
|
DataBase = dbConf[0].Write.DataBase
|
|
|
}
|
|
|
if dbConf[0].Write.Port > 0 {
|
|
|
Port = dbConf[0].Write.Port
|
|
|
}
|
|
|
if len(dbConf[0].Write.User) > 0 {
|
|
|
User = dbConf[0].Write.User
|
|
|
}
|
|
|
if len(dbConf[0].Write.Pass) > 0 {
|
|
|
Pass = dbConf[0].Write.Pass
|
|
|
}
|
|
|
if len(dbConf[0].Write.Charset) > 0 {
|
|
|
Charset = dbConf[0].Write.Charset
|
|
|
}
|
|
|
} else {
|
|
|
if len(dbConf[0].Read.Host) > 0 {
|
|
|
Host = dbConf[0].Read.Host
|
|
|
}
|
|
|
if len(dbConf[0].Read.DataBase) > 0 {
|
|
|
DataBase = dbConf[0].Read.DataBase
|
|
|
}
|
|
|
if dbConf[0].Read.Port > 0 {
|
|
|
Port = dbConf[0].Read.Port
|
|
|
}
|
|
|
if len(dbConf[0].Read.User) > 0 {
|
|
|
User = dbConf[0].Read.User
|
|
|
}
|
|
|
if len(dbConf[0].Read.Pass) > 0 {
|
|
|
Pass = dbConf[0].Read.Pass
|
|
|
}
|
|
|
if len(dbConf[0].Read.Charset) > 0 {
|
|
|
Charset = dbConf[0].Read.Charset
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
switch strings.ToLower(sqlType) {
|
|
|
case "mysql":
|
|
|
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=false&loc=Local", User, Pass, Host, Port, DataBase, Charset)
|
|
|
case "sqlserver", "mssql":
|
|
|
return fmt.Sprintf("server=%s;port=%d;database=%s;user id=%s;password=%s;encrypt=disable", Host, Port, DataBase, User, Pass)
|
|
|
case "postgresql", "postgre", "postgres":
|
|
|
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable TimeZone=Asia/Shanghai", Host, Port, DataBase, User, Pass)
|
|
|
}
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
// 创建自定义日志模块,对 gorm 日志进行拦截、
|
|
|
func redefineLog(sqlType string) gormLog.Interface {
|
|
|
return createCustomGormLog(sqlType,
|
|
|
SetInfoStrFormat("[info] %s\n"), SetWarnStrFormat("[warn] %s\n"), SetErrStrFormat("[error] %s\n"),
|
|
|
SetTraceStrFormat("[traceStr] %s [%.3fms] [rows:%v] %s\n"), SetTracWarnStrFormat("[traceWarn] %s %s [%.3fms] [rows:%v] %s\n"), SetTracErrStrFormat("[traceErr] %s %s [%.3fms] [rows:%v] %s\n"))
|
|
|
}
|