Cookie Consent by Privacy Policies Generator Vegapit
How to use Torch in Rust with tch-rs
Rust MachineLearning Torch Fri, 19 Jul 2019 12:18:00 +0000

Thanks to the diligent work of Laurent Mazare on his tch-rs crate, the Rust community can now enjoy an easy access to the powerful Torch neural net framework. Being personally an avid user of both Rust and Torch, stumbling on this repo has been nothing but a belated birthday present. In this post, I would like to dive into two examples to present its most fundamental functionalities.

Getting gradients from a computational graph

As it is always easier to understand a concept in a real world application, let's implement a useful script that would use Torch computational graph capabilities: a European option pricer that handles first and second-order derivatives based on the Black76 formula. If you are not familiar with what a financial option is, do not worry one bit. The important take away will be the principles used to derive the gradients so you can aply them in your own settings.

To achieve this, we will need 2 functions: the standard Normal cumulative distribution function and the Black76 option price formula. As we will need to calculate the gradients of these functions, they need to be implemented using Tensor structs so that Torch can build the computational graph in the background.

fn norm_cdf(x: &Tensor) -> Tensor {
    0.5 * ( 1.0 + ( x / Tensor::from(2.0).sqrt() ).erf() )

fn black76(epsilon: &Tensor, f: &Tensor, k: &Tensor, t: &Tensor, sigma: &Tensor, r: &Tensor) -> Tensor {
    let d1 = ((f/k).log() + Tensor::from(0.5) * sigma.pow(2.0) * t) / ( sigma * t.sqrt() );
    let d2 = &d1 - sigma * t.sqrt();
    epsilon * (-r * t).exp() * ( f * norm_cdf(&(epsilon * d1)) - k * norm_cdf(&(epsilon * d2)) )

The black76 function returns the price of an option given its set of characteristics. It is useful to extract the first derivative of the price w.r.t the underlying price f and the volatility sigmawhich are called respectively delta and vega. To do so, we need to use the run_backward function as such:

let price = &black76(&epsilon, &f, &k, &t, &sigma, &r); // 11.8049

let price_grad = Tensor::run_backward(&[price], &[&f,&sigma], true, true); 
let delta = &price_grad[0]; // 0.5540
let vega = &price_grad[1]; // 39.0554

Torch will calculate the gradient of the variables set in the first array parameter w.r.t the variables set in the second array parameter. The first boolean parameter indicate to Torch whether we intend to use this computational graph for further calculations. The second boolean parameter indicates whether Torch need to build a new computational graph for the results, which would then be used in calculating their own gradients. We set them both to true in our case, as we want to calculate the derivative of delta w.r.t. the underlying price f which is called gammaand the derivative of the vega w.r.t the underlying price f and the volatility sigma which are called respectively vanna and volga. To do this, we simply call the run_backward function two more times:

let delta_grad = Tensor::run_backward(&[delta], &[&f], true, false);
let gamma = &delta_grad[0]; // 0.013

let vega_grad = Tensor::run_backward(&[vega], &[&f,&sigma], false, false);
let vanna = &vega_grad[0]; // 0.1953
let volga = &vega_grad[1]; // -2.9292

The first boolean of the first function call has to be set to true to force Torch to retain the graph for the last gradient calculation. There we have it, a Torch-powered European option pricer script with derivatives.

Optimising a function through backpropagation

On derivatives exchanges, we can observe option prices and derive all but one of its characteristics from the contract specification: the volatility sigma. Our new focus will be to implement a solver that will derive the volatility sigma from the option price and the other characteristics. To achieve this we need to remember one crucial aspect about closures in Rust: As opposed to functions, they can access all variables previously defined in their scope. Here they are in action in this example:

fn greeter() -> impl Fn(&str) -> String {
    let greeting = "Hello";
    move |x| {
        format!("{} {}!", greeting, x)

fn main() {
    let f = greeter();
    println!("{}", f("everybody")); // Hello everybody!
    println!("{}", f("world")); // Hello world!

In the function greeter, the greeting immutable variable is defined and moved to a closure that is then returned. The greeting parameter lives on inside the returned closure that implements the Fn trait as it takes an immutable receiver. Let's use exactly the same principle to solve our problem:

fn func_builder(p: nn::Path) -> impl Fn(&Tensor,&Tensor,&Tensor,&Tensor,&Tensor) -> Tensor {
    let sigma = p.randn_standard("sigma", &[1]);
    move |epsilon, f, k , t ,r| {
        black76(&epsilon, &f, &k, &t, &sigma, &r)

The variable we need to solve (sigma) is a Tensor generated using the Path struct that points to the variable store that will be used over the course of the computation. As in our previous example, the variable will be "passed on" to the closure returned by the function. We are now ready to setup our optimisation loop in the main function:

fn main() {
    let vs = nn::VarStore::new( tch::Device::Cpu );
    let black76_volsolver = func_builder( vs.root() );
    let opt = nn::Adam::default().build(&vs, 1e-2).unwrap();

    let epsilon = Tensor::from(1f64);
    let f = Tensor::from(100f64);
    let k = Tensor::from(100f64);
    let t = Tensor::from(1f64);
    let r = Tensor::from(0.01);
    let price = Tensor::from(11.805);

    loop {
        let square_loss = (black76_volsolver(&epsilon, &f, &k, &t, &r) - &price).pow(2f64).sum();
        println!("{}", f64::from(&square_loss) );
        if f64::from(&square_loss) < 0.001 {
    let sigma = &vs.root().get("sigma").unwrap();
    let calc_price = f64::from( black76(&epsilon, &f, &k, &t, &sigma, &r) );
    assert!( (calc_price - f64::from(price)).abs() < 0.01  );

First we define a VarStore on the CPU. We pass its root Path for our func_builder call, and collect the closure that will be used for solving for volatility. Then, we generate an Optimizer struct that will use the Adam optimisation algorithm, and that will be linked to our VarStore. Then, after defining the known option characteristics and the target price, we back propagate our loss in a loop using backward_step and exit when the square loss gets below a certain threshold. Finally, the final result is checked versus our target to ensure success.

The cargo project containing the full code can be found here

If you like this post, follow me on Twitter and get notified on the next posts.