Skip to content

Commit

Permalink
Merge pull request #927 from bjwswang/main
Browse files Browse the repository at this point in the history
fix: run worker atop of two gpus in one node
  • Loading branch information
bjwswang authored Mar 26, 2024
2 parents 65aeaff + ec662de commit 5f06644
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 39 deletions.
33 changes: 33 additions & 0 deletions pkg/config/config_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package config

import (
"fmt"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
)

Expand Down Expand Up @@ -72,4 +74,35 @@ type RayCluster struct {
DashboardHost string `json:"dashboardHost,omitempty"`
// Overwrite the python version in the woker
PythonVersion string `json:"pythonVersion,omitempty"`
// Ray cluster version
RayVersion string `json:"rayVersion,omitempty"`
}

func (rayCluster RayCluster) String() string {
return fmt.Sprintf("Name:%s HeadAddress: %s DashboardHost:%s PythonVersion:%s RayVersion: %s", rayCluster.Name, rayCluster.HeadAddress, rayCluster.DashboardHost, rayCluster.PythonVersion, rayCluster.RayVersion)
}

// GetRayVersion in ray cluster
func (rayCluster RayCluster) GetRayVersion() string {
// Default ray version is 2.9.3
if rayCluster.RayVersion == "" {
return "2.9.3"
}
return rayCluster.RayVersion
}

func (rayCluster RayCluster) GetPythonVersion() string {
// Default python version is 3.9.5
if rayCluster.PythonVersion == "" {
return "3.9.5"
}
return rayCluster.PythonVersion
}

func DefaultRayCluster() RayCluster {
return RayCluster{
Name: "default",
PythonVersion: "3.9.5",
RayVersion: "2.9.3",
}
}
80 changes: 41 additions & 39 deletions pkg/worker/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ const (
defaultFastchatVLLMImage = "kubeagi/arcadia-fastchat-worker:vllm-v0.2.36"
// defaultKubeAGIImage for RunnerKubeAGI
defaultKubeAGIImage = "kubeagi/core-library-cli:v0.0.1"

// mount path in runner
defaultModelMountPath = "/data/models"
defaultShmMountPath = "/dev/shm"
)

// ModelRunner run a model service
Expand Down Expand Up @@ -95,7 +99,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
}
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
modelFileDir := fmt.Sprintf("%s/%s", defaultModelMountPath, model.Name)
additionalEnvs := []corev1.EnvVar{}
extraArgs := fmt.Sprintf("--device %s %s", runner.Device().String(), extraAgrs)
if runner.modelFileFromRemote {
Expand Down Expand Up @@ -139,7 +143,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
{Name: "http", ContainerPort: arcadiav1alpha1.DefaultWorkerPort},
},
VolumeMounts: []corev1.VolumeMount{
{Name: "models", MountPath: "/data/models"},
{Name: "models", MountPath: defaultModelMountPath},
},
Resources: runner.w.Spec.Resources,
}
Expand Down Expand Up @@ -187,46 +191,45 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
return nil, fmt.Errorf("failed to get arcadia config with %w", err)
}

rayEnabled := false
rayClusterAddress := ""
pythonVersion := ""
extraAgrs := ""
additionalEnvs := []corev1.EnvVar{}

rayClusterIndex := 0
// configure ray cluster
rayCluster := config.DefaultRayCluster()
for _, envItem := range runner.w.Spec.AdditionalEnvs {
// Check if Ray is enabled using distributed inference
if envItem.Name == "NUMBER_GPUS" {
rayEnabled = true
}
if envItem.Name == "RAY_CLUSTER_INDEX" {
rayClusterIndex, _ = strconv.Atoi(envItem.Value)
rayEnabled = true
externalRayClusterIndex, _ := strconv.Atoi(envItem.Value)
rayClusters, err := config.GetRayClusters(ctx, runner.c)
if err != nil || len(rayClusters) == 0 {
return nil, fmt.Errorf("failed to find ray clusters: %s", err.Error())
}
if len(rayClusters) == 0 {
return nil, fmt.Errorf("no ray clusters configured")
}
rayCluster = rayClusters[externalRayClusterIndex]
}
// extra arguments to run llm
if envItem.Name == "EXTRA_ARGS" {
extraAgrs = envItem.Value
}
}

// Get ray config from configMap
if rayEnabled {
rayClusters, err := config.GetRayClusters(ctx, runner.c)
if err != nil || len(rayClusters) == 0 {
klog.Warningln("no ray cluster configured, fallback to local resource: ", err)
} else {
// Use the 1st ray cluster for now
// TODO: let user to select with ray cluster to use
rayClusterAddress = rayClusters[rayClusterIndex].HeadAddress
pythonVersion = rayClusters[rayClusterIndex].PythonVersion
klog.Infof("run worker using ray: %s, number of GPU: %s", rayClusterAddress, runner.NumberOfGPUs())
}
} else {
// Set gpu number to the number of GPUs in the worker's resource
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "NUMBER_GPUS", Value: runner.NumberOfGPUs()})
klog.Infof("run worker with %s GPU", runner.NumberOfGPUs())
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
// set ray configurations into additional environments
additionalEnvs = append(additionalEnvs,
corev1.EnvVar{
Name: "RAY_ADDRESS",
Value: rayCluster.HeadAddress,
}, corev1.EnvVar{
Name: "RAY_VERSION",
Value: rayCluster.GetRayVersion(),
}, corev1.EnvVar{
Name: "PYTHON_VERSION",
Value: rayCluster.GetPythonVersion(),
})
// Set gpu number to the number of GPUs in the worker's resource
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "NUMBER_GPUS", Value: runner.NumberOfGPUs()})
klog.V(5).Infof("run worker with raycluster:\n %s", rayCluster.String())

modelFileDir := fmt.Sprintf("%s/%s", defaultModelMountPath, model.Name)
// --enforce-eager to disable cupy
// TODO: remove --enforce-eager when https://github.com/kubeagi/arcadia/issues/878 is fixed
extraAgrs = fmt.Sprintf("%s --trust-remote-code --enforce-eager", extraAgrs)
Expand Down Expand Up @@ -264,15 +267,14 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
{Name: "FASTCHAT_WORKER_ADDRESS", Value: fmt.Sprintf("http://%s.%s:%d", runner.w.Name+WokerCommonSuffix, runner.w.Namespace, arcadiav1alpha1.DefaultWorkerPort)},
{Name: "FASTCHAT_CONTROLLER_ADDRESS", Value: gw.Controller},
{Name: "EXTRA_ARGS", Value: extraAgrs},
// Need python version and ray address for distributed inference
{Name: "PYTHON_VERSION", Value: pythonVersion},
{Name: "RAY_ADDRESS", Value: rayClusterAddress},
},
Ports: []corev1.ContainerPort{
{Name: "http", ContainerPort: arcadiav1alpha1.DefaultWorkerPort},
},
VolumeMounts: []corev1.VolumeMount{
{Name: "models", MountPath: "/data/models"},
{Name: "models", MountPath: defaultModelMountPath},
// mount volume to /dev/shm to avoid Bus error
{Name: "models", MountPath: defaultShmMountPath},
},
Resources: runner.w.Spec.Resources,
}
Expand Down Expand Up @@ -322,8 +324,8 @@ func (runner *KubeAGIRunner) Build(ctx context.Context, model *arcadiav1alpha1.T
}

// read worker address
mountPath := "/data/models"
rerankModelPath := fmt.Sprintf("%s/%s", mountPath, model.Name)
modelMountPath := "/data/models"
rerankModelPath := fmt.Sprintf("%s/%s", modelMountPath, model.Name)

if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
Expand Down Expand Up @@ -355,7 +357,7 @@ func (runner *KubeAGIRunner) Build(ctx context.Context, model *arcadiav1alpha1.T
{Name: "http", ContainerPort: arcadiav1alpha1.DefaultWorkerPort},
},
VolumeMounts: []corev1.VolumeMount{
{Name: "models", MountPath: mountPath},
{Name: "models", MountPath: defaultModelMountPath},
},
Resources: runner.w.Spec.Resources,
}
Expand Down

0 comments on commit 5f06644

Please sign in to comment.