Processing math: 100%

autodiff

The autodiff module can be used to differentiate scalar output functions of the form:

f:RnR

Simple expressions are differentiated as follows:

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // SIMPLE EXPRESSIONS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ let g = Graph::new(); let a = 1.; let b = 2.; let x = g.var(69.); let y = g.var(420.); // Define a function. let f = a + b + (x * y).exp(); // Accumulate the gradient. let gradient = f.accumulate(); println!("z = {}", f.value); println!("dz/dx = {}", gradient.wrt(&x)); println!("dz/dy = {}", gradient.wrt(&y)); println!("grad = {:?}", gradient.wrt(&[x, y]));

Block expressions are also supported:

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // BLOCK EXPRESSIONS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ let g = Graph::new(); let x = g.var(69.); let y = g.var(420.); let block = { let z = x.sin() + y.tan(); z.exp() }; let grad = block.accumulate(); println!("f = {}", block.value); println!("df/dx = {}", grad.wrt(&x)); println!("df/dy = {}", grad.wrt(&y)); println!("grad = {:?}", grad.wrt(&[x, y]));

Closures and functions are also supported:

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // CLOSURES // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ let g = Graph::new(); let x = g.var(1.); let y = g.var(2.); let closure = || (x * y).cosh() / (x.tanh() * y.sinh()); let grad = closure().accumulate(); println!("z = {}", closure().value); println!("dz/dx = {}", grad.wrt(&x)); println!("dz/dy = {}", grad.wrt(&y)); println!("grad = {:?}", grad.wrt(&[x, y]));
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // PROPER FUNCTIONS. // // Note that you can also add many variables via vectors, slices, arrays, etc. // This is where the `autodiff` crate really shines, as it allows // you to differentiate functions of any number of variables and // computing gradients for large functions using AD rather than // finite-difference quotients is significantly faster and has no error. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Function to differentiate: // f = x^(y + cos(1)) - atanh(z) / 2 + 1 // at x = 3, y = 2, z = 1. #[rustfmt::skip] fn function<'v>(variables: &[Variable<'v>], constants: &[f64]) -> Variable<'v> { variables[0].powf(variables[1] + constants[0].cos()) - variables[2].atanh() / constants[1] + constants[0] } // New graph. let graph = Graph::new(); // Variables and constants. let variables = graph.vars(&[3.0, 2.0, 1.0]); let constants = [1., 2.]; // Evaluate and differentiate the function. let result = function(&variables, &constants); let gradient = result.accumulate(); // Print the graph length. println!("Graph length: {}", graph.len()); println!("{:?}", gradient.wrt(&variables)); // Print the graphviz output. // You can copy and paste this into your Graphviz viewer of choice. println!("{}", graphviz(&graph, &variables));