DiffSharp

Stochastic gradient descent is a stochastic variant of the gradient descent algorithm that is used for minimizing loss functions with the form of a sum

$Q(\mathbf{w}) = \sum_{i=1}^{d} Q_i(\mathbf{w}) \; ,$

where $$\mathbf{w}$$ is a "weight" vector that is being optimized. The component $$Q_i$$ is the contribution of the $$i$$-th sample to the overall loss $$Q$$, which is to be minimized using a training set of $$d$$ samples.

Using the standard gradient descent algorithm, $$Q$$ can be minimized by the iteration

$\begin{eqnarray*} \mathbf{w}_{t+1} &=& \mathbf{w}_t - \eta \nabla Q \\ &=& \mathbf{w}_t - \eta \sum_{i=1}^{d} \nabla Q_i(\mathbf{w}_t) \; ,\\ \end{eqnarray*}$

where $$\eta > 0$$ is a step size. This "batch" update rule has to compute the full loss in each step, the evaluation time of which is proportional to the size of the training set $$d$$.

Alternatively, in stochastic gradient descent, $$Q$$ is minimized using

$\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla Q_i (\mathbf{w}_t) \; ,$

updating the weights $$\mathbf{w}$$ in each step using just one sample $$i$$ randomly chosen from the training set. This is advantageous for big sample sizes, because it makes the evaluation time of each step independent from $$d$$. Another advantage is that it can process samples on the fly, in an online learning task.

In practice, instead of $$\eta$$, the algorithm is used with a decreasing sequence of step sizes $$\eta_t$$, for convergence.

Let's implement stochastic gradient descent with constant step size.

  1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12:  open DiffSharp.AD.Float64 let rnd = new System.Random() // Stochastic gradient descent // f: function, w0: starting weights, eta: step size, epsilon: threshold, t: training set let sgd f w0 (eta:D) epsilon (t:(DV*DV)[]) = let rec desc w = let x, y = t.[rnd.Next(t.Length)] let g = grad (fun wi -> DV.l2norm (y - (f wi x))) w if DV.l2norm g < epsilon then w else desc (w - eta * g) desc w0 

In this implementation $$Q_i$$ has the form

$Q_i(\mathbf{w}) = \left\Vert \mathbf{y}_i - f_{\mathbf{w}} (\mathbf{x}_i) \right\Vert \; ,$

where $$f_{\mathbf{w}} : \mathbb{R}^n \to \mathbb{R}^m$$ is a model function for our data (parameterized by $$\mathbf{w}$$) and $$\mathbf{x}_i \in \mathbb{R}^n$$ and $$\mathbf{y}_i \in \mathbb{R}^m$$ are the input–output pair of the $$i$$-th sample in the training set. Finding the parameters $$\mathbf{w}$$ minimizing $$Q(\mathbf{w}) = \sum_{i=1}^{d} Q_i (\mathbf{w})$$ thus fits the model function $$f_{\mathbf{w}}$$ to our data.

We can test this via fitting a curve

$f_{\mathbf{w}} (x) = w_1 x^2 + w_2 x + w_3$

to the points $$(0.5, 2), (3.2, 1), (5.2, 4)$$.

  1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14:  // Model function let inline f (w:DV) (x:DV) = w.[0] * x.[0] * x.[0] + w.[1] * x.[0] + w.[2] // Points let points = [|0.5, 2. 3.2, 1. 5.2, 4.|] // Construct training set using the points let train = Array.map (fun x -> (toDV [fst x]), (toDV [snd x])) points // Find w minimizing the error of fit let wopt = sgd f (toDV [0.; 0.; 0.]) (D 0.0001) (D 0.01) train 
 val wopt : DV = DV [|0.3854891266; -1.761000211; 2.731836811|]

We can plot the points in the training set and the fitted curve.

 1: 2: 3: 4:  open FSharp.Charting Chart.Combine([Chart.Line([for x in 0. .. 0.1 .. 6. -> (x, float <|f wopt (toDV [x]))]) Chart.Point(points, MarkerSize = 10)]) 
namespace DiffSharp
module Float64

val rnd : System.Random

namespace System
Multiple items
type Random =
new : unit -> Random + 1 overload
member Next : unit -> int + 2 overloads
member NextBytes : buffer:byte[] -> unit
member NextDouble : unit -> float

Full name: System.Random

--------------------
System.Random() : unit
System.Random(Seed: int) : unit
val sgd : f:(DV -> DV -> D) -> w0:DV -> eta:D -> epsilon:D -> t:(DV * DV) [] -> DV

val f : (DV -> DV -> D)
val w0 : DV
val eta : D
type D =
| D of float
| DF of D * D * uint32
| DR of D * D ref * TraceOp * uint32 ref * uint32
interface IComparable
member Copy : unit -> D
override Equals : other:obj -> bool
member GetForward : t:D * i:uint32 -> D
override GetHashCode : unit -> int
member GetReverse : i:uint32 -> D
override ToString : unit -> string
member A : D
member F : uint32
member P : D
member PD : D
member T : D
member A : D with set
member F : uint32 with set
static member Abs : a:D -> D
static member Acos : a:D -> D
static member Asin : a:D -> D
static member Atan : a:D -> D
static member Atan2 : a:int * b:D -> D
static member Atan2 : a:D * b:int -> D
static member Atan2 : a:float * b:D -> D
static member Atan2 : a:D * b:float -> D
static member Atan2 : a:D * b:D -> D
static member Ceiling : a:D -> D
static member Cos : a:D -> D
static member Cosh : a:D -> D
static member Exp : a:D -> D
static member Floor : a:D -> D
static member Log : a:D -> D
static member Log10 : a:D -> D
static member LogSumExp : a:D -> D
static member Max : a:D * b:D -> D
static member Min : a:D * b:D -> D
static member Op_D_D : a:D * ff:(float -> float) * fd:(D -> D) * df:(D * D * D -> D) * r:(D -> TraceOp) -> D
static member Op_D_D_D : a:D * b:D * ff:(float * float -> float) * fd:(D * D -> D) * df_da:(D * D * D -> D) * df_db:(D * D * D -> D) * df_dab:(D * D * D * D * D -> D) * r_d_d:(D * D -> TraceOp) * r_d_c:(D * D -> TraceOp) * r_c_d:(D * D -> TraceOp) -> D
static member Pow : a:int * b:D -> D
static member Pow : a:D * b:int -> D
static member Pow : a:float * b:D -> D
static member Pow : a:D * b:float -> D
static member Pow : a:D * b:D -> D
static member ReLU : a:D -> D
static member Round : a:D -> D
static member Sigmoid : a:D -> D
static member Sign : a:D -> D
static member Sin : a:D -> D
static member Sinh : a:D -> D
static member SoftPlus : a:D -> D
static member SoftSign : a:D -> D
static member Sqrt : a:D -> D
static member Tan : a:D -> D
static member Tanh : a:D -> D
static member One : D
static member Zero : D
static member ( + ) : a:int * b:D -> D
static member ( + ) : a:D * b:int -> D
static member ( + ) : a:float * b:D -> D
static member ( + ) : a:D * b:float -> D
static member ( + ) : a:D * b:D -> D
static member ( / ) : a:int * b:D -> D
static member ( / ) : a:D * b:int -> D
static member ( / ) : a:float * b:D -> D
static member ( / ) : a:D * b:float -> D
static member ( / ) : a:D * b:D -> D
static member op_Explicit : d:D -> float
static member ( * ) : a:int * b:D -> D
static member ( * ) : a:D * b:int -> D
static member ( * ) : a:float * b:D -> D
static member ( * ) : a:D * b:float -> D
static member ( * ) : a:D * b:D -> D
static member ( - ) : a:int * b:D -> D
static member ( - ) : a:D * b:int -> D
static member ( - ) : a:float * b:D -> D
static member ( - ) : a:D * b:float -> D
static member ( - ) : a:D * b:D -> D
static member ( ~- ) : a:D -> D

val epsilon : D
val t : (DV * DV) []
Multiple items
union case DV.DV: float [] -> DV

--------------------
module DV

--------------------
type DV =
| DV of float []
| DVF of DV * DV * uint32
| DVR of DV * DV ref * TraceOp * uint32 ref * uint32
member Copy : unit -> DV
member GetForward : t:DV * i:uint32 -> DV
member GetReverse : i:uint32 -> DV
member GetSlice : lower:int option * upper:int option -> DV
member ToArray : unit -> D []
member ToColDM : unit -> DM
member ToMathematicaString : unit -> string
member ToMatlabString : unit -> string
member ToRowDM : unit -> DM
override ToString : unit -> string
member Visualize : unit -> string
member A : DV
member F : uint32
member Item : i:int -> D with get
member Length : int
member P : DV
member PD : DV
member T : DV
member A : DV with set
member F : uint32 with set
static member Abs : a:DV -> DV
static member Acos : a:DV -> DV
static member AddItem : a:DV * i:int * b:D -> DV
static member AddSubVector : a:DV * i:int * b:DV -> DV
static member Append : a:DV * b:DV -> DV
static member Asin : a:DV -> DV
static member Atan : a:DV -> DV
static member Atan2 : a:int * b:DV -> DV
static member Atan2 : a:DV * b:int -> DV
static member Atan2 : a:float * b:DV -> DV
static member Atan2 : a:DV * b:float -> DV
static member Atan2 : a:D * b:DV -> DV
static member Atan2 : a:DV * b:D -> DV
static member Atan2 : a:DV * b:DV -> DV
static member Ceiling : a:DV -> DV
static member Cos : a:DV -> DV
static member Cosh : a:DV -> DV
static member Exp : a:DV -> DV
static member Floor : a:DV -> DV
static member L1Norm : a:DV -> D
static member L2Norm : a:DV -> D
static member L2NormSq : a:DV -> D
static member Log : a:DV -> DV
static member Log10 : a:DV -> DV
static member LogSumExp : a:DV -> D
static member Max : a:DV -> D
static member Max : a:D * b:DV -> DV
static member Max : a:DV * b:D -> DV
static member Max : a:DV * b:DV -> DV
static member MaxIndex : a:DV -> int
static member Mean : a:DV -> D
static member Min : a:DV -> D
static member Min : a:D * b:DV -> DV
static member Min : a:DV * b:D -> DV
static member Min : a:DV * b:DV -> DV
static member MinIndex : a:DV -> int
static member Normalize : a:DV -> DV
static member OfArray : a:D [] -> DV
static member Op_DV_D : a:DV * ff:(float [] -> float) * fd:(DV -> D) * df:(D * DV * DV -> D) * r:(DV -> TraceOp) -> D
static member Op_DV_DM : a:DV * ff:(float [] -> float [,]) * fd:(DV -> DM) * df:(DM * DV * DV -> DM) * r:(DV -> TraceOp) -> DM
static member Op_DV_DV : a:DV * ff:(float [] -> float []) * fd:(DV -> DV) * df:(DV * DV * DV -> DV) * r:(DV -> TraceOp) -> DV
static member Op_DV_DV_D : a:DV * b:DV * ff:(float [] * float [] -> float) * fd:(DV * DV -> D) * df_da:(D * DV * DV -> D) * df_db:(D * DV * DV -> D) * df_dab:(D * DV * DV * DV * DV -> D) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> D
static member Op_DV_DV_DM : a:DV * b:DV * ff:(float [] * float [] -> float [,]) * fd:(DV * DV -> DM) * df_da:(DM * DV * DV -> DM) * df_db:(DM * DV * DV -> DM) * df_dab:(DM * DV * DV * DV * DV -> DM) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> DM
static member Op_DV_DV_DV : a:DV * b:DV * ff:(float [] * float [] -> float []) * fd:(DV * DV -> DV) * df_da:(DV * DV * DV -> DV) * df_db:(DV * DV * DV -> DV) * df_dab:(DV * DV * DV * DV * DV -> DV) * r_d_d:(DV * DV -> TraceOp) * r_d_c:(DV * DV -> TraceOp) * r_c_d:(DV * DV -> TraceOp) -> DV
static member Op_DV_D_DV : a:DV * b:D * ff:(float [] * float -> float []) * fd:(DV * D -> DV) * df_da:(DV * DV * DV -> DV) * df_db:(DV * D * D -> DV) * df_dab:(DV * DV * DV * D * D -> DV) * r_d_d:(DV * D -> TraceOp) * r_d_c:(DV * D -> TraceOp) * r_c_d:(DV * D -> TraceOp) -> DV
static member Op_D_DV_DV : a:D * b:DV * ff:(float * float [] -> float []) * fd:(D * DV -> DV) * df_da:(DV * D * D -> DV) * df_db:(DV * DV * DV -> DV) * df_dab:(DV * D * D * DV * DV -> DV) * r_d_d:(D * DV -> TraceOp) * r_d_c:(D * DV -> TraceOp) * r_c_d:(D * DV -> TraceOp) -> DV
static member Pow : a:int * b:DV -> DV
static member Pow : a:DV * b:int -> DV
static member Pow : a:float * b:DV -> DV
static member Pow : a:DV * b:float -> DV
static member Pow : a:D * b:DV -> DV
static member Pow : a:DV * b:D -> DV
static member Pow : a:DV * b:DV -> DV
static member ReLU : a:DV -> DV
static member ReshapeToDM : m:int * a:DV -> DM
static member Round : a:DV -> DV
static member Sigmoid : a:DV -> DV
static member Sign : a:DV -> DV
static member Sin : a:DV -> DV
static member Sinh : a:DV -> DV
static member SoftMax : a:DV -> DV
static member SoftPlus : a:DV -> DV
static member SoftSign : a:DV -> DV
static member Split : d:DV * n:seq<int> -> seq<DV>
static member Sqrt : a:DV -> DV
static member StandardDev : a:DV -> D
static member Standardize : a:DV -> DV
static member Sum : a:DV -> D
static member Tan : a:DV -> DV
static member Tanh : a:DV -> DV
static member Variance : a:DV -> D
static member ZeroN : n:int -> DV
static member Zero : DV
static member ( + ) : a:int * b:DV -> DV
static member ( + ) : a:DV * b:int -> DV
static member ( + ) : a:float * b:DV -> DV
static member ( + ) : a:DV * b:float -> DV
static member ( + ) : a:D * b:DV -> DV
static member ( + ) : a:DV * b:D -> DV
static member ( + ) : a:DV * b:DV -> DV
static member ( &* ) : a:DV * b:DV -> DM
static member ( / ) : a:int * b:DV -> DV
static member ( / ) : a:DV * b:int -> DV
static member ( / ) : a:float * b:DV -> DV
static member ( / ) : a:DV * b:float -> DV
static member ( / ) : a:D * b:DV -> DV
static member ( / ) : a:DV * b:D -> DV
static member ( ./ ) : a:DV * b:DV -> DV
static member ( .* ) : a:DV * b:DV -> DV
static member op_Explicit : d:float [] -> DV
static member op_Explicit : d:DV -> float []
static member ( * ) : a:int * b:DV -> DV
static member ( * ) : a:DV * b:int -> DV
static member ( * ) : a:float * b:DV -> DV
static member ( * ) : a:DV * b:float -> DV
static member ( * ) : a:D * b:DV -> DV
static member ( * ) : a:DV * b:D -> DV
static member ( * ) : a:DV * b:DV -> D
static member ( - ) : a:int * b:DV -> DV
static member ( - ) : a:DV * b:int -> DV
static member ( - ) : a:float * b:DV -> DV
static member ( - ) : a:DV * b:float -> DV
static member ( - ) : a:D * b:DV -> DV
static member ( - ) : a:DV * b:D -> DV
static member ( - ) : a:DV * b:DV -> DV
static member ( ~- ) : a:DV -> DV

val desc : (DV -> DV)
val w : DV
val x : DV
val y : DV
System.Random.Next() : int
System.Random.Next(maxValue: int) : int
System.Random.Next(minValue: int, maxValue: int) : int
property System.Array.Length: int
val g : DV
val grad : f:('c -> D) -> x:'c -> 'c (requires member GetReverse and member get_A)

val wi : DV
val l2norm : v:DV -> D

val f : w:DV -> x:DV -> D

val points : (float * float) []

val train : (DV * DV) []

module Array

from Microsoft.FSharp.Collections
val map : mapping:('T -> 'U) -> array:'T [] -> 'U []

Full name: Microsoft.FSharp.Collections.Array.map
val x : float * float
val toDV : v:seq<'a> -> DV (requires member op_Explicit)

val fst : tuple:('T1 * 'T2) -> 'T1

Full name: Microsoft.FSharp.Core.Operators.fst
val snd : tuple:('T1 * 'T2) -> 'T2

Full name: Microsoft.FSharp.Core.Operators.snd
val wopt : DV

union case D.D: float -> D
val printf : format:Printf.TextWriterFormat<'T> -> 'T

Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.printf
namespace FSharp
namespace FSharp.Charting
type Chart =
static member Area : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
static member Area : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
static member Bar : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
static member Bar : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> GenericChart
static member BoxPlotFromData : data:seq<#key * #seq<'a2>> * ?Name:string * ?Title:string * ?Color:Color * ?XTitle:string * ?YTitle:string * ?Percentile:int * ?ShowAverage:bool * ?ShowMedian:bool * ?ShowUnusualValues:bool * ?WhiskerPercentile:int -> GenericChart (requires 'a2 :> value)
static member BoxPlotFromStatistics : data:seq<#key * #value * #value * #value * #value * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string * ?Percentile:int * ?ShowAverage:bool * ?ShowMedian:bool * ?ShowUnusualValues:bool * ?WhiskerPercentile:int -> GenericChart
static member Bubble : data:seq<#value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string * ?BubbleMaxSize:int * ?BubbleMinSize:int * ?BubbleScaleMax:float * ?BubbleScaleMin:float * ?UseSizeForLabel:bool -> GenericChart
static member Bubble : data:seq<#key * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string * ?BubbleMaxSize:int * ?BubbleMinSize:int * ?BubbleScaleMax:float * ?BubbleScaleMin:float * ?UseSizeForLabel:bool -> GenericChart
static member Candlestick : data:seq<#value * #value * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> CandlestickChart
static member Candlestick : data:seq<#key * #value * #value * #value * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:Color * ?XTitle:string * ?YTitle:string -> CandlestickChart
...

Full name: FSharp.Charting.Chart
static member Chart.Combine : charts:seq<ChartTypes.GenericChart> -> ChartTypes.GenericChart
static member Chart.Line : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string -> ChartTypes.GenericChart
static member Chart.Line : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string -> ChartTypes.GenericChart
val x : float
Multiple items
val float : value:'T -> float (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.Operators.float

--------------------
type float = System.Double

Full name: Microsoft.FSharp.Core.float

--------------------
type float<'Measure> = float

Full name: Microsoft.FSharp.Core.float<_>
static member Chart.Point : data:seq<#value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string * ?MarkerColor:System.Drawing.Color * ?MarkerSize:int -> ChartTypes.GenericChart
static member Chart.Point : data:seq<#key * #value> * ?Name:string * ?Title:string * ?Labels:#seq<string> * ?Color:System.Drawing.Color * ?XTitle:string * ?YTitle:string * ?MarkerColor:System.Drawing.Color * ?MarkerSize:int -> ChartTypes.GenericChart