思路
一种 vector<int>, vector<int>
的 map 做法,代码稍长一些但是特别好想。
先考虑 k = 1 k = 1 k = 1 时的做法,后面 k k k 增大时在原基础上改动一点就可以了。
看到左右括号,我们容易想到遇到左括号加 1 1 1 ,右括号减 1 1 1 ,记第到 i i i 个位置的前缀和为 s u m i sum_i s u m i 。显然,区间 [ l , r ] [l,r] [ l , r ] 若是合法的,必须满足以下条件:
左括号和右括号的数量相等。
在任何位置,前面的右括号个数不能超过左括号。
用数学的形式表达,就是:
s u m r − s u m l − 1 = 0 sum_r - sum_{l-1} = 0 s u m r − s u m l − 1 = 0 。
对于 [ l , r ] [l, r] [ l , r ] 之间的任意一点 i i i ,满足 s u m i − s u m l − 1 ≥ 0 sum_i - sum_{l - 1} \geq 0 s u m i − s u m l − 1 ≥ 0 。
移项,得:
s u m r = s u m l − 1 sum_r = sum_{l-1} s u m r = s u m l − 1 。
对于 [ l , r ] [l, r] [ l , r ] 之间的任意一点 i i i ,满足 s u m i ≥ s u m l − 1 sum_i \geq sum_{l - 1} s u m i ≥ s u m l − 1 。
我们分开考虑每个条件。
若假设 l l l 为当前计算区间的左端点,那么若存在一点 i i i 满足 s u m i < s u m l − 1 sum_i < sum_{l - 1} s u m i < s u m l − 1 且 i ≥ l i \geq l i ≥ l ,那么 i i i 以及后面的所有点和 l l l 所构成的区间都是不合法的。那么我们找到 l l l 右边第一个比 s u m l − 1 sum_{l - 1} s u m l − 1 小的位置 R R R (可以用单调栈预处理),那么 R R R 以前的位置都可能合法。条件二的限制就解决了。结合条件一,我们预处理每个 s u m sum s u m 值都出现过的坐标,用 vector 存起来排序,在 s u m l − 1 sum_{l - 1} s u m l − 1 查找 [ i , R − 1 ] [i, R - 1] [ i , R − 1 ] 之间有几个坐标,加到答案即可。
对于增加的 k k k ,R R R 取 k k k 个数里的最小值,由于 k ≤ 10 k \leq 10 k ≤ 1 0 ,我们直接把 k k k 个 s u m sum s u m 放到一个 vector 里并做成 map,每次用 vector 查就好了,自己想想就清楚了。
时间复杂度 O ( n k log n ) O(nk \log n) O ( n k log n ) 。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 #include <bits/stdc++.h> #ifdef AT_contest #include <atcoder/all> using namespace atcoder;#endif using namespace std;#define fi first #define se second #define pb push_back #define sz(x) (x).size() #define all(x) (x).begin(), (x).end() #define pii pair<int, int> #define mpii map<int, int> #define vi vector<int> #define fr front #define bk back #define ls(x) (x << 1) #define rs(x) (x << 1 | 1) inline int read () { int x = 0 , f = 1 ; char ch = getchar (); while (ch < '0' || ch > '9' ) {if (ch == '-' ) f = -1 ; ch = getchar ();} while (ch >= '0' && ch <= '9' ) {x = x * 10 + ch - 48 ; ch = getchar ();} return x * f; } #define inf 0x7fffffff #define INF 0x3f3f3f3f3f3f3f3fll #if defined(int) #define RETURN_MAIN signed #endif #if !defined(int) #define RETURN_MAIN int #endif char s[12 ][50010 ];int sum[12 ][50010 ], rmax[12 ][50010 ];map<vi, vi> mp; RETURN_MAIN main () { int k = read (), n = read (); for (int i = 1 ; i <= k; i++) scanf ("%s" , s[i] + 1 ); for (int i = 1 ; i <= k; i++) for (int j = 1 ; j <= n; j++) { if (s[i][j] == '(' ) sum[i][j] = sum[i][j - 1 ] + 1 ; else sum[i][j] = sum[i][j - 1 ] - 1 ; } for (int i = 1 ; i <= k; i++) sum[i][n + 1 ] = -inf; for (int i = 1 ; i <= k; i++) { stack<int > st; st.push (n + 1 ); for (int j = n; j >= 0 ; j--) { while (!st.empty () && sum[i][st.top ()] >= sum[i][j]) st.pop (); rmax[i][j] = st.top (); st.push (j); } } for (int j = 0 ; j <= n; j++) { vi tmp; for (int i = 1 ; i <= k; i++) tmp.pb (sum[i][j]); mp[tmp].pb (j); } int ans = 0 ; for (int j = 1 ; j <= n; j++) { int R = inf; for (int i = 1 ; i <= k; i++) R = min (R, rmax[i][j - 1 ] - 1 ); vi tmp; for (int i = 1 ; i <= k; i++) tmp.pb (sum[i][j - 1 ]); if (R <= j) continue ; vi pos = mp[tmp]; int vl = lower_bound (all (pos), j) - pos.begin (), vr = upper_bound (all (pos), R) - pos.begin (); ans += vr - vl; } cout << ans; return 0 ; }