1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| #include <bits/stdc++.h> using namespace std; template <typename T> inline void read(T &a){ T w = 1; a = 0; char ch = getchar(); for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') w = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) a = (a * 10) + (ch - '0'); a *= w; } #define int long long template <typename T> inline void ckmax(T &a, T b){a = a > b ? a : b;} template <typename T> inline void ckmin(T &a, T b){a = a < b ? a : b;} #define fi first #define se second #define pb push_back #define mp make_pair #define mii map<int, int> #define pii pair<int, int> #define vi vector<int> #define si set<int> #define ins insert #define LL long long #define era erase #define Debug(x) cout << #x << " = " << x << endl #define For(i,l,r) for (int i = l; i <= r; ++i) #define foR(i,l,r) for (int i = l; i >= r; --i) const int N = 5e4 + 10; struct edge { int to, nxt; LL dis; } e[N << 1]; int head[N], cnt; inline void add (int u, int v, LL w) { e[++cnt].to = v; e[cnt].nxt = head[u]; head[u] = cnt; e[cnt].dis = w; } int n, m, S, T; struct Node { int id; LL dis; Node (int Id = 0, LL Dis = 0) { id = Id, dis = Dis; } bool operator < (const Node &a) const { return dis > a.dis; } };
LL dis[2][N], f[2][N]; bool vis[N]; inline void Dij (int op) { priority_queue <Node> q; memset (dis[op], 0x3f, sizeof dis[op]); memset (vis, 0, sizeof vis); if (op == 0) q.push(Node(S, 0)), dis[op][S] = 0, f[op][S] = 1; else q.push(Node(T, 0)), dis[op][T] = 0, f[op][T] = 1; while (!q.empty()) { Node tmp = q.top(); q.pop(); if (vis[tmp.id]) continue; vis[tmp.id] = 1; int u = tmp.id; for (int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if (dis[op][v] == dis[op][u] + e[i].dis) f[op][v] += f[op][u]; if (dis[op][v] > dis[op][u] + e[i].dis) { dis[op][v] = dis[op][u] + e[i].dis; f[op][v] = f[op][u]; q.push(Node(v, dis[op][v])); } } } }
int shortest, F[N]; map<LL, bitset <N>> t; bitset <N> g[2][N]; bool InShort (int u) { return (dis[0][u] + dis[1][u] == shortest); } int du[N];
void Topu (int op) { static queue <int> q; while (!q.empty()) q.pop();
For (u, 1, n) for (int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if (dis[op][u] + e[i].dis + dis[op ^ 1][v] == shortest) du[v]++; } For (i, 1, n) { g[op][i].set(), g[op][i][0] = g[op][i][i] = 0; if (!du[i]) q.push(i); }
while (!q.empty()) { int u = q.front(); q.pop(); for (int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if (dis[op][u] + e[i].dis + dis[op ^ 1][v] == shortest) { du[v]--; g[op][v] &= g[op][u]; if (!du[v]) q.push(v); } } }
}
signed main() { read(n), read(m), read(S), read(T); For (i, 1, m) { int u, v; LL w; read(u), read(v), read(w); add(u, v, w), add(v, u, w); } Dij(0), Dij(1); shortest = dis[0][T]; if (!f[0][T]) return printf ("%lld\n", 1ll * n * (n - 1) / 2), 0; For (i, 1, n) {if (InShort(i)) { F[i] = f[0][i] * f[1][i]; } t[F[i]].set(i); } Topu(0), Topu(1); LL Ans = 0; For (i, 1, n) Ans += (t[F[T] - F[i]] & g[0][i] & g[1][i]).count(); printf ("%lld\n", Ans >> 1); }
|