* accept custom device_mesh
* fix device_map
* assert that num_heads % tp_size == 0
* todo.
* ReplicateParallel
* handle tied weights
* handle dtensor in save_pretrained with safe_serialization
* tp test works
* doesnt work
* fix shard_and_distribute_module's rank should be local_rank
* tp=4 is correct
* dp+tp is broken
* todo allreduce with dtensors on another dim is annoying
* workaround to sync dp grads when using dtensors
* loading a checkpoint works
* wandb and compare losses with different tp/dp
* cleaning
* cleaning
* .
* .
* logs
* CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention
* DP=2 TP=2 now works even with tied embeddings
* model.parameters() and model.module.parameters() are empty..
* reformat sanity_check_tensor_sync
* set atol=1e-4 for CP to pass
* try populate _parameters from named_modules
* refactors
TP2 DP2 works
CP2 DP2 works
* is_causal=True and pack sequences, no attn mask, and preshuffle dataset
* fix packing
* CP=4 doesn't work
* fix labels and position_ids for CP
* DP CP works with transformers 🥳🥳🥳
* refactor
* add example cp
* fixup
* revert sdpa changes
* example cleared
* add CP, DP to the mesh init
* nit
* clean
* use `ALL_PARALLEL_STYLES`
* style
* FSDP works
* log on 1 rank
* .
* fix?
* FSDP1 also has .parameters() bug
* reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay
* .
* style and fixup
* move stuff around
* fix tests
* style
* let's make it a check
* add missing licences
* warning should be an info
* tp plan should not be NONE
* test all
* god damn it
* test all
---------
Co-authored-by: nouamanetazi <nouamane98@gmail.com>
* accept custom device_mesh
* fix device_map
* assert that num_heads % tp_size == 0
* todo.
* ReplicateParallel
* handle tied weights
* handle dtensor in save_pretrained with safe_serialization
* tp test works
* doesnt work
* fix shard_and_distribute_module's rank should be local_rank
* tp=4 is correct
* dp+tp is broken
* todo allreduce with dtensors on another dim is annoying
* workaround to sync dp grads when using dtensors
* loading a checkpoint works
* wandb and compare losses with different tp/dp
* cleaning
* cleaning
* .
* .
* logs
* CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention
* DP=2 TP=2 now works even with tied embeddings
* model.parameters() and model.module.parameters() are empty..
* reformat sanity_check_tensor_sync
* set atol=1e-4 for CP to pass
* try populate _parameters from named_modules
* refactors
TP2 DP2 works
CP2 DP2 works
* is_causal=True and pack sequences, no attn mask, and preshuffle dataset
* fix packing
* CP=4 doesn't work
* fix labels and position_ids for CP
* DP CP works with transformers 🥳🥳🥳
* refactor
* add example cp
* fixup
* revert sdpa changes
* example cleared
* add CP, DP to the mesh init
* nit
* clean
* use `ALL_PARALLEL_STYLES`
* style
* FSDP works
* log on 1 rank
* .
* fix?
* FSDP1 also has .parameters() bug
* reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay
* .
* style and fixup
* move stuff around
* fix tests
* style
* let's make it a check
* warning should be an info
---------
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>