diff --git a/server/api/v1/resource_pool.proto b/server/api/v1/resource_pool.proto index 3b18322..bbcb968 100644 --- a/server/api/v1/resource_pool.proto +++ b/server/api/v1/resource_pool.proto @@ -95,12 +95,14 @@ message Nodes { message ResourcePoolCreateRequest { string pool_name = 1; repeated Nodes nodes = 2; + int32 pool_type = 3; } message ResourcePoolUpdateRequest { int64 pool_id = 1; string pool_name = 2; repeated Nodes nodes = 3; + int32 pool_type = 4; } message ResourcePoolDeleteRequest { @@ -154,6 +156,7 @@ message ResourcePoolListData{ int64 disk_size = 8; repeated Nodes node_list = 9; string link_url = 10; + int32 pool_type = 11; } message ResourcePoolListRequest { diff --git a/server/config/db.sql b/server/config/db.sql index 268e0ba..25c07c9 100644 --- a/server/config/db.sql +++ b/server/config/db.sql @@ -18,3 +18,6 @@ create table nodes( ); INSERT INTO hami.resource_pool (id, pool_name) VALUES (1, '大模型资源池'); + + +alter table resource_pool add column pool_type int(8) default 0 comment '类型'; diff --git a/server/internal/database/resource_pool_db.go b/server/internal/database/resource_pool_db.go index 02380ce..e57313a 100644 --- a/server/internal/database/resource_pool_db.go +++ b/server/internal/database/resource_pool_db.go @@ -12,6 +12,7 @@ import ( type ResourcePool struct { Id int64 `db:"id"` PoolName string `db:"pool_name"` + PoolType int32 `db:"pool_type"` CreateTime time.Time `db:"create_time"` UpdateTime time.Time `db:"update_time"` } @@ -42,8 +43,8 @@ func ExistsResourcePoolByPoolName(poolName string) bool { func QueryResourcePoolById(poolId int64) (*ResourcePool, error) { var pool ResourcePool - err := db.QueryRow("SELECT id, pool_name, create_time, update_time FROM resource_pool WHERE id = ?", poolId). - Scan(&pool.Id, &pool.PoolName, &pool.CreateTime, &pool.UpdateTime) + err := db.QueryRow("SELECT id, pool_name, pool_type, create_time, update_time FROM resource_pool WHERE id = ?", poolId). + Scan(&pool.Id, &pool.PoolName, &pool.PoolType, &pool.CreateTime, &pool.UpdateTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { log.Infof("No record found with id %d", poolId) @@ -58,7 +59,7 @@ func QueryResourcePoolById(poolId int64) (*ResourcePool, error) { func QueryResourcePoolListAll() ([]*ResourcePool, error) { // 执行查询 - rows, err := db.Query("SELECT id, pool_name, create_time, update_time FROM resource_pool order by create_time desc") + rows, err := db.Query("SELECT id, pool_name, pool_type, create_time, update_time FROM resource_pool order by create_time desc") if err != nil { log.Infof("Query failed: %v", err) return nil, err @@ -71,7 +72,7 @@ func QueryResourcePoolListAll() ([]*ResourcePool, error) { // 遍历每一行 for rows.Next() { var pool ResourcePool - err := rows.Scan(&pool.Id, &pool.PoolName, &pool.CreateTime, &pool.UpdateTime) + err := rows.Scan(&pool.Id, &pool.PoolName, &pool.PoolType, &pool.CreateTime, &pool.UpdateTime) if err != nil { log.Infof("Scan failed: %v", err) return nil, err @@ -211,10 +212,10 @@ func QueryResourceNamesByNodeName(nodeName string) ([]string, error) { return resourcePoolNames, nil } -func InsertResourcePool(poolName string) (int64, error) { - querySql := "INSERT INTO resource_pool(pool_name) VALUES (?)" +func InsertResourcePool(poolName string, poolType int32) (int64, error) { + querySql := "INSERT INTO resource_pool(pool_name, pool_type) VALUES (?, ?)" - result, err := db.Exec(querySql, poolName) + result, err := db.Exec(querySql, poolName, poolType) if err != nil { log.Infof("Failed to insert record: %v", err) return 0, err @@ -229,9 +230,9 @@ func InsertResourcePool(poolName string) (int64, error) { return id, nil } -func UpdateResourcePool(poolId int64, poolName string) (int64, error) { - updateSql := "UPDATE resource_pool SET pool_name=? where id=?" - result, err := db.Exec(updateSql, poolName, poolId) +func UpdateResourcePool(poolId int64, poolName string, poolType int32) (int64, error) { + updateSql := "UPDATE resource_pool SET pool_name=?, pool_type=? where id=?" + result, err := db.Exec(updateSql, poolName, poolType, poolId) if err != nil { log.Infof("Failed to update record: %v", err) return 0, err @@ -259,6 +260,7 @@ func InsertNodes(poolId int64, nodes []*NodeInfo) (int64, error) { strings.Join(valueStrings, ","), ) + log.Info("InsertNodes: ", insertSql) result, err := db.Exec(insertSql, valueArgs...) if err != nil { log.Infof("Batch insert failed: %v", err) diff --git a/server/internal/service/resource_pool.go b/server/internal/service/resource_pool.go index b2c3c78..90db032 100644 --- a/server/internal/service/resource_pool.go +++ b/server/internal/service/resource_pool.go @@ -30,17 +30,23 @@ func NewResourcePoolService(uc *biz.NodeUsecase, pod *biz.PodUseCase, summary *b func (s *ResourcePoolService) Create(ctx context.Context, req *pb.ResourcePoolCreateRequest) (*pb.BaseResponse, error) { log.Info("CreateResourcePool called", req) poolName := req.PoolName + poolType := req.PoolType if database.ExistsResourcePoolByPoolName(poolName) { return &pb.BaseResponse{Code: 500, Message: "资源池:'" + poolName + "'已经存在"}, nil } - poolId, err := database.InsertResourcePool(poolName) + poolId, err := database.InsertResourcePool(poolName, poolType) if err != nil { return &pb.BaseResponse{Code: 500, Message: poolName + "创建资源池失败"}, nil } - nodes := make([]*database.NodeInfo, 0, len(req.Nodes)) + nodeSize := len(req.Nodes) + if poolType != 2 && nodeSize > 1 { + return &pb.BaseResponse{Code: 500, Message: "非多机多卡只能选择一个节点"}, nil + } + + nodes := make([]*database.NodeInfo, 0, nodeSize) for _, node := range req.Nodes { nodes = append(nodes, &database.NodeInfo{ Name: node.NodeName, @@ -58,8 +64,9 @@ func (s *ResourcePoolService) Create(ctx context.Context, req *pb.ResourcePoolCr } func (s *ResourcePoolService) Update(ctx context.Context, req *pb.ResourcePoolUpdateRequest) (*pb.BaseResponse, error) { - log.Info("UpdateResourcePool called", req) + log.Info("UpdateResourcePool called ", req) poolId := req.PoolId + poolType := req.PoolType resourcePool, err := database.QueryResourcePoolById(poolId) if err != nil { return &pb.BaseResponse{Code: 500, Message: "更新资源池失败"}, nil @@ -74,7 +81,12 @@ func (s *ResourcePoolService) Update(ctx context.Context, req *pb.ResourcePoolUp return &pb.BaseResponse{Code: 500, Message: "更新资源池失败"}, nil } - nodes := make([]*database.NodeInfo, 0, len(req.Nodes)) + nodeSize := len(req.Nodes) + if poolType != 2 && nodeSize > 1 { + return &pb.BaseResponse{Code: 500, Message: "非多机多卡只能选择一个节点"}, nil + } + + nodes := make([]*database.NodeInfo, 0, nodeSize) for _, node := range req.Nodes { nodes = append(nodes, &database.NodeInfo{ Name: node.NodeName, @@ -82,7 +94,7 @@ func (s *ResourcePoolService) Update(ctx context.Context, req *pb.ResourcePoolUp }) } _, err = database.InsertNodes(poolId, nodes) - _, err = database.UpdateResourcePool(poolId, req.PoolName) + _, err = database.UpdateResourcePool(poolId, req.PoolName, poolType) if err != nil { return &pb.BaseResponse{Code: 500, Message: "更新资源池失败"}, nil } @@ -144,6 +156,7 @@ func (s *ResourcePoolService) List(ctx context.Context, req *pb.ResourcePoolList var poolData pb.ResourcePoolListData poolData.PoolId = resourcePool.Id poolData.PoolName = resourcePool.PoolName + poolData.PoolType = resourcePool.PoolType dbNodes, _ := database.QueryNodesByPoolId(resourcePool.Id) poolData.NodeNum = int64(len(dbNodes)) diff --git a/server/openapi.yaml b/server/openapi.yaml index 0b450b4..3339b74 100644 --- a/server/openapi.yaml +++ b/server/openapi.yaml @@ -289,6 +289,9 @@ components: type: array items: $ref: '#/components/schemas/Nodes' + poolType: + type: integer + format: int32 ResourcePoolDeleteRequest: type: object properties: @@ -331,6 +334,9 @@ components: $ref: '#/components/schemas/Nodes' linkUrl: type: string + poolType: + type: integer + format: int32 ResourcePoolListResponse: type: object properties: @@ -349,6 +355,9 @@ components: type: array items: $ref: '#/components/schemas/Nodes' + poolType: + type: integer + format: int32 Status: type: object properties: