diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 7d6d159f5ef..acbae4dac6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -364,10 +364,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -390,8 +390,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index 80f62eb5acc..b1c0b16c972 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -15,6 +15,7 @@ package org.tensorflow; +import org.tensorflow.TensorPrinter.Options; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.Shaped; import org.tensorflow.op.Op; @@ -56,7 +57,7 @@ public interface Operand extends Op, Shaped { /** * Returns the tensor at this operand. * - * Only works when running in an eager execution + *

Only works when running in an eager execution * * @return the tensor * @throws IllegalStateException if this is an operand of a graph @@ -66,14 +67,34 @@ default T asTensor() { } /** - * Returns the tensor type of this operand + * Returns the String representation of the tensor elements at this operand. + * + * @return the String representation of the tensor elements + * @throws IllegalStateException if this is an operand of a graph + */ + default String print() { + return asTensor().print(); + } + + /** + * Returns the String representation of the tensor elements at this operand. + * + * @param options overrides the default configuration + * @return the String representation of the tensor elements + * @throws IllegalStateException if this is an operand of a graph */ + default String print(Options options) { + return asTensor().print(options); + } + + /** Returns the tensor type of this operand */ default Class type() { return asOutput().type(); } /** - * Returns the (possibly partially known) shape of the tensor referred to by the {@link Output} of this operand. + * Returns the (possibly partially known) shape of the tensor referred to by the {@link Output} of + * this operand. */ @Override default Shape shape() { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index fc1275229bf..640fabb2e96 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -15,20 +15,21 @@ package org.tensorflow; -import java.util.function.Consumer; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.Shaped; import org.tensorflow.ndarray.buffer.ByteDataBuffer; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; +import java.util.function.Consumer; + /** * A statically typed multi-dimensional array. * - *

There are two categories of tensors in TensorFlow Java: {@link TType typed tensors} and - * {@link RawTensor raw tensors}. The former maps the tensor native memory to an - * n-dimensional typed data space, allowing direct I/O operations from the JVM, while the latter - * is only a reference to a native tensor allowing basic operations and flat data access.

+ *

There are two categories of tensors in TensorFlow Java: {@link TType typed tensors} and {@link + * RawTensor raw tensors}. The former maps the tensor native memory to an n-dimensional typed data + * space, allowing direct I/O operations from the JVM, while the latter is only a reference to a + * native tensor allowing basic operations and flat data access. * *

WARNING: Resources consumed by the Tensor object must be explicitly freed by * invoking the {@link #close()} method when the object is no longer needed. For example, using a @@ -39,6 +40,7 @@ * doSomethingWith(t); * } * } + * *

Instances of a Tensor are not thread-safe. */ public interface Tensor extends Shaped, AutoCloseable { @@ -54,9 +56,9 @@ public interface Tensor extends Shaped, AutoCloseable { * @param shape shape of the tensor * @return an allocated but uninitialized tensor * @throws IllegalArgumentException if elements of the given {@code type} are of variable length - * (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link + * Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated */ static T of(Class type, Shape shape) { @@ -67,27 +69,27 @@ static T of(Class type, Shape shape) { * Allocates a tensor of a given datatype, shape and size. * *

This method is identical to {@link #of(Class, Shape)}, except that the final size of the - * tensor can be explicitly set instead of computing it from the datatype and shape, which could be - * larger than the actual space required to store the data but not smaller. + * tensor can be explicitly set instead of computing it from the datatype and shape, which could + * be larger than the actual space required to store the data but not smaller. * * @param the tensor type * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape * @return an allocated but uninitialized tensor - * @see #of(Class, Shape) * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * store the tensor data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code + * type} are of variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link + * Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated + * @see #of(Class, Shape) */ static T of(Class type, Shape shape, long size) { RawTensor tensor = RawTensor.allocate(type, shape, size); try { - return (T)tensor.asTypedTensor(); + return (T) tensor.asTypedTensor(); } catch (Exception e) { tensor.close(); throw e; @@ -114,12 +116,13 @@ static T of(Class type, Shape shape, long size) { * @param the tensor type * @param type the tensor type class * @param shape shape of the tensor - * @param dataInitializer method receiving accessor to the allocated tensor data for initialization + * @param dataInitializer method receiving accessor to the allocated tensor data for + * initialization * @return an allocated and initialized tensor * @throws IllegalArgumentException if elements of the given {@code type} are of variable length - * (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link + * Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated */ static T of(Class type, Shape shape, Consumer dataInitializer) { @@ -129,28 +132,30 @@ static T of(Class type, Shape shape, Consumer dataInitia /** * Allocates a tensor of a given datatype, shape and size. * - *

This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final - * size for the tensor can be explicitly set instead of being computed from the datatype and shape. + *

This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final size + * for the tensor can be explicitly set instead of being computed from the datatype and shape. * - *

This could be useful for tensor types that stores data but also metadata in the tensor memory, - * such as the lookup table in a tensor of strings. + *

This could be useful for tensor types that stores data but also metadata in the tensor + * memory, such as the lookup table in a tensor of strings. * * @param the tensor type * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape - * @param dataInitializer method receiving accessor to the allocated tensor data for initialization + * @param dataInitializer method receiving accessor to the allocated tensor data for + * initialization * @return an allocated and initialized tensor - * @see #of(Class, Shape, long, Consumer) * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * store the tensor data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code + * type} are of variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link + * Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated + * @see #of(Class, Shape, long, Consumer) */ - static T of(Class type, Shape shape, long size, Consumer dataInitializer) { + static T of( + Class type, Shape shape, long size, Consumer dataInitializer) { T tensor = of(type, shape, size); try { dataInitializer.accept(tensor); @@ -172,34 +177,65 @@ static T of(Class type, Shape shape, long size, Consumer * @param shape the tensor shape. * @param rawData a buffer containing the tensor raw data. * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor - * data - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * data + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link + * Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated with the given parameters */ static T of(Class type, Shape shape, ByteDataBuffer rawData) { - return of(type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); + return of( + type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); } + /** Returns the {@link DataType} of elements stored in the tensor. */ + DataType dataType(); + + /** Returns the size, in bytes, of the tensor data. */ + long numBytes(); + /** - * Returns the {@link DataType} of elements stored in the tensor. + * Gets the String representation of the tensor elements + * + * @return the String representation of the tensor elements */ - DataType dataType(); + default String print() { + return new TensorPrinter(this).print(); + } /** - * Returns the size, in bytes, of the tensor data. + * Gets the String representation of the tensor elements + * + * @param printOptions the options for the {@link TensorPrinter} object. + * @return the String representation of the tensor elements */ - long numBytes(); + default String print(TensorPrinter.Options printOptions) { + return new TensorPrinter(this, printOptions).print(); + } /** - * Returns the shape of the tensor. + * Get a {@link TensorPrinter} for this tensor. + * + * @return the {@link TensorPrinter} for this tensor. */ - @Override - Shape shape(); + default TensorPrinter printer() { + return new TensorPrinter(this); + } /** - * Returns a raw (untyped) representation of this tensor + * Get a {@link TensorPrinter} for this tensor. + * + * @param printOptions the options for the {@link TensorPrinter} object. + * @return the {@link TensorPrinter} for this tensor. */ + default TensorPrinter printer(TensorPrinter.Options printOptions) { + return new TensorPrinter(this, printOptions); + } + + /** Returns the shape of the tensor. */ + @Override + Shape shape(); + + /** Returns a raw (untyped) representation of this tensor */ RawTensor asRawTensor(); /** @@ -212,4 +248,22 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData */ @Override void close(); + + class ToStringOptions { + + private Integer maxWidth; + + private ToStringOptions() {} + + /** + * Sets the maximum width of the output in characters. + * + * @param maxWidth the maximum width of the output in characters ({@code null} if unlimited). + * This limit may surpassed if the first or last element are too long. + */ + public ToStringOptions maxWidth(Integer maxWidth) { + this.maxWidth = maxWidth; + return this; + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorPrinter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorPrinter.java new file mode 100644 index 00000000000..86e7e09f0e5 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorPrinter.java @@ -0,0 +1,421 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.StringJoiner; + +/** Utility class to print the contents of a Tensor */ +public class TensorPrinter { + + private final Tensor tensor; + private Options options; + + private String endOfLineSep; + private String separatorString; + private String openSet, closeSet; + private String indentString = " "; + + /** + * Creates a TensorPrinter + * + * @param tensor the tensor + */ + public TensorPrinter(Tensor tensor) { + this(tensor, Options.create()); + } + + /** + * Creates a TensorPrinter + * + * @param options the {@link Options} for formatting the print + * @param tensor the tensor + */ + public TensorPrinter(Tensor tensor, Options options) { + this.tensor = tensor; + setOptions(options); + } + + /** + * Gets a printable String for the Tensor data + * + * @return a printable String for the Tensor data + */ + public String print() { + Tensor tmpTensor = tensor; + if (tmpTensor instanceof RawTensor) { + tmpTensor = ((RawTensor) tensor).asTypedTensor(); + } + if (!(tmpTensor instanceof NdArray)) { + // TODO can this ever happen? + return dumpRawTensor(tmpTensor.asRawTensor()); + } + + NdArray ndArray = (NdArray) tmpTensor; + Iterator> iterator = ndArray.scalars().iterator(); + Shape shape = tmpTensor.shape(); + if (shape.numDimensions() == 0) { + if (!iterator.hasNext()) { + return ""; + } + return String.valueOf(iterator.next().getObject()); + } + return formatString(iterator, tmpTensor.dataType(), shape, 0); + } + + /** + * Prints the raw tensor in cases where a Tensor does not inherit from NDArray. + * + * @param tensor the tensor. + * @return the printable raw tensor. + */ + private String dumpRawTensor(RawTensor tensor) { + StringBuilder sb = new StringBuilder(); + sb.append("actual : ").append(tensor).append('\n'); + sb.append("dataType : ").append(tensor.dataType()).append('\n'); + sb.append("class : ").append(tensor.getClass()).append('\n'); + sb.append('\n'); + return sb.toString(); + } + + /** + * @param iterator an iterator over the scalars + * @param dataType the data type for the tensor + * @param shape the shape of the tensor + * @param dimension the current dimension being processed + * @return the String representation of the tensor data at {@code dimension} + */ + private String formatString( + Iterator> iterator, DataType dataType, Shape shape, int dimension) { + + if (dimension < shape.numDimensions() - 1) { + StringJoiner joiner = + new StringJoiner( + endOfLineSep + "\n", + indent(dimension) + openSet + "\n", + "\n" + indent(dimension) + closeSet); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + String element = formatString(iterator, dataType, shape, dimension + 1); + joiner.add(element); + } + return joiner.toString(); + } + if (options.maxWidth == null) { + StringJoiner joiner = + new StringJoiner(separatorString, indent(dimension) + openSet, closeSet); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + Object element = iterator.next().getObject(); + joiner.add(elementToString(dataType, element)); + } + + return joiner.toString(); + } + List lengths = new ArrayList<>(); + StringJoiner joiner = new StringJoiner(separatorString, indent(dimension) + openSet, closeSet); + int lengthBefore = closeSet.length(); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + Object element = iterator.next().getObject(); + joiner.add(elementToString(dataType, element)); + int addedLength = joiner.length() - lengthBefore; + lengths.add(addedLength); + lengthBefore += addedLength; + } + return truncateWidth(joiner.toString(), options.maxWidth, lengths); + } + + /** + * Convert an element of a tensor to string, in a way that may depend on the data type. + * + * @param dataType the tensor's data type + * @param data the element + * @return the element's string representation + */ + private String elementToString(DataType dataType, Object data) { + if (dataType == DataType.DT_STRING) { + return '"' + data.toString() + '"'; + } else if (options.numDecimals != null + && (dataType == DataType.DT_DOUBLE || dataType == DataType.DT_FLOAT)) { + String format = "%." + options.numDecimals + "f"; + return String.format(format, data); + } else { + return data.toString(); + } + } + + /** + * Truncates the width of a String if it's too long, inserting "{@code ...}" in place of the + * removed data. + * + * @param input the input to truncate + * @param maxWidth the maximum width of the output in characters + * @param lengths the lengths of elements inside input + * @return the (potentially) truncated output + */ + private String truncateWidth(String input, int maxWidth, List lengths) { + if (input.length() <= maxWidth) { + return input; + } + StringBuilder output = new StringBuilder(input); + int midPoint = (maxWidth / 2) - 1; + int width = 0; + int indexOfElementToRemove = lengths.size() - 1; + int widthBeforeElementToRemove = 0; + for (int i = 0, size = lengths.size(); i < size; ++i) { + width += lengths.get(i); + if (width > midPoint) { + indexOfElementToRemove = i; + break; + } + widthBeforeElementToRemove = width; + } + if (indexOfElementToRemove == 0) { + // Cannot remove first element + return input; + } + output.insert(widthBeforeElementToRemove, separatorString + "..."); + widthBeforeElementToRemove += (separatorString + "...").length(); + width = output.length(); + while (width > maxWidth) { + if (indexOfElementToRemove == 0) { + // Cannot remove first element + break; + } else if (indexOfElementToRemove == lengths.size() - 1) { + // Cannot remove last element + --indexOfElementToRemove; + continue; + } + Integer length = lengths.remove(indexOfElementToRemove); + output.delete(widthBeforeElementToRemove, widthBeforeElementToRemove + length); + width = output.length(); + } + if (output.length() < input.length()) { + return output.toString(); + } + // Do not insert ellipses if it increases the length + return input; + } + + /** + * Gets the indent string based on the indent level + * + * @param level the level of indent + * @return the indentation string + */ + private String indent(int level) { + if (level <= 0) { + return ""; + } + StringBuilder result = new StringBuilder(level * 2); + for (int i = 0; i < level; ++i) { + result.append(indentString); + } + return result.toString(); + } + + /** + * Gets the tensor + * + * @return the tensor + */ + public Tensor getTensor() { + return tensor; + } + + /** + * Gets the Options + * + * @return the Options + */ + public Options getOptions() { + return options; + } + + /** + * Sets the options for formatting the print string. + * + * @param options the options + */ + public final void setOptions(Options options) { + this.options = options == null ? Options.create() : options; + + switch (this.options.enclosure) { + case BRACES: + openSet = "{"; + closeSet = "}"; + break; + + case PARENS: + openSet = "("; + closeSet = ")"; + break; + case BRACKETS: + default: + openSet = "["; + closeSet = "]"; + break; + } + endOfLineSep = this.options.trailingSeparator ? String.valueOf(this.options.separator) : ""; + separatorString = String.format("%c ", this.options.separator); + if (this.options.indentSize != null) { + String format = "%" + this.options.indentSize + "s"; + indentString = String.format(format, ""); + } + } + + /** Contains the options for TensorPrint */ + public static class Options { + public static char DEFAULT_SEPARATOR = ','; + public static Enclosure DEFAULT_ENCLOSURE = Enclosure.BRACKETS; + + /** The max width of a single line */ + public Integer maxWidth; + /** The element separator character, default is {@link #DEFAULT_SEPARATOR} */ + public char separator = DEFAULT_SEPARATOR; + /** + * the number of digits after the decimal point for floating point numbers, null means to use + * the default format + */ + public Integer numDecimals; + + /** The number of spaces for each indent space */ + public Integer indentSize; + + /** + * Indicator whether a trailing separator is present at the end of an inner set, default is + * false. + */ + public boolean trailingSeparator; + + /** The set of characters that enclose sets, default is {@link #DEFAULT_ENCLOSURE} */ + public Enclosure enclosure = DEFAULT_ENCLOSURE; + + /** Creates an Options instance. */ + Options() {} + + /** + * Creates an Options instance. + * + * @return this Options instance. + */ + public static Options create() { + return new Options(); + } + + /** + * Sets the maxWidth property + * + * @param maxWidth the maximum width of a line + * @return this Options instance. + */ + public Options maxWidth(int maxWidth) { + this.maxWidth = maxWidth; + return this; + } + + /** + * Sets the maxWidth property + * + * @param separator the item separator character, this is the character to separate each + * element. Default is {@link #DEFAULT_SEPARATOR}. + * @return this Options instance. + */ + public Options separator(char separator) { + this.separator = separator; + return this; + } + + /** + * Sets the enclosure to {@link Enclosure#BRACES} {@code {}}. + * + * @return this Options instance. + */ + public Options encloseWithBraces() { + this.enclosure = Enclosure.BRACES; + return this; + } + + /** + * Sets the enclosure to {@link Enclosure#BRACKETS}, {@code []}. + * + * @return this Options instance. + */ + public Options encloseWithBrackets() { + this.enclosure = Enclosure.BRACKETS; + return this; + } + + /** + * Sets the enclosure to {@link Enclosure#PARENS}, {@code ()}. + * + * @return this Options instance. + */ + public Options encloseWithParens() { + this.enclosure = Enclosure.PARENS; + return this; + } + + /** + * Sets the trailingSeparator property. + * + * @param trailingSeparator whether or not to add the separator after inner sets. + * @return this Options instance. + */ + public Options trailingSeparator(boolean trailingSeparator) { + this.trailingSeparator = trailingSeparator; + return this; + } + + /** + * Sets the number of digits after the decimal point. Only applies to floating point data types. + * + *

The default is to use either {@link String#valueOf(float)} or {@link + * String#valueOf(double)} + * + * @param numDecimals the number of digits after the decimal point. + * @return this Options instance. + */ + public Options numDecimals(int numDecimals) { + this.numDecimals = numDecimals; + return this; + } + + /** + * Sets the number of spaces for the indent + * + * @param indentSize the number of spaces for the indent + * @return this Options instance. + */ + public Options indentSize(int indentSize) { + this.indentSize = indentSize; + return this; + } + + /** Enumerator for specifying the character pairs for enclosing sets. */ + public enum Enclosure { + /** Enclose the set in brackets, {@code []} */ + BRACKETS, + /** Enclose the set in curly braces, {@code {}} */ + BRACES, + /** Enclose the set in parenthesis, {@code ()} */ + PARENS + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorPrinterTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorPrinterTest.java new file mode 100644 index 00000000000..04b3a0effd9 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorPrinterTest.java @@ -0,0 +1,405 @@ +package org.tensorflow; + +import org.junit.jupiter.api.Test; +import org.tensorflow.TensorPrinter.Options; +import org.tensorflow.ndarray.BooleanNdArray; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class TensorPrinterTest { + + @Test + public void testPrint1D() { + int[] ints = {1, 2, 3}; + float[] floats = {1f, 2f, 3f}; + double[] doubles = {1d, 2d, 3d}; + long[] longs = {1L, 2L, 3L}; + boolean[] bools = {true, false, true}; + String[] strings = {"A", "B", "C"}; + + try (TInt32 tints = TInt32.vectorOf(ints); + TFloat32 tfloats = TFloat32.vectorOf(floats); + TFloat64 tdoubles = TFloat64.vectorOf(doubles); + TInt64 tlongs = TInt64.vectorOf(longs); + TBool tbools = TBool.vectorOf(bools); + TString tstrings = TString.vectorOf(strings)) { + + TensorPrinter tp = new TensorPrinter(tints); + assertEquals("[1, 2, 3]", tp.print()); + + tp = new TensorPrinter(tfloats); + assertEquals("[1.0, 2.0, 3.0]", tp.print()); + + tp = new TensorPrinter(tdoubles); + assertEquals("[1.0, 2.0, 3.0]", tp.print()); + + tp = new TensorPrinter(tlongs); + assertEquals("[1, 2, 3]", tp.print()); + + tp = new TensorPrinter(tbools); + assertEquals("[true, false, true]", tp.print()); + + tp = new TensorPrinter(tstrings); + assertEquals("[\"A\", \"B\", \"C\"]", tp.print()); + } + } + + @Test + public void testPrint1DSemiColon() { + int[] ints = {1, 2, 3}; + float[] floats = {1f, 2f, 3f}; + double[] doubles = {1d, 2d, 3d}; + long[] longs = {1L, 2L, 3L}; + boolean[] bools = {true, false, true}; + String[] strings = {"A", "B", "C"}; + + try (TInt32 tints = TInt32.vectorOf(ints); + TFloat32 tfloats = TFloat32.vectorOf(floats); + TFloat64 tdoubles = TFloat64.vectorOf(doubles); + TInt64 tlongs = TInt64.vectorOf(longs); + TBool tbools = TBool.vectorOf(bools); + TString tstrings = TString.vectorOf(strings)) { + + Options options = Options.create().separator(':'); + TensorPrinter tp = new TensorPrinter(tints, options); + assertEquals("[1: 2: 3]", tp.print()); + + tp = new TensorPrinter(tfloats, options); + assertEquals("[1.0: 2.0: 3.0]", tp.print()); + + tp = new TensorPrinter(tdoubles, options); + assertEquals("[1.0: 2.0: 3.0]", tp.print()); + + tp = new TensorPrinter(tlongs, options); + assertEquals("[1: 2: 3]", tp.print()); + + tp = new TensorPrinter(tbools, options); + assertEquals("[true: false: true]", tp.print()); + + tp = new TensorPrinter(tstrings, options); + assertEquals("[\"A\": \"B\": \"C\"]", tp.print()); + } + } + + + @Test + public void testPrint1D2Decimal() { + int[] ints = {1, 2, 3}; + float[] floats = {1f, 2f, 3f}; + double[] doubles = {1d, 2d, 3d}; + long[] longs = {1L, 2L, 3L}; + boolean[] bools = {true, false, true}; + String[] strings = {"A", "B", "C"}; + + try (TInt32 tints = TInt32.vectorOf(ints); + TFloat32 tfloats = TFloat32.vectorOf(floats); + TFloat64 tdoubles = TFloat64.vectorOf(doubles); + TInt64 tlongs = TInt64.vectorOf(longs); + TBool tbools = TBool.vectorOf(bools); + TString tstrings = TString.vectorOf(strings)) { + + Options options = Options.create().numDecimals(2); + + TensorPrinter tp = new TensorPrinter(tints, options); + assertEquals("[1, 2, 3]", tp.print()); + + tp = new TensorPrinter(tfloats, options); + assertEquals("[1.00, 2.00, 3.00]", tp.print()); + + tp = new TensorPrinter(tdoubles, options); + assertEquals("[1.00, 2.00, 3.00]", tp.print()); + + tp = new TensorPrinter(tlongs, options); + assertEquals("[1, 2, 3]", tp.print()); + + tp = new TensorPrinter(tbools, options); + assertEquals("[true, false, true]", tp.print()); + + tp = new TensorPrinter(tstrings, options); + assertEquals("[\"A\", \"B\", \"C\"]", tp.print()); + } + } + + @Test + public void testPrint2D() { + int[][] ints = {{1, 2}, {3, 4}}; + float[][] floats = {{1f, 2f}, {3f, 4f}}; + double[][] doubles = {{1d, 2d}, {3d, 4d}}; + long[][] longs = {{1L, 2L}, {3L, 4L}}; + boolean[][] bools = {{true, false}, {true, false}}; + String[][] strings = {{"A", "B"}, {"C", "D"}}; + + IntNdArray iMatrix = StdArrays.ndCopyOf(ints); + FloatNdArray fMatrix = StdArrays.ndCopyOf(floats); + DoubleNdArray dMatrix = StdArrays.ndCopyOf(doubles); + LongNdArray lMatrix = StdArrays.ndCopyOf(longs); + BooleanNdArray bMatrix = StdArrays.ndCopyOf(bools); + NdArray sMatrix = StdArrays.ndCopyOf(strings); + + try (TInt32 tints = TInt32.tensorOf(iMatrix); + TFloat32 tfloats = TFloat32.tensorOf(fMatrix); + TFloat64 tdoubles = TFloat64.tensorOf(dMatrix); + TInt64 tlongs = TInt64.tensorOf(lMatrix); + TBool tbools = TBool.tensorOf(bMatrix); + TString tstrings = TString.tensorOf(sMatrix)) { + + TensorPrinter tp = new TensorPrinter(tints); + assertEquals("[\n [1, 2]\n [3, 4]\n]", tp.print()); + + tp = new TensorPrinter(tfloats); + assertEquals("[\n [1.0, 2.0]\n [3.0, 4.0]\n]", tp.print()); + + tp = new TensorPrinter(tdoubles); + assertEquals("[\n [1.0, 2.0]\n [3.0, 4.0]\n]", tp.print()); + + tp = new TensorPrinter(tlongs); + assertEquals("[\n [1, 2]\n [3, 4]\n]", tp.print()); + + tp = new TensorPrinter(tbools); + assertEquals("[\n [true, false]\n [true, false]\n]", tp.print()); + + tp = new TensorPrinter(tstrings); + assertEquals("[\n [\"A\", \"B\"]\n [\"C\", \"D\"]\n]", tp.print()); + } + } + + @Test + public void testPrint3D() { + int[][][] ints = {{{1}, {2}}, {{3}, {4}}}; + float[][][] floats = {{{1f}, {2f}}, {{3f}, {4f}}}; + double[][][] doubles = {{{1d}, {2d}}, {{3d}, {4d}}}; + long[][][] longs = {{{1L}, {2L}}, {{3L}, {4L}}}; + boolean[][][] bools = {{{true}, {false}}, {{true}, {false}}}; + String[][][] strings = {{{"A"}, {"B"}}, {{"C"}, {"D"}}}; + + IntNdArray iMatrix = StdArrays.ndCopyOf(ints); + FloatNdArray fMatrix = StdArrays.ndCopyOf(floats); + DoubleNdArray dMatrix = StdArrays.ndCopyOf(doubles); + LongNdArray lMatrix = StdArrays.ndCopyOf(longs); + BooleanNdArray bMatrix = StdArrays.ndCopyOf(bools); + NdArray sMatrix = StdArrays.ndCopyOf(strings); + + try (TInt32 tints = TInt32.tensorOf(iMatrix); + TFloat32 tfloats = TFloat32.tensorOf(fMatrix); + TFloat64 tdoubles = TFloat64.tensorOf(dMatrix); + TInt64 tlongs = TInt64.tensorOf(lMatrix); + TBool tbools = TBool.tensorOf(bMatrix); + TString tstrings = TString.tensorOf(sMatrix)) { + + TensorPrinter tp = new TensorPrinter(tints); + assertEquals("[\n [\n [1]\n [2]\n ]\n [\n [3]\n [4]\n ]\n]", tp.print()); + + tp = new TensorPrinter(tfloats); + assertEquals( + "[\n [\n [1.0]\n [2.0]\n ]\n [\n [3.0]\n [4.0]\n ]\n]", tp.print()); + + tp = new TensorPrinter(tdoubles); + assertEquals( + "[\n [\n [1.0]\n [2.0]\n ]\n [\n [3.0]\n [4.0]\n ]\n]", tp.print()); + + tp = new TensorPrinter(tlongs); + assertEquals("[\n [\n [1]\n [2]\n ]\n [\n [3]\n [4]\n ]\n]", tp.print()); + + tp = new TensorPrinter(tbools); + assertEquals( + "[\n [\n [true]\n [false]\n ]\n [\n [true]\n [false]\n ]\n]", tp.print()); + + tp = new TensorPrinter(tstrings); + assertEquals( + "[\n [\n [\"A\"]\n [\"B\"]\n ]\n [\n [\"C\"]\n [\"D\"]\n ]\n]", + tp.print()); + } + } + + @Test + public void testPrint2DBrace() { + int[][] ints = {{1, 2}, {3, 4}}; + float[][] floats = {{1f, 2f}, {3f, 4f}}; + double[][] doubles = {{1d, 2d}, {3d, 4d}}; + long[][] longs = {{1L, 2L}, {3L, 4L}}; + boolean[][] bools = {{true, false}, {true, false}}; + String[][] strings = {{"A", "B"}, {"C", "D"}}; + + IntNdArray iMatrix = StdArrays.ndCopyOf(ints); + FloatNdArray fMatrix = StdArrays.ndCopyOf(floats); + DoubleNdArray dMatrix = StdArrays.ndCopyOf(doubles); + LongNdArray lMatrix = StdArrays.ndCopyOf(longs); + BooleanNdArray bMatrix = StdArrays.ndCopyOf(bools); + NdArray sMatrix = StdArrays.ndCopyOf(strings); + + try (TInt32 tints = TInt32.tensorOf(iMatrix); + TFloat32 tfloats = TFloat32.tensorOf(fMatrix); + TFloat64 tdoubles = TFloat64.tensorOf(dMatrix); + TInt64 tlongs = TInt64.tensorOf(lMatrix); + TBool tbools = TBool.tensorOf(bMatrix); + TString tstrings = TString.tensorOf(sMatrix)) { + + Options options = Options.create().encloseWithBraces(); + + TensorPrinter tp = new TensorPrinter(tints, options); + assertEquals("{\n {1, 2}\n {3, 4}\n}", tp.print()); + + tp = new TensorPrinter(tfloats, options); + assertEquals("{\n {1.0, 2.0}\n {3.0, 4.0}\n}", tp.print()); + + tp = new TensorPrinter(tdoubles, options); + assertEquals("{\n {1.0, 2.0}\n {3.0, 4.0}\n}", tp.print()); + + tp = new TensorPrinter(tlongs, options); + assertEquals("{\n {1, 2}\n {3, 4}\n}", tp.print()); + + tp = new TensorPrinter(tbools, options); + assertEquals("{\n {true, false}\n {true, false}\n}", tp.print()); + + tp = new TensorPrinter(tstrings, options); + assertEquals("{\n {\"A\", \"B\"}\n {\"C\", \"D\"}\n}", tp.print()); + } + } + + @Test + public void testPrint2DParen() { + int[][] ints = {{1, 2}, {3, 4}}; + float[][] floats = {{1f, 2f}, {3f, 4f}}; + double[][] doubles = {{1d, 2d}, {3d, 4d}}; + long[][] longs = {{1L, 2L}, {3L, 4L}}; + boolean[][] bools = {{true, false}, {true, false}}; + String[][] strings = {{"A", "B"}, {"C", "D"}}; + + IntNdArray iMatrix = StdArrays.ndCopyOf(ints); + FloatNdArray fMatrix = StdArrays.ndCopyOf(floats); + DoubleNdArray dMatrix = StdArrays.ndCopyOf(doubles); + LongNdArray lMatrix = StdArrays.ndCopyOf(longs); + BooleanNdArray bMatrix = StdArrays.ndCopyOf(bools); + NdArray sMatrix = StdArrays.ndCopyOf(strings); + + try (TInt32 tints = TInt32.tensorOf(iMatrix); + TFloat32 tfloats = TFloat32.tensorOf(fMatrix); + TFloat64 tdoubles = TFloat64.tensorOf(dMatrix); + TInt64 tlongs = TInt64.tensorOf(lMatrix); + TBool tbools = TBool.tensorOf(bMatrix); + TString tstrings = TString.tensorOf(sMatrix)) { + + Options options = Options.create().encloseWithParens(); + + TensorPrinter tp = new TensorPrinter(tints, options); + assertEquals("(\n (1, 2)\n (3, 4)\n)", tp.print()); + + tp = new TensorPrinter(tfloats, options); + assertEquals("(\n (1.0, 2.0)\n (3.0, 4.0)\n)", tp.print()); + + tp = new TensorPrinter(tdoubles, options); + assertEquals("(\n (1.0, 2.0)\n (3.0, 4.0)\n)", tp.print()); + + tp = new TensorPrinter(tlongs, options); + assertEquals("(\n (1, 2)\n (3, 4)\n)", tp.print()); + + tp = new TensorPrinter(tbools, options); + assertEquals("(\n (true, false)\n (true, false)\n)", tp.print()); + + tp = new TensorPrinter(tstrings, options); + assertEquals("(\n (\"A\", \"B\")\n (\"C\", \"D\")\n)", tp.print()); + } + } + + @Test + public void testPrint2DtrailingSeparator() { + int[][] ints = {{1, 2}, {3, 4}}; + float[][] floats = {{1f, 2f}, {3f, 4f}}; + double[][] doubles = {{1d, 2d}, {3d, 4d}}; + long[][] longs = {{1L, 2L}, {3L, 4L}}; + boolean[][] bools = {{true, false}, {true, false}}; + String[][] strings = {{"A", "B"}, {"C", "D"}}; + + IntNdArray iMatrix = StdArrays.ndCopyOf(ints); + FloatNdArray fMatrix = StdArrays.ndCopyOf(floats); + DoubleNdArray dMatrix = StdArrays.ndCopyOf(doubles); + LongNdArray lMatrix = StdArrays.ndCopyOf(longs); + BooleanNdArray bMatrix = StdArrays.ndCopyOf(bools); + NdArray sMatrix = StdArrays.ndCopyOf(strings); + + try (TInt32 tints = TInt32.tensorOf(iMatrix); + TFloat32 tfloats = TFloat32.tensorOf(fMatrix); + TFloat64 tdoubles = TFloat64.tensorOf(dMatrix); + TInt64 tlongs = TInt64.tensorOf(lMatrix); + TBool tbools = TBool.tensorOf(bMatrix); + TString tstrings = TString.tensorOf(sMatrix)) { + + Options options = Options.create().trailingSeparator(true); + + TensorPrinter tp = new TensorPrinter(tints, options); + assertEquals("[\n [1, 2],\n [3, 4]\n]", tp.print()); + + tp = new TensorPrinter(tfloats, options); + assertEquals("[\n [1.0, 2.0],\n [3.0, 4.0]\n]", tp.print()); + + tp = new TensorPrinter(tdoubles, options); + assertEquals("[\n [1.0, 2.0],\n [3.0, 4.0]\n]", tp.print()); + + tp = new TensorPrinter(tlongs, options); + assertEquals("[\n [1, 2],\n [3, 4]\n]", tp.print()); + + tp = new TensorPrinter(tbools, options); + assertEquals("[\n [true, false],\n [true, false]\n]", tp.print()); + + tp = new TensorPrinter(tstrings, options); + assertEquals("[\n [\"A\", \"B\"],\n [\"C\", \"D\"]\n]", tp.print()); + } + } + + @Test + public void testPrint2DtrailingIndent() { + int[][] ints = {{1, 2}, {3, 4}}; + float[][] floats = {{1f, 2f}, {3f, 4f}}; + double[][] doubles = {{1d, 2d}, {3d, 4d}}; + long[][] longs = {{1L, 2L}, {3L, 4L}}; + boolean[][] bools = {{true, false}, {true, false}}; + String[][] strings = {{"A", "B"}, {"C", "D"}}; + + IntNdArray iMatrix = StdArrays.ndCopyOf(ints); + FloatNdArray fMatrix = StdArrays.ndCopyOf(floats); + DoubleNdArray dMatrix = StdArrays.ndCopyOf(doubles); + LongNdArray lMatrix = StdArrays.ndCopyOf(longs); + BooleanNdArray bMatrix = StdArrays.ndCopyOf(bools); + NdArray sMatrix = StdArrays.ndCopyOf(strings); + + try (TInt32 tints = TInt32.tensorOf(iMatrix); + TFloat32 tfloats = TFloat32.tensorOf(fMatrix); + TFloat64 tdoubles = TFloat64.tensorOf(dMatrix); + TInt64 tlongs = TInt64.tensorOf(lMatrix); + TBool tbools = TBool.tensorOf(bMatrix); + TString tstrings = TString.tensorOf(sMatrix)) { + + Options options = Options.create().indentSize(1); + + TensorPrinter tp = new TensorPrinter(tints, options); + assertEquals("[\n [1, 2]\n [3, 4]\n]", tp.print()); + + tp = new TensorPrinter(tfloats, options); + assertEquals("[\n [1.0, 2.0]\n [3.0, 4.0]\n]", tp.print()); + + tp = new TensorPrinter(tdoubles, options); + assertEquals("[\n [1.0, 2.0]\n [3.0, 4.0]\n]", tp.print()); + + tp = new TensorPrinter(tlongs, options); + assertEquals("[\n [1, 2]\n [3, 4]\n]", tp.print()); + + tp = new TensorPrinter(tbools, options); + assertEquals("[\n [true, false]\n [true, false]\n]", tp.print()); + + tp = new TensorPrinter(tstrings, options); + assertEquals("[\n [\"A\", \"B\"]\n [\"C\", \"D\"]\n]", tp.print()); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 9415a986222..ffc121b8155 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -15,21 +15,8 @@ package org.tensorflow; -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; - -import java.nio.Buffer; -import java.nio.BufferUnderflowException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; import org.junit.jupiter.api.Test; +import org.tensorflow.TensorPrinter.Options; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; @@ -50,11 +37,47 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; +import java.nio.Buffer; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + /** Unit tests for {@link org.tensorflow.Tensor}. */ public class TensorTest { private static final double EPSILON = 1e-7; private static final float EPSILON_F = 1e-7f; + // Workaround for cross compiliation + // (e.g., javac -source 1.9 -target 1.8). + // + // In Java 8 and prior, subclasses of java.nio.Buffer (e.g., java.nio.DoubleBuffer) inherited the + // "flip()" and "clear()" methods from java.nio.Buffer resulting in the signature: + // Buffer flip(); + // In Java 9 these subclasses had their own methods like: + // DoubleBuffer flip(); + // As a result, compiling for 1.9 source for a target of JDK 1.8 would result in errors at runtime + // like: + // + // java.lang.NoSuchMethodError: java.nio.DoubleBuffer.flip()Ljava/nio/DoubleBuffer + private static void flipBuffer(Buffer buf) { + buf.flip(); + } + + // See comment for flipBuffer() + private static void clearBuffer(Buffer buf) { + buf.clear(); + } + @Test public void createWithRawData() { double[] doubles = {1d, 2d, 3d, 4d}; @@ -66,7 +89,7 @@ public void createWithRawData() { Shape strings_shape = Shape.scalar(); byte[] strings_; // raw TF_STRING try (TString t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { - strings_ = new byte[(int)t.numBytes()]; + strings_ = new byte[(int) t.numBytes()]; t.asRawTensor().data().read(strings_); } @@ -86,8 +109,11 @@ public void createWithRawData() { // validate creating a tensor using a direct byte buffer (in host order) { - DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) - .asDoubleBuffer().put(doubles); + DoubleBuffer buf = + ByteBuffer.allocateDirect(8 * doubles.length) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer() + .put(doubles); try (TFloat64 t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { double[] actual = new double[doubles.length]; t.read(DataBuffers.of(actual)); @@ -140,10 +166,10 @@ public void createFromBufferWithNonNativeByteOrder() { @Test public void createWithTypedBuffer() { - IntBuffer ints = IntBuffer.wrap(new int[]{1, 2, 3, 4}); - FloatBuffer floats = FloatBuffer.wrap(new float[]{1f, 2f, 3f, 4f}); - DoubleBuffer doubles = DoubleBuffer.wrap(new double[]{1d, 2d, 3d, 4d}); - LongBuffer longs = LongBuffer.wrap(new long[]{1L, 2L, 3L, 4L}); + IntBuffer ints = IntBuffer.wrap(new int[] {1, 2, 3, 4}); + FloatBuffer floats = FloatBuffer.wrap(new float[] {1f, 2f, 3f, 4f}); + DoubleBuffer doubles = DoubleBuffer.wrap(new double[] {1d, 2d, 3d, 4d}); + LongBuffer longs = LongBuffer.wrap(new long[] {1L, 2L, 3L, 4L}); // validate creating a tensor using a typed buffer { @@ -243,7 +269,7 @@ public void readFromRawData() { // validate the use of direct buffers { ByteBuffer bbuf = - ByteBuffer.allocateDirect((int)tdoubles.numBytes()).order(ByteOrder.nativeOrder()); + ByteBuffer.allocateDirect((int) tdoubles.numBytes()).order(ByteOrder.nativeOrder()); tdoubles.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON); } @@ -251,13 +277,17 @@ public void readFromRawData() { // validate byte order conversion { DoubleBuffer foreignBuf = - ByteBuffer.allocate((int)tdoubles.numBytes()) + ByteBuffer.allocate((int) tdoubles.numBytes()) .order( ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN) .asDoubleBuffer(); - tdoubles.asRawTensor().data().asDoubles().copyTo(DataBuffers.of(foreignBuf), foreignBuf.capacity()); + tdoubles + .asRawTensor() + .data() + .asDoubles() + .copyTo(DataBuffers.of(foreignBuf), foreignBuf.capacity()); double[] actual = new double[foreignBuf.remaining()]; foreignBuf.get(actual); assertArrayEquals(doubles, actual, EPSILON); @@ -320,7 +350,7 @@ public void scalars() { @Test public void nDimensional() { - DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); + DoubleNdArray vector = StdArrays.ndCopyOf(new double[] {1.414, 2.718, 3.1415}); try (TFloat64 t = TFloat64.tensorOf(vector)) { assertEquals(TFloat64.class, t.type()); assertEquals(DataType.DT_DOUBLE, t.dataType()); @@ -329,7 +359,7 @@ public void nDimensional() { assertEquals(vector, t); } - IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); + IntNdArray matrix = StdArrays.ndCopyOf(new int[][] {{1, 2, 3}, {4, 5, 6}}); try (TInt32 t = TInt32.tensorOf(matrix)) { assertEquals(TInt32.class, t.type()); assertEquals(DataType.DT_INT32, t.dataType()); @@ -339,9 +369,11 @@ public void nDimensional() { assertEquals(matrix, t); } - LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ - {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, - }); + LongNdArray threeD = + StdArrays.ndCopyOf( + new long[][][] { + {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, + }); try (TInt64 t = TInt64.tensorOf(threeD)) { assertEquals(TInt64.class, t.type()); assertEquals(DataType.DT_INT64, t.dataType()); @@ -352,11 +384,13 @@ public void nDimensional() { assertEquals(threeD, t); } - BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ - {{{false, false, false, true}, {false, false, true, false}}}, - {{{false, false, true, true}, {false, true, false, false}}}, - {{{false, true, false, true}, {false, true, true, false}}}, - }); + BooleanNdArray fourD = + StdArrays.ndCopyOf( + new boolean[][][][] { + {{{false, false, false, true}, {false, false, true, false}}}, + {{{false, false, true, true}, {false, true, false, false}}}, + {{{false, true, false, true}, {false, true, true, false}}}, + }); try (TBool t = TBool.tensorOf(fourD)) { assertEquals(TBool.class, t.type()); assertEquals(DataType.DT_BOOL, t.dataType()); @@ -387,7 +421,9 @@ public void testNDimensionalStringTensor() { } NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); - matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); + matrix + .scalars() + .forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); try (TString t = TString.tensorOfBytes(byteMatrix)) { assertEquals(TString.class, t.type()); assertEquals(DataType.DT_STRING, t.dataType()); @@ -512,9 +548,10 @@ public void fromHandle() { // // An exception is made for this test, where the pitfalls of this is avoided by not calling // close() on both Tensors. - final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); + final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][] {{1, 2, 3}, {4, 5, 6}}); try (TFloat32 src = TFloat32.tensorOf(matrix)) { - TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor(); + TFloat32 cpy = + (TFloat32) RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor(); assertEquals(src.type(), cpy.type()); assertEquals(src.dataType(), cpy.dataType()); assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions()); @@ -541,24 +578,55 @@ public void gracefullyFailCreationFromNullArrayForStringTensor() { } } - // Workaround for cross compiliation - // (e.g., javac -source 1.9 -target 1.8). - // - // In Java 8 and prior, subclasses of java.nio.Buffer (e.g., java.nio.DoubleBuffer) inherited the - // "flip()" and "clear()" methods from java.nio.Buffer resulting in the signature: - // Buffer flip(); - // In Java 9 these subclasses had their own methods like: - // DoubleBuffer flip(); - // As a result, compiling for 1.9 source for a target of JDK 1.8 would result in errors at runtime - // like: - // - // java.lang.NoSuchMethodError: java.nio.DoubleBuffer.flip()Ljava/nio/DoubleBuffer - private static void flipBuffer(Buffer buf) { - buf.flip(); - } - - // See comment for flipBuffer() - private static void clearBuffer(Buffer buf) { - buf.clear(); + @Test + public void testPrint() { + try (TInt32 t = TInt32.vectorOf(3, 0, 1)) { + String actual = t.print(); + assertEquals("[3, 0, 1]", actual); + } + try (TInt32 t = TInt32.vectorOf(3, 0, 1)) { + String actual = t.print(Options.create().maxWidth(5)); + // Cannot remove first or last element + assertEquals("[3, 0, 1]", actual); + } + try (TInt32 t = TInt32.vectorOf(3, 0, 1)) { + String actual = t.print(Options.create().maxWidth(6)); + // Do not insert ellipses if it increases the length + assertEquals("[3, 0, 1]", actual); + } + try (TInt32 t = TInt32.vectorOf(3, 0, 1, 2)) { + String actual = t.print(Options.create().maxWidth(11)); + // Limit may be surpassed if first or last element are too long + assertEquals("[3, ..., 2]", actual); + } + try (TInt32 t = TInt32.vectorOf(3, 0, 1, 2)) { + String actual = t.print(Options.create().maxWidth(12)); + assertEquals("[3, 0, 1, 2]", actual); + } + try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{1, 2, 3}, {3, 2, 1}}))) { + String actual = t.print(Options.create().maxWidth(12)); + assertEquals("[\n" + " [1, 2, 3]\n" + " [3, 2, 1]\n" + "]", actual); + } + try (RawTensor t = TInt32.vectorOf(3, 0, 1, 2).asRawTensor()) { + String actual = t.print(Options.create().maxWidth(12).encloseWithBraces()); + assertEquals("{3, 0, 1, 2}", actual); + } + // different data types + try (RawTensor t = TFloat32.vectorOf(3.0101f, 0, 1.5f, 2).asRawTensor()) { + String actual = t.print(); + assertEquals("[3.0101, 0.0, 1.5, 2.0]", actual); + } + try (RawTensor t = TFloat64.vectorOf(3.0101, 0, 1.5, 2).asRawTensor()) { + String actual = t.print(); + assertEquals("[3.0101, 0.0, 1.5, 2.0]", actual); + } + try (RawTensor t = TBool.vectorOf(true, true, false, true).asRawTensor()) { + String actual = t.print(); + assertEquals("[true, true, false, true]", actual); + } + try (RawTensor t = TString.vectorOf("a", "b", "c").asRawTensor()) { + String actual = t.print(); + assertEquals("[\"a\", \"b\", \"c\"]", actual); + } } }