DiffSharp


Stochastic Gradient Descent

Stochastic gradient descent is a stochastic variant of the gradient descent algorithm that is used for minimizing loss functions with the form of a sum

\[ Q(\mathbf{w}) = \sum_{i=1}^{d} Q_i(\mathbf{w}) \; ,\]

where \(\mathbf{w}\) is a "weight" vector that is being optimized. The component \(Q_i\) is the contribution of the \(i\)-th sample to the overall loss \(Q\), which is to be minimized using a training set of \(d\) samples.

Using the standard gradient descent algorithm, \(Q\) can be minimized by the iteration

\[ \begin{eqnarray*} \mathbf{w}_{t+1} &=& \mathbf{w}_t - \eta \nabla Q \\ &=& \mathbf{w}_t - \eta \sum_{i=1}^{d} \nabla Q_i(\mathbf{w}_t) \; ,\\ \end{eqnarray*}\]

where \(\eta > 0\) is a step size. This "batch" update rule has to compute the full loss in each step, the evaluation time of which is proportional to the size of the training set \(d\).

Alternatively, in stochastic gradient descent, \(Q\) is minimized using

\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla Q_i (\mathbf{w}_t) \; ,\]

updating the weights \(\mathbf{w}\) in each step using just one sample \(i\) randomly chosen from the training set. This is advantageous for big sample sizes, because it makes the evaluation time of each step independent from \(d\). Another advantage is that it can process samples on the fly, in an online learning task.

In practice, instead of \(\eta\), the algorithm is used with a decreasing sequence of step sizes \(\eta_t\), for convergence.

Let's implement stochastic gradient descent with constant step size.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
open DiffSharp.AD.Float64

let rnd = new System.Random()

// Stochastic gradient descent
// f: function, w0: starting weights, eta: step size, epsilon: threshold, t: training set
let sgd f w0 (eta:D) epsilon (t:(DV*DV)[]) =
    let rec desc w =
        let x, y = t.[rnd.Next(t.Length)]
        let g = grad (fun wi -> DV.l2norm (y - (f wi x))) w
        if DV.l2norm g < epsilon then w else desc (w - eta * g)
    desc w0

In this implementation \(Q_i\) has the form

\[ Q_i(\mathbf{w}) = \left\Vert \mathbf{y}_i - f_{\mathbf{w}} (\mathbf{x}_i) \right\Vert \; ,\]

where \(f_{\mathbf{w}} : \mathbb{R}^n \to \mathbb{R}^m\) is a model function for our data (parameterized by \(\mathbf{w}\)) and \(\mathbf{x}_i \in \mathbb{R}^n\) and \(\mathbf{y}_i \in \mathbb{R}^m\) are the input–output pair of the \(i\)-th sample in the training set. Finding the parameters \(\mathbf{w}\) minimizing \(Q(\mathbf{w}) = \sum_{i=1}^{d} Q_i (\mathbf{w})\) thus fits the model function \(f_{\mathbf{w}}\) to our data.

We can test this via fitting a curve

\[ f_{\mathbf{w}} (x) = w_1 x^2 + w_2 x + w_3\]

to the points \((0.5, 2), (3.2, 1), (5.2, 4)\).

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
// Model function
let inline f (w:DV) (x:DV) =
    w.[0] * x.[0] * x.[0] + w.[1] * x.[0] + w.[2] 

// Points
let points = [|0.5, 2.
               3.2, 1.
               5.2, 4.|]

// Construct training set using the points
let train = Array.map (fun x -> (toDV [fst x]), (toDV [snd x])) points

// Find w minimizing the error of fit
let wopt = sgd f (toDV [0.; 0.; 0.]) (D 0.0001) (D 0.01) train
val wopt : DV = DV [|0.3854891266; -1.761000211; 2.731836811|]

We can plot the points in the training set and the fitted curve.

1: 
2: 
3: 
4: 
open FSharp.Charting

Chart.Combine([Chart.Line([for x in 0. .. 0.1 .. 6. -> (x, float <|f wopt (toDV [x]))])
               Chart.Point(points, MarkerSize = 10)])
Chart
namespace DiffSharp
namespace DiffSharp.AD
module Float64

from DiffSharp.AD
val rnd : System.Random

Full name: Examples-stochasticgradientdescent.rnd
namespace System
Multiple items
type Random =
  new : unit -> Random + 1 overload
  member Next : unit -> int + 2 overloads
  member NextBytes : buffer:byte[] -> unit
  member NextDouble : unit -> float

Full name: System.Random

--------------------
System.Random() : unit
System.Random(Seed: int) : unit
val sgd : f:(DV -> DV -> D) -> w0:DV -> eta:D -> epsilon:D -> t:(DV * DV) [] -> DV

Full name: Examples-stochasticgradientdescent.sgd
val f : (DV -> DV -> D)
val w0 : DV
val eta : D
type D =
  | D of float
  | DF of D * D * uint32
  | DR of D * D ref * TraceOp * uint32 ref * uint32
  interface IComparable
  member Copy : unit -> D
  override Equals : other:obj -> bool
  member GetForward : t:D * i:uint32 -> D
  override GetHashCode : unit -> int
  member GetReverse : i:uint32 -> D
  override ToString : unit -> string
  member A : D
  member F : uint32
  member P : D
  member PD : D
  member T : D
  member A : D with set
  member F : uint32 with set
  static member Abs : a:D -> D
  static member Acos : a:D -> D
  static member Asin : a:D -> D
  static member Atan : a:D -> D
  static member Atan2 : a:int * b:D -> D
  static member Atan2 : a:D * b:int -> D
  static member Atan2 : a:float * b:D -> D
  static member Atan2 : a:D * b:float -> D
  static member Atan2 : a:D * b:D -> D
  static member Ceiling : a:D -> D
  static member Cos : a:D -> D
  static member Cosh : a:D -> D
  static member Exp : a:D -> D
  static member Floor : a:D -> D
  static member Log : a:D -> D
  static member Log10 : a:D -> D
  static member LogSumExp : a:D -> D
  static member Max : a:D * b:D -> D
  static member Min : a:D * b:D -> D
  static member Op_D_D : a:D * ff:(float -> float) * fd:(D -> D) * df:(D * D * D -> D) * r:(D -> TraceOp) -> D
  static member Op_D_D_D : a:D * b:D * ff:(float * float -> float) * fd:(D * D -> D) * df_da:(D * D * D -> D) * df_db:(D * D * D -> D) * df_dab:(D * D * D * D * D -> D) * r_d_d:(D * D -> TraceOp) * r_d_c:(D * D -> TraceOp) * r_c_d:(D * D -> TraceOp) -> D
  static member Pow : a:int * b:D -> D
  static member Pow : a:D * b:int -> D
  static member Pow : a:float * b:D -> D
  static member Pow : a:D * b:float -> D
  static member Pow : a:D * b:D -> D
  static member ReLU : a:D -> D
  static member Round : a:D -> D
  static member Sigmoid : a:D -> D
  static member Sign : a:D -> D
  static member Sin : a:D -> D
  static member Sinh : a:D -> D
  static member SoftPlus : a:D -> D
  static member SoftSign : a:D -> D
  static member Sqrt : a:D -> D
  static member Tan : a:D -> D
  static member Tanh : a:D -> D
  static member One : D
  static member Zero : D
  static member ( + ) : a:int * b:D -> D
  static member ( + ) : a:D * b:int -> D
  static member ( + ) : a:float * b:D -> D
  static member ( + ) : a:D * b:float -> D
  static member ( + ) : a:D * b:D -> D
  static member ( / ) : a:int * b:D -> D
  static member ( / ) : a:D * b:int -> D
  static member ( / ) : a:float * b:D -> D
  static member ( / ) : a:D * b:float -> D
  static member ( / ) : a:D * b:D -> D
  static member op_Explicit : d:D -> float
  static member ( * ) : a:int * b:D -> D
  static member ( * ) : a:D * b:int -> D
  static member ( * ) : a:float * b:D -> D
  static member ( * ) : a:D * b:float -> D
  static member ( * ) : a:D * b:D -> D
  static member ( - ) : a:int * b:D -> D
  static member ( - ) : a:D * b:int -> D
  static member ( - ) : a:float * b:D -> D
  static member ( - ) : a:D * b:float -> D
  static member ( - ) : a:D * b:D -> D
  static member ( ~- ) : a:D -> D

Full name: DiffSharp.AD.Float64.D
val epsilon : D
val t : (DV * DV) []
Multiple items
union case DV.DV: float [] -> DV

--------------------
module DV

from DiffSharp.AD.Float64

--------------------
type DV =
  | DV of float []
  | DVF of DV * DV * uint32
  | DVR of DV * DV ref * TraceOp * uint32 ref * uint32
  member Copy : unit -> DV
  member GetForward : t:DV * i:uint32 -> DV
  member GetReverse : i:uint32 -> DV
  member GetSlice : lower:int option * upper:int option -> DV
  member ToArray : unit -> D []
  member ToColDM : unit -> DM
  member ToMathematicaString : unit -> string
  member ToMatlabString : unit -> string
  member ToRowDM : unit -> DM
  override ToString : unit -> string
  member Visualize : unit -> string
  member A : DV
  member F : uint32
  member Item : i:int -> D with get
  member Length : int
  member P : DV
  member PD : DV
  member T : DV
  member A : DV with set
  member F : uint32 with set
  static member Abs : a:DV -> DV
  static member Acos : a:DV -> DV
  static member AddItem : a:DV * i:int * b:D -> DV
  static member AddSubVector : a:DV * i:int * b:DV -> DV
  static member Append : a:DV * b:DV -> DV
  static member Asin : a:DV -> DV
  static member Atan : a:DV -> DV
  static member Atan2 : a:int * b:DV -> DV
  static member Atan2 : a:DV * b:int -> DV
  static member Atan2 : a:float * b:DV -> DV
  static member Atan2 : a:DV * b:float -> DV
  static member Atan2 : a:D * b:DV -> DV
  static member Atan2 : a:DV * b:D -> DV
  static member Atan2 : a:DV * b:DV -> DV
  static member Ceiling : a:DV -> DV
  static member Cos : a:DV -> DV
  static member Cosh : a:DV -> DV
  static member Exp : a:DV -> DV
  static member Floor : a:DV -> DV
  static member L1Norm : a:DV -> D
  static member L2Norm : a:DV -> D
  static member L2NormSq : a:DV -> D
  static member Log : a:DV -> DV
  static member Log10 : a:DV -> DV
  static member LogSumExp : a:DV -> D
  static member Max : a:DV -> D
  static member Max : a:D * b:DV -> DV
  static member Max : a:DV * b:D -> DV
  static member Max : a:DV * b:DV -> DV
  static member MaxIndex : a:DV -> int
  static member Mean : a:DV -> D
  static member Min : a:DV -> D
  static member Min : a:D * b:DV -> DV
  static member Min : a:DV * b:D -> DV
  static member Min : a:DV * b:DV -> DV
  static member MinIndex : a:DV -> int
  static member Normalize : a:DV -> DV
  static member OfArray : a:D [] -> DV
  static member Op_DV_D : a:DV * ff:(float [] -> float) * fd:(DV -> D) * df:(D * DV * DV -> D) * r:(DV -> TraceOp) -> D
  static member Op_DV_DM : a:DV * ff:(float [] -> float [,]) * fd:(DV -> DM) * df:(DM * DV * DV -> DM) * r:(DV -> TraceOp) -> DM
  static member Op_DV_DV : a:DV * ff:(float [] -> float []) * fd:(DV -> DV) * df:(DV * DV * DV -> DV) * r:(DV -> TraceOp) -> DV
  static member Op_DV_DV_D : a:DV * b:DV * ff:(float [] * float [] -> float) * fd:(DV * DV -> D) * df_da:(D * DV * DV -> D) * df_db:(D * DV * DV -> D) * df_dab:(D * DV * DV * DV * DV -> D) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> D
  static member Op_DV_DV_DM : a:DV * b:DV * ff:(float [] * float [] -> float [,]) * fd:(DV * DV -> DM) * df_da:(DM * DV * DV -> DM) * df_db:(DM * DV * DV -> DM) * df_dab:(DM * DV * DV * DV * DV -> DM) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> DM
  static member Op_DV_DV_DV : a:DV * b:DV * ff:(float [] * float [] -> float []) * fd:(DV * DV -> DV) * df_da:(DV * DV * DV -> DV) * df_db:(DV * DV * DV -> DV) * df_dab:(DV * DV * DV * DV * DV -> DV) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> DV
  static member Op_DV_D_DV : a:DV * b:D * ff:(float [] * float -> float []) * fd:(DV * D -> DV) * df_da:(DV * DV * DV -> DV) * df_db:(DV * D * D -> DV) * df_dab:(DV * DV * DV * D * D -> DV) * r_d_d:(DV * D -> TraceOp) * r_d_c:(DV * D -> TraceOp) * r_c_d:(DV * D -> TraceOp) -> DV
  static member Op_D_DV_DV : a:D * b:DV * ff:(float * float [] -> float []) * fd:(D * DV -> DV) * df_da:(DV * D * D -> DV) * df_db:(DV * DV * DV -> DV) * df_dab:(DV * D * D * DV * DV -> DV) * r_d_d:(D * DV -> TraceOp) * r_d_c:(D * DV -> TraceOp) * r_c_d:(D * DV -> TraceOp) -> DV
  static member Pow : a:int * b:DV -> DV
  static member Pow : a:DV * b:int -> DV
  static member Pow : a:float * b:DV -> DV
  static member Pow : a:DV * b:float -> DV
  static member Pow : a:D * b:DV -> DV
  static member Pow : a:DV * b:D -> DV
  static member Pow : a:DV * b:DV -> DV
  static member ReLU : a:DV -> DV
  static member ReshapeToDM : m:int * a:DV -> DM
  static member Round : a:DV -> DV
  static member Sigmoid : a:DV -> DV
  static member Sign : a:DV -> DV
  static member Sin : a:DV -> DV
  static member Sinh : a:DV -> DV
  static member SoftMax : a:DV -> DV
  static member SoftPlus : a:DV -> DV
  static member SoftSign : a:DV -> DV
  static member Split : d:DV * n:seq<int> -> seq<DV>
  static member Sqrt : a:DV -> DV
  static member StandardDev : a:DV -> D
  static member Standardize : a:DV -> DV
  static member Sum : a:DV -> D
  static member Tan : a:DV -> DV
  static member Tanh : a:DV -> DV
  static member Variance : a:DV -> D
  static member ZeroN : n:int -> DV
  static member Zero : DV
  static member ( + ) : a:int * b:DV -> DV
  static member ( + ) : a:DV * b:int -> DV
  static member ( + ) : a:float * b:DV -> DV
  static member ( + ) : a:DV * b:float -> DV
  static member ( + ) : a:D * b:DV -> DV
  static member ( + ) : a:DV * b:D -> DV
  static member ( + ) : a:DV * b:DV -> DV
  static member ( &* ) : a:DV * b:DV -> DM
  static member ( / ) : a:int * b:DV -> DV
  static member ( / ) : a:DV * b:int -> DV
  static member ( / ) : a:float * b:DV -> DV
  static member ( / ) : a:DV * b:float -> DV
  static member ( / ) : a:D * b:DV -> DV
  static member ( / ) : a:DV * b:D -> DV
  static member ( ./ ) : a:DV * b:DV -> DV
  static member ( .* ) : a:DV * b:DV -> DV
  static member op_Explicit : d:float [] -> DV
  static member op_Explicit : d:DV -> float []
  static member ( * ) : a:int * b:DV -> DV
  static member ( * ) : a:DV * b:int -> DV
  static member ( * ) : a:float * b:DV -> DV
  static member ( * ) : a:DV * b:float -> DV
  static member ( * ) : a:D * b:DV -> DV
  static member ( * ) : a:DV * b:D -> DV
  static member ( * ) : a:DV * b:DV -> D
  static member ( - ) : a:int * b:DV -> DV
  static member ( - ) : a:DV * b:int -> DV
  static member ( - ) : a:float * b:DV -> DV
  static member ( - ) : a:DV * b:float -> DV
  static member ( - ) : a:D * b:DV -> DV
  static member ( - ) : a:DV * b:D -> DV
  static member ( - ) : a:DV * b:DV -> DV
  static member ( ~- ) : a:DV -> DV

Full name: DiffSharp.AD.Float64.DV
val desc : (DV -> DV)
val w : DV
val x : DV
val y : DV
System.Random.Next() : int
System.Random.Next(maxValue: int) : int
System.Random.Next(minValue: int, maxValue: int) : int
property System.Array.Length: int
val g : DV
val grad : f:('c -> D) -> x:'c -> 'c (requires member GetReverse and member get_A)

Full name: DiffSharp.AD.Float64.DiffOps.grad
val wi : DV
val l2norm : v:DV -> D

Full name: DiffSharp.AD.Float64.DV.l2norm
val f : w:DV -> x:DV -> D

Full name: Examples-stochasticgradientdescent.f
val points : (float * float) []

Full name: Examples-stochasticgradientdescent.points
val train : (DV * DV) []

Full name: Examples-stochasticgradientdescent.train
module Array

from Microsoft.FSharp.Collections
val map : mapping:('T -> 'U) -> array:'T [] -> 'U []

Full name: Microsoft.FSharp.Collections.Array.map
val x : float * float
val toDV : v:seq<'a> -> DV (requires member op_Explicit)

Full name: DiffSharp.AD.Float64.DOps.toDV
val fst : tuple:('T1 * 'T2) -> 'T1

Full name: Microsoft.FSharp.Core.Operators.fst
val snd : tuple:('T1 * 'T2) -> 'T2

Full name: Microsoft.FSharp.Core.Operators.snd
val wopt : DV

Full name: Examples-stochasticgradientdescent.wopt
union case D.D: float -> D
val printf : format:Printf.TextWriterFormat<'T> -> 'T

Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.printf
namespace FSharp
namespace FSharp.Charting
type Chart =
  static member Area : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
  static member Area : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
  static member Bar : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
  static member Bar : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
  static member BoxPlotFromData : data:seq<#key * #seq<'a2>> * ?Name:string * ?Title:string * ?Color:Color * ?XTitle:string * ?YTitle:string * ?Percentile:int * ?ShowAverage:bool * ?ShowMedian:bool * ?ShowUnusualValues:bool * ?WhiskerPercentile:int -> GenericChart (requires 'a2 :> value)
  static member BoxPlotFromStatistics : data:seq<#key * #value * #value * #value * #value * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string * ?Percentile:int * ?ShowAverage:bool * ?ShowMedian:bool * ?ShowUnusualValues:bool * ?WhiskerPercentile:int -> GenericChart
  static member Bubble : data:seq<#value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string * ?BubbleMaxSize:int * ?BubbleMinSize:int * ?BubbleScaleMax:float * ?BubbleScaleMin:float * ?UseSizeForLabel:bool -> GenericChart
  static member Bubble : data:seq<#key * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string * ?BubbleMaxSize:int * ?BubbleMinSize:int * ?BubbleScaleMax:float * ?BubbleScaleMin:float * ?UseSizeForLabel:bool -> GenericChart
  static member Candlestick : data:seq<#value * #value * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> CandlestickChart
  static member Candlestick : data:seq<#key * #value * #value * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> CandlestickChart
  ...

Full name: FSharp.Charting.Chart
static member Chart.Combine : charts:seq<ChartTypes.GenericChart> -> ChartTypes.GenericChart
static member Chart.Line : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string -> ChartTypes.GenericChart
static member Chart.Line : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string -> ChartTypes.GenericChart
val x : float
Multiple items
val float : value:'T -> float (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.Operators.float

--------------------
type float = System.Double

Full name: Microsoft.FSharp.Core.float

--------------------
type float<'Measure> = float

Full name: Microsoft.FSharp.Core.float<_>
static member Chart.Point : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string * ?MarkerColor:System.Drawing.Color * ?MarkerSize:int -> ChartTypes.GenericChart
static member Chart.Point : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string * ?MarkerColor:System.Drawing.Color * ?MarkerSize:int -> ChartTypes.GenericChart