diff --git a/llm-finetuning/README.md b/llm-finetuning/README.md index 1d20fe58..c3de0383 100644 --- a/llm-finetuning/README.md +++ b/llm-finetuning/README.md @@ -78,13 +78,51 @@ python run.py --training-pipeline --config finetune_gcp.yaml # Deployment python run.py --deployment-pipeline --config -python run.py --deployment-pipeline --config finetune_gcp.yaml +python run.py --deployment-pipeline --config deployment_a100.yaml ``` The `feature_engineering` and `deployment` pipeline can be run simply with the `default` stack, but the training pipelines [stack](https://docs.zenml.io/user-guide/production-guide/understand-stacks) will depend on the config. The `deployment` pipelines relies on the `training_pipeline` to have run before. +## :cloud: Deployment + +We have create a custom zenml model deployer for deploying models on the huggingface inference endpoint. The code for custom deployer is in [huggingface](./huggingface/) folder. + +For running deployment pipeline, we create a custom zenml stack. As we are using a custom model deployer, we will have to register the flavor and model deployer. We update the stack to use this custom model deployer for running deployment pipeline. + +```bash +zenml init +zenml stack register zencoder_hf_stack -o default -a default +zenml stack set zencoder_hf_stack +export HUGGINGFACE_USERNAME= +export HUGGINGFACE_TOKEN= +export NAMESPACE= +zenml secret create huggingface_creds --username=$HUGGINGFACE_USERNAME --token=$HUGGINGFACE_TOKEN +zenml model-deployer flavor register huggingface.hf_model_deployer_flavor.HuggingFaceModelDeployerFlavor +``` + +Afterward, you should see the new flavor in the list of available flavors: + +```bash +zenml model-deployer flavor list +``` + +Register model deployer component into the current stack + +```bash +zenml model-deployer register hfendpoint --flavor=hfendpoint --token=$HUGGINGFACE_TOKEN --namespace=$NAMESPACE +zenml stack update zencoder_hf_stack -d hfendpoint +``` + +Run the deployment pipeline using the CLI: + +```shell +# Deployment +python run.py --deployment-pipeline --config +python run.py --deployment-pipeline --config deployment_a100.yaml +``` + ## 🥇Recent developments A working prototype has been trained and deployed as of Jan 19 2024. The model is using minimal data and finetuned using QLoRA and PEFT. The model was trained using 1 A100 GPU on the cloud: @@ -114,6 +152,7 @@ This project recently did a [call of volunteers](https://www.linkedin.com/feed/u - [x] Create a functioning training pipeline. - [ ] Curate a set of 5-10 repositories that are using the ZenML latest syntax and use data generation pipeline to push dataset to HuggingFace. - [ ] Create a Dockerfile for the training pipeline with all requirements installed including ZenML, torch, CUDA etc. CUrrently I am having trouble creating this in this [config file](configs/finetune_local.yaml). Probably might make sense to create a docker imag with the right CUDA and requirements including ZenML. See here: https://sdkdocs.zenml.io/0.54.0/integration_code_docs/integrations-aws/#zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorSettings + - [ ] Tests trained model on various metrics - [ ] Create a custom [model deployer](https://docs.zenml.io/stacks-and-components/component-guide/model-deployers) that deploys a huggingface model from the hub to a huggingface inference endpoint. This would involve creating a [custom model deployer](https://docs.zenml.io/stacks-and-components/component-guide/model-deployers/custom) and editing the [deployment pipeline accordingly](pipelines/deployment.py) @@ -121,9 +160,9 @@ This project recently did a [call of volunteers](https://www.linkedin.com/feed/u While the work here is solely based on the task of finetuning the model for the ZenML library, the pipeline can be changed with minimal effort to point to any set of repositories on GitHub. Theoretically, one could extend this work to point to proprietary codebases to learn from them for any use-case. -For example, see how [VMWare fine-tuned StarCoder to learn their style](https://octo.vmware.com/fine-tuning-starcoder-to-learn-vmwares-coding-style/). +For example, see how [VMWare fine-tuned StarCoder to learn their style](https://octo.vmware.com/fine-tuning-starcoder-to-learn-vmwares-coding-style/). Also, make sure to join our Slack - Slack Community - to become part of the ZenML family! \ No newline at end of file + Slack Community + to become part of the ZenML family! diff --git a/llm-finetuning/configs/deployment_a10.yaml b/llm-finetuning/configs/deployment_a10.yaml index 4622ce40..1ae571c3 100644 --- a/llm-finetuning/configs/deployment_a10.yaml +++ b/llm-finetuning/configs/deployment_a10.yaml @@ -10,21 +10,22 @@ model: steps: deploy_model_to_hf_hub: parameters: - framework: pytorch - task: text-generation - accelerator: gpu - vendor: aws - region: us-east-1 - max_replica: 1 - instance_size: xxlarge - instance_type: g5.12xlarge - namespace: zenml - custom_image: - health_route: /health - env: - MAX_BATCH_PREFILL_TOKENS: "2048" - MAX_INPUT_LENGTH: "1024" - MAX_TOTAL_TOKENS: "1512" - QUANTIZE: bitsandbytes - MODEL_ID: /repository - url: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-564f2a3 \ No newline at end of file + hf_endpoint_cfg: + framework: pytorch + task: text-generation + accelerator: gpu + vendor: aws + region: us-east-1 + max_replica: 1 + instance_size: xxlarge + instance_type: g5.12xlarge + namespace: zenml + custom_image: + health_route: /health + env: + MAX_BATCH_PREFILL_TOKENS: "2048" + MAX_INPUT_LENGTH: "1024" + MAX_TOTAL_TOKENS: "1512" + QUANTIZE: bitsandbytes + MODEL_ID: /repository + url: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-564f2a3 diff --git a/llm-finetuning/configs/deployment_a100.yaml b/llm-finetuning/configs/deployment_a100.yaml index a7b1eaee..fedf6c58 100644 --- a/llm-finetuning/configs/deployment_a100.yaml +++ b/llm-finetuning/configs/deployment_a100.yaml @@ -10,21 +10,22 @@ model: steps: deploy_model_to_hf_hub: parameters: - framework: pytorch - task: text-generation - accelerator: gpu - vendor: aws - region: us-east-1 - max_replica: 1 - instance_size: xlarge - instance_type: p4de - namespace: zenml - custom_image: - health_route: /health - env: - MAX_BATCH_PREFILL_TOKENS: "2048" - MAX_INPUT_LENGTH: "1024" - MAX_TOTAL_TOKENS: "1512" - QUANTIZE: bitsandbytes - MODEL_ID: /repository - url: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-564f2a3 \ No newline at end of file + hf_endpoint_cfg: + framework: pytorch + task: text-generation + accelerator: gpu + vendor: aws + region: us-east-1 + max_replica: 1 + instance_size: xlarge + instance_type: p4de + namespace: zenml + custom_image: + health_route: /health + env: + MAX_BATCH_PREFILL_TOKENS: "2048" + MAX_INPUT_LENGTH: "1024" + MAX_TOTAL_TOKENS: "1512" + QUANTIZE: bitsandbytes + MODEL_ID: /repository + url: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-564f2a3 diff --git a/llm-finetuning/configs/deployment_t4.yaml b/llm-finetuning/configs/deployment_t4.yaml index 26d5733f..1b9cebe7 100644 --- a/llm-finetuning/configs/deployment_t4.yaml +++ b/llm-finetuning/configs/deployment_t4.yaml @@ -10,21 +10,22 @@ model: steps: deploy_model_to_hf_hub: parameters: - framework: pytorch - task: text-generation - accelerator: gpu - vendor: aws - region: us-east-1 - max_replica: 1 - instance_size: large - instance_type: g4dn.12xlarge - namespace: zenml - custom_image: - health_route: /health - env: - MAX_BATCH_PREFILL_TOKENS: "2048" - MAX_INPUT_LENGTH: "1024" - MAX_TOTAL_TOKENS: "1512" - QUANTIZE: bitsandbytes - MODEL_ID: /repository - url: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-564f2a3 \ No newline at end of file + hf_endpoint_cfg: + framework: pytorch + task: text-generation + accelerator: gpu + vendor: aws + region: us-east-1 + max_replica: 1 + instance_size: large + instance_type: g4dn.12xlarge + namespace: zenml + custom_image: + health_route: /health + env: + MAX_BATCH_PREFILL_TOKENS: "2048" + MAX_INPUT_LENGTH: "1024" + MAX_TOTAL_TOKENS: "1512" + QUANTIZE: bitsandbytes + MODEL_ID: /repository + url: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-564f2a3 diff --git a/llm-finetuning/huggingface/__init__.py b/llm-finetuning/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm-finetuning/huggingface/hf_deployment_base_config.py b/llm-finetuning/huggingface/hf_deployment_base_config.py new file mode 100644 index 00000000..2ecaaeed --- /dev/null +++ b/llm-finetuning/huggingface/hf_deployment_base_config.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel +from typing import Optional, Dict +from zenml.utils.secret_utils import SecretField + + +class HuggingFaceBaseConfig(BaseModel): + """Huggingface Inference Endpoint configuration.""" + + endpoint_name: Optional[str] = "" + repository: Optional[str] = None + framework: Optional[str] = None + accelerator: Optional[str] = None + instance_size: Optional[str] = None + instance_type: Optional[str] = None + region: Optional[str] = None + vendor: Optional[str] = None + token: Optional[str] = None + account_id: Optional[str] = None + min_replica: Optional[int] = 0 + max_replica: Optional[int] = 1 + revision: Optional[str] = None + task: Optional[str] = None + custom_image: Optional[Dict] = None + namespace: Optional[str] = None + endpoint_type: str = "public" diff --git a/llm-finetuning/huggingface/hf_deployment_service.py b/llm-finetuning/huggingface/hf_deployment_service.py new file mode 100644 index 00000000..54b297ce --- /dev/null +++ b/llm-finetuning/huggingface/hf_deployment_service.py @@ -0,0 +1,200 @@ +"""Implementation of the Huggingface Deployment service.""" +from zenml.logger import get_logger +from typing import Generator, Tuple, Optional, Any, List +from zenml.services import ServiceType, ServiceState, ServiceStatus +from zenml.services.service import BaseDeploymentService, ServiceConfig +from huggingface_hub import ( + InferenceClient, + InferenceEndpointError, + InferenceEndpoint, + InferenceEndpointStatus, +) +from huggingface_hub.utils import HfHubHTTPError +from huggingface_hub import ( + create_inference_endpoint, + get_inference_endpoint, +) +from huggingface.hf_deployment_base_config import HuggingFaceBaseConfig + +from pydantic import Field + +logger = get_logger(__name__) + +POLLING_TIMEOUT = 1200 + + +class HuggingFaceServiceConfig(HuggingFaceBaseConfig, ServiceConfig): + """Base class for Huggingface configurations.""" + + +class HuggingFaceServiceStatus(ServiceStatus): + """HF Endpoint Inference service status.""" + + +class HuggingFaceDeploymentService(BaseDeploymentService): + """HuggingFace model deployment service.""" + + SERVICE_TYPE = ServiceType( + name="hf-endpoint-deployment", + type="model-serving", + flavor="hfendpoint", + description="Huggingface inference endpoint service", + ) + config: HuggingFaceServiceConfig + status: HuggingFaceServiceStatus = Field( + default_factory=lambda: HuggingFaceServiceStatus() + ) + + def __init__(self, config: HuggingFaceServiceConfig, **attrs: Any): + """_summary_.""" + super().__init__(config=config, **attrs) + + @property + def hf_endpoint(self) -> InferenceEndpoint: + """Get the deployed Huggingface inference endpoint. + + Returns: + Huggingface inference endpoint. + """ + return get_inference_endpoint( + name=self.config.endpoint_name, + token=self.config.token, + namespace=self.config.namespace, + ) + + @property + def prediction_url(self) -> Optional[str]: + """The prediction URI exposed by the prediction service. + + Returns: + The prediction URI exposed by the prediction service, or None if + the service is not yet ready. + """ + if not self.is_running: + return None + return self.hf_endpoint.url + + @property + def inference_client(self) -> InferenceClient: + """Get the Huggingface InferenceClient from Inference Endpoint. + + Returns: + Huggingface inference client. + """ + return self.hf_endpoint.client + + def provision(self) -> None: + """Provision or update remote Huggingface deployment instance. + + This should then match the current configuration. + """ + + _ = create_inference_endpoint( + name=self.config.endpoint_name, + repository=self.config.repository, + framework=self.config.framework, + accelerator=self.config.accelerator, + instance_size=self.config.instance_size, + instance_type=self.config.instance_type, + region=self.config.region, + vendor=self.config.vendor, + account_id=self.config.account_id, + min_replica=self.config.min_replica, + max_replica=self.config.max_replica, + revision=self.config.revision, + task=self.config.task, + custom_image=self.config.custom_image, + type=self.config.endpoint_type, + namespace=self.config.namespace, + token=self.config.token, + ).wait(timeout=POLLING_TIMEOUT) + + if self.hf_endpoint.url is not None: + logger.info( + "Huggingface inference endpoint successfully deployed." + ) + else: + logger.info( + "Failed to start huggingface inference endpoint service..." + ) + + def check_status(self) -> Tuple[ServiceState, str]: + """Check the the current operational state of the HuggingFace deployment. + + Returns: + The operational state of the HuggingFace deployment and a message + providing additional information about that state (e.g. a + description of the error, if one is encountered). + """ + try: + _ = self.hf_endpoint.status + except (InferenceEndpointError, HfHubHTTPError): + return (ServiceState.INACTIVE, "") + + if self.hf_endpoint.status == InferenceEndpointStatus.RUNNING: + return ( + ServiceState.ACTIVE, + f"HuggingFace Inference Endpoint deployment is available", + ) + + if self.hf_endpoint.status == InferenceEndpointStatus.FAILED: + return ( + ServiceState.ERROR, + f"HuggingFace Inference Endpoint deployment failed: ", + ) + + if self.hf_endpoint.status == InferenceEndpointStatus.PENDING: + return ( + ServiceState.PENDING_STARTUP, + "HuggingFace Inference Endpoint deployment is being created: ", + ) + + def deprovision(self, force: bool = False) -> None: + """Deprovision the remote HuggingFace deployment instance. + + Args: + force: if True, the remote deployment instance will be + forcefully deprovisioned. + """ + try: + self.hf_endpoint.delete() + except HfHubHTTPError: + logger.error( + "Huggingface Inference Endpoint is deleted or cannot be found." + ) + pass + + def predict(self, data: "Any", max_new_tokens: int) -> "Any": + """Make a prediction using the service. + + Args: + data: input data + max_new_tokens: Number of new tokens to generate + + Returns: + The prediction result. + + Raises: + Exception: if the service is not running + ValueError: if the prediction endpoint is unknown. + """ + if not self.is_running: + raise Exception( + "Huggingface endpoint inference service is not running. " + "Please start the service before making predictions." + ) + if self.hf_endpoint.prediction_url is not None: + if self.hf_endpoint.task == "text-generation": + result = self.inference_client.task_generation( + data, max_new_tokens=max_new_tokens + ) + else: + raise NotImplementedError( + "Tasks other than text-generation is not implemented." + ) + return result + + def get_logs( + self, follow: bool = False, tail: int = None + ) -> Generator[str, bool, None]: + return super().get_logs(follow, tail) diff --git a/llm-finetuning/huggingface/hf_model_deployer.py b/llm-finetuning/huggingface/hf_model_deployer.py new file mode 100644 index 00000000..544f03f3 --- /dev/null +++ b/llm-finetuning/huggingface/hf_model_deployer.py @@ -0,0 +1,452 @@ +"""Implementation of the Huggingface Model Deployer.""" +from uuid import UUID +from zenml.model_deployers import BaseModelDeployer +from huggingface.hf_model_deployer_flavor import HuggingFaceModelDeployerFlavor +from zenml.logger import get_logger + +from typing import List, Optional, cast, ClassVar, Type, Dict +from zenml.services import BaseService, ServiceConfig +from huggingface.hf_deployment_service import ( + HuggingFaceDeploymentService, + HuggingFaceServiceConfig, +) +from huggingface.hf_model_deployer_flavor import ( + HuggingFaceModelDeployerSettings, + HuggingFaceModelDeployerConfig, +) +from zenml.model_deployers.base_model_deployer import ( + DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + BaseModelDeployerFlavor, +) +from huggingface_hub import InferenceEndpoint, list_inference_endpoints +from zenml.client import Client +from zenml.services import ServiceRegistry +from zenml.artifacts.utils import save_artifact, log_artifact_metadata + +logger = get_logger(__name__) + +HUGGINGFACE_SERVICE_ARTIFACT = "hf_deployment_service" + + +class HuggingFaceModelDeployer(BaseModelDeployer): + """Huggingface endpoint model deployer.""" + + NAME: ClassVar[str] = "HFEndpoint" + FLAVOR: ClassVar[ + Type[BaseModelDeployerFlavor] + ] = HuggingFaceModelDeployerFlavor + + @property + def config(self) -> HuggingFaceModelDeployerConfig: + """Config class for the Huggingface Model deployer settings class. + + Returns: + The configuration. + """ + return cast(HuggingFaceModelDeployerConfig, self._config) + + @property + def settings_class(self) -> Type[HuggingFaceModelDeployerSettings]: + """Settings class for the Huggingface Model deployer settings class. + + Returns: + The settings class. + """ + return HuggingFaceModelDeployerSettings + + @property + def deployed_endpoints(self) -> List[InferenceEndpoint]: + """Get list of deployed endpoint from Huggingface. + + Returns: + List of deployed endpoints. + """ + return list_inference_endpoints( + token=self.config.token, + namespace=self.config.namespace, + ) + + def modify_endpoint_name( + self, endpoint_name: str, artifact_version: str + ) -> str: + """Modify endpoint name by adding suffix and prefix. + + It adds a prefix "zenml-" if not present and a suffix + of first 8 characters of uuid. + + Args: + endpoint_name : Name of the endpoint + artifact_version: Name of the artifact version + + Returns: + Modified endpoint name with added prefix and suffix + """ + + # Add zenml prefix if endpoint name is not set + # or it does not start with "zenml-" + if not endpoint_name and not endpoint_name.startswith("zenml-"): + endpoint_name = "zenml-" + endpoint_name + else: + endpoint_name = "zenml-" + + # Add first 8 characters of UUID to endpoint name + endpoint_name += artifact_version + return endpoint_name + + def _create_new_service( + self, timeout: int, config: HuggingFaceServiceConfig + ) -> HuggingFaceDeploymentService: + """Creates a new HuggingFaceDeploymentService. + + Args: + timeout: the timeout in seconds to wait for the Huggingface inference endpoint + to be provisioned and successfully started or updated. + config: the configuration of the model to be deployed with Huggingface model deployer. + + Returns: + The HuggingFaceServiceConfig object that can be used to interact + with the Huggingface inference endpoint. + """ + # create a new service for the new model + service = HuggingFaceDeploymentService(config) + + # Use first 8 characters of UUID as artifact version + # Add same 8 characters as suffix to endpoint name + service_metadata = service.dict() + artifact_version = str(service_metadata["uuid"])[:8] + + service.config.endpoint_name = self.modify_endpoint_name( + service.config.endpoint_name, artifact_version + ) + + logger.info( + f"Creating an artifact {HUGGINGFACE_SERVICE_ARTIFACT} with service instance attached as metadata." + " If there's an active pipeline and/or model this artifact will be associated with it." + ) + + service_metadata = service.dict() + + save_artifact( + service, + HUGGINGFACE_SERVICE_ARTIFACT, + version=artifact_version, + is_deployment_artifact=True, + ) + # UUID object is not json serializable + service_metadata["uuid"] = str(service_metadata["uuid"]) + log_artifact_metadata( + artifact_name=HUGGINGFACE_SERVICE_ARTIFACT, + artifact_version=artifact_version, + metadata={HUGGINGFACE_SERVICE_ARTIFACT: service_metadata}, + ) + + service.start(timeout=timeout) + return service + + def _clean_up_existing_service( + self, + timeout: int, + force: bool, + existing_service: HuggingFaceDeploymentService, + ) -> None: + """Stop existing services. + + Args: + timeout: the timeout in seconds to wait for the Huggingface + deployment to be stopped. + force: if True, force the service to stop + existing_service: Existing Huggingface deployment service + """ + # stop the older service + existing_service.stop(timeout=timeout, force=force) + + def deploy_model( + self, + config: ServiceConfig, + replace: bool = True, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Create a new Huggingface deployment service or update an existing one. + + This should serve the supplied model and deployment configuration. + + Args: + config: the configuration of the model to be deployed with Huggingface. + Core + replace: set this flag to True to find and update an equivalent + Huggingface deployment server with the new model instead of + starting a new deployment server. + timeout: the timeout in seconds to wait for the Huggingface endpoint + to be provisioned and successfully started or updated. If set + to 0, the method will return immediately after the Huggingface + server is provisioned, without waiting for it to fully start. + + Raises: + RuntimeError: _description_ + + Returns: + The ZenML Huggingface deployment service object that can be used to + interact with the remote Huggingface inference endpoint server. + """ + config = cast(HuggingFaceServiceConfig, config) + service = None + + # if replace is True, remove all existing services + if replace is True: + existing_services = self.find_model_server( + pipeline_name=config.pipeline_name, + pipeline_step_name=config.pipeline_step_name, + ) + + for existing_service in existing_services: + if service is None: + # keep the most recently created service + service = cast( + HuggingFaceDeploymentService, existing_service + ) + try: + # delete the older services and don't wait for them to + # be deprovisioned + self._clean_up_existing_service( + existing_service=cast( + HuggingFaceDeploymentService, existing_service + ), + timeout=timeout, + force=True, + ) + except RuntimeError: + # ignore errors encountered while stopping old services + pass + + if service: + # update an equivalent service in place + logger.info( + f"Updating an existing Huggingface deployment service: {service}" + ) + + # Default endpoint name is set to "" + # Using same name as endpoint name results in Bad name + service_metadata = service.dict() + artifact_version = str(service_metadata["uuid"])[:8] + config.endpoint_name = self.modify_endpoint_name( + config.endpoint_name, artifact_version + ) + + service.stop(timeout=timeout, force=True) + service.update(config) + service.start(timeout=timeout) + else: + # create a new HuggingFaceDeploymentService instance + service = self._create_new_service(timeout, config) + logger.info( + f"Creating a new huggingface inference endpoint service: {service}" + ) + + return cast(BaseService, service) + + def find_model_server( + self, + running: bool = False, + service_uuid: Optional[UUID] = None, + pipeline_name: Optional[str] = None, + run_name: Optional[str] = None, + pipeline_step_name: Optional[str] = None, + model_name: Optional[str] = None, + model_uri: Optional[str] = None, + model_type: Optional[str] = None, + ) -> List[BaseService]: + """Find one or more Huggingface model services that match the given criteria. + + + Args: + running: if true, only running services will be returned. + service_uuid: the UUID of the Huggingface service that was + originally used to create the Huggingface deployment resource. + pipeline_name: name of the pipeline that the deployed model was part + of. + run_name: Name of the pipeline run which the deployed model was + part of. + pipeline_step_name: the name of the pipeline model deployment step + that deployed the model. + model_name: the name of the deployed model. + model_uri: URI of the deployed model. + model_type: the Huggingface server implementation used to serve + the model + + Raises: + TypeError: _description_ + + Returns: + One or more Huggingface service objects representing Huggingface + model servers that match the input search criteria. + """ + # Use a Huggingface deployment service configuration to compute the labels + config = HuggingFaceServiceConfig( + pipeline_name=pipeline_name or "", + run_name=run_name or "", + pipeline_run_id=run_name or "", + pipeline_step_name=pipeline_step_name or "", + model_name=model_name or "", + model_uri=model_uri or "", + implementation=model_type or "", + ) + + services: List[BaseService] = [] + + # Find all services that match input criteria + for endpoint in self.deployed_endpoints: + if endpoint.name.startswith("zenml-"): + artifact_version = endpoint.name[-8:] + # If service_uuid is supplied, fetch service for that uuid + if ( + service_uuid is not None + and str(service_uuid)[:8] != artifact_version + ): + continue + + # Fetch the saved metadata artifact from zenml server to recreate service + client = Client() + try: + service_artifact = client.get_artifact_version( + HUGGINGFACE_SERVICE_ARTIFACT, artifact_version + ) + hf_deployment_service_dict = service_artifact.run_metadata[ + HUGGINGFACE_SERVICE_ARTIFACT + ].value + + existing_service = ( + ServiceRegistry().load_service_from_dict( + hf_deployment_service_dict + ) + ) + + if not isinstance( + existing_service, HuggingFaceDeploymentService + ): + raise TypeError( + f"Expected service type HuggingFaceDeploymentService but got " + f"{type(existing_service)} instead" + ) + + existing_service.update_status() + if self._matches_search_criteria(existing_service, config): + if not running or existing_service.is_running: + services.append( + cast(BaseService, existing_service) + ) + + except KeyError: + pass + + return services + + def _matches_search_criteria( + self, + existing_service: HuggingFaceDeploymentService, + config: HuggingFaceModelDeployerConfig, + ) -> bool: + """Returns true if a service matches the input criteria. + + If any of the values in the input criteria are None, they are ignored. + This allows listing services just by common pipeline names or step + names, etc. + + Args: + existing_service: The materialized Service instance derived from + the config of the older (existing) service + config: The BentoMlDeploymentConfig object passed to the + deploy_model function holding parameters of the new service + to be created. + + Returns: + True if the service matches the input criteria. + """ + existing_service_config = existing_service.config + + # check if the existing service matches the input criteria + if ( + ( + not config.pipeline_name + or existing_service_config.pipeline_name + == config.pipeline_name + ) + and ( + not config.pipeline_step_name + or existing_service_config.pipeline_step_name + == config.pipeline_step_name + ) + and ( + not config.run_name + or existing_service_config.run_name == config.run_name + ) + ): + return True + + return False + + def stop_model_server( + self, + uuid: UUID, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> None: + """Method to stop a model server. + + Args: + uuid: UUID of the model server to stop. + timeout: Timeout in seconds to wait for the service to stop. + force: If True, force the service to stop. + """ + # get list of all services + existing_services = self.find_model_server(service_uuid=uuid) + + # if the service exists, stop it + if existing_services: + existing_services[0].stop(timeout=timeout, force=force) + + def start_model_server( + self, uuid: UUID, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT + ) -> None: + """Method to start a model server. + + Args: + uuid: UUID of the model server to start. + timeout: Timeout in seconds to wait for the service to start. + """ + # get list of all services + existing_services = self.find_model_server(service_uuid=uuid) + + # if the service exists, start it + if existing_services: + existing_services[0].start(timeout=timeout) + + def delete_model_server( + self, + uuid: UUID, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> None: + """Method to delete all configuration of a model server. + + Args: + uuid: UUID of the model server to delete. + timeout: Timeout in seconds to wait for the service to stop. + force: If True, force the service to stop. + """ + # get list of all services + existing_services = self.find_model_server(service_uuid=uuid) + + # if the service exists, clean it up + if existing_services: + service = cast(HuggingFaceDeploymentService, existing_services[0]) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) + + def get_model_server_info( + self, + service_instance: "HuggingFaceDeploymentService", + ) -> Dict[str, Optional[str]]: + return { + "PREDICTION_URL": service_instance.prediction_url, + } diff --git a/llm-finetuning/huggingface/hf_model_deployer_flavor.py b/llm-finetuning/huggingface/hf_model_deployer_flavor.py new file mode 100644 index 00000000..6a85845a --- /dev/null +++ b/llm-finetuning/huggingface/hf_model_deployer_flavor.py @@ -0,0 +1,91 @@ +"""Huggingface model deployer flavor.""" +from typing import Optional, Type, TYPE_CHECKING +from zenml.model_deployers.base_model_deployer import ( + BaseModelDeployerFlavor, + BaseModelDeployerConfig, +) +from zenml.config.base_settings import BaseSettings +from huggingface.hf_deployment_base_config import HuggingFaceBaseConfig +from zenml.utils.secret_utils import SecretField + +if TYPE_CHECKING: + from huggingface.hf_model_deployer import HuggingFaceModelDeployer + + +HUGGINGFACE_MODEL_DEPLOYER_FLAVOR = "hfendpoint" + + +class HuggingFaceModelDeployerSettings(HuggingFaceBaseConfig, BaseSettings): + """Settings for the Huggingface model deployer.""" + + +class HuggingFaceModelDeployerConfig( + BaseModelDeployerConfig, HuggingFaceModelDeployerSettings +): + """Configuration for the Huggingface model deployer.""" + + token: str = SecretField() + + # The namespace to list endpoints for. Set to `"*"` to list all endpoints + # from all namespaces (i.e. personal namespace and all orgs the user belongs to). + namespace: str + + +class HuggingFaceModelDeployerFlavor(BaseModelDeployerFlavor): + """Huggingface Endpoint model deployer flavor.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return HUGGINGFACE_MODEL_DEPLOYER_FLAVOR + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_registry/huggingface.png" + + @property + def config_class(self) -> Type[HuggingFaceModelDeployerConfig]: + """Returns `HuggingFaceModelDeployerConfig` config class. + + Returns: + The config class. + """ + return HuggingFaceModelDeployerConfig + + @property + def implementation_class(self) -> Type["HuggingFaceModelDeployer"]: + """Implementation class for this flavor. + + Returns: + The implementation class. + """ + from huggingface.hf_model_deployer import HuggingFaceModelDeployer + + return HuggingFaceModelDeployer diff --git a/llm-finetuning/run.py b/llm-finetuning/run.py index 8e6f189d..0ffb9f46 100644 --- a/llm-finetuning/run.py +++ b/llm-finetuning/run.py @@ -113,24 +113,25 @@ def main( ) pipeline_args = {"enable_cache": not no_cache} if config: - pipeline_args["config_path"] = os.path.join( - config_folder, config - ) - + pipeline_args["config_path"] = os.path.join(config_folder, config) + # Execute Feature Engineering Pipeline if feature_pipeline: pipeline_args = {} from pipelines import generate_code_dataset + generate_code_dataset.with_options(**pipeline_args)() logger.info("Feature Engineering pipeline finished successfully!\n") - + elif training_pipeline: from pipelines import finetune_starcoder + finetune_starcoder.with_options(**pipeline_args)() logger.info("Training pipeline finished successfully!\n") - + elif deploy_pipeline: from pipelines import huggingface_deployment + huggingface_deployment.with_options(**pipeline_args)() logger.info("Deployment pipeline finished successfully!\n") diff --git a/llm-finetuning/steps/deployment.py b/llm-finetuning/steps/deployment.py index b101f67c..5d6572bc 100644 --- a/llm-finetuning/steps/deployment.py +++ b/llm-finetuning/steps/deployment.py @@ -1,30 +1,16 @@ from zenml import step -from zenml.client import Client -from huggingface_hub import create_inference_endpoint, get_inference_endpoint -from zenml import ArtifactConfig -from typing_extensions import Annotated from zenml import get_step_context -from typing import Optional, Dict -import random -from zenml import log_artifact_metadata -from zenml.metadata.metadata_types import Uri +from zenml.client import Client +from typing import Optional, cast, Dict from zenml.logger import get_logger -import time +from huggingface.hf_deployment_service import ( + HuggingFaceDeploymentService, + HuggingFaceServiceConfig, +) +from huggingface.hf_model_deployer import HuggingFaceModelDeployer logger = get_logger(__name__) -POLLING_TIMEOUT = 1200 - - -def generate_random_letters(number_of_letters: int = 10) -> str: - """Generates three random letters. - - Returns: - Three random letters. - """ - letters = "abcdefghijklmnopqrstuvwxyz" - return "".join(random.choice(letters) for i in range(number_of_letters)) - def parse_huggingface_url(url): # Split the URL into parts @@ -43,47 +29,23 @@ def parse_huggingface_url(url): @step(enable_cache=False) -def deploy_model_to_hf_hub( - framework: str, - accelerator: str, - instance_size: str, - instance_type: str, - region: str, - vendor: str, - endpoint_name: Optional[str] = None, - account_id: Optional[str] = None, - min_replica: int = 0, - max_replica: int = 1, - task: Optional[str] = None, - namespace: Optional[str] = None, - custom_image: Optional[Dict] = None, - endpoint_type: str = "public", -) -> Annotated[ - str, - ArtifactConfig(name="huggingface_service", is_deployment_artifact=True), -]: +def deploy_model_to_hf_hub(hf_endpoint_cfg: Optional[Dict] = None) -> None: """Pushes the dataset to the Hugging Face Hub. Args: - framework: The framework of the model. - accelerator: The accelerator of the model. - instance_size: The instance size of the model. - instance_type: The instance type of the model. - region: The region of the model. - vendor: The vendor of the model. - endpoint_name: The name of the model. - account_id: The account id of the model. - min_replica: The minimum replica of the model. - max_replica: The maximum replica of the model. - task: The task of the model. - custom_image: The custom image of the model. - endpoint_type: The type of the model. + hf_endpoint_cfg: The configuration for the Huggingface endpoint. + """ + endpoint_name = None + hf_endpoint_cfg = HuggingFaceServiceConfig(**hf_endpoint_cfg) + secret = Client().get_secret("huggingface_creds") hf_token = secret.secret_values["token"] + commit_info = get_step_context().model.run_metadata[ "merged_model_commit_info" ].value + model_namespace, repository, revision = parse_huggingface_url(commit_info) if repository is None: @@ -92,74 +54,30 @@ def deploy_model_to_hf_hub( "Please make sure that the training pipeline is configured correctly." ) - if endpoint_name is None: - endpoint_name = generate_random_letters() - - endpoint = create_inference_endpoint( - name=endpoint_name, - repository=f"{model_namespace}/{repository}", - framework=framework, - accelerator=accelerator, - instance_size=instance_size, - instance_type=instance_type, - region=region, - vendor=vendor, - account_id=account_id, - min_replica=min_replica, - max_replica=max_replica, - revision=revision, - task=task, - custom_image=custom_image, - type=endpoint_type, - namespace=namespace, - token=hf_token, + if ( + hf_endpoint_cfg.repository is None + or hf_endpoint_cfg.revision is None + or hf_endpoint_cfg.token is None + ): + logger.warning( + "The Huggingface endpoint configuration has already been set via an old pipeline run. " + "The endpoint name, repository, and revision will be overwritten." + ) + hf_endpoint_cfg.repository = f"{model_namespace}/{repository}" + hf_endpoint_cfg.revision = revision + hf_endpoint_cfg.token = hf_token + + # TODO: Can check if the model deployer is of the right type + model_deployer = cast( + HuggingFaceModelDeployer, + HuggingFaceModelDeployer.get_active_model_deployer(), ) - - model_url = f"https://huggingface.co/{model_namespace}/{repository}" - if revision: - model_url = f"{model_url}/tree/{revision}" - - log_artifact_metadata( - metadata={ - "service_type": "huggingface", - "status": "active", - "description": "Huggingface Inference Endpoint", - "endpoint_name": Uri(endpoint.name), - "huggingface_model": Uri(model_url), - "framework": framework, - "accelerator": accelerator, - "instance_size": instance_size, - "instance_type": instance_type, - "region": region, - "min_replica": min_replica, - "max_replica": max_replica, - "revision": revision, - "task": task, - "type": endpoint_type, - } + service = cast( + HuggingFaceDeploymentService, + model_deployer.deploy_model(config=hf_endpoint_cfg), ) - # Wait for initialization - try: - # Add timelimit - start_time = time.time() - endpoint_url = None - while endpoint_url is None: - logger.info( - f"Waiting for {endpoint.name} to deploy. This might take a few minutes.." - ) - endpoint_url = get_inference_endpoint( - name=endpoint.name, token=hf_token, namespace=namespace - ).url - time.sleep(5) - if time.time() - start_time > POLLING_TIMEOUT: - break - log_artifact_metadata( - metadata={ - "endpoint_url": Uri(endpoint_url), - } - ) - except KeyboardInterrupt: - logger.info("Detected keyboard interrupt. Stopping polling.") - - return str(endpoint) + logger.info( + f"Huggingface Inference Endpoint deployment service started and reachable at:\n" + f" {service.prediction_url}\n" + )