Minor formatting fixes for sort_and_normalize_weights

This commit is contained in:
Colin Basnett
2025-05-06 19:08:20 -07:00
parent 16471344f0
commit f6fa646a63

View File

@@ -115,22 +115,27 @@ class Psk(object):
def sort_and_normalize_weights(self):
self.weights.sort(key=lambda x: x.point_index)
weight_index = 0
weight_total = len(self.weights)
while weight_index < weight_total:
point_index = self.weights[weight_index].point_index
weight_sum = self.weights[weight_index].weight
point_weight_total = 1
point_index: int = self.weights[weight_index].point_index
weight_sum: float = self.weights[weight_index].weight
# Count the number of weights with contiguous point indices and sum the total weights.
for w in range(weight_index + 1, weight_total):
if point_index != self.weights[w].point_index:
# Calculate the sum of weights for the current point_index.
for i in range(weight_index + 1, weight_total):
if self.weights[i].point_index != point_index:
break
weight_sum += self.weights[i].weight
point_weight_total += 1
weight_sum += self.weights[w].weight
# Now normalize the weights against the sum of all weights.
for weight in self.weights[weight_index:weight_index+point_weight_total]:
weight.weight /= weight_sum
# Increment
# Normalize the weights for the current point_index.
for i in range(weight_index, weight_index + point_weight_total):
self.weights[i].weight /= weight_sum
# Move to the next group of weights.
weight_index += point_weight_total
def __init__(self):