系列导航

需求

经常写CRUD程序的小伙伴们可能都经历过定义很多Repository接口,分别做对应的实现,依赖注入并使用的场景。有的时候会发现,很多分散的XXXXRepository的逻辑都是基本一致的,于是开始思考是否可以将这些操作抽象出去,当然是可以的,而且被抽象出去的部分是可以不加改变地在今后的任何有此需求的项目中直接引入使用。

那么我们本文的需求就是:如何实现一个可重用的Repository模块。

长文预警,包含大量代码。

目标

实现通用Repository模式并进行验证。

原理和思路

通用的基础在于抽象,抽象的粒度决定了通用的程度,但是同时也决定了使用上的复杂度。对于自己的项目而言,抽象到什么程度最合适,需要自己去权衡,也许后面某个时候我会决定自己去实现一个完善的Repository库提供出来(事实上已经有很多人这样做了,我们甚至可以直接下载Nuget包进行使用,但是自己亲手去实现的过程能让你更好地去理解其中的原理,也理解如何开发一个通用的类库。)

总体思路是:在Application中定义相关的接口,在Infrastructure中实现基类的功能。

实现

通用Repository实现

对于要如何去设计一个通用的Repository库,实际上涉及的面非常多,尤其是在获取数据的时候。而且根据每个人的习惯,实现起来的方式是有比较大的差别的,尤其是关于泛型接口到底需要提供哪些方法,每个人都有自己的理解,这里我只演示基本的思路,而且尽量保持简单,关于更复杂和更全面的实现,GIthub上有很多已经写好的库可以去学习和参考,我会列在下面:

很显然,第一步要去做的是在Application/Common/Interfaces中增加一个IRepository<T>的定义用于适用不同类型的实体,然后在Infrastructure/Persistence/Repositories中创建一个基类RepositoryBase<T>实现这个接口,并有办法能提供一致的对外方法签名。

  • IRepository.cs
namespace TodoList.Application.Common.Interfaces;

public interface IRepository<T> where T : class
{
}
  • RepositoryBase.cs
using Microsoft.EntityFrameworkCore;
using TodoList.Application.Common.Interfaces;

namespace TodoList.Infrastructure.Persistence.Repositories;

public class RepositoryBase<T> : IRepository<T> where T : class
{
    private readonly TodoListDbContext _dbContext;

    public RepositoryBase(TodoListDbContext dbContext) => _dbContext = dbContext;
}

在动手实际定义IRepository<T>之前,先思考一下:对数据库的操作都会出现哪些情况:

新增实体(Create)

新增实体在Repository层面的逻辑很简单,传入一个实体对象,然后保存到数据库就可以了,没有其他特殊的需求。

  • IRepository.cs
// 省略其他...
// Create相关操作接口
Task<T> AddAsync(T entity, CancellationToken cancellationToken = default);
  • RepositoryBase.cs
// 省略其他...
public async Task<T> AddAsync(T entity, CancellationToken cancellationToken = default)
{
    await _dbContext.Set<T>().AddAsync(entity, cancellationToken);
    await _dbContext.SaveChangesAsync(cancellationToken);

    return entity;
}

更新实体(Update)

和新增实体类似,但是更新时一般是单个实体对象去操作。

  • IRepository.cs
// 省略其他...
// Update相关操作接口
Task UpdateAsync(T entity, CancellationToken cancellationToken = default);
  • RepositoryBase.cs
// 省略其他...
public async Task UpdateAsync(T entity, CancellationToken cancellationToken = default)
{
    // 对于一般的更新而言,都是Attach到实体上的,只需要设置该实体的State为Modified就可以了
    _dbContext.Entry(entity).State = EntityState.Modified;
    await _dbContext.SaveChangesAsync(cancellationToken);
}

删除实体(Delete)

对于删除实体,可能会出现两种情况:删除一个实体;或者删除一组实体。

  • IRepository.cs
// 省略其他...
// Delete相关操作接口,这里根据key删除对象的接口需要用到一个获取对象的方法
ValueTask<T?> GetAsync(object key);
Task DeleteAsync(object key);
Task DeleteAsync(T entity, CancellationToken cancellationToken = default);
Task DeleteRangeAsync(IEnumerable<T> entities, CancellationToken cancellationToken = default);
  • RepositoryBase.cs
// 省略其他...
public virtual ValueTask<T?> GetAsync(object key) => _dbContext.Set<T>().FindAsync(key);

public async Task DeleteAsync(object key)
{
    var entity = await GetAsync(key);
    if (entity is not null)
    {
        await DeleteAsync(entity);
    }
}

public async Task DeleteAsync(T entity, CancellationToken cancellationToken = default)
{
    _dbContext.Set<T>().Remove(entity);
    await _dbContext.SaveChangesAsync(cancellationToken);
}

public async Task DeleteRangeAsync(IEnumerable<T> entities, CancellationToken cancellationToken = default)
{
    _dbContext.Set<T>().RemoveRange(entities);
    await _dbContext.SaveChangesAsync(cancellationToken);
}

获取实体(Retrieve)

对于如何获取实体,是最复杂的一部分。我们不仅要考虑通过什么方式获取哪些数据,还需要考虑获取的数据有没有特殊的要求比如排序、分页、数据对象类型的转换之类的问题。

具体来说,比如下面这一个典型的LINQ查询语句:

var results = await _context.A.Join(_context.B, a => a.Id, b => b.aId, (a, b) => new
    {
        // ...
    })
    .Where(ab => ab.Name == "name" && ab.Date == DateTime.Now)
    .Select(ab => new
    {
        // ...
    })
    .OrderBy(o => o.Date)
    .Skip(20 * 1)
    .Take(20)
    .ToListAsync();

可以将整个查询结构分割成以下几个组成部分,而且每个部分基本都是以lambda表达式的方式表示的,这转化成建模的话,可以使用Expression相关的对象来表示:

  1. 查询数据集准备过程,在这个过程中可能会出现Include/Join/GroupJoin/GroupBy等等类似的关键字,它们的作用是构建一个用于接下来将要进行查询的数据集。
  2. Where子句,用于过滤查询集合。
  3. Select子句,用于转换原始数据类型到我们想要的结果类型。
  4. Order子句,用于对结果集进行排序,这里可能会包含类似:OrderBy/OrderByDescending/ThenBy/ThenByDescending等关键字。
  5. Paging子句,用于对结果集进行后端分页返回,一般都是Skip/Take一起使用。
  6. 其他子句,多数是条件控制,比如AsNoTracking/SplitQuery等等。

为了保持我们的演示不会过于复杂,我会做一些取舍。在这里的实现我参考了Edi.WangMoonglade中的相关实现。有兴趣的小伙伴也可以去找一下一个更完整的实现:Ardalis.Specification

首先来定义一个简单的ISpecification来表示查询的各类条件:

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

namespace TodoList.Application.Common.Interfaces;

public interface ISpecification<T>
{
    // 查询条件子句
    Expression<Func<T, bool>> Criteria { get; }
    // Include子句
    Func<IQueryable<T>, IIncludableQueryable<T, object>> Include { get; }
    // OrderBy子句
    Expression<Func<T, object>> OrderBy { get; }
    // OrderByDescending子句
    Expression<Func<T, object>> OrderByDescending { get; }

    // 分页相关属性
    int Take { get; }
    int Skip { get; }
    bool IsPagingEnabled { get; }
}

并实现这个泛型接口,放在Application/Common中:

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;
using TodoList.Application.Common.Interfaces;

namespace TodoList.Application.Common;

public abstract class SpecificationBase<T> : ISpecification<T>
{
    protected SpecificationBase() { }
    protected SpecificationBase(Expression<Func<T, bool>> criteria) => Criteria = criteria;

    public Expression<Func<T, bool>> Criteria { get; private set; }
    public Func<IQueryable<T>, IIncludableQueryable<T, object>> Include { get; private set; }
    public List<string> IncludeStrings { get; } = new();
    public Expression<Func<T, object>> OrderBy { get; private set; }
    public Expression<Func<T, object>> OrderByDescending { get; private set; }

    public int Take { get; private set; }
    public int Skip { get; private set; }
    public bool IsPagingEnabled { get; private set; }

    public void AddCriteria(Expression<Func<T, bool>> criteria) => Criteria = Criteria is not null ? Criteria.AndAlso(criteria) : criteria;

    protected virtual void AddInclude(Func<IQueryable<T>, IIncludableQueryable<T, object>> includeExpression) => Include = includeExpression;
    protected virtual void AddInclude(string includeString) => IncludeStrings.Add(includeString);

    protected virtual void ApplyPaging(int skip, int take)
    {
        Skip = skip;
        Take = take;
        IsPagingEnabled = true;
    }

    protected virtual void ApplyOrderBy(Expression<Func<T, object>> orderByExpression) => OrderBy = orderByExpression;
    protected virtual void ApplyOrderByDescending(Expression<Func<T, object>> orderByDescendingExpression) => OrderByDescending = orderByDescendingExpression;
}

// https://stackoverflow.com/questions/457316/combining-two-expressions-expressionfunct-bool
public static class ExpressionExtensions
{
    public static Expression<Func<T, bool>> AndAlso<T>(this Expression<Func<T, bool>> expr1, Expression<Func<T, bool>> expr2)
    {
        var parameter = Expression.Parameter(typeof(T));

        var leftVisitor = new ReplaceExpressionVisitor(expr1.Parameters[0], parameter);
        var left = leftVisitor.Visit(expr1.Body);

        var rightVisitor = new ReplaceExpressionVisitor(expr2.Parameters[0], parameter);
        var right = rightVisitor.Visit(expr2.Body);

        return Expression.Lambda<Func<T, bool>>(
            Expression.AndAlso(left ?? throw new InvalidOperationException(),
                right ?? throw new InvalidOperationException()), parameter);
    }

    private class ReplaceExpressionVisitor : ExpressionVisitor
    {
        private readonly Expression _oldValue;
        private readonly Expression _newValue;

        public ReplaceExpressionVisitor(Expression oldValue, Expression newValue)
        {
            _oldValue = oldValue;
            _newValue = newValue;
        }

        public override Expression Visit(Expression node) => node == _oldValue ? _newValue : base.Visit(node);
    }
}

为了在RepositoryBase中能够把所有的Spcification串起来形成查询子句,我们还需要定义一个用于组织Specification的SpecificationEvaluator类:

using TodoList.Application.Common.Interfaces;

namespace TodoList.Application.Common;

public class SpecificationEvaluator<T> where T : class
{
    public static IQueryable<T> GetQuery(IQueryable<T> inputQuery, ISpecification<T>? specification)
    {
        var query = inputQuery;

        if (specification?.Criteria is not null)
        {
            query = query.Where(specification.Criteria);
        }

        if (specification?.Include is not null)
        {
            query = specification.Include(query);
        }

        if (specification?.OrderBy is not null)
        {
            query = query.OrderBy(specification.OrderBy);
        }
        else if (specification?.OrderByDescending is not null)
        {
            query = query.OrderByDescending(specification.OrderByDescending);
        }

        if (specification?.IsPagingEnabled != false)
        {
            query = query.Skip(specification!.Skip).Take(specification.Take);
        }

        return query;
    }
}

IRepository中添加查询相关的接口,大致可以分为以下这几类接口,每类中又可能存在同步接口和异步接口:

  • IRepository.cs
// 省略其他...
// 1. 查询基础操作接口
IQueryable<T> GetAsQueryable();
IQueryable<T> GetAsQueryable(ISpecification<T> spec);

// 2. 查询数量相关接口
int Count(ISpecification<T>? spec = null);
int Count(Expression<Func<T, bool>> condition);
Task<int> CountAsync(ISpecification<T>? spec);

// 3. 查询存在性相关接口
bool Any(ISpecification<T>? spec);
bool Any(Expression<Func<T, bool>>? condition = null);

// 4. 根据条件获取原始实体类型数据相关接口
Task<T?> GetAsync(Expression<Func<T, bool>> condition);
Task<IReadOnlyList<T>> GetAsync();
Task<IReadOnlyList<T>> GetAsync(ISpecification<T>? spec);

// 5. 根据条件获取映射实体类型数据相关接口,涉及到Group相关操作也在其中,使用selector来传入映射的表达式
TResult? SelectFirstOrDefault<TResult>(ISpecification<T>? spec, Expression<Func<T, TResult>> selector);
Task<TResult?> SelectFirstOrDefaultAsync<TResult>(ISpecification<T>? spec, Expression<Func<T, TResult>> selector);

Task<IReadOnlyList<TResult>> SelectAsync<TResult>(Expression<Func<T, TResult>> selector);
Task<IReadOnlyList<TResult>> SelectAsync<TResult>(ISpecification<T>? spec, Expression<Func<T, TResult>> selector);
Task<IReadOnlyList<TResult>> SelectAsync<TGroup, TResult>(Expression<Func<T, TGroup>> groupExpression, Expression<Func<IGrouping<TGroup, T>, TResult>> selector, ISpecification<T>? spec = null);

有了这些基础,我们就可以去Infrastructure/Persistence/Repositories中实现RepositoryBase类剩下的关于查询部分的代码了:

  • RepositoryBase.cs
// 省略其他...
// 1. 查询基础操作接口实现
public IQueryable<T> GetAsQueryable()
    => _dbContext.Set<T>();

public IQueryable<T> GetAsQueryable(ISpecification<T> spec)
    => ApplySpecification(spec);

// 2. 查询数量相关接口实现
public int Count(Expression<Func<T, bool>> condition)
    => _dbContext.Set<T>().Count(condition);

public int Count(ISpecification<T>? spec = null)
    => null != spec ? ApplySpecification(spec).Count() : _dbContext.Set<T>().Count();

public Task<int> CountAsync(ISpecification<T>? spec)
    => ApplySpecification(spec).CountAsync();

// 3. 查询存在性相关接口实现
public bool Any(ISpecification<T>? spec)
    => ApplySpecification(spec).Any();

public bool Any(Expression<Func<T, bool>>? condition = null)
    => null != condition ? _dbContext.Set<T>().Any(condition) : _dbContext.Set<T>().Any();

// 4. 根据条件获取原始实体类型数据相关接口实现
public async Task<T?> GetAsync(Expression<Func<T, bool>> condition)
    => await _dbContext.Set<T>().FirstOrDefaultAsync(condition);

public async Task<IReadOnlyList<T>> GetAsync()
    => await _dbContext.Set<T>().AsNoTracking().ToListAsync();

public async Task<IReadOnlyList<T>> GetAsync(ISpecification<T>? spec)
    => await ApplySpecification(spec).AsNoTracking().ToListAsync();

// 5. 根据条件获取映射实体类型数据相关接口实现
public TResult? SelectFirstOrDefault<TResult>(ISpecification<T>? spec, Expression<Func<T, TResult>> selector)
    => ApplySpecification(spec).AsNoTracking().Select(selector).FirstOrDefault();

public Task<TResult?> SelectFirstOrDefaultAsync<TResult>(ISpecification<T>? spec, Expression<Func<T, TResult>> selector)
    => ApplySpecification(spec).AsNoTracking().Select(selector).FirstOrDefaultAsync();

public async Task<IReadOnlyList<TResult>> SelectAsync<TResult>(Expression<Func<T, TResult>> selector)
    => await _dbContext.Set<T>().AsNoTracking().Select(selector).ToListAsync();

public async Task<IReadOnlyList<TResult>> SelectAsync<TResult>(ISpecification<T>? spec, Expression<Func<T, TResult>> selector)
    => await ApplySpecification(spec).AsNoTracking().Select(selector).ToListAsync();

public async Task<IReadOnlyList<TResult>> SelectAsync<TGroup, TResult>(
    Expression<Func<T, TGroup>> groupExpression,
    Expression<Func<IGrouping<TGroup, T>, TResult>> selector,
    ISpecification<T>? spec = null)
    => null != spec ?
        await ApplySpecification(spec).AsNoTracking().GroupBy(groupExpression).Select(selector).ToListAsync() :
        await _dbContext.Set<T>().AsNoTracking().GroupBy(groupExpression).Select(selector).ToListAsync();

// 用于拼接所有Specification的辅助方法,接收一个`IQuerybale<T>对象(通常是数据集合)
// 和一个当前实体定义的Specification对象,并返回一个`IQueryable<T>`对象为子句执行后的结果。
private IQueryable<T> ApplySpecification(ISpecification<T>? spec)
    => SpecificationEvaluator<T>.GetQuery(_dbContext.Set<T>().AsQueryable(), spec);

引入使用

为了验证通用Repsitory的用法,我们可以先在Infrastructure/DependencyInjection.cs中进行依赖注入:

// in AddInfrastructure, 省略其他
services.AddScoped(typeof(IRepository<>), typeof(RepositoryBase<>));

验证

用于初步验证(主要是查询接口),我们在Application项目里新建文件夹TodoItems/Specs,创建一个TodoItemSpec类:

  • TodoItemSpec.cs
using TodoList.Application.Common;
using TodoList.Domain.Entities;
using TodoList.Domain.Enums;

namespace TodoList.Application.TodoItems.Specs;

public sealed class TodoItemSpec : SpecificationBase<TodoItem>
{
    public TodoItemSpec(bool done, PriorityLevel priority) : base(t => t.Done == done && t.Priority == priority)
    {
    }
}

然后我们临时使用示例接口WeatherForecastController,通过日志来看一下查询的正确性。

private readonly IRepository<TodoItem> _repository;
private readonly ILogger<WeatherForecastController> _logger;

// 为了验证,临时在这注入IRepository<TodoItem>对象,验证完后撤销修改
public WeatherForecastController(IRepository<TodoItem> repository, ILogger<WeatherForecastController> logger)
{
    _repository = repository;
    _logger = logger;
}

Get方法里增加这段逻辑用于观察日志输出:

// 记录日志
_logger.LogInformation($"maybe this log is provided by Serilog...");

var spec = new TodoItemSpec(true, PriorityLevel.High);
var items = _repository.GetAsync(spec).Result;

foreach (var item in items)
{
    _logger.LogInformation($"item: {item.Id} - {item.Title} - {item.Priority}");
}

启动Api项目然后请求示例接口,观察控制台输出:

# 以上省略,Controller日志开始...
[16:49:59 INF] maybe this log is provided by Serilog...
[16:49:59 INF] Entity Framework Core 6.0.1 initialized 'TodoListDbContext' using provider 'Microsoft.EntityFrameworkCore.SqlServer:6.0.1' with options: MigrationsAssembly=TodoList.Infrastructure, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null 
[16:49:59 INF] Executed DbCommand (51ms) [Parameters=[@__done_0='?' (DbType = Boolean), @__priority_1='?' (DbType = Int32)], CommandType='Text', CommandTimeout='30']
SELECT [t].[Id], [t].[Created], [t].[CreatedBy], [t].[Done], [t].[LastModified], [t].[LastModifiedBy], [t].[ListId], [t].[Priority], [t].[Title]
FROM [TodoItems] AS [t]
WHERE ([t].[Done] = @__done_0) AND ([t].[Priority] = @__priority_1)
# 下面这句是我们之前初始化数据库的种子数据,可以参考上一篇文章结尾的验证截图。
[16:49:59 INF] item: 87f1ddf1-e6cd-4113-74ed-08d9c5112f6b - Apples - High
[16:49:59 INF] Executing ObjectResult, writing value of type 'TodoList.Api.WeatherForecast[]'.
[16:49:59 INF] Executed action TodoList.Api.Controllers.WeatherForecastController.Get (TodoList.Api) in 160.5517ms

总结

在本文中,我大致演示了实现一个通用Repository基础框架的过程。实际上关于Repository的组织与实现有很多种实现方法,每个人的关注点和思路都会有不同,但是大的方向基本都是这样,无非是抽象的粒度和提供的接口的方便程度不同。有兴趣的像伙伴可以仔细研究一下参考资料里的第2个实现,也可以从Nuget直接下载在项目中引用使用。

感谢大家耐心看完。从下一篇文章开始,我们就进入喜闻乐见的CRUD环节。

参考资料

  1. Moonglade from Edi Wang
  2. Ardalis.Specification