Skip to content

Commit

Permalink
Fix discrete/continuous mixup bug; move eval rng inside loop for cons…
Browse files Browse the repository at this point in the history
…istent eval rng each step
  • Loading branch information
smsharma committed Feb 12, 2024
1 parent 161ed9c commit e6fa6f6
Showing 1 changed file with 58 additions and 59 deletions.
117 changes: 58 additions & 59 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,14 @@ def train(config: ConfigDict, workdir: str = "./logging/") -> train_state.TrainS
with trange(config.training.n_train_steps) as steps:
for step in steps:

# Use same rng for eval every time
# NOTE: Should move this down inside the eval if
rng_eval = jax.random.PRNGKey(config.seed)

# Eval portion

# Evaluate before starting, hence no `and (step != 0)`
if (step % config.training.eval_every_steps == 0) and (jax.process_index() == 0):

# Use same rng for eval every time
rng_eval = jax.random.PRNGKey(config.seed)

# Log step at which evaluating
logging.info(f"Evaluating at step {step}")

Expand All @@ -272,9 +271,9 @@ def train(config: ConfigDict, workdir: str = "./logging/") -> train_state.TrainS
# Rotations
rng_eval, _ = jax.random.split(rng_eval)
if config.data.augment_rotate_type == "continuous":
rotation_angles = jax.random.choice(rng_eval, rot_angles_90, shape=(images.shape[0],)) # Angles in radians
elif config.data.augment_rotate_type == "discrete":
rotation_angles = jax.random.uniform(rng_eval, shape=(images.shape[0],)) * 2 * np.pi # Angles in radians
elif config.data.augment_rotate_type == "discrete":
rotation_angles = jax.random.choice(rng_eval, rot_angles_90, shape=(images.shape[0],)) # Angles in radians
else:
raise ValueError(f"Invalid augment_rotate_type: {config.data.augment_rotate_type}")
images = jax.vmap(partial(rotate, mode='constant', cval=1.))(images, rotation_angles)
Expand Down Expand Up @@ -325,65 +324,65 @@ def serialize_metrics(metrics):

ckpt_mgr.save(step, state_ckpt, save_kwargs={'save_args': save_args}, metrics=summary)

# # Train portion

# rng, rng_aug = jax.random.split(rng)
# images, captions = next(batches)
# images = np.array(images)

# # Augment images through random rotations and flips
# if config.data.augment_rotate:

# # Rotations
# rng_aug, _ = jax.random.split(rng_aug)
# if config.data.augment_rotate_type == "continuous":
# rotation_angles = jax.random.choice(rng_aug, rot_angles_90, shape=(images.shape[0],)) # Angles in radians
# elif config.data.augment_rotate_type == "discrete":
# rotation_angles = jax.random.uniform(rng_aug, shape=(images.shape[0],)) * 2 * np.pi # Angles in radians
# else:
# raise ValueError(f"Invalid augment_rotate_type: {config.data.augment_rotate_type}")
# images = jax.vmap(partial(rotate, mode='constant', cval=1.))(images, rotation_angles)

# # Flips
# rng_aug, _ = jax.random.split(rng_aug)
# images = jax.vmap(partial(random_flip_up_down, key=rng_aug))(image=images)

# rng_aug, _ = jax.random.split(rng_aug)
# images = jax.vmap(partial(random_flip_left_right, key=rng_aug))(image=images)

# # Augment images through random crops
# # Otherwise, they'll be downsampled to the vision model's image size
# if config.data.augment_crop:
# rng_aug, _ = jax.random.split(rng_aug)
# images = jax.vmap(random_crop, in_axes=(None,0,None))(rng_aug, images, (model.config.vision_config.image_size, model.config.vision_config.image_size, 3))

# # NOTE: Image arrays should be ints in the range [0, 255] here
# captions = process_truncate_captions(captions, rng_aug, max_length_words=max_length_words, use_sum1=config.sum1.use_sum1, df_sum_merged=df_sum_merged)
# inputs = processor(text=captions, images=(images * 255.).astype(np.uint8), return_tensors="np", padding="max_length", truncation=True, max_length=model.config.text_config.max_length)
# batch = inputs.data
# Train portion

rng, rng_aug = jax.random.split(rng)
images, captions = next(batches)
images = np.array(images)

# Augment images through random rotations and flips
if config.data.augment_rotate:

# Rotations
rng_aug, _ = jax.random.split(rng_aug)
if config.data.augment_rotate_type == "continuous":
rotation_angles = jax.random.uniform(rng_aug, shape=(images.shape[0],)) * 2 * np.pi # Angles in radians
elif config.data.augment_rotate_type == "discrete":
rotation_angles = jax.random.choice(rng_aug, rot_angles_90, shape=(images.shape[0],)) # Angles in radians
else:
raise ValueError(f"Invalid augment_rotate_type: {config.data.augment_rotate_type}")
images = jax.vmap(partial(rotate, mode='constant', cval=1.))(images, rotation_angles)

# Flips
rng_aug, _ = jax.random.split(rng_aug)
images = jax.vmap(partial(random_flip_up_down, key=rng_aug))(image=images)

rng_aug, _ = jax.random.split(rng_aug)
images = jax.vmap(partial(random_flip_left_right, key=rng_aug))(image=images)

# Augment images through random crops
# Otherwise, they'll be downsampled to the vision model's image size
if config.data.augment_crop:
rng_aug, _ = jax.random.split(rng_aug)
images = jax.vmap(random_crop, in_axes=(None,0,None))(rng_aug, images, (model.config.vision_config.image_size, model.config.vision_config.image_size, 3))

# NOTE: Image arrays should be ints in the range [0, 255] here
captions = process_truncate_captions(captions, rng_aug, max_length_words=max_length_words, use_sum1=config.sum1.use_sum1, df_sum_merged=df_sum_merged)
inputs = processor(text=captions, images=(images * 255.).astype(np.uint8), return_tensors="np", padding="max_length", truncation=True, max_length=model.config.text_config.max_length)
batch = inputs.data

# # Optionally shuffle "pixel_values" within batch
# if config.data.shuffle_within_batch:
# batch["pixel_values"] = jax.random.permutation(rng, batch["pixel_values"], axis=0)
# Optionally shuffle "pixel_values" within batch
if config.data.shuffle_within_batch:
batch["pixel_values"] = jax.random.permutation(rng, batch["pixel_values"], axis=0)

# # Split batch across devices
# batch = jax.tree_map(lambda x: np.split(x, num_local_devices, axis=0), batch)
# batch = jax.tree_map(lambda x: np.array(x, dtype=dtype), batch)
# Split batch across devices
batch = jax.tree_map(lambda x: np.split(x, num_local_devices, axis=0), batch)
batch = jax.tree_map(lambda x: np.array(x, dtype=dtype), batch)

# pstate, metrics = train_step(pstate, np.array(batch["input_ids"]), np.array(batch["pixel_values"]), np.array(batch["attention_mask"]), config.training.loss_type)
# steps.set_postfix(val=unreplicate(metrics["loss"]))
# train_metrics.append(metrics)
pstate, metrics = train_step(pstate, np.array(batch["input_ids"]), np.array(batch["pixel_values"]), np.array(batch["attention_mask"]), config.training.loss_type)
steps.set_postfix(val=unreplicate(metrics["loss"]))
train_metrics.append(metrics)

# # Log periodically
# if (step % config.training.log_every_steps == 0) and (step != 0) and (jax.process_index() == 0):
# train_metrics = common_utils.get_metrics(train_metrics)
# summary = {f"train/{k}": v for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()}
# Log periodically
if (step % config.training.log_every_steps == 0) and (step != 0) and (jax.process_index() == 0):
train_metrics = common_utils.get_metrics(train_metrics)
summary = {f"train/{k}": v for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()}

# writer.write_scalars(step, summary)
# train_metrics = []
writer.write_scalars(step, summary)
train_metrics = []

# if config.wandb.log_train:
# wandb.log({"train/step": step, **summary})
if config.wandb.log_train:
wandb.log({"train/step": step, **summary})

logging.info("All done! Have a great day.")

Expand Down

0 comments on commit e6fa6f6

Please sign in to comment.