Skip to content

Commit 3e97b94

Browse files
committed
Fix a few more issues with i8x.
1 parent ac96142 commit 3e97b94

File tree

1 file changed

+57
-39
lines changed

1 file changed

+57
-39
lines changed

nnc/Store.swift

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,13 +1705,15 @@ private let i8xEncode:
17051705
return 1
17061706
}
17071707

1708-
private func i8xDecodeImpl(
1709-
_ data: UnsafeRawPointer?, _ dataSize: Int, _ dataType: Int32,
1710-
_ dimensions: UnsafePointer<Int32>?, _ dimensionCount: Int32, _ identifier: UInt32,
1711-
_ context: UnsafeMutableRawPointer?, _ params: ccv_nnc_tensor_param_t,
1712-
_ tensorOut: UnsafeMutablePointer<UnsafeMutablePointer<ccv_nnc_tensor_t>?>?,
1713-
_ decoded: UnsafeMutableRawPointer?, _ decodedSize: UnsafeMutablePointer<Int>?
1714-
) -> Int32 {
1708+
private let i8xDecode:
1709+
@convention(c) (
1710+
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UInt32, UnsafeMutableRawPointer?,
1711+
ccv_nnc_tensor_param_t, UnsafeMutablePointer<UnsafeMutablePointer<ccv_nnc_tensor_t>?>?,
1712+
UnsafeMutableRawPointer?, UnsafeMutablePointer<Int>?
1713+
) -> Int32 = {
1714+
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params, tensorOut,
1715+
decoded, decodedSize
1716+
in
17151717
guard identifier == 0x8a1e9b else { return 0 }
17161718
guard
17171719
dataType == Int32(CCV_64F) || dataType == Int32(CCV_32F) || dataType == Int32(CCV_16F)
@@ -1770,7 +1772,7 @@ private func i8xDecodeImpl(
17701772
return 1
17711773
}
17721774

1773-
private let i8xDecode:
1775+
private let i8xDecodeJit:
17741776
@convention(c) (
17751777
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UInt32, UnsafeMutableRawPointer?,
17761778
ccv_nnc_tensor_param_t, UnsafeMutablePointer<UnsafeMutablePointer<ccv_nnc_tensor_t>?>?,
@@ -1779,18 +1781,6 @@ private let i8xDecode:
17791781
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params, tensorOut,
17801782
decoded, decodedSize
17811783
in
1782-
return i8xDecodeImpl(
1783-
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params,
1784-
tensorOut, decoded, decodedSize)
1785-
}
1786-
1787-
private func i8xDecodeJitImpl(
1788-
_ data: UnsafeRawPointer?, _ dataSize: Int, _ dataType: Int32,
1789-
_ dimensions: UnsafePointer<Int32>?, _ dimensionCount: Int32, _ identifier: UInt32,
1790-
_ context: UnsafeMutableRawPointer?, _ params: ccv_nnc_tensor_param_t,
1791-
_ tensorOut: UnsafeMutablePointer<UnsafeMutablePointer<ccv_nnc_tensor_t>?>?,
1792-
_ decoded: UnsafeMutableRawPointer?, _ decodedSize: UnsafeMutablePointer<Int>?
1793-
) -> Int32 {
17941784
guard identifier == 0x8a1e9b else { return 0 }
17951785
guard
17961786
dataType == Int32(CCV_64F) || dataType == Int32(CCV_32F) || dataType == Int32(CCV_16F)
@@ -1808,7 +1798,7 @@ private func i8xDecodeJitImpl(
18081798
decodedSize[0] = dataSize
18091799
return 1
18101800
}
1811-
return i8xDecodeImpl(
1801+
return i8xDecode(
18121802
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params,
18131803
tensorOut, decoded, decodedSize)
18141804
}
@@ -1817,14 +1807,14 @@ private func i8xDecodeJitImpl(
18171807
numberOfElements *= Int(dimensions[i])
18181808
}
18191809
guard TensorShape(dims: params.dim).reduce(1, *) == numberOfElements else {
1820-
return i8xDecodeImpl(
1810+
return i8xDecode(
18211811
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params,
18221812
tensorOut, decoded, decodedSize)
18231813
}
18241814
let rowwiseParams = ccv_nnc_tensor_8i_rowwise(params)
18251815
let encodedDataSize = ccv_nnc_tensor_data_size_without_padding(rowwiseParams)
18261816
guard dataSize >= encodedDataSize && decodedSize[0] >= encodedDataSize else {
1827-
return i8xDecodeImpl(
1817+
return i8xDecode(
18281818
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params,
18291819
tensorOut, decoded, decodedSize)
18301820
}
@@ -1836,20 +1826,6 @@ private func i8xDecodeJitImpl(
18361826
return 1
18371827
}
18381828

1839-
private let i8xDecodeJit:
1840-
@convention(c) (
1841-
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UInt32, UnsafeMutableRawPointer?,
1842-
ccv_nnc_tensor_param_t, UnsafeMutablePointer<UnsafeMutablePointer<ccv_nnc_tensor_t>?>?,
1843-
UnsafeMutableRawPointer?, UnsafeMutablePointer<Int>?
1844-
) -> Int32 = {
1845-
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params, tensorOut,
1846-
decoded, decodedSize
1847-
in
1848-
return i8xDecodeJitImpl(
1849-
data, dataSize, dataType, dimensions, dimensionCount, identifier, context, params,
1850-
tensorOut, decoded, decodedSize)
1851-
}
1852-
18531829
private let fpzipEncode:
18541830
@convention(c) (
18551831
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UnsafeMutableRawPointer?,
@@ -2371,6 +2347,25 @@ private let q8pAndEzm7Encode:
23712347
identifier)
23722348
}
23732349

2350+
private let i8xAndEzm7Encode:
2351+
@convention(c) (
2352+
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UnsafeMutableRawPointer?,
2353+
UnsafeMutableRawPointer?, UnsafeMutablePointer<Int>?,
2354+
UnsafeMutablePointer<ccv_nnc_tensor_param_t>?, UnsafeMutablePointer<UInt32>?
2355+
) -> Int32 = {
2356+
data, dataSize, dataType, dimensions, dimensionCount, context, encoded, encodedSize, params,
2357+
identifier
2358+
in
2359+
guard
2360+
i8xEncode(
2361+
data, dataSize, dataType, dimensions, dimensionCount, context, encoded, encodedSize, params,
2362+
identifier) == 0
2363+
else { return 1 }
2364+
return ezm7Encode(
2365+
data, dataSize, dataType, dimensions, dimensionCount, context, encoded, encodedSize, params,
2366+
identifier)
2367+
}
2368+
23742369
private let fpzipAndZipEncode:
23752370
@convention(c) (
23762371
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UnsafeMutableRawPointer?,
@@ -2486,6 +2481,25 @@ private let q8pAndEzm7EncodeWithExternalStore:
24862481
identifier)
24872482
}
24882483

2484+
private let i8xAndEzm7EncodeWithExternalStore:
2485+
@convention(c) (
2486+
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UnsafeMutableRawPointer?,
2487+
UnsafeMutableRawPointer?, UnsafeMutablePointer<Int>?,
2488+
UnsafeMutablePointer<ccv_nnc_tensor_param_t>?, UnsafeMutablePointer<UInt32>?
2489+
) -> Int32 = {
2490+
data, dataSize, dataType, dimensions, dimensionCount, context, encoded, encodedSize, params,
2491+
identifier
2492+
in
2493+
guard
2494+
i8xEncodeWithExternalStore(
2495+
data, dataSize, dataType, dimensions, dimensionCount, context, encoded, encodedSize, params,
2496+
identifier) == 0
2497+
else { return 1 }
2498+
return ezm7EncodeWithExternalStore(
2499+
data, dataSize, dataType, dimensions, dimensionCount, context, encoded, encodedSize, params,
2500+
identifier)
2501+
}
2502+
24892503
private let ezm7EncodeWithExternalStore:
24902504
@convention(c) (
24912505
UnsafeRawPointer?, Int, Int32, UnsafePointer<Int32>?, Int32, UnsafeMutableRawPointer?,
@@ -3437,7 +3451,7 @@ private let i8xDecodeJitWithExternalStore:
34373451
let offset = Int(data.load(as: UInt64.self))
34383452
let length = Int((data + MemoryLayout<UInt64>.size).load(as: UInt64.self))
34393453
let mappedData = store.loadBytes(offset: offset, length: length)
3440-
return i8xDecodeJitImpl(
3454+
return i8xDecodeJit(
34413455
mappedData, length, dataType, dimensions, dimensionCount, identifier, context, params,
34423456
tensorOut, decoded, decodedSize)
34433457
}
@@ -4751,7 +4765,7 @@ private let i8xDecodeWithExternalStore:
47514765
let offset = Int(data.load(as: UInt64.self))
47524766
let length = Int((data + MemoryLayout<UInt64>.size).load(as: UInt64.self))
47534767
let mappedData = store.loadBytes(offset: offset, length: length)
4754-
return i8xDecodeImpl(
4768+
return i8xDecode(
47554769
mappedData, length, dataType, dimensions, dimensionCount, identifier, context, params,
47564770
tensorOut, decoded, decodedSize)
47574771
}
@@ -5132,6 +5146,8 @@ extension DynamicGraph {
51325146
return q7pAndEzm7EncodeWithExternalStore
51335147
} else if contains(.ezm7) && contains(.q8p) {
51345148
return q8pAndEzm7EncodeWithExternalStore
5149+
} else if contains(.ezm7) && contains(.i8x) {
5150+
return i8xAndEzm7EncodeWithExternalStore
51355151
} else if contains(.ezm7) {
51365152
return ezm7EncodeWithExternalStore
51375153
} else if contains(.q4p) {
@@ -5165,6 +5181,8 @@ extension DynamicGraph {
51655181
return q7pAndEzm7Encode // Prefer q7p, if it is longer (because 256 palette), use ezm7.
51665182
} else if contains(.ezm7) && contains(.q8p) {
51675183
return q8pAndEzm7Encode // Prefer q8p, if it is longer (because 256 palette), use ezm7.
5184+
} else if contains(.ezm7) && contains(.i8x) {
5185+
return i8xAndEzm7Encode // Prefer i8x, if it is longer (because 256 palette), use ezm7.
51685186
} else if contains(.ezm7) {
51695187
// .ezm7 is not supported with other lossless formats
51705188
guard self == .ezm7 else { return nil }

0 commit comments

Comments
 (0)