How to Add custom middleware for GRPC server

1. Objective:

a. To customize a middleware to capture global code 500 error (panic error) in grpc server.

b. Grpc middleware is different from HTTP gin middleware. Gin can use use use or handlerfunc to enable middleware, but grpc can’t. Here we use the go grpc middleware plug-in to demonstrate.

c. Grpc client is only an active calling interface, so it is unnecessary to be a middleware.

2. Writing middleware

Install plug-in dependencies:

go get github.com/grpc-ecosystem/go-grpc-middleware

All the codes of middleware, in which the return value type is fixed (the return value form of go grpc middleware plug-in is the maximum value form)

package middlewares

import (
	"context"
	"fmt"

	"google.golang.org/grpc"
)


// StreamGSError500 Catching Fatal Errors in Streaming Code
func StreamGSError500(address string) grpc.StreamServerInterceptor {
	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
		
		fmt.Println("StreamGSError500 The service has been added to the listener===")
		defer func() {
			if err := recover(); err != nil {
				//Print error stack information
				fmt.Println(err)
				
			}
		}()

		err = handler(srv, stream)
		return err
	}
}

// UnaryGSError500 Catching fatal errors in simple code
func UnaryGSError500(address string) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {

		fmt.Println("UnaryGSError500 The service has been added to the listener===")
		defer func() {
			if err := recover(); err != nil {
				//Print error stack information
				fmt.Println(err)
				
			}
		}()

		resp, err := handler(ctx, req)
		return resp, err
	}
}

3. Add middleware when starting grpc server

introduce:

import (
    "github.com/grpc-ecosystem/go-grpc-middleware"
	grpcRecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
	"google.golang.org/grpc"

)

Startup:

    var address string = "127.0.0.1:9600"


	// Instantiate the grpc server and insert the median price
	grpcServer := grpc.NewServer(
		grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( // Stream Interceptor
			middlewares.StreamGSError500(address),
            grpcRecovery.StreamServerInterceptor(),
		)),
		grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( // Simple Interceptor
			middlewares.UnaryGSError500(address),
            grpcRecovery.StreamServerInterceptor(),
		)),
	)

Read More: