`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.
Co-authored-by: Peter Hawkins <phawkins@google.com>
* refactor: change default block_size
* fix: return tf to origin
* fix: change files to origin
* rebase
* rebase
* rebase
* rebase
* rebase
* rebase
* rebase
* rebase
* refactor: add min block_size to files
* reformat: add min block_size for run_clm tf
* Result of black 23.1
* Update target to Python 3.7
* Switch flake8 to ruff
* Configure isort
* Configure isort
* Apply isort with line limit
* Put the right black version
* adapt black in check copies
* Fix copies
* remove sum for list flattening
* change to chain(*)
* make chain object a list
* delete empty lines
per sgugger's suggestions
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Nicholas Broad <nicholas@nmbroad.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>