Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwichBack_add #81

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ airflow.yaml

# temporary
examples/tests
outputs

### d3rlpy logs
d3rlpy_logs/
d3rlpy_logs/

# datasets
replay_benchmarks/data
# logs and checkpoints
replay_benchmarks/artifacts
28 changes: 28 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel as builder
RUN apt-get update \
&& apt-get install -y --no-install-recommends apt-utils \
&& apt-get install libgomp1 build-essential pandoc -y \
&& apt-get install git -y --no-install-recommends

COPY --from=openjdk:11-jre-slim /usr/local/openjdk-11 /usr/local/openjdk-11
ENV JAVA_HOME /usr/local/openjdk-11
RUN update-alternatives --install /usr/bin/java java /usr/local/openjdk-11/bin/java 1

WORKDIR /home

RUN pip install --no-cache-dir --upgrade pip wheel poetry==1.5.1 poetry-dynamic-versioning \
&& python -m poetry config virtualenvs.create false
COPY . RePlay-Accelerated/
RUN cd RePlay-Accelerated && ./poetry_wrapper.sh install --all-extras

RUN pip install --upgrade torch
RUN pip install rs_datasets
RUN pip install Ninja==1.11.1.1
RUN pip install -U tensorboard


# ENV CUDACXX=/usr/local/cuda/bin/nvcc
# RUN cd /home/RecSys/cutlass && mkdir build && cd build && cmake .. -DCUTLASS_NVCC_ARCHS=75


CMD ["/bin/bash"]
34 changes: 34 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Эксперимент: Замена Classification Head с `nn.Linear` на `SwitchBackLinear`

## Описание
В данном эксперименте производится замена слоя `nn.Linear` на слой `SwitchBackLinear`, который выполняет форвард-проходы в формате `int8`. Это позволяет ускорить обучение модели.

## Шаги для запуска

1. Скачайте Docker-образ:
```bash
docker pull dmitryredkosk/bitsandbytes_recsys_clear
```

2. Перенесите содержимое файла `config_ml20_swichback.yaml` в основной файл конфигурации `config.yaml`

3. Запустите основной скрипт:
```bash
python RePlay-Accelerated/main.py
```

## Проверка обновления слоя

Чтобы убедиться, что слой `SwitchBackLinear` успешно заменил `nn.Linear`, вы можете добавить вывод инициализированной модели в метод `run` файла `RePlay-Accelerated/replay_benchmarks/train_runner.py`:
```python
def run(self):
"""Execute the training pipeline."""
train_dataloader, val_dataloader, prediction_dataloader = (
self._load_dataloaders()
)

logging.info("Initializing model...")
model = self._initialize_model()

print(model)
```
44 changes: 44 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Main module"""

import os
import logging
import warnings
import yaml
import argparse

from replay_benchmarks.utils.conf import load_config, seed_everything
from replay_benchmarks import TrainRunner, InferRunner

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
warnings.filterwarnings("ignore")


def main() -> None:
config_dir = "./replay_benchmarks/configs"
base_config_path = os.path.join(config_dir, "config.yaml")
# parser = argparse.ArgumentParser()
# parser.add_argument('--base_config_path', type=str, help='Path to config file')
# args = parser.parse_args()
# base_config_path = os.path.abspath(args.base_config_path)
# config_dir = os.path.dirname(base_config_path)
config = load_config(base_config_path, config_dir)
logging.info("Configuration:\n%s", yaml.dump(config))

seed_everything(config["env"]["SEED"])
logging.info(f"Fixing seed: {config['env']['SEED']}")

if config["mode"]["name"] == "train":
runner = TrainRunner(config)
elif config["mode"]["name"] == "infer":
runner = InferRunner(config)
else:
raise ValueError(f"Unsupported mode: {config['mode']}")

runner.run()


if __name__ == "__main__":
main()
93 changes: 92 additions & 1 deletion replay/models/nn/sequential/bert4rec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ def __init__(
loss_sample_count: Optional[int] = None,
negative_sampling_strategy: str = "global_uniform",
negatives_sharing: bool = False,
n_buckets: int = 100,
bucket_size_x: int = 100,
bucket_size_y: int = 100,
mix_x: bool = False,
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
acceleration_config: Optional[dict] = None
):
"""
:param tensor_schema (TensorSchema): Tensor schema of features.
Expand Down Expand Up @@ -64,9 +69,19 @@ def __init__(
Default: ``global_uniform``.
:param negatives_sharing: Apply negative sharing in calculating sampled logits.
Default: ``False``.
:param n_buckets: Number of buckets for SCE loss.
Default: ``100``
:param bucket_size_x: Size of x buckets for SCE loss.
Default: ``100``
:param bucket_size_y: Size of y buckets for SCE loss.
Default: ``100``
:param mix_x: Mix states embeddings with random matrix for SCE loss.
Default: ``False``
:param optimizer_factory: Optimizer factory.
Default: ``FatOptimizerFactory``.
:param lr_scheduler_factory: Learning rate schedule factory.
Default: ``None``
:param acceleration_config: Parameters for acceleration.
Default: ``None``.
"""
super().__init__()
Expand All @@ -81,7 +96,19 @@ def __init__(
dropout=dropout_rate,
enable_positional_embedding=enable_positional_embedding,
enable_embedding_tying=enable_embedding_tying,
acceleration_config=acceleration_config
)

if acceleration_config:
if acceleration_config["dtype"] == "fp32":
pass
elif acceleration_config["dtype"] == "bf16":
self._model = self._model.to(torch.bfloat16)
elif acceleration_config["dtype"] == "fp16":
self._model = self._model.to(torch.float16)
else:
raise ValueError(f"dtype in acceleration config is not supported")

self._loss_type = loss_type
self._loss_sample_count = loss_sample_count
self._negative_sampling_strategy = negative_sampling_strategy
Expand All @@ -90,6 +117,10 @@ def __init__(
self._lr_scheduler_factory = lr_scheduler_factory
self._loss = self._create_loss()
self._schema = tensor_schema
self._n_buckets = n_buckets
self._bucket_size_x = bucket_size_x
self._bucket_size_y = bucket_size_y
self._mix_x = mix_x
assert negative_sampling_strategy in {"global_uniform", "inbatch"}

item_count = tensor_schema.item_id_features.item().cardinality
Expand Down Expand Up @@ -207,6 +238,8 @@ def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
elif self._loss_type == "CE":
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
elif self._loss_type == "SCE":
loss_func = self._compute_loss_scalable_ce
else:
msg = f"Not supported loss type: {self._loss_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -325,6 +358,64 @@ def _compute_loss_ce_sampled(
labels_flat = torch.zeros(positive_logits.size(0), dtype=torch.long, device=padding_mask.device)
loss = self._loss(logits, labels_flat)
return loss

def _compute_loss_scalable_ce(
self,
feature_tensors: TensorMap,
positive_labels: torch.LongTensor,
padding_mask: torch.BoolTensor,
tokens_mask: torch.BoolTensor,
) -> torch.Tensor:

labels_mask = (~padding_mask) + tokens_mask
masked_tokens = ~labels_mask

pad_token = feature_tensors[self._schema.item_id_feature_name].view(-1)[~padding_mask.view(-1)][0]
emb = self._model.forward_step(feature_tensors, padding_mask, tokens_mask)
hd = torch.tensor(emb.shape[-1])

x = emb.view(-1, hd)
y = positive_labels.view(-1)
w = self.get_all_embeddings()["item_embedding"]

correct_class_logits_ = (x * torch.index_select(w, dim=0, index=y)).sum(dim=1) # (bs,)

with torch.no_grad():
if self._mix_x:
omega = 1/torch.sqrt(torch.sqrt(hd)) * torch.randn(x.shape[0], self._n_buckets, device=x.device)
buckets = omega.T @ x
del omega
else:
buckets = 1/torch.sqrt(torch.sqrt(hd)) * torch.randn(self._n_buckets, hd, device=x.device) # (n_b, hd)

with torch.no_grad():
x_bucket = buckets @ x.T # (n_b, hd) x (hd, b) -> (n_b, b)
x_bucket[:, ~padding_mask.view(-1)] = float('-inf')
_, top_x_bucket = torch.topk(x_bucket, dim=1, k=self._bucket_size_x) # (n_b, bs_x)
del x_bucket

y_bucket = buckets @ w.T # (n_b, hd) x (hd, n_cl) -> (n_b, n_cl)

y_bucket[:, pad_token] = float('-inf')
_, top_y_bucket = torch.topk(y_bucket, dim=1, k=self._bucket_size_y) # (n_b, bs_y)
del y_bucket

x_bucket = torch.gather(x, 0, top_x_bucket.view(-1, 1).expand(-1, hd)).view(self._n_buckets, self._bucket_size_x, hd) # (n_b, bs_x, hd)
y_bucket = torch.gather(w, 0, top_y_bucket.view(-1, 1).expand(-1, hd)).view(self._n_buckets, self._bucket_size_y, hd) # (n_b, bs_y, hd)

wrong_class_logits = (x_bucket @ y_bucket.transpose(-1, -2)) # (n_b, bs_x, bs_y)
mask = torch.index_select(y, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[:, :, None] == top_y_bucket[:, None, :] # (n_b, bs_x, bs_y)
wrong_class_logits = wrong_class_logits.masked_fill(mask, float('-inf')) # (n_b, bs_x, bs_y)
correct_class_logits = torch.index_select(correct_class_logits_, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[:, :, None] # (n_b, bs_x, 1)
logits = torch.cat((wrong_class_logits, correct_class_logits), dim=2) # (n_b, bs_x, bs_y + 1)

loss_ = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), (logits.shape[-1] - 1) * torch.ones(logits.shape[0] * logits.shape[1], dtype=torch.int64, device=logits.device), reduction='none') # (n_b * bs_x,)
loss = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
loss.scatter_reduce_(0, top_x_bucket.view(-1), loss_, reduce='amax', include_self=False)
loss = loss[(loss != 0) & (masked_tokens).view(-1)]
loss = torch.mean(loss)

return loss

def _get_sampled_logits(
self,
Expand Down Expand Up @@ -412,7 +503,7 @@ def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntrop
if self._loss_type == "BCE":
return torch.nn.BCEWithLogitsLoss(reduction="sum")

if self._loss_type == "CE":
if self._loss_type == "CE" or self._loss_type == "SCE":
return torch.nn.CrossEntropyLoss()

msg = "Not supported loss_type"
Expand Down
Loading