@@ -87,6 +87,7 @@ struct gru_config {
87
87
// Resource reuse info
88
88
static const unsigned io_type = io_parallel;
89
89
static const unsigned reuse_factor = 1 ;
90
+ static const bool pytorch_order = false ;
90
91
static const bool store_weights_in_bram = false ;
91
92
92
93
// Activation
@@ -133,7 +134,10 @@ void gru_cell(data_T x[CONFIG_T::n_in], res_T h[CONFIG_T::n_units],
133
134
hls_register typename CONFIG_T::accum_t hadamard_r_h[CONFIG_T::n_units];
134
135
#pragma unroll recurrent_unroll_factor
135
136
for (int i = 0 ; i < (CONFIG_T::n_units); i++) {
136
- hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
137
+ if (CONFIG_T::pytorch_order)
138
+ hadamard_r_h[i] = z_r_act[i] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
139
+ else
140
+ hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
137
141
}
138
142
139
143
// The candidate state; X * W_{hx} + hadmard(r(t), h_(t-1)) * W_{hh} + b_{h}
@@ -152,7 +156,11 @@ void gru_cell(data_T x[CONFIG_T::n_in], res_T h[CONFIG_T::n_units],
152
156
// Update state
153
157
#pragma unroll recurrent_unroll_factor
154
158
for (int i = 0 ; i < (CONFIG_T::n_units); i++) {
155
- h[i] = static_cast <res_T>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
159
+ if (CONFIG_T::pytorch_order)
160
+ h[i] = static_cast <res_T>(h_cand_act[i] * (1 - z_r_act[i + CONFIG_T::n_units]) +
161
+ h[i] * z_r_act[i + CONFIG_T::n_units]);
162
+ else
163
+ h[i] = static_cast <res_T>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
156
164
}
157
165
}
158
166
@@ -315,6 +323,131 @@ void simple_rnn(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[C
315
323
}
316
324
}
317
325
}
326
+ // ----------------------
327
+ // SimpleRNN with pytorch biases
328
+ // ----------------------
329
+
330
+ struct simpleRNN_pytorch_config {
331
+ // Internal data type definitions
332
+ typedef float weight_t ;
333
+ typedef float bias_t ;
334
+ typedef float accum_t ;
335
+
336
+ // Layer Sizes
337
+ static const unsigned n_in = 1 ;
338
+ static const unsigned n_out = 1 ;
339
+ static const unsigned n_outputs = 1 ;
340
+ static const unsigned n_timesteps = 1 ;
341
+ static const bool return_sequences = false ;
342
+
343
+ // Resource reuse info
344
+ static const unsigned io_type = io_parallel;
345
+ static const unsigned reuse_factor = 1 ;
346
+ static const bool store_weights_in_bram = false ;
347
+
348
+ // Activation
349
+ template <class x_T , class y_T , class config_T > using activation_recr = nnet::activation::relu<x_T, y_T, config_T>;
350
+
351
+ template <class x_T , class y_T , class config_T > using activation = nnet::activation::relu<x_T, y_T, config_T>;
352
+ };
353
+
354
+ template <class data_T , class res_T , typename CONFIG_T>
355
+ void simple_rnn_pytorch_cell (data_T inputs[CONFIG_T::n_in], res_T hidden_state[CONFIG_T::n_out],
356
+ res_T hidden_state_o[CONFIG_T::n_out],
357
+ const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
358
+ const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
359
+ const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
360
+ const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
361
+ // Weight multiplication
362
+ typename CONFIG_T::accum_t afterW[CONFIG_T::n_out] hls_register;
363
+ multiply_W<data_T, typename CONFIG_T::accum_t , typename CONFIG_T::weight_t , CONFIG_T::n_in, CONFIG_T::n_out>(
364
+ inputs, afterW, kernel);
365
+
366
+ // Bias addition
367
+ typename CONFIG_T::accum_t afterBias[CONFIG_T::n_out] hls_register;
368
+ add_bias<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t , typename CONFIG_T::bias_t , CONFIG_T::n_out>(
369
+ afterW, afterBias, bias);
370
+
371
+ // Hidden state
372
+ typename CONFIG_T::accum_t hiddenCand[CONFIG_T::n_out] hls_register;
373
+ multiply_U<data_T, typename CONFIG_T::accum_t , typename CONFIG_T::weight_t , CONFIG_T::n_out>(hidden_state, hiddenCand,
374
+ rec_kernel);
375
+
376
+ // Hidden state bias addition
377
+ typename CONFIG_T::accum_t hiddenBias[CONFIG_T::n_out] hls_register;
378
+ add_bias<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t , typename CONFIG_T::bias_t , CONFIG_T::n_out>(
379
+ hiddenCand, hiddenBias, rec_bias);
380
+
381
+ // Vector addition
382
+ typename CONFIG_T::accum_t afterAdd[CONFIG_T::n_out];
383
+ add_vectors<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t , CONFIG_T::n_out>(afterBias, hiddenBias, afterAdd);
384
+
385
+ // Activation
386
+ CONFIG_T::template activation<typename CONFIG_T::accum_t , data_T, typename CONFIG_T::ACT_CONFIG_T>::activation (
387
+ afterAdd, hidden_state_o);
388
+ }
389
+
390
+ template <class data_T , class res_T , typename CONFIG_T>
391
+ void simple_rnn_pytorch (data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in],
392
+ res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
393
+ const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
394
+ const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
395
+ const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
396
+ const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
397
+ res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1 ] hls_register;
398
+ res_T hidden_state_temp[CONFIG_T::n_out] hls_register;
399
+ res_T h[CONFIG_T::n_out] hls_register;
400
+ data_T in[CONFIG_T::n_in] hls_register;
401
+
402
+ // Set initially hidden state (output) to zero
403
+ INIT_LOOP:
404
+ #pragma unroll
405
+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
406
+ hidden_state[x][0 ] = 0 ;
407
+ }
408
+
409
+ #pragma disable_loop_pipelining
410
+ for (int i = 0 ; i < CONFIG_T::n_timesteps; i++) {
411
+
412
+ // Data at current time step
413
+ #pragma unroll
414
+ for (int x = 0 ; x < CONFIG_T::n_in; x++) {
415
+ in[x] = data[x + i * CONFIG_T::n_in];
416
+ }
417
+
418
+ // Hidden state at current time step
419
+ #pragma unroll
420
+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
421
+ hidden_state_temp[x] = hidden_state[x][i];
422
+ }
423
+
424
+ // Do SimpleRNN
425
+ simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
426
+
427
+ // Write result
428
+ #pragma unroll
429
+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
430
+ hidden_state[x][i + 1 ] = h[x];
431
+ }
432
+ }
433
+
434
+ if (CONFIG_T::return_sequences == 0 ) {
435
+ // Output when return_sequences is false
436
+ #pragma unroll
437
+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
438
+ res[x] = hidden_state[x][CONFIG_T::n_timesteps];
439
+ }
440
+ } else {
441
+ // Output when return_sequences is true
442
+ #pragma unroll
443
+ for (int x = 0 ; x < CONFIG_T::n_timesteps; x++) {
444
+ #pragma unroll
445
+ for (int h = 0 ; h < CONFIG_T::n_out; h++) {
446
+ res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1 ];
447
+ }
448
+ }
449
+ }
450
+ }
318
451
319
452
// ----------------------
320
453
// LSTM
0 commit comments