Skip to content

Commit b5c7c78

Browse files
committed
Updates based on comments from PR.
Removed generic from Regularizer class and changed the call method to define the generic return based on the weights parameter. Added static method l1_l2() to L1L2 class. Fixed JavaDoc comments.
1 parent b446618 commit b5c7c78

File tree

12 files changed

+145
-138
lines changed

12 files changed

+145
-138
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,23 @@
1515
package org.tensorflow.framework.regularizers;
1616

1717
import org.tensorflow.op.Ops;
18-
import org.tensorflow.types.family.TNumber;
1918

2019
/**
21-
* A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression,
22-
* regularization penalty.
20+
* A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator)
21+
* Regression, regularization penalty.
2322
*
2423
* <p>The L1 regularization penalty is computed as: <code>loss = l1 * reduceSum(abs(x))</code>
25-
*
26-
* @param <R> the data type for the weights
2724
*/
28-
public class L1<R extends TNumber> extends L1L2<R> {
25+
public class L1 extends L1L2 {
2926

3027
/**
3128
* Create a regularizer that applies an L1 regularization penalty of {@link
3229
* #DEFAULT_REGULARIZATION_PENALTY}
3330
*
3431
* @param tf the TensorFlow Ops
3532
*/
36-
public L1(Ops tf, Class<R> type) {
37-
this(tf, DEFAULT_REGULARIZATION_PENALTY, type);
33+
public L1(Ops tf) {
34+
this(tf, DEFAULT_REGULARIZATION_PENALTY);
3835
}
3936

4037
/**
@@ -44,7 +41,7 @@ public L1(Ops tf, Class<R> type) {
4441
* @param l1 the L1 regularization penalty
4542
* @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite.
4643
*/
47-
public L1(Ops tf, float l1, Class<R> type) {
48-
super(tf, l1, null, type);
44+
public L1(Ops tf, float l1) {
45+
super(tf, l1, null);
4946
}
5047
}

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,19 @@
3232
*
3333
* <p>The difference between this class and the {@link L1_L2} is use of the default regularization
3434
* penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0.
35-
*
36-
* @param <R> the data type for the weights
3735
*/
38-
public class L1L2<R extends TNumber> extends Regularizer<R> {
36+
public class L1L2 extends Regularizer {
3937

40-
private final Float l1;
41-
private final Float l2;
38+
private final float l1;
39+
private final float l2;
4240

4341
/**
44-
* Creates an L1L2 regularizer with no l1 or l2 penalty with default penal
42+
* Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty
4543
*
4644
* @param tf the TensorFlow Ops
47-
* @param type the data type for the weights
4845
*/
49-
public L1L2(Ops tf, Class<R> type) {
50-
this(tf, null, null, type);
46+
public L1L2(Ops tf) {
47+
this(tf, null, null);
5148
}
5249

5350
/**
@@ -56,12 +53,11 @@ public L1L2(Ops tf, Class<R> type) {
5653
* @param tf the TensorFlow Ops
5754
* @param l1 L1 regularization factor, if null it is set to 0.
5855
* @param l2 L2 regularization factor, if null it is set to 0.
59-
* @param type the data type for the weights
6056
* @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN}
6157
* of {@link Float#isInfinite}
6258
*/
63-
public L1L2(Ops tf, Float l1, Float l2, Class<R> type) {
64-
super(tf, type);
59+
public L1L2(Ops tf, Float l1, Float l2) {
60+
super(tf);
6561
if (l1 != null) {
6662
if (l1.isNaN() || l1.isInfinite()) {
6763
throw new IllegalArgumentException(
@@ -86,25 +82,29 @@ public L1L2(Ops tf, Float l1, Float l2, Class<R> type) {
8682
}
8783
}
8884

85+
public static L1L2 l1_l2(Ops tf) {
86+
return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY);
87+
}
88+
8989
/** {@inheritDoc} */
9090
@Override
91-
public Operand<R> call(Operand<R> input) {
91+
public <R extends TNumber> Operand<R> call(Operand<R> input) {
9292
Ops tf = getTF();
93-
if (this.getL1() == null && this.getL2() == null) {
93+
if (this.getL1() == 0f && this.getL2() == 0f) {
9494
return tf.dtypes.cast(tf.constant(0), input.type());
9595
}
9696
Operand<R> regularization = tf.dtypes.cast(tf.constant(0), input.type());
9797

98-
if (this.getL1() != null && this.getL1() != 0.f) {
98+
if (this.getL1() != 0.f) {
9999
Operand<R> l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type());
100100
Operand<R> abs = tf.math.abs(input);
101101
Operand<R> reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input));
102102
regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum));
103103
}
104104

105-
if (this.getL2() != null && this.getL2() != 0.f) {
105+
if (this.getL2() != 0.f) {
106106
Operand<R> l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type());
107-
Operand<R> sqr = tf.math.abs(input);
107+
Operand<R> sqr = tf.math.square(input);
108108
Operand<R> reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input));
109109
regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum));
110110
}
@@ -117,7 +117,7 @@ public Operand<R> call(Operand<R> input) {
117117
*
118118
* @return the L1 regularization factor
119119
*/
120-
public Float getL1() {
120+
public float getL1() {
121121
return l1;
122122
}
123123

@@ -126,7 +126,7 @@ public Float getL1() {
126126
*
127127
* @return the L2 regularization factor
128128
*/
129-
public Float getL2() {
129+
public float getL2() {
130130
return l2;
131131
}
132132
}

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package org.tensorflow.framework.regularizers;
1616

1717
import org.tensorflow.op.Ops;
18-
import org.tensorflow.types.family.TNumber;
1918

2019
/**
2120
* A regularizer that applies both L1 and L2 regularization penalties.
@@ -30,33 +29,33 @@
3029
*
3130
* <p>The difference between this class and the {@link L1L2} is use of the default regularization
3231
* penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0.
33-
*
34-
* @param <R> the data type for the weights
3532
*/
36-
public class L1_L2<R extends TNumber> extends L1L2<R> {
33+
public class L1_L2 extends L1L2 {
3734

3835
/**
3936
* Creates a regularizer that applies an L1 and l2 regularization penalty of {@link
4037
* #DEFAULT_REGULARIZATION_PENALTY}
4138
*
4239
* @param tf the TensorFlow Ops
4340
*/
44-
public L1_L2(Ops tf, Class<R> type) {
45-
this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type);
41+
public L1_L2(Ops tf) {
42+
this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY);
4643
}
4744

4845
/**
4946
* Creates a regularizer that applies an L1 and l2 regularization penalty
5047
*
5148
* @param tf the TensorFlow Ops
52-
* @param l1 the L1 regularization penalty
53-
* @param l2 the L2 regularization penalty
49+
* @param l1 the L1 regularization penalty. If null, then l1 will be set to {@link
50+
* #DEFAULT_REGULARIZATION_PENALTY}.
51+
* @param l2 the L2 regularization penalty. If null, then l2 will be set to {@link
52+
* #DEFAULT_REGULARIZATION_PENALTY}.
5453
* @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite.
5554
*/
56-
public L1_L2(Ops tf, Float l1, Float l2, Class<R> type) {
57-
super(tf,
58-
l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1,
59-
l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2,
60-
type);
55+
public L1_L2(Ops tf, Float l1, Float l2) {
56+
super(
57+
tf,
58+
l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1,
59+
l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2);
6160
}
6261
}

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,22 @@
1515
package org.tensorflow.framework.regularizers;
1616

1717
import org.tensorflow.op.Ops;
18-
import org.tensorflow.types.family.TNumber;
1918

2019
/**
2120
* A regularizer that applies a L2 (Ridge Regression) regularization penalty.
2221
*
2322
* <p>The L2 regularization penalty is computed as: <code>loss = l2 * reduceSum(square(x))</code>
24-
*
25-
* @param <R> the data type for the operands and result
2623
*/
27-
public class L2<R extends TNumber> extends L1L2<R> {
24+
public class L2 extends L1L2 {
2825

2926
/**
3027
* Create a regularizer that applies an L2 regularization penalty of {@link
3128
* #DEFAULT_REGULARIZATION_PENALTY}
3229
*
3330
* @param tf the TensorFlow Ops
3431
*/
35-
public L2(Ops tf, Class<R> type) {
36-
this(tf, DEFAULT_REGULARIZATION_PENALTY, type);
32+
public L2(Ops tf) {
33+
this(tf, DEFAULT_REGULARIZATION_PENALTY);
3734
}
3835

3936
/**
@@ -43,7 +40,7 @@ public L2(Ops tf, Class<R> type) {
4340
* @param l2 the L2 regularization penalty
4441
* @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite.
4542
*/
46-
public L2(Ops tf, float l2, Class<R> type) {
47-
super(tf, null, l2, type);
43+
public L2(Ops tf, float l2) {
44+
super(tf, null, l2);
4845
}
4946
}

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,31 @@
2424
*
2525
* <p>Regularizers allow you to apply penalties on layer parameters or layer activity during
2626
* optimization. These penalties are summed into the loss function that the network optimizes.
27-
*
28-
* @param <R> the data type of the operands and result
2927
*/
30-
public abstract class Regularizer<R extends TNumber> {
28+
public abstract class Regularizer {
3129

3230
public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f;
3331

3432
private final Ops tf;
3533
private final String name;
36-
protected Class<R> type;
3734

3835
/**
39-
* Creates a Regularizer
36+
* Creates a Regularizer, using {@link Class#getSimpleName()} for the name
4037
*
4138
* @param tf the TensorFlow ops.
4239
*/
43-
protected Regularizer(Ops tf, Class<R> type) {
44-
this(tf, null, type);
40+
protected Regularizer(Ops tf) {
41+
this(tf, null);
4542
}
4643
/**
4744
* Creates a Regularizer
4845
*
4946
* @param tf the TensorFlow ops.
47+
* @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the
48+
* name.
5049
*/
51-
protected Regularizer(Ops tf, String name, Class<R> type) {
50+
protected Regularizer(Ops tf, String name) {
5251
this.tf = tf;
53-
this.type = type;
5452
this.name = name == null ? this.getClass().getSimpleName() : name;
5553
}
5654

@@ -61,7 +59,7 @@ protected Regularizer(Ops tf, String name, Class<R> type) {
6159
* @return this Regularizer as a Loss
6260
*/
6361
public Loss asLoss() {
64-
return new RegularizerLoss<>(this.tf, this);
62+
return new RegularizerLoss(this.tf, this);
6563
}
6664

6765
/**
@@ -70,7 +68,7 @@ public Loss asLoss() {
7068
* @param input the weighted input
7169
* @return the result of computing the regularization penalty
7270
*/
73-
public abstract Operand<R> call(Operand<R> input);
71+
public abstract <R extends TNumber> Operand<R> call(Operand<R> input);
7472

7573
/**
7674
* Gets the TensorFlow Ops

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,24 @@
1919
import org.tensorflow.op.Ops;
2020
import org.tensorflow.types.family.TNumber;
2121

22-
import static org.tensorflow.framework.utils.CastHelper.cast;
23-
2422
/**
2523
* A Regularizer call wrapped as a Loss instance
2624
*
2725
* <p>This class facilitates using a regularizer as a loss, only <code>sampleWeights</code> are
2826
* regularized.
29-
*
30-
* @param <R> the datatype for the weights type
3127
*/
32-
class RegularizerLoss<R extends TNumber> extends Loss {
28+
class RegularizerLoss extends Loss {
29+
30+
private final Regularizer regularizer;
3331

34-
private final Regularizer<R> regularizer;
35-
private final Class<R> type;
3632
/**
3733
* Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link
3834
* Loss#REDUCTION_DEFAULT}
3935
*
4036
* @param tf the TensorFlow Ops
37+
* @param regularizer the regularizer used to calculate the loss
4138
*/
42-
public RegularizerLoss(Ops tf, Regularizer<R> regularizer) {
39+
public RegularizerLoss(Ops tf, Regularizer regularizer) {
4340
this(tf, null, regularizer);
4441
}
4542

@@ -48,11 +45,11 @@ public RegularizerLoss(Ops tf, Regularizer<R> regularizer) {
4845
*
4946
* @param tf the TensorFlow Ops
5047
* @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}.
48+
* @param regularizer the regularizer used to calculate the loss
5149
*/
52-
public RegularizerLoss(Ops tf, String name, Regularizer<R> regularizer) {
50+
public RegularizerLoss(Ops tf, String name, Regularizer regularizer) {
5351
super(tf, name);
5452
this.regularizer = regularizer;
55-
this.type = regularizer.type;
5653
}
5754

5855
/** {@inheritDoc} */
@@ -62,7 +59,6 @@ public <T extends TNumber> Operand<T> call(
6259
if (sampleWeights == null) {
6360
throw new IllegalArgumentException("sampleWeights cannot be null");
6461
}
65-
Operand<R> result = regularizer.call(cast(getTF(), sampleWeights, type));
66-
return cast(tf, result, sampleWeights.type());
62+
return regularizer.call(sampleWeights);
6763
}
6864
}

0 commit comments

Comments
 (0)