From e725f830b3a62e964a06dcf87d9aac48027b00a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 07:42:58 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- conda_classic_solver/__init__.py | 3 + conda_classic_solver/_logic.py | 11 +- conda_classic_solver/logic.py | 29 ++--- conda_classic_solver/resolve.py | 187 ++++++++----------------------- conda_classic_solver/solve.py | 118 ++++++------------- tests/test_logic.py | 14 +-- tests/test_solvers.py | 3 + 7 files changed, 107 insertions(+), 258 deletions(-) diff --git a/conda_classic_solver/__init__.py b/conda_classic_solver/__init__.py index 8acdb4d..a56346b 100644 --- a/conda_classic_solver/__init__.py +++ b/conda_classic_solver/__init__.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # Copyright (C) 2023 conda # SPDX-License-Identifier: BSD-3-Clause diff --git a/conda_classic_solver/_logic.py b/conda_classic_solver/_logic.py index 35380bb..bf6d0be 100644 --- a/conda_classic_solver/_logic.py +++ b/conda_classic_solver/_logic.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # Copyright (C) 2023 conda # SPDX-License-Identifier: BSD-3-Clause @@ -578,9 +581,7 @@ def BDD(self, lits, coeffs, nterms, lo, hi, polarity): # avoid calling self.assign here via add_new_clauses=True. # If we want to translate parts of the code to a compiled language, # self.BDD (+ its downward call stack) is the prime candidate! - ret[call_stack_pop()] = ITE( - abs(LA), thi, tlo, polarity, add_new_clauses=True - ) + ret[call_stack_pop()] = ITE(abs(LA), thi, tlo, polarity, add_new_clauses=True) return ret[target] def LinearBound(self, lits, coeffs, lo, hi, preprocess, polarity): @@ -591,9 +592,7 @@ def LinearBound(self, lits, coeffs, lo, hi, preprocess, polarity): nterms = len(coeffs) if nterms and coeffs[-1] > hi: nprune = sum(c > hi for c in coeffs) - log.log( - TRACE, "Eliminating %d/%d terms for bound violation", nprune, nterms - ) + log.log(TRACE, "Eliminating %d/%d terms for bound violation", nprune, nterms) nterms -= nprune else: nprune = 0 diff --git a/conda_classic_solver/logic.py b/conda_classic_solver/logic.py index f378b30..cdde227 100644 --- a/conda_classic_solver/logic.py +++ b/conda_classic_solver/logic.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # Copyright (C) 2023 conda # SPDX-License-Identifier: BSD-3-Clause @@ -166,14 +169,10 @@ def Any(self, vals, polarity=None, name=None): return self._eval(self._clauses.Any, (list(vals),), (), polarity, name) def AtMostOne_NSQ(self, vals, polarity=None, name=None): - return self._eval( - self._clauses.AtMostOne_NSQ, (list(vals),), (), polarity, name - ) + return self._eval(self._clauses.AtMostOne_NSQ, (list(vals),), (), polarity, name) def AtMostOne_BDD(self, vals, polarity=None, name=None): - return self._eval( - self._clauses.AtMostOne_BDD, (list(vals),), (), polarity, name - ) + return self._eval(self._clauses.AtMostOne_BDD, (list(vals),), (), polarity, name) def AtMostOne(self, vals, polarity=None, name=None): vals = list(vals) @@ -185,14 +184,10 @@ def AtMostOne(self, vals, polarity=None, name=None): return self._eval(what, (vals,), (), polarity, name) def ExactlyOne_NSQ(self, vals, polarity=None, name=None): - return self._eval( - self._clauses.ExactlyOne_NSQ, (list(vals),), (), polarity, name - ) + return self._eval(self._clauses.ExactlyOne_NSQ, (list(vals),), (), polarity, name) def ExactlyOne_BDD(self, vals, polarity=None, name=None): - return self._eval( - self._clauses.ExactlyOne_BDD, (list(vals),), (), polarity, name - ) + return self._eval(self._clauses.ExactlyOne_BDD, (list(vals),), (), polarity, name) def ExactlyOne(self, vals, polarity=None, name=None): vals = list(vals) @@ -231,17 +226,11 @@ def sat(self, additional=None, includeIf=False, names=False, limit=0): return set() if names else [] if additional: additional = (tuple(self.names.get(c, c) for c in cc) for cc in additional) - solution = self._clauses.sat( - additional=additional, includeIf=includeIf, limit=limit - ) + solution = self._clauses.sat(additional=additional, includeIf=includeIf, limit=limit) if solution is None: return None if names: - return { - nm - for nm in (self.indices.get(s) for s in solution) - if nm and nm[0] != "!" - } + return {nm for nm in (self.indices.get(s) for s in solution) if nm and nm[0] != "!"} return solution def itersolve(self, constraints=None, m=None): diff --git a/conda_classic_solver/resolve.py b/conda_classic_solver/resolve.py index d396ca9..c6bced2 100644 --- a/conda_classic_solver/resolve.py +++ b/conda_classic_solver/resolve.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # Copyright (C) 2023 conda # SPDX-License-Identifier: BSD-3-Clause @@ -76,9 +79,7 @@ def try_out_solver(sat_solver): try: try_out_solver(sat_solver) except Exception as e: - log.warning( - "Could not run SAT solver through interface '%s'.", sat_solver_choice - ) + log.warning("Could not run SAT solver through interface '%s'.", sat_solver_choice) log.debug("SAT interface error due to: %s", e, exc_info=True) else: log.debug("Using SAT solver interface '%s'.", sat_solver_choice) @@ -95,9 +96,7 @@ def try_out_solver(sat_solver): else: log.debug("Falling back to SAT solver interface '%s'.", sat_solver_choice) return sat_solver - raise CondaDependencyError( - "Cannot run solver. No functioning SAT implementations available." - ) + raise CondaDependencyError("Cannot run solver. No functioning SAT implementations available.") def exactness_and_number_of_deps(resolve_obj, ms): @@ -122,9 +121,7 @@ def __init__(self, index, processed=False, channels=()): self.index = index self.channels = channels - self._channel_priorities_map = ( - self._make_channel_priorities(channels) if channels else {} - ) + self._channel_priorities_map = self._make_channel_priorities(channels) if channels else {} self._channel_priority = context.channel_priority self._solver_ignore_timestamps = context.solver_ignore_timestamps @@ -152,10 +149,7 @@ def __init__(self, index, processed=False, channels=()): self._system_precs = { _ for _ in index - if ( - hasattr(_, "package_type") - and _.package_type == PackageType.VIRTUAL_SYSTEM - ) + if (hasattr(_, "package_type") and _.package_type == PackageType.VIRTUAL_SYSTEM) } # sorting these in reverse order is effectively prioritizing @@ -187,9 +181,7 @@ def default_filter(self, features=None, filter=None): else: filter.clear() - filter.update( - {make_feature_record(fstr): False for fstr in self.trackers.keys()} - ) + filter.update({make_feature_record(fstr): False for fstr in self.trackers.keys()}) if features: filter.update({make_feature_record(fstr): True for fstr in features}) return filter @@ -215,11 +207,7 @@ def v_(spec): return v_ms_(spec) if isinstance(spec, MatchSpec) else v_fkey_(spec) def v_ms_(ms): - return ( - optional - and ms.optional - or any(v_fkey_(fkey) for fkey in self.find_matches(ms)) - ) + return optional and ms.optional or any(v_fkey_(fkey) for fkey in self.find_matches(ms)) def v_fkey_(prec): val = filter.get(prec) @@ -256,15 +244,11 @@ def is_valid_prec(prec): if val is None: filter_out[prec] = False try: - has_valid_deps = all( - is_valid_spec(ms) for ms in self.ms_depends(prec) - ) + has_valid_deps = all(is_valid_spec(ms) for ms in self.ms_depends(prec)) except InvalidSpec: val = filter_out[prec] = "invalid dep specs" else: - val = filter_out[prec] = ( - False if has_valid_deps else "invalid depends specs" - ) + val = filter_out[prec] = False if has_valid_deps else "invalid depends specs" return not val return is_valid(spec_or_prec) @@ -330,17 +314,13 @@ def verify_specs(self, specs): else: non_tf_specs.append(ms) bad_deps.extend( - (spec,) - for spec in non_tf_specs - if (not spec.optional and not self.find_matches(spec)) + (spec,) for spec in non_tf_specs if (not spec.optional and not self.find_matches(spec)) ) if bad_deps: raise ResolvePackageNotFound(bad_deps) return tuple(non_tf_specs), feature_names - def _classify_bad_deps( - self, bad_deps, specs_to_add, history_specs, strict_channel_priority - ): + def _classify_bad_deps(self, bad_deps, specs_to_add, history_specs, strict_channel_priority): classes = { "python": set(), "request_conflict_with_history": set(), @@ -361,8 +341,7 @@ def _classify_bad_deps( if python_first_specs: python_spec = python_first_specs[0] if not ( - set(self.find_matches(python_spec)) - & set(self.find_matches(chain[-1])) + set(self.find_matches(python_spec)) & set(self.find_matches(chain[-1])) ): classes["python"].add( ( @@ -372,9 +351,7 @@ def _classify_bad_deps( ) elif chain[-1].name.startswith("__"): version = [_ for _ in self._system_precs if _.name == chain[-1].name] - virtual_package_version = ( - version[0].version if version else "not available" - ) + virtual_package_version = version[0].version if version else "not available" classes["virtual_package"].add((tuple(chain), virtual_package_version)) elif chain[0] in specs_to_add: match = False @@ -386,23 +363,15 @@ def _classify_bad_deps( match = True if not match: - classes["direct"].add( - (tuple(chain), str(MatchSpec(chain[0], target=None))) - ) + classes["direct"].add((tuple(chain), str(MatchSpec(chain[0], target=None)))) else: - if len(chain) > 1 or any( - len(c) >= 1 and c[0] == chain[0] for c in bad_deps - ): - classes["direct"].add( - (tuple(chain), str(MatchSpec(chain[0], target=None))) - ) + if len(chain) > 1 or any(len(c) >= 1 and c[0] == chain[0] for c in bad_deps): + classes["direct"].add((tuple(chain), str(MatchSpec(chain[0], target=None)))) if classes["python"]: # filter out plain single-entry python conflicts. The python section explains these. classes["direct"] = [ - _ - for _ in classes["direct"] - if _[1].startswith("python ") or len(_[0]) > 1 + _ for _ in classes["direct"] if _[1].startswith("python ") or len(_[0]) > 1 ] return classes @@ -426,9 +395,7 @@ def find_conflicts(self, specs, specs_to_add=None, history_specs=None): strict_channel_priority = context.channel_priority == ChannelPriority.STRICT raise UnsatisfiableError(bad_deps, strict=strict_channel_priority) - def breadth_first_search_for_dep_graph( - self, root_spec, target_name, dep_graph, num_targets=1 - ): + def breadth_first_search_for_dep_graph(self, root_spec, target_name, dep_graph, num_targets=1): """Return shorted path from root_spec to target_name""" queue = [] queue.append([root_spec]) @@ -584,9 +551,7 @@ def build_conflict_map(self, specs, specs_to_add=None, history_specs=None): lroots = [_ for _ in roots] current_shortest_chain = [] shortest_node = None - requested_spec_unsat = frozenset(nodes).intersection( - {_.name for _ in roots} - ) + requested_spec_unsat = frozenset(nodes).intersection({_.name for _ in roots}) if requested_spec_unsat: chains.append([_ for _ in roots if _.name in requested_spec_unsat]) shortest_node = chains[-1][0] @@ -628,19 +593,10 @@ def _get_strict_channel(self, package_name): channel_name = self._strict_channel_cache[package_name] except KeyError: if package_name in self.groups: - all_channel_names = { - prec.channel.name for prec in self.groups[package_name] - } - by_cp = { - self._channel_priorities_map.get(cn, 1): cn - for cn in all_channel_names - } - highest_priority = sorted(by_cp)[ - 0 - ] # highest priority is the lowest number - channel_name = self._strict_channel_cache[package_name] = by_cp[ - highest_priority - ] + all_channel_names = {prec.channel.name for prec in self.groups[package_name]} + by_cp = {self._channel_priorities_map.get(cn, 1): cn for cn in all_channel_names} + highest_priority = sorted(by_cp)[0] # highest priority is the lowest number + channel_name = self._strict_channel_cache[package_name] = by_cp[highest_priority] return channel_name @memoizemethod @@ -662,9 +618,7 @@ def _get_package_pool(self, specs): return pool @time_recorder(module_name=__name__) - def get_reduced_index( - self, explicit_specs, sort_by_exactness=True, exit_on_conflict=False - ): + def get_reduced_index(self, explicit_specs, sort_by_exactness=True, exit_on_conflict=False): # TODO: fix this import; this is bad from conda.core.subdir_data import make_feature_record @@ -728,16 +682,12 @@ def filter_group(_specs): explicit_spec_package_pool.get(name) and prec not in explicit_spec_package_pool[name] ): - filter_out[prec] = ( - f"incompatible with required spec {top_level_spec}" - ) + filter_out[prec] = f"incompatible with required spec {top_level_spec}" continue unsatisfiable_dep_specs = set() for ms in self.ms_depends(prec): if not ms.optional and not any( - rec - for rec in self.find_matches(ms) - if not filter_out.get(rec, False) + rec for rec in self.find_matches(ms) if not filter_out.get(rec, False) ): unsatisfiable_dep_specs.add(ms) if unsatisfiable_dep_specs: @@ -810,9 +760,7 @@ def filter_group(_specs): return {} # Determine all valid packages in the dependency graph - reduced_index2 = { - prec: prec for prec in (make_feature_record(fstr) for fstr in features) - } + reduced_index2 = {prec: prec for prec in (make_feature_record(fstr) for fstr in features)} specs_by_name_seed = {} for s in explicit_specs: specs_by_name_seed[s.name] = specs_by_name_seed.get(s.name, []) + [s] @@ -827,9 +775,7 @@ def filter_group(_specs): strict_channel_name = self._get_strict_channel(add_these_precs2[0].name) add_these_precs2 = tuple( - prec - for prec in add_these_precs2 - if prec.channel.name == strict_channel_name + prec for prec in add_these_precs2 if prec.channel.name == strict_channel_name ) reduced_index2.update((prec, prec) for prec in add_these_precs2) @@ -844,9 +790,7 @@ def filter_group(_specs): dep_specs = set(self.ms_depends(pkg)) for dep in dep_specs: specs = specs_by_name.get(dep.name, []) - if dep not in specs and ( - not specs or dep.strictness >= specs[0].strictness - ): + if dep not in specs and (not specs or dep.strictness >= specs[0].strictness): specs.insert(0, dep) specs_by_name[dep.name] = specs @@ -856,17 +800,14 @@ def filter_group(_specs): # specs_added = [] ms = dep_specs.pop() seen_specs.add(ms) - for dep_pkg in ( - _ for _ in self.find_matches(ms) if _ not in reduced_index2 - ): + for dep_pkg in (_ for _ in self.find_matches(ms) if _ not in reduced_index2): if not self.valid2(dep_pkg, filter_out): continue # expand the reduced index if not using strict channel priority, # or if using it and this package is in the appropriate channel if not strict_channel_priority or ( - self._get_strict_channel(dep_pkg.name) - == dep_pkg.channel.name + self._get_strict_channel(dep_pkg.name) == dep_pkg.channel.name ): reduced_index2[dep_pkg] = dep_pkg @@ -958,18 +899,14 @@ def _make_channel_priorities(channels): priorities_map = {} for priority_counter, chn in enumerate( itertools.chain.from_iterable( - (Channel(cc) for cc in c._channels) - if isinstance(c, MultiChannel) - else (c,) + (Channel(cc) for cc in c._channels) if isinstance(c, MultiChannel) else (c,) for c in (Channel(c) for c in channels) ) ): channel_name = chn.name if channel_name in priorities_map: continue - priorities_map[channel_name] = min( - priority_counter, MAX_CHANNEL_PRIORITY - 1 - ) + priorities_map[channel_name] = min(priority_counter, MAX_CHANNEL_PRIORITY - 1) return priorities_map def get_pkgs(self, ms, emptyok=False): # pragma: no cover @@ -1005,9 +942,7 @@ def push_MatchSpec(self, C, spec): simple = spec._is_single() nm = spec.get_exact_value("name") tf = frozenset( - _tf - for _tf in (f.strip() for f in spec.get_exact_value("track_features") or ()) - if _tf + _tf for _tf in (f.strip() for f in spec.get_exact_value("track_features") or ()) if _tf ) if nm: @@ -1060,9 +995,7 @@ def gen_clauses(self): C.Require(C.Or, nkey, self.push_MatchSpec(C, ms)) if log.isEnabledFor(DEBUG): - log.debug( - "gen_clauses returning with clause count: %d", C.get_clause_count() - ) + log.debug("gen_clauses returning with clause count: %d", C.get_clause_count()) return C def generate_spec_constraints(self, C, specs): @@ -1087,9 +1020,7 @@ def generate_feature_count(self, C): return result def generate_update_count(self, C, specs): - return { - "!" + ms.target: 1 for ms in specs if ms.target and C.from_name(ms.target) - } + return {"!" + ms.target: 1 for ms in specs if ms.target and C.from_name(ms.target)} def generate_feature_metric(self, C): eq = {} # a C.minimize() objective: dict[varname, coeff] @@ -1102,14 +1033,10 @@ def generate_feature_metric(self, C): prec_feats = {self.to_sat_name(prec): set(prec.features) for prec in group} active_feats = set.union(*prec_feats.values()).intersection(self.trackers) for feat in active_feats: - clause_id_for_feature = self.push_MatchSpec( - C, MatchSpec(track_features=feat) - ) + clause_id_for_feature = self.push_MatchSpec(C, MatchSpec(track_features=feat)) for prec_sat_name, features in prec_feats.items(): if feat not in features: - feature_metric_id = self.to_feature_metric_id( - prec_sat_name, feat - ) + feature_metric_id = self.to_feature_metric_id(prec_sat_name, feat) C.name_var( C.And(prec_sat_name, clause_id_for_feature), feature_metric_id, @@ -1272,9 +1199,7 @@ def mysat(specs, add_if=False): C = r2.gen_clauses() # This first result is just a single unsatisfiable core. There may be several. final_unsat_specs = tuple( - minimal_unsatisfiable_subset( - specs, sat=mysat, explicit_specs=explicit_specs - ) + minimal_unsatisfiable_subset(specs, sat=mysat, explicit_specs=explicit_specs) ) else: final_unsat_specs = None @@ -1327,14 +1252,8 @@ def get_(name, snames): get_(MatchSpec(spec).name, snames) if len(snames) < len(sat_name_map): limit = snames - xtra = [ - rec - for sat_name, rec in sat_name_map.items() - if rec["name"] not in snames - ] - log.debug( - "Limiting solver to the following packages: %s", ", ".join(limit) - ) + xtra = [rec for sat_name, rec in sat_name_map.items() if rec["name"] not in snames] + log.debug("Limiting solver to the following packages: %s", ", ".join(limit)) if xtra: log.debug("Packages to be preserved: %s", xtra) return limit, xtra @@ -1363,9 +1282,7 @@ def install_specs(self, specs, installed, update_deps=True): # TODO: fix target here spec = MatchSpec(name=name, target=prec.dist_str()) else: - spec = MatchSpec( - name=name, version=version, build=build, channel=schannel - ) + spec = MatchSpec(name=name, version=version, build=build, channel=schannel) specs.insert(0, spec) return tuple(specs), preserve @@ -1464,9 +1381,7 @@ def solve( if not_found_packages: raise ResolvePackageNotFound(not_found_packages) elif wrong_version_packages: - raise UnsatisfiableError( - [[d] for d in wrong_version_packages], chains=False - ) + raise UnsatisfiableError([[d] for d in wrong_version_packages], chains=False) if should_retry_solve: # We don't want to call find_conflicts until our last try. # This jumps back out to conda/cli/install.py, where the @@ -1485,9 +1400,7 @@ def mysat(specs, add_if=False): # Return a solution of packages def clean(sol): return [ - q - for q in (C.from_index(s) for s in sol) - if q and q[0] != "!" and "@" not in q + q for q in (C.from_index(s) for s in sol) if q and q[0] != "!" and "@" not in q ] def is_converged(solution): @@ -1544,9 +1457,7 @@ def is_converged(solution): # Requested packages: maximize versions log.debug("Solve: maximize versions of requested packages") - eq_req_c, eq_req_v, eq_req_b, eq_req_a, eq_req_t = r2.generate_version_metrics( - C, specr - ) + eq_req_c, eq_req_v, eq_req_b, eq_req_a, eq_req_t = r2.generate_version_metrics(C, specr) solution, obj3a = C.minimize(eq_req_c, solution) solution, obj3 = C.minimize(eq_req_v, solution) log.debug("Initial package channel/version metric: %d/%d", obj3a, obj3) @@ -1666,6 +1577,4 @@ def is_converged(solution): # for psol in psolutions] # return sorted(Dist(stripfeat(dname)) for dname in psolutions[0]) - return sorted( - (new_index[sat_name] for sat_name in psolutions[0]), key=lambda x: x.name - ) + return sorted((new_index[sat_name] for sat_name in psolutions[0]), key=lambda x: x.name) diff --git a/conda_classic_solver/solve.py b/conda_classic_solver/solve.py index 2ce3737..77a7c57 100644 --- a/conda_classic_solver/solve.py +++ b/conda_classic_solver/solve.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # Copyright (C) 2023 conda # SPDX-License-Identifier: BSD-3-Clause @@ -123,16 +126,11 @@ def solve_final_state( deps_modifier = context.deps_modifier else: deps_modifier = DepsModifier(str(deps_modifier).lower()) - ignore_pinned = ( - context.ignore_pinned if ignore_pinned is NULL else ignore_pinned - ) + ignore_pinned = context.ignore_pinned if ignore_pinned is NULL else ignore_pinned force_remove = context.force_remove if force_remove is NULL else force_remove log.debug( - "solving prefix %s\n" - " specs_to_remove: %s\n" - " specs_to_add: %s\n" - " prune: %s", + "solving prefix %s\n" " specs_to_remove: %s\n" " specs_to_add: %s\n" " prune: %s", self.prefix, self.specs_to_remove, self.specs_to_add, @@ -214,15 +212,11 @@ def solve_final_state( ssc = self._add_specs(ssc) solution_precs = copy.copy(ssc.solution_precs) - pre_packages = self.get_request_package_in_solution( - ssc.solution_precs, ssc.specs_map - ) + pre_packages = self.get_request_package_in_solution(ssc.solution_precs, ssc.specs_map) ssc = self._find_inconsistent_packages(ssc) # this will prune precs that are deps of precs that get removed due to conflicts ssc = self._run_sat(ssc) - post_packages = self.get_request_package_in_solution( - ssc.solution_precs, ssc.specs_map - ) + post_packages = self.get_request_package_in_solution(ssc.solution_precs, ssc.specs_map) if ssc.update_modifier == UpdateModifier.UPDATE_SPECS: constrained = self.get_constrained_packages( @@ -235,17 +229,14 @@ def solve_final_state( # if there were any conflicts, we need to add their orphaned deps back in if ssc.add_back_map: orphan_precs = ( - set(solution_precs) - - set(ssc.solution_precs) - - set(ssc.add_back_map) + set(solution_precs) - set(ssc.solution_precs) - set(ssc.add_back_map) ) solution_prec_names = [_.name for _ in ssc.solution_precs] ssc.solution_precs.extend( [ _ for _ in orphan_precs - if _.name not in ssc.specs_map - and _.name not in solution_prec_names + if _.name not in ssc.specs_map and _.name not in solution_prec_names ] ) @@ -287,9 +278,7 @@ def determine_constricting_specs(self, spec, solution_precs): else: constricting.append((prec.name, m_dep)) - hard_constricting = [ - i for i in constricting if i[1].version.matcher_vo <= highest_version - ] + hard_constricting = [i for i in constricting if i[1].version.matcher_vo <= highest_version] if len(hard_constricting) == 0: return None @@ -339,9 +328,7 @@ def empty_package_list(pkg): if pkg.name.startswith("__"): # ignore virtual packages continue current_version = max(i[1] for i in pre_packages[pkg.name]) - if current_version == max( - i.version for i in index_keys if i.name == pkg.name - ): + if current_version == max(i.version for i in index_keys if i.name == pkg.name): continue else: if post_packages == pre_packages: @@ -368,9 +355,7 @@ def _collect_all_metadata(self, ssc): "console_shortcut", "powershell_shortcut", ): - if pkg_name not in ssc.specs_map and ssc.prefix_data.get( - pkg_name, None - ): + if pkg_name not in ssc.specs_map and ssc.prefix_data.get(pkg_name, None): ssc.specs_map[pkg_name] = MatchSpec(pkg_name) # Add virtual packages so they are taken into account by the solver @@ -416,9 +401,7 @@ def _remove_specs(self, ssc): # SAT for spec removal determination, we can use the PrefixGraph and simple tree # traversal if we're careful about how we handle features. We still invoke sat via # `r.solve()` later. - _track_fts_specs = ( - spec for spec in self.specs_to_remove if "track_features" in spec - ) + _track_fts_specs = (spec for spec in self.specs_to_remove if "track_features" in spec) feature_names = set( chain.from_iterable( spec.get_raw_value("track_features") for spec in _track_fts_specs @@ -539,9 +522,7 @@ def _package_has_updates(self, ssc, spec, installed_pool): else spec ) - def _should_freeze( - self, ssc, target_prec, conflict_specs, explicit_pool, installed_pool - ): + def _should_freeze(self, ssc, target_prec, conflict_specs, explicit_pool, installed_pool): # never, ever freeze anything if we have no history. if not ssc.specs_from_history_map: return False @@ -577,19 +558,13 @@ def _add_specs(self, ssc): # Ignore installed specs on prune. installed_specs = () else: - installed_specs = [ - record.to_match_spec() for record in ssc.prefix_data.iter_records() - ] + installed_specs = [record.to_match_spec() for record in ssc.prefix_data.iter_records()] - conflict_specs = ( - ssc.r.get_conflicting_specs(installed_specs, self.specs_to_add) or tuple() - ) + conflict_specs = ssc.r.get_conflicting_specs(installed_specs, self.specs_to_add) or tuple() conflict_specs = {spec.name for spec in conflict_specs} for pkg_name, spec in ssc.specs_map.items(): - matches_for_spec = tuple( - prec for prec in ssc.solution_precs if spec.match(prec) - ) + matches_for_spec = tuple(prec for prec in ssc.solution_precs if spec.match(prec)) if matches_for_spec: if len(matches_for_spec) != 1: raise CondaError( @@ -625,18 +600,14 @@ def _add_specs(self, ssc): target=target_prec.dist_str(), ) else: - ssc.specs_map[pkg_name] = MatchSpec( - pkg_name, target=target_prec.dist_str() - ) + ssc.specs_map[pkg_name] = MatchSpec(pkg_name, target=target_prec.dist_str()) pin_overrides = set() for s in ssc.pinned_specs: if s.name in explicit_pool: if s.name not in self.specs_to_add_names and not ssc.ignore_pinned: ssc.specs_map[s.name] = MatchSpec(s, optional=False) - elif explicit_pool[s.name] & ssc.r._get_package_pool([s]).get( - s.name, set() - ): + elif explicit_pool[s.name] & ssc.r._get_package_pool([s]).get(s.name, set()): ssc.specs_map[s.name] = MatchSpec(s, optional=False) pin_overrides.add(s.name) else: @@ -653,9 +624,7 @@ def _add_specs(self, ssc): # optimal output all the time. It would probably also get rid of the need # to retry with an unfrozen (UPDATE_SPECS) solve. if ssc.update_modifier == UpdateModifier.FREEZE_INSTALLED: - precs = [ - _ for _ in ssc.prefix_data.iter_records() if _.name not in ssc.specs_map - ] + precs = [_ for _ in ssc.prefix_data.iter_records() if _.name not in ssc.specs_map] for prec in precs: if prec.name not in conflict_specs: ssc.specs_map[prec.name] = prec.to_match_spec() @@ -744,9 +713,7 @@ def _add_specs(self, ssc): # anything here - let python float when it hasn't been explicitly specified python_spec = ssc.specs_map.get("python", MatchSpec("python")) if not python_spec.get("version"): - pinned_version = ( - get_major_minor_version(python_prefix_rec.version) + ".*" - ) + pinned_version = get_major_minor_version(python_prefix_rec.version) + ".*" python_spec = MatchSpec(python_spec, version=pinned_version) spec_set = (python_spec,) + tuple(self.specs_to_add) @@ -770,9 +737,7 @@ def _add_specs(self, ssc): # add in explicitly requested specs from specs_to_add # this overrides any name-matching spec already in the spec map - ssc.specs_map.update( - (s.name, s) for s in self.specs_to_add if s.name not in pin_overrides - ) + ssc.specs_map.update((s.name, s) for s in self.specs_to_add if s.name not in pin_overrides) # As a business rule, we never want to downgrade conda below the current version, # unless that's requested explicitly by the user (which we actively discourage). @@ -780,13 +745,9 @@ def _add_specs(self, ssc): conda_prefix_rec = ssc.prefix_data.get("conda") if conda_prefix_rec: version_req = f">={conda_prefix_rec.version}" - conda_requested_explicitly = any( - s.name == "conda" for s in self.specs_to_add - ) + conda_requested_explicitly = any(s.name == "conda" for s in self.specs_to_add) conda_spec = ssc.specs_map["conda"] - conda_in_specs_to_add_version = ssc.specs_map.get("conda", {}).get( - "version" - ) + conda_in_specs_to_add_version = ssc.specs_map.get("conda", {}).get("version") if not conda_in_specs_to_add_version: conda_spec = MatchSpec(conda_spec, version=version_req) if context.auto_update_conda and not conda_requested_explicitly: @@ -822,10 +783,7 @@ def _run_sat(self, ssc): # several times, each time making modifications to loosen constraints. conflicting_specs = set( - ssc.r.get_conflicting_specs( - tuple(final_environment_specs), self.specs_to_add - ) - or [] + ssc.r.get_conflicting_specs(tuple(final_environment_specs), self.specs_to_add) or [] ) while conflicting_specs: specs_modified = False @@ -851,9 +809,9 @@ def _run_sat(self, ssc): if conflicting_pinned_specs.get(True): in_specs_map = grouped_specs.get(True, ()) pinned_conflicts = conflicting_pinned_specs.get(True, ()) - in_specs_map_or_specs_to_add = ( - set(in_specs_map) | set(self.specs_to_add) - ) - set(pinned_conflicts) + in_specs_map_or_specs_to_add = (set(in_specs_map) | set(self.specs_to_add)) - set( + pinned_conflicts + ) raise SpecsConfigurationConflictError( sorted(s.__str__() for s in in_specs_map_or_specs_to_add), @@ -872,9 +830,7 @@ def _run_sat(self, ssc): ssc.specs_map[spec.name] = neutered_spec if specs_modified: conflicting_specs = set( - ssc.r.get_conflicting_specs( - tuple(final_environment_specs), self.specs_to_add - ) + ssc.r.get_conflicting_specs(tuple(final_environment_specs), self.specs_to_add) ) else: # Let r.solve() use r.find_conflicts() to report conflict chains. @@ -916,9 +872,7 @@ def _run_sat(self, ssc): if not spec: # filter out solution precs and reinsert the conflict. Any resolution # of the conflict should be explicit (i.e. it must be in ssc.specs_map) - ssc.solution_precs = [ - _ for _ in ssc.solution_precs if _.name != name - ] + ssc.solution_precs = [_ for _ in ssc.solution_precs if _.name != name] ssc.solution_precs.append(prec) final_environment_specs.add(spec) @@ -952,9 +906,7 @@ def _post_sat_handling(self, ssc): } remove_before_adding_back = {prec.name for prec in only_add_these} _no_deps_solution = IndexedSet( - prec - for prec in _no_deps_solution - if prec.name not in remove_before_adding_back + prec for prec in _no_deps_solution if prec.name not in remove_before_adding_back ) _no_deps_solution |= only_add_these ssc.solution_precs = _no_deps_solution @@ -993,9 +945,7 @@ def _post_sat_handling(self, ssc): for node in removed_nodes if node.name not in specs_to_remove_names ) - ssc.solution_precs = tuple( - PrefixGraph((*graph.graph, *filter(None, add_back))).graph - ) + ssc.solution_precs = tuple(PrefixGraph((*graph.graph, *filter(None, add_back))).graph) # TODO: check if solution is satisfiable, and emit warning if it's not @@ -1013,9 +963,7 @@ def _post_sat_handling(self, ssc): update_names = set() for spec in self.specs_to_add: node = graph.get_node_by_name(spec.name) - update_names.update( - ancest_rec.name for ancest_rec in graph.all_ancestors(node) - ) + update_names.update(ancest_rec.name for ancest_rec in graph.all_ancestors(node)) specs_map = {name: MatchSpec(name) for name in update_names} # Remove pinned_specs and any python spec (due to major-minor pinning business rule). diff --git a/tests/test_logic.py b/tests/test_logic.py index 85d8f08..ff26b6e 100644 --- a/tests/test_logic.py +++ b/tests/test_logic.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # SPDX-License-Identifier: BSD-3-Clause from itertools import chain, combinations, permutations, product @@ -119,8 +122,7 @@ def my_TEST(Mfunc, Cfunc, mmin, mmax, is_iter): Cpos.new_var(nm) Cneg.new_var(nm) ij2 = tuple( - C.from_index(k) if isinstance(k, int) and k not in {TRUE, FALSE} else k - for k in ij + C.from_index(k) if isinstance(k, int) and k not in {TRUE, FALSE} else k for k in ij ) if is_iter: x = Cfunc.__get__(C, Clauses)(ij2) @@ -322,14 +324,10 @@ def test_LinearBound(): Cpos.Require(Cpos.LinearBound, eq, rhs[0], rhs[1]) Cneg.Prevent(Cneg.LinearBound, eq, rhs[0], rhs[1]) if x != FALSE: - for _, sol in zip( - range(max_iter), C.itersolve([] if x == TRUE else [(x,)], N) - ): + for _, sol in zip(range(max_iter), C.itersolve([] if x == TRUE else [(x,)], N)): assert rhs[0] <= my_EVAL(eq2, sol) <= rhs[1], C.as_list() if x != TRUE: - for _, sol in zip( - range(max_iter), C.itersolve([] if x == TRUE else [(C.Not(x),)], N) - ): + for _, sol in zip(range(max_iter), C.itersolve([] if x == TRUE else [(C.Not(x),)], N)): assert not (rhs[0] <= my_EVAL(eq2, sol) <= rhs[1]), C.as_list() for _, sol in zip(range(max_iter), Cpos.itersolve([], N)): assert rhs[0] <= my_EVAL(eq2, sol) <= rhs[1], ("Cpos", Cpos.as_list()) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index bbe6e58..edede5c 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -1,3 +1,6 @@ +# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause # Copyright (C) 2012 Anaconda, Inc # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations