-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapting the script classification_with_grn_and_vsn to be Backend-Agnostic #2023
Adapting the script classification_with_grn_and_vsn to be Backend-Agnostic #2023
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
""" | ||
Clean the directory for the downloaded files except the .tar.gz file and | ||
also remove the empty directories | ||
""" | ||
|
||
subprocess.run( | ||
f'find {extracted_path} -type f ! -name "*.tar.gz" -exec rm -f {{}} +', | ||
shell=True, | ||
check=True, | ||
) | ||
subprocess.run( | ||
f"find {extracted_path} -type d -empty -exec rmdir {{}} +", shell=True, check=True | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could apply to any dataset, but I find this an unnecessary distraction. Can you remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
# The reason for each individual backend calculation is that I couldn't find | ||
# the equivalent keras operation that is backend-agnostic. In the following case there,s | ||
# a keras.ops.matmul but it was returning errors. I could have used the tensorflow matmul | ||
# for all backends, but due to jax jit tracing it results in an error. | ||
def matmul_dependent_on_backend(tensor_1, tensor_2): | ||
""" | ||
Function for executing matmul for each backend. | ||
""" | ||
# jax backend | ||
if keras.backend.backend() == "jax": | ||
import jax.numpy as jnp | ||
|
||
result = jnp.sum(tensor_1 * tensor_2, axis=1) | ||
elif keras.backend.backend() == "torch": | ||
result = torch.sum(tensor_1 * tensor_2, dim=1) | ||
# tensorflow backend | ||
elif keras.backend.backend() == "tensorflow": | ||
result = keras.ops.squeeze(tf.matmul(tensor_1, tensor_2, transpose_a=True), axis=1) | ||
# unsupported backend exception | ||
else: | ||
raise ValueError( | ||
"Unsupported backend: {}".format(keras.backend.backend()) | ||
) | ||
return result | ||
|
||
# jax backend | ||
if keras.backend.backend() == "jax": | ||
# This repetative imports are intentional to force the idea of backend | ||
# separation | ||
import jax.numpy as jnp | ||
|
||
result_jax = matmul_dependent_on_backend(v, x) | ||
return result_jax | ||
# torch backend | ||
if keras.backend.backend() == "torch": | ||
import torch | ||
|
||
result_torch = matmul_dependent_on_backend(v, x) | ||
return result_torch | ||
# tensorflow backend | ||
if keras.backend.backend() == "tensorflow": | ||
import tensorflow as tf | ||
|
||
result_tf = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1) | ||
return result_tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definitely should not be needed.
What is the issue with keras.ops.squeeze(keras.ops.matmul(keras.ops.transpose(v), x), axis=1)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After careful thought, I've made it to work. I also struggled with the Keras doc string of the op keras.transpose, I don't think axes is explicit about the permutations. I had to read tensorflow doc to have a clear picture. But, nonetheless, it is resolved.
# to remove the build warnings | ||
def build(self): | ||
self.built = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more of this.
@@ -415,7 +520,7 @@ def create_model(encoding_size): | |||
learning_rate = 0.001 | |||
dropout_rate = 0.15 | |||
batch_size = 265 | |||
num_epochs = 20 | |||
num_epochs = 1 # maybe adjusted to a desired value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert after you're done testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted to 20, but left the comment.
@@ -108,13 +109,37 @@ | |||
"income_level", | |||
] | |||
|
|||
data_url = "https://archive.ics.uci.edu/static/public/20/census+income.zip" | |||
data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change in dataset?
It seems like the original dataset was easier to handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataset description in the script says it must have 41 input features: 7 numerical features and 34 categorical features.
. The original dataset only had 14 features
and its target variable was in <= or >=50k
, whereas in the script it is in -5000 or +5000
PR sent addressing the comments. |
@@ -415,7 +471,7 @@ def create_model(encoding_size): | |||
learning_rate = 0.001 | |||
dropout_rate = 0.15 | |||
batch_size = 265 | |||
num_epochs = 20 | |||
num_epochs = 1 # may be adjusted to a desired value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this a different one or was it not actually reverted? (Please change back to 20)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed. It was an over-sight.
PR sent. Together with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for the port!
This PR adapts the script
classification_with_grn_and_vsn.py
fromstructured_data
examples tojax
andtorch
, i.e., make the script Backend-Agnostic.Approach:
Preprocessing
layers and use them with thetf.data.Datasets
. Refer to this Keras doc string for more..ipynb
and.md
files95%
accuracy in the original model