diff --git a/server/api/v1/node.proto b/server/api/v1/node.proto index f39f70d..1ce2e52 100644 --- a/server/api/v1/node.proto +++ b/server/api/v1/node.proto @@ -46,6 +46,26 @@ service Node { summary: "禁用启用节点"; }; } + + rpc DiscoveredNode (DiscoveredNodeRequest) returns (DiscoveredNodeResponse) { + option (google.api.http) = { + post: "/v1/node/discovered", + body: "*" + }; + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + summary: "发现节点"; + }; + } + + rpc JoinNode (JoinNodeRequest) returns (JoinNodeResponse) { + option (google.api.http) = { + post: "/v1/node/join", + body: "*" + }; + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + summary: "加入节点"; + }; + } } message GetSummaryReq { @@ -119,4 +139,28 @@ message UpdateNodeStatusRequest { message UpdateNodeStatusResponse { int32 code = 1; string message = 2; -} \ No newline at end of file +} + + +message DiscoveredNodeRequest { + +} + +message DiscoveredNodeResponse { + repeated DiscoveredNodeInfo list = 1; +} + +message DiscoveredNodeInfo{ + string node_ip = 1; + string node_name = 2; +} + +message JoinNodeRequest{ + repeated string node_names = 1; +} + + +message JoinNodeResponse { + int32 code = 1; + string message = 2; +} diff --git a/server/internal/biz/node.go b/server/internal/biz/node.go index 7ab9e5b..adbe20a 100644 --- a/server/internal/biz/node.go +++ b/server/internal/biz/node.go @@ -2,6 +2,7 @@ package biz import ( "context" + "vgpu/internal/database" "github.com/go-kratos/kratos/v2/log" ) @@ -27,6 +28,7 @@ type Node struct { AvailableMemory int64 // 可用内存(字节) DiskTotal int64 // 磁盘总大小(字节) StorageNum int64 + Lables map[string]string } type DeviceInfo struct { @@ -60,6 +62,8 @@ type NodeRepo interface { FindDeviceByAliasId(string) (*DeviceInfo, error) EnableNode(context.Context, string) error DisableNode(context.Context, string) error + DiscoveredNode() ([]*database.Nodes, error) + JoinNode([]string) error } type NodeUsecase struct { @@ -98,3 +102,11 @@ func (uc *NodeUsecase) EnableNode(ctx context.Context, nodeName string) error { func (uc *NodeUsecase) DisableNode(ctx context.Context, nodeName string) error { return uc.repo.DisableNode(ctx, nodeName) } + +func (uc *NodeUsecase) DiscoveredNode() ([]*database.Nodes, error) { + return uc.repo.DiscoveredNode() +} + +func (uc *NodeUsecase) JoinNode(nodeNames []string) error { + return uc.repo.JoinNode(nodeNames) +} diff --git a/server/internal/data/node.go b/server/internal/data/node.go index f716e17..02415bb 100644 --- a/server/internal/data/node.go +++ b/server/internal/data/node.go @@ -17,6 +17,7 @@ import ( "sync" "time" "vgpu/internal/biz" + "vgpu/internal/database" "vgpu/internal/provider" "vgpu/internal/provider/ascend" "vgpu/internal/provider/hygon" @@ -104,7 +105,8 @@ func (r *nodeRepo) updateLocalNodes() { r.nodes = n var all []*biz.Node - allNodes, _ := r.nodeLister.List(labels.Everything()) + + allNodes, _ := r.nodeLister.List(labels.Set{"gpu": "on"}.AsSelector()) for _, node := range allNodes { bizNode := r.fetchNodeInfo(node) gpuNode := n[k8stypes.UID(bizNode.Uid)] @@ -145,8 +147,6 @@ func (r *nodeRepo) onDeletedNode(obj interface{}) { } func (r *nodeRepo) fetchNodeInfo(node *corev1.Node) *biz.Node { - //b, _ := json.MarshalIndent(node, "", " ") - //fmt.Println(string(b)) n := &biz.Node{IsSchedulable: !node.Spec.Unschedulable} for _, addr := range node.Status.Addresses { if addr.Type == corev1.NodeInternalIP { @@ -159,6 +159,7 @@ func (r *nodeRepo) fetchNodeInfo(node *corev1.Node) *biz.Node { n.IsReady = true } } + n.Uid = string(node.UID) n.Name = node.Name n.OSImage = node.Status.NodeInfo.OSImage @@ -169,6 +170,7 @@ func (r *nodeRepo) fetchNodeInfo(node *corev1.Node) *biz.Node { n.KubeProxyVersion = node.Status.NodeInfo.KubeProxyVersion n.Architecture = strings.ToUpper(node.Status.NodeInfo.Architecture) n.CreationTimestamp = node.CreationTimestamp.Format("2006-01-02 15:04:05") + n.Lables = node.Labels capacity := node.Status.Capacity allocatable := node.Status.Allocatable @@ -242,7 +244,6 @@ func (r *nodeRepo) FindDeviceByAliasId(aliasId string) (*biz.DeviceInfo, error) return nil, errors.New(fmt.Sprintf("aliasID:%s device not found", aliasId)) } -// DisableNode 禁用节点(标记为不可调度并排空Pod) func (r *nodeRepo) EnableNode(ctx context.Context, nodeName string) error { // 1. 标记为可调度 patch := []byte(`{"spec":{"unschedulable":false}}`) @@ -266,7 +267,6 @@ func (r *nodeRepo) EnableNode(ctx context.Context, nodeName string) error { return nil } -// DisableNode 禁用节点(标记为不可调度并排空Pod) func (r *nodeRepo) DisableNode(ctx context.Context, nodeName string) error { // 1. 标记为不可调度 patch := []byte(`{"spec":{"unschedulable":true}}`) @@ -289,6 +289,60 @@ func (r *nodeRepo) DisableNode(ctx context.Context, nodeName string) error { return nil } +func (r *nodeRepo) DiscoveredNode() ([]*database.Nodes, error) { + distinctNodes, err := database.QueryDistinctNodes() + if err != nil { + return nil, err + } + + ipSet := make(map[string]struct{}) + for _, value := range distinctNodes { + ipSet[value.NodeIp] = struct{}{} + } + + var discoverNodes []*database.Nodes + for _, value := range r.allNodes { + if value.Lables["gpu"] == "on" { + continue + } + log.Infof("发现节点%s", value.IP) + if _, found := ipSet[value.IP]; !found { + discoverNodes = append(discoverNodes, &database.Nodes{ + NodeIp: value.IP, + NodeName: value.Name, + }) + } + } + + return discoverNodes, err +} + +func (r *nodeRepo) JoinNode(nodeNames []string) error { + for _, nodeName := range nodeNames { + err := r.labelNode(nodeName) + if err != nil { + return err + } + } + return nil +} + +func (r *nodeRepo) labelNode(nodeName string) error { + node, err := r.data.k8sCl.CoreV1().Nodes().Get(context.TODO(), nodeName, metav1.GetOptions{}) + if err != nil { + return err + } + + node.Labels["gpu"] = "on" + _, err = r.data.k8sCl.CoreV1().Nodes().Update(context.TODO(), node, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("failed to label node: %v", err) + } + + fmt.Printf("Successfully labeled node %s\n", nodeName) + return nil +} + func (r *nodeRepo) evictPodsOnNode(ctx context.Context, nodeName string) error { // 获取该节点上的 Pod 列表 pods, err := r.data.k8sCl.CoreV1().Pods("").List(ctx, metav1.ListOptions{ diff --git a/server/internal/database/resource_pool_db.go b/server/internal/database/resource_pool_db.go index 8403656..992719f 100644 --- a/server/internal/database/resource_pool_db.go +++ b/server/internal/database/resource_pool_db.go @@ -118,6 +118,37 @@ func QueryNodesByPoolId(poolId int64) ([]*Nodes, error) { return nodes, nil } +func QueryDistinctNodes() ([]*Nodes, error) { + // 执行查询 + rows, err := db.Query("select distinct nodes.node_name, nodes.node_ip from nodes") + if err != nil { + log.Infof("Query failed: %v", err) + return nil, err + } + defer rows.Close() + + // 存放结果的切片 + nodes := make([]*Nodes, 0) + + // 遍历每一行 + for rows.Next() { + var node Nodes + err := rows.Scan(&node.NodeName, &node.NodeIp) + if err != nil { + log.Infof("Scan failed: %v", err) + return nil, err + } + nodes = append(nodes, &node) + } + + // 检查 rows 是否遍历中出错 + if err := rows.Err(); err != nil { + return nil, err + } + + return nodes, nil +} + func InsertResourcePool(poolName string) (int64, error) { querySql := "INSERT INTO resource_pool(pool_name) VALUES (?)" diff --git a/server/internal/service/node.go b/server/internal/service/node.go index 01895ac..5def2b2 100644 --- a/server/internal/service/node.go +++ b/server/internal/service/node.go @@ -105,6 +105,29 @@ func (s *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNodeSt return &pb.UpdateNodeStatusResponse{Code: 200, Message: "成功"}, nil } +func (s *NodeService) DiscoveredNode(ctx context.Context, req *pb.DiscoveredNodeRequest) (*pb.DiscoveredNodeResponse, error) { + nodes, err := s.uc.DiscoveredNode() + if err != nil { + return nil, err + } + + var list []*pb.DiscoveredNodeInfo + for _, value := range nodes { + list = append(list, &pb.DiscoveredNodeInfo{NodeIp: value.NodeIp, NodeName: value.NodeName}) + } + + return &pb.DiscoveredNodeResponse{List: list}, nil +} + +func (s *NodeService) JoinNode(ctx context.Context, req *pb.JoinNodeRequest) (*pb.JoinNodeResponse, error) { + err := s.uc.JoinNode(req.NodeNames) + if err != nil { + return &pb.JoinNodeResponse{Code: 500, Message: err.Error()}, err + } + + return &pb.JoinNodeResponse{Code: 200, Message: "成功"}, nil +} + func (s *NodeService) buildNodeReply(ctx context.Context, node *biz.Node) (*pb.NodeReply, error) { nodeReply := &pb.NodeReply{ Name: node.Name, diff --git a/server/openapi.yaml b/server/openapi.yaml index 96fe1e9..ad4b343 100644 --- a/server/openapi.yaml +++ b/server/openapi.yaml @@ -29,6 +29,54 @@ paths: application/json: schema: $ref: '#/components/schemas/Status' + /v1/node/discovered: + post: + tags: + - Node + operationId: Node_DiscoveredNode + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DiscoveredNodeRequest' + required: true + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/DiscoveredNodeResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' + /v1/node/join: + post: + tags: + - Node + operationId: Node_JoinNode + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/JoinNodeRequest' + required: true + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/JoinNodeResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' /v1/node/status/update: post: tags: @@ -130,6 +178,23 @@ components: nodeCount: type: integer format: int32 + DiscoveredNodeInfo: + type: object + properties: + nodeIp: + type: string + nodeName: + type: string + DiscoveredNodeRequest: + type: object + properties: {} + DiscoveredNodeResponse: + type: object + properties: + list: + type: array + items: + $ref: '#/components/schemas/DiscoveredNodeInfo' GetAllNodesReq: type: object properties: @@ -166,6 +231,21 @@ components: description: The type of the serialized message. additionalProperties: true description: Contains an arbitrary serialized message along with a @type that describes the type of the serialized message. + JoinNodeRequest: + type: object + properties: + nodeNames: + type: array + items: + type: string + JoinNodeResponse: + type: object + properties: + code: + type: integer + format: int32 + message: + type: string NodeReply: type: object properties: