今天读某公众号推送的一篇文章,题目是:求无序数组中的中位数.最简单的方法莫非就是排序,然后直接print中位数.
《算法导论》第9章 中位数和顺序统计量 正好讲到了这个问题,可以在不排序的情况下求解,并且时间复杂度接近O(n)
概念
找出数组中的最大值、最小值和中位数问题都可以一般化为选择问题:从一个由n个互异的元素构成的集合中选择第i个顺序统计量问题
第i个顺序统计量指集合中第i小的元素,所以:
- 最小值是第1个顺序统计量(i=1)
- 最大值是第n个顺序统计量(i=n)
- 当n为奇数时,中位数的i=(n+1)/2;当n为偶数时,中位数的i=n/2和n/2+1
期望为线性时间的选择算法
该算法以快速排序算法为原型,采用分治法,基本思路是:任意选择一个元素作为key,基于key将数组分为两部分.左部分元素均小于等于key,右部分元素均大于key.如果key的下标idx正好等于(n+1)/2,那么key即为中位数.否则若idx<(n+1)/2,那么递归去处理右部分,反之处理左部分
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from random import randint
def randomized_partition(A, low, high):
rand_n = randint(low, high)
key = A[rand_n]
A[rand_n] = A[high]
A[high] = key
i = low - 1
j = low
tmp = 0
while j < high:
if A[j] < key:
i += 1
tmp = A[i]
A[i] = A[j]
A[j] = tmp
j += 1
A[high] = A[i + 1]
A[i + 1] = key
return i + 1
def randomized_select(A, low, high, i):
if low == high:
return A[low]
q = randomized_partition(A, low, high)
k = q-low+1
if i == k:
return A[q]
elif i < k:
return randomized_select(A, low, q-1, i)
else:
return randomized_select(A, q+1, high, i-k)
if __name__ == '__main__':
A = [4,2,3,1,7]
# i取4,表示求解A中的中位数
print randomized_select(A, 0, len(A)-1, 4)
不同于快速排序会递归处理划分的两边,而randomized_select只处理换分的一边.经证明randomized_select的期望运行时间为O(n)
最坏情况为线性时间的选择算法
和randomized_select一样,select算法也是通过递归划分来寻找所需元素.但是该算法能保证得到对数组的一个好的划分.根据《算法导论》中描述,select算法的步骤为:
- 1.将n个元素的输入数组划分为[n/5]组,每组5个元素,且至多只有一组由剩下的n mod 5个元素组成
- 2.寻找[n/5]组中每一组的中位数:首先对每组元素进行插入排序,然后确定每一组的有序元素的中位数
- 3.对第2步中找出的[n/5]个中位数,递归调用select以找出其中位数x(如果有偶数个中位数,为了方便,取较小那个中位数)
- 4.利用修改过的partition,按中位数的中位数x对输入数组进行划分,确定x在数组中的位置k
- 5.如果i==k,返回x.否则,i<k,处理低区.反之在高区寻找i-k小的元素
#!/usr/bin/env python
# -*- coding: utf-8 -*-
def partition(A, low, high, key):
idx = 0
for i in xrange(low, high):
if A[i] == key:
idx = i
break
swap(A, idx, high)
i = low - 1
j = low
while j < high:
if A[j] < key:
i += 1
swap(A, i, j)
j += 1
swap(A, i+1, high)
return i + 1
def insert_sort(A, low, high):
i = low + 1
while i <= high:
key = A[i]
k = i - 1
while k >= low and A[k] > key:
A[k + 1] = A[k]
k -= 1
A[k + 1] = key
i += 1
def swap(A, a, b):
tmp = A[a]
A[a] = A[b]
A[b] = tmp
def select(A, low, high, i):
if high-low < 5:
insert_sort(A, low, high)
return A[low + i - 1]
group = (high - low + 5) / 5
for j in xrange(group):
left = low + j*5
right = (low + j*5 + 4) if (low + j*5 + 4) < high else high
mid = (left + right)/2
insert_sort(A, left, right)
swap(A, low+j, mid)
key = select(A, low, low+group-1, (group+1)/2)
key_idx = partition(A, low, high, key)
k = key_idx - low + 1
if k == i:
return A[key_idx]
elif k > i:
return select(A, low, key_idx-1, i)
else:
return select(A, key_idx+1, high, i-k)
if __name__ == '__main__':
A = [32,23,12,67,45,78,10,39,9,58]
for i in xrange(1, 11):
print select(A, 0, len(A)-1, i)
脑洞大开:利用最小堆
(依据待字闺中微信公众号推送的文章)
首先,将数组的前(n+1)/2个元素建立一个最小堆.然后对于下一个元素,和堆顶元素比较,如果小于等于就丢弃之.接着看下一个元素,如果大于,则用该元素取代该顶,再调整堆.重复直至数组为空时,堆顶元素即为中位数
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import heapq
def heap_select(A, lens):
h = []
for j in xrange((lens+1)/2):
heapq.heappush(h, A[j])
top = 0
for j in xrange((lens+1)/2, lens):
top = heapq.heappop(h)
if A[j] <= top:
heapq.heappush(h, top)
continue
else:
heapq.heappush(h, A[j])
return heapq.heappop(h)
if __name__ == '__main__':
A = [4,5,1,3,2]
print heap_select(A, len(A))
参考:
- 《算法导论》Chapter 9
- 微信公众号:待字闺中