Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using Semmle.Util;
using Semmle.Util.Logging;

namespace Semmle.Extraction.CSharp.DependencyFetching
{
internal sealed partial class FeedManager : IDisposable
{
private readonly ILogger logger;
private readonly IDotNet dotnet;
private readonly DependencyDirectory emptyPackageDirectory;

public ImmutableHashSet<string> PrivateRegistryFeeds { get; }
public bool HasPrivateRegistryFeeds { get; }
public bool CheckNugetFeedResponsiveness { get; } = EnvironmentVariables.GetBooleanOptOut(EnvironmentVariableNames.CheckNugetFeedResponsiveness);

public FeedManager(ILogger logger, IDotNet dotnet, DependabotProxy? dependabotProxy)
{
this.logger = logger;
this.dotnet = dotnet;
PrivateRegistryFeeds = dependabotProxy?.RegistryURLs.ToImmutableHashSet() ?? [];
HasPrivateRegistryFeeds = PrivateRegistryFeeds.Count > 0;
emptyPackageDirectory = new DependencyDirectory("empty", "empty package", logger);
}

private IEnumerable<string> GetFeeds(Func<IList<string>> getNugetFeeds)
{
var results = getNugetFeeds();
var regex = EnabledNugetFeed();
foreach (var result in results)
{
var match = regex.Match(result);
if (!match.Success)
{
logger.LogError($"Failed to parse feed from '{result}'");
continue;
}

var url = match.Groups[1].Value;
if (!url.StartsWith("https://", StringComparison.InvariantCultureIgnoreCase) &&
!url.StartsWith("http://", StringComparison.InvariantCultureIgnoreCase))
{
logger.LogInfo($"Skipping feed '{url}' as it is not a valid URL.");
continue;
}

if (!string.IsNullOrWhiteSpace(url))
{
yield return url;
}
}
}

public IEnumerable<string> GetFeedsFromFolder(string folderPath) =>
GetFeeds(() => dotnet.GetNugetFeedsFromFolder(folderPath));


public IEnumerable<string> GetFeedsFromNugetConfig(string nugetConfigPath) =>
GetFeeds(() => dotnet.GetNugetFeeds(nugetConfigPath));

private string FeedsToRestoreArgument(IEnumerable<string> feeds)
{
// If there are no feeds, we want to override any default feeds that `dotnet restore` would use by passing a dummy source argument.
if (!feeds.Any())
{
return $" -s \"{emptyPackageDirectory.DirInfo.FullName}\"";
}

// Add package sources. If any are present, they override all sources specified in
// the configuration file(s).
var feedArgs = new StringBuilder();
foreach (var feed in feeds)
{
feedArgs.Append($" -s \"{feed}\"");
}

return feedArgs.ToString();
}

/// <summary>
/// Constructs the list of NuGet sources to use for this restore.
/// (1) Use the feeds we get from `dotnet nuget list source`
/// (2) Use private registries, if they are configured
/// </summary>
/// <param name="path">Path to project/solution</param>
/// <param name="reachableFeeds">The set of reachable NuGet feeds.</param>
/// <returns>A string representing the NuGet sources argument for the restore command.</returns>
public string? MakeRestoreSourcesArgument(string path, HashSet<string> reachableFeeds)
{
// Do not construct an set of explicit NuGet sources to use for restore.
if (!CheckNugetFeedResponsiveness && !HasPrivateRegistryFeeds)
{
return null;
}

// Find the path specific feeds.
var folder = FileUtils.GetDirectoryName(path, logger);
var feedsToConsider = folder is not null ? GetFeedsFromFolder(folder).ToHashSet() : new HashSet<string>();

if (HasPrivateRegistryFeeds)
{
feedsToConsider.UnionWith(PrivateRegistryFeeds);
}

var feedsToUse = CheckNugetFeedResponsiveness
? feedsToConsider.Where(reachableFeeds.Contains)
: feedsToConsider;

return FeedsToRestoreArgument(feedsToUse);
}

[GeneratedRegex(@"^E\s(.*)$", RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.Singleline)]
private static partial Regex EnabledNugetFeed();

public void Dispose()
{
emptyPackageDirectory.Dispose();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ internal sealed partial class NugetPackageRestorer : IDisposable
private readonly IDiagnosticsWriter diagnosticsWriter;
private readonly DependencyDirectory legacyPackageDirectory;
private readonly DependencyDirectory missingPackageDirectory;
private readonly DependencyDirectory emptyPackageDirectory;
private readonly ILogger logger;
private readonly ICompilationInfoContainer compilationInfoContainer;
private readonly bool checkNugetFeedResponsiveness = EnvironmentVariables.GetBooleanOptOut(EnvironmentVariableNames.CheckNugetFeedResponsiveness);
private readonly ImmutableHashSet<string> privateRegistryFeeds;
private readonly bool hasPrivateRegistryFeeds;
private readonly FeedManager feedManager;

public DependencyDirectory PackageDirectory { get; }


public NugetPackageRestorer(
FileProvider fileProvider,
FileContent fileContent,
Expand All @@ -50,16 +48,14 @@ public NugetPackageRestorer(
this.fileContent = fileContent;
this.dotnet = dotnet;
this.dependabotProxy = dependabotProxy;
this.privateRegistryFeeds = dependabotProxy?.RegistryURLs.ToImmutableHashSet() ?? [];
this.hasPrivateRegistryFeeds = privateRegistryFeeds.Count > 0;
this.diagnosticsWriter = diagnosticsWriter;
this.logger = logger;
this.compilationInfoContainer = compilationInfoContainer;

PackageDirectory = new DependencyDirectory("packages", "package", logger);
legacyPackageDirectory = new DependencyDirectory("legacypackages", "legacy package", logger);
missingPackageDirectory = new DependencyDirectory("missingpackages", "missing package", logger);
emptyPackageDirectory = new DependencyDirectory("empty", "empty package", logger);
feedManager = new FeedManager(logger, dotnet, dependabotProxy);
}

public string? TryRestore(string package)
Expand Down Expand Up @@ -118,8 +114,8 @@ public DirectoryInfo[] GetOrderedPackageVersionSubDirectories(string packagePath
public HashSet<AssemblyLookupLocation> Restore()
{
var assemblyLookupLocations = new HashSet<AssemblyLookupLocation>();
logger.LogInfo($"Checking NuGet feed responsiveness: {checkNugetFeedResponsiveness}");
compilationInfoContainer.CompilationInfos.Add(("NuGet feed responsiveness checked", checkNugetFeedResponsiveness ? "1" : "0"));
logger.LogInfo($"Checking NuGet feed responsiveness: {feedManager.CheckNugetFeedResponsiveness}");
compilationInfoContainer.CompilationInfos.Add(("NuGet feed responsiveness checked", feedManager.CheckNugetFeedResponsiveness ? "1" : "0"));

HashSet<string> explicitFeeds = [];
HashSet<string> reachableFeeds = [];
Expand All @@ -131,7 +127,7 @@ public HashSet<AssemblyLookupLocation> Restore()
// (including inherited ones) from other locations on the host outside of the working directory.
(explicitFeeds, var allFeeds) = GetAllFeeds();

if (checkNugetFeedResponsiveness)
if (feedManager.CheckNugetFeedResponsiveness)
{
var inheritedFeeds = allFeeds.Except(explicitFeeds).ToHashSet();

Expand Down Expand Up @@ -215,7 +211,7 @@ public HashSet<AssemblyLookupLocation> Restore()

var usedPackageNames = GetAllUsedPackageDirNames(dependencies);

var missingPackageLocation = checkNugetFeedResponsiveness
var missingPackageLocation = feedManager.CheckNugetFeedResponsiveness
? DownloadMissingPackagesFromSpecificFeeds(usedPackageNames, explicitFeeds)
: DownloadMissingPackages(usedPackageNames);

Expand Down Expand Up @@ -264,7 +260,7 @@ private List<string> GetReachableNuGetFeeds(HashSet<string> feedsToCheck, bool i

private bool IsDefaultFeedReachable()
{
if (checkNugetFeedResponsiveness)
if (feedManager.CheckNugetFeedResponsiveness)
{
var (initialTimeout, tryCount) = GetFeedRequestSettings(isFallback: false);
return IsFeedReachable(PublicNugetOrgFeed, initialTimeout, tryCount, out var _);
Expand Down Expand Up @@ -321,7 +317,7 @@ private IEnumerable<string> RestoreSolutions(HashSet<string> reachableFeeds, out
var projects = fileProvider.Solutions.SelectMany(solution =>
{
logger.LogInfo($"Restoring solution {solution}...");
var nugetSources = MakeRestoreSourcesArgument(solution, reachableFeeds);
var nugetSources = feedManager.MakeRestoreSourcesArgument(solution, reachableFeeds);
var res = dotnet.Restore(new(solution, PackageDirectory.DirInfo.FullName, ForceDotnetRefAssemblyFetching: true, NugetSources: nugetSources, TargetWindows: isWindows));
if (res.Success)
{
Expand All @@ -346,57 +342,6 @@ private IEnumerable<string> RestoreSolutions(HashSet<string> reachableFeeds, out
return projects;
}

private string FeedsToRestoreArgument(IEnumerable<string> feeds)
{
// If there are no feeds, we want to override any default feeds that `dotnet restore` would use by passing a dummy source argument.
if (!feeds.Any())
{
return $" -s \"{emptyPackageDirectory.DirInfo.FullName}\"";
}

// Add package sources. If any are present, they override all sources specified in
// the configuration file(s).
var feedArgs = new StringBuilder();
foreach (var feed in feeds)
{
feedArgs.Append($" -s \"{feed}\"");
}

return feedArgs.ToString();
}

/// <summary>
/// Constructs the list of NuGet sources to use for this restore.
/// (1) Use the feeds we get from `dotnet nuget list source`
/// (2) Use private registries, if they are configured
/// </summary>
/// <param name="path">Path to project/solution</param>
/// <param name="reachableFeeds">The set of reachable NuGet feeds.</param>
/// <returns>A string representing the NuGet sources argument for the restore command.</returns>
private string? MakeRestoreSourcesArgument(string path, HashSet<string> reachableFeeds)
{
// Do not construct an set of explicit NuGet sources to use for restore.
if (!checkNugetFeedResponsiveness && !hasPrivateRegistryFeeds)
{
return null;
}

// Find the path specific feeds.
var folder = GetDirectoryName(path);
var feedsToConsider = folder is not null ? GetFeeds(() => dotnet.GetNugetFeedsFromFolder(folder)).ToHashSet() : [];

if (hasPrivateRegistryFeeds)
{
feedsToConsider.UnionWith(privateRegistryFeeds);
}

var feedsToUse = checkNugetFeedResponsiveness
? feedsToConsider.Where(reachableFeeds.Contains)
: feedsToConsider;

return FeedsToRestoreArgument(feedsToUse);
}

/// <summary>
/// Executes `dotnet restore` on all projects in projects.
/// This is done in parallel for performance reasons.
Expand All @@ -421,7 +366,7 @@ private void RestoreProjects(IEnumerable<string> projects, HashSet<string> reach
foreach (var project in projectGroup)
{
logger.LogInfo($"Restoring project {project}...");
var nugetSources = MakeRestoreSourcesArgument(project, reachableFeeds);
var nugetSources = feedManager.MakeRestoreSourcesArgument(project, reachableFeeds);
var res = dotnet.Restore(new(project, PackageDirectory.DirInfo.FullName, ForceDotnetRefAssemblyFetching: true, NugetSources: nugetSources, TargetWindows: isWindows));
assets.AddDependenciesRange(res.AssetsFilePaths);
lock (sync)
Expand Down Expand Up @@ -899,47 +844,6 @@ private void EmitUnreachableFeedsDiagnostics(bool allFeedsReachable)
compilationInfoContainer.CompilationInfos.Add(("All NuGet feeds reachable", allFeedsReachable ? "1" : "0"));
}

private IEnumerable<string> GetFeeds(Func<IList<string>> getNugetFeeds)
{
var results = getNugetFeeds();
var regex = EnabledNugetFeed();
foreach (var result in results)
{
var match = regex.Match(result);
if (!match.Success)
{
logger.LogError($"Failed to parse feed from '{result}'");
continue;
}

var url = match.Groups[1].Value;
if (!url.StartsWith("https://", StringComparison.InvariantCultureIgnoreCase) &&
!url.StartsWith("http://", StringComparison.InvariantCultureIgnoreCase))
{
logger.LogInfo($"Skipping feed '{url}' as it is not a valid URL.");
continue;
}

if (!string.IsNullOrWhiteSpace(url))
{
yield return url;
}
}
}

private string? GetDirectoryName(string path)
{
try
{
return new FileInfo(path).Directory?.FullName;
}
catch (Exception exc)
{
logger.LogWarning($"Failed to get directory of '{path}': {exc}");
}
return null;
}

private (HashSet<string> explicitFeeds, HashSet<string> allFeeds) GetAllFeeds()
{
var nugetConfigs = fileProvider.NugetConfigs;
Expand Down Expand Up @@ -981,7 +885,7 @@ private IEnumerable<string> GetFeeds(Func<IList<string>> getNugetFeeds)

// Find feeds that are explicitly configured in the NuGet configuration files that we found.
var explicitFeeds = nugetConfigs
.SelectMany(config => GetFeeds(() => dotnet.GetNugetFeeds(config)))
.SelectMany(config => feedManager.GetFeedsFromNugetConfig(config))
.ToHashSet();

if (explicitFeeds.Count > 0)
Expand All @@ -995,10 +899,10 @@ private IEnumerable<string> GetFeeds(Func<IList<string>> getNugetFeeds)

// If private package registries are configured for C#, then consider those
// in addition to the ones that are configured in `nuget.config` files.
if (hasPrivateRegistryFeeds)
if (feedManager.HasPrivateRegistryFeeds)
{
logger.LogInfo($"Found {privateRegistryFeeds.Count} private registry feeds configured for C#: {string.Join(", ", privateRegistryFeeds.OrderBy(f => f))}");
explicitFeeds.UnionWith(privateRegistryFeeds);
logger.LogInfo($"Found {feedManager.PrivateRegistryFeeds.Count} private registry feeds configured for C#: {string.Join(", ", feedManager.PrivateRegistryFeeds.OrderBy(f => f))}");
explicitFeeds.UnionWith(feedManager.PrivateRegistryFeeds);
}

HashSet<string> allFeeds = [];
Expand All @@ -1008,15 +912,15 @@ private IEnumerable<string> GetFeeds(Func<IList<string>> getNugetFeeds)

// Obtain the list of feeds from the root source directory.
// If a NuGet file is present it will be respected, otherwise we will just get the machine/environment specific feeds.
var nugetFeedsFromRoot = GetFeeds(() => dotnet.GetNugetFeedsFromFolder(fileProvider.SourceDir.FullName));
var nugetFeedsFromRoot = feedManager.GetFeedsFromFolder(fileProvider.SourceDir.FullName);
allFeeds.UnionWith(nugetFeedsFromRoot);

if (nugetConfigs.Count > 0)
{
var nugetConfigFeeds = nugetConfigs
.Select(GetDirectoryName)
.Select(path => FileUtils.GetDirectoryName(path, logger))
.Where(folder => folder != null)
.SelectMany(folder => GetFeeds(() => dotnet.GetNugetFeedsFromFolder(folder!)))
.SelectMany(folder => feedManager.GetFeedsFromFolder(folder!))
.ToHashSet();

allFeeds.UnionWith(nugetConfigFeeds);
Expand All @@ -1036,15 +940,12 @@ private IEnumerable<string> GetFeeds(Func<IList<string>> getNugetFeeds)
[GeneratedRegex(@"^(.+)\.(\d+\.\d+\.\d+(-(.+))?)$", RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.Singleline)]
private static partial Regex LegacyNugetPackage();

[GeneratedRegex(@"^E\s(.*)$", RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.Singleline)]
private static partial Regex EnabledNugetFeed();

public void Dispose()
{
PackageDirectory?.Dispose();
legacyPackageDirectory?.Dispose();
missingPackageDirectory?.Dispose();
emptyPackageDirectory?.Dispose();
feedManager.Dispose();
}

/// <summary>
Expand Down
Loading
Loading