Skip to content

Commit

Permalink
Add dynamodb retry config for throttling and other errors. Add expone…
Browse files Browse the repository at this point in the history
…ntial backoff and jitter for unprocessed keys. Fix edge case where we succesfully process keys on our last attempt but still fail
  • Loading branch information
KaspariK committed Jan 15, 2025
1 parent 04418af commit b361235
Showing 1 changed file with 55 additions and 30 deletions.
85 changes: 55 additions & 30 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import pickle
import random
import threading
import time
from collections import defaultdict
Expand All @@ -20,6 +21,7 @@
from typing import TypeVar

import boto3 # type: ignore
from botocore.config import Config

import tron.prom_metrics as prom_metrics
from tron.core.job import Job
Expand All @@ -35,16 +37,33 @@
# to contain other attributes like object name and number of partitions.
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 10
# This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
# keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
# infinite loops in the case where a key is truly unprocessable.
MAX_UNPROCESSED_KEYS_RETRIES = 10
MAX_TRANSACT_WRITE_ITEMS = 100
log = logging.getLogger(__name__)
T = TypeVar("T")


class DynamoDBStateStore:
def __init__(self, name, dynamodb_region, stopping=False) -> None:
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region)
self.client = boto3.client("dynamodb", region_name=dynamodb_region)
# Standard mode includes an exponential backoff by a base factor of 2 for a
# maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
# random number between 0 and 1 and r is the base factor of 2). This might
# look like:
#
# seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
#
# By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
# on the random jitter.
#
# It handles transient errors like RequestTimeout and ConnectionError, as well
# as Service-side errors like Throttling, SlowDown, and LimitExceeded.
retry_config = Config(retries={"max_attempts": 5, "mode": "standard"})

self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region, config=retry_config)
self.client = boto3.client("dynamodb", region_name=dynamodb_region, config=retry_config)
self.name = name
self.dynamodb_region = dynamodb_region
self.table = self.dynamodb.Table(name)
Expand All @@ -63,11 +82,11 @@ def build_key(self, type, iden) -> str:

def restore(self, keys, read_json: bool = False) -> dict:
"""
Fetch all under the same parition key(s).
Fetch all under the same partition key(s).
ret: <dict of key to states>
"""
# format of the keys always passed here is
# job_state job_name --> high level info about the job: enabled, run_nums
# job_state job_name --> high level info about the job: enabled, run_nums
# job_run_state job_run_name --> high level info about the job run
first_items = self._get_first_partitions(keys)
remaining_items = self._get_remaining_partitions(first_items, read_json)
Expand All @@ -87,8 +106,11 @@ def _get_items(self, table_keys: list) -> object:
items = []
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts_to_retrieve_keys = 0
while len(cand_keys_list) != 0:
attempts = 0
base_delay = 0.5
max_delay = 10

while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
Expand All @@ -106,20 +128,33 @@ def _get_items(self, table_keys: list) -> object:
cand_keys_list = []
for resp in concurrent.futures.as_completed(responses):
try:
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
cand_keys_list.extend(resp.result()["UnprocessedKeys"][self.name]["Keys"])
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
)
raise error
result = resp.result()
items.extend(result.get("Responses", {}).get(self.name, []))

# If DynamoDB returns unprocessed keys, we need to collect them and retry
unprocessed_keys = result.get("UnprocessedKeys", {}).get(self.name, {}).get("Keys", [])
if unprocessed_keys:
cand_keys_list.extend(unprocessed_keys)
except Exception as e:
log.exception("Encountered issues retrieving data from DynamoDB")
raise e
attempts_to_retrieve_keys += 1
if cand_keys_list:
attempts += 1
# Exponential backoff for retrying unprocessed keys
exponential_delay = min(base_delay * (2 ** (attempts - 1)), max_delay)
# Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls
jitter = random.uniform(0, exponential_delay)
delay = jitter
log.warning(
f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - Retrying {len(cand_keys_list)} unprocessed keys after {delay:.2f}s delay."
)
time.sleep(delay)
if cand_keys_list:
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys_list}\n from dynamodb after {MAX_UNPROCESSED_KEYS_RETRIES} retries."
)
log.error(repr(error))
raise error
return items

def _get_first_partitions(self, keys: list):
Expand Down Expand Up @@ -336,25 +371,15 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
"N": str(num_json_val_partitions),
}

count = 0
items.append(item)

while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
try:
self.client.transact_write_items(TransactItems=items)
items = []
break # exit the while loop on successful writing
except Exception as e:
count += 1
if count > 3:
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
)
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
raise e
else:
log.warning(f"Got error while saving {key}, trying again: {repr(e)}")
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
raise e
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
Expand Down

0 comments on commit b361235

Please sign in to comment.