每次把当前全局最大值减 1、全局最小值加 1,做 k 次(如果中途已不能再缩小差值则提前结束),问此时最大值与最小值的差。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from collections import Counter
from math import ceil

def range_after_k(nums, k):
n = len(nums)
if n <= 1:
return 0
s = sum(nums)
lo_target = s // n
hi_target = (s + n - 1) // n # ceil(s/n)

# 最少需要的步数,若 k 足够大则直接给出最终差值
need = sum(max(0, a - hi_target) for a in nums)
if k >= need:
return 0 if s % n == 0 else 1

cnt = Counter(nums)
lo = min(cnt)
hi = max(cnt)

while k > 0 and lo < hi:
while lo < hi and cnt[lo] == 0:
lo += 1
while lo < hi and cnt[hi] == 0:
hi -= 1
if lo >= hi or hi - lo <= 1:
break

move = min(cnt[lo], cnt[hi], k)
# 从两端各“搬” move 个
cnt[lo] -= move
cnt[lo + 1] += move
cnt[hi] -= move
cnt[hi - 1] += move
k -= move

if cnt[lo] == 0:
lo += 1
if cnt[hi] == 0:
hi -= 1

return max(0, hi - lo)

复杂度

  • 构建计数 O(n),双指针整体移动次数不超过值域跨度 U,循环为 O(U + 批次数);批量搬运使得即使 k 很大也不会按步数线性运行。
  • 若值域很大且不便计数,可用有序多重集/平衡树(如 C++ multiset):每次弹出最小和最大各 1 个、修改后再插入,做 k 次,复杂度 O(k log n)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <bits/stdc++.h>
using namespace std;

long long diff_after_k(vector<long long>& a, long long k) {
int n = (int)a.size();
if (n <= 1) return 0;
multiset<long long> ms(a.begin(), a.end());

while (k-- > 0) {
auto it_lo = ms.begin();
auto it_hi = prev(ms.end());
long long lo = *it_lo;
long long hi = *it_hi;
if (hi - lo <= 1) break;

// 移除一端各 1 个,再各向中间移动 1 后插回
ms.erase(it_lo);
it_hi = prev(ms.end()); // 先移除最小后再次获取最大迭代器
long long cur_hi = *it_hi;
ms.erase(it_hi);

ms.insert(lo + 1);
ms.insert(cur_hi - 1);
}

if (ms.empty()) return 0;
long long lo = *ms.begin();
long long hi = *prev(ms.end());
return max(0LL, hi - lo);
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n; long long k;
if (!(cin >> n >> k)) return 0;
vector<long long> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
cout << diff_after_k(a, k) << "\n";
return 0;
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import heapq
from collections import Counter

def diff_after_k(nums, k):
n = len(nums)
if n <= 1:
return 0

# 大 k 提前判定:是否已能做到最终的 0/1 差
s = sum(nums)
ceil_mean = (s + n - 1) // n
need = sum(max(0, a - ceil_mean) for a in nums) # 需要的最少步数
if k >= need:
return 0 if s % n == 0 else 1

# 双堆 + 懒删除
cnt = Counter(nums)
min_heap = nums[:]
max_heap = [-x for x in nums]
heapq.heapify(min_heap)
heapq.heapify(max_heap)

def get_min_val():
while min_heap and cnt[min_heap[0]] == 0:
heapq.heappop(min_heap)
return min_heap[0] if min_heap else None

def get_max_val():
while max_heap and cnt[-max_heap[0]] == 0:
heapq.heappop(max_heap)
return -max_heap[0] if max_heap else None

while k > 0:
lo = get_min_val()
hi = get_max_val()
if lo is None or hi is None or hi - lo <= 1:
break

# 从两端各拿一个,向中间移动 1
cnt[lo] -= 1
cnt[hi] -= 1

lo1 = lo + 1
hi1 = hi - 1
cnt[lo1] += 1
cnt[hi1] += 1

heapq.heappush(min_heap, lo1)
heapq.heappush(min_heap, hi1)
heapq.heappush(max_heap, -lo1)
heapq.heappush(max_heap, -hi1)

k -= 1

lo = get_min_val()
hi = get_max_val()
if lo is None or hi is None:
return 0
return max(0, hi - lo)

# 示例
if __name__ == "__main__":
# 输入格式示例:n k\n a1 a2 ... an
import sys
data = sys.stdin.read().strip().split()
if data:
it = iter(data)
n = int(next(it)); k = int(next(it))
nums = [int(next(it)) for _ in range(n)]
print(diff_after_k(nums, k))