Skip to content

Commit

Permalink
Include new task to train models with task cluster (#3748)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpangas authored Nov 8, 2023
1 parent 49e304d commit 1abf12e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
31 changes: 31 additions & 0 deletions .taskcluster.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ tasks:
then: ${event.pull_request.head.repo.html_url}
else: ${event.repository.html_url}

pr_description:
$if: 'tasks_for == "github-pull-request"'
then: ${event.pull_request.body}

taskboot_image: "mozilla/taskboot:0.4.2"
in:
$if: 'tasks_for == "github-push" || (tasks_for == "github-pull-request" && event["action"] in ["opened", "reopened", "synchronize"])'
Expand Down Expand Up @@ -206,6 +210,33 @@ tasks:
owner: [email protected]
source: ${repository}/raw/${head_rev}/.taskcluster.yml

- $if: 'tasks_for == "github-pull-request" && "Train on Taskcluster: " in pr_description'
then:
dependencies:
- { $eval: as_slugid("docker_build") }
taskId: { $eval: as_slugid("train_on_taskcluster") }
created: { $fromNow: "" }
deadline: { $fromNow: "1 day" }
provisionerId: proj-bugbug
workerType: compute-large
payload:
maxRunTime: 10800
image:
type: task-image
path: public/bugbug/bugbug-base.tar.zst
taskId: { $eval: as_slugid("docker_build") }
env:
PR_DESCRIPTION: "${pr_description}"
command:
- "/bin/sh"
- "-lcx"
- "python -m scripts.trainer $(python -m scripts.trainer_extract_args)"
metadata:
name: bugbug train on TC
description: Train a BugBug model on Taskcluster
owner: [email protected]
source: ${repository}/raw/${head_rev}/.taskcluster.yml

- taskId: { $eval: as_slugid("frontend_build") }
created: { $fromNow: "" }
deadline: { $fromNow: "1 hour" }
Expand Down
39 changes: 39 additions & 0 deletions scripts/trainer_extract_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.

import logging
import os
import re

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_model_name() -> str | None:
pr_description = os.environ.get("PR_DESCRIPTION")
if not pr_description:
logger.error("The PR_DESCRIPTION environment variable does not exist")
return None

match = re.search(r"Train on Taskcluster:\s+([a-z_1-9]+)", pr_description)
if not match:
logger.error(
"Could not identify the model name using the 'Train on Taskcluster' keyword from the Pull Request description"
)
return None

model_name = match.group(1)

return model_name


def main():
model = get_model_name()
if model:
print(model)


if __name__ == "__main__":
main()

0 comments on commit 1abf12e

Please sign in to comment.