Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapting the script classification_with_grn_and_vsn to be Backend-Agnostic #2023

Merged
merged 6 commits into from
Jan 17, 2025

Conversation

Humbulani1234
Copy link
Contributor

This PR adapts the script classification_with_grn_and_vsn.py from structured_data examples to jax and torch, i.e., make the script Backend-Agnostic.

Approach:

  • Modified the model architecture and removed the Preprocessing layers and use them with the tf.data.Datasets. Refer to this Keras doc string for more.
  • Generated the .ipynb and .md files
  • Achieved the 95% accuracy in the original model

Copy link
Contributor

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Humbulani1234 ,

Thank you for the PR!

Comment on lines 185 to 198
"""
Clean the directory for the downloaded files except the .tar.gz file and
also remove the empty directories
"""

subprocess.run(
f'find {extracted_path} -type f ! -name "*.tar.gz" -exec rm -f {{}} +',
shell=True,
check=True,
)
subprocess.run(
f"find {extracted_path} -type d -empty -exec rmdir {{}} +", shell=True, check=True
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could apply to any dataset, but I find this an unnecessary distraction. Can you remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines 450 to 494
# The reason for each individual backend calculation is that I couldn't find
# the equivalent keras operation that is backend-agnostic. In the following case there,s
# a keras.ops.matmul but it was returning errors. I could have used the tensorflow matmul
# for all backends, but due to jax jit tracing it results in an error.
def matmul_dependent_on_backend(tensor_1, tensor_2):
"""
Function for executing matmul for each backend.
"""
# jax backend
if keras.backend.backend() == "jax":
import jax.numpy as jnp

result = jnp.sum(tensor_1 * tensor_2, axis=1)
elif keras.backend.backend() == "torch":
result = torch.sum(tensor_1 * tensor_2, dim=1)
# tensorflow backend
elif keras.backend.backend() == "tensorflow":
result = keras.ops.squeeze(tf.matmul(tensor_1, tensor_2, transpose_a=True), axis=1)
# unsupported backend exception
else:
raise ValueError(
"Unsupported backend: {}".format(keras.backend.backend())
)
return result

# jax backend
if keras.backend.backend() == "jax":
# This repetative imports are intentional to force the idea of backend
# separation
import jax.numpy as jnp

result_jax = matmul_dependent_on_backend(v, x)
return result_jax
# torch backend
if keras.backend.backend() == "torch":
import torch

result_torch = matmul_dependent_on_backend(v, x)
return result_torch
# tensorflow backend
if keras.backend.backend() == "tensorflow":
import tensorflow as tf

result_tf = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1)
return result_tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely should not be needed.

What is the issue with keras.ops.squeeze(keras.ops.matmul(keras.ops.transpose(v), x), axis=1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After careful thought, I've made it to work. I also struggled with the Keras doc string of the op keras.transpose, I don't think axes is explicit about the permutations. I had to read tensorflow doc to have a clear picture. But, nonetheless, it is resolved.

Comment on lines +496 to +498
# to remove the build warnings
def build(self):
self.built = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more of this.

@@ -415,7 +520,7 @@ def create_model(encoding_size):
learning_rate = 0.001
dropout_rate = 0.15
batch_size = 265
num_epochs = 20
num_epochs = 1 # maybe adjusted to a desired value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert after you're done testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to 20, but left the comment.

@@ -108,13 +109,37 @@
"income_level",
]

data_url = "https://archive.ics.uci.edu/static/public/20/census+income.zip"
data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change in dataset?

It seems like the original dataset was easier to handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset description in the script says it must have 41 input features: 7 numerical features and 34 categorical features.. The original dataset only had 14 features and its target variable was in <= or >=50k, whereas in the script it is in -5000 or +5000

@Humbulani1234
Copy link
Contributor Author

PR sent addressing the comments.

@@ -415,7 +471,7 @@ def create_model(encoding_size):
learning_rate = 0.001
dropout_rate = 0.15
batch_size = 265
num_epochs = 20
num_epochs = 1 # may be adjusted to a desired value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a different one or was it not actually reverted? (Please change back to 20)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. It was an over-sight.

@Humbulani1234
Copy link
Contributor Author

PR sent. Together with the .md and .ipynb files.

Copy link
Contributor

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for the port!

@hertschuh hertschuh merged commit fcf47ee into keras-team:master Jan 17, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants