mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
parent
22f7218560
commit
e6f221c8d4
@ -1011,7 +1011,7 @@ def main():
|
|||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
|
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
|
||||||
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
|
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
@ -1064,7 +1064,7 @@ def main():
|
|||||||
if metrics:
|
if metrics:
|
||||||
# normalize metrics
|
# normalize metrics
|
||||||
metrics = get_metrics(metrics)
|
metrics = get_metrics(metrics)
|
||||||
metrics = jax.tree_map(jnp.mean, metrics)
|
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
|
||||||
|
|
||||||
# compute ROUGE metrics
|
# compute ROUGE metrics
|
||||||
generations = []
|
generations = []
|
||||||
|
@ -781,7 +781,7 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||||
@ -824,7 +824,7 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||||
|
@ -827,9 +827,9 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
||||||
eval_normalizer = eval_metrics.pop("normalizer")
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||||
@ -841,7 +841,7 @@ def main():
|
|||||||
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
@ -867,9 +867,9 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
||||||
eval_normalizer = eval_metrics.pop("normalizer")
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
perplexity = math.exp(eval_metrics["loss"])
|
perplexity = math.exp(eval_metrics["loss"])
|
||||||
|
@ -940,7 +940,7 @@ def main():
|
|||||||
|
|
||||||
# get eval metrics
|
# get eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
|
epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
|
||||||
@ -952,7 +952,7 @@ def main():
|
|||||||
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
@ -978,7 +978,7 @@ def main():
|
|||||||
|
|
||||||
# get eval metrics
|
# get eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
|
||||||
|
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||||
|
@ -902,7 +902,7 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
# compute ROUGE metrics
|
# compute ROUGE metrics
|
||||||
rouge_desc = ""
|
rouge_desc = ""
|
||||||
@ -923,7 +923,7 @@ def main():
|
|||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
@ -957,7 +957,7 @@ def main():
|
|||||||
|
|
||||||
# normalize prediction metrics
|
# normalize prediction metrics
|
||||||
pred_metrics = get_metrics(pred_metrics)
|
pred_metrics = get_metrics(pred_metrics)
|
||||||
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)
|
||||||
|
|
||||||
# compute ROUGE metrics
|
# compute ROUGE metrics
|
||||||
rouge_desc = ""
|
rouge_desc = ""
|
||||||
|
@ -542,7 +542,7 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
# Print metrics and update progress bar
|
# Print metrics and update progress bar
|
||||||
eval_step_progress_bar.close()
|
eval_step_progress_bar.close()
|
||||||
@ -560,7 +560,7 @@ def main():
|
|||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||||
|
@ -104,7 +104,7 @@ class DataCollator:
|
|||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
batch = self.collate_fn(batch)
|
batch = self.collate_fn(batch)
|
||||||
batch = jax.tree_map(shard, batch)
|
batch = jax.tree_util.tree_map(shard, batch)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def collate_fn(self, features):
|
def collate_fn(self, features):
|
||||||
|
@ -608,9 +608,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
||||||
eval_normalizer = eval_metrics.pop("normalizer")
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
steps.desc = (
|
steps.desc = (
|
||||||
@ -624,7 +624,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
training_args.output_dir,
|
training_args.output_dir,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -551,7 +551,7 @@ def main():
|
|||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
|
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
# Print metrics and update progress bar
|
# Print metrics and update progress bar
|
||||||
eval_step_progress_bar.close()
|
eval_step_progress_bar.close()
|
||||||
|
@ -481,7 +481,7 @@ def main():
|
|||||||
param_spec = set_partitions(unfreeze(model.params))
|
param_spec = set_partitions(unfreeze(model.params))
|
||||||
|
|
||||||
# Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
|
# Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
|
||||||
params_shapes = jax.tree_map(lambda x: x.shape, model.params)
|
params_shapes = jax.tree_util.tree_map(lambda x: x.shape, model.params)
|
||||||
state_shapes = jax.eval_shape(get_initial_state, params_shapes)
|
state_shapes = jax.eval_shape(get_initial_state, params_shapes)
|
||||||
|
|
||||||
# get PartitionSpec for opt_state, this is very specific to adamw
|
# get PartitionSpec for opt_state, this is very specific to adamw
|
||||||
@ -492,7 +492,7 @@ def main():
|
|||||||
return param_spec
|
return param_spec
|
||||||
return None
|
return None
|
||||||
|
|
||||||
opt_state_spec, param_spec = jax.tree_map(
|
opt_state_spec, param_spec = jax.tree_util.tree_map(
|
||||||
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -506,7 +506,7 @@ def main():
|
|||||||
|
|
||||||
# hack: move the inital params to CPU to free up device memory
|
# hack: move the inital params to CPU to free up device memory
|
||||||
# TODO: allow loading weights on CPU in pre-trained model
|
# TODO: allow loading weights on CPU in pre-trained model
|
||||||
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
|
||||||
|
|
||||||
# mesh defination
|
# mesh defination
|
||||||
mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
|
mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
|
||||||
@ -636,7 +636,7 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = stack_forest(eval_metrics)
|
eval_metrics = stack_forest(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||||
|
@ -591,7 +591,7 @@ def main():
|
|||||||
|
|
||||||
# get eval metrics
|
# get eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.write(
|
epochs.write(
|
||||||
@ -606,7 +606,7 @@ def main():
|
|||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
|
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
|
||||||
|
|
||||||
|
|
||||||
|
@ -674,9 +674,9 @@ if __name__ == "__main__":
|
|||||||
eval_metrics.append(metrics)
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
eval_metrics_np = get_metrics(eval_metrics)
|
eval_metrics_np = get_metrics(eval_metrics)
|
||||||
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
|
eval_metrics_np = jax.tree_util.tree_map(jnp.sum, eval_metrics_np)
|
||||||
eval_normalizer = eval_metrics_np.pop("normalizer")
|
eval_normalizer = eval_metrics_np.pop("normalizer")
|
||||||
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
eval_summary = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.desc = (
|
epochs.desc = (
|
||||||
|
@ -699,7 +699,7 @@ class FlaxGenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return tensor[batch_indices, beam_indices]
|
return tensor[batch_indices, beam_indices]
|
||||||
|
|
||||||
return jax.tree_map(gather_fn, nested)
|
return jax.tree_util.tree_map(gather_fn, nested)
|
||||||
|
|
||||||
# init values
|
# init values
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
@ -788,7 +788,7 @@ class FlaxGenerationMixin:
|
|||||||
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
||||||
|
|
||||||
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
||||||
cache = jax.tree_map(
|
cache = jax.tree_util.tree_map(
|
||||||
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -874,7 +874,7 @@ class FlaxGenerationMixin:
|
|||||||
# With these, gather the top k beam-associated caches.
|
# With these, gather the top k beam-associated caches.
|
||||||
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
||||||
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
||||||
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
||||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||||
|
|
||||||
return BeamSearchState(
|
return BeamSearchState(
|
||||||
|
@ -253,7 +253,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# check if we have bf16 weights
|
# check if we have bf16 weights
|
||||||
is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
||||||
if any(is_type_bf16):
|
if any(is_type_bf16):
|
||||||
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
|
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
|
||||||
# and bf16 is not fully supported in PT yet.
|
# and bf16 is not fully supported in PT yet.
|
||||||
@ -261,7 +261,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|||||||
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
||||||
"before loading those in PyTorch model."
|
"before loading those in PyTorch model."
|
||||||
)
|
)
|
||||||
flax_state = jax.tree_map(
|
flax_state = jax.tree_util.tree_map(
|
||||||
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -303,10 +303,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
return param
|
return param
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
return jax.tree_map(conditional_cast, params)
|
return jax.tree_util.tree_map(conditional_cast, params)
|
||||||
|
|
||||||
flat_params = flatten_dict(params)
|
flat_params = flatten_dict(params)
|
||||||
flat_mask, _ = jax.tree_flatten(mask)
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
||||||
|
|
||||||
for masked, key in zip(flat_mask, flat_params.keys()):
|
for masked, key in zip(flat_mask, flat_params.keys()):
|
||||||
if masked:
|
if masked:
|
||||||
@ -900,7 +900,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# dictionary of key: dtypes for the model params
|
# dictionary of key: dtypes for the model params
|
||||||
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
|
||||||
# extract keys of parameters not in jnp.float32
|
# extract keys of parameters not in jnp.float32
|
||||||
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
|
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
|
||||||
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
|
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
|
||||||
|
@ -90,7 +90,7 @@ def flatten_nested_dict(params, parent_key="", sep="/"):
|
|||||||
|
|
||||||
|
|
||||||
def to_f32(params):
|
def to_f32(params):
|
||||||
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params)
|
return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params)
|
||||||
|
|
||||||
|
|
||||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||||
@ -398,7 +398,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Load from checkpoint and convert params to float-32
|
# Load from checkpoint and convert params to float-32
|
||||||
variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"]
|
variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"]
|
||||||
flax_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables)
|
flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables)
|
||||||
del variables
|
del variables
|
||||||
|
|
||||||
# Convert CLIP backbone
|
# Convert CLIP backbone
|
||||||
|
@ -776,7 +776,7 @@ class FlaxModelTesterMixin:
|
|||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
# check if all params are still in float32 when dtype of computation is half-precision
|
# check if all params are still in float32 when dtype of computation is half-precision
|
||||||
model = model_class(config, dtype=jnp.float16)
|
model = model_class(config, dtype=jnp.float16)
|
||||||
types = jax.tree_map(lambda x: x.dtype, model.params)
|
types = jax.tree_util.tree_map(lambda x: x.dtype, model.params)
|
||||||
types = flatten_dict(types)
|
types = flatten_dict(types)
|
||||||
|
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
@ -790,7 +790,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# cast all params to bf16
|
# cast all params to bf16
|
||||||
params = model.to_bf16(model.params)
|
params = model.to_bf16(model.params)
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||||
# test if all params are in bf16
|
# test if all params are in bf16
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||||
@ -802,7 +802,7 @@ class FlaxModelTesterMixin:
|
|||||||
mask = unflatten_dict(mask)
|
mask = unflatten_dict(mask)
|
||||||
|
|
||||||
params = model.to_bf16(model.params, mask)
|
params = model.to_bf16(model.params, mask)
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||||
# test if all params are in bf16 except key
|
# test if all params are in bf16 except key
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
if name == key:
|
if name == key:
|
||||||
@ -818,7 +818,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# cast all params to fp16
|
# cast all params to fp16
|
||||||
params = model.to_fp16(model.params)
|
params = model.to_fp16(model.params)
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||||
# test if all params are in fp16
|
# test if all params are in fp16
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||||
@ -830,7 +830,7 @@ class FlaxModelTesterMixin:
|
|||||||
mask = unflatten_dict(mask)
|
mask = unflatten_dict(mask)
|
||||||
|
|
||||||
params = model.to_fp16(model.params, mask)
|
params = model.to_fp16(model.params, mask)
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||||
# test if all params are in fp16 except key
|
# test if all params are in fp16 except key
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
if name == key:
|
if name == key:
|
||||||
@ -849,7 +849,7 @@ class FlaxModelTesterMixin:
|
|||||||
params = model.to_fp32(params)
|
params = model.to_fp32(params)
|
||||||
|
|
||||||
# test if all params are in fp32
|
# test if all params are in fp32
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
||||||
|
|
||||||
@ -864,7 +864,7 @@ class FlaxModelTesterMixin:
|
|||||||
params = model.to_fp32(params, mask)
|
params = model.to_fp32(params, mask)
|
||||||
|
|
||||||
# test if all params are in fp32 except key
|
# test if all params are in fp32 except key
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
if name == key:
|
if name == key:
|
||||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
|
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
|
||||||
@ -884,7 +884,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# load the weights again and check if they are still in fp16
|
# load the weights again and check if they are still in fp16
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||||
|
|
||||||
@ -901,7 +901,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# load the weights again and check if they are still in fp16
|
# load the weights again and check if they are still in fp16
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
|
||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user