001/*
002 * Copyright (c) 2007-2022 The Cascading Authors. All Rights Reserved.
003 *
004 * Project and contact information: https://cascading.wensel.net/
005 *
006 * This file is part of the Cascading project.
007 *
008 * Licensed under the Apache License, Version 2.0 (the "License");
009 * you may not use this file except in compliance with the License.
010 * You may obtain a copy of the License at
011 *
012 *     http://www.apache.org/licenses/LICENSE-2.0
013 *
014 * Unless required by applicable law or agreed to in writing, software
015 * distributed under the License is distributed on an "AS IS" BASIS,
016 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
017 * See the License for the specific language governing permissions and
018 * limitations under the License.
019 */
020
021package cascading.nested.core.aggregate;
022
023import java.beans.ConstructorProperties;
024import java.util.function.Consumer;
025
026import cascading.tuple.Fields;
027import cascading.tuple.Tuple;
028import cascading.tuple.type.CoercibleType;
029
030/**
031 * Class AverageDoubleNestedAggregate is a @{link cascading.nested.core.NestedAggregate} implementation for averaging
032 * the elements collected from the parent container object.
033 * <p>
034 * Optionally null values can be ignored or counted against the average.
035 *
036 * @param <Node>
037 */
038public class AverageDoubleNestedAggregate<Node> extends BaseNumberNestedAggregate<Node, Double, BaseNumberNestedAggregate.BaseContext<Double, Node>>
039  {
040  public enum Include
041    {
042      ALL,
043      NO_NULLS
044    }
045
046  public static class Context<Node> extends BaseContext<Double, Node>
047    {
048    final Consumer<Double> aggregate;
049
050    int count = 0;
051    double sum = 0D;
052
053    public Context( BaseNumberNestedAggregate<Node, Double, BaseContext<Double, Node>> aggregateFunction, CoercibleType<Node> coercibleType, Include include )
054      {
055      super( aggregateFunction, coercibleType );
056
057      switch( include )
058        {
059        case ALL:
060          aggregate = this::aggregateAll;
061          break;
062        case NO_NULLS:
063          aggregate = this::aggregateNoNulls;
064          break;
065        default:
066          throw new IllegalArgumentException( "unknown include type, got: " + include );
067        }
068      }
069
070    @Override
071    protected void aggregateFilteredValue( Double value )
072      {
073      aggregate.accept( value );
074      }
075
076    protected void aggregateNoNulls( Double value )
077      {
078      if( value == null )
079        return;
080
081      count++;
082      sum += value;
083      }
084
085    protected void aggregateAll( Double value )
086      {
087      count++;
088
089      if( value == null )
090        return;
091
092      sum += value;
093      }
094
095    @Override
096    protected void completeAggregateValue( Tuple results )
097      {
098      results.set( 0, sum / count );
099      }
100
101    @Override
102    public void reset()
103      {
104      count = 0;
105      sum = 0D;
106      super.reset();
107      }
108    }
109
110  final protected Include include;
111
112  /**
113   * @param declaredFields
114   */
115  @ConstructorProperties({"declaredFields"})
116  public AverageDoubleNestedAggregate( Fields declaredFields )
117    {
118    this( declaredFields, Include.ALL );
119    }
120
121  @ConstructorProperties({"declaredFields", "include"})
122  public AverageDoubleNestedAggregate( Fields declaredFields, Include include )
123    {
124    super( declaredFields, Double.TYPE );
125    this.include = include;
126    }
127
128  @Override
129  protected boolean discardNullValues()
130    {
131    return false;
132    }
133
134  @Override
135  public Context<Node> createContext( CoercibleType<Node> nestedCoercibleType )
136    {
137    return new Context<>( this, nestedCoercibleType, include );
138    }
139  }