]> git.wh0rd.org Git - home.git/blob - .bin/git-rb-catchup
git-rb-catchup: add option to automatically stop at last failing rebase
[home.git] / .bin / git-rb-catchup
1 #!/usr/bin/env python3
2
3 """Helper to automatically rebase onto latest commit possible."""
4
5 import argparse
6 import subprocess
7 import sys
8
9
10 def git(args, **kwargs):
11     """Run git."""
12     kwargs.setdefault('check', True)
13     kwargs.setdefault('capture_output', True)
14     kwargs.setdefault('encoding', 'utf-8')
15     return subprocess.run(['git'] + args, **kwargs)
16
17
18 def rebase(target):
19     """Try to rebase onto |target|."""
20     try:
21         git(['rebase', target])
22         return True
23     except KeyboardInterrupt:
24         git(['rebase', '--abort'])
25         print('aborted')
26         sys.exit(1)
27     except:
28         git(['rebase', '--abort'])
29         return False
30
31
32 def rebase_bisect(lbranch, rbranch, behind, leave_rebase=False):
33     """Try to rebase branch as close to |rbranch| as possible."""
34     def attempt(pos):
35         target = f'{rbranch}~{pos}'
36         print(f'Rebasing onto {target} ', end='', flush=True)
37         print('.', end='', flush=True)
38 #        git(['checkout', '-f', target])
39         print('.', end='', flush=True)
40 #        git(['checkout', '-f', lbranch])
41         print('. ', end='', flush=True)
42         ret = rebase(target)
43         print('OK' if ret else 'failed')
44         return ret
45
46     # "pmin" is the latest branch position while "pmax" is where we're now.
47     pmin = 0
48     pmax = behind
49     old_mid = None
50     first_fail = 0
51     while True:
52         mid = pmin + (pmax - pmin) // 2
53         if mid == old_mid or mid < pmin or mid >= pmax:
54             break
55         if attempt(mid):
56             pmax = mid
57         else:
58             first_fail = max(first_fail, mid)
59             pmin = mid
60         old_mid = mid
61
62     if pmin or pmax:
63         last_target = f'{rbranch}~{first_fail}'
64         if leave_rebase:
65             print('Restarting', last_target)
66             result = git(['rebase', last_target], check=False)
67             print(result.stdout.strip())
68         else:
69             print('Found first failure', last_target)
70     else:
71         print('All caught up!')
72
73
74 def get_ahead_behind(lbranch, rbranch):
75     """Return number of commits |lbranch| is ahead & behind relative to |rbranch|."""
76     output = git(['rev-list', '--left-right', '--count', f'{lbranch}...{rbranch}']).stdout
77     return [int(x) for x in output.split()]
78
79
80 def get_tracking_branch(branch):
81     """Return remote branch that |branch| is tracking."""
82     merge = git(['config', '--local', f'branch.{branch}.merge']).stdout.strip()
83     if not merge:
84         return None
85
86     remote = git(['config', '--local', f'branch.{branch}.remote']).stdout.strip()
87     if remote:
88         if merge.startswith('refs/heads/'):
89             merge = merge[11:]
90         return f'{remote}/{merge}'
91     else:
92         return merge
93
94
95 def get_local_branch():
96     """Return the name of the local checked out branch."""
97     return git(['branch', '--show-current']).stdout.strip()
98
99
100 def get_parser():
101     """Get CLI parser."""
102     parser = argparse.ArgumentParser(description=__doc__)
103     parser.add_argument(
104         '--skip-initial-rebase-latest', dest='initial_rebase',
105         action='store_false', default=True,
106         help='skip initial rebase attempt onto the latest branch')
107     parser.add_argument(
108         '--leave-at-last-failed-rebase', dest='leave_rebase',
109         action='store_true', default=False,
110         help='leave tree state at last failing rebase')
111     parser.add_argument(
112         'branch', nargs='?',
113         help='branch to rebase onto')
114     return parser
115
116
117 def main(argv):
118     """The main entry point for scripts."""
119     parser = get_parser()
120     opts = parser.parse_args(argv)
121
122     lbranch = get_local_branch()
123     print(f'Local branch resolved to "{lbranch}"')
124     if not lbranch:
125         print('Unable to resolve local branch', file=sys.stderr)
126         return 1
127
128     if opts.branch:
129         rbranch = opts.branch
130     else:
131         rbranch = get_tracking_branch(lbranch)
132     print(f'Remote branch resolved to "{rbranch}"')
133
134     ahead, behind = get_ahead_behind(lbranch, rbranch)
135     print(f'Branch is {ahead} commits ahead and {behind} commits behind')
136
137     if not behind:
138         print('Up-to-date!')
139     elif not ahead:
140         print('Fast forwarding ...')
141         git(['merge'])
142     else:
143         if opts.initial_rebase:
144             print(f'Trying to rebase onto latest {rbranch} ... ', end='', flush=True)
145             if rebase(rbranch):
146                 print('OK!')
147                 return 0
148             print('failed; falling back to bisect')
149         rebase_bisect(lbranch, rbranch, behind, leave_rebase=opts.leave_rebase)
150
151
152 if __name__ == '__main__':
153     sys.exit(main(sys.argv[1:]))