[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_*

* fix double tree_util
This commit is contained in:
Sanchit Gandhi 2022-09-09 14:18:56 +01:00 committed by GitHub
parent 22f7218560
commit e6f221c8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 49 additions and 49 deletions

View File

@ -1011,7 +1011,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
if training_args.push_to_hub:
@ -1064,7 +1064,7 @@ def main():
if metrics:
# normalize metrics
metrics = get_metrics(metrics)
metrics = jax.tree_map(jnp.mean, metrics)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
# compute ROUGE metrics
generations = []

View File

@ -781,7 +781,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
@ -824,7 +824,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])

View File

@ -827,9 +827,9 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
@ -841,7 +841,7 @@ def main():
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
@ -867,9 +867,9 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])

View File

@ -940,7 +940,7 @@ def main():
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Update progress bar
epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
@ -952,7 +952,7 @@ def main():
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
@ -978,7 +978,7 @@ def main():
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}

View File

@ -902,7 +902,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute ROUGE metrics
rouge_desc = ""
@ -923,7 +923,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
@ -957,7 +957,7 @@ def main():
# normalize prediction metrics
pred_metrics = get_metrics(pred_metrics)
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)
# compute ROUGE metrics
rouge_desc = ""

View File

@ -542,7 +542,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()
@ -560,7 +560,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

View File

@ -104,7 +104,7 @@ class DataCollator:
def __call__(self, batch):
batch = self.collate_fn(batch)
batch = jax.tree_map(shard, batch)
batch = jax.tree_util.tree_map(shard, batch)
return batch
def collate_fn(self, features):

View File

@ -608,9 +608,9 @@ if __name__ == "__main__":
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
steps.desc = (
@ -624,7 +624,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,

View File

@ -551,7 +551,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()

View File

@ -481,7 +481,7 @@ def main():
param_spec = set_partitions(unfreeze(model.params))
# Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
params_shapes = jax.tree_map(lambda x: x.shape, model.params)
params_shapes = jax.tree_util.tree_map(lambda x: x.shape, model.params)
state_shapes = jax.eval_shape(get_initial_state, params_shapes)
# get PartitionSpec for opt_state, this is very specific to adamw
@ -492,7 +492,7 @@ def main():
return param_spec
return None
opt_state_spec, param_spec = jax.tree_map(
opt_state_spec, param_spec = jax.tree_util.tree_map(
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
@ -506,7 +506,7 @@ def main():
# hack: move the inital params to CPU to free up device memory
# TODO: allow loading weights on CPU in pre-trained model
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
# mesh defination
mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
@ -636,7 +636,7 @@ def main():
# normalize eval metrics
eval_metrics = stack_forest(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])

View File

@ -591,7 +591,7 @@ def main():
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Update progress bar
epochs.write(
@ -606,7 +606,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)

View File

@ -674,9 +674,9 @@ if __name__ == "__main__":
eval_metrics.append(metrics)
eval_metrics_np = get_metrics(eval_metrics)
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
eval_metrics_np = jax.tree_util.tree_map(jnp.sum, eval_metrics_np)
eval_normalizer = eval_metrics_np.pop("normalizer")
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
eval_summary = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
# Update progress bar
epochs.desc = (

View File

@ -699,7 +699,7 @@ class FlaxGenerationMixin:
else:
return tensor[batch_indices, beam_indices]
return jax.tree_map(gather_fn, nested)
return jax.tree_util.tree_map(gather_fn, nested)
# init values
max_length = max_length if max_length is not None else self.config.max_length
@ -788,7 +788,7 @@ class FlaxGenerationMixin:
model_outputs = model(input_token, params=params, **state.model_kwargs)
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
cache = jax.tree_map(
cache = jax.tree_util.tree_map(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
)
@ -874,7 +874,7 @@ class FlaxGenerationMixin:
# With these, gather the top k beam-associated caches.
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache)
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return BeamSearchState(

View File

@ -253,7 +253,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
raise
# check if we have bf16 weights
is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
# convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
# and bf16 is not fully supported in PT yet.
@ -261,7 +261,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_map(
flax_state = jax.tree_util.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)

View File

@ -303,10 +303,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
return param
if mask is None:
return jax.tree_map(conditional_cast, params)
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
@ -900,7 +900,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]

View File

@ -90,7 +90,7 @@ def flatten_nested_dict(params, parent_key="", sep="/"):
def to_f32(params):
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params)
return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params)
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
@ -398,7 +398,7 @@ if __name__ == "__main__":
# Load from checkpoint and convert params to float-32
variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"]
flax_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables)
flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables)
del variables
# Convert CLIP backbone

View File

@ -776,7 +776,7 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes:
# check if all params are still in float32 when dtype of computation is half-precision
model = model_class(config, dtype=jnp.float16)
types = jax.tree_map(lambda x: x.dtype, model.params)
types = jax.tree_util.tree_map(lambda x: x.dtype, model.params)
types = flatten_dict(types)
for name, type_ in types.items():
@ -790,7 +790,7 @@ class FlaxModelTesterMixin:
# cast all params to bf16
params = model.to_bf16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
@ -802,7 +802,7 @@ class FlaxModelTesterMixin:
mask = unflatten_dict(mask)
params = model.to_bf16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16 except key
for name, type_ in types.items():
if name == key:
@ -818,7 +818,7 @@ class FlaxModelTesterMixin:
# cast all params to fp16
params = model.to_fp16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
@ -830,7 +830,7 @@ class FlaxModelTesterMixin:
mask = unflatten_dict(mask)
params = model.to_fp16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16 except key
for name, type_ in types.items():
if name == key:
@ -849,7 +849,7 @@ class FlaxModelTesterMixin:
params = model.to_fp32(params)
# test if all params are in fp32
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
@ -864,7 +864,7 @@ class FlaxModelTesterMixin:
params = model.to_fp32(params, mask)
# test if all params are in fp32 except key
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
@ -884,7 +884,7 @@ class FlaxModelTesterMixin:
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
@ -901,7 +901,7 @@ class FlaxModelTesterMixin:
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")