Google JAX
Tipus | biblioteca Python |
---|---|
Versió inicial | |
Versió estable | |
Llicència | Llicència Apache, versió 2.0 |
Equip | |
Desenvolupador(s) | Peter Hawkins (en) , Matthew Johnson (en) i Jacob VanderPlas (en) |
Més informació | |
Lloc web | jax.readthedocs.io… |
| |
Google JAX és un marc d'aprenentatge automàtic per transformar funcions numèriques.[1][2] Es descriu com reunir una versió modificada d'autograd (obtenció automàtica de la funció de gradient mitjançant la diferenciació d'una funció) i XLA de TensorFlow (àlgebra lineal accelerada). Està dissenyat per seguir l'estructura i el flux de treball de NumPy tan de prop com sigui possible i funciona amb diversos marcs existents com TensorFlow i PyTorch.[3][4] Les funcions principals de JAX són:
- grau: diferenciació automàtica
- jit: compilació
- vmap: vectorització automàtica
- pmap: programació SPMD
Funció grau
[modifica]El codi següent mostra la diferenciació automàtica de la funció de graduació .
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
Funcio jit
[modifica]El codi següent mostra l'optimització de la funció jit mitjançant la fusió.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
Funció vmap
[modifica]El codi següent mostra la vectorització de la funció vmap.
# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
Funció pmap
[modifica]El codi següent mostra la paral·lelització de la funció pmap per a la multiplicació de matrius.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
Biblioteques que utilitzen JAX
[modifica]Diverses biblioteques de Python utilitzen JAX com a backend, incloent:
- Flax, una biblioteca de xarxes neuronals d'alt nivell desenvolupada inicialment per Google Brain.
- Equinox, una biblioteca que gira al voltant de la idea de representar funcions parametritzades (incloses les xarxes neuronals) com a PyTrees. Va ser creat per Patrick Kidger.
- Diffrax, una biblioteca per a la solució numèrica d'equacions diferencials, com ara equacions diferencials ordinàries i equacions diferencials estocàstiques.
- Optax, una biblioteca per al processament i optimització de gradients desenvolupada per DeepMind.
- Lineax, una biblioteca per resoldre numèricament sistemes lineals i mínims quadrats lineals.
- RLax, una biblioteca per desenvolupar agents d'aprenentatge de reforç desenvolupada per DeepMind.
- jraph, una biblioteca per a xarxes neuronals gràfics, desenvolupada per DeepMind.
- jaxtyping, una biblioteca per afegir anotacions de tipus per a la forma i el tipus de dades ("dtype") de matrius o tensors.
Referències
[modifica]- ↑ Frostig, Roy; Johnson, Matthew James; Leary, Chris MLsys, 02-02-2018, pàg. 1–3.
- ↑ «Using JAX to accelerate our research» (en anglès). www.deepmind.com. Arxivat de l'original el 2022-06-18. [Consulta: 18 juny 2022].
- ↑ Lynley, Matthew. «Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta» (en anglès americà). Business Insider. Arxivat de l'original el 2022-06-21. [Consulta: 21 juny 2022].
- ↑ «Why is Google's JAX so popular?» (en anglès americà). Analytics India Magazine, 25-04-2022. Arxivat de l'original el 2022-06-18. [Consulta: 18 juny 2022].