K优解

Time Limit: 20 Sec Memory Limit: 512 MB

Description

给定n个行数,每行m个。在每行中选出一个数来,求出前 k 小的异或和。

Input

第一行 3 个正整数 n,m,k。

接下来 n 行,每行 m 个非负整数,第 i 行第 j 个为权值a[i][j]。

Output

一行一个数表示答案。

Sample Input

3 2 2
  11 21
  9 25
  17 19

Sample Output

2

HINT

n*m<=300000,k<=300000,保证m^n>=k,a[i][j]均不超过10^9

Solution

先对于每个 i,将每行的 a[i][1]~a[i][m] 从小到大排序,再将按照其元素差值多关键字排序(共m-1个关键字)。

那么我们知道,最小的方案肯定是所有行都取第一个。由于其有一些特殊,我们先抛开这个方案。
  我们知道,次小的方案是**(2,1,1,1…),把这个状态加入堆,由较优方案扩展较劣方案**,对于每一个状态,我们记录其扩展到第几行,以及取第几个元素

已经得到前 k 优的方案时,当前所有方案中还未扩展的最好的方案x(其最后扩展位置为 i),就是第 k+1 优

从方案x,我们可以扩展出几个较劣解

1、x 的第 i 个元素不取m:将 i 行取的元素增加1(扩展位置为 i

2、i + 1 <= n:将 i+1 行取为2(扩展位置为 i+1

3、x 的第 i 个元素取为2i + 1 <= n:将 i 行取为1,i+1 行取为2(扩展位置为 i+1

由此,每个解都可由唯一的优于它的解扩展得来。

用个维护一下,每次取出最小的即可。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>
using namespace std;
typedef long long s64;

const int ONE = 300005;
const int MOD = 1e9 + 7;

int n, m, k;
vector <int> A[ONE];
int id[ONE];
s64 Ans;

struct power
{
s64 val;
int pt, id;
bool operator <(power a) const
{
return a.val < val;
}
};
priority_queue <power> q;

int cmp(int a, int b)
{
for(int i = 1; i < m; i++)
{
if(A[a][i + 1] - A[a][i] < A[b][i + 1] - A[b][i]) return 1;
if(A[a][i + 1] - A[a][i] > A[b][i + 1] - A[b][i]) return 0;
}
return 0;
}

int get()
{
int res=1,Q=1; char c;
while( (c=getchar())<48 || c>57)
if(c=='-')Q=-1;
if(Q) res=c-48;
while((c=getchar())>=48 && c<=57)
res=res*10+c-48;
return res*Q;
}

int main()
{
n = get(); m = get(); k = get();
for(int i = 1; i <= n; i++)
{
A[i].push_back(0);
for(int j = 1; j <= m; j++)
A[i].push_back(get());
sort(A[i].begin(), A[i].end());
id[i] = i;
}

sort(id + 1, id + n + 1, cmp);

s64 res = 0;
for(int i = 1; i <= n; i++) res += A[i][1];
Ans = res;

q.push((power){res - A[id[1]][1] + A[id[1]][2], 1, 2});

for(int i = 2; i <= k; i++)
{
power u = q.top(); q.pop();
Ans ^= u.val;

if(u.id + 1 <= m)
q.push((power){u.val - A[id[u.pt]][u.id] + A[id[u.pt]][u.id + 1], u.pt, u.id + 1});
if(u.pt + 1 <= n && 2 <= m)
q.push((power){u.val - A[id[u.pt + 1]][1] + A[id[u.pt + 1]][2], u.pt + 1, 2});
if(u.pt + 1 <= n && u.id == 2)
q.push((power){u.val - A[id[u.pt]][2] + A[id[u.pt]][1] - A[id[u.pt + 1]][1] + A[id[u.pt + 1]][2], u.pt + 1, 2});
}

printf("%lld", Ans);
}