diff --git a/docs/index.md b/docs/index.md index 1d89399..66c4489 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,9 +11,9 @@ This case study looks at gradient descent, and the application of gradient desce ## Gradient Descent for Function Fitting -At the time of writing (September 2022), [Stable Diffusion](https://stablediffusionweb.com/) is one of the newest, and best, text-to-image programs. Give it a try! Enter some text and see what image yous can get it to produce. It's certainly impressive, though the results are still sometimes a little bit odd. +At the time of writing (September 2022), [Stable Diffusion](https://stablediffusionweb.com/) is one of the newest, and best, text-to-image programs. Give it a try! Enter some text and see what image you can get it to produce. It's certainly impressive, though the results are still sometimes a little bit odd. -At it's core, Stable Diffusion and similar programs such as Midjourney, are functions. Remember the core idea of a function is that you put something and get something back. In this case you put in text and get back an image. What makes these functions particularly interesting is that parts of the function are learned from data. The data consists of example of text and images associated with them. The general shape of the function is fixed but many parts of it, called weights, are adjusted so that, given input, the output becomes closer to that in data used for learning. +At it's core, Stable Diffusion and similar programs such as Midjourney, are functions. Remember the core idea of a function is that you put something and get something back. In this case you put in text and get back an image. What makes these functions particularly interesting is that parts of the function are learned from data. The data consists of pairs of text and associated images. The general shape of the function is fixed but many parts of it, called weights, are adjusted so that, given input, the output becomes closer to the image given the associated input text. An example will help make this clearer. Consider the function below. We'll call this function the *model*. @@ -32,11 +32,13 @@ The model is a function with two parameters: (Note that I'm defining functions so that the code is closer to the mathematics, but we could equally use a Scala method.) -You can play with the demo below, to see how changing the value of `a` changes the model. +You can play with the demo below. See how changing the value of `a` changes the model. @:doodle(draw-basic-plot, Sine.drawBasicPlot) -Now imagine we have some data, which are pairs of `x` and `y` values. For each `x` value we have the `y` value we'd like the model to produce. We can adjust the value of `a` to bring the model closer or further away from the output. To quantify how good a choice we've made for `a`, we can look at the distance between the model output and the `y` value for each data point in our data set. We'll call this the *loss function* or just the *loss*. The demo below allows you to adjust `a` and see how the the loss changes for some randomly choosen data. You should note that you can increase and decrease the loss by changing `a`. +Now imagine we have some data, which are pairs of `x` and `y` values. For each `x` value we have the `y` value we'd like the model to produce. We can adjust the value of `a` to bring the model closer or further away from the output. + +To quantify how good a choice we've made for `a`, we can look at the distance between the model output and the `y` value for each data point in our data set. We'll call this the *loss function* or just the *loss*. The demo below allows you to adjust `a` and see how the the loss changes for some randomly chosen data. You should note that you can increase and decrease the loss by changing `a`. @:doodle(draw-error-plot, Sine.drawErrorPlot) @@ -53,7 +55,7 @@ Now the final piece of the puzzle is to tell the computer how to adjust the para We use the term *differentiation* for finding the gradient of a function. -To really formalize this we need to be a bit more precise about about the error function. For one particular data point, the loss is +To really formalize this we need to be a bit more precise about the error function. For one particular data point, the loss is $$ pointLoss(f, a, x, y) = (f(x, a) - y)^2 $$ @@ -69,7 +71,6 @@ val pointLoss: ((Double, Double) => Double, Double, Double, Double) => Double = This means the loss for a single point is always non-negative. - Now we just have to sum up the loss over all the data points to get what we commonly call the loss. $$ loss(data, f, a) = \sum_{pt \in data}pointLoss(f, a, pt.x, pt.y)$$