DiffSharp


Gradient Descent

The gradient descent algorithm is an optimization algorithm for finding a local minimum of a scalar-valued function near a starting point, taking successive steps in the direction of the negative of the gradient.

For a function \(f: \mathbb{R}^n \to \mathbb{R}\), starting from an initial point \(\mathbf{x}_0\), the method works by computing successive points in the function domain

\[ \mathbf{x}_{n + 1} = \mathbf{x}_n - \eta \left( \nabla f \right)_{\mathbf{x}_n} \; ,\]

where \(\eta > 0\) is a small step size and \(\left( \nabla f \right)_{\mathbf{x}_n}\) is the gradient of \(f\) evaluated at \(\mathbf{x}_n\). The successive values of the function

\[ f(\mathbf{x}_0) \ge f(\mathbf{x}_1) \ge f(\mathbf{x}_2) \ge \dots\]

keep decreasing and the sequence \(\mathbf{x}_n\) usually converges to a local minimum.

In practice, using a fixed step size \(\eta\) yields suboptimal performance and there are adaptive algorithms that select a locally optimal step size \(\eta\) on each iteration.

The following code implements gradient descent with fixed step size, stopping when the norm of the gradient falls below a given threshold.

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
9: 
open DiffSharp.AD.Float64

// Gradient descent
// f: function, x0: starting point, eta: step size, epsilon: threshold
let gd f x0 eta epsilon =
    let rec desc x =
        let g = grad f x
        if DV.l2norm g < epsilon then x else desc (x - eta * g)
    desc x0

Let's find a minimum of \(f(x, y) = (\sin x + \cos y)\).

1: 
2: 
3: 
4: 
5: 
6: 
let inline f (x:DV) =  sin x.[0] + cos x.[1]

// Find the minimum of f
// Start from (1, 1), step size 0.9, threshold 0.00001
let xmin = gd f (toDV [1.; 1.]) (D 0.9) (D 0.00001)
let fxmin = f xmin
val xmin : DV = DV [|-1.570790759; 3.141591964|]
val fxmin : D = D -2.0

A minimum, \(f(x, y) = -2\), is found at \((x, y) = \left(-\frac{\pi}{2}, \pi\right)\).

namespace DiffSharp
namespace DiffSharp.AD
module Float64

from DiffSharp.AD
val gd : f:(DV -> D) -> x0:DV -> eta:D -> epsilon:D -> DV

Full name: Examples-gradientdescent.gd
val f : (DV -> D)
val x0 : DV
val eta : D
val epsilon : D
val desc : (DV -> DV)
val x : DV
val g : DV
val grad : f:('c -> D) -> x:'c -> 'c (requires member GetReverse and member get_A)

Full name: DiffSharp.AD.Float64.DiffOps.grad
Multiple items
union case DV.DV: float [] -> DV

--------------------
module DV

from DiffSharp.AD.Float64

--------------------
type DV =
  | DV of float []
  | DVF of DV * DV * uint32
  | DVR of DV * DV ref * TraceOp * uint32 ref * uint32
  member Copy : unit -> DV
  member GetForward : t:DV * i:uint32 -> DV
  member GetReverse : i:uint32 -> DV
  member GetSlice : lower:int option * upper:int option -> DV
  member ToArray : unit -> D []
  member ToColDM : unit -> DM
  member ToMathematicaString : unit -> string
  member ToMatlabString : unit -> string
  member ToRowDM : unit -> DM
  override ToString : unit -> string
  member Visualize : unit -> string
  member A : DV
  member F : uint32
  member Item : i:int -> D with get
  member Length : int
  member P : DV
  member PD : DV
  member T : DV
  member A : DV with set
  member F : uint32 with set
  static member Abs : a:DV -> DV
  static member Acos : a:DV -> DV
  static member AddItem : a:DV * i:int * b:D -> DV
  static member AddSubVector : a:DV * i:int * b:DV -> DV
  static member Append : a:DV * b:DV -> DV
  static member Asin : a:DV -> DV
  static member Atan : a:DV -> DV
  static member Atan2 : a:int * b:DV -> DV
  static member Atan2 : a:DV * b:int -> DV
  static member Atan2 : a:float * b:DV -> DV
  static member Atan2 : a:DV * b:float -> DV
  static member Atan2 : a:D * b:DV -> DV
  static member Atan2 : a:DV * b:D -> DV
  static member Atan2 : a:DV * b:DV -> DV
  static member Ceiling : a:DV -> DV
  static member Cos : a:DV -> DV
  static member Cosh : a:DV -> DV
  static member Exp : a:DV -> DV
  static member Floor : a:DV -> DV
  static member L1Norm : a:DV -> D
  static member L2Norm : a:DV -> D
  static member L2NormSq : a:DV -> D
  static member Log : a:DV -> DV
  static member Log10 : a:DV -> DV
  static member LogSumExp : a:DV -> D
  static member Max : a:DV -> D
  static member Max : a:D * b:DV -> DV
  static member Max : a:DV * b:D -> DV
  static member Max : a:DV * b:DV -> DV
  static member MaxIndex : a:DV -> int
  static member Mean : a:DV -> D
  static member Min : a:DV -> D
  static member Min : a:D * b:DV -> DV
  static member Min : a:DV * b:D -> DV
  static member Min : a:DV * b:DV -> DV
  static member MinIndex : a:DV -> int
  static member Normalize : a:DV -> DV
  static member OfArray : a:D [] -> DV
  static member Op_DV_D : a:DV * ff:(float [] -> float) * fd:(DV -> D) * df:(D * DV * DV -> D) * r:(DV -> TraceOp) -> D
  static member Op_DV_DM : a:DV * ff:(float [] -> float [,]) * fd:(DV -> DM) * df:(DM * DV * DV -> DM) * r:(DV -> TraceOp) -> DM
  static member Op_DV_DV : a:DV * ff:(float [] -> float []) * fd:(DV -> DV) * df:(DV * DV * DV -> DV) * r:(DV -> TraceOp) -> DV
  static member Op_DV_DV_D : a:DV * b:DV * ff:(float [] * float [] -> float) * fd:(DV * DV -> D) * df_da:(D * DV * DV -> D) * df_db:(D * DV * DV -> D) * df_dab:(D * DV * DV * DV * DV -> D) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> D
  static member Op_DV_DV_DM : a:DV * b:DV * ff:(float [] * float [] -> float [,]) * fd:(DV * DV -> DM) * df_da:(DM * DV * DV -> DM) * df_db:(DM * DV * DV -> DM) * df_dab:(DM * DV * DV * DV * DV -> DM) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> DM
  static member Op_DV_DV_DV : a:DV * b:DV * ff:(float [] * float [] -> float []) * fd:(DV * DV -> DV) * df_da:(DV * DV * DV -> DV) * df_db:(DV * DV * DV -> DV) * df_dab:(DV * DV * DV * DV * DV -> DV) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> DV
  static member Op_DV_D_DV : a:DV * b:D * ff:(float [] * float -> float []) * fd:(DV * D -> DV) * df_da:(DV * DV * DV -> DV) * df_db:(DV * D * D -> DV) * df_dab:(DV * DV * DV * D * D -> DV) * r_d_d:(DV * D -> TraceOp) * r_d_c:(DV * D -> TraceOp) * r_c_d:(DV * D -> TraceOp) -> DV
  static member Op_D_DV_DV : a:D * b:DV * ff:(float * float [] -> float []) * fd:(D * DV -> DV) * df_da:(DV * D * D -> DV) * df_db:(DV * DV * DV -> DV) * df_dab:(DV * D * D * DV * DV -> DV) * r_d_d:(D * DV -> TraceOp) * r_d_c:(D * DV -> TraceOp) * r_c_d:(D * DV -> TraceOp) -> DV
  static member Pow : a:int * b:DV -> DV
  static member Pow : a:DV * b:int -> DV
  static member Pow : a:float * b:DV -> DV
  static member Pow : a:DV * b:float -> DV
  static member Pow : a:D * b:DV -> DV
  static member Pow : a:DV * b:D -> DV
  static member Pow : a:DV * b:DV -> DV
  static member ReLU : a:DV -> DV
  static member ReshapeToDM : m:int * a:DV -> DM
  static member Round : a:DV -> DV
  static member Sigmoid : a:DV -> DV
  static member Sign : a:DV -> DV
  static member Sin : a:DV -> DV
  static member Sinh : a:DV -> DV
  static member SoftMax : a:DV -> DV
  static member SoftPlus : a:DV -> DV
  static member SoftSign : a:DV -> DV
  static member Split : d:DV * n:seq<int> -> seq<DV>
  static member Sqrt : a:DV -> DV
  static member StandardDev : a:DV -> D
  static member Standardize : a:DV -> DV
  static member Sum : a:DV -> D
  static member Tan : a:DV -> DV
  static member Tanh : a:DV -> DV
  static member Variance : a:DV -> D
  static member ZeroN : n:int -> DV
  static member Zero : DV
  static member ( + ) : a:int * b:DV -> DV
  static member ( + ) : a:DV * b:int -> DV
  static member ( + ) : a:float * b:DV -> DV
  static member ( + ) : a:DV * b:float -> DV
  static member ( + ) : a:D * b:DV -> DV
  static member ( + ) : a:DV * b:D -> DV
  static member ( + ) : a:DV * b:DV -> DV
  static member ( &* ) : a:DV * b:DV -> DM
  static member ( / ) : a:int * b:DV -> DV
  static member ( / ) : a:DV * b:int -> DV
  static member ( / ) : a:float * b:DV -> DV
  static member ( / ) : a:DV * b:float -> DV
  static member ( / ) : a:D * b:DV -> DV
  static member ( / ) : a:DV * b:D -> DV
  static member ( ./ ) : a:DV * b:DV -> DV
  static member ( .* ) : a:DV * b:DV -> DV
  static member op_Explicit : d:float [] -> DV
  static member op_Explicit : d:DV -> float []
  static member ( * ) : a:int * b:DV -> DV
  static member ( * ) : a:DV * b:int -> DV
  static member ( * ) : a:float * b:DV -> DV
  static member ( * ) : a:DV * b:float -> DV
  static member ( * ) : a:D * b:DV -> DV
  static member ( * ) : a:DV * b:D -> DV
  static member ( * ) : a:DV * b:DV -> D
  static member ( - ) : a:int * b:DV -> DV
  static member ( - ) : a:DV * b:int -> DV
  static member ( - ) : a:float * b:DV -> DV
  static member ( - ) : a:DV * b:float -> DV
  static member ( - ) : a:D * b:DV -> DV
  static member ( - ) : a:DV * b:D -> DV
  static member ( - ) : a:DV * b:DV -> DV
  static member ( ~- ) : a:DV -> DV

Full name: DiffSharp.AD.Float64.DV
val l2norm : v:DV -> D

Full name: DiffSharp.AD.Float64.DV.l2norm
val f : x:DV -> D

Full name: Examples-gradientdescent.f
val sin : value:'T -> 'T (requires member Sin)

Full name: Microsoft.FSharp.Core.Operators.sin
val cos : value:'T -> 'T (requires member Cos)

Full name: Microsoft.FSharp.Core.Operators.cos
val xmin : DV

Full name: Examples-gradientdescent.xmin
val toDV : v:seq<'a> -> DV (requires member op_Explicit)

Full name: DiffSharp.AD.Float64.DOps.toDV
union case D.D: float -> D
val fxmin : D

Full name: Examples-gradientdescent.fxmin
val printf : format:Printf.TextWriterFormat<'T> -> 'T

Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.printf