defdiff_after_k(nums, k): n = len(nums) if n <= 1: return0
# 大 k 提前判定:是否已能做到最终的 0/1 差 s = sum(nums) ceil_mean = (s + n - 1) // n need = sum(max(0, a - ceil_mean) for a in nums) # 需要的最少步数 if k >= need: return0if s % n == 0else1
# 双堆 + 懒删除 cnt = Counter(nums) min_heap = nums[:] max_heap = [-x for x in nums] heapq.heapify(min_heap) heapq.heapify(max_heap)
defget_min_val(): while min_heap and cnt[min_heap[0]] == 0: heapq.heappop(min_heap) return min_heap[0] if min_heap elseNone
defget_max_val(): while max_heap and cnt[-max_heap[0]] == 0: heapq.heappop(max_heap) return -max_heap[0] if max_heap elseNone
while k > 0: lo = get_min_val() hi = get_max_val() if lo isNoneor hi isNoneor hi - lo <= 1: break
# 从两端各拿一个,向中间移动 1 cnt[lo] -= 1 cnt[hi] -= 1
lo1 = lo + 1 hi1 = hi - 1 cnt[lo1] += 1 cnt[hi1] += 1
lo = get_min_val() hi = get_max_val() if lo isNoneor hi isNone: return0 returnmax(0, hi - lo)
# 示例 if __name__ == "__main__": # 输入格式示例:n k\n a1 a2 ... an import sys data = sys.stdin.read().strip().split() if data: it = iter(data) n = int(next(it)); k = int(next(it)) nums = [int(next(it)) for _ inrange(n)] print(diff_after_k(nums, k))