Skip to content

Commit

Permalink
feat: Add failure recovery by retrying the whole job (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
sitaowang1998 authored Dec 23, 2024
1 parent 03d043a commit 636b8d3
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 29 deletions.
5 changes: 5 additions & 0 deletions src/spider/core/Task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class Task {
m_state(state),
m_timeout(timeout) {}

void set_max_retries(unsigned int num_retries) { m_max_tries = num_retries; }

void add_input(TaskInput const& input) { m_inputs.emplace_back(input); }

void add_output(TaskOutput const& output) { m_outputs.emplace_back(output); }
Expand All @@ -125,6 +127,8 @@ class Task {

[[nodiscard]] auto get_timeout() const -> float { return m_timeout; }

[[nodiscard]] auto get_max_retries() const -> unsigned int { return m_max_tries; }

[[nodiscard]] auto get_num_inputs() const -> size_t { return m_inputs.size(); }

[[nodiscard]] auto get_num_outputs() const -> size_t { return m_outputs.size(); }
Expand All @@ -142,6 +146,7 @@ class Task {
std::string m_function_name;
TaskState m_state = TaskState::Pending;
float m_timeout = 0;
unsigned int m_max_tries = 0;
std::vector<TaskInput> m_inputs;
std::vector<TaskOutput> m_outputs;
};
Expand Down
18 changes: 17 additions & 1 deletion src/spider/scheduler/SchedulerMessage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,31 @@ class ScheduleTaskRequest {
: m_worker_id{worker_id},
m_worker_addr{std::move(addr)} {}

ScheduleTaskRequest(
boost::uuids::uuid const worker_id,
std::string addr,
boost::uuids::uuid const task_id
)
: m_worker_id{worker_id},
m_worker_addr{std::move(addr)},
m_task_id{task_id} {}

[[nodiscard]] auto has_task_id() const -> bool { return m_task_id.has_value(); }

// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
[[nodiscard]] auto get_task_id() const -> boost::uuids::uuid { return m_task_id.value(); }

[[nodiscard]] auto get_worker_id() const -> boost::uuids::uuid { return m_worker_id; }

[[nodiscard]] auto get_worker_addr() const -> std::string const& { return m_worker_addr; }

MSGPACK_DEFINE_ARRAY(m_worker_id, m_worker_addr);
MSGPACK_DEFINE_ARRAY(m_worker_id, m_worker_addr, m_task_id);

private:
boost::uuids::uuid m_worker_id;
std::string m_worker_addr;
// Optional task id if the task fails
std::optional<boost::uuids::uuid> m_task_id = std::nullopt;
};

class ScheduleTaskResponse {
Expand Down
20 changes: 20 additions & 0 deletions src/spider/scheduler/SchedulerServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <boost/uuid/uuid_io.hpp>
#include <spdlog/spdlog.h>

#include "../core/Error.hpp"
#include "../io/BoostAsio.hpp" // IWYU pragma: keep
#include "../io/MsgPack.hpp" // IWYU pragma: keep
#include "../io/msgpack_message.hpp"
Expand Down Expand Up @@ -131,6 +132,25 @@ auto SchedulerServer::process_message(boost::asio::ip::tcp::socket socket
}
ScheduleTaskRequest const& request = optional_request.value();

// Reset the whole job if the task fails
if (request.has_task_id()) {
boost::uuids::uuid job_id;
core::StorageErr err = m_metadata_store->get_task_job_id(request.get_task_id(), &job_id);
// It is possible the job is deleted, so we don't need to reset it
if (!err.success()) {
spdlog::error(
"Cannot get job id for task {}",
boost::uuids::to_string(request.get_task_id())
);
} else {
err = m_metadata_store->reset_job(job_id);
if (!err.success()) {
spdlog::error("Cannot reset job {}", boost::uuids::to_string(job_id));
co_return;
}
}
}

std::optional<boost::uuids::uuid> const task_id = m_policy->schedule_next(
m_metadata_store,
m_data_store,
Expand Down
1 change: 1 addition & 0 deletions src/spider/storage/MetadataStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MetadataStorage {
std::vector<boost::uuids::uuid>* job_ids
) -> StorageErr = 0;
virtual auto remove_job(boost::uuids::uuid id) -> StorageErr = 0;
virtual auto reset_job(boost::uuids::uuid id) -> StorageErr = 0;
virtual auto add_child(boost::uuids::uuid parent_id, Task const& child) -> StorageErr = 0;
virtual auto get_task(boost::uuids::uuid id, Task* task) -> StorageErr = 0;
virtual auto get_task_job_id(boost::uuids::uuid id, boost::uuids::uuid* job_id) -> StorageErr
Expand Down
65 changes: 60 additions & 5 deletions src/spider/storage/MysqlStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ char const* const cCreateTaskTable = R"(CREATE TABLE IF NOT EXISTS tasks (
`state` ENUM('pending', 'ready', 'running', 'success', 'cancel', 'fail') NOT NULL,
`timeout` FLOAT,
`max_retry` INT UNSIGNED DEFAULT 0,
`retry` INT UNSIGNED DEFAULT 0,
`instance_id` BINARY(16),
CONSTRAINT `task_job_id` FOREIGN KEY (`job_id`) REFERENCES `jobs` (`id`) ON UPDATE NO ACTION ON DELETE CASCADE,
PRIMARY KEY (`id`)
Expand Down Expand Up @@ -395,7 +396,7 @@ void MySqlMetadataStorage::add_task(sql::bytes job_id, Task const& task) {
// Add task
std::unique_ptr<sql::PreparedStatement> task_statement(
m_conn->prepareStatement("INSERT INTO `tasks` (`id`, `job_id`, `func_name`, `state`, "
"`timeout`) VALUES (?, ?, ?, ?, ?)")
"`timeout`, `max_retry`) VALUES (?, ?, ?, ?, ?, ?)")
);
sql::bytes task_id_bytes = uuid_get_bytes(task.get_id());
// NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers, readability-magic-numbers)
Expand All @@ -404,6 +405,7 @@ void MySqlMetadataStorage::add_task(sql::bytes job_id, Task const& task) {
task_statement->setString(3, task.get_function_name());
task_statement->setString(4, task_state_to_string(task.get_state()));
task_statement->setFloat(5, task.get_timeout());
task_statement->setUInt(6, task.get_max_retries());
// NOLINTEND(cppcoreguidelines-avoid-magic-numbers, readability-magic-numbers)
task_statement->executeUpdate();

Expand Down Expand Up @@ -855,6 +857,57 @@ auto MySqlMetadataStorage::remove_job(boost::uuids::uuid id) -> StorageErr {
return StorageErr{};
}

auto MySqlMetadataStorage::reset_job(boost::uuids::uuid const id) -> StorageErr {
try {
// Check for retry count on all tasks
std::unique_ptr<sql::PreparedStatement> retry_statement(m_conn->prepareStatement(
"SELECT `id` FROM `tasks` WHERE `job_id` = ? AND `retry` >= `max_retry`"
));
sql::bytes job_id_bytes = uuid_get_bytes(id);
retry_statement->setBytes(1, &job_id_bytes);
std::unique_ptr<sql::ResultSet> const res(retry_statement->executeQuery());
if (res->rowsCount() > 0) {
m_conn->commit();
return StorageErr{StorageErrType::Success, "Some tasks have reached max retry count"};
}
// Increment the retry count for all tasks
std::unique_ptr<sql::PreparedStatement> increment_statement(m_conn->prepareStatement(
"UPDATE `tasks` SET `retry` = `retry` + 1 WHERE `job_id` = ?"
));
increment_statement->setBytes(1, &job_id_bytes);
increment_statement->executeUpdate();
// Reset states for all tasks. Head tasks should be ready and other tasks should be pending
std::unique_ptr<sql::PreparedStatement> state_statement(m_conn->prepareStatement(
"UPDATE `tasks` SET `state` = IF(`id` NOT IN (SELECT `task_id` FROM `task_inputs` "
"WHERE `task_id` IN (SELECT `id` FROM `tasks` WHERE `job_id` = ?) AND "
"`output_task_id` IS NOT NULL), 'ready', 'pending') WHERE job_id = ?"
));
state_statement->setBytes(1, &job_id_bytes);
state_statement->setBytes(2, &job_id_bytes);
state_statement->executeUpdate();
// Clear outputs for all tasks
std::unique_ptr<sql::PreparedStatement> output_statement(m_conn->prepareStatement(
"UPDATE `task_outputs` SET `value` = NULL, `data_id` = NULL "
"WHERE `task_id` IN (SELECT `id` FROM `tasks` WHERE `job_id` = ?)"
));
output_statement->setBytes(1, &job_id_bytes);
output_statement->executeUpdate();
// Clear inputs for non-head tasks
std::unique_ptr<sql::PreparedStatement> input_statement(m_conn->prepareStatement(
"UPDATE `task_inputs` SET `value` = NULL, `data_id` = NULL "
"WHERE `task_id` IN (SELECT `id` FROM `tasks` WHERE `job_id` = ?) "
"AND `output_task_id` IS NOT NULL"
));
input_statement->setBytes(1, &job_id_bytes);
input_statement->executeUpdate();
} catch (sql::SQLException& e) {
m_conn->rollback();
return StorageErr{StorageErrType::OtherErr, e.what()};
}
m_conn->commit();
return StorageErr{};
}

auto MySqlMetadataStorage::add_child(boost::uuids::uuid parent_id, Task const& child)
-> StorageErr {
try {
Expand Down Expand Up @@ -935,11 +988,13 @@ auto MySqlMetadataStorage::get_task_job_id(boost::uuids::uuid id, boost::uuids::

auto MySqlMetadataStorage::get_ready_tasks(std::vector<Task>* tasks) -> StorageErr {
try {
// Get all ready tasks from job that has not failed or cancelled
std::unique_ptr<sql::Statement> statement(m_conn->createStatement());
std::unique_ptr<sql::ResultSet> const res(
statement->executeQuery("SELECT `id`, `func_name`, `state`, `timeout` "
"FROM `tasks` WHERE `state` = 'ready'")
);
std::unique_ptr<sql::ResultSet> res(statement->executeQuery(
"SELECT `id`, `func_name`, `state`, `timeout` FROM `tasks` WHERE `state` = 'ready' "
"AND `job_id` NOT IN (SELECT `job_id` FROM `tasks` WHERE `state` = 'fail' OR "
"`state` = 'cancel')"
));
while (res->next()) {
tasks->emplace_back(fetch_full_task(res));
}
Expand Down
1 change: 1 addition & 0 deletions src/spider/storage/MysqlStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MySqlMetadataStorage : public MetadataStorage {
std::vector<boost::uuids::uuid>* job_ids
) -> StorageErr override;
auto remove_job(boost::uuids::uuid id) -> StorageErr override;
auto reset_job(boost::uuids::uuid id) -> StorageErr override;
auto add_child(boost::uuids::uuid parent_id, Task const& child) -> StorageErr override;
auto get_task(boost::uuids::uuid id, Task* task) -> StorageErr override;
auto get_task_job_id(boost::uuids::uuid id, boost::uuids::uuid* job_id) -> StorageErr override;
Expand Down
20 changes: 9 additions & 11 deletions src/spider/worker/WorkerClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <boost/uuid/uuid.hpp>

#include "../core/Driver.hpp"
#include "../core/Task.hpp"
#include "../io/BoostAsio.hpp" // IWYU pragma: keep
#include "../io/MsgPack.hpp" // IWYU pragma: keep
#include "../io/msgpack_message.hpp"
Expand All @@ -34,16 +33,8 @@ WorkerClient::WorkerClient(
m_data_store(std::move(data_store)),
m_metadata_store(std::move(metadata_store)) {}

auto WorkerClient::task_finish(
core::TaskInstance const& instance,
std::vector<core::TaskOutput> const& outputs
auto WorkerClient::get_next_task(std::optional<boost::uuids::uuid> const& fail_task_id
) -> std::optional<boost::uuids::uuid> {
m_metadata_store->task_finish(instance, outputs);

return get_next_task();
}

auto WorkerClient::get_next_task() -> std::optional<boost::uuids::uuid> {
// Get schedulers
std::vector<core::Scheduler> schedulers;
if (!m_metadata_store->get_active_scheduler(&schedulers).success()) {
Expand Down Expand Up @@ -74,7 +65,14 @@ auto WorkerClient::get_next_task() -> std::optional<boost::uuids::uuid> {
boost::asio::ip::tcp::socket socket(context);
boost::asio::connect(socket, endpoints);

scheduler::ScheduleTaskRequest const request{m_worker_id, m_worker_addr};
scheduler::ScheduleTaskRequest request{m_worker_id, m_worker_addr};
if (fail_task_id.has_value()) {
request = scheduler::ScheduleTaskRequest{
m_worker_id,
m_worker_addr,
fail_task_id.value()
};
}
msgpack::sbuffer request_buffer;
msgpack::pack(request_buffer, request);

Expand Down
8 changes: 1 addition & 7 deletions src/spider/worker/WorkerClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include <boost/uuid/uuid.hpp>

#include "../core/Task.hpp"
#include "../io/BoostAsio.hpp" // IWYU pragma: keep
#include "../storage/DataStorage.hpp"
#include "../storage/MetadataStorage.hpp"
Expand All @@ -30,13 +28,9 @@ class WorkerClient {
std::shared_ptr<core::MetadataStorage> metadata_store
);

auto task_finish(
core::TaskInstance const& instance,
std::vector<core::TaskOutput> const& outputs
auto get_next_task(std::optional<boost::uuids::uuid> const& fail_task_id
) -> std::optional<boost::uuids::uuid>;

auto get_next_task() -> std::optional<boost::uuids::uuid>;

private:
boost::uuids::uuid m_worker_id;
std::string m_worker_addr;
Expand Down
18 changes: 15 additions & 3 deletions src/spider/worker/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,19 @@ auto heartbeat_loop(

constexpr int cFetchTaskTimeout = 100;

auto fetch_task(spider::worker::WorkerClient& client) -> boost::uuids::uuid {
auto fetch_task(
spider::worker::WorkerClient& client,
std::optional<boost::uuids::uuid> fail_task_id
) -> boost::uuids::uuid {
spdlog::debug("Fetching task");
while (true) {
std::optional<boost::uuids::uuid> const optional_task_id = client.get_next_task();
std::optional<boost::uuids::uuid> const optional_task_id
= client.get_next_task(fail_task_id);
if (optional_task_id.has_value()) {
return optional_task_id.value();
}
// If the first request succeeds, later requests should not include the failed task id
fail_task_id = std::nullopt;
std::this_thread::sleep_for(std::chrono::milliseconds(cFetchTaskTimeout));
}
}
Expand Down Expand Up @@ -202,10 +208,12 @@ auto task_loop(
boost::process::v2::environment::value> const& environment,
spider::core::StopToken const& stop_token
) -> void {
std::optional<boost::uuids::uuid> fail_task_id = std::nullopt;
while (!stop_token.stop_requested()) {
boost::asio::io_context context;
boost::uuids::uuid const task_id = fetch_task(client);
boost::uuids::uuid const task_id = fetch_task(client, fail_task_id);
spdlog::debug("Fetched task {}", boost::uuids::to_string(task_id));
fail_task_id = std::nullopt;
// Fetch task detail from metadata storage
spider::core::Task task{""};
spider::core::StorageErr err = metadata_store->get_task(task_id, &task);
Expand Down Expand Up @@ -255,6 +263,7 @@ auto task_loop(
instance,
fmt::format("Task {} failed", task.get_function_name())
);
fail_task_id = task_id;
continue;
}

Expand All @@ -270,6 +279,7 @@ auto task_loop(
task.get_function_name()
)
);
fail_task_id = task_id;
continue;
}
std::vector<msgpack::sbuffer> const& result_buffers = optional_result_buffers.value();
Expand All @@ -283,12 +293,14 @@ auto task_loop(
task.get_function_name()
)
);
fail_task_id = task_id;
continue;
}
std::vector<spider::core::TaskOutput> const& outputs = optional_outputs.value();
// Submit result
spdlog::debug("Submitting result for task {}", boost::uuids::to_string(task_id));
err = metadata_store->task_finish(instance, outputs);
fail_task_id = std::nullopt;
if (!err.success()) {
spdlog::error("Submit task {} fails: {}", task.get_function_name(), err.description);
}
Expand Down
12 changes: 10 additions & 2 deletions tests/integration/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Task:
inputs: List[TaskInput]
outputs: List[TaskOutput]
timeout: float = 0.0
max_retries: int = 0


@dataclass
Expand Down Expand Up @@ -95,8 +96,15 @@ def submit_job(conn, client_id: uuid.UUID, graph: TaskGraph):
else:
state = "pending"
cursor.execute(
"INSERT INTO tasks (id, job_id, func_name, state, timeout) VALUES (%s, %s, %s, %s, %s)",
(task.id.bytes, graph.id.bytes, task.function_name, state, task.timeout),
"INSERT INTO tasks (id, job_id, func_name, state, timeout, max_retry) VALUES (%s, %s, %s, %s, %s, %s)",
(
task.id.bytes,
graph.id.bytes,
task.function_name,
state,
task.timeout,
task.max_retries,
),
)

for i, task_input in enumerate(task.inputs):
Expand Down
Loading

0 comments on commit 636b8d3

Please sign in to comment.