Skip to content

Commit

Permalink
Found bug (had to do with max words
Browse files Browse the repository at this point in the history
  • Loading branch information
smsharma committed Feb 13, 2024
1 parent e6fa6f6 commit 461c7d9
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 169 deletions.
4 changes: 2 additions & 2 deletions configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_config():
data.augment_subsample_text = False
data.max_length_words = 77
data.tfrecords_dir = "tfrecords_v5"
data.caption_type = "abstract" # "abstract" or "summary"
data.caption_type = "summary" # "abstract" or "summary"
data.shuffle_within_batch = False
data.data_dir = "/n/holyscratch01/iaifi_lab/smsharma/hubble_data/"

Expand All @@ -51,7 +51,7 @@ def get_config():
training.ckpt_best_metric_best_mode = "min" # "max" or "min"
training.ckpt_keep_top_n = 3 # Save the top `ckpt_keep_top_n` checkpoints based on `ckpt_best_metric`
training.load_ckpt = True
training.ckpt_run_name = "ancient-pine-88"
training.ckpt_run_name = "vermilion-lantern-101"


# Sum1 options
Expand Down
83 changes: 53 additions & 30 deletions notebooks/xx_debugging_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-02-12 18:01:07.414616: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-02-12 18:01:07.414658: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-02-12 18:01:07.415937: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
"2024-02-13 12:01:17.992001: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-02-13 12:01:17.992050: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-02-13 12:01:17.993373: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
Expand Down Expand Up @@ -81,7 +81,8 @@
"import optax \n",
"from flax.training import train_state\n",
"import flax\n",
"import orbax\n",
"import flax.training.orbax_utils\n",
"import orbax.checkpoint\n",
"\n",
"replicate = flax.jax_utils.replicate\n",
"unreplicate = flax.jax_utils.unreplicate\n",
Expand All @@ -108,7 +109,7 @@
"from ml_collections.config_dict import ConfigDict\n",
"\n",
"logging_dir = '../logging/proposals/'\n",
"run_name = 'ancient-pine-88'\n",
"run_name = 'vermilion-lantern-101'\n",
"\n",
"config_file = \"{}/{}/config.yaml\".format(logging_dir, run_name)\n",
"\n",
Expand All @@ -120,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 6,
"id": "37674e0a",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -150,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 7,
"id": "c783dec2-b2f3-454c-8782-576da5bf6761",
"metadata": {},
"outputs": [],
Expand All @@ -164,12 +165,12 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 8,
"id": "73e60d7d-6854-4c2a-8d61-28f4b7fd1bbd",
"metadata": {},
"outputs": [],
"source": [
"run_labels = ['ancient-pine-88',]\n",
"run_labels = ['vermilion-lantern-101',]\n",
"run_legends = [\"Fine-tune (abstracts)\"]\n",
"\n",
"data_type = [\"abstract\"]\n",
Expand All @@ -178,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 9,
"id": "976345c4-c4a8-4d0e-9679-731b65312b29",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -217,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 12,
"id": "583d4643-8c1e-4749-be73-87422986c232",
"metadata": {},
"outputs": [],
Expand All @@ -228,7 +229,7 @@
"from models.losses import softmax_loss\n",
"\n",
"@partial(jax.pmap, axis_name=\"batch\")\n",
"def get_features(state, input_ids, pixel_values, attention_mask):\n",
"def eval_step(state, input_ids, pixel_values, attention_mask):\n",
"\n",
" # captions_feat = model.get_text_features(input_ids, attention_mask, params=state.params)\n",
" # images_feat = model.get_image_features(pixel_values, params=state.params)\n",
Expand All @@ -248,6 +249,9 @@
" \n",
" return metrics\n",
" \n",
"# Rotation angles in rad\n",
"rot_angles_90 = np.array([0.0, np.pi / 2, np.pi, 3 * np.pi / 2])\n",
"\n",
"def get_features_ds(state, ds, truncate=False):\n",
"\n",
" batches = iter(ds)\n",
Expand All @@ -260,23 +264,17 @@
"\n",
" retrieval_eval_metrics = []\n",
"\n",
" rng_eval = jax.random.PRNGKey(42)\n",
"\n",
" for (images, captions) in tqdm(batches, total=total_batches):\n",
" if current_batch == total_batches - 1:\n",
" break\n",
" \n",
" images = np.array(images)\n",
"\n",
" if truncate:\n",
" captions = process_truncate_captions(captions, jax.random.PRNGKey(onp.random.randint(99999)), max_length_words=config.data.max_length_words)\n",
" else:\n",
" captions = captions.numpy().tolist()\n",
" captions = [c.decode('utf-8') for c in captions]\n",
"\n",
" rng_eval = jax.random.PRNGKey(onp.random.randint(99999))\n",
" \n",
" # Rotations\n",
" rng_eval, _ = jax.random.split(rng_eval)\n",
" rotation_angles = jax.random.uniform(rng_eval, shape=(images.shape[0],)) * 2 * np.pi # Angles in radians\n",
" rotation_angles = jax.random.choice(rng_eval, rot_angles_90, shape=(images.shape[0],)) # Angles in radians\n",
" images = jax.vmap(partial(rotate, mode='constant', cval=1.))(images, rotation_angles)\n",
" \n",
" # Flips\n",
Expand All @@ -288,12 +286,20 @@
"\n",
" images = jax.vmap(random_crop, in_axes=(None,0,None))(rng_eval, images, (model.config.vision_config.image_size, model.config.vision_config.image_size, 3))\n",
"\n",
" input = processor(text=captions, images=(images * 255.).astype(np.uint8), return_tensors=\"np\", padding=\"max_length\", truncation=True, max_length=77)\n",
" if truncate:\n",
" captions = process_truncate_captions(captions, rng_eval, max_length_words=config.data.max_length_words)\n",
" else:\n",
" captions = captions.numpy().tolist()\n",
" captions = [c.decode('utf-8') for c in captions]\n",
"\n",
" inputs = processor(text=captions, images=(images * 255.).astype(np.uint8), return_tensors=\"np\", padding=\"max_length\", truncation=True, max_length=77)\n",
"\n",
" batch = inputs.data\n",
" \n",
" batch = jax.tree_map(lambda x: np.split(x, num_local_devices, axis=0), input.data)\n",
" batch = jax.tree_map(lambda x: np.split(x, num_local_devices, axis=0), batch)\n",
" batch = jax.tree_map(lambda x: np.array(x, dtype=np.float32), batch)\n",
"\n",
" metrics = get_features(replicate(state), np.array(batch[\"input_ids\"]), np.array(batch[\"pixel_values\"]), np.array(batch[\"attention_mask\"]))\n",
" metrics = eval_step(replicate(state), np.array(batch[\"input_ids\"]), np.array(batch[\"pixel_values\"]), np.array(batch[\"attention_mask\"]))\n",
"\n",
" retrieval_eval_metrics.append(metrics)\n",
" \n",
Expand All @@ -304,11 +310,29 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 13,
"id": "404f1ac0-3e64-4d6b-8c24-8a3ad8510453",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.\n",
" 97%|█████████▋| 30/31 [01:47<00:03, 3.59s/it]\n",
"100%|██████████| 1/1 [01:52<00:00, 112.96s/it]\n"
]
}
],
"source": [
"\n",
"accuracy_lists = []\n",
"for idx, run_name in enumerate(tqdm(run_labels[:])):\n",
"\n",
Expand All @@ -334,16 +358,15 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 15,
"id": "592da267-3405-44ed-89a2-25cf47f74d1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tune (abstracts) {'val/loss': 3.2589102, 'val/top_10_accuracy': 0.66666675, 'val/top_1_accuracy': 0.20033331, 'val/top_20_accuracy': 0.79533327, 'val/top_5_accuracy': 0.522}\n",
"Fine-tune (summaries) {'val/loss': 3.3698084, 'val/top_10_accuracy': 0.625, 'val/top_1_accuracy': 0.22099997, 'val/top_20_accuracy': 0.74966675, 'val/top_5_accuracy': 0.4993333}\n"
"Fine-tune (abstracts) {'val/loss': 3.26416, 'val/top_10_accuracy': 0.6536666, 'val/top_1_accuracy': 0.2053333, 'val/top_20_accuracy': 0.79233336, 'val/top_5_accuracy': 0.52033335}\n"
]
}
],
Expand All @@ -358,7 +381,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "56b36156",
"id": "e83efb0a",
"metadata": {},
"outputs": [],
"source": []
Expand Down
12 changes: 6 additions & 6 deletions scripts/submit_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ alias jupyter=/n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/jupyter
cd /n/holystore01/LABS/iaifi_lab/Users/smsharma/multimodal-data/

# # Core runs
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py # Base config, with summary captions
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.clip.transfer_head=True # Fine-tune just head
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.data.caption_type="summary" # Base config, with summary captions
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.data.caption_type="abstract" --config.data.augment_subsample_text=True # Full abstract

/n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.data.caption_type="abstract" --config.data.augment_subsample_text=True # Full abstract
/n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.clip.random_init_text=True --config.clip.random_init_vision=True # From scratch
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.clip.transfer_head=True # Fine-tune just head
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.clip.random_init_text=True --config.clip.random_init_vision=True # From scratch

# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.data.shuffle_within_batch=True # Shuffle within batch
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.sum1.use_sum1=True --config.sum1.sum1_filename="summary_sum1_v3" # Base config, with summary captions
/n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.data.shuffle_within_batch=True # Shuffle within batch
/n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.sum1.use_sum1=True --config.sum1.sum1_filename="summary_sum1_v3" # Base config, with summary captions

# # Additional runs
# /n/holystore01/LABS/iaifi_lab/Users/smsharma/envs/$ENV/bin/python -u train.py --config ./configs/base.py --config.optim.learning_rate=1e-6 # Lower LR
Expand Down
Loading

0 comments on commit 461c7d9

Please sign in to comment.