-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathmain.py
49 lines (40 loc) · 1.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import logging
import os
from typing import Any, AsyncGenerator
from llama_cpp import Llama
from leapfrogai_sdk import BackendConfig
from leapfrogai_sdk.llm import LLM, GenerationConfig
logging.basicConfig(
level=os.getenv("LFAI_LOG_LEVEL", logging.INFO),
format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s",
)
logger = logging.getLogger(__name__)
@LLM
class Model:
backend_config = BackendConfig()
if not os.path.exists(backend_config.model.source):
raise ValueError(f"Model path ({backend_config.model.source}) does not exist")
llm = Llama(
model_path=backend_config.model.source,
n_ctx=backend_config.max_context_length,
n_gpu_layers=0,
)
async def generate(
self, prompt: str, config: GenerationConfig
) -> AsyncGenerator[str, Any]:
logger.info("Begin generating streamed response")
for res in self.llm(
prompt,
stream=True,
temperature=config.temperature,
max_tokens=config.max_new_tokens,
top_p=config.top_p,
top_k=config.top_k,
stop=self.backend_config.stop_tokens,
):
yield res["choices"][0]["text"] # type: ignore
logger.info("Streamed response complete")
async def count_tokens(self, raw_text: str) -> int:
string_bytes: bytes = bytes(raw_text, "utf-8")
tokens: list[int] = self.llm.tokenize(string_bytes)
return len(tokens)