Rolling Hash
給 A, B 兩字串,
詢問每個長度為 A, B 字串長度公因數的子字串
是否可以藉由複製自己多次,重組為 A, B
且是否為 A, B 的共同子字串。
用 時間枚舉每個公因數長度子字串, 檢查是否合法。
總體複雜度
利用 Rolling Hash 性質:
假設原字串 Hash 為 ,字串長為 ,那麼
其中, 為 hash 使用的基底, 為 hash 使用的模數 為字串 A 重複 2 遍的 hash 結果。
可以利用上面的性質和快速冪的概念,
在 時間內取得 A 重複 N 遍的 hash 值。
實作方式可以參考程式第 72
行。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// CF182D - 37
// Submmit ID: 28497065
// Time: 30 ms
#pragma optimize ("O3")
#pragma target ("avx")
#include <bits/stdc++.h>
#define UT long long int
#define PII std::pair< UT, UT >
struct RollingHash {
#define MAXN 100010
#define PB push_back
#define F first
#define S second
static const PII p, q;
PII h[MAXN], cache[MAXN];
inline int mulmod (int a, int b, int m) {
int d, r;
if(a==0 or b==0) return 0;
if(a==1 or b==1) return (a>b?a:b);
__asm__("mull %4;"
"divl %2"
: "=d" (r), "=a"(d)
: "r"(m), "a"(a), "d"(b));
return r;
}
inline void get_pow(const int n) {
const PII &base = p;
const PII &P = q;
PII *h = this->cache;
h[0].F=h[0].S=1;
for(int i=1; i<=n; ++i) {
h[i].F = this->mulmod(h[i-1].F, base.F, P.F);
h[i].S = this->mulmod(h[i-1].S, base.S, P.S);
}
}
inline void get_hash(std::string &s) {
// rev 代表是否要反向建 Hash
PII *h = this->h;
h[0].F=h[0].S=h[s.length()+1].S=h[s.length()+1].F=0;
++h; // index is shifted
for(int i=0; i!=s.length(); ++i ){
h[i].F = (this->mulmod(h[i-1].F,p.F,q.F)+(UT)(s[i]-'a')+1LL)%q.F;
h[i].S = (this->mulmod(h[i-1].S,p.S,q.S)+(UT)(s[i]-'a')+1LL)%q.S;
}
}
// pw 代表 在 get_pow() 得到的 hash 值
// 下面的函數用來取得 從原字串 i 位置開始長度 n 子字串的 hash 值
std::pair<UT,UT> partial_hash(int i, int n) {
const PII &pw = cache[n];
PII *h = this->h;
++h; //shift index
UT temp1, temp2;//Lazy dog...
temp1 = ((q.F - this->mulmod(h[i-1].F,pw.F,q.F)) + h[i+n-1].F)%q.F;
temp2 = ((q.S - this->mulmod(h[i-1].S,pw.S,q.S)) + h[i+n-1].S)%q.S;
return std::make_pair(temp1, temp2);
}
std::pair<UT,UT> double_self(int n, std::pair<UT,UT> target) {
// string: *A -> *2A
// hash : (p^n+1)*A, where n is length of A
//std::pair<UT,UT> target = this->partial_hash(i, n);
target.F = (target.F + this->mulmod(target.F,this->cache[n].F,q.F))%q.F;
target.S = (target.S + this->mulmod(target.S,this->cache[n].S,q.S))%q.S;
return target;
}
std::pair<UT,UT> extend_self(int i, int n, int t) {
std::pair<UT,UT> base = this->partial_hash(i,n);
std::pair<UT,UT> res(0,0); // 一次
int crr_times = 0;
int base_times = 1;
while(t) {
if (t&1) {
res.F = this->mulmod(res.F,cache[base_times*n].F,q.F);
res.S = this->mulmod(res.S,cache[base_times*n].S,q.S);
res.F = (res.F+base.F)%q.F;
res.S = (res.S+base.S)%q.S;
crr_times += base_times;
}
base = this->double_self(n*base_times, base); // A -> AA -> AAAA
t>>=1; base_times<<=1; // base length
}
return res;
}
//#undef PB
//#undef F
//#undef S
#undef MAXN
};
const std::pair<UT,UT> RollingHash::p(311, 337);
const std::pair<UT,UT> RollingHash::q(10000103, 10000121);
//#undef UT
//#undef PII
using namespace std;
RollingHash ah, bh;
char buff[100010];
int main(void) {
fgets(buff,100005,stdin);
strtok(buff,"\n");
string a(buff);
fgets(buff,100005,stdin);
strtok(buff,"\n");
string b(buff);
ah.get_pow((int)a.length());
bh.get_pow((int)b.length());
ah.get_hash(a);
bh.get_hash(b);
vector<int> cd;
int la = a.length(), lb = b.length();
int lmin = min(la, lb);
for (int i=1; i<=lmin; ++i) {
if (la % i == 0 && lb % i == 0) // 公因數
cd.push_back(i);
}
int counter = 0;
pair<UT,UT> ta = ah.partial_hash(0, a.length());
pair<UT,UT> tb = bh.partial_hash(0, b.length());
for (vector<int>::iterator v=cd.begin(); v!=cd.end(); ++v) {
const int &i = *v;
pair<UT,UT> ppa = ah.extend_self(0, i, a.length()/i);
pair<UT,UT> ppb = bh.extend_self(0, i, b.length()/i);
pair<UT,UT> tta = ah.partial_hash(0, i);
pair<UT,UT> ttb = bh.partial_hash(0, i);
if (ta==ppa && tb==ppb && tta==ttb) ++counter;
}
printf("%d\n", counter);
return 0;
}