-
Notifications
You must be signed in to change notification settings - Fork 214
Add Regularizers 1 #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Regularizers 1 #216
Changes from 10 commits
c57a2e7
09fc07e
a99dcb4
ba294ea
04f419a
02e7ebf
e0c9ed8
5b0374b
ccc7820
05ec6e8
b446618
b5c7c78
a3ccf61
8c79214
1af4552
e038bbd
def3051
11748ae
a9412ea
2ff8dfe
ee5e38a
54f1802
bbd3bc3
6c48131
3c45a87
2bd80b3
da7a10b
9ea1d9a
1a93bdc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.regularizers; | ||
|
||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression, | ||
* regularization penalty. | ||
* | ||
* <p>The L1 regularization penalty is computed as: <code>loss = l1 * reduceSum(abs(x))</code> | ||
* | ||
* @param <R> the data type for the weights | ||
*/ | ||
public class L1<R extends TNumber> extends L1L2<R> { | ||
|
||
/** | ||
* Create a regularizer that applies an L1 regularization penalty of {@link | ||
* #DEFAULT_REGULARIZATION_PENALTY} | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
public L1(Ops tf, Class<R> type) { | ||
this(tf, DEFAULT_REGULARIZATION_PENALTY, type); | ||
} | ||
|
||
/** | ||
* Create a regularizer that applies an L1 regularization penalty | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param l1 the L1 regularization penalty | ||
* @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. | ||
*/ | ||
public L1(Ops tf, float l1, Class<R> type) { | ||
super(tf, l1, null, type); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.regularizers; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.framework.losses.impl.LossesHelper; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* A regularizer that applies both L1 and L2 regularization penalties. | ||
* | ||
* <p>The L1 regularization penalty is computed as: | ||
* | ||
* <pre>loss = l1 * reduceSum(abs(x))</pre> | ||
* | ||
* <p>The L2 regularization penalty is computed as | ||
* | ||
* <pre>loss = l2 * reduceSum(square(x))</pre> | ||
* | ||
* <p>The difference between this class and the {@link L1_L2} is use of the default regularization | ||
* penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. | ||
* | ||
* @param <R> the data type for the weights | ||
*/ | ||
public class L1L2<R extends TNumber> extends Regularizer<R> { | ||
|
||
private final Float l1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the motivation for storing these as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, maybe we don't need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 done |
||
private final Float l2; | ||
|
||
/** | ||
* Creates an L1L2 regularizer with no l1 or l2 penalty with default penal | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param type the data type for the weights | ||
*/ | ||
public L1L2(Ops tf, Class<R> type) { | ||
this(tf, null, null, type); | ||
} | ||
|
||
/** | ||
* Creates an L1L2 regularizer | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param l1 L1 regularization factor, if null it is set to 0. | ||
* @param l2 L2 regularization factor, if null it is set to 0. | ||
* @param type the data type for the weights | ||
* @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} | ||
* of {@link Float#isInfinite} | ||
*/ | ||
public L1L2(Ops tf, Float l1, Float l2, Class<R> type) { | ||
super(tf, type); | ||
if (l1 != null) { | ||
if (l1.isNaN() || l1.isInfinite()) { | ||
throw new IllegalArgumentException( | ||
String.format( | ||
"L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", | ||
l1)); | ||
} | ||
this.l1 = l1; | ||
} else { | ||
this.l1 = 0f; | ||
} | ||
if (l2 != null) { | ||
if (l2.isNaN() || l2.isInfinite()) { | ||
throw new IllegalArgumentException( | ||
String.format( | ||
"L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", | ||
l2)); | ||
} | ||
this.l2 = l2; | ||
} else { | ||
this.l2 = 0f; | ||
} | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Operand<R> call(Operand<R> input) { | ||
Ops tf = getTF(); | ||
if (this.getL1() == null && this.getL2() == null) { | ||
return tf.dtypes.cast(tf.constant(0), input.type()); | ||
} | ||
Operand<R> regularization = tf.dtypes.cast(tf.constant(0), input.type()); | ||
|
||
if (this.getL1() != null && this.getL1() != 0.f) { | ||
Operand<R> l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); | ||
Operand<R> abs = tf.math.abs(input); | ||
Operand<R> reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); | ||
regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); | ||
} | ||
|
||
if (this.getL2() != null && this.getL2() != 0.f) { | ||
Operand<R> l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); | ||
Operand<R> sqr = tf.math.abs(input); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 done |
||
Operand<R> reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); | ||
regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); | ||
} | ||
|
||
return regularization; | ||
} | ||
|
||
/** | ||
* Gets the L1 regularization factor | ||
* | ||
* @return the L1 regularization factor | ||
*/ | ||
public Float getL1() { | ||
return l1; | ||
} | ||
|
||
/** | ||
* Gets the L2 regularization factor | ||
* | ||
* @return the L2 regularization factor | ||
*/ | ||
public Float getL2() { | ||
return l2; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.regularizers; | ||
|
||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* A regularizer that applies both L1 and L2 regularization penalties. | ||
* | ||
* <p>The L1 regularization penalty is computed as: | ||
* | ||
* <pre>loss = l1 * reduceSum(abs(x))</pre> | ||
* | ||
* <p>The L2 regularization penalty is computed as | ||
* | ||
* <pre>loss = l2 * reduceSum(square(x))</pre> | ||
* | ||
* <p>The difference between this class and the {@link L1L2} is use of the default regularization | ||
* penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. | ||
* | ||
* @param <R> the data type for the weights | ||
*/ | ||
public class L1_L2<R extends TNumber> extends L1L2<R> { | ||
|
||
/** | ||
* Creates a regularizer that applies an L1 and l2 regularization penalty of {@link | ||
* #DEFAULT_REGULARIZATION_PENALTY} | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
public L1_L2(Ops tf, Class<R> type) { | ||
this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type); | ||
} | ||
|
||
/** | ||
* Creates a regularizer that applies an L1 and l2 regularization penalty | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param l1 the L1 regularization penalty | ||
* @param l2 the L2 regularization penalty | ||
* @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. | ||
*/ | ||
public L1_L2(Ops tf, Float l1, Float l2, Class<R> type) { | ||
super(tf, | ||
l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, | ||
l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2, | ||
type); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.regularizers; | ||
|
||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* A regularizer that applies a L2 (Ridge Regression) regularization penalty. | ||
* | ||
* <p>The L2 regularization penalty is computed as: <code>loss = l2 * reduceSum(square(x))</code> | ||
* | ||
* @param <R> the data type for the operands and result | ||
*/ | ||
public class L2<R extends TNumber> extends L1L2<R> { | ||
|
||
/** | ||
* Create a regularizer that applies an L2 regularization penalty of {@link | ||
* #DEFAULT_REGULARIZATION_PENALTY} | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
public L2(Ops tf, Class<R> type) { | ||
this(tf, DEFAULT_REGULARIZATION_PENALTY, type); | ||
} | ||
|
||
/** | ||
* Create a regularizer that applies an L1 regularization penalty | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param l2 the L2 regularization penalty | ||
* @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. | ||
*/ | ||
public L2(Ops tf, float l2, Class<R> type) { | ||
super(tf, null, l2, type); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.regularizers; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.framework.losses.Loss; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* Base class for Regularizers | ||
* | ||
* <p>Regularizers allow you to apply penalties on layer parameters or layer activity during | ||
* optimization. These penalties are summed into the loss function that the network optimizes. | ||
* | ||
* @param <R> the data type of the operands and result | ||
*/ | ||
public abstract class Regularizer<R extends TNumber> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a clear motivation for giving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will fix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 done |
||
|
||
public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; | ||
|
||
private final Ops tf; | ||
private final String name; | ||
protected Class<R> type; | ||
|
||
/** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me, this is one of those header docs that should be either elaborated or omitted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 done |
||
* Creates a Regularizer | ||
* | ||
* @param tf the TensorFlow ops. | ||
*/ | ||
protected Regularizer(Ops tf, Class<R> type) { | ||
this(tf, null, type); | ||
} | ||
/** | ||
* Creates a Regularizer | ||
* | ||
* @param tf the TensorFlow ops. | ||
*/ | ||
protected Regularizer(Ops tf, String name, Class<R> type) { | ||
this.tf = tf; | ||
this.type = type; | ||
this.name = name == null ? this.getClass().getSimpleName() : name; | ||
} | ||
|
||
/** | ||
* Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only | ||
* sampleWeights are applied to the regularizer. | ||
* | ||
* @return this Regularizer as a Loss | ||
*/ | ||
public Loss asLoss() { | ||
return new RegularizerLoss<>(this.tf, this); | ||
} | ||
|
||
/** | ||
* Computes a regularization penalty from an input. | ||
* | ||
* @param input the weighted input | ||
* @return the result of computing the regularization penalty | ||
*/ | ||
public abstract Operand<R> call(Operand<R> input); | ||
|
||
/** | ||
* Gets the TensorFlow Ops | ||
* | ||
* @return the TensorFlow Ops | ||
*/ | ||
public Ops getTF() { | ||
return tf; | ||
} | ||
|
||
/** | ||
* Gets the name for this regularizer | ||
* | ||
* @return the name for this regularizer | ||
*/ | ||
public String getName() { | ||
return name; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default of no penalty doesn't seem useful. I wonder if
L1_L2
was created later in Python to provide useful defaults? Is it possible we should omitL1L2
on the Java side?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it looks like
l1_l2
was the after thought in Keras, as it is defined as a method. I am ok with adding thel1_l2
defaults toL1L2
and perhaps add a static method toL1L2
to create an L1L2 class with the currentl1_l2
defaults.Perhaps something like:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That general approach seems fine to me. I do like the
L1L2
class name better. The method namel1_l2
is off-Java-naming-standard and is descriptive only by reference to Python Keras. Perhaps instead overload the namecreate
and use default values for the variant with no loss factors provided?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create
orcreateDefault
? Alternatively, I could change the constructors that do not specify the l1, l2 values,There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would lean toward changing the no-specified-penalty constructors to use the defaults instead of
0.f
. But if factory methods felt more comfortable to you, and if you felt that fit some overall direction of the code base, I could be comfortable with that.Yeah, I started down the
createDefault
path, but realized the idiomatic pattern is just overloadedcreate
methods.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am changing
L1L2(Ops tf)
to use tDEFAULT_REGULARIZATION_PENALTY
and eliminating classL1_L2
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done