Tensors and differentiable operations in Rust

### 16 releases (5 breaking)

 0.8.0 Jun 25, 2018 Mar 27, 2018 Mar 15, 2018 Dec 2, 2017 Nov 25, 2017

#6 in Machine learning

272KB
7K SLoC

This library provides differentiable operations and tensors. The current backend is rust-ndarray.

## Features

### Lazy, side-effect-free tensors in Rust

Tensors themselves don't have the values basically. It realizes graphs that are eagerly executable at any timing.

### Gradients using reverse-mode automatic differentiation

It supports higher-order derivatives. Defining your own differentiable operations is not so difficult.

### Runtime

Graph execution engine is implemented in pure Rust,
so compilable to WebAssembly with little or no modifications. GPUs are not supported for now.

## Examples

Here we are computing partial derivatives of `z = 2x^2 + 3y + 1`.

``````

let ref x = ag::placeholder(&[]);
let ref y = ag::placeholder(&[]);
let ref z = 2*x*x + 3*y + 1;

// dz/dy
println!("{:?}", gy.eval(&[]));   // => Ok(3.)

// dz/dx (requires to fill the placeholder `x`)
println!("{:?}", gx.eval(&[(x, &ag::ndarray::arr0(2.))]));  // => Ok(8.)

// ddz/dx (differentiates `z` again)
println!("{:?}", ggx.eval(&[]));  // => Ok(4.)
``````

Another example: softmax regression for MNIST digits classification with Adam.

``````// This achieves 0.918 test accuracy after 3 epochs,
// 0.27 sec/epoch on 2.7GHz Intel Core i5 (blas feature is disabled)

let ref x = ag::placeholder(&[-1, 28*28]);
let ref y = ag::placeholder(&[-1]);
let ref w = ag::variable(ag::ndarray_ext::glorot_uniform(&[28*28, 10]));
let ref b = ag::variable(ag::ndarray_ext::zeros(&[1, 10]));
let ref z = ag::matmul(x, w) + b;
let ref loss = ag::reduce_mean(&ag::sparse_softmax_cross_entropy(z, y), &[0, 1], false);
let ref params = [w, b];
let ref predictions = ag::argmax(z, -1, true);
let ref accuracy = ag::reduce_mean(&ag::equal(predictions, y), &[0], false);

// -- dataset --
let ((x_train, y_train), (x_test, y_test)) = dataset::load();

// -- training loop --
for epoch in 0..max_epoch {
...
ag::eval(update_ops, &[(x, &x_batch), (y, &y_batch)]);
}

``````

For more, see documentation or examples

#### Dependencies

~78MB
~2M SLoC

• build `build.rs`
• build cc 1.0