-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into bump_quepid_7_7_0
- Loading branch information
Showing
9 changed files
with
406 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Train Models with Jina's Finetuner | ||
|
||
This folder contains some example scripts to fine-tune embedding (vector search) models. | ||
More details on the models produced by these scripts and fine-tuning in general are provided in the 13th kata. | ||
|
||
All the commands displayed below are meant to be executed from within the folder `finetuning`. | ||
|
||
## Setup the Environment | ||
|
||
Finetuner runs with Python. | ||
Accordingly, you need a Python environment to run it. | ||
We recommend you to create also a virtual environment by running: | ||
``python3 -m venv venv`` | ||
Afterward, you can activate the environment by running: | ||
``source venv/bin/activate`` | ||
To install the dependencies, to run the fine-tuning examples, run | ||
``pip install -r requirements.txt`` | ||
|
||
## Run Unsupervised SBERT fine-tuning (Text-to-Text Search) | ||
|
||
To automatically create training data finetuner needs a dataset of documents and a set of queries. | ||
For the corpus we use the product data in `data-encoder/ecommerce/vectors/data`. | ||
Since this dataset does not come with a set of queries (which can usually be extracted from a query log), we generate some queries. | ||
To download them one can execute the following command: | ||
``` | ||
wget https://finetuner-ecommerce-experiment.s3.eu-central-1.amazonaws.com/generated-queries.jsonl -P sbert_unsupervised/ | ||
``` | ||
To extract queries and documents and transform them into Finetuner's preferred format for finetuner execute the following command*: | ||
``` | ||
python3 sbert_unsupervised/data_preparation.py | ||
``` | ||
*While running the script, a browser window will open up and ask you to log into your Jina account. If you don't have one, you need to sign up to run the script. | ||
After you logged into your account, a session token will be sent back to the script. This token is then used to authenticate the script to the Jina Cloud in order to download or upload models and datasets. | ||
|
||
After that, you can run a job to generated training data for a text-to-text embedding model automatically by running: | ||
``` | ||
python3 sbert_unsupervised/data_synthesis.py | ||
``` | ||
|
||
If you terminated the script or the connection to the log stream got interrupted (which happens from time to time), | ||
you can execute the following code in a script or the Python interpreter to get the current status the logs of the job: | ||
|
||
```python | ||
import finetuner | ||
|
||
finetuner.login() | ||
|
||
run = finetuner.get_run('ecommerce-synthesis') | ||
print(f'Status: {run.status()["status"]}') | ||
print('Logs:\n', run.logs()) | ||
|
||
``` | ||
|
||
|
||
Finally, you can run the fine-tuning job by running: | ||
``` | ||
python3 sbert_unsupervised/finetune.py | ||
``` | ||
This will save a fine-tuned model in the `finetuning` folder with the name `sbert_unsupervised`. | ||
|
||
## Run Unsupervised CLIP fine-tuning (Text-to-Image Search) | ||
|
||
For fine-tuning a CLIP model we directly create a dataset with text image pairs from the product dataset. | ||
This can be executed with the following command: | ||
``` | ||
python3 clip_unsupervised/data_preparation.py | ||
``` | ||
After that the fine-tuning job can be executed with the following command: | ||
``` | ||
python3 clip_unsupervised/finetune.py | ||
``` | ||
|
||
## Integrate the Fine-Tuned Models into Chorus | ||
|
||
After running the scripts mentioned above, you obtain fine-tuned embedding models which can be integrated into Chorus. | ||
If you are only interested in the integration, you can download the final models from the following links: | ||
- [sbert_unsupervised](https://finetuner-ecommerce-experiment.s3.eu-central-1.amazonaws.com/fine-tuned-sbert-model.zip) | ||
- [clip_unsupervised](https://finetuner-ecommerce-experiment.s3.eu-central-1.amazonaws.com/fine-tuned-clip-model.zip) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import json | ||
import click | ||
from tqdm import tqdm | ||
from zipfile import ZipFile | ||
|
||
from docarray import Document, DocumentArray | ||
|
||
|
||
def prepare_training_data(filepaths: list[str]) -> dict[str, str]: | ||
dataset = DocumentArray() | ||
for filepath in filepaths: | ||
with ZipFile(filepath, 'r') as archive: | ||
for filename in archive.namelist(): | ||
with archive.open(filename, 'r') as f: | ||
objs = json.load(f) | ||
for obj in tqdm(objs, desc=f'Load data from "{filename}"'): | ||
title = obj.get('title', '') | ||
supplier = obj.get('supplier', '') | ||
product_type = obj.get('attr_t_product_type') | ||
if 'img_500x500' in obj: | ||
if not title and supplier: | ||
continue | ||
else: | ||
text_value = ( | ||
f'{title} {supplier} {product_type}' | ||
if product_type | ||
else f'{title} {supplier}' | ||
) | ||
try: | ||
img_doc = Document( | ||
uri=obj.get('img_500x500') | ||
).load_uri_to_blob() | ||
text_doc = Document(text=text_value) | ||
dataset.append( | ||
Document(chunks=DocumentArray([text_doc, img_doc])) | ||
) | ||
except: | ||
pass | ||
return dataset | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
'--input', | ||
'-i', | ||
default=[ | ||
'../data-encoder/ecommerce/vectors/data/1.json.zip', | ||
'../data-encoder/ecommerce/vectors/data/2.json.zip', | ||
'../data-encoder/ecommerce/vectors/data/3.json.zip', | ||
'../data-encoder/ecommerce/vectors/data/4.json.zip', | ||
], | ||
multiple=True, | ||
help='Input filepaths', | ||
) | ||
def main(input): | ||
train_dataset = prepare_training_data(input) | ||
train_dataset.summary() | ||
train_dataset.save_binary('clip_unsupervised/clip_train_dataset.da') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import finetuner | ||
from docarray import DocumentArray | ||
|
||
# To use finetuner an account for the Jina AI Cloud is required. | ||
finetuner.login() | ||
|
||
training_data = DocumentArray.from_bytes('clip_unsupervised/clip_train_dataset.da') | ||
|
||
training_run = finetuner.fit( | ||
model='clip-large-en', | ||
# name of the pre-trained model | ||
train_data=training_data, | ||
# path to the prepared training dataset | ||
loss='CLIPLoss', | ||
# contrastive loss function for text-image pairs | ||
optimizer='AdamW', | ||
learning_rate=1e-7, | ||
# choose a small learning rate since the model is already pre-trained | ||
batch_size=8, | ||
# too high batch-sizes can lead to memory issues | ||
epochs=1, | ||
# for fine-tuning usually a low number of epochs is enough | ||
) | ||
|
||
print(training_run.name) | ||
|
||
# print logs | ||
for entry in training_run.stream_logs(): | ||
print(entry) | ||
|
||
# download model | ||
training_run.save_artifact('finetuned_clip_model') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
finetuner[full]=0.7.8 | ||
docarray[commons]==0.21.0 | ||
click==8.1.3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import json | ||
import click | ||
from zipfile import ZipFile | ||
|
||
from typing import List | ||
|
||
from docarray import Document, DocumentArray | ||
|
||
QUERIES_DA_PATH = 'sbert_unsupervised/queries.da' | ||
CORPUS_DA_PATH = 'sbert_unsupervised/corpus.da' | ||
|
||
|
||
def load_queries(filename: str) -> List[str]: | ||
"""Loads queries from a jsonl file.""" | ||
queries = list() | ||
with open(filename, 'r') as f: | ||
for i, line in enumerate(f): | ||
obj = json.loads(line) | ||
queries.append(obj['text']) | ||
return queries | ||
|
||
|
||
def load_corpus(filepaths: List[str]) -> List[str]: | ||
"""Loads corpus from a list of zipped jsonl files with product properties. | ||
For the training, we only consider the properties `title` and `supplier`. | ||
""" | ||
corpus = list() | ||
for filepath in filepaths: | ||
with ZipFile(filepath, 'r') as archive: | ||
for filename in archive.namelist(): | ||
with archive.open(filename, 'r') as f: | ||
objs = json.load(f) | ||
for obj in objs: | ||
title = obj.get('title', '') | ||
supplier = obj.get('supplier', '') | ||
if title and supplier: | ||
corpus.append(f'{title} {supplier}') | ||
return corpus | ||
|
||
|
||
def prepare_queries(queries: List[str]) -> DocumentArray: | ||
"""Transforms queries into Finetuner's dataset format (DocumentArray).""" | ||
return DocumentArray([Document(text=q) for q in queries]) | ||
|
||
|
||
def prepare_corpus(corpus: List[str]) -> DocumentArray: | ||
"""Transforms product data into Finetuner's dataset format (DocumentArray).""" | ||
return DocumentArray([Document(text=p) for p in corpus]) | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
'--queries', | ||
default='sbert_unsupervised/generated-queries.jsonl', | ||
help='Path to queries file', | ||
) | ||
@click.option( | ||
'--corpus', | ||
default=[ | ||
'../data-encoder/ecommerce/vectors/data/1.json.zip', | ||
'../data-encoder/ecommerce/vectors/data/2.json.zip', | ||
'../data-encoder/ecommerce/vectors/data/3.json.zip', | ||
'../data-encoder/ecommerce/vectors/data/4.json.zip', | ||
], | ||
help='Path to corpus files', | ||
multiple=True, | ||
) | ||
def main(queries: str, corpus: str): | ||
queries = load_queries(queries) | ||
corpus = load_corpus(corpus) | ||
queries = prepare_queries(queries) | ||
corpus = prepare_corpus(corpus) | ||
queries.summary() | ||
corpus.summary() | ||
queries.save_binary(QUERIES_DA_PATH) | ||
corpus.save_binary(CORPUS_DA_PATH) | ||
print( | ||
f'Prepared queries and corpus for unsupervised training of SBERT model:', | ||
f'{QUERIES_DA_PATH}, {CORPUS_DA_PATH}', | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import finetuner | ||
from finetuner.model import synthesis_model_en | ||
|
||
from docarray import DocumentArray | ||
|
||
DATASET_RUN_NAME = 'ecommerce-synthesis' | ||
|
||
finetuner.login(force=True) | ||
|
||
# load the datasets pre-prepared for Finetuner | ||
queries = DocumentArray.load_binary('sbert_unsupervised/queries.da') | ||
corpus = DocumentArray.load_binary('sbert_unsupervised/corpus.da') | ||
|
||
# upload the data to the Jina AI Cloud | ||
queries.push('ecommerce-queries') | ||
corpus.push('ecommerce-corpus') | ||
|
||
# start the data synthesis cloud job | ||
synthesis_run = finetuner.synthesize( | ||
query_data='ecommerce-queries', | ||
corpus_data='ecommerce-corpus', | ||
run_name=DATASET_RUN_NAME, | ||
models=synthesis_model_en, | ||
num_relations=10, | ||
) | ||
|
||
# print the name of the run (should be the same as DATASET_RUN_NAME) | ||
print(synthesis_run.name) | ||
|
||
# print logs | ||
for entry in synthesis_run.stream_logs(): | ||
print(entry) | ||
|
||
# download results | ||
train_data_name = synthesis_run.train_data | ||
train_data = DocumentArray.pull(train_data_name) | ||
train_data.summary() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import finetuner | ||
|
||
from finetuner.callback import EvaluationCallback | ||
|
||
# To use finetuner an account for the Jina AI Cloud is required. | ||
finetuner.login() | ||
|
||
DATASET_RUN_NAME = 'ecommerce-synthesis' # should be the same as in data_synthesis.py | ||
|
||
train_dataset_name = finetuner.get_run(DATASET_RUN_NAME).train_data | ||
|
||
training_run = finetuner.fit( | ||
model='sbert-base-en', | ||
train_data=train_dataset_name, | ||
loss='MarginMSELoss', | ||
# This loss function is specific for the unsupervised training | ||
# on generated data. It expects triplet of a query and two | ||
# documents associated with a margin relevance score. | ||
optimizer='Adam', | ||
learning_rate=1e-5, | ||
# choose a small learning rate since the model is already pre-trained | ||
epochs=1, | ||
# for fine-tuning usually one epoch is enough | ||
batch_size=16, | ||
# too high batch-sizes can lead to memory issues | ||
) | ||
|
||
print(training_run.name) | ||
|
||
# print logs | ||
for entry in training_run.stream_logs(): | ||
print(entry) | ||
|
||
# download model | ||
training_run.save_artivact('finetuned_sbert_model') |
Oops, something went wrong.