Skip to content

Commit

Permalink
Creates Patch
Browse files Browse the repository at this point in the history
This creates the Patch concept along with some start of usages. There is a more
specialized ParamPatch for the standard parameter additive patches and a Scaled,
Basic, and LoRA implementation. The patches can be created directly, by
comparing models, and from gradients.

This is an initial step. Following this, there are a few pieces of work that
could be considered:
1. DJL Serving Python engine specific patch implementation
2. LoRA for full training
3. Make BasicParamPatch from Optimizer (including gradients, momentum, and lr)
  • Loading branch information
zachgk committed Jul 19, 2023
1 parent 1397b2c commit 817484e
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 0 deletions.
135 changes: 135 additions & 0 deletions api/src/main/java/ai/djl/patch/BasicParamPatch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.patch;

import ai.djl.Model;
import ai.djl.ndarray.NDArray;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.training.GradientCollector;
import ai.djl.util.Pair;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** The basic implementation of a {@link ParamPatch}. */
public class BasicParamPatch extends ParamPatch {

Map<String, NDArray> data;

/**
* Constructs a {@link BasicParamPatch} with patching data.
*
* @param data the patching data
*/
public BasicParamPatch(Map<String, NDArray> data) {
this.data = data;
}

/**
* Makes a patch by comparing two models.
*
* @param source the source model
* @param target the target model
* @return a patch that would transform the source model to the target model
*/
public static BasicParamPatch makePatch(Model source, Model target) {
return BasicParamPatch.makePatch(source.getBlock(), target.getBlock());
}

/**
* Makes a patch by comparing two blocks.
*
* @param source the source block
* @param target the target block
* @return a patch that would transform the source block to the target block
*/
public static BasicParamPatch makePatch(Block source, Block target) {
return BasicParamPatch.makePatch(source.getParameters(), target.getParameters());
}

/**
* Makes a patch by comparing two {@link ParameterList}s.
*
* @param source the source {@link ParameterList}
* @param target the target {@link ParameterList}
* @return a patch that would transform the source {@link ParameterList} to the target {@link
* ParameterList}.
*/
public static BasicParamPatch makePatch(ParameterList source, ParameterList target) {
Map<String, NDArray> data = new ConcurrentHashMap<>(source.size());
for (Pair<String, Parameter> sourcePair : source) {
String key = sourcePair.getKey();
NDArray patchValue = target.get(key).getArray().sub(sourcePair.getValue().getArray());
data.put(key, patchValue);
}
return new BasicParamPatch(data);
}

/**
* Makes a patch from gradients.
*
* <p>This does not include learning rates or any other data from the {@link
* ai.djl.training.optimizer.Optimizer}.
*
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link
* GradientCollector#zeroGradients()} to clear the gradients.
*
* @param block the block for which to collect gradients
* @param gradientCollector the {@link GradientCollector} of the gradients
* @return the gradients as a {@link BasicParamPatch}.
*/
public static BasicParamPatch makePatch(Block block, GradientCollector gradientCollector) {
ParameterList params = block.getParameters();
Map<String, NDArray> data = new ConcurrentHashMap<>(params.size());
for (Pair<String, Parameter> param : params) {
String key = param.getKey();
// Get gradient * -1 to account for gradient being subtracted from param
NDArray patchValue = param.getValue().getArray().getGradient().duplicate().mul(-1);
data.put(key, patchValue);
}
return new BasicParamPatch(data);
}

/**
* Makes a patch from gradients.
*
* <p>This does not include learning rates or any other data from the {@link
* ai.djl.training.optimizer.Optimizer}.
*
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link
* GradientCollector#zeroGradients()} to clear the gradients.
*
* @param model the model for which to collect gradients
* @param gradientCollector the {@link GradientCollector} of the gradients
* @return the gradients as a {@link BasicParamPatch}.
*/
public static BasicParamPatch makePatch(Model model, GradientCollector gradientCollector) {
return makePatch(model.getBlock(), gradientCollector);
}

/** {@inheritDoc} */
@Override
public NDArray getPatch(String paramName) {
return data.get(paramName).duplicate();
}

/** {@inheritDoc} */
@Override
public void close() {
for (NDArray d : data.values()) {
d.close();
}
}
}
58 changes: 58 additions & 0 deletions api/src/main/java/ai/djl/patch/LoRA.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.patch;

import ai.djl.ndarray.NDArray;
import ai.djl.util.Pair;

import java.util.Map;

/**
* A {@link ParamPatch} based on low-rank adapters.
*
* <p>Based on the paper <a href="https://arxiv.org/abs/2106.09685">LoRA: Low-Rank Adaptation of
* Large Language Models</a>.
*
* <p>TODO This support for LoRA is still a placeholder and needs effective code for creating and
* training
*/
public class LoRA extends ParamPatch {

/** Data of type map from param name to (A, B) pair. */
Map<String, Pair<NDArray, NDArray>> data;

/**
* Constructs a {@link LoRA}.
*
* @param data the data to patch with
*/
public LoRA(Map<String, Pair<NDArray, NDArray>> data) {
this.data = data;
}

/** {@inheritDoc} */
@Override
public NDArray getPatch(String paramName) {
Pair<NDArray, NDArray> d = data.get(paramName);
return d.getKey().get(paramName).matMul(d.getValue().get(paramName));
}

/** {@inheritDoc} */
@Override
public void close() {
for (Pair<NDArray, NDArray> d : data.values()) {
d.getKey().close();
d.getValue().close();
}
}
}
99 changes: 99 additions & 0 deletions api/src/main/java/ai/djl/patch/ParamPatch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.patch;

import ai.djl.Model;
import ai.djl.ndarray.NDArray;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.util.Pair;

/**
* A standard {@link Patch} that only adds to {@link Parameter}s.
*
* <p>To create a param patch, see {@link BasicParamPatch}.
*/
public abstract class ParamPatch extends Patch {

/**
* Scales the patch by a scalar multiplier.
*
* @param scale the scalar multiplier for each patch NDArray
* @return a new patch that is a scaled version of this patch
*/
public ParamPatch scale(float scale) {
return new ScaledParamPatch(scale, this);
}

/**
* Returns a new {@link ParamPatch} that is the additive inverse of this patch.
*
* <p>It is equivalent to scaling by -1.
*
* @return a new {@link ParamPatch} that is the additive inverse of this patch
*/
public ParamPatch reverse() {
return scale(-1);
}

/**
* Returns the patch {@link NDArray} for a particular paramName.
*
* @param paramName the parameter path in a {@link ParameterList}.
* @return the patch array
*/
public abstract NDArray getPatch(String paramName);

/**
* Applies the part of this patch to a particular {@link Parameter}.
*
* @param paramName the parameter path in a {@link ParameterList}.
* @param param the {@link Parameter} to patch
*/
public void apply(String paramName, Parameter param) {
NDArray p = getPatch(paramName).duplicate();
param.getArray().addi(p);
p.close();
}

/**
* Applies this patch to a {@link ParameterList}.
*
* @param params the params to patch
*/
public void apply(ParameterList params) {
for (Pair<String, Parameter> param : params) {
apply(param.getKey(), param.getValue());
}
}

/**
* Applies this patch to a {@link Block}.
*
* @param block the block to patch
*/
public void apply(Block block) {
apply(block.getParameters());
}

/**
* Applies this patch to a {@link Model}.
*
* @param model the model to patch
*/
@Override
public void apply(Model model) {
apply(model.getBlock());
}
}
34 changes: 34 additions & 0 deletions api/src/main/java/ai/djl/patch/Patch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.patch;

import ai.djl.Model;

/**
* A method for modifying a {@link Model}.
*
* <p>The most standard form is the {@link ParamPatch}.
*/
public abstract class Patch implements AutoCloseable {

/**
* Applies this patch to a model.
*
* @param model the model to update with the patch
*/
public abstract void apply(Model model);

/** {@inheritDoc} */
@Override
public abstract void close();
}
55 changes: 55 additions & 0 deletions api/src/main/java/ai/djl/patch/ScaledParamPatch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.patch;

import ai.djl.ndarray.NDArray;

/**
* Constructs a {@link ScaledParamPatch} to scale a {@link ParamPatch} by a scalar multiplier.
*
* @see ParamPatch#scale(float)
*/
public class ScaledParamPatch extends ParamPatch {

float scale;
ParamPatch base;

/**
* Constructs a {@link ScaledParamPatch}.
*
* @param scale the scalar multiplier
* @param base the {@link ParamPatch} to scale
*/
public ScaledParamPatch(float scale, ParamPatch base) {
if (base instanceof ScaledParamPatch) {
ScaledParamPatch sbase = (ScaledParamPatch) base;
this.scale = scale * sbase.scale;
this.base = sbase.base;
} else {
this.scale = scale;
this.base = base;
}
}

/** {@inheritDoc} */
@Override
public NDArray getPatch(String paramName) {
return base.getPatch(paramName).muli(scale);
}

/** {@inheritDoc} */
@Override
public void close() {
base.close();
}
}
Loading

0 comments on commit 817484e

Please sign in to comment.