Skip to content

Commit 1af4552

Browse files
committed
delete class L1_L2
modified Float to float for l1 and l2 parameters Change ctor L1L2(Ops tf) to use DEFAULT_REGULARIZATION_PENALTY for l1/l2 parameters Fix JavaDoc
1 parent 8c79214 commit 1af4552

File tree

8 files changed

+33
-238
lines changed

8 files changed

+33
-238
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ public L1(Ops tf) {
4242
* @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite.
4343
*/
4444
public L1(Ops tf, float l1) {
45-
super(tf, l1, null);
45+
super(tf, l1, 0f);
4646
}
4747
}

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
*
3131
* <pre>loss = l2 * reduceSum(square(x))</pre>
3232
*
33-
* <p>The difference between this class and the {@link L1_L2} is use of the default regularization
34-
* penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0.
3533
*/
3634
public class L1L2 extends Regularizer {
3735

@@ -44,7 +42,7 @@ public class L1L2 extends Regularizer {
4442
* @param tf the TensorFlow Ops
4543
*/
4644
public L1L2(Ops tf) {
47-
this(tf, null, null);
45+
this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY);
4846
}
4947

5048
/**
@@ -56,42 +54,25 @@ public L1L2(Ops tf) {
5654
* @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN}
5755
* of {@link Float#isInfinite}
5856
*/
59-
public L1L2(Ops tf, Float l1, Float l2) {
57+
public L1L2(Ops tf, float l1, float l2) {
6058
super(tf);
61-
if (l1 != null) {
62-
if (l1.isNaN() || l1.isInfinite()) {
63-
throw new IllegalArgumentException(
64-
String.format(
65-
"L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value",
66-
l1));
67-
}
68-
this.l1 = l1;
69-
} else {
70-
this.l1 = 0f;
59+
if (Float.isNaN(l1) || Float.isInfinite(l1)) {
60+
throw new IllegalArgumentException(
61+
String.format(
62+
"L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value",
63+
l1));
7164
}
72-
if (l2 != null) {
73-
if (l2.isNaN() || l2.isInfinite()) {
74-
throw new IllegalArgumentException(
75-
String.format(
76-
"L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value",
77-
l2));
78-
}
79-
this.l2 = l2;
80-
} else {
81-
this.l2 = 0f;
65+
this.l1 = l1;
66+
67+
if (Float.isNaN(l2) || Float.isInfinite(l2)) {
68+
throw new IllegalArgumentException(
69+
String.format(
70+
"L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value",
71+
l2));
8272
}
73+
this.l2 = l2;
8374
}
8475

85-
/**
86-
* Creates an L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2
87-
* values.
88-
*
89-
* @param tf the TensorFlow Ops
90-
* @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values.
91-
*/
92-
public static L1L2 create(Ops tf) {
93-
return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY);
94-
}
9576

9677
/** {@inheritDoc} */
9778
@Override

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java

Lines changed: 0 additions & 61 deletions
This file was deleted.

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,6 @@ public L2(Ops tf) {
4141
* @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite.
4242
*/
4343
public L2(Ops tf, float l2) {
44-
super(tf, null, l2);
44+
super(tf, 0f, l2);
4545
}
4646
}

tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public Loss asLoss() {
6767
*
6868
* @param input the weighted input
6969
* @return the result of computing the regularization penalty
70+
* @param <R> the data type of the input and result
7071
*/
7172
public abstract <R extends TNumber> Operand<R> call(Operand<R> input);
7273

tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,21 @@ public void testCreate() {
2121
assertEquals(0.2f, instance.getL1());
2222
assertEquals(0.3f, instance.getL2());
2323

24-
instance = new L1L2(tf, null, null);
24+
instance = new L1L2(tf, 0, 0);
2525
assertEquals(0.f, instance.getL1());
2626
assertEquals(0.f, instance.getL2());
2727

28-
instance = new L1L2(tf, 0.5f, null);
28+
instance = new L1L2(tf, 0.5f, 0);
2929
assertEquals(0.5f, instance.getL1());
3030
assertEquals(0.f, instance.getL2());
3131

32-
instance = new L1L2(tf, null, 0.5f);
32+
instance = new L1L2(tf, 0, 0.5f);
3333
assertEquals(0.f, instance.getL1());
3434
assertEquals(0.5f, instance.getL2());
35+
36+
instance = new L1L2(tf);
37+
assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1());
38+
assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2());
3539
}
3640
}
3741

@@ -42,16 +46,16 @@ public void testCallDefaultsConstant() {
4246
Ops tf = session.getTF();
4347
L1L2 instance = new L1L2(tf);
4448
Operand<TFloat32> result = instance.call(tf.constant(555f));
45-
session.evaluate(0f, result);
49+
session.evaluate(3085.8f, result);
4650
}
4751
}
4852

4953
@Test
50-
public void testCallL1L20() {
54+
public void testCallL1L2_0() {
5155
for (TestSession.Mode tfMode : tfModes)
5256
try (TestSession session = TestSession.createTestSession(tfMode)) {
5357
Ops tf = session.getTF();
54-
L1L2 instance = new L1L2(tf);
58+
L1L2 instance = new L1L2(tf, 0, 0);
5559
Operand<TFloat32> weights =
5660
tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}});
5761
Operand<TFloat32> result = instance.call(weights);
@@ -90,11 +94,11 @@ public void testCallL1L2TFloat64() {
9094
}
9195

9296
@Test
93-
public void testCallL2Null() {
97+
public void testCallL2_0() {
9498
for (TestSession.Mode tfMode : tfModes)
9599
try (TestSession session = TestSession.createTestSession(tfMode)) {
96100
Ops tf = session.getTF();
97-
L1L2 instance = new L1L2(tf, 0.01f, null);
101+
L1L2 instance = new L1L2(tf, 0.01f, 0);
98102
float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
99103
Operand<TFloat32> weights = tf.constant(w);
100104
Operand<TFloat32> result = instance.call(weights);
@@ -104,11 +108,11 @@ public void testCallL2Null() {
104108
}
105109

106110
@Test
107-
public void testCallL1Null() {
111+
public void testCallL1_0() {
108112
for (TestSession.Mode tfMode : tfModes)
109113
try (TestSession session = TestSession.createTestSession(tfMode)) {
110114
Ops tf = session.getTF();
111-
L1L2 instance = new L1L2(tf, null, 0.02f);
115+
L1L2 instance = new L1L2(tf, 0, 0.02f);
112116
double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}};
113117
Operand<TFloat64> weights = tf.constant(w);
114118
Operand<TFloat64> result = instance.call(weights);

tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java

Lines changed: 0 additions & 130 deletions
This file was deleted.

tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public void testCreate() {
1414
for (TestSession.Mode tfMode : tfModes)
1515
try (TestSession session = TestSession.createTestSession(tfMode)) {
1616
Ops tf = session.getTF();
17-
L1L2 regularizer = new L1L2(tf, 0.01f, null);
17+
L1L2 regularizer = new L1L2(tf, 0.01f, 0f);
1818
float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
1919
Operand<TFloat32> weights = tf.constant(w);
2020
Operand<TFloat32> regularizerResult = regularizer.call(weights);

0 commit comments

Comments
 (0)