Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def format_coord(x, y):

def plot_fc_weights(self):
if not self.args.no_fc_weights:
num_buckets = self.model.feature_set.num_ls_buckets
num_buckets = self.model.num_ls_buckets
fig, axs = plt.subplots(3, num_buckets, dpi=self.dpi)

extra_info = ""
Expand All @@ -343,9 +343,7 @@ def get_l1_weights(bucket_id, l1):

for i in range(N):
l1_weights[2 * i] = l1_weights_[i][self.sorted_input_neurons]
l1_weights[2 * i + 1] = l1_weights_[i][
self.M + self.sorted_input_neurons
]
l1_weights[2 * i + 1] = l1_weights_[i+N][self.sorted_input_neurons]
return l1_weights, N

def get_l2_weights(bucket_id, l2):
Expand Down Expand Up @@ -376,6 +374,8 @@ def get_l2_weights(bucket_id, l2):
self.model.layer_stacks.get_coalesced_layer_stacks()
):
l1_weights, N = get_l1_weights(bucket_id, l1)
# truncate to prevent matshow rendering a blank plot
truncated_l1_weights = l1_weights[:, :16]
l2_weights = get_l2_weights(bucket_id, l2)
output_weights = output.weight.data.numpy()

Expand All @@ -384,7 +384,7 @@ def get_l2_weights(bucket_id, l2):

ax = axs[0, bucket_id]
im = ax.matshow(
np.abs(l1_weights) if plot_abs else l1_weights,
np.abs(truncated_l1_weights) if plot_abs else truncated_l1_weights,
vmin=vmin,
vmax=vmax,
cmap=cmap,
Expand All @@ -410,7 +410,7 @@ def get_l2_weights(bucket_id, l2):
)

row_names = ["bucket {}".format(i) for i in range(num_buckets)]
col_names = ["l1", "l2", "output"]
col_names = ["l1 (truncated)", "l2", "output"]
for i in range(3):
for j in range(num_buckets):
ax = axs[i, j]
Expand Down Expand Up @@ -482,7 +482,7 @@ def plot_fc_biases(self):
self.ref_model.layer_stacks.get_coalesced_layer_stacks()
)

num_buckets = self.model.feature_set.num_ls_buckets
num_buckets = self.model.num_ls_buckets
fig, axs = plt.subplots(3, num_buckets, dpi=self.dpi)
extra_info = ""
if self.args.sort_input_neurons:
Expand Down