-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from boostcampaitech7/version_1.1
Feat: add --testing argument_parser and add requirements.txt
- Loading branch information
Showing
8 changed files
with
227 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
data/ | ||
__pycache__/ | ||
__pycache__/ | ||
models/ | ||
outputs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
## version 설명 | ||
- 본 버전은 version_0.0으로 가정한 기본 제공된 BaseLineCode에서 준성님이 재구조화 시킨 version_1.0을 기반으로 작동하며 | ||
- 해당 버전에 있어 https://github.com/boostcampaitech7/level2-mrc-nlp-01/pull/10 에 기록된 바, `Sparse_retrieval.py` 를 다소 수정하여 retriever 성능 평가 기능을 정상 작동시킴. | ||
- `requirements.txt` 추가 | ||
- 코드 수정: train.py | ||
``` | ||
from datasets import load_from_disk, load_metric | ||
``` | ||
위 코드를 | ||
``` | ||
from datasets import load_from_disk | ||
from evaluate import load as load_metric | ||
``` | ||
위와 같이 변경 | ||
- 코드 수정: config.py | ||
``` | ||
class Config: | ||
def __init__(self, config_dict=None, path='./config.yaml'): | ||
``` | ||
위 부분을 | ||
``` | ||
class Config: | ||
def __init__(self, config_dict=None, path='../config.yaml'): | ||
``` | ||
와 같이 현재 파일구조와 맞춤. | ||
- 코드 수정: config.yaml | ||
경로 관련하여 | ||
CLI에서 실행시킬 `train.py`를 | ||
`/data/ephemeral/home/project/version_1.1/src/train.py` 와 같이 입력시키면 `permission denied` 문제가 있어서 | ||
그냥 터미널을 src 폴더 들어가서 실행시키는 것으로 하고 config.yaml의 | ||
`./data/train_dataset` 라고 되어있던 부분을 | ||
`../data/train_dataset` 이렇게 바꿈. | ||
원래대로 `python train.py --output_dir ./models/train_dataset --do_train` | ||
명령하면 이제 됨. | ||
- 기능 추가: `--testing` | ||
관련하여 전체 데이터셋을 사용하지 않고 testing 할 수 있도록 argument parser 추가 | ||
- `train.py`와 `inference.py`, `sparse_retrieval.py`에 해당 부분이 적용될 수 있도록 코드 수정하였고 | ||
- 구체적으로는 datasets, wikipedia_documents 불러오는 부분을 1%만 불러오게끔 수정함. | ||
- 코드 수정: `.gitignore`에 `__pycache__` 추가함. | ||
## 구체적인 적용점 (GitHub Issue) | ||
- Ljb issue 09 #10 | ||
# 실행 방법 | ||
- train의 경우 | ||
python train.py --output_dir ./models/train_dataset --do_train `--testing`(선택) | ||
- eval의 경우 | ||
python train.py --output_dir ./outputs/train_dataset --do_eval `--testing`(선택) | ||
- inference의 경우 | ||
python inference.py --output_dir ./outputs/test_dataset/ --dataset_name ../data/test_dataset --do_predict `--testing`(선택) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,3 +40,6 @@ dataQA: | |
# - method: your-augementation-method | ||
# params: | ||
# p: 0.0 | ||
|
||
testing: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
accelerate==1.0.1 | ||
aiohappyeyeballs==2.4.3 | ||
aiohttp==3.10.10 | ||
aiosignal==1.3.1 | ||
asttokens==2.0.5 # Ensure this version or the version you need | ||
astunparse==1.6.3 | ||
async-timeout==4.0.3 | ||
attrs==23.1.0 | ||
backcall==0.2.0 | ||
beautifulsoup4==4.11.1 | ||
boltons==23.0.0 | ||
brotlipy==0.7.0 | ||
certifi==2022.6.15 | ||
cffi==1.15.1 | ||
cryptography | ||
chardet==4.0.0 | ||
charset-normalizer==2.1.1 | ||
click==8.1.3 | ||
cryptography==38.0.0 | ||
datasets==3.0.1 | ||
debugpy==1.6.3 | ||
decorator==5.1.1 | ||
dill==0.3.8 | ||
dnspython==2.4.2 | ||
entrypoints==0.3 | ||
evaluate==0.4.3 | ||
exceptiongroup==1.0.0 | ||
executing==0.8.3 | ||
expecttest==0.1.6 | ||
filelock==3.8.0 | ||
frozenlist==1.4.1 | ||
fsspec==2023.9.2 | ||
hypothesis==6.87.1 | ||
idna==3.3 | ||
ipykernel==6.15.2 | ||
ipython==8.6.0 | ||
jedi==0.18.1 | ||
Jinja2==3.0.3 | ||
jsonpatch==1.32 | ||
jsonpointer==2.3 | ||
jupyter-client==7.3.5 | ||
jupyter_core==4.10.0 | ||
libarchive-c==4.0 | ||
MarkupSafe==2.1.1 | ||
matplotlib-inline==0.1.6 | ||
more-itertools==8.14.0 | ||
mpmath==1.2.1 | ||
multidict==6.1.0 | ||
multiprocess==0.70.16 | ||
nest_asyncio==1.5.5 | ||
networkx==2.8.4 | ||
numpy==1.26.0 | ||
packaging==21.3 | ||
pandas==2.2.3 | ||
parso==0.8.3 | ||
pexpect==4.8.0 | ||
pickleshare==0.7.5 | ||
Pillow==9.2.0 | ||
pkginfo==1.8.3 | ||
platformdirs==2.5.2 | ||
pluggy==1.0.0 | ||
prompt-toolkit==3.0.31 | ||
propcache==0.2.0 | ||
psutil==5.9.1 | ||
ptyprocess==0.7.0 | ||
pure-eval==0.2.2 | ||
pyarrow==15.0.0 | ||
pycparser==2.21 | ||
Pygments==2.12.0 | ||
pyOpenSSL==22.1.0 | ||
PySocks==1.7.1 | ||
python-dateutil==2.8.2 | ||
python-etcd==0.4.5 | ||
pytz==2022.1 | ||
PyYAML==6.0 | ||
pyzmq==23.1.0 | ||
regex==2024.9.11 | ||
requests==2.32.3 | ||
ruamel.yaml==0.17.21 | ||
ruamel.yaml.clib==0.2.6 | ||
safetensors==0.4.3 | ||
six==1.16.0 | ||
sortedcontainers==2.4.0 | ||
soupsieve==2.3.2.post1 | ||
stack-data==0.2.0 | ||
sympy==1.10.1 | ||
tokenizers | ||
tomli==2.0.1 | ||
toolz==0.12.0 | ||
torch==2.1.0 | ||
torchaudio==2.1.0 | ||
torchelastic==0.2.2 | ||
torchvision==0.16.0 | ||
tornado==6.2 | ||
tqdm==4.66.3 | ||
traitlets==5.5.0 | ||
transformers==4.45.2 | ||
triton==2.1.0 | ||
truststore==0.8.0 | ||
types-dataclasses==0.6.6 | ||
typing_extensions==4.1.1 | ||
tzdata==2024.2 | ||
urllib3==1.26.11 | ||
wcwidth==0.2.5 | ||
xxhash==3.5.0 | ||
yarl==1.12.0 | ||
faiss-gpu | ||
scikit-learn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters