]> git.wh0rd.org - home.git/blob - .bin/git-rb-catchup
git-rb-catchup: handle running in non-git dirs better
[home.git] / .bin / git-rb-catchup
1 #!/usr/bin/env python3
2 # Distributed under the terms of the GNU General Public License v2 or later.
3
4 """Helper to automatically rebase onto latest commit possible.
5
6 Helpful when you have a branch tracking an old commit, and a lot of conflicting
7 changes have landed in the latest branch, but you still want to update.
8
9 A single rebase to the latest commit will require addressing all the different
10 changes at once which can be difficult, overwhelming, and error-prone. Instead,
11 if you rebased onto each intermediate conflicting point, you'd break up the work
12 into smaller pieces, and be able to run tests to make sure things were still OK.
13 """
14
15 import argparse
16 import os
17 from pathlib import Path
18 import subprocess
19 import sys
20 from typing import List, Tuple, Union
21
22
23 assert sys.version_info >= (3, 7), f'Need Python 3.7+, not {sys.version_info}'
24
25
26 def git(args: List[str], **kwargs) -> subprocess.CompletedProcess:
27 """Run git."""
28 kwargs.setdefault('check', True)
29 kwargs.setdefault('capture_output', True)
30 kwargs.setdefault('encoding', 'utf-8')
31 # pylint: disable=subprocess-run-check
32 return subprocess.run(['git'] + args, **kwargs)
33
34
35 def rebase(target: str) -> bool:
36 """Try to rebase onto |target|."""
37 try:
38 git(['rebase', target])
39 return True
40 except KeyboardInterrupt:
41 git(['rebase', '--abort'])
42 print('aborted')
43 sys.exit(1)
44 except:
45 git(['rebase', '--abort'])
46 return False
47
48
49 def rebase_bisect(lbranch: str,
50 rbranch: str,
51 behind: int,
52 leave_rebase: bool = False,
53 force_checkout: bool = False):
54 """Try to rebase branch as close to |rbranch| as possible."""
55 def attempt(pos: int) -> bool:
56 target = f'{rbranch}~{pos}'
57 print(f'Rebasing onto {target} ', end='')
58 print('.', end='', flush=True)
59 # Checking out these branches directly helps clobber orphaned files,
60 # but is usually unnessary, and can slow down the overall process.
61 if force_checkout:
62 git(['checkout', '-f', target])
63 print('.', end='', flush=True)
64 if force_checkout:
65 git(['checkout', '-f', lbranch])
66 print('. ', end='', flush=True)
67 ret = rebase(target)
68 print('OK' if ret else 'failed')
69 return ret
70
71 # "pmin" is the latest branch position while "pmax" is where we're now.
72 pmin = 0
73 pmax = behind
74 old_mid = None
75 first_fail = 0
76 while True:
77 mid = pmin + (pmax - pmin) // 2
78 if mid == old_mid or mid < pmin or mid >= pmax:
79 break
80 if attempt(mid):
81 pmax = mid
82 else:
83 first_fail = max(first_fail, mid)
84 pmin = mid
85 old_mid = mid
86
87 if pmin or pmax:
88 last_target = f'{rbranch}~{first_fail}'
89 if leave_rebase:
90 print('Restarting', last_target)
91 result = git(['rebase', last_target], check=False)
92 print(result.stdout.strip())
93 else:
94 print('Found first failure', last_target)
95 else:
96 print('All caught up!')
97
98
99 def get_ahead_behind(lbranch: str, rbranch: str) -> Tuple[int, int]:
100 """Return number of commits |lbranch| is ahead & behind relative to |rbranch|."""
101 output = git(
102 ['rev-list', '--first-parent', '--left-right', '--count',
103 f'{lbranch}...{rbranch}']).stdout
104 return [int(x) for x in output.split()]
105
106
107 def get_tracking_branch(branch: str) -> Union[str, None]:
108 """Return branch that |branch| is tracking."""
109 merge = git(['config', '--local', f'branch.{branch}.merge']).stdout.strip()
110 if not merge:
111 return None
112
113 remote = git(['config', '--local', f'branch.{branch}.remote']).stdout.strip()
114 if remote:
115 if merge.startswith('refs/heads/'):
116 merge = merge[11:]
117 return f'{remote}/{merge}'
118 else:
119 return merge
120
121
122 def get_local_branch() -> str:
123 """Return the name of the local checked out branch."""
124 return git(['branch', '--show-current']).stdout.strip()
125
126
127 def get_parser() -> argparse.ArgumentParser:
128 """Get CLI parser."""
129 parser = argparse.ArgumentParser(
130 description=__doc__,
131 formatter_class=argparse.RawDescriptionHelpFormatter)
132 parser.add_argument(
133 '--skip-initial-rebase-latest', dest='initial_rebase',
134 action='store_false', default=True,
135 help='skip initial rebase attempt onto the latest branch')
136 parser.add_argument(
137 '--leave-at-last-failed-rebase', dest='leave_rebase',
138 action='store_true', default=False,
139 help='leave tree state at last failing rebase')
140 parser.add_argument(
141 '--checkout-before-rebase', dest='force_checkout',
142 action='store_true', default=False,
143 help='force checkout before rebasing to target (to cleanup orphans)')
144 parser.add_argument(
145 'branch', nargs='?',
146 help='branch to rebase onto')
147 return parser
148
149
150 def main(argv: List[str]) -> int:
151 """The main entry point for scripts."""
152 parser = get_parser()
153 opts = parser.parse_args(argv)
154
155 try:
156 lbranch = get_local_branch()
157 except subprocess.CalledProcessError as e:
158 sys.exit(f'{os.path.basename(sys.argv[0])}: {Path.cwd()}:\n{e}\n{e.stderr.strip()}')
159 print(f'Local branch resolved to "{lbranch}".')
160 if not lbranch:
161 print('Unable to resolve local branch', file=sys.stderr)
162 return 1
163
164 if opts.branch:
165 rbranch = opts.branch
166 else:
167 rbranch = get_tracking_branch(lbranch)
168 print(f'Tracking branch resolved to "{rbranch}".')
169
170 ahead, behind = get_ahead_behind(lbranch, rbranch)
171 print(f'Branch is {ahead} commits ahead and {behind} commits behind.')
172 print('NB: Counts for the first parent in merge history, not all commits.')
173
174 if not behind:
175 print('Up-to-date!')
176 elif not ahead:
177 print('Fast forwarding ...')
178 git(['merge'])
179 else:
180 if opts.initial_rebase:
181 print(f'Trying to rebase onto latest {rbranch} ... ',
182 end='', flush=True)
183 if rebase(rbranch):
184 print('OK!')
185 return 0
186 print('failed; falling back to bisect')
187 rebase_bisect(lbranch, rbranch, behind, leave_rebase=opts.leave_rebase,
188 force_checkout=opts.force_checkout)
189
190 return 0
191
192
193 if __name__ == '__main__':
194 sys.exit(main(sys.argv[1:]))