diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore b/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore index 37e0880890..794ce1ca9c 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore @@ -4,3 +4,12 @@ airavata_python_sdk.egg-info .tox dist build +__pycache__/ +.DS_Store +.ipynb_checkpoints +*.egg-info/ +data/ +results*/ +plan.json +settings*.ini +auth.state \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py new file mode 100644 index 0000000000..f45c820c4f --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from . import base, plan +from .auth import login, logout +from .runtime import list_runtimes, Runtime + +__all__ = ["login", "logout", "list_runtimes", "base", "plan"] + +def display_runtimes(runtimes: list[Runtime]): + """ + Display runtimes in a tabular format + """ + import pandas as pd + + records = [] + for runtime in runtimes: + record = dict(id=runtime.id, **runtime.args) + records.append(record) + + return pd.DataFrame(records) + +def display_experiments(experiments: list[base.Experiment]): + """ + Display experiments in a tabular format + """ + import pandas as pd + + records = [] + for experiment in experiments: + record = dict(name=experiment.name, application=experiment.application.app_id, num_tasks=len(experiment.tasks)) + for k, v in experiment.inputs.items(): + record[k] = ", ".join(v) if isinstance(v, list) else str(v) + records.append(record) + + return pd.DataFrame(records) + +def display_plans(plans: list[plan.Plan]): + """ + Display plans in a tabular format + """ + import pandas as pd + + records = [] + for plan in plans: + for task in plan.tasks: + record = dict(plan_id=str(plan.id)) + for k, v in task.model_dump().items(): + record[k] = ", ".join(v) if isinstance(v, list) else str(v) + records.append(record) + + return pd.DataFrame(records) + +def display(arg): + + if isinstance(arg, list): + if all(isinstance(x, Runtime) for x in arg): + return display_runtimes(arg) + if all(isinstance(x, base.Experiment) for x in arg): + return display_experiments(arg) + if all(isinstance(x, plan.Plan) for x in arg): + return display_plans(arg) + else: + if isinstance(arg, Runtime): + return display_runtimes([arg]) + if isinstance(arg, base.Experiment): + return display_experiments([arg]) + if isinstance(arg, plan.Plan): + return display_plans([arg]) + + raise NotImplementedError(f"Cannot display object of type {type(arg)}") \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py new file mode 100644 index 0000000000..ebab013cba --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py @@ -0,0 +1,780 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from pathlib import Path +from typing import Literal, NamedTuple +from .sftp import SFTPConnector +import time +import warnings +import requests +from urllib.parse import urlparse +import uuid +import os +import base64 + +import jwt +from airavata.model.security.ttypes import AuthzToken +from airavata.model.experiment.ttypes import ExperimentModel, ExperimentType, UserConfigurationDataModel +from airavata.model.scheduling.ttypes import ComputationalResourceSchedulingModel +from airavata.model.data.replica.ttypes import DataProductModel, DataProductType, DataReplicaLocationModel, ReplicaLocationCategory +from airavata_sdk.clients.api_server_client import APIServerClient + +warnings.filterwarnings("ignore", category=DeprecationWarning) +logger = logging.getLogger("airavata_sdk.clients") +logger.setLevel(logging.INFO) + +LaunchState = NamedTuple("LaunchState", [ + ("experiment_id", str), + ("agent_ref", str), + ("process_id", str), + ("mount_point", Path), + ("experiment_dir", str), + ("sr_host", str), +]) + +class Settings: + + def __init__(self, config_path: str) -> None: + + import configparser + config = configparser.ConfigParser() + config.read(config_path) + + # api server client settings + self.API_SERVER_HOST = config.get('APIServer', 'API_HOST') + self.API_SERVER_PORT = config.getint('APIServer', 'API_PORT') + self.API_SERVER_SECURE = config.getboolean('APIServer', 'API_SECURE') + self.CONNECTION_SVC_URL = config.get('APIServer', 'CONNECTION_SVC_URL') + self.FILEMGR_SVC_URL = config.get('APIServer', 'FILEMGR_SVC_URL') + + # gateway settings + self.GATEWAY_ID = config.get('Gateway', 'GATEWAY_ID') + self.GATEWAY_URL = config.get('Gateway', 'GATEWAY_URL') + self.GATEWAY_DATA_STORE_DIR = config.get('Gateway', 'GATEWAY_DATA_STORE_DIR') + self.STORAGE_RESOURCE_HOST = config.get('Gateway', 'STORAGE_RESOURCE_HOST') + self.SFTP_PORT = config.get('Gateway', 'SFTP_PORT') + + # runtime-specific settings + self.PROJECT_NAME = config.get('User', 'PROJECT_NAME') + self.GROUP_RESOURCE_PROFILE_NAME = config.get('User', 'GROUP_RESOURCE_PROFILE_NAME') + + +class AiravataOperator: + + def register_input_file( + self, + file_identifier: str, + storage_name: str, + storageId: str, + gateway_id: str, + input_file_name: str, + uploaded_storage_path: str, + ) -> str: + + dataProductModel = DataProductModel( + gatewayId=gateway_id, + ownerName=self.user_id, + productName=file_identifier, + dataProductType=DataProductType.FILE, + replicaLocations=[ + DataReplicaLocationModel( + replicaName="{} gateway data store copy".format(input_file_name), + replicaLocationCategory=ReplicaLocationCategory.GATEWAY_DATA_STORE, + storageResourceId=storageId, + filePath="file://{}:{}".format(storage_name, uploaded_storage_path + input_file_name), + )], + ) + + return self.api_server_client.register_data_product(self.airavata_token, dataProductModel) # type: ignore + + def create_experiment_model( + self, + project_name: str, + application_name: str, + experiment_name: str, + description: str, + gateway_id: str, + ) -> ExperimentModel: + + execution_id = self.get_app_interface_id(application_name) + project_id = self.get_project_id(project_name) + return ExperimentModel( + experimentName=experiment_name, + gatewayId=gateway_id, + userName=self.user_id, + description=description, + projectId=project_id, + experimentType=ExperimentType.SINGLE_APPLICATION, + executionId=execution_id + ) + + def get_resource_host_id(self, resource_name): + resources: dict = self.api_server_client.get_all_compute_resource_names(self.airavata_token) # type: ignore + return next((str(k) for k, v in resources.items() if v == resource_name)) + + def configure_computation_resource_scheduling( + self, + experiment_model: ExperimentModel, + computation_resource_name: str, + group_resource_profile_name: str, + storageId: str, + node_count: int, + total_cpu_count: int, + queue_name: str, + wall_time_limit: int, + experiment_dir_path: str, + auto_schedule=False, + ) -> ExperimentModel: + resource_host_id = self.get_resource_host_id(computation_resource_name) + groupResourceProfileId = self.get_group_resource_profile_id(group_resource_profile_name) + computRes = ComputationalResourceSchedulingModel() + computRes.resourceHostId = resource_host_id + computRes.nodeCount = node_count + computRes.totalCPUCount = total_cpu_count + computRes.queueName = queue_name + computRes.wallTimeLimit = wall_time_limit + + userConfigData = UserConfigurationDataModel() + userConfigData.computationalResourceScheduling = computRes + + userConfigData.groupResourceProfileId = groupResourceProfileId + userConfigData.storageId = storageId + + userConfigData.experimentDataDir = experiment_dir_path + userConfigData.airavataAutoSchedule = auto_schedule + experiment_model.userConfigurationData = userConfigData + + return experiment_model + + def __init__(self, access_token: str, config_file: str = "settings.ini"): + # store variables + self.access_token = access_token + self.settings = Settings(config_file) + # load api server settings and create client + self.api_server_client = APIServerClient(api_server_settings=self.settings) + # load gateway settings + gateway_id = self.default_gateway_id() + self.airavata_token = self.__airavata_token__(self.access_token, gateway_id) + + def default_gateway_id(self): + return self.settings.GATEWAY_ID + + def default_gateway_grp_name(self): + return self.settings.GROUP_RESOURCE_PROFILE_NAME + + def default_gateway_data_store_dir(self): + return self.settings.GATEWAY_DATA_STORE_DIR + + def default_sftp_port(self): + return self.settings.SFTP_PORT + + def default_sr_hostname(self): + return self.settings.STORAGE_RESOURCE_HOST + + def default_project_name(self): + return self.settings.PROJECT_NAME + + def connection_svc_url(self): + return self.settings.CONNECTION_SVC_URL + + def filemgr_svc_url(self): + return self.settings.FILEMGR_SVC_URL + + def __airavata_token__(self, access_token: str, gateway_id: str): + """ + Decode access token (string) and create AuthzToken (object) + + """ + decode = jwt.decode(access_token, options={"verify_signature": False}) + self.user_id = str(decode["preferred_username"]) + claimsMap = {"userName": self.user_id, "gatewayID": gateway_id} + return AuthzToken(accessToken=self.access_token, claimsMap=claimsMap) + + def get_experiment(self, experiment_id: str): + """ + Get experiment by id + + """ + return self.api_server_client.get_experiment(self.airavata_token, experiment_id) + + def get_process_id(self, experiment_id: str) -> str: + """ + Get process id by experiment id + + """ + tree: any = self.api_server_client.get_detailed_experiment_tree(self.airavata_token, experiment_id) # type: ignore + processModels: list = tree.processes + assert len(processModels) == 1, f"Expected 1 process model, got {len(processModels)}" + return processModels[0].processId + + def get_accessible_apps(self, gateway_id: str | None = None): + """ + Get all applications available in the gateway + + """ + # use defaults for missing values + gateway_id = gateway_id or self.default_gateway_id() + # logic + app_interfaces = self.api_server_client.get_all_application_interfaces(self.airavata_token, gateway_id) + return app_interfaces + + def get_preferred_storage(self, gateway_id: str | None = None, sr_hostname: str | None = None): + """ + Get preferred storage resource + + """ + # use defaults for missing values + gateway_id = gateway_id or self.default_gateway_id() + sr_hostname = sr_hostname or self.default_sr_hostname() + # logic + sr_names: dict[str, str] = self.api_server_client.get_all_storage_resource_names(self.airavata_token) # type: ignore + sr_id = next((str(k) for k, v in sr_names.items() if v == sr_hostname)) + return self.api_server_client.get_gateway_storage_preference(self.airavata_token, gateway_id, sr_id) + + def get_storage(self, storage_name: str | None = None) -> any: # type: ignore + """ + Get storage resource by name + + """ + # use defaults for missing values + storage_name = storage_name or self.default_sr_hostname() + # logic + sr_names: dict[str, str] = self.api_server_client.get_all_storage_resource_names(self.airavata_token) # type: ignore + sr_id = next((str(k) for k, v in sr_names.items() if v == storage_name)) + storage = self.api_server_client.get_storage_resource(self.airavata_token, sr_id) + return storage + + def get_group_resource_profile_id(self, grp_name: str | None = None) -> str: + """ + Get group resource profile id by name + + """ + # use defaults for missing values + grp_name = grp_name or self.default_gateway_grp_name() + # logic + grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) + return str(grp_id) + + def get_group_resource_profile(self, grp_id: str): + grp: any = self.api_server_client.get_group_resource_profile(self.airavata_token, grp_id) # type: ignore + return grp + + def get_compatible_deployments(self, app_interface_id: str, grp_name: str | None = None): + """ + Get compatible deployments for an application interface and group resource profile + + """ + # use defaults for missing values + grp_name = grp_name or self.default_gateway_grp_name() + # logic + grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) + deployments = self.api_server_client.get_application_deployments_for_app_module_and_group_resource_profile(self.airavata_token, app_interface_id, grp_id) + return deployments + + def get_app_interface_id(self, app_name: str, gateway_id: str | None = None): + """ + Get application interface id by name + + """ + gateway_id = str(gateway_id or self.default_gateway_id()) + apps: list = self.api_server_client.get_all_application_interfaces(self.airavata_token, gateway_id) # type: ignore + app_id = next((app.applicationInterfaceId for app in apps if app.applicationName == app_name)) + return str(app_id) + + def get_project_id(self, project_name: str, gateway_id: str | None = None): + gateway_id = str(gateway_id or self.default_gateway_id()) + projects: list = self.api_server_client.get_user_projects(self.airavata_token, gateway_id, self.user_id, 10, 0) # type: ignore + project_id = next((p.projectID for p in projects if p.name == project_name and p.owner == self.user_id)) + return str(project_id) + + def get_application_inputs(self, app_interface_id: str) -> list: + """ + Get application inputs by id + + """ + return list(self.api_server_client.get_application_inputs(self.airavata_token, app_interface_id)) # type: ignore + + def get_compute_resources_by_ids(self, resource_ids: list[str]): + """ + Get compute resources by ids + + """ + return [self.api_server_client.get_compute_resource(self.airavata_token, resource_id) for resource_id in resource_ids] + + def make_experiment_dir(self, sr_host: str, project_name: str, experiment_name: str) -> str: + """ + Make experiment directory on storage resource, and return the remote path + + Return Path: /{project_name}/{experiment_name} + + """ + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + remote_path = sftp_connector.mkdir(project_name, experiment_name) + logger.info("Experiment directory created at %s", remote_path) + return remote_path + + def upload_files(self, process_id: str | None, agent_ref: str | None, sr_host: str, local_files: list[Path], remote_dir: str) -> list[str]: + """ + Upload local files to a remote directory of a storage resource + TODO add data_svc fallback + + Return Path: /{project_name}/{experiment_name} + + """ + + # step = experiment staging + if process_id is None and agent_ref is None: + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + paths = sftp_connector.put(local_files, remote_dir) + logger.info(f"{len(paths)} Local files uploaded to remote dir: %s", remote_dir) + return paths + + # step = post-staging file upload + elif process_id is not None and agent_ref is not None: + assert len(local_files) == 1, f"Expected 1 file, got {len(local_files)}" + file = local_files[0] + fp = os.path.join("/data", file.name) + rawdata = file.read_bytes() + b64data = base64.b64encode(rawdata).decode() + res = requests.post(f"{self.connection_svc_url()}/agent/executecommandrequest", json={ + "agentId": agent_ref, + "workingDir": ".", + "arguments": ["sh", "-c", f"echo {b64data} | base64 -d > {fp}"] + }) + data = res.json() + if data["error"] is not None: + if str(data["error"]) == "Agent not found": + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=sr_host, port=int(port), username=self.user_id, password=self.access_token) + paths = sftp_connector.put(local_files, remote_dir) + return paths + else: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"{self.connection_svc_url()}/agent/executecommandresponse/{exc_id}") + data = res.json() + if data["available"]: + return [fp] + time.sleep(1) + + # step = unknown + else: + raise ValueError("Invalid arguments for upload_files") + + # file manager service fallback + assert process_id is not None, f"Expected process_id, got {process_id}" + file = local_files[0] + url_path = os.path.join(process_id, file.name) + filemgr_svc_upload_url = f"{self.filemgr_svc_url()}/upload/live/{url_path}" + + def list_files(self, process_id: str, agent_ref: str, sr_host: str, remote_dir: str) -> list[str]: + """ + List files in a remote directory of a storage resource + TODO add data_svc fallback + + Return Path: /{project_name}/{experiment_name} + + """ + res = requests.post(f"{self.connection_svc_url()}/agent/executecommandrequest", json={ + "agentId": agent_ref, + "workingDir": ".", + "arguments": ["sh", "-c", r"find /data -type d -name 'venv' -prune -o -type f -printf '%P\n' | sort"] + }) + data = res.json() + if data["error"] is not None: + if str(data["error"]) == "Agent not found": + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=sr_host, port=int(port), username=self.user_id, password=self.access_token) + return sftp_connector.ls(remote_dir) + else: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"{self.connection_svc_url()}/agent/executecommandresponse/{exc_id}") + data = res.json() + if data["available"]: + files = data["responseString"].split("\n") + return files + time.sleep(1) + + # file manager service fallback + assert process_id is not None, f"Expected process_id, got {process_id}" + filemgr_svc_ls_url = f"{self.filemgr_svc_url()}/list/live/{process_id}" + + def download_file(self, process_id: str, agent_ref: str, sr_host: str, remote_file: str, remote_dir: str, local_dir: str) -> str: + """ + Download files from a remote directory of a storage resource to a local directory + TODO add data_svc fallback + + Return Path: /{project_name}/{experiment_name} + + """ + import os + fp = os.path.join("/data", remote_file) + res = requests.post(f"{self.connection_svc_url()}/agent/executecommandrequest", json={ + "agentId": agent_ref, + "workingDir": ".", + "arguments": ["sh", "-c", f"cat {fp} | base64 -w0"] + }) + data = res.json() + if data["error"] is not None: + if str(data["error"]) == "Agent not found": + port = self.default_sftp_port() + fp = os.path.join(remote_dir, remote_file) + sftp_connector = SFTPConnector(host=sr_host, port=int(port), username=self.user_id, password=self.access_token) + path = sftp_connector.get(fp, local_dir) + return path + else: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"{self.connection_svc_url()}/agent/executecommandresponse/{exc_id}") + data = res.json() + if data["available"]: + content = data["responseString"] + import base64 + content = base64.b64decode(content) + path = Path(local_dir) / remote_file + with open(path, "wb") as f: + f.write(content) + return path.as_posix() + time.sleep(1) + + # file manager service fallback + assert process_id is not None, f"Expected process_id, got {process_id}" + url_path = os.path.join(process_id, remote_file) + filemgr_svc_download_url = f"{self.filemgr_svc_url()}/download/live/{url_path}" + + def cat_file(self, process_id: str, agent_ref: str, sr_host: str, remote_file: str, remote_dir: str) -> bytes: + """ + Download files from a remote directory of a storage resource to a local directory + TODO add data_svc fallback + + Return Path: /{project_name}/{experiment_name} + + """ + import os + fp = os.path.join("/data", remote_file) + res = requests.post(f"{self.connection_svc_url()}/agent/executecommandrequest", json={ + "agentId": agent_ref, + "workingDir": ".", + "arguments": ["sh", "-c", f"cat {fp} | base64 -w0"] + }) + data = res.json() + if data["error"] is not None: + if str(data["error"]) == "Agent not found": + port = self.default_sftp_port() + fp = os.path.join(remote_dir, remote_file) + sftp_connector = SFTPConnector(host=sr_host, port=int(port), username=self.user_id, password=self.access_token) + data = sftp_connector.cat(fp) + return data + else: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"{self.connection_svc_url()}/agent/executecommandresponse/{exc_id}") + data = res.json() + if data["available"]: + content = data["responseString"] + import base64 + content = base64.b64decode(content) + return content + time.sleep(1) + + # file manager service fallback + assert process_id is not None, f"Expected process_id, got {process_id}" + url_path = os.path.join(process_id, remote_file) + filemgr_svc_download_url = f"{self.filemgr_svc_url()}/download/live/{url_path}" + + def launch_experiment( + self, + experiment_name: str, + app_name: str, + inputs: dict[str, dict[str, str | int | float | list[str]]], + computation_resource_name: str, + queue_name: str, + node_count: int, + cpu_count: int, + walltime: int, + *, + gateway_id: str | None = None, + grp_name: str | None = None, + sr_host: str | None = None, + project_name: str | None = None, + auto_schedule: bool = False, + ) -> LaunchState: + """ + Launch an experiment and return its id + + """ + # preprocess args (str) + print("[AV] Preprocessing args...") + gateway_id = str(gateway_id or self.default_gateway_id()) + grp_name = str(grp_name or self.default_gateway_grp_name()) + sr_host = str(sr_host or self.default_sr_hostname()) + mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id + project_name = str(project_name or self.default_project_name()) + server_url = urlparse(self.connection_svc_url()).netloc + + # validate args (str) + print("[AV] Validating args...") + assert len(experiment_name) > 0, f"Invalid experiment_name: {experiment_name}" + assert len(app_name) > 0, f"Invalid app_name: {app_name}" + assert len(computation_resource_name) > 0, f"Invalid computation_resource_name: {computation_resource_name}" + assert len(inputs) > 0, f"Invalid inputs: {inputs}" + assert len(gateway_id) > 0, f"Invalid gateway_id: {gateway_id}" + assert len(queue_name) > 0, f"Invalid queue_name: {queue_name}" + assert len(grp_name) > 0, f"Invalid grp_name: {grp_name}" + assert len(sr_host) > 0, f"Invalid sr_host: {sr_host}" + assert len(project_name) > 0, f"Invalid project_name: {project_name}" + assert len(mount_point.as_posix()) > 0, f"Invalid mount_point: {mount_point}" + + # validate args (int) + assert node_count > 0, f"Invalid node_count: {node_count}" + assert cpu_count > 0, f"Invalid cpu_count: {cpu_count}" + assert walltime > 0, f"Invalid walltime: {walltime}" + + # parse and validate inputs + file_inputs = dict[str, Path | list[Path]]() + data_inputs = dict[str, str | int | float]() + for input_name, input_spec in inputs.items(): + input_type = input_spec["type"] + input_value = input_spec["value"] + if input_type == "uri": + assert isinstance(input_value, str) and os.path.isfile(str(input_value)), f"Invalid {input_name}: {input_value}" + file_inputs[input_name] = Path(input_value) + elif input_type == "uri[]": + assert isinstance(input_value, list) and all([os.path.isfile(str(v)) for v in input_value]), f"Invalid {input_name}: {input_value}" + file_inputs[input_name] = [Path(v) for v in input_value] + else: + assert isinstance(input_value, (int, float, str)), f"Invalid {input_name}: {input_value}" + data_inputs[input_name] = input_value + data_inputs.update({"agent_id": data_inputs.get("agent_id", str(uuid.uuid4()))}) + data_inputs.update({"server_url": server_url}) + + # setup runtime params + print("[AV] Setting up runtime params...") + storage = self.get_storage(sr_host) + sr_id = storage.storageResourceId + + # setup application interface + print("[AV] Setting up application interface...") + app_interface_id = self.get_app_interface_id(app_name) + assert app_interface_id is not None, f"Invalid app_interface_id: {app_interface_id}" + + # setup experiment + print("[AV] Setting up experiment...") + experiment = self.create_experiment_model( + experiment_name=experiment_name, + application_name=app_name, + project_name=project_name, + description=experiment_name, + gateway_id=gateway_id, + ) + # setup experiment directory + print("[AV] Setting up experiment directory...") + exp_dir = self.make_experiment_dir( + sr_host=storage.hostName, + project_name=project_name, + experiment_name=experiment_name, + ) + abs_path = (mount_point / exp_dir.lstrip("/")).as_posix().rstrip("/") + "/" + print("[AV] exp_dir:", exp_dir) + print("[AV] abs_path:", abs_path) + + experiment = self.configure_computation_resource_scheduling( + experiment_model=experiment, + computation_resource_name=computation_resource_name, + group_resource_profile_name=grp_name, + storageId=sr_id, + node_count=node_count, + total_cpu_count=cpu_count, + wall_time_limit=walltime, + queue_name=queue_name, + experiment_dir_path=abs_path, + auto_schedule=auto_schedule, + ) + + def register_input_file(file: Path) -> str: + return str(self.register_input_file(file.name, sr_host, sr_id, gateway_id, file.name, abs_path)) + + # set up file inputs + print("[AV] Setting up file inputs...") + files_to_upload = list[Path]() + file_refs = dict[str, str | list[str]]() + for key, value in file_inputs.items(): + if isinstance(value, Path): + files_to_upload.append(value) + file_refs[key] = register_input_file(value) + elif isinstance(value, list): + assert all([isinstance(v, Path) for v in value]), f"Invalid file input value: {value}" + files_to_upload.extend(value) + file_refs[key] = [*map(register_input_file, value)] + else: + raise ValueError("Invalid file input type") + + # configure experiment inputs + experiment_inputs = [] + for exp_input in self.api_server_client.get_application_inputs(self.airavata_token, app_interface_id): # type: ignore + if exp_input.type < 3 and exp_input.name in data_inputs: + value = data_inputs[exp_input.name] + if exp_input.type == 0: + exp_input.value = str(value) + else: + exp_input.value = repr(value) + elif exp_input.type == 3 and exp_input.name in file_refs: + exp_input.value = file_refs[exp_input.name] + elif exp_input.type == 4 and exp_input.name in file_refs: + exp_input.value = ','.join(file_refs[exp_input.name]) + experiment_inputs.append(exp_input) + experiment.experimentInputs = experiment_inputs + + # configure experiment outputs + outputs = self.api_server_client.get_application_outputs(self.airavata_token, app_interface_id) + experiment.experimentOutputs = outputs + + # upload file inputs for experiment + print(f"[AV] Uploading {len(files_to_upload)} file inputs for experiment...") + self.upload_files(None, None, storage.hostName, files_to_upload, exp_dir) + + # create experiment + ex_id = self.api_server_client.create_experiment(self.airavata_token, gateway_id, experiment) + ex_id = str(ex_id) + print(f"[AV] Experiment {experiment_name} CREATED with id: {ex_id}") + + # launch experiment + self.api_server_client.launch_experiment(self.airavata_token, ex_id, gateway_id) + print(f"[AV] Experiment {experiment_name} STARTED with id: {ex_id}") + + # get process id + print(f"[AV] Experiment {experiment_name} WAITING until experiment begins...") + process_id = None + while process_id is None: + try: + process_id = self.get_process_id(ex_id) + except: + time.sleep(2) + else: + time.sleep(2) + print(f"[AV] Experiment {experiment_name} EXECUTING with pid: {process_id}") + + return LaunchState( + experiment_id=ex_id, + agent_ref=str(data_inputs["agent_id"]), + process_id=process_id, + mount_point=mount_point, + experiment_dir=exp_dir, + sr_host=storage.hostName, + ) + + def get_experiment_status(self, experiment_id: str) -> Literal["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", "CANCELING", "CANCELED", "COMPLETED", "FAILED"]: + states = ["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", "CANCELING", "CANCELED", "COMPLETED", "FAILED"] + status: any = self.api_server_client.get_experiment_status(self.airavata_token, experiment_id) # type: ignore + return states[status.state] + + def stop_experiment(self, experiment_id: str): + status = self.api_server_client.terminate_experiment( + self.airavata_token, experiment_id, self.default_gateway_id()) + return status + + def execute_py(self, libraries: list[str], code: str, agent_id: str, pid: str, runtime_args: dict, cold_start: bool = True) -> str | None: + # lambda to send request + print(f"[av] Attempting to submit to agent {agent_id}...") + make_request = lambda: requests.post(f"{self.connection_svc_url()}/agent/executepythonrequest", json={ + "libraries": libraries, + "code": code, + "pythonVersion": "3.10", # TODO verify + "keepAlive": False, # TODO verify + "parentExperimentId": "/data", # the working directory + "agentId": agent_id, + }) + try: + if cold_start: + res = make_request() + data = res.json() + if data["error"] == "Agent not found": + # waiting for agent to be available + print(f"[av] Agent {agent_id} not found! Relaunching...") + self.launch_experiment( + experiment_name="Agent", + app_name="AiravataAgent", + inputs={ + "agent_id": {"type": "str", "value": agent_id}, + "server_url": {"type": "str", "value": urlparse(self.connection_svc_url()).netloc}, + "process_id": {"type": "str", "value": pid}, + }, + computation_resource_name=runtime_args["cluster"], + queue_name=runtime_args["queue_name"], + node_count=1, + cpu_count=runtime_args["cpu_count"], + walltime=runtime_args["walltime"], + ) + return self.execute_py(libraries, code, agent_id, pid, runtime_args, cold_start=False) + elif data["executionId"] is not None: + print(f"[av] Submitted to Python Interpreter") + # agent response + exc_id = data["executionId"] + else: + # unrecoverable error + raise Exception(data["error"]) + else: + # poll until agent is available + while True: + res = make_request() + data = res.json() + if data["error"] == "Agent not found": + # print(f"[av] Waiting for Agent {agent_id}...") + time.sleep(2) + continue + elif data["executionId"] is not None: + print(f"[av] Submitted to Python Interpreter") + exc_id = data["executionId"] + break + else: + raise Exception(data["error"]) + assert exc_id is not None, f"Invalid execution id: {exc_id}" + + # wait for the execution response to be available + while True: + res = requests.get(f"{self.connection_svc_url()}/agent/executepythonresponse/{exc_id}") + data = res.json() + if data["available"]: + response = str(data["responseString"]) + return response + time.sleep(1) + except Exception as e: + print(f"[av] Remote execution failed! {e}") + return None + + def get_available_runtimes(self): + from .runtime import Remote + return [ + Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, walltime=30), + Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, walltime=30), + Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, walltime=30), + ] diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/__init__.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/__init__.py new file mode 100644 index 0000000000..86bf71e7f5 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/__init__.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .device_auth import DeviceFlowAuthenticator + +context = DeviceFlowAuthenticator( + idp_url="https://auth.cybershuttle.org", + realm="10000000", + client_id="cybershuttle-agent", +) + + +def login(): + context.login() + + +def logout(): + context.logout() diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py new file mode 100644 index 0000000000..165d1a7175 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import json +import os +import time +import webbrowser + +import jwt +import requests + + +class DeviceFlowAuthenticator: + + idp_url: str + realm: str + client_id: str + interval: int + device_code: str | None + _access_token: str | None + _refresh_token: str | None + + def __has_expired__(self, token: str) -> bool: + try: + decoded = jwt.decode(token, options={"verify_signature": False}) + tA = datetime.datetime.now(datetime.timezone.utc).timestamp() + tB = int(decoded.get("exp", 0)) + return tA >= tB + except: + return True + + @property + def access_token(self) -> str: + if self._access_token and not self.__has_expired__(self._access_token): + return self._access_token + elif self._refresh_token and not self.__has_expired__(self._refresh_token): + self.refresh() + else: + self.login() + assert self._access_token + return self._access_token + + @property + def refresh_token(self) -> str: + if self._refresh_token and not self.__has_expired__(self._refresh_token): + return self._refresh_token + else: + self.login() + assert self._refresh_token + return self._refresh_token + + + def __init__( + self, + idp_url: str, + realm: str, + client_id: str, + ): + self.idp_url = idp_url + self.realm = realm + self.client_id = client_id + + if not self.client_id or not self.realm or not self.idp_url: + raise ValueError( + "Missing required environment variables for client ID, realm, or auth server URL") + + self.interval = 5 + self.device_code = None + self._access_token = None + self._refresh_token = None + + def refresh(self) -> None: + auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token" + response = requests.post(auth_device_url, data={ + "client_id": self.client_id, + "grant_type": "refresh_token", + "scope": "openid", + "refresh_token": self._refresh_token + }) + if response.status_code != 200: + raise Exception(f"Error in token refresh request: {response.status_code} - {response.text}") + data = response.json() + self._refresh_token = data["refresh_token"] + self._access_token = data["access_token"] + assert self._access_token is not None + assert self._refresh_token is not None + self.__persist_token__(self._refresh_token, self._access_token) + + def login(self, interactive: bool = True) -> None: + + try: + # [Flow A] Reuse saved token + if os.path.exists("auth.state"): + try: + # [A1] Load token from file + with open("auth.state", "r") as f: + data = json.load(f) + self._refresh_token = str(data["refresh_token"]) + self._access_token = str(data["access_token"]) + except: + print("Failed to load auth.state file!") + else: + # [A2] Check if access token is valid, if so, return + if not self.__has_expired__(self._access_token): + print("Authenticated via saved access token!") + return None + else: + print("Access token is invalid!") + # [A3] Check if refresh token is valid. if so, refresh + try: + if not self.__has_expired__(self._refresh_token): + self.refresh() + print("Authenticated via saved refresh token!") + return None + else: + print("Refresh token is invalid!") + except Exception as e: + print(*e.args) + + # [Flow B] Request device and user code + + # [B1] Initiate device auth flow + auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/auth/device" + response = requests.post(auth_device_url, data={ + "client_id": self.client_id, + "scope": "openid", + }) + if response.status_code != 200: + raise Exception(f"Error in device authorization request: {response.status_code} - {response.text}") + data = response.json() + self.device_code = data.get("device_code", self.device_code) + self.interval = data.get("interval", self.interval) + url = data['verification_uri_complete'] + print(f"Please authenticate by visiting: {url}") + if interactive: + webbrowser.open(url) + + # [B2] Poll until token is received + token_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token" + print("Waiting for authorization...") + while True: + response = requests.post( + token_url, + data={ + "client_id": self.client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": self.device_code, + }, + ) + if response.status_code == 200: + data = response.json() + self.__persist_token__(data["refresh_token"], data["access_token"]) + print("Authenticated via device auth!") + return + elif response.status_code == 400 and response.json().get("error") == "authorization_pending": + time.sleep(self.interval) + else: + raise Exception(f"Authorization error: {response.status_code} - {response.text}") + + except Exception as e: + print("login() failed!", e) + + def logout(self) -> None: + self._access_token = None + self._refresh_token = None + + def __persist_token__(self, refresh_token: str, access_token: str) -> None: + self._access_token = access_token + self._refresh_token = refresh_token + import json + with open("auth.state", "w") as f: + json.dump({"refresh_token": self._refresh_token, "access_token": self._access_token}, f) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py new file mode 100644 index 0000000000..1967cf6f73 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import abc +from itertools import product +from typing import Any, Generic, TypeVar +import uuid +import random + +from .plan import Plan +from .runtime import Runtime +from .task import Task + + +class GUIApp: + + app_id: str + + def __init__(self, app_id: str) -> None: + self.app_id = app_id + + def open(self, runtime: Runtime, location: str) -> None: + """ + Open the GUI application + """ + raise NotImplementedError() + + @classmethod + @abc.abstractmethod + def initialize(cls, **kwargs) -> GUIApp: ... + + +class ExperimentApp: + + app_id: str + + def __init__(self, app_id: str) -> None: + self.app_id = app_id + + @classmethod + @abc.abstractmethod + def initialize(cls, **kwargs) -> Experiment: ... + + +T = TypeVar("T", ExperimentApp, GUIApp) + + +class Experiment(Generic[T], abc.ABC): + + name: str + application: T + inputs: dict[str, Any] + input_mapping: dict[str, tuple[Any, str]] + resource: Runtime = Runtime.default() + tasks: list[Task] = [] + + def __init__(self, name: str, application: T): + self.name = name + self.application = application + self.input_mapping = {} + + def with_inputs(self, **inputs: Any) -> Experiment[T]: + """ + Add shared inputs to the experiment + """ + self.inputs = inputs + return self + + def with_resource(self, resource: Runtime) -> Experiment[T]: + self.resource = resource + return self + + def add_replica(self, *allowed_runtimes: Runtime) -> None: + """ + Add a replica to the experiment. + This will create a copy of the application with the given inputs. + + """ + runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource + uuid_str = str(uuid.uuid4())[:4].upper() + + self.tasks.append( + Task( + name=f"{self.name}_{uuid_str}", + app_id=self.application.app_id, + inputs={**self.inputs}, + runtime=runtime, + ) + ) + print(f"Added replica. ({len(self.tasks)} tasks in total)") + + def add_sweep(self, *allowed_runtimes: Runtime, **space: list) -> None: + """ + Add a sweep to the experiment. + + """ + for values in product(space.values()): + runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource + uuid_str = str(uuid.uuid4())[:4].upper() + + task_specific_params = dict(zip(space.keys(), values)) + agg_inputs = {**self.inputs, **task_specific_params} + task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in self.input_mapping.items()} + + self.tasks.append(Task( + name=f"{self.name}_{uuid_str}", + app_id=self.application.app_id, + inputs=task_inputs, + runtime=runtime or self.resource, + )) + + def plan(self, **kwargs) -> Plan: + if len(self.tasks) == 0: + self.add_replica(self.resource) + tasks = [] + for t in self.tasks: + agg_inputs = {**self.inputs, **t.inputs} + task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in self.input_mapping.items()} + tasks.append(Task(name=t.name, app_id=self.application.app_id, inputs=task_inputs, runtime=t.runtime)) + return Plan(tasks=tasks) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/md/__init__.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/md/__init__.py new file mode 100644 index 0000000000..5c3b0d790c --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/md/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .applications import NAMD, VMD + +__all__ = ["NAMD", "VMD"] diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/md/applications.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/md/applications.py new file mode 100644 index 0000000000..b68b1f7acb --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/md/applications.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Literal +from ..base import Experiment, ExperimentApp, GUIApp + + +class NAMD(ExperimentApp): + """ + Nanoscale Molecular Dynamics (NAMD, formerly Not Another Molecular Dynamics Program) + is a computer software for molecular dynamics simulation, written using the Charm++ + parallel programming model (not to be confused with CHARMM). + It is noted for its parallel efficiency and is often used to simulate large systems + (millions of atoms). It has been developed by the collaboration of the Theoretical + and Computational Biophysics Group (TCB) and the Parallel Programming Laboratory (PPL) + at the University of Illinois Urbana–Champaign. + + """ + + def __init__( + self, + ) -> None: + super().__init__(app_id="NAMD") + + @classmethod + def initialize( + cls, + name: str, + config_file: str, + pdb_file: str, + psf_file: str, + ffp_files: list[str], + other_files: list[str] = [], + parallelism: Literal["CPU", "GPU"] = "CPU", + num_replicas: int = 1, + ) -> Experiment[ExperimentApp]: + app = cls() + obj = Experiment[ExperimentApp](name, app).with_inputs( + config_file=config_file, + pdb_file=pdb_file, + psf_file=psf_file, + ffp_files=ffp_files, + parallelism=parallelism, + other_files=other_files, + num_replicas=num_replicas, + ) + obj.input_mapping = { + "MD-Instructions-Input": ("config_file", "uri"), # uri? [REQUIRED] + "Coordinates-PDB-File": ("pdb_file", "uri"), # uri? [OPTIONAL] + "Protein-Structure-File_PSF": ("psf_file", "uri"), # uri? [REQUIRED] + "FF-Parameter-Files": ("ffp_files", "uri[]"), # uri[]? [REQUIRED] + "Execution_Type": ("parallelism", "str"), # "CPU" | "GPU" [REQUIRED] + "Optional_Inputs": ("other_files", "uri[]"), # uri[]? [OPTIONAL] + "Number of Replicas": ("num_replicas", "str"), # integer [REQUIRED] + # "Constraints-PDB": ("pdb_file", "uri"), # uri? [OPTIONAL] + # "Replicate": (None, "str"), # "yes"? [OPTIONAL] + # "Continue_from_Previous_Run?": (None, "str"), # "yes"? [OPTIONAL] + # "Previous_JobID": (None, "str"), # string? [OPTIONAL] [show if "Continue_from_Previous_Run?" == "yes"] + # "GPU Resource Warning": (None, "str"), # string? [OPTIONAL] [show if "Continue_from_Previous_Run?" == "yes"] + # "Restart_Replicas_List": (None, "str[]"), # string [OPTIONAL] [show if "Continue_from_Previous_Run?" == "yes"] + } + obj.tasks = [] + return obj + + +class VMD(GUIApp): + """ + Visual Molecular Dynamics (VMD) is a molecular visualization and analysis program + designed for biological systems such as proteins, nucleic acids, lipid bilayer assemblies, + etc. It also includes tools for working with volumetric data, sequence data, and arbitrary + graphics objects. VMD can be used to animate and analyze the trajectory of molecular dynamics + simulations, and can interactively manipulate molecules being simulated on remote computers + (Interactive MD). + + """ + + def __init__( + self, + ) -> None: + super().__init__(app_id="vmd") + + @classmethod + def initialize( + cls, + name: str, + ) -> GUIApp: + app = cls() + return app diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py new file mode 100644 index 0000000000..b6cdaa497c --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import json +import time +import os + +import pydantic +from rich.progress import Progress +from .runtime import is_terminal_state +from .task import Task +import uuid + +from .airavata import AiravataOperator +from .auth import context + +class Plan(pydantic.BaseModel): + + id: str | None = pydantic.Field(default=None) + tasks: list[Task] = [] + + @pydantic.field_validator("tasks", mode="before") + def default_tasks(cls, v): + if isinstance(v, list): + return [Task(**task) if isinstance(task, dict) else task for task in v] + return v + + def __stage_prepare__(self) -> None: + print("Preparing execution plan...") + + def __stage_confirm__(self, silent: bool) -> None: + print("Confirming execution plan...") + if not silent: + while True: + res = input("Here is the execution plan. continue? (Y/n) ") + if res.upper() in ["N"]: + raise Exception("Execution was aborted by user.") + elif res.upper() in ["Y", ""]: + break + else: + continue + + def __stage_launch_task__(self) -> None: + print("Launching tasks...") + for task in self.tasks: + task.launch() + + def __stage_status__(self) -> list: + statuses = [] + for task in self.tasks: + statuses.append(task.status()) + return statuses + + def __stage_stop__(self) -> None: + print("Stopping task(s)...") + for task in self.tasks: + task.stop() + print("Task(s) stopped.") + + def __stage_fetch__(self, local_dir: str) -> list[list[str]]: + print("Fetching results...") + fps = list[list[str]]() + for task in self.tasks: + fps.append(task.download_all(local_dir)) + print("Results fetched.") + self.save_json(os.path.join(local_dir, "plan.json")) + return fps + + def launch(self, silent: bool = False) -> None: + try: + self.__stage_prepare__() + self.__stage_confirm__(silent) + self.__stage_launch_task__() + self.save() + except Exception as e: + print(*e.args, sep="\n") + + def status(self) -> None: + statuses = self.__stage_status__() + print(f"Plan {self.id} ({len(self.tasks)} tasks):") + for task, status in zip(self.tasks, statuses): + print(f"* {task.name}: {status}") + + def wait_for_completion(self, check_every_n_mins: float = 0.1) -> None: + n = len(self.tasks) + try: + with Progress() as progress: + pbars = [progress.add_task(f"{task.name} ({i+1}/{n}): CHECKING", total=None) for i, task in enumerate(self.tasks)] + while True: + completed = [False] * n + statuses = self.__stage_status__() + for i, (task, status, pbar) in enumerate(zip(self.tasks, statuses, pbars)): + completed[i] = is_terminal_state(status) + progress.update(pbar, description=f"{task.name} ({i+1}/{n}): {status}", completed=completed[i], refresh=True) + if all(completed): + break + sleep_time = check_every_n_mins * 60 + time.sleep(sleep_time) + print("All tasks completed.") + except KeyboardInterrupt: + print("Interrupted by user.") + + def download(self, local_dir: str): + assert os.path.isdir(local_dir) + self.__stage_fetch__(local_dir) + + def stop(self) -> None: + self.__stage_stop__() + self.save() + + def save_json(self, filename: str) -> None: + with open(filename, "w") as f: + json.dump(self.model_dump(), f, indent=2) + + def save(self) -> None: + av = AiravataOperator(context.access_token) + az = av.__airavata_token__(av.access_token, av.default_gateway_id()) + assert az.accessToken is not None + assert az.claimsMap is not None + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + az.accessToken, + 'X-Claims': json.dumps(az.claimsMap) + } + import requests + if self.id is None: + self.id = str(uuid.uuid4()) + response = requests.post("https://api.gateway.cybershuttle.org/api/v1/plan", headers=headers, json=self.model_dump()) + print(f"Plan saved: {self.id}") + else: + response = requests.put(f"https://api.gateway.cybershuttle.org/api/v1/plan/{self.id}", headers=headers, json=self.model_dump()) + print(f"Plan updated: {self.id}") + + if response.status_code == 200: + body = response.json() + plan = json.loads(body["data"]) + assert plan["id"] == self.id + else: + raise Exception(response) + +def load_json(filename: str) -> Plan: + with open(filename, "r") as f: + model = json.load(f) + return Plan(**model) + +def load(id: str | None) -> Plan: + assert id is not None + av = AiravataOperator(context.access_token) + az = av.__airavata_token__(av.access_token, av.default_gateway_id()) + assert az.accessToken is not None + assert az.claimsMap is not None + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + az.accessToken, + 'X-Claims': json.dumps(az.claimsMap) + } + import requests + response = requests.get(f"https://api.gateway.cybershuttle.org/api/v1/plan/{id}", headers=headers) + + if response.status_code == 200: + body = response.json() + plan = json.loads(body["data"]) + return Plan(**plan) + else: + raise Exception(response) + +def query() -> list[Plan]: + av = AiravataOperator(context.access_token) + az = av.__airavata_token__(av.access_token, av.default_gateway_id()) + assert az.accessToken is not None + assert az.claimsMap is not None + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + az.accessToken, + 'X-Claims': json.dumps(az.claimsMap) + } + import requests + response = requests.get(f"https://api.gateway.cybershuttle.org/api/v1/plan/user", headers=headers) + + if response.status_code == 200: + items: list = response.json() + plans = [json.loads(item["data"]) for item in items] + return [Plan(**plan) for plan in plans] + else: + raise Exception(response) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py new file mode 100644 index 0000000000..260b784e65 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations +from .auth import context +import abc +from typing import Any +from pathlib import Path + +import pydantic + +# from .task import Task +Task = Any + +class Runtime(abc.ABC, pydantic.BaseModel): + + id: str + args: dict[str, str | int | float] = pydantic.Field(default={}) + + @abc.abstractmethod + def execute(self, task: Task) -> None: ... + + @abc.abstractmethod + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: ... + + @abc.abstractmethod + def status(self, task: Task) -> str: ... + + @abc.abstractmethod + def signal(self, signal: str, task: Task) -> None: ... + + @abc.abstractmethod + def ls(self, task: Task) -> list[str]: ... + + @abc.abstractmethod + def upload(self, file: Path, task: Task) -> str: ... + + @abc.abstractmethod + def download(self, file: str, local_dir: str, task: Task) -> str: ... + + @abc.abstractmethod + def cat(self, file: str, task: Task) -> bytes: ... + + def __str__(self) -> str: + return f"{self.__class__.__name__}(args={self.args})" + + @staticmethod + def default(): + return Remote.default() + + @staticmethod + def create(id: str, args: dict[str, Any]) -> Runtime: + if id == "mock": + return Mock(**args) + elif id == "remote": + return Remote(**args) + else: + raise ValueError(f"Unknown runtime id: {id}") + + @staticmethod + def Remote(**kwargs): + return Remote(**kwargs) + + @staticmethod + def Local(**kwargs): + return Mock(**kwargs) + + +class Mock(Runtime): + + _state: int = 0 + + def __init__(self) -> None: + super().__init__(id="mock") + + def execute(self, task: Task) -> None: + import uuid + task.agent_ref = str(uuid.uuid4()) + task.ref = str(uuid.uuid4()) + + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: + pass + + def status(self, task: Task) -> str: + import random + + self._state += random.randint(0, 5) + if self._state > 10: + return "COMPLETED" + return "RUNNING" + + def signal(self, signal: str, task: Task) -> None: + pass + + def ls(self, task: Task) -> list[str]: + return [""] + + def upload(self, file: Path, task: Task) -> str: + return "" + + def download(self, file: str, local_dir: str, task: Task) -> str: + return "" + + def cat(self, file: str, task: Task) -> bytes: + return b"" + + @staticmethod + def default(): + return Mock() + + +class Remote(Runtime): + + def __init__(self, cluster: str, category: str, queue_name: str, node_count: int, cpu_count: int, walltime: int) -> None: + super().__init__(id="remote", args=dict( + cluster=cluster, + category=category, + queue_name=queue_name, + node_count=node_count, + cpu_count=cpu_count, + walltime=walltime, + )) + + def execute(self, task: Task) -> None: + assert task.ref is None + assert task.agent_ref is None + assert {"cluster", "queue_name", "node_count", "cpu_count", "walltime"}.issubset(self.args.keys()) + print(f"[Remote] Creating Experiment: name={task.name}") + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + try: + launch_state = av.launch_experiment( + experiment_name=task.name, + app_name=task.app_id, + inputs=task.inputs, + computation_resource_name=str(self.args["cluster"]), + queue_name=str(self.args["queue_name"]), + node_count=int(self.args["node_count"]), + cpu_count=int(self.args["cpu_count"]), + walltime=int(self.args["walltime"]), + ) + task.agent_ref = launch_state.agent_ref + task.pid = launch_state.process_id + task.ref = launch_state.experiment_id + task.workdir = launch_state.experiment_dir + task.sr_host = launch_state.sr_host + print(f"[Remote] Experiment Launched: id={task.ref}") + except Exception as e: + print(f"[Remote] Failed to launch experiment: {e}") + raise e + + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: + assert task.ref is not None + assert task.agent_ref is not None + assert task.pid is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + result = av.execute_py(libraries, code, task.agent_ref, task.pid, task.runtime.args) + print(result) + + def status(self, task: Task): + assert task.ref is not None + assert task.agent_ref is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + status = av.get_experiment_status(task.ref) + return status + + def signal(self, signal: str, task: Task) -> None: + assert task.ref is not None + assert task.agent_ref is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + av.stop_experiment(task.ref) + + def ls(self, task: Task) -> list[str]: + assert task.ref is not None + assert task.pid is not None + assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + files = av.list_files(task.pid, task.agent_ref, task.sr_host, task.workdir) + return files + + def upload(self, file: Path, task: Task) -> str: + assert task.ref is not None + assert task.pid is not None + assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + result = av.upload_files(task.pid, task.agent_ref, task.sr_host, [file], task.workdir).pop() + return result + + def download(self, file: str, local_dir: str, task: Task) -> str: + assert task.ref is not None + assert task.pid is not None + assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + result = av.download_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir, local_dir) + return result + + def cat(self, file: str, task: Task) -> bytes: + assert task.ref is not None + assert task.pid is not None + assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + content = av.cat_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir) + return content + + @staticmethod + def default(): + return list_runtimes(cluster="login.expanse.sdsc.edu", category="gpu").pop() + + +def list_runtimes( + cluster: str | None = None, + category: str | None = None, + node_count: int | None = None, + cpu_count: int | None = None, + walltime: int | None = None, +) -> list[Runtime]: + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + all_runtimes = av.get_available_runtimes() + out_runtimes = [] + for r in all_runtimes: + if (cluster in [None, r.args["cluster"]]) and (category in [None, r.args["category"]]): + r.args["node_count"] = node_count or r.args["node_count"] + r.args["cpu_count"] = cpu_count or r.args["cpu_count"] + r.args["walltime"] = walltime or r.args["walltime"] + out_runtimes.append(r) + return out_runtimes + +def is_terminal_state(x): + return x in ["CANCELED", "COMPLETED", "FAILED"] \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py new file mode 100644 index 0000000000..76aa874ddb --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py @@ -0,0 +1,144 @@ +import inspect +import ast +import textwrap +import sys + + +def scriptize(func): + # Get the source code of the decorated function + source_code = textwrap.dedent(inspect.getsource(func)) + func_tree = ast.parse(source_code) + + # Retrieve the module where the function is defined + module_name = func.__module__ + if module_name in sys.modules: + module = sys.modules[module_name] + else: + raise RuntimeError(f"Cannot find module {module_name} for function {func.__name__}") + + # Attempt to get the module source. + # If this fails (e.g., in a Jupyter notebook), fallback to an empty module tree. + try: + module_source = textwrap.dedent(inspect.getsource(module)) + module_tree = ast.parse(module_source) + except (TypeError, OSError): + # In Jupyter (or certain environments), we can't get the module source this way. + # Use an empty module tree as a fallback. + module_tree = ast.parse("") + + # Find the function definition node + func_def = next( + (node for node in func_tree.body if isinstance(node, ast.FunctionDef)), None) + if not func_def: + raise ValueError("No function definition found in func_tree.") + + # ---- NEW: Identify used names in the function body ---- + # We'll walk the function body to collect all names used. + class NameCollector(ast.NodeVisitor): + def __init__(self): + self.used_names = set() + + def visit_Name(self, node): + self.used_names.add(node.id) + self.generic_visit(node) + + def visit_Attribute(self, node): + # This accounts for usage like time.sleep (attribute access) + # We add 'time' if we see something like time.sleep + # The top-level name is usually in node.value + if isinstance(node.value, ast.Name): + self.used_names.add(node.value.id) + self.generic_visit(node) + + name_collector = NameCollector() + name_collector.visit(func_def) + used_names = name_collector.used_names + + # For imports, we need to consider a few cases: + # - `import module` + # - `import module as alias` + # - `from module import name` + # We'll keep an import if it introduces at least one name or module referenced by the function. + def is_import_used(import_node): + + if isinstance(import_node, ast.Import): + # import something [as alias] + for alias in import_node.names: + # If we have something like `import time` and "time" is used, + # or `import pandas as pd` and "pd" is used, keep it. + if alias.asname and alias.asname in used_names: + return True + if alias.name.split('.')[0] in used_names: + return True + return False + elif isinstance(import_node, ast.ImportFrom): + # from module import name(s) + # Keep if any of the imported names or their asnames are used + for alias in import_node.names: + # Special case: if we have `from module import task_context`, ignore it + if alias.name == "task_context": + return False + # If from module import x as y, check y; else check x + if alias.asname and alias.asname in used_names: + return True + if alias.name in used_names: + return True + # Another subtlety: if we have `from time import sleep` + # and we call `time.sleep()` is that detected? + # Actually, we already caught attribute usage above, which would add "time" to used_names + # but not "sleep". If the code does `sleep(n)` directly, then "sleep" is in used_names. + return False + return False + + # For other functions, include only if their name is referenced. + def is_function_used(func_node): + return func_node.name in used_names + + def wrapper(*args, **kwargs): + # Bind arguments + func_signature = inspect.signature(func) + bound_args = func_signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + # Convert the original function body to source + body_source_lines = [ast.unparse(stmt) for stmt in func_def.body] + body_source_code = "\n".join(body_source_lines) + + # Collect relevant code blocks: + relevant_code_blocks = [] + for node in module_tree.body: + if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): + # Include only used imports + if is_import_used(node): + relevant_code_blocks.append(ast.unparse(node).strip()) + elif isinstance(node, ast.FunctionDef): + # Include only used functions, excluding the decorator itself and the decorated function + if node.name not in ('task_context', func.__name__) and is_function_used(node): + func_code = ast.unparse(node).strip() + relevant_code_blocks.append(func_code) + + # Prepare argument assignments + arg_assignments = [] + for arg_name, arg_value in bound_args.arguments.items(): + # Stringify arguments as before + if isinstance(arg_value, str): + arg_assignments.append(f"{arg_name} = {arg_value!r}") + else: + arg_assignments.append(f"{arg_name} = {repr(arg_value)}") + + # Combine everything + combined_code_parts = [] + if relevant_code_blocks: + combined_code_parts.append("\n\n".join(relevant_code_blocks)) + if arg_assignments: + if combined_code_parts: + combined_code_parts.append("") # blank line before args + combined_code_parts.extend(arg_assignments) + if arg_assignments: + combined_code_parts.append("") # blank line before body + combined_code_parts.append(body_source_code) + + combined_code = "\n".join(combined_code_parts).strip() + return combined_code + + return wrapper diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py new file mode 100644 index 0000000000..18e72a1d15 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from pathlib import Path +from datetime import datetime +from rich.progress import Progress + +import paramiko +from scp import SCPClient + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("paramiko").setLevel(logging.WARNING) + + +def create_pkey(pkey_path): + if pkey_path is not None: + return paramiko.RSAKey.from_private_key_file(pkey_path) + return None + + +class SFTPConnector(object): + + def __init__(self, host: str, port: int, username: str, password: str | None = None, pkey: str | None = None): + self.host = host + self.port = port + self.username = username + self.password = password + + ssh = paramiko.SSHClient() + self.ssh = ssh + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + def mkdir(self, project_name: str, exprement_id: str): + project_name = project_name.replace(" ", "_") + time = datetime.now().strftime("%Y-%m-%d %H:%M:%S").replace(" ", "_") + time = time.replace(":", "_") + time = time.replace("-", "_") + exprement_id = exprement_id + "_" + time + base_path = "/" + project_name + remote_path = base_path + "/" + exprement_id + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + try: + connection = paramiko.SFTPClient.from_transport(transport) + assert connection is not None + try: + connection.lstat(base_path) # Test if remote_path exists + except IOError: + connection.mkdir(base_path) + try: + connection.lstat(remote_path) # Test if remote_path exists + except IOError: + connection.mkdir(remote_path) + finally: + transport.close() + return remote_path + + def put(self, local_paths: list[Path], remote_path: str) -> list[str]: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + remote_paths = [] + try: + with Progress() as progress: + task = progress.add_task("Uploading...", total=len(local_paths)-1) + for file in local_paths: + connection = paramiko.SFTPClient.from_transport(transport) + assert connection is not None + try: + connection.lstat(remote_path) # Test if remote_path exists + except IOError: + connection.mkdir(remote_path) + remote_fpath = remote_path + "/" + file.name + connection.put(file, remote_fpath) + remote_paths.append(remote_fpath) + progress.update(task, advance=1, description=f"Uploading: {file.name}") + progress.update(task, completed=True) + finally: + transport.close() + return remote_paths + + def ls(self, remote_path: str) -> list[str]: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + try: + connection = paramiko.SFTPClient.from_transport(transport) + assert connection is not None + files = connection.listdir(remote_path) + finally: + transport.close() + return files + + def get(self, remote_path: str, local_path: str) -> str: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + with SCPClient(transport) as conn: + conn.get(remote_path, local_path, recursive=True) + self.ssh.close() + return (Path(local_path) / Path(remote_path).name).as_posix() + + def cat(self, remote_path: str) -> bytes: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + sftp = paramiko.SFTPClient.from_transport(transport) + assert sftp is not None + with sftp.open(remote_path, "r") as f: + content = f.read() + return content diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py new file mode 100644 index 0000000000..1612f44f74 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations +from typing import Any +import pydantic +from .runtime import Runtime +from rich.progress import Progress + +class Task(pydantic.BaseModel): + + name: str + app_id: str + inputs: dict[str, Any] + runtime: Runtime + ref: str | None = pydantic.Field(default=None) + pid: str | None = pydantic.Field(default=None) + agent_ref: str | None = pydantic.Field(default=None) + workdir: str | None = pydantic.Field(default=None) + sr_host: str | None = pydantic.Field(default=None) + + @pydantic.field_validator("runtime", mode="before") + def set_runtime(cls, v): + if isinstance(v, dict) and "id" in v: + id = v.pop("id") + args = v.pop("args", {}) + return Runtime.create(id=id, args=args) + return v + + def __str__(self) -> str: + return f"Task(\nname={self.name}\napp_id={self.app_id}\ninputs={self.inputs}\nruntime={self.runtime}\nref={self.ref}\nagent_ref={self.agent_ref}\nfile_path={self.sr_host}:{self.workdir}\n)" + + def launch(self, force=True) -> None: + if not force and self.ref is not None: + print(f"[Task] Task {self.name} has already launched: ref={self.ref}") + return + if self.ref is not None: + input("[NOTE] Past runs will be overwritten! Hit Enter to continue...") + self.ref = None + self.agent_ref = None + print(f"[Task] Executing {self.name} on {self.runtime}") + self.runtime.execute(self) + + def status(self) -> str: + assert self.ref is not None + return self.runtime.status(self) + + def ls(self) -> list[str]: + assert self.ref is not None + return self.runtime.ls(self) + + def upload(self, file: str) -> str: + assert self.ref is not None + from pathlib import Path + return self.runtime.upload(Path(file), self) + + def download(self, file: str, local_dir: str) -> str: + assert self.ref is not None + from pathlib import Path + Path(local_dir).mkdir(parents=True, exist_ok=True) + return self.runtime.download(file, local_dir, self) + + def download_all(self, local_dir: str) -> list[str]: + assert self.ref is not None + import os + os.makedirs(local_dir, exist_ok=True) + fps_task = list[str]() + files = self.ls() + with Progress() as progress: + pbar = progress.add_task(f"Downloading: ...", total=len(files)) + for remote_fp in self.ls(): + fp = self.runtime.download(remote_fp, local_dir, self) + progress.update(pbar, description=f"Downloading: {remote_fp}", advance=1) + fps_task.append(fp) + progress.update(pbar, description=f"Downloading: DONE", refresh=True) + return fps_task + + def cat(self, file: str) -> bytes: + assert self.ref is not None + return self.runtime.cat(file, self) + + def stop(self) -> None: + assert self.ref is not None + return self.runtime.signal("SIGTERM", self) + + def context(self, packages: list[str]) -> Any: + def decorator(func): + def wrapper(*args, **kwargs): + from .scripter import scriptize + make_script = scriptize(func) + return self.runtime.execute_py(packages, make_script(*args, **kwargs), self) + return wrapper + return decorator diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py index c6a10132fd..1c120c10d3 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py @@ -29,9 +29,12 @@ class APIServerClient(object): - def __init__(self, configuration_file_location=None): - self.api_server_settings = APIServerClientSettings(configuration_file_location) - self._load_settings(configuration_file_location) + def __init__(self, configuration_file_location=None, api_server_settings=None): + if configuration_file_location is not None: + self.api_server_settings = APIServerClientSettings(configuration_file_location) + self._load_settings(configuration_file_location) + elif api_server_settings is not None: + self.api_server_settings = api_server_settings self.api_server_client_pool = utils.initialize_api_client_pool(self.api_server_settings.API_SERVER_HOST, self.api_server_settings.API_SERVER_PORT, self.api_server_settings.API_SERVER_SECURE) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/sftp_file_handling_client.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/sftp_file_handling_client.py index badd5f9755..3cbe194e97 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/sftp_file_handling_client.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/sftp_file_handling_client.py @@ -83,4 +83,4 @@ def download_files(self, local_path, remote_path): @staticmethod def uploading_info(uploaded_file_size, total_file_size): logging.info('uploaded_file_size : {} total_file_size : {}'. - format(uploaded_file_size, total_file_size)) + format(uploaded_file_size, total_file_size)) \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py index 9e81f539b7..df3991de67 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py @@ -38,15 +38,14 @@ class ExperimentHandlerUtil(object): - def __init__(self, configuration_file_location=None): + def __init__(self, configuration_file_location=None, access_token=None): self.configuration_file = configuration_file_location - self.authenticator = Authenticator(configuration_file_location) self.gateway_conf = GatewaySettings(configuration_file_location) self.experiment_conf = ExperimentSettings(configuration_file_location) - self.keycloak_conf = KeycloakConfiguration(configuration_file_location) self.authenticator = Authenticator(self.configuration_file) - self.authenticator.authenticate_with_auth_code() - access_token = getpass.getpass('Copy paste the access token') + if access_token is None: + self.authenticator.authenticate_with_auth_code() + access_token = getpass.getpass('Copy paste the access token') self.access_token = access_token decode = jwt.decode(access_token, options={"verify_signature": False}) self.user_id = decode['preferred_username'] @@ -178,7 +177,7 @@ def launch_experiment(self, experiment_name="default_exp", description="this is logger.info("experiment launched id: %s", ex_id) - experiment_url = 'https://' + self.gateway_conf.GATEWAY_ID + '.org/workspace/experiments/' + ex_id + experiment_url = 'https://' + self.gateway_conf.GATEWAY_URL + '.org/workspace/experiments/' + ex_id logger.info("For more information visit %s", experiment_url) if self.experiment_conf.MONITOR_STATUS: diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/transport/settings.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/transport/settings.py index b1e61f22aa..d36cc65550 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/transport/settings.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/transport/settings.py @@ -113,6 +113,7 @@ def __init__(self, configFileLocation=None): if configFileLocation is not None: config.read(configFileLocation) self.GATEWAY_ID = config.get('Gateway', 'GATEWAY_ID') + self.GATEWAY_URL = config.get('Gateway', 'GATEWAY_URL') self.GATEWAY_DATA_STORE_RESOURCE_ID = config.get('Gateway', 'GATEWAY_DATA_STORE_RESOURCE_ID') self.GATEWAY_DATA_STORE_DIR = config.get('Gateway', 'GATEWAY_DATA_STORE_DIR') self.GATEWAY_DATA_STORE_HOSTNAME = config.get('Gateway', 'GATEWAY_DATA_STORE_HOSTNAME') diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml new file mode 100644 index 0000000000..1024105272 --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "airavata-python-sdk-test" +version = "0.0.13" +description = "Apache Airavata Python SDK" +readme = "README.md" +license = { text = "Apache License 2.0" } +authors = [{ name = "Airavata Developers", email = "dev@airavata.apache.org" }] +requires-python = ">=3.10" +dependencies = [ + "oauthlib", + "requests", + "requests-oauthlib", + "thrift", + "thrift_connector", + "paramiko", + "scp", + "pysftp", + "configparser", + "urllib3", + "pyjwt", + "pydantic", + "rich", + "ipywidgets", + "pandas", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["airavata*"] +exclude = ["*.egg-info"] + +[tool.setuptools.package-data] +"airavata_sdk.transport" = ["*.ini"] +"airavata_sdk.samples.resources" = ["*.pem"] diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/requirements.txt b/airavata-api/airavata-client-sdks/airavata-python-sdk/requirements.txt deleted file mode 100644 index 9304d80884..0000000000 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -oauthlib -requests==2.13.0 -requests-oauthlib==0.7.0 -thrift==0.10.0 -thrift_connector==0.24 -paramiko -scp -pysftp -configparser -urllib3 -pyjwt diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/setup.cfg b/airavata-api/airavata-client-sdks/airavata-python-sdk/setup.cfg deleted file mode 100644 index 3397c30e4f..0000000000 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/setup.cfg +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -[bdist_wheel] -universal = 1 - -[metadata] -description-file = README.md -license_file = LICENSE - -[aliases] -test = pytest \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/setup.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/setup.py deleted file mode 100644 index 648e5497d4..0000000000 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/setup.py +++ /dev/null @@ -1,21 +0,0 @@ -import os - -from setuptools import setup, find_packages - - -def read(fname): - with open(os.path.join(os.path.dirname(__file__), fname)) as f: - return f.read() - - -setup( - name='airavata-python-sdk', - version='1.1.6', - packages=find_packages(), - package_data={'airavata_sdk.transport': ['*.ini'], 'airavata_sdk.samples.resources': ['*.pem']}, - url='http://airavata.com', - license='Apache License 2.0', - author='Airavata Developers', - author_email='dev@airavata.apache.org', - description='Apache Airavata Python SDK' -) diff --git a/dev-tools/deployment/jupyterhub/user-container/MD/.gitignore b/dev-tools/deployment/jupyterhub/user-container/MD/.gitignore new file mode 100644 index 0000000000..423a839aa4 --- /dev/null +++ b/dev-tools/deployment/jupyterhub/user-container/MD/.gitignore @@ -0,0 +1,3 @@ +plan.json +auth.state +results*/ \ No newline at end of file diff --git a/dev-tools/deployment/jupyterhub/user-container/MD/poc.ipynb b/dev-tools/deployment/jupyterhub/user-container/MD/smd_cpu.ipynb similarity index 95% rename from dev-tools/deployment/jupyterhub/user-container/MD/poc.ipynb rename to dev-tools/deployment/jupyterhub/user-container/MD/smd_cpu.ipynb index 858f473622..22b4faa37e 100644 --- a/dev-tools/deployment/jupyterhub/user-container/MD/poc.ipynb +++ b/dev-tools/deployment/jupyterhub/user-container/MD/smd_cpu.ipynb @@ -159,9 +159,8 @@ " \"data/b4pull.restart.xsc\",\n", " ],\n", " parallelism=\"CPU\",\n", - " num_replicas=1,\n", ")\n", - "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"cpu\"))\n", + "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"cpu\", walltime=60))\n", "ae.display(exp)" ] }, @@ -337,8 +336,8 @@ "outputs": [], "source": [ "plan.wait_for_completion() # wait for plan to complete\n", - "for task in plan.tasks:\n", - " task.download_all(f\"./results_{task.name}\") # download plan outputs" + "# for task in plan.tasks:\n", + "# task.download_all(f\"./results_{task.name}\") # download plan outputs" ] }, { @@ -363,17 +362,24 @@ " @task.context(packages=[\"numpy\", \"pandas\"])\n", " def analyze() -> None:\n", " import numpy as np\n", - " with open(\"pull.conf\", \"r\") as f:\n", + " with open(\"pull_cpu.conf\", \"r\") as f:\n", " data = f.read()\n", - " print(\"pull.conf has\", len(data), \"chars\")\n", + " print(\"pull_cpu.conf has\", len(data), \"chars\")\n", " print(np.arange(10))\n", " analyze()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "airavata", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -387,9 +393,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.11.6" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/dev-tools/deployment/jupyterhub/user-container/MD/smd_gpu.ipynb b/dev-tools/deployment/jupyterhub/user-container/MD/smd_gpu.ipynb new file mode 100644 index 0000000000..d9de2a5b38 --- /dev/null +++ b/dev-tools/deployment/jupyterhub/user-container/MD/smd_gpu.ipynb @@ -0,0 +1,401 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cybershuttle SDK - Molecular Dynamics\n", + "> Define, run, monitor, and analyze molecular dynamics experiments in a HPC-agnostic way.\n", + "\n", + "This notebook shows how users can setup and launch a **NAMD** experiment with replicas, monitor its execution, and run analyses both during and after execution." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installing Required Packages\n", + "\n", + "First, install the `airavata-python-sdk-test` package from the pip repository." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade airavata-python-sdk-test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Importing the SDK" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import airavata_experiments as ae\n", + "from airavata_experiments import md" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Authenticating\n", + "\n", + "To authenticate for remote execution, call the `ae.login()` method.\n", + "This method will prompt you to enter your credentials and authenticate your session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ae.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once authenticated, the `ae.list_runtimes()` function can be called to list HPC resources that the user has access to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "runtimes = ae.list_runtimes()\n", + "ae.display(runtimes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uploading Experiment Files\n", + "\n", + "Drag and drop experiment files onto the workspace that this notebook is run on.\n", + "\n", + "```bash\n", + "(sh) $: tree data\n", + "data\n", + "├── b4pull.pdb\n", + "├── b4pull.restart.coor\n", + "├── b4pull.restart.vel\n", + "├── b4pull.restart.xsc\n", + "├── par_all36_water.prm\n", + "├── par_all36m_prot.prm\n", + "├── pull.conf\n", + "├── structure.pdb\n", + "└── structure.psf\n", + "\n", + "1 directory, 9 files\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining a NAMD Experiment\n", + "\n", + "The `md.NAMD.initialize()` is used to define a NAMD experiment.\n", + "Here, provide the paths to the `.conf` file, the `.pdb` file, the `.psf` file, any optional files you want to run NAMD on.\n", + "You can preview the function definition through auto-completion.\n", + "\n", + "```python\n", + "def initialize(\n", + " name: str,\n", + " config_file: str,\n", + " pdb_file: str,\n", + " psf_file: str,\n", + " ffp_files: list[str],\n", + " other_files: list[str] = [],\n", + " parallelism: Literal['CPU', 'GPU'] = \"CPU\",\n", + " num_replicas: int = 1\n", + ") -> Experiment[ExperimentApp]\n", + "```\n", + "\n", + "To add replica runs, simply call the `exp.add_replica()` function.\n", + "You can call the `add_replica()` function as many times as you want replicas.\n", + "Any optional resource constraint can be provided here.\n", + "\n", + "You can also call `ae.display()` to pretty-print the experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "exp = md.NAMD.initialize(\n", + " name=\"SMD\",\n", + " config_file=\"data/pull_gpu.conf\",\n", + " pdb_file=\"data/structure.pdb\",\n", + " psf_file=\"data/structure.psf\",\n", + " ffp_files=[\n", + " \"data/par_all36_water.prm\",\n", + " \"data/par_all36m_prot.prm\"\n", + " ],\n", + " other_files=[\n", + " \"data/b4pull.pdb\",\n", + " \"data/b4pull.restart.coor\",\n", + " \"data/b4pull.restart.vel\",\n", + " \"data/b4pull.restart.xsc\",\n", + " ],\n", + " parallelism=\"GPU\",\n", + ")\n", + "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"gpu\", walltime=180))\n", + "ae.display(exp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating an Execution Plan\n", + "\n", + "Call the `exp.plan()` function to transform the experiment definition + replicas into a stateful execution plan." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan = exp.plan()\n", + "ae.display(plan)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Plan\n", + "\n", + "A created plan can be saved locally (in JSON) or remotely (in a user-local DB) for later reference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.save() # this will save the plan in DB\n", + "plan.save_json(\"plan_gpu.json\") # save the plan state locally" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launching the Plan\n", + "\n", + "A created plan can be launched using the `plan.launch()` function.\n", + "Changes to plan states will be automatically saved onto the remote.\n", + "However, plan state can also be tracked locally by invoking `plan.save_json()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.launch()\n", + "plan.save_json(\"plan_gpu.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Checking the Plan Status\n", + "The status of a plan can be retrieved by calling `plan.status()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.status()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading a Saved Plan\n", + "\n", + "A saved plan can be loaded by calling `ae.plan.load_json(plan_path)` (for local plans) or `ae.plan.load(plan_id)` (for remote plans)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan = ae.plan.load_json(\"plan_gpu.json\")\n", + "plan = ae.plan.load(plan.id)\n", + "plan.status()\n", + "ae.display(plan)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fetching User-Defined Plans\n", + "\n", + "The `ae.plan.query()` function retrieves all plans stored in the remote." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plans = ae.plan.query()\n", + "ae.display(plans)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Managing Plan Execution\n", + "\n", + "The `plan.stop()` function will stop a currently executing plan.\n", + "The `plan.wait_for_completion()` function would block until the plan finishes executing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plan.stop()\n", + "# plan.wait_for_completion()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interacting with Files\n", + "\n", + "The `task` object has several helper functions to perform file operations within its context.\n", + "\n", + "* `task.ls()` - list all remote files (inputs, outputs, logs, etc.)\n", + "* `task.upload(, )` - upload a local file to remote\n", + "* `task.cat()` - displays contents of a remote file\n", + "* `task.download(, )` - fetch a remote file to local" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for task in plan.tasks:\n", + " print(task.name, task.pid)\n", + " display(task.ls()) # list files\n", + " task.upload(\"data/sample.txt\") # upload sample.txt\n", + " display(task.ls()) # list files AFTER upload\n", + " display(task.cat(\"sample.txt\")) # preview sample.txt\n", + " task.download(\"sample.txt\", f\"./results_{task.name}\") # download sample.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.wait_for_completion() # wait for plan to complete\n", + "# for task in plan.tasks:\n", + "# task.download_all(f\"./results_{task.name}\") # download plan outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Executing Task-Local Code Remotely\n", + "\n", + "The `@task.context()` decorator can be applied on Python functions to run them remotely within the task context.\n", + "The functions executed this way has access to the task files, as well as the remote compute resources.\n", + "\n", + "**NOTE**: Currently, remote code execution is only supported for ongoing tasks. In future updates, we will support both ongoing and completed tasks. Stay tuned!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for index, task in enumerate(plan.tasks):\n", + " @task.context(packages=[\"numpy\", \"pandas\"])\n", + " def analyze() -> None:\n", + " import numpy as np\n", + " with open(\"pull_gpu.conf\", \"r\") as f:\n", + " data = f.read()\n", + " print(\"pull_gpu.conf has\", len(data), \"chars\")\n", + " print(np.arange(10))\n", + " analyze()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/dev-tools/deployment/scripts/.gitkeep b/dev-tools/deployment/scripts/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/deployment/scripts/expanse/agent.sh b/dev-tools/deployment/scripts/expanse/agent.sh new file mode 100644 index 0000000000..b0df5471e3 --- /dev/null +++ b/dev-tools/deployment/scripts/expanse/agent.sh @@ -0,0 +1,40 @@ +#!/bin/bash -x + +# ##################################################################### +# Standalone Airavata Agent for Expanse +# ##################################################################### +# +# ---------------------------------------------------------------------- +# CONTRIBUTORS +# ---------------------------------------------------------------------- +# * Dimuthu Wannipurage +# * Lahiru Jayathilake +# * Yasith Jayawardana +# ###################################################################### + +#----------------------------------------------------------------------- +# STEP 1 - PARSE COMMAND LINE ARGS +#----------------------------------------------------------------------- +while getopts a:s:p: option; do + case $option in + a) AGENT_ID=$OPTARG ;; + s) SERVER_URL=$OPTARG ;; + p) PROCESS_ID=$OPTARG ;; + \?) cat <! Usage: $0 [-a AGENT_ID ] !< +>! [-s SERVER_URL] !< +>! [-w PROCESS_ID] !< +ENDCAT + esac +done + +echo "AGENT_ID=$AGENT_ID" +echo "SERVER_URL=$SERVER_URL" +echo "PROCESS_ID=$PROCESS_ID" + +# ---------------------------------------------------------------------- +# STEP 2 - RUN AGENT +# ---------------------------------------------------------------------- +SIF_PATH=/home/scigap/agent-framework/airavata-agent.sif +module load singularitypro +singularity exec --bind /expanse/lustre/scratch/scigap/temp_project/neuro-workdirs/$PROCESS_ID:/data $SIF_PATH bash -c "/opt/airavata-agent $SERVER_URL:19900 $AGENT_ID" diff --git a/dev-tools/deployment/scripts/expanse/alphafold2-agent.sh b/dev-tools/deployment/scripts/expanse/alphafold2-agent.sh new file mode 100755 index 0000000000..d752b6baa8 --- /dev/null +++ b/dev-tools/deployment/scripts/expanse/alphafold2-agent.sh @@ -0,0 +1,133 @@ +#!/bin/bash -x + +# ##################################################################### +# AlphaFold2 Driver + Airavata Agent for Expanse +# ##################################################################### +# +# ---------------------------------------------------------------------- +# CONTRIBUTORS +# ---------------------------------------------------------------------- +# * Sudhakar Pamidigantham +# * Lahiru Jayathilake +# * Dimuthu Wannipurage +# * Yasith Jayawardana +# +# ###################################################################### + +######################################################################## +# Part 1 - Housekeeping +######################################################################## + +#----------------------------------------------------------------------- +# Step 1.1 - Check command line +#----------------------------------------------------------------------- + +while getopts t:p:m: option; do + case $option in + t) MaxDate=$OPTARG ;; + p) MODEL_PRESET=$OPTARG ;; + m) Num_Multi=$OPTARG ;; + \?) cat <! Usage: $0 [-t Maximum Template Date ] !< +>! [-p Model Preset ] !< +>! [-m Number of Multimers per Model ] !< +ENDCAT + # exit 1 ;; + esac +done + +if [ $Num_Multi = "" ]; then + export Num_Multi=1 +fi +#set the environment PATH +export PYTHONNOUSERSITE=True +module reset +module load singularitypro +ALPHAFOLD_DATA_PATH=/expanse/projects/qstore/data/alphafold-v2.3.2 +ALPHAFOLD_MODELS=/expanse/projects/qstore/data/alphafold-v2.3.2/params + +#ALPHAFOLD_DATA_PATH=/expanse/projects/qstore/data/alphafold +#ALPHAFOLD_MODELS=/expanse/projects/qstore/data/alphafold/params +pdb70="" +uniprot="" +pdbseqres="" +nummulti="" + +# check_flags +if [ "monomer" = "${MODEL_PRESET%_*}" ]; then + export pdb70="--pdb70_database_path=/data/pdb70/pdb70" +else + export uniprot="--uniprot_database_path=/data/uniprot/uniprot.fasta" + export pdbseqres="--pdb_seqres_database_path=/data/pdb_seqres/pdb_seqres.txt" + export nummulti="--num_multimer_predictions_per_model=$Num_Multi" +fi + +## Copy input to node local scratch +cp input.fasta /scratch/$USER/job_$SLURM_JOBID +#cp -r /expanse/projects/qstore/data/alphafold/uniclust30/uniclust30_2018_08 /scratch/$USER/job_$SLURM_JOBID/ +cd /scratch/$USER/job_$SLURM_JOBID +ln -s /expanse/projects/qstore/data/alphafold/uniclust30/uniclust30_2018_08 +mkdir bfd +cp /expanse/projects/qstore/data/alphafold/bfd/*index bfd/ +#cp /expanse/projects/qstore/data/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffdata bfd/ +#cp /expanse/projects/qstore/data/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_cs219.ffdata bfd/ +cd bfd +ln -s /expanse/projects/qstore/data/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffdata +ln -s /expanse/projects/qstore/data/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_cs219.ffdata +ln -s /expanse/projects/qstore/data/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffdata +cd ../ +mkdir alphafold_output +# Create soft links ro rundir form submitdir + +ln -s /scratch/$USER/job_$SLURM_JOBID $SLURM_SUBMIT_DIR/rundir + +#Run the command +singularity run --nv \ + -B /expanse/lustre \ + -B /expanse/projects \ + -B /scratch \ + -B $ALPHAFOLD_DATA_PATH:/data \ + -B $ALPHAFOLD_MODELS \ + /cm/shared/apps/containers/singularity/alphafold/alphafold_aria2_v2.3.2.simg \ + --fasta_paths=/scratch/$USER/job_$SLURM_JOBID/input.fasta \ + --uniref90_database_path=/data/uniref90/uniref90.fasta \ + --data_dir=/data \ + --mgnify_database_path=/data/mgnify/mgy_clusters_2022_05.fa \ + --bfd_database_path=/scratch/$USER/job_$SLURM_JOBID/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ + --uniref30_database_path=/data/uniref30/UniRef30_2021_03 \ + $pdbseqres \ + $pdb70 \ + $uniprot \ + --template_mmcif_dir=/data/pdb_mmcif/mmcif_files \ + --obsolete_pdbs_path=/data/pdb_mmcif/obsolete.dat \ + --output_dir=/scratch/$USER/job_$SLURM_JOBID/alphafold_output \ + --max_template_date=$MaxDate \ + --model_preset=$MODEL_PRESET \ + --use_gpu_relax=true \ + --models_to_relax=best \ + $nummulti + +#-B .:/etc \ +#/cm/shared/apps/containers/singularity/alphafold/alphafold.sif \ +#--fasta_paths=input.fasta \ +#--uniref90_database_path=/data/uniref90/uniref90.fasta \ +#--data_dir=/data \ +#--mgnify_database_path=/data/mgnify/mgy_clusters.fa \ +#--bfd_database_path=/scratch/$USER/job_$SLURM_JOBID/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ +#--uniclust30_database_path=/scratch/$USER/job_$SLURM_JOBID/uniclust30_2018_08/uniclust30_2018_08 \ +#--pdb70_database_path=/data/pdb70/pdb70 \ +#--template_mmcif_dir=/data/pdb_mmcif/mmcif_files \ +#--obsolete_pdbs_path=/data/pdb_mmcif/obsolete.dat \ +#--output_dir=alphafold_output \ +#--max_template_date=$MaxDate \ +#--preset=$MODEL_PRESET + +#make a user choice --preset=casp14 +#make this a user choice --max_template_date=2020-05-14 \ +# --model_names='model_1' \ +# Remove model data +unlink $SLURM_SUBMIT_DIR/rundir + +### Copy back results + +tar -cvf $SLURM_SUBMIT_DIR/alphafold_output.tar alphafold_output diff --git a/dev-tools/deployment/scripts/expanse/namd-agent.sh b/dev-tools/deployment/scripts/expanse/namd-agent.sh new file mode 100755 index 0000000000..63ea64bd29 --- /dev/null +++ b/dev-tools/deployment/scripts/expanse/namd-agent.sh @@ -0,0 +1,231 @@ +#!/bin/bash -x + +# ##################################################################### +# NAMD Driver + Airavata Agent for Expanse +# ##################################################################### +# +# ---------------------------------------------------------------------- +# CONTRIBUTORS +# ---------------------------------------------------------------------- +# * Sudhakar Pamidigantham +# * Diego Gomes +# * Lahiru Jayathilake +# * Yasith Jayawardana +# +# ---------------------------------------------------------------------- +# CHANGELOG +# ---------------------------------------------------------------------- +# * 2024/12/13 - Agent subprocess and graceful shutdown (Yasith J) +# * 2024/12/09 - Reviewed (Diego Gomes) +# ###################################################################### + +######################################################################## +# Part 1 - Housekeeping +######################################################################## + +#----------------------------------------------------------------------- +# Step 1.1 - Check command line +#----------------------------------------------------------------------- +if [ $# -lt 1 -o $# -gt 15 ]; then + echo 1>&2 "Usage: $0 -t [CPU/GPU] -r [PJobID] -l [Continue_Replicas_list] -n [Number_of_Replicas] -i input_conf [SEAGrid_UserName] " + exit 127 +fi + +# subdir depends on whether we're doing freq, water or PES. For freq and water, +# it should be hardcoded in the Xbaya workflow. For PES, it should be an +# additional array generated by the frontend. The contents of this array are +# trivial, but creating an extra Xbaya service to generate it would add +# unnecessary extra complexity. Besides, the frontend cannot avoid having to +# pass at least one array: the array with gjf files. + +subdir="$PWD" +while getopts t:r:l:n:i:a:s: option; do + case $option in + t) ExeTyp=$OPTARG ;; + r) PJobID=$OPTARG ;; + l) rep_list=$OPTARG ;; + n) num_rep=$OPTARG ;; + i) input=$OPTARG ;; + a) agent_id=$OPTARG ;; + s) server_url=$OPTARG ;; + \?) cat <! Usage: $0 [-et execution type cpu/gpu ] !< +>! [-rr Previous JobID for continuation (optional)] !< +>! [-rl replica list for contiuation (optional)] !< +>! [-rep Number of replicas to run (optional)] !< +ENDCAT + esac +done + +echo "ExeTyp=$ExeTyp" +echo "PJobID=$PJobID" +echo "rep_list=$rep_list" +echo "num_rep=$num_rep" +echo "input=$input" +echo "agent_id=$agent_id" +echo "server_url=$server_url" + +# ---------------------------------------------------------------------- +# RUN AGENT AS SUBPROCESS (for now) +# ---------------------------------------------------------------------- +SIF_PATH=/home/scigap/agent-framework/airavata-agent.sif +module load singularitypro +singularity exec --bind $PWD:/data $SIF_PATH bash -c "/opt/airavata-agent $server_url:19900 $agent_id" & +agent_pid=$! # save agent PID for graceful shutdown + +#----------------------------------------------------------------------- +# Step 1.2 - Validate inputs +#----------------------------------------------------------------------- + +if [ ! $AIRAVATA_USERNAME ]; then + echo "Missing AIRAVATA_USERNAME. Check with support!" + exit +fi + +if [ ! $ExeTyp ]; then + echo "Missing Execution Type: [CPU, GPU]" + exit +fi + +SG_UserName="$AIRAVATA_USERNAME" +echo "Execution Type: $ExeTyp" + +#----------------------------------------------------------------------- +# Step 1.3 - Get the input configuration filename +#----------------------------------------------------------------------- +filename=$(basename -- "$input") +filename="${filename%.*}" + +#----------------------------------------------------------------------- +# Step 1.4 - Copy previous files if this a continuation. +#----------------------------------------------------------------------- +if [ "$PJobID" ]; then + cp $input saveInput #save configuration + ls -lt $localarc/$PJobID/ + cp -r $localarc/$PJobID/. . + cp saveInput $input +fi + +#----------------------------------------------------------------------- +# Step 1.5 - Create folders for replicas (if necessary) +#----------------------------------------------------------------------- +echo " Creating folders for replica run(s)" +input_files=$(ls *.* | grep -v slurm) + +# Create one subdirectory per replica and copy over the inputs +for i in $(seq 1 ${num_rep}); do + if [ ! -d ${i} ]; then + mkdir ${i} + cp $input_files ${i}/ + fi +done + +######################################################################## +# Part 2 - Machine specific Options (SDSC-Expanse) +######################################################################## + +#----------------------------------------------------------------------- +# Step 2.1 - Load modules (SDSC-Expanse) +#----------------------------------------------------------------------- + +module purge +module load slurm/expanse/current + +if [ $ExeTyp = "CPU" ]; then + echo "Loading CPU modules" + module load cpu/0.17.3b gcc/10.2.0 openmpi/4.1.1 +fi +if [ $ExeTyp = "GPU" ]; then + echo "Loading GPU modules" + module load gpu/0.17.3b +fi + +module list + +#----------------------------------------------------------------------- +# Step 2.2 - Set NAMD binary and command line for SDSC-Expanse +#----------------------------------------------------------------------- +APP_PATH=/home/scigap/applications +if [ $ExeTyp == "CPU" ]; then + export NAMDPATH="$APP_PATH/NAMD_3.1alpha2_Linux-x86_64-multicore" +fi +if [ $ExeTyp == "GPU" ]; then + export NAMDPATH="$APP_PATH/NAMD_3.0.1_Linux-x86_64-multicore-CUDA" +fi + +#----------------------------------------------------------------------- +# Step 2.3 A - Run NAMD3 (CPU, Serial) +#----------------------------------------------------------------------- +# - one replica at a given time +# - each replica uses all CPUs +#----------------------------------------------------------------------- +if [ ${ExeTyp} == "CPU" ]; then + for replica in $(seq 1 ${num_rep}); do + cd ${subdir}/${replica}/ # Go to folder + + # Run NAMD3 + ${NAMDPATH}/namd3 \ + +setcpuaffinity \ + +p ${SLURM_CPUS_ON_NODE} \ + $input >${filename}.out 2>${filename}.err + done +fi + +#----------------------------------------------------------------------- +# Step 2.3 B - Run NAMD3 (GPU, Batched) +#----------------------------------------------------------------------- +# - one replica PER GPU at a given time +# - each replica uses all CPUs +#----------------------------------------------------------------------- +if [ ${ExeTyp} == "GPU" ]; then + GPU_ID=0 + subtask_pids=() + for replica in $(seq 1 ${num_rep}); do + cd ${subdir}/${replica}/ # Go to folder + + # Run NAMD3 in background + ${NAMDPATH}/namd3 \ + +setcpuaffinity \ + +p ${SLURM_CPUS_ON_NODE} \ + +devices ${GPU_ID} \ + $input >${filename}.out 2>${filename}.err & + + subtask_pids+=($!) # Store PID of the background NAMD task + let GPU_ID+=1 # Increment GPU_ID + + # Wait for a batch of replicas to complete + if [ ${GPU_ID} == ${SLURM_GPUS_ON_NODE} ]; then + wait "${subtask_pids[@]}" # wait for current batch to complete + subtask_pids=() # clear subtask_pids of current batch + GPU_ID=0 # reset gpu counter + fi + done + + # Wait for the last batch of replicas to complete + if [ ${#subtask_pids[@]} -gt 0 ]; then + wait "${subtask_pids[@]}" # wait for last batch to complete + subtask_pids=() # clear subtask_pids of last batch + fi +fi + +# Once done, go back to main folder +cd ${subdir} + +######################################################################## +# Part 3 - Output Flattening +######################################################################## +for replica in $(seq 1 ${num_rep}); do + for file in $(ls ${replica}/*.*); do + mv ${file} ${replica}"_"$(basename $file) + done + rm -rf ${replica}/ +done + +# Send SIGTERM to agent, and wait for completion +kill -TERM $agent_pid +wait $agent_pid + +# Give it a break when jobs are done +sleep 10 + +# bye! diff --git a/modules/agent-framework/airavata-agent/.gitignore b/modules/agent-framework/airavata-agent/.gitignore index 3569009354..771f823695 100644 --- a/modules/agent-framework/airavata-agent/.gitignore +++ b/modules/agent-framework/airavata-agent/.gitignore @@ -1,4 +1,5 @@ airavata-agent +airavata-agent-linux go.sum jupyter/extension/airavata_jupyter_magic/dist jupyter/extension/airavata_jupyter_magic/lib diff --git a/modules/agent-framework/airavata-agent/Dockerfile b/modules/agent-framework/airavata-agent/Dockerfile index f90643cbc7..745a271579 100644 --- a/modules/agent-framework/airavata-agent/Dockerfile +++ b/modules/agent-framework/airavata-agent/Dockerfile @@ -1,4 +1,4 @@ -FROM python:slim +FROM python:3.12-slim RUN pip install flask jupyter jupyter-client RUN mkdir -p /opt/jupyter diff --git a/modules/agent-framework/airavata-agent/agent.go b/modules/agent-framework/airavata-agent/agent.go index 99d28c8088..7ae28fcc37 100644 --- a/modules/agent-framework/airavata-agent/agent.go +++ b/modules/agent-framework/airavata-agent/agent.go @@ -2,21 +2,21 @@ package main import ( protos "airavata-agent/protos" - "bufio" "bytes" "context" "encoding/json" "fmt" "io" - "io/ioutil" "log" "net" "net/http" "os" "os/exec" + "strings" "golang.org/x/crypto/ssh" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) func main() { @@ -27,7 +27,7 @@ func main() { grpcStreamChannel := make(chan struct{}) kernelChannel := make(chan struct{}) - conn, err := grpc.Dial(serverUrl, grpc.WithInsecure(), grpc.WithBlock()) + conn, err := grpc.NewClient(serverUrl, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { log.Fatalf("did not connect: %v", err) } @@ -49,57 +49,57 @@ func main() { log.Printf("Connected to the server...") } - go func() { - log.Printf("Starting jupyter kernel") - cmd := exec.Command("python", "/opt/jupyter/kernel.py") - //cmd := exec.Command("jupyter/venv/bin/python", "jupyter/kernel.py") - stdout, err := cmd.StdoutPipe() - - if err != nil { - fmt.Println("[agent.go] Error creating StdoutPipe:", err) - return - } - - // Get stderr pipe - stderr, err := cmd.StderrPipe() - if err != nil { - fmt.Println("[agent.go] Error creating StderrPipe:", err) - return - } - - log.Printf("[agent.go] Starting command for execution") - // Start the command - if err := cmd.Start(); err != nil { - fmt.Println("[agent.go] Error starting command:", err) - return - } - - // Create channels to read from stdout and stderr - stdoutScanner := bufio.NewScanner(stdout) - stderrScanner := bufio.NewScanner(stderr) - - // Stream stdout - go func() { - for stdoutScanner.Scan() { - fmt.Printf("[agent.go] stdout: %s\n", stdoutScanner.Text()) - } - }() - - // Stream stderr - go func() { - for stderrScanner.Scan() { - fmt.Printf("[agent.go] stderr: %s\n", stderrScanner.Text()) - } - }() - - // Wait for the command to finish - if err := cmd.Wait(); err != nil { - fmt.Println("[agent.go] Error waiting for command:", err) - return - } - - fmt.Println("[agent.go] Command finished") - }() + // go func() { + // log.Printf("Starting jupyter kernel") + // cmd := exec.Command("python", "/opt/jupyter/kernel.py") + // //cmd := exec.Command("jupyter/venv/bin/python", "jupyter/kernel.py") + // stdout, err := cmd.StdoutPipe() + + // if err != nil { + // fmt.Println("[agent.go] Error creating StdoutPipe:", err) + // return + // } + + // // Get stderr pipe + // stderr, err := cmd.StderrPipe() + // if err != nil { + // fmt.Println("[agent.go] Error creating StderrPipe:", err) + // return + // } + + // log.Printf("[agent.go] Starting command for execution") + // // Start the command + // if err := cmd.Start(); err != nil { + // fmt.Println("[agent.go] Error starting command:", err) + // return + // } + + // // Create channels to read from stdout and stderr + // stdoutScanner := bufio.NewScanner(stdout) + // stderrScanner := bufio.NewScanner(stderr) + + // // Stream stdout + // go func() { + // for stdoutScanner.Scan() { + // fmt.Printf("[agent.go] stdout: %s\n", stdoutScanner.Text()) + // } + // }() + + // // Stream stderr + // go func() { + // for stderrScanner.Scan() { + // fmt.Printf("[agent.go] stderr: %s\n", stderrScanner.Text()) + // } + // }() + + // // Wait for the command to finish + // if err := cmd.Wait(); err != nil { + // fmt.Println("[agent.go] Error waiting for command:", err) + // return + // } + + // fmt.Println("[agent.go] Command finished") + // }() go func() { for { @@ -127,23 +127,61 @@ func main() { log.Printf("[agent.go] Working Dir %s", workingDir) log.Printf("[agent.go] Libraries %s", libraries) - // TODO: cd into working dir, create the virtual environment with provided libraries - cmd := exec.Command("python3", "-c", code) //TODO: Load python runtime from a config - - output, err := cmd.Output() - if err != nil { - fmt.Println("[agent.go] Failed to run python command:", err) - return - } + go func() { + + // setup the venv + venvCmd := fmt.Sprintf(` + agentId="%s" + pkgs="%s" + + if [ ! -f "/tmp/$agentId/venv" ]; then + mkdir -p /tmp/$agentId + python3 -m venv /tmp/$agentId/venv + fi + + source /tmp/$agentId/venv/bin/activate + python3 -m pip install $pkgs + + `, agentId, strings.Join(libraries, " ")) + log.Println("[agent.go] venv setup:", venvCmd) + venvExc := exec.Command("bash", "-c", venvCmd) + venvOut, venvErr := venvExc.CombinedOutput() + if venvErr != nil { + fmt.Println("[agent.go] venv setup: ERR", venvErr) + return + } + venvStdout := string(venvOut) + fmt.Println("[agent.go] venv setup:", venvStdout) + + // execute the python code + pyCmd := fmt.Sprintf(` + workingDir="%s"; + agentId="%s"; + + cd $workingDir; + source /tmp/$agentId/venv/bin/activate; + python3 <