Skip to content

Commit a03018d

Browse files
committed
feat(pip_repository): Enable PyPi dep cycles
This patch adjusts the pip_repository interface to accept a new parameter: `composite_libs`, being a list of PyPi package names which form a cycle and must be installed together. The intuition behind this design is that a dependency cycle {a <-> b} is implemented simply as emplacing both a and b at once. Hence a dependency graph {c -> a, c -> b} has the same effect. If we modify the installation of a and b to remove their mutual dependency, and generate a c which dominates a and b, we can then modify the `requirement()` and `whl_requirement()` helper functions to recognize the requirements a and b and provide a reference to c instead.
1 parent c72c7bc commit a03018d

9 files changed

Lines changed: 153 additions & 39 deletions

File tree

python/pip_install/pip_repository.bzl

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,6 @@ def _create_repository_execution_environment(rctx):
226226

227227
return env
228228

229-
_BUILD_FILE_CONTENTS = """\
230-
package(default_visibility = ["//visibility:public"])
231-
232-
# Ensure the `requirements.bzl` source can be accessed by stardoc, since users load() from it
233-
exports_files(["requirements.bzl"])
234-
"""
235-
236229
def locked_requirements_label(ctx, attr):
237230
"""Get the preferred label for a locked requirements file based on platform.
238231
@@ -360,15 +353,16 @@ def _pip_repository_bzlmod_impl(rctx):
360353

361354
repo_name = rctx.attr.name.split("~")[-1]
362355

363-
build_contents = _BUILD_FILE_CONTENTS
364-
365356
if rctx.attr.incompatible_generate_aliases:
366357
_pkg_aliases(rctx, repo_name, bzl_packages)
358+
build_footer = ""
367359
else:
368-
build_contents += _bzlmod_pkg_aliases(repo_name, bzl_packages)
360+
build_footer = _bzlmod_pkg_aliases(repo_name, bzl_packages)
369361

370-
rctx.file("BUILD.bazel", build_contents)
371-
rctx.template("requirements.bzl", rctx.attr._template, substitutions = {
362+
rctx.file("BUILD.bazel", rctx.attr._build_template, substitutions = {
363+
"%%FOOTER%%": build_footer,
364+
})
365+
rctx.template("requirements.bzl", rctx.attr._requirements_template, substitutions = {
372366
"%%ALL_REQUIREMENTS%%": _format_repr_list([
373367
"@{}//{}".format(repo_name, p) if rctx.attr.incompatible_generate_aliases else "@{}_{}//:pkg".format(rctx.attr.name, p)
374368
for p in bzl_packages
@@ -406,9 +400,12 @@ wheels are fetched/built only for the targets specified by 'build/run/test'.
406400
allow_single_file = True,
407401
doc = "Override the requirements_lock attribute when the host platform is Windows",
408402
),
409-
"_template": attr.label(
403+
"_requirements_template": attr.label(
410404
default = ":pip_repository_requirements_bzlmod.bzl.tmpl",
411405
),
406+
"_build_template": attr.label(
407+
default = ":pip_repository_build.bazel.tmpl",
408+
),
412409
}
413410

414411
pip_repository_bzlmod = repository_rule(
@@ -422,11 +419,46 @@ def _pip_repository_impl(rctx):
422419
content = rctx.read(requirements_txt)
423420
parsed_requirements_txt = parse_requirements(content)
424421

425-
packages = [(_clean_pkg_name(name), requirement) for name, requirement in parsed_requirements_txt.requirements]
422+
# Apply name normalizations to the composite libs def once
423+
composite_libs = {
424+
_clean_pkg_name(name): [_clean_pkg_name(it) for it in components]
425+
for name, components in rctx.attr.composite_libs.items()
426+
}
426427

427-
bzl_packages = sorted([name for name, _ in packages])
428+
# Ditto for requirements defs
429+
requirements = {
430+
_clean_pkg_name(name): requirement
431+
for name, requirement in parsed_requirements_txt.requirements
432+
}
433+
434+
# Map normalized package names to a composite
435+
composite_mapping = {
436+
name: composite_name
437+
for composite_name, names in composite_libs.items()
438+
for name in names
439+
}
440+
441+
# Normal packages are defined by a single requirement.
442+
# We will deal with composites shortly.
443+
normal_packages = [
444+
(name, requirement)
445+
for name, requirement in requirements.items()
446+
if name not in composite_mapping
447+
]
448+
449+
# Composite packages are a cluster which can only be depended on together
450+
composite_packages = {
451+
_clean_pkg_name(composite_name): [
452+
(rname, requirements[rname])
453+
for rname in composite_components
454+
]
455+
for composite_name, composite_components in rctx.attr.composite_libs.items()
456+
}
457+
458+
bzl_packages = sorted([name for name, _ in requirements.items()])
428459

429460
imports = [
461+
'load("@rules_python//python:defs.bzl", "py_library")',
430462
'load("@rules_python//python/pip_install:pip_repository.bzl", "whl_library")',
431463
]
432464

@@ -463,8 +495,14 @@ def _pip_repository_impl(rctx):
463495
if rctx.attr.incompatible_generate_aliases:
464496
_pkg_aliases(rctx, rctx.attr.name, bzl_packages)
465497

466-
rctx.file("BUILD.bazel", _BUILD_FILE_CONTENTS)
467-
rctx.template("requirements.bzl", rctx.attr._template, substitutions = {
498+
rctx.template("lib.bzl", rctx.attr._lib_template, substitutions = {
499+
"%%NAME%%": rctx.attr.name,
500+
})
501+
rctx.template("BUILD.bazel", rctx.attr._build_template, substitutions = {
502+
"%%NAME%%": rctx.attr.name,
503+
"%%FOOTER%%": "",
504+
})
505+
rctx.template("requirements.bzl", rctx.attr._requirements_template, substitutions = {
468506
"%%ALL_REQUIREMENTS%%": _format_repr_list([
469507
"@{}//{}".format(rctx.attr.name, p) if rctx.attr.incompatible_generate_aliases else "@{}_{}//:pkg".format(rctx.attr.name, p)
470508
for p in bzl_packages
@@ -475,13 +513,15 @@ def _pip_repository_impl(rctx):
475513
]),
476514
"%%ANNOTATIONS%%": _format_dict(_repr_dict(annotations)),
477515
"%%CONFIG%%": _format_dict(_repr_dict(config)),
516+
"%%CLUSTERS%%": _format_dict(_repr_dict(composite_packages)),
517+
"%%CLUSTER_MAPPINGS%%": _format_dict(_repr_dict(composite_mapping)),
478518
"%%EXTRA_PIP_ARGS%%": json.encode(options),
479519
"%%IMPORTS%%": "\n".join(sorted(imports)),
480520
"%%NAME%%": rctx.attr.name,
481521
"%%PACKAGES%%": _format_repr_list(
482522
[
483523
("{}_{}".format(rctx.attr.name, p), r)
484-
for p, r in packages
524+
for p, r in normal_packages
485525
],
486526
),
487527
"%%REQUIREMENTS_LOCK%%": str(requirements_txt),
@@ -602,9 +642,18 @@ wheels are fetched/built only for the targets specified by 'build/run/test'.
602642
allow_single_file = True,
603643
doc = "Override the requirements_lock attribute when the host platform is Windows",
604644
),
605-
"_template": attr.label(
645+
"composite_libs": attr.string_list_dict(
646+
doc = "Groups of requirements which represent dependency cycles and must be treated as composites.",
647+
),
648+
"_requirements_template": attr.label(
606649
default = ":pip_repository_requirements.bzl.tmpl",
607650
),
651+
"_build_template": attr.label(
652+
default = ":pip_repository_build.bazel.tmpl",
653+
),
654+
"_lib_template": attr.label(
655+
default = ":pip_repository_lib.bzl.tmpl",
656+
),
608657
}
609658

610659
pip_repository_attrs.update(**common_attrs)
@@ -673,6 +722,8 @@ def _whl_library_impl(rctx):
673722
"--annotation",
674723
rctx.path(rctx.attr.annotation),
675724
])
725+
for d in rctx.attr.skip_deps:
726+
args.extend(["--skip", d])
676727

677728
args = _parse_optional_attrs(rctx, args)
678729

@@ -705,6 +756,10 @@ whl_library_attrs = {
705756
mandatory = True,
706757
doc = "Python requirement string describing the package to make available",
707758
),
759+
"skip_deps": attr.string_list(
760+
doc = "List of requirements to skip due to clustering",
761+
default = [],
762+
),
708763
}
709764

710765
whl_library_attrs.update(**common_attrs)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
load(":lib.bzl", "install_clusters")
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
# Ensure the `requirements.bzl` source can be accessed by stardoc, since users load() from it
6+
exports_files(["requirements.bzl"])
7+
8+
install_clusters()
9+
10+
%%FOOTER%%
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@rules_python//python:defs.bzl", "py_library")
2+
load(":requirements.bzl", "requirement", "whl_requirement", "requirement_clusters")
3+
4+
5+
def install_clusters():
6+
for cname, components in requirement_clusters.items():
7+
py_library(
8+
name = cname,
9+
deps = [requirement(c, use_clusters=False) for c, _ in components]
10+
)
11+
native.filegroup(
12+
name = "whl_" + cname,
13+
data = [whl_requirement(c, use_clusters=False) for c, _ in components]
14+
)

python/pip_install/pip_repository_requirements.bzl.tmpl

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,40 @@ all_requirements = %%ALL_REQUIREMENTS%%
1111
all_whl_requirements = %%ALL_WHL_REQUIREMENTS%%
1212

1313
_packages = %%PACKAGES%%
14+
_cluster_mappings = %%CLUSTER_MAPPINGS%%
15+
requirement_clusters = %%CLUSTERS%%
1416
_config = %%CONFIG%%
1517
_annotations = %%ANNOTATIONS%%
1618

1719
def _clean_name(name):
1820
return name.replace("-", "_").replace(".", "_").lower()
1921

20-
def requirement(name):
21-
return "@%%NAME%%_" + _clean_name(name) + "//:pkg"
22+
def requirement(name, use_clusters=True):
23+
cname = _clean_name(name)
24+
if cname in _cluster_mappings and use_clusters:
25+
return "@%%NAME%%//:" + _cluster_mappings[cname]
26+
else:
27+
return "@%%NAME%%_" + cname + "//:pkg"
2228

23-
def whl_requirement(name):
24-
return "@%%NAME%%_" + _clean_name(name) + "//:whl"
29+
def whl_requirement(name, use_clusters=True):
30+
cname = _clean_name(name)
31+
if cname in _cluster_mappings and use_clusters:
32+
return "@%%NAME%%//:whl_" + _cluster_mappings[cname]
33+
return "@%%NAME%%_" + cname + "//:whl"
2534

2635
def data_requirement(name):
36+
cname = _clean_name(name)
2737
return "@%%NAME%%_" + _clean_name(name) + "//:data"
2838

2939
def dist_info_requirement(name):
40+
cname = _clean_name(name)
3041
return "@%%NAME%%_" + _clean_name(name) + "//:dist_info"
3142

3243
def entry_point(pkg, script = None):
44+
cname = _clean_name(pkg)
3345
if not script:
3446
script = pkg
35-
return "@%%NAME%%_" + _clean_name(pkg) + "//:rules_python_wheel_entry_point_" + script
47+
return "@%%NAME%%_" + cname + "//:rules_python_wheel_entry_point_" + script
3648

3749
def _get_annotation(requirement):
3850
# This expects to parse `setuptools==58.2.0 --hash=sha256:2551203ae6955b9876741a26ab3e767bb3242dafe86a32a749ea0d78b6792f11`
@@ -43,10 +55,24 @@ def _get_annotation(requirement):
4355
def install_deps(**whl_library_kwargs):
4456
whl_config = dict(_config)
4557
whl_config.update(whl_library_kwargs)
46-
for name, requirement in _packages:
58+
# Install normal requirements
59+
for name, spec in _packages:
4760
whl_library(
4861
name = name,
49-
requirement = requirement,
50-
annotation = _get_annotation(requirement),
62+
requirement = spec,
63+
annotation = _get_annotation(spec),
5164
**whl_config
5265
)
66+
# And deal with requirement_clusters
67+
for cname, components in requirement_clusters.items():
68+
# Generate the component libraries
69+
cnames = [c[0] for c in components]
70+
for rname, spec in components:
71+
name = "%%NAME%%_" + rname
72+
whl_library(
73+
name = name,
74+
requirement = spec,
75+
annotation = _get_annotation(spec),
76+
skip_deps = cnames,
77+
**whl_config,
78+
)

python/pip_install/tools/dependency_resolver/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

python/pip_install/tools/lib/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

python/pip_install/tools/lib/annotation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class Annotation(OrderedDict):
2323
"""A python representation of `@rules_python//python:pip.bzl%package_annotation`"""
2424

2525
def __init__(self, content: Dict[str, Any]) -> None:
26-
2726
missing = []
2827
ordered_content = OrderedDict()
2928
for field in (

python/pip_install/tools/lib/annotations_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525

2626
class AnnotationsTestCase(unittest.TestCase):
27-
2827
maxDiff = None
2928

3029
def test_annotations_constructor(self) -> None:

python/pip_install/tools/wheel_installer/wheel_installer.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _generate_copy_commands(src, dest, is_executable=False) -> str:
195195

196196
def _generate_build_file_contents(
197197
name: str,
198+
repo_prefix: str,
198199
dependencies: List[str],
199200
whl_file_deps: List[str],
200201
data_exclude: List[str],
@@ -241,6 +242,7 @@ def _generate_build_file_contents(
241242
"""\
242243
load("@rules_python//python:defs.bzl", "py_library", "py_binary")
243244
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
245+
load("@{repo_prefix}//:requirements.bzl", "requirement", "whl_requirement")
244246
245247
package(default_visibility = ["//visibility:public"])
246248
@@ -272,6 +274,7 @@ def _generate_build_file_contents(
272274
)
273275
""".format(
274276
name=name,
277+
repo_prefix=repo_prefix.rstrip("_"),
275278
dependencies=",".join(sorted(dependencies)),
276279
data_exclude=json.dumps(sorted(data_exclude)),
277280
whl_file_label=bazel.WHEEL_FILE_LABEL,
@@ -297,6 +300,7 @@ def _extract_wheel(
297300
repo_prefix: str,
298301
installation_dir: Path = Path("."),
299302
annotation: Optional[annotation.Annotation] = None,
303+
skip_deps: List[str] = [],
300304
) -> None:
301305
"""Extracts wheel into given directory and creates py_library and filegroup targets.
302306
@@ -318,15 +322,21 @@ def _extract_wheel(
318322
extras_requested = extras[whl.name] if whl.name in extras else set()
319323
# Packages may create dependency cycles when specifying optional-dependencies / 'extras'.
320324
# Example: github.com/google/etils/blob/a0b71032095db14acf6b33516bca6d885fe09e35/pyproject.toml#L32.
321-
self_edge_dep = set([whl.name])
322-
whl_deps = sorted(whl.dependencies(extras_requested) - self_edge_dep)
325+
to_skip = {bazel.sanitise_name(it, "") for it in [whl.name] + skip_deps}
326+
deps = {bazel.sanitise_name(it, "") for it in whl.dependencies(extras_requested)}
327+
whl_deps = sorted(deps - to_skip)
328+
print(
329+
"While building %s\n\tDeps: %r\n\tSkipping: %r\n\tEffective: %r"
330+
% (
331+
whl.name,
332+
deps,
333+
to_skip,
334+
whl_deps,
335+
)
336+
)
323337

324-
sanitised_dependencies = [
325-
bazel.sanitised_repo_library_label(d, repo_prefix=repo_prefix) for d in whl_deps
326-
]
327-
sanitised_wheel_file_dependencies = [
328-
bazel.sanitised_repo_file_label(d, repo_prefix=repo_prefix) for d in whl_deps
329-
]
338+
sanitised_dependencies = ["requirement(%r)" % d for d in whl_deps]
339+
sanitised_wheel_file_dependencies = ["whl_requirement(%r)" % d for d in whl_deps]
330340

331341
entry_points = []
332342
for name, (module, attribute) in sorted(whl.entry_points().items()):
@@ -370,6 +380,7 @@ def _extract_wheel(
370380

371381
contents = _generate_build_file_contents(
372382
name=bazel.PY_LIBRARY_LABEL,
383+
repo_prefix=repo_prefix,
373384
dependencies=sanitised_dependencies,
374385
whl_file_deps=sanitised_wheel_file_dependencies,
375386
data_exclude=data_exclude,
@@ -396,6 +407,7 @@ def main() -> None:
396407
type=annotation.annotation_from_str_path,
397408
help="A json encoded file containing annotations for rendered packages.",
398409
)
410+
parser.add_argument("--skip", action="append", dest="skip_deps", default=[])
399411
arguments.parse_common_args(parser)
400412
args = parser.parse_args()
401413
deserialized_args = dict(vars(args))
@@ -443,6 +455,7 @@ def main() -> None:
443455
enable_implicit_namespace_pkgs=args.enable_implicit_namespace_pkgs,
444456
repo_prefix=args.repo_prefix,
445457
annotation=args.annotation,
458+
skip_deps=args.skip_deps,
446459
)
447460

448461

0 commit comments

Comments
 (0)