-
Notifications
You must be signed in to change notification settings - Fork 659
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This creates the Patch concept along with some start of usages. There is a more specialized ParamPatch for the standard parameter additive patches and a Scaled, Basic, and LoRA implementation. The patches can be created directly, by comparing models, and from gradients. This is an initial step. Following this, there are a few pieces of work that could be considered: 1. DJL Serving Python engine specific patch implementation 2. LoRA for full training 3. Make BasicParamPatch from Optimizer (including gradients, momentum, and lr)
- Loading branch information
Showing
8 changed files
with
552 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.Model; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.nn.Block; | ||
import ai.djl.nn.Parameter; | ||
import ai.djl.nn.ParameterList; | ||
import ai.djl.training.GradientCollector; | ||
import ai.djl.util.Pair; | ||
|
||
import java.util.Map; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
|
||
/** The basic implementation of a {@link ParamPatch}. */ | ||
public class BasicParamPatch extends ParamPatch { | ||
|
||
Map<String, NDArray> data; | ||
|
||
/** | ||
* Constructs a {@link BasicParamPatch} with patching data. | ||
* | ||
* @param data the patching data | ||
*/ | ||
public BasicParamPatch(Map<String, NDArray> data) { | ||
this.data = data; | ||
} | ||
|
||
/** | ||
* Makes a patch by comparing two models. | ||
* | ||
* @param source the source model | ||
* @param target the target model | ||
* @return a patch that would transform the source model to the target model | ||
*/ | ||
public static BasicParamPatch makePatch(Model source, Model target) { | ||
return BasicParamPatch.makePatch(source.getBlock(), target.getBlock()); | ||
} | ||
|
||
/** | ||
* Makes a patch by comparing two blocks. | ||
* | ||
* @param source the source block | ||
* @param target the target block | ||
* @return a patch that would transform the source block to the target block | ||
*/ | ||
public static BasicParamPatch makePatch(Block source, Block target) { | ||
return BasicParamPatch.makePatch(source.getParameters(), target.getParameters()); | ||
} | ||
|
||
/** | ||
* Makes a patch by comparing two {@link ParameterList}s. | ||
* | ||
* @param source the source {@link ParameterList} | ||
* @param target the target {@link ParameterList} | ||
* @return a patch that would transform the source {@link ParameterList} to the target {@link | ||
* ParameterList}. | ||
*/ | ||
public static BasicParamPatch makePatch(ParameterList source, ParameterList target) { | ||
Map<String, NDArray> data = new ConcurrentHashMap<>(source.size()); | ||
for (Pair<String, Parameter> sourcePair : source) { | ||
String key = sourcePair.getKey(); | ||
NDArray patchValue = target.get(key).getArray().sub(sourcePair.getValue().getArray()); | ||
data.put(key, patchValue); | ||
} | ||
return new BasicParamPatch(data); | ||
} | ||
|
||
/** | ||
* Makes a patch from gradients. | ||
* | ||
* <p>This does not include learning rates or any other data from the {@link | ||
* ai.djl.training.optimizer.Optimizer}. | ||
* | ||
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link | ||
* GradientCollector#zeroGradients()} to clear the gradients. | ||
* | ||
* @param block the block for which to collect gradients | ||
* @param gradientCollector the {@link GradientCollector} of the gradients | ||
* @return the gradients as a {@link BasicParamPatch}. | ||
*/ | ||
public static BasicParamPatch makePatch(Block block, GradientCollector gradientCollector) { | ||
ParameterList params = block.getParameters(); | ||
Map<String, NDArray> data = new ConcurrentHashMap<>(params.size()); | ||
for (Pair<String, Parameter> param : params) { | ||
String key = param.getKey(); | ||
// Get gradient * -1 to account for gradient being subtracted from param | ||
NDArray patchValue = param.getValue().getArray().getGradient().duplicate().mul(-1); | ||
data.put(key, patchValue); | ||
} | ||
return new BasicParamPatch(data); | ||
} | ||
|
||
/** | ||
* Makes a patch from gradients. | ||
* | ||
* <p>This does not include learning rates or any other data from the {@link | ||
* ai.djl.training.optimizer.Optimizer}. | ||
* | ||
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link | ||
* GradientCollector#zeroGradients()} to clear the gradients. | ||
* | ||
* @param model the model for which to collect gradients | ||
* @param gradientCollector the {@link GradientCollector} of the gradients | ||
* @return the gradients as a {@link BasicParamPatch}. | ||
*/ | ||
public static BasicParamPatch makePatch(Model model, GradientCollector gradientCollector) { | ||
return makePatch(model.getBlock(), gradientCollector); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPatch(String paramName) { | ||
return data.get(paramName).duplicate(); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void close() { | ||
for (NDArray d : data.values()) { | ||
d.close(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.util.Pair; | ||
|
||
import java.util.Map; | ||
|
||
/** | ||
* A {@link ParamPatch} based on low-rank adapters. | ||
* | ||
* <p>Based on the paper <a href="https://arxiv.org/abs/2106.09685">LoRA: Low-Rank Adaptation of | ||
* Large Language Models</a>. | ||
* | ||
* <p>TODO This support for LoRA is still a placeholder and needs effective code for creating and | ||
* training | ||
*/ | ||
public class LoRA extends ParamPatch { | ||
|
||
/** Data of type map from param name to (A, B) pair. */ | ||
Map<String, Pair<NDArray, NDArray>> data; | ||
|
||
/** | ||
* Constructs a {@link LoRA}. | ||
* | ||
* @param data the data to patch with | ||
*/ | ||
public LoRA(Map<String, Pair<NDArray, NDArray>> data) { | ||
this.data = data; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPatch(String paramName) { | ||
Pair<NDArray, NDArray> d = data.get(paramName); | ||
return d.getKey().get(paramName).matMul(d.getValue().get(paramName)); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void close() { | ||
for (Pair<NDArray, NDArray> d : data.values()) { | ||
d.getKey().close(); | ||
d.getValue().close(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.Model; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.nn.Block; | ||
import ai.djl.nn.Parameter; | ||
import ai.djl.nn.ParameterList; | ||
import ai.djl.util.Pair; | ||
|
||
/** | ||
* A standard {@link Patch} that only adds to {@link Parameter}s. | ||
* | ||
* <p>To create a param patch, see {@link BasicParamPatch}. | ||
*/ | ||
public abstract class ParamPatch extends Patch { | ||
|
||
/** | ||
* Scales the patch by a scalar multiplier. | ||
* | ||
* @param scale the scalar multiplier for each patch NDArray | ||
* @return a new patch that is a scaled version of this patch | ||
*/ | ||
public ParamPatch scale(float scale) { | ||
return new ScaledParamPatch(scale, this); | ||
} | ||
|
||
/** | ||
* Returns a new {@link ParamPatch} that is the additive inverse of this patch. | ||
* | ||
* <p>It is equivalent to scaling by -1. | ||
* | ||
* @return a new {@link ParamPatch} that is the additive inverse of this patch | ||
*/ | ||
public ParamPatch reverse() { | ||
return scale(-1); | ||
} | ||
|
||
/** | ||
* Returns the patch {@link NDArray} for a particular paramName. | ||
* | ||
* @param paramName the parameter path in a {@link ParameterList}. | ||
* @return the patch array | ||
*/ | ||
public abstract NDArray getPatch(String paramName); | ||
|
||
/** | ||
* Applies the part of this patch to a particular {@link Parameter}. | ||
* | ||
* @param paramName the parameter path in a {@link ParameterList}. | ||
* @param param the {@link Parameter} to patch | ||
*/ | ||
public void apply(String paramName, Parameter param) { | ||
NDArray p = getPatch(paramName).duplicate(); | ||
param.getArray().addi(p); | ||
p.close(); | ||
} | ||
|
||
/** | ||
* Applies this patch to a {@link ParameterList}. | ||
* | ||
* @param params the params to patch | ||
*/ | ||
public void apply(ParameterList params) { | ||
for (Pair<String, Parameter> param : params) { | ||
apply(param.getKey(), param.getValue()); | ||
} | ||
} | ||
|
||
/** | ||
* Applies this patch to a {@link Block}. | ||
* | ||
* @param block the block to patch | ||
*/ | ||
public void apply(Block block) { | ||
apply(block.getParameters()); | ||
} | ||
|
||
/** | ||
* Applies this patch to a {@link Model}. | ||
* | ||
* @param model the model to patch | ||
*/ | ||
@Override | ||
public void apply(Model model) { | ||
apply(model.getBlock()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.Model; | ||
|
||
/** | ||
* A method for modifying a {@link Model}. | ||
* | ||
* <p>The most standard form is the {@link ParamPatch}. | ||
*/ | ||
public abstract class Patch implements AutoCloseable { | ||
|
||
/** | ||
* Applies this patch to a model. | ||
* | ||
* @param model the model to update with the patch | ||
*/ | ||
public abstract void apply(Model model); | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public abstract void close(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
|
||
/** | ||
* Constructs a {@link ScaledParamPatch} to scale a {@link ParamPatch} by a scalar multiplier. | ||
* | ||
* @see ParamPatch#scale(float) | ||
*/ | ||
public class ScaledParamPatch extends ParamPatch { | ||
|
||
float scale; | ||
ParamPatch base; | ||
|
||
/** | ||
* Constructs a {@link ScaledParamPatch}. | ||
* | ||
* @param scale the scalar multiplier | ||
* @param base the {@link ParamPatch} to scale | ||
*/ | ||
public ScaledParamPatch(float scale, ParamPatch base) { | ||
if (base instanceof ScaledParamPatch) { | ||
ScaledParamPatch sbase = (ScaledParamPatch) base; | ||
this.scale = scale * sbase.scale; | ||
this.base = sbase.base; | ||
} else { | ||
this.scale = scale; | ||
this.base = base; | ||
} | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPatch(String paramName) { | ||
return base.getPatch(paramName).muli(scale); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void close() { | ||
base.close(); | ||
} | ||
} |
Oops, something went wrong.