Skip to content

Commit

Permalink
[Misc] Enhance offline_inference to support user-configurable paramet… (
Browse files Browse the repository at this point in the history
#10392)

Signed-off-by: wchen61 <[email protected]>
  • Loading branch information
wchen61 authored Nov 17, 2024
1 parent 80d85c5 commit d1557e6
Showing 1 changed file with 78 additions and 20 deletions.
98 changes: 78 additions & 20 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,80 @@
from dataclasses import asdict

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


def get_prompts(num_prompts: int):
# The default sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

if num_prompts != len(prompts):
prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]

return prompts


def main(args):
# Create prompts
prompts = get_prompts(args.num_prompts)

# Create a sampling params object.
sampling_params = SamplingParams(n=args.n,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
max_tokens=args.max_tokens)

# Create an LLM.
# The default model is 'facebook/opt-125m'
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**asdict(engine_args))

# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
group = parser.add_argument_group("SamplingParams options")
group.add_argument("--num-prompts",
type=int,
default=4,
help="Number of prompts used for inference")
group.add_argument("--max-tokens",
type=int,
default=16,
help="Generated output length for sampling")
group.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt')
group.add_argument('--temperature',
type=float,
default=0.8,
help='Temperature for text generation')
group.add_argument('--top-p',
type=float,
default=0.95,
help='top_p for text generation')
group.add_argument('--top-k',
type=int,
default=-1,
help='top_k for text generation')

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
args = parser.parse_args()
main(args)

0 comments on commit d1557e6

Please sign in to comment.