forked from Asaad47/llama-finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
83 lines (61 loc) · 2.75 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from modal import Image, gpu, method
import subprocess
import os
from common import stub, BASE_MODELS, VOLUME_CONFIG
tgi_image = (
Image.from_registry("ghcr.io/huggingface/text-generation-inference:1.0.3")
.dockerfile_commands("ENTRYPOINT []")
.pip_install("text-generation", "transformers>=4.33.0")
.env(dict(HUGGINGFACE_HUB_CACHE="/pretrained"))
)
@stub.function(image=tgi_image, volumes=VOLUME_CONFIG, timeout=60 * 20)
def merge(run_id: str, commit: bool = False):
from text_generation_server.utils.peft import download_and_unload_peft
os.mkdir(f"/results/{run_id}/merged")
subprocess.call(f"cp /results/{run_id}/*.* /results/{run_id}/merged", shell=True)
print(f"Merging weights for fine-tuned {run_id=}.")
download_and_unload_peft(f"/results/{run_id}/merged", None, False)
if commit:
print("Committing merged model permanently (can take a few minutes).")
stub.results_volume.commit()
@stub.cls(
image=tgi_image,
gpu=gpu.A100(count=1, memory=40),
allow_concurrent_inputs=100,
volumes=VOLUME_CONFIG,
)
class Model:
def __init__(self, base: str = "", run_id: str = ""):
from text_generation import AsyncClient
import socket
import time
model = f"/results/{run_id}/merged" if run_id else BASE_MODELS[base]
if run_id and not os.path.isdir(model):
merge.local(run_id) # local = run in the same container
print(f"Loading {model} into GPU ... ")
launch_cmd = ["text-generation-launcher", "--model-id", model, "--port", "8000"]
self.launcher = subprocess.Popen(launch_cmd, stdout=subprocess.DEVNULL)
self.client = None
while not self.client and self.launcher.returncode is None:
try:
socket.create_connection(("127.0.0.1", 8000), timeout=1).close()
self.client = AsyncClient("http://127.0.0.1:8000", timeout=60)
except (socket.timeout, ConnectionRefusedError):
time.sleep(1.0)
assert self.launcher.returncode is None
def __exit__(self, _exc_type, _exc_value, _traceback):
self.launcher.terminate()
@method()
async def generate(self, prompt: str):
result = await self.client.generate(prompt, max_new_tokens=512)
return result.generated_text
@stub.local_entrypoint()
def main(prompt: str, base: str, run_id: str = "", batch: int = 1):
print(f"Running completion for prompt:\n{prompt}")
print("=" * 20 + "Generating without adapter" + "=" * 20)
for output in Model(base).generate.map([prompt] * batch):
print(output)
if run_id:
print("=" * 20 + "Generating with adapter" + "=" * 20)
for output in Model(base, run_id).generate.map([prompt] * batch):
print(output)