package mysql import ( "fmt" "os" "strings" "github.com/jmoiron/sqlx" "github.com/spf13/viper" "go.uber.org/zap" _ "github.com/go-sql-driver/mysql" ) var sqlFile = "dao/mysql/init.sql" var db *sqlx.DB func Init() (err error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/", viper.GetString("mysql.user"), viper.GetString("mysql.password"), viper.GetString("mysql.host"), viper.GetInt("mysql.port"), ) db, err = sqlx.Connect("mysql", dsn) if err != nil { zap.L().Error("connect DB failed", zap.Error(err)) return err } dbName := viper.GetString("mysql.dbname") err = createDatabaseIfNotExists(db, dbName) if err != nil { zap.L().Error("connect Database failed", zap.Error(err)) return err } dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", viper.GetString("mysql.user"), viper.GetString("mysql.password"), viper.GetString("mysql.host"), viper.GetInt("mysql.port"), viper.GetString("mysql.dbname"), ) db, err = sqlx.Connect("mysql", dsn) if err != nil { zap.L().Error("connect DB failed", zap.Error(err)) return err } err = importTableIfNotExists() if err != nil { zap.L().Error("import Sql failed", zap.Error(err)) return err } // 最大闲置连接 db.SetMaxIdleConns(viper.GetInt("max_idle_conns")) // 最大连接 db.SetMaxOpenConns(viper.GetInt("max_open_conns")) return } func createDatabaseIfNotExists(db *sqlx.DB, dbName string) error { _, err := db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbName)) return err } func importTableIfNotExists() error { return importSQL(db, sqlFile) } func importSQL(db *sqlx.DB, filePath string) error { sqlBytes, err := os.ReadFile(filePath) if err != nil { return err } sqlStatements := string(sqlBytes) statements := strings.Split(sqlStatements, ";") for _, statement := range statements { if strings.TrimSpace(statement) != "" { _, err = db.Exec(statement) if err != nil { return err } } } return err } func GetDb() *sqlx.DB { return db } func ChangeDb(anotherDb *sqlx.DB) { db = anotherDb } func Close() { _ = db.Close() }