diff --git a/visualize.py b/visualize.py index bd6f488e..40694b0e 100644 --- a/visualize.py +++ b/visualize.py @@ -317,7 +317,7 @@ def format_coord(x, y): def plot_fc_weights(self): if not self.args.no_fc_weights: - num_buckets = self.model.feature_set.num_ls_buckets + num_buckets = self.model.num_ls_buckets fig, axs = plt.subplots(3, num_buckets, dpi=self.dpi) extra_info = "" @@ -343,9 +343,7 @@ def get_l1_weights(bucket_id, l1): for i in range(N): l1_weights[2 * i] = l1_weights_[i][self.sorted_input_neurons] - l1_weights[2 * i + 1] = l1_weights_[i][ - self.M + self.sorted_input_neurons - ] + l1_weights[2 * i + 1] = l1_weights_[i+N][self.sorted_input_neurons] return l1_weights, N def get_l2_weights(bucket_id, l2): @@ -376,6 +374,8 @@ def get_l2_weights(bucket_id, l2): self.model.layer_stacks.get_coalesced_layer_stacks() ): l1_weights, N = get_l1_weights(bucket_id, l1) + # truncate to prevent matshow rendering a blank plot + truncated_l1_weights = l1_weights[:, :16] l2_weights = get_l2_weights(bucket_id, l2) output_weights = output.weight.data.numpy() @@ -384,7 +384,7 @@ def get_l2_weights(bucket_id, l2): ax = axs[0, bucket_id] im = ax.matshow( - np.abs(l1_weights) if plot_abs else l1_weights, + np.abs(truncated_l1_weights) if plot_abs else truncated_l1_weights, vmin=vmin, vmax=vmax, cmap=cmap, @@ -410,7 +410,7 @@ def get_l2_weights(bucket_id, l2): ) row_names = ["bucket {}".format(i) for i in range(num_buckets)] - col_names = ["l1", "l2", "output"] + col_names = ["l1 (truncated)", "l2", "output"] for i in range(3): for j in range(num_buckets): ax = axs[i, j] @@ -482,7 +482,7 @@ def plot_fc_biases(self): self.ref_model.layer_stacks.get_coalesced_layer_stacks() ) - num_buckets = self.model.feature_set.num_ls_buckets + num_buckets = self.model.num_ls_buckets fig, axs = plt.subplots(3, num_buckets, dpi=self.dpi) extra_info = "" if self.args.sort_input_neurons: