You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
2.0 KiB
101 lines
2.0 KiB
package mysql
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/spf13/viper"
|
|
"go.uber.org/zap"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
var sqlFile = "dao/mysql/init.sql"
|
|
|
|
var db *sqlx.DB
|
|
|
|
func Init() (err error) {
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/",
|
|
viper.GetString("mysql.user"),
|
|
viper.GetString("mysql.password"),
|
|
viper.GetString("mysql.host"),
|
|
viper.GetInt("mysql.port"),
|
|
)
|
|
db, err = sqlx.Connect("mysql", dsn)
|
|
if err != nil {
|
|
zap.L().Error("connect DB failed", zap.Error(err))
|
|
return err
|
|
}
|
|
dbName := viper.GetString("mysql.dbname")
|
|
err = createDatabaseIfNotExists(db, dbName)
|
|
if err != nil {
|
|
zap.L().Error("connect Database failed", zap.Error(err))
|
|
return err
|
|
}
|
|
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
|
viper.GetString("mysql.user"),
|
|
viper.GetString("mysql.password"),
|
|
viper.GetString("mysql.host"),
|
|
viper.GetInt("mysql.port"),
|
|
viper.GetString("mysql.dbname"),
|
|
)
|
|
db, err = sqlx.Connect("mysql", dsn)
|
|
if err != nil {
|
|
zap.L().Error("connect DB failed", zap.Error(err))
|
|
return err
|
|
}
|
|
err = importTableIfNotExists()
|
|
if err != nil {
|
|
zap.L().Error("import Sql failed", zap.Error(err))
|
|
return err
|
|
}
|
|
// 最大闲置连接
|
|
db.SetMaxIdleConns(viper.GetInt("max_idle_conns"))
|
|
// 最大连接
|
|
db.SetMaxOpenConns(viper.GetInt("max_open_conns"))
|
|
return
|
|
}
|
|
|
|
func createDatabaseIfNotExists(db *sqlx.DB, dbName string) error {
|
|
_, err := db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbName))
|
|
return err
|
|
}
|
|
|
|
func importTableIfNotExists() error {
|
|
return importSQL(db, sqlFile)
|
|
}
|
|
|
|
func importSQL(db *sqlx.DB, filePath string) error {
|
|
sqlBytes, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sqlStatements := string(sqlBytes)
|
|
statements := strings.Split(sqlStatements, ";")
|
|
|
|
for _, statement := range statements {
|
|
if strings.TrimSpace(statement) != "" {
|
|
_, err = db.Exec(statement)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func GetDb() *sqlx.DB {
|
|
return db
|
|
}
|
|
|
|
func ChangeDb(anotherDb *sqlx.DB) {
|
|
db = anotherDb
|
|
}
|
|
|
|
func Close() {
|
|
_ = db.Close()
|
|
}
|