Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rectorization to introduce the separation between Model metadata and deployment #1903

Merged
merged 112 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from 101 commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
e6b9729
Add the Tokenizer object logic
JosselinSomervilleRoberts Oct 4, 2023
38784f7
Changed most clients to use a Tokenizer object
JosselinSomervilleRoberts Oct 5, 2023
daf431b
Changed remaining clients to use a Tokenizer object
JosselinSomervilleRoberts Oct 6, 2023
cfd0c15
Removed calls to CachableClient.tokenize()
JosselinSomervilleRoberts Oct 6, 2023
4bf3c46
Add TODOs
JosselinSomervilleRoberts Oct 7, 2023
d6341ef
Make client methods abstract
JosselinSomervilleRoberts Oct 7, 2023
d2ac135
Resolve merge conflicts
JosselinSomervilleRoberts Oct 9, 2023
4e83fd2
Fix ICE Tokenizer test
JosselinSomervilleRoberts Oct 9, 2023
0445a78
Fix Critique breaking change
JosselinSomervilleRoberts Oct 9, 2023
e1cbe32
Revert fix
JosselinSomervilleRoberts Oct 9, 2023
40337f7
Fix all window service test issues except for Cohere
JosselinSomervilleRoberts Oct 10, 2023
037a869
Resolve merge conflicts with HuggingFace refactorization
JosselinSomervilleRoberts Oct 11, 2023
1077699
Refactor CachableClient -> CachingClient
JosselinSomervilleRoberts Oct 11, 2023
b0fefef
Refactor yalm_tokenizer_src -> yalm_tokenizer_data
JosselinSomervilleRoberts Oct 11, 2023
0e47750
Merge #1891
JosselinSomervilleRoberts Oct 11, 2023
c1b06e3
First draft of the model deployment/metadata refactorization
JosselinSomervilleRoberts Oct 17, 2023
802b2ec
Fix one of the TODO
JosselinSomervilleRoberts Oct 17, 2023
a907774
Merge branch 'main' into joss-refactor-1-tokenizer
JosselinSomervilleRoberts Oct 17, 2023
c9aa4fd
CachableTokenizer -> CachingTokenizer
JosselinSomervilleRoberts Oct 17, 2023
b1badb7
Port VLM model Idefics to use new Tokenizer logic
JosselinSomervilleRoberts Oct 17, 2023
57fd565
Add TODOs to remove tokenize and decode methods from Client
JosselinSomervilleRoberts Oct 17, 2023
50bb454
Change methods of CachingTokenizer to preserve existing Cache
JosselinSomervilleRoberts Oct 17, 2023
9876162
Change raw_request to request in _tokenization_raw_response_to_tokens…
JosselinSomervilleRoberts Oct 18, 2023
359994c
Merge branch 'main' into joss-refactor-1-tokenizer
JosselinSomervilleRoberts Oct 18, 2023
4d8c080
First draft of the model deployment/metadata refactorization
JosselinSomervilleRoberts Oct 17, 2023
ec97f8d
Fix one of the TODO
JosselinSomervilleRoberts Oct 17, 2023
068b199
Merge branch 'joss-refactor-4-deployments' of https://github.com/stan…
JosselinSomervilleRoberts Oct 18, 2023
427e723
Fixing black and mypy
JosselinSomervilleRoberts Oct 18, 2023
15d3cc4
Merge main
JosselinSomervilleRoberts Oct 25, 2023
06d5707
Support model in conf
JosselinSomervilleRoberts Oct 25, 2023
cebc733
Mode model definition to yaml
JosselinSomervilleRoberts Oct 27, 2023
e73c8a9
Done all metadata until Cohere (included)
JosselinSomervilleRoberts Oct 28, 2023
6edb232
Fix helm-summarize
JosselinSomervilleRoberts Oct 28, 2023
ae414ca
AI21 models done
JosselinSomervilleRoberts Oct 28, 2023
01be528
AI21 comment updated
JosselinSomervilleRoberts Oct 28, 2023
24e3446
Aleph Alpha deployments done
JosselinSomervilleRoberts Oct 28, 2023
5a3e504
Change AlephAlpha API key name
JosselinSomervilleRoberts Oct 28, 2023
57fd268
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Oct 30, 2023
a412734
Add api_key binding for tokenizers
JosselinSomervilleRoberts Oct 30, 2023
faacc1b
Added AI21 Tokenizer
JosselinSomervilleRoberts Oct 30, 2023
f2fcd3e
Added AlephAlpha Tokenizer
JosselinSomervilleRoberts Oct 31, 2023
3495980
Added most of the model metadatas
JosselinSomervilleRoberts Oct 31, 2023
2fadd8c
Add last model metadatas
JosselinSomervilleRoberts Nov 1, 2023
71d76e6
Correct few errors in metadats
JosselinSomervilleRoberts Nov 1, 2023
bc8b068
Added Anthropic, BigScience and BigCode model deployments
JosselinSomervilleRoberts Nov 1, 2023
48ac8eb
Removed tags that were not necessary anymore with the new architecture
JosselinSomervilleRoberts Nov 1, 2023
3ca69f1
Added and tested tokenizers for Anthropic, BigCode and BigScience
JosselinSomervilleRoberts Nov 1, 2023
7f6dc4e
Added Cohere deployments
JosselinSomervilleRoberts Nov 1, 2023
56796cc
Use register_model_metadata
JosselinSomervilleRoberts Nov 1, 2023
a605581
Added Cohere command models and deprecated old Cohere models
JosselinSomervilleRoberts Nov 3, 2023
f27ac03
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 3, 2023
3b875cb
Added all tokenizers
JosselinSomervilleRoberts Nov 3, 2023
a3c9479
Added almost all model deployments (except palm, neurips and simple) …
JosselinSomervilleRoberts Nov 3, 2023
1a71401
Added deprecated field and updated many Together models
JosselinSomervilleRoberts Nov 4, 2023
3603549
Cleaning up
JosselinSomervilleRoberts Nov 4, 2023
077c1e5
Updating arguments for dependency injection
JosselinSomervilleRoberts Nov 4, 2023
15f60d9
Clean up handling of old keyword 'model'
JosselinSomervilleRoberts Nov 4, 2023
080702b
Better handle backward compatibility and remove auto model metadata
JosselinSomervilleRoberts Nov 4, 2023
af28659
Set some together model to legacy
JosselinSomervilleRoberts Nov 6, 2023
08fc833
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 6, 2023
c47c436
Nearly all tests should now pass
JosselinSomervilleRoberts Nov 7, 2023
578185a
All tests should now pass
JosselinSomervilleRoberts Nov 7, 2023
2d1ebb5
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 7, 2023
e5436d3
Trying to make the regression tests pass
JosselinSomervilleRoberts Nov 7, 2023
48e01c9
Lazy instantiate Aleph Alpha Client to pass regression test
JosselinSomervilleRoberts Nov 7, 2023
64b1685
Trying to make summarize work as expected
JosselinSomervilleRoberts Nov 7, 2023
84b6f48
Trying to make summarize compatible with old runs
JosselinSomervilleRoberts Nov 7, 2023
97377bf
helm-summarize is now compatible with old HELM
JosselinSomervilleRoberts Nov 7, 2023
bf4e201
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 8, 2023
7d0af69
Changes to frontend
JosselinSomervilleRoberts Nov 8, 2023
00d730f
Making sure all the input type for helm-run are handled properly
JosselinSomervilleRoberts Nov 9, 2023
205b871
Remove # ========= # in configs
JosselinSomervilleRoberts Nov 9, 2023
e177a1c
Rename host group to host organization
JosselinSomervilleRoberts Nov 9, 2023
401fcb4
Deleting get_default_deployment_for_model()
JosselinSomervilleRoberts Nov 9, 2023
ccff08e
Change creator organization to host organization for Cache
JosselinSomervilleRoberts Nov 9, 2023
9875991
A lot of small changes to answer some comments on the PR
JosselinSomervilleRoberts Nov 9, 2023
a892bab
Rename model_metadatas.yaml to singular
JosselinSomervilleRoberts Nov 9, 2023
8f0e964
Set openai/text-embedding-ada-002 as non deprecated
JosselinSomervilleRoberts Nov 9, 2023
fa04702
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 9, 2023
67a7351
Remove fancy headers
JosselinSomervilleRoberts Nov 9, 2023
3c61bf6
Add back model to adapter_keys_shown
JosselinSomervilleRoberts Nov 9, 2023
7628ce5
Revert frontend changes
JosselinSomervilleRoberts Nov 9, 2023
8901151
Added request.model_deployment
JosselinSomervilleRoberts Nov 9, 2023
606382a
Remove calls to client.tokenize
JosselinSomervilleRoberts Nov 9, 2023
33d8470
Added Lit GPT
JosselinSomervilleRoberts Nov 9, 2023
40c18cc
Added Neurips local
JosselinSomervilleRoberts Nov 9, 2023
47704c8
Fix YAML typo for lit-gpt
JosselinSomervilleRoberts Nov 9, 2023
d8d97ff
Fix test_run_entry
JosselinSomervilleRoberts Nov 9, 2023
4b461a3
Add some missing comments
JosselinSomervilleRoberts Nov 9, 2023
3e490ad
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 13, 2023
853eca6
Add test to ensure that all models are available
JosselinSomervilleRoberts Nov 13, 2023
c6f8df7
Update comment style
JosselinSomervilleRoberts Nov 13, 2023
1fa74ce
Update get_deployment_name_from_model_arg()
JosselinSomervilleRoberts Nov 13, 2023
e4a8db4
Rename maybe_register_helm and move it to its own file
JosselinSomervilleRoberts Nov 14, 2023
8e515b5
Remove deprecation warning for cases like mode=text
JosselinSomervilleRoberts Nov 14, 2023
e0ca539
Split test read run specs in several tests
JosselinSomervilleRoberts Nov 14, 2023
d3f0e4e
Change Exception to Warning when deployments are found
JosselinSomervilleRoberts Nov 14, 2023
3f97b6d
Use importlib so that local paths work on a pypi install
JosselinSomervilleRoberts Nov 14, 2023
7cecdbd
Add default model metadata registration for huggingface models
JosselinSomervilleRoberts Nov 14, 2023
a8d931b
Changing Request so that model and model_deployment are always both f…
JosselinSomervilleRoberts Nov 14, 2023
57431f1
Fix test server service
JosselinSomervilleRoberts Nov 14, 2023
67252bc
Update tutorial
JosselinSomervilleRoberts Nov 15, 2023
abf9d0a
Alternative model deployment proposal (#2002)
yifanmai Nov 15, 2023
6ba0199
Merge main
JosselinSomervilleRoberts Nov 15, 2023
4b849a7
Fix helm-run and a few tests
JosselinSomervilleRoberts Nov 15, 2023
5749a8f
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 15, 2023
6f8a13c
Fix broken test
JosselinSomervilleRoberts Nov 17, 2023
e7eb250
Small fixes to the configs
JosselinSomervilleRoberts Nov 17, 2023
f177503
Fix Mistral #1998
JosselinSomervilleRoberts Nov 17, 2023
a3323bd
Fix files that were still not speciying both model and deployment in …
JosselinSomervilleRoberts Nov 17, 2023
c2d3e03
Fix some mypy issues
JosselinSomervilleRoberts Nov 17, 2023
5ce5648
Merge branch 'main' into joss-refactor-4-deployments
JosselinSomervilleRoberts Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
print(account.usages)

# Make a request
request = Request(model="ai21/j1-large", prompt="Life is like a box of", echo_prompt=True)
request = Request(model_deployment="ai21/j2-large", prompt="Life is like a box of", echo_prompt=True)
request_result: RequestResult = service.make_request(auth, request)
print(request_result.completions[0].text)

Expand All @@ -28,12 +28,12 @@
print(request_result.completions[0].text)

# How to get the embedding for some text
request = Request(model="openai/text-similarity-ada-001", prompt="Life is like a box of", embedding=True)
request = Request(model_deployment="openai/text-similarity-ada-002", prompt="Life is like a box of", embedding=True)
request_result = service.make_request(auth, request)
print(request_result.embedding)

# Tokenize
request = TokenizationRequest(tokenizer="ai21/j1-jumbo", text="Tokenize me please.")
request = TokenizationRequest(tokenizer="ai21/j2-jumbo", text="Tokenize me please.")
tokenization_request_result: TokenizationRequestResult = service.tokenize(auth, request)
print(f"Number of tokens: {len(tokenization_request_result.tokens)}")

Expand Down
4 changes: 2 additions & 2 deletions scripts/compute_request_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def try_request(

try:
request = Request(
model=model_name,
model_deployment=model_name,
prompt=prefix + " ".join(["hello"] * (sequence_length - num_tokens_prefix - num_tokens_suffix)) + suffix,
max_tokens=num_tokens,
)
Expand Down Expand Up @@ -287,7 +287,7 @@ def main():
print("client successfully created")

print("Making short request...")
request = Request(model=args.model_name, prompt=args.prefix + "hello" + args.suffix, max_tokens=1)
request = Request(model_deployment=args.model_name, prompt=args.prefix + "hello" + args.suffix, max_tokens=1)
response = client.make_request(request)
if not response.success:
raise ValueError("Request failed")
Expand Down
6 changes: 5 additions & 1 deletion src/helm/benchmark/adaptation/adapter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ class AdapterSpec:

# Decoding parameters (inherited by `Request`)

# Model to make the request to (need to fill in)
# Model deployment to make the request to (need to fill in)
model_deployment: str = ""
yifanmai marked this conversation as resolved.
Show resolved Hide resolved

# DEPRECATED: old model field, kept for backward compatibility
# TODO: Remove this once we do not wish to support backward compatibility anymore.
model: str = ""

# Temperature to use
Expand Down
2 changes: 1 addition & 1 deletion src/helm/benchmark/adaptation/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Adapter(ABC):
def __init__(self, adapter_spec: AdapterSpec, tokenizer_service: TokenizerService):
self.adapter_spec: AdapterSpec = adapter_spec
self.window_service: WindowService = WindowServiceFactory.get_window_service(
adapter_spec.model, tokenizer_service
adapter_spec.model_deployment, tokenizer_service
)

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def generate_requests(
reference_index=reference_index,
)
request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
prompt=prompt.text,
num_completions=self.adapter_spec.num_outputs,
temperature=self.adapter_spec.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_requests(
training_instances, eval_instance, include_output=False, reference_index=None
)
request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
prompt=prompt.text,
num_completions=self.adapter_spec.num_outputs,
temperature=self.adapter_spec.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _generate_requests(self, eval_instance: Instance) -> List[RequestState]:
self.window_service.encode(prefix_token).tokens, tokens[:first_seq_len], max_request_length, text
)
request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
prompt=prompt_text,
num_completions=1,
temperature=0,
Expand Down Expand Up @@ -161,7 +161,7 @@ def _generate_requests(self, eval_instance: Instance) -> List[RequestState]:
)

request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
prompt=prompt_text,
num_completions=1,
temperature=0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def generate_requests(
)

request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
multimodal_prompt=prompt.multimedia_object,
num_completions=self.adapter_spec.num_outputs,
temperature=self.adapter_spec.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def generate_requests(
)

request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
multimodal_prompt=prompt.multimedia_object,
num_completions=self.adapter_spec.num_outputs,
temperature=self.adapter_spec.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def teardown_method(self, _):

def test_construct_prompt(self):
adapter_spec: AdapterSpec = AdapterSpec(
model="simple/model1",
model_deployment="simple/model1",
method=ADAPT_GENERATION_MULTIMODAL,
global_prefix="[START]",
instructions="Please answer the following question about the images.",
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_construct_prompt(self):

def test_construct_prompt_multi_label(self):
adapter_spec: AdapterSpec = AdapterSpec(
model="simple/model1",
model_deployment="simple/model1",
method=ADAPT_GENERATION_MULTIMODAL,
global_prefix="[START]",
instructions="Please answer the following question about the images.",
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_construct_prompt_idefics_instruct_example(self):
Constructing the same prompt from this example: https://huggingface.co/blog/idefics
"""
adapter_spec: AdapterSpec = AdapterSpec(
model="simple/model1",
model_deployment="simple/model1",
method=ADAPT_GENERATION_MULTIMODAL,
input_prefix="User: ",
input_suffix="<end_of_utterance>",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def generate_requests(
for reference_index, reference in enumerate(eval_instance.references)
)
request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
prompt=prompt.text,
num_completions=1,
top_k_per_token=self.adapter_spec.num_outputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def construct_request_state(
request_mode: str = "original",
) -> RequestState:
request = Request(
model=self.adapter_spec.model,
model_deployment=self.adapter_spec.model_deployment,
prompt=prompt.text,
num_completions=1,
temperature=0,
Expand Down
22 changes: 15 additions & 7 deletions src/helm/benchmark/adaptation/adapters/test_generation_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_adapt(self):

def test_construct_prompt(self):
adapter_spec = AdapterSpec(
model="openai/davinci",
model_deployment="openai/davinci",
method=ADAPT_GENERATION,
input_prefix="",
input_suffix="",
Expand All @@ -59,7 +59,11 @@ def test_construct_prompt(self):

def test_construct_prompt_with_truncation(self):
adapter_spec = AdapterSpec(
model="openai/davinci", method=ADAPT_GENERATION, input_prefix="", output_prefix="", max_tokens=100
model_deployment="openai/davinci",
method=ADAPT_GENERATION,
input_prefix="",
output_prefix="",
max_tokens=100,
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
correct_reference = Reference(Output(text=""), tags=[CORRECT_TAG])
Expand All @@ -80,7 +84,7 @@ def test_construct_prompt_with_truncation(self):
assert prompt_text.count("eval") == 1948

def test_sample_examples_without_references(self):
adapter_spec = AdapterSpec(method=ADAPT_GENERATION, model="openai/ada", max_train_instances=1)
adapter_spec = AdapterSpec(method=ADAPT_GENERATION, model_deployment="openai/ada", max_train_instances=1)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
all_train_instances = [
Instance(Input(text="prompt1"), references=[]),
Expand All @@ -92,7 +96,7 @@ def test_sample_examples_without_references(self):
assert len(examples) == 1

def test_sample_examples_open_ended_generation(self):
adapter_spec = AdapterSpec(method=ADAPT_GENERATION, model="openai/ada", max_train_instances=3)
adapter_spec = AdapterSpec(method=ADAPT_GENERATION, model_deployment="openai/ada", max_train_instances=3)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)

all_train_instances: List[Instance] = [
Expand All @@ -106,7 +110,7 @@ def test_sample_examples_open_ended_generation(self):
assert seed0_examples != seed1_examples, "Examples should differ when changing the seed"

def test_sample_examples_open_ended_generation_stress(self):
adapter_spec = AdapterSpec(method=ADAPT_GENERATION, model="openai/ada", max_train_instances=5)
adapter_spec = AdapterSpec(method=ADAPT_GENERATION, model_deployment="openai/ada", max_train_instances=5)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)

all_train_instances: List[Instance] = [
Expand Down Expand Up @@ -146,7 +150,7 @@ def test_sample_examples_open_ended_generation_stress(self):

def test_multiple_correct_reference(self):
adapter_spec = AdapterSpec(
method=ADAPT_GENERATION, model="openai/ada", max_train_instances=2, sample_train=False
method=ADAPT_GENERATION, model_deployment="openai/ada", max_train_instances=2, sample_train=False
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
train_instances = [
Expand Down Expand Up @@ -191,7 +195,11 @@ def test_multiple_correct_reference(self):

def test_multiple_correct_reference_multi_label(self):
adapter_spec = AdapterSpec(
method=ADAPT_GENERATION, model="openai/ada", max_train_instances=2, multi_label=True, sample_train=False
method=ADAPT_GENERATION,
model_deployment="openai/ada",
max_train_instances=2,
multi_label=True,
sample_train=False,
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
train_instances = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_construct_language_modeling_prompt(self):
adapter_spec = AdapterSpec(
method=ADAPT_LANGUAGE_MODELING,
input_prefix="",
model="openai/davinci",
model_deployment="openai/davinci",
output_prefix="",
max_tokens=0,
)
Expand All @@ -38,7 +38,7 @@ def test_fits_tokens_within_context_window(self):
adapter_spec = AdapterSpec(
method=ADAPT_LANGUAGE_MODELING,
input_prefix="",
model="openai/curie",
model_deployment="openai/curie",
output_prefix="",
max_tokens=0,
)
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_prompt_truncated(self):
adapter_spec = AdapterSpec(
method=ADAPT_LANGUAGE_MODELING,
input_prefix="",
model="anthropic/claude-v1.3",
model_deployment="anthropic/claude-v1.3",
output_prefix="",
max_tokens=0,
)
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_prompt_truncated(self):
adapter_spec_2_ = AdapterSpec(
method=ADAPT_LANGUAGE_MODELING,
input_prefix="",
model="anthropic/claude-v1.3",
model_deployment="anthropic/claude-v1.3",
output_prefix="",
max_tokens=2000,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

class TestMultipleChoiceJointAdapter(TestAdapter):
def test_sample_examples(self):
adapter_spec = AdapterSpec(method=ADAPT_MULTIPLE_CHOICE_JOINT, model="openai/ada", max_train_instances=4)
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT, model_deployment="openai/ada", max_train_instances=4
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
all_train_instances = [
Instance(Input(text="say no"), references=[Reference(Output(text="no"), tags=[CORRECT_TAG])]),
Expand All @@ -27,13 +29,17 @@ def test_sample_examples(self):
assert examples[3].input.text == "say yes3"

def test_sample_examples_no_train_instances(self):
adapter_spec = AdapterSpec(method=ADAPT_MULTIPLE_CHOICE_JOINT, model="openai/ada", max_train_instances=2)
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT, model_deployment="openai/ada", max_train_instances=2
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
examples = adapter.sample_examples(all_train_instances=[], seed=0)
assert len(examples) == 0

def test_sample_examples_greater_max_train_instances(self):
adapter_spec = AdapterSpec(method=ADAPT_MULTIPLE_CHOICE_JOINT, model="openai/ada", max_train_instances=10)
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT, model_deployment="openai/ada", max_train_instances=10
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
all_train_instances = [
Instance(Input(text="say no"), references=[Reference(Output(text="no"), tags=[CORRECT_TAG])]),
Expand All @@ -46,7 +52,10 @@ def test_sample_examples_greater_max_train_instances(self):

def test_multiple_correct_reference(self):
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT, model="openai/ada", max_train_instances=10, sample_train=False
method=ADAPT_MULTIPLE_CHOICE_JOINT,
model_deployment="openai/ada",
max_train_instances=10,
sample_train=False,
)
adapter = AdapterFactory.get_adapter(adapter_spec, self.tokenizer_service)
train_instances = [
Expand Down Expand Up @@ -101,7 +110,7 @@ def test_multiple_correct_reference(self):
def test_multiple_correct_reference_multi_label(self):
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT,
model="openai/ada",
model_deployment="openai/ada",
max_train_instances=10,
multi_label=True,
sample_train=False,
Expand Down
14 changes: 14 additions & 0 deletions src/helm/benchmark/config_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from helm.benchmark.model_deployment_registry import register_deployments_if_not_already_registered
from helm.benchmark.model_metadata_registry import register_metadatas_if_not_already_registered
from helm.benchmark.tokenizer_config_registry import register_tokenizers_if_not_already_registered

HELM_REGISTERED: bool = False


def register_helm_configurations():
global HELM_REGISTERED
if not HELM_REGISTERED:
register_metadatas_if_not_already_registered()
register_tokenizers_if_not_already_registered()
register_deployments_if_not_already_registered()
HELM_REGISTERED = True
30 changes: 30 additions & 0 deletions src/helm/benchmark/huggingface_registration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import os
from typing import Optional
from datetime import date

from helm.benchmark.model_deployment_registry import (
ClientSpec,
ModelDeployment,
WindowServiceSpec,
register_model_deployment,
)
from helm.benchmark.model_metadata_registry import (
get_model_metadata,
ModelMetadata,
register_model_metadata,
TEXT_MODEL_TAG,
FULL_FUNCTIONALITY_TEXT_MODEL_TAG,
)
from helm.benchmark.tokenizer_config_registry import TokenizerConfig, TokenizerSpec, register_tokenizer_config
from helm.common.hierarchical_logger import hlog


def register_huggingface_model(
Expand All @@ -30,6 +39,27 @@ def register_huggingface_model(
args=object_spec_args,
),
)

# We check if the model is already registered because we don't want to
# overwrite the model metadata if it's already registered.
# If it's not registered, we register it, as otherwise an error would be thrown
# when we try to register the model deployment.
try:
_ = get_model_metadata(model_name=helm_model_name)
except ValueError:
register_model_metadata(
ModelMetadata(
name=helm_model_name,
creator_organization_name="Unknown",
display_name=helm_model_name,
description=helm_model_name,
access="open",
release_date=date.today(),
tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG],
)
)
hlog(f"Registered default metadata for model {helm_model_name}")

register_model_deployment(model_deployment)
tokenizer_config = TokenizerConfig(
name=helm_model_name,
Expand Down
Loading
Loading