@@ -5066,6 +5066,34 @@ struct ggml_tensor * ggml_repeat(
5066
5066
return result;
5067
5067
}
5068
5068
5069
+ // ggml_repeat2
5070
+
5071
+ struct ggml_tensor * ggml_repeat2(
5072
+ struct ggml_context * ctx,
5073
+ struct ggml_tensor * a,
5074
+ struct ggml_tensor * b) {
5075
+ GGML_ASSERT(ggml_can_repeat(a, b));
5076
+
5077
+ bool is_node = false;
5078
+
5079
+ if (a->grad) {
5080
+ is_node = true;
5081
+ }
5082
+
5083
+ if (ggml_are_same_shape(a, b) && !is_node) {
5084
+ return a;
5085
+ }
5086
+
5087
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5088
+
5089
+ result->op = GGML_OP_REPEAT2;
5090
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5091
+ result->src0 = a;
5092
+ result->src1 = b;
5093
+
5094
+ return result;
5095
+ }
5096
+
5069
5097
// ggml_abs
5070
5098
5071
5099
struct ggml_tensor * ggml_abs_impl(
@@ -8847,6 +8875,87 @@ static void ggml_compute_forward_repeat(
8847
8875
}
8848
8876
}
8849
8877
8878
+ // ggml_compute_forward_repeat2
8879
+
8880
+ static void ggml_compute_forward_repeat2_f32(
8881
+ const struct ggml_compute_params * params,
8882
+ const struct ggml_tensor * src0,
8883
+ struct ggml_tensor * dst) {
8884
+ GGML_ASSERT(params->ith == 0);
8885
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
8886
+
8887
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8888
+ return;
8889
+ }
8890
+
8891
+ const int64_t ne0 = dst->ne[0];
8892
+ const int64_t ne1 = dst->ne[1];
8893
+ const int64_t ne2 = dst->ne[2];
8894
+ const int64_t ne3 = dst->ne[3];
8895
+
8896
+ const int64_t ne00 = src0->ne[0];
8897
+ const int64_t ne01 = src0->ne[1];
8898
+ const int64_t ne02 = src0->ne[2];
8899
+ const int64_t ne03 = src0->ne[3];
8900
+
8901
+ const size_t nb0 = dst->nb[0];
8902
+ const size_t nb1 = dst->nb[1];
8903
+ const size_t nb2 = dst->nb[2];
8904
+ const size_t nb3 = dst->nb[3];
8905
+
8906
+ const size_t nb00 = src0->nb[0];
8907
+ const size_t nb01 = src0->nb[1];
8908
+ const size_t nb02 = src0->nb[2];
8909
+ const size_t nb03 = src0->nb[3];
8910
+
8911
+ // guaranteed to be an integer due to the check in ggml_can_repeat
8912
+ const int nr0 = (int)(ne0/ne00);
8913
+ const int nr1 = (int)(ne1/ne01);
8914
+ const int nr2 = (int)(ne2/ne02);
8915
+ const int nr3 = (int)(ne3/ne03);
8916
+
8917
+ // TODO: support for transposed / permuted tensors
8918
+ GGML_ASSERT(nb0 == sizeof(float));
8919
+ GGML_ASSERT(nb00 == sizeof(float));
8920
+
8921
+ int i2k2 = 0;
8922
+
8923
+ // TODO: maybe this is not optimal?
8924
+ for (int i3 = 0; i3 < nr3; i3++) {
8925
+ for (int k3 = 0; k3 < ne03; k3++, i2k2 = 0) {
8926
+ for (int i2 = 0; i2 < nr2; i2++) {
8927
+ for (int k2 = 0; k2 < ne02; k2++, i2k2++) {
8928
+ for (int i1 = 0; i1 < nr1; i1++) {
8929
+ for (int k1 = 0; k1 < ne01; k1++) {
8930
+ for (int i0 = 0; i0 < nr0; i0++) {
8931
+ ggml_vec_cpy_f32(ne00,
8932
+ (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
8933
+ (float *) ((char *) src0->data + ( k3)*nb03 + (i2k2 / nr2)*nb02 + ( k1)*nb01));
8934
+ }
8935
+ }
8936
+ }
8937
+ }
8938
+ }
8939
+ }
8940
+ }
8941
+ }
8942
+
8943
+ static void ggml_compute_forward_repeat2(
8944
+ const struct ggml_compute_params * params,
8945
+ const struct ggml_tensor * src0,
8946
+ struct ggml_tensor * dst) {
8947
+ switch (src0->type) {
8948
+ case GGML_TYPE_F32:
8949
+ {
8950
+ ggml_compute_forward_repeat2_f32(params, src0, dst);
8951
+ } break;
8952
+ default:
8953
+ {
8954
+ GGML_ASSERT(false);
8955
+ } break;
8956
+ }
8957
+ }
8958
+
8850
8959
// ggml_compute_forward_abs
8851
8960
8852
8961
static void ggml_compute_forward_abs_f32(
@@ -13336,6 +13445,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13336
13445
{
13337
13446
ggml_compute_forward_repeat(params, tensor->src0, tensor);
13338
13447
} break;
13448
+ case GGML_OP_REPEAT2:
13449
+ {
13450
+ ggml_compute_forward_repeat2(params, tensor->src0, tensor);
13451
+ } break;
13339
13452
case GGML_OP_ABS:
13340
13453
{
13341
13454
ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13684,6 +13797,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13684
13797
GGML_ASSERT(false); // TODO: implement
13685
13798
} break;
13686
13799
case GGML_OP_REPEAT:
13800
+ case GGML_OP_REPEAT2:
13687
13801
{
13688
13802
// necessary for llama
13689
13803
if (src0->grad) {
@@ -14568,6 +14682,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14568
14682
case GGML_OP_SUM_ROWS:
14569
14683
case GGML_OP_MEAN:
14570
14684
case GGML_OP_REPEAT:
14685
+ case GGML_OP_REPEAT2:
14571
14686
case GGML_OP_ABS:
14572
14687
case GGML_OP_SGN:
14573
14688
case GGML_OP_NEG:
0 commit comments