Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Deploy a Transformers4rec model with pre-trained embeddings #394

Open
mvidela31 opened this issue Jan 10, 2025 · 0 comments
Open

[QST] Deploy a Transformers4rec model with pre-trained embeddings #394

mvidela31 opened this issue Jan 10, 2025 · 0 comments

Comments

@mvidela31
Copy link

❓ Questions & Help

Hi everyone,

I tried to deploy a Transformers4rec model using pre-trained embedding following the Transformers4rec with pre-trained embeddings example and the transformers-next-item-prediction-with-pretrained-embeddings.ipynb (for Tensorflow Merlin-models). However, it seems to be problems to trace the PyTorch model with pre-trained embeddings.

Details

Based on the above examples, I made the following example:

data = tr.data.music_streaming_testing_data
schema = data.merlin_schema.select_by_name([
    "item_id",
    "item_category",
    "item_recency",
    "item_genres",
])

batch_size, max_length, pretrained_dim = 128, 20, 16

item_cardinality = schema["item_id"].int_domain.max + 1
np_emb_item_id = np.random.rand(item_cardinality, pretrained_dim)
embeddings_op = EmbeddingOperator(
    np_emb_item_id, lookup_key="item_id", embedding_name="pretrained_item_id_embeddings"
)

# set dataloader with pre-trained embeddings
data_loader = MerlinDataLoader.from_schema(
    schema,
    Dataset(data.path, schema=schema),
    max_sequence_length=max_length,
    batch_size=batch_size,
    transforms=[embeddings_op],
    shuffle=False,
)

# set the model schema from data-loader
model_schema = data_loader.output_schema
inputs = tr.TabularSequenceFeatures.from_schema(
    model_schema,
    max_sequence_length=max_length,
    pretrained_output_dims=8,
    normalizer="layer-norm",
    d_output=64,
    masking="mlm",
)
transformer_config = tr.XLNetConfig.build(64, 4, 2, 20)
task = tr.NextItemPredictionTask(weight_tying=True)
model = transformer_config.to_torch_model(inputs, task, max_sequence_length=max_length)

args = T4RecTrainingArguments(
    output_dir=".",
    max_steps=5,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size // 2,
    max_sequence_length=max_length,
    fp16=False,
    report_to=[],
    debug=["r"],
)

# Explicitly pass the merlin dataloader with pre-trained embeddings
recsys_trainer = Trainer(
    model=model,
    args=args,
    schema=schema,
    train_dataloader=data_loader,
    eval_dataloader=data_loader,
    compute_metrics=True,
)

recsys_trainer.train()
eval_metrics = recsys_trainer.evaluate(eval_dataset=data.path, metric_key_prefix="eval")

### Model export
topk = 20
model.top_k = topk
model.eval()

df = cudf.read_parquet(data.path, columns=model.input_schema.column_names)
table = TensorTable.from_df(df.loc[:10])
for column in table.columns:
    table[column] = convert_col(table[column], TorchColumn)
model_input_dict = table.to_dict()

traced_model = torch.jit.trace(model, model_input_dict, strict=True)
input_schema = model.input_schema
output_schema = model.output_schema

torch_op = schema.column_names >> embeddings_op >> PredictPyTorch(
    traced_model, input_schema, output_schema
)

ensemble = Ensemble(torch_op, schema)
ens_config, node_configs = ensemble.export(".")

As you can see below, a matrix shape mismatch error raises when tried to trace the PyTorch model:

Traceback (most recent call last):
  File "/opt/ml/code/train.py", line 899, in test_trainer_with_pretrained_embeddings
    traced_model = torch.jit.trace(model, model_input_dict, strict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 581, in forward
    head(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 382, in forward
    body_outputs = self.body(body_outputs, training=training, testing=testing, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 256, in forward
    input = module(input, training=training, testing=testing)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
    outputs = super().__call__(inputs, *args, **kwargs)  # noqa
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/sequence.py", line 259, in forward
    outputs = self.projection_module(outputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 252, in forward
    input = module(input, **filtered_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 260, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (220x128 and 136x64)

It seems that the torch.jit.trace() function can't recognize the pre-trained embeddings provided by the dataloader.

Do you have any suggestion on how to deploy a Transformers4rec model with pre-trained embeddings on Triton Inference Server?

Thanks for your amazing work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant