/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.saml.attribute.resolver.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import javax.annotation.Nonnull;
import javax.annotation.concurrent.ThreadSafe;

import org.opensaml.core.xml.XMLObject;
import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.profile.context.navigate.InboundMessageContextLookup;
import org.opensaml.saml.common.messaging.context.SAMLMetadataContext;
import org.opensaml.saml.common.messaging.context.SAMLPeerEntityContext;
import org.opensaml.saml.ext.saml2mdattr.EntityAttributes;
import org.opensaml.saml.saml2.metadata.EntitiesDescriptor;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.slf4j.Logger;

import com.google.common.collect.Multimap;

import net.shibboleth.idp.attribute.AttributesMapContainer;
import net.shibboleth.idp.attribute.IdPAttribute;
import net.shibboleth.idp.attribute.IdPAttributeValue;
import net.shibboleth.idp.attribute.resolver.AbstractDataConnector;
import net.shibboleth.idp.attribute.resolver.DataConnector;
import net.shibboleth.idp.attribute.resolver.ResolutionException;
import net.shibboleth.idp.attribute.resolver.context.AttributeResolutionContext;
import net.shibboleth.idp.attribute.resolver.context.AttributeResolverWorkContext;
import net.shibboleth.shared.annotation.constraint.NotLive;
import net.shibboleth.shared.annotation.constraint.Unmodifiable;
import net.shibboleth.shared.collection.CollectionSupport;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.primitive.LoggerFactory;

/**
 * A {@link DataConnector} that returns the decoded {@link EntityAttributes}
 * from a peer's metadata.
 * 
 * @since 5.0.0
 */
@ThreadSafe
public class EntityAttributesDataConnector extends AbstractDataConnector {

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(EntityAttributesDataConnector.class);

    /** Metadata context lookup strategy. */
    @Nonnull private Function<ProfileRequestContext,SAMLMetadataContext> metadataContextLookupStrategy;

    /** Constructor. */
    public EntityAttributesDataConnector() {
        // Default is inbound -> SAMLPeerEntityContext -> SAMLMetadataContext.
        metadataContextLookupStrategy = new ChildContextLookup<>(SAMLMetadataContext.class).compose(
                new ChildContextLookup<>(SAMLPeerEntityContext.class).compose(
                        new InboundMessageContextLookup()));
    }
    
    /**
     * Set the lookup strategy for the {@link SAMLMetadataContext} to pull tags from.
     * 
     * @param strategy lookup strategy
     */
    public void setMetadataContextLookupStrategy(
            @Nonnull final Function<ProfileRequestContext,SAMLMetadataContext> strategy) {
        metadataContextLookupStrategy = Constraint.isNotNull(strategy, "SAMLMetadataContext strategy cannot be null");
    }
    
    /** {@inheritDoc} */
    @Override
    @Nonnull @Unmodifiable @NotLive protected Map<String, IdPAttribute> doDataConnectorResolve(
            @Nonnull final AttributeResolutionContext resolutionContext,
            @Nonnull final AttributeResolverWorkContext workContext) throws ResolutionException {
        
        final SAMLMetadataContext metadataContext = metadataContextLookupStrategy.apply(
                resolutionContext.getProfileRequestContextLookupStrategy().apply(resolutionContext));
        final EntityDescriptor entity = metadataContext != null ? metadataContext.getEntityDescriptor() : null;
        if (entity == null) {
            log.debug("Specified metadata source was absent.");
            return CollectionSupport.emptyMap();
        }

        final Map<String,IdPAttribute> results = new HashMap<>();

        try {
            resolveMappedTags(entity, results);
        } catch (final CloneNotSupportedException e) {
            throw new ResolutionException(e);
        }
        
        XMLObject parent = entity.getParent();
        while (parent instanceof EntitiesDescriptor entities) {
            try {
                resolveMappedTags(entities, results);
            } catch (final CloneNotSupportedException e) {
                throw new ResolutionException(e);
            }
            parent = parent.getParent();
        }

        if (results.isEmpty()) {
            log.trace("{} No entity attributes resolved", getLogPrefix());
            return CollectionSupport.emptyMap();
        }
        
        log.trace("{} Resolved attributes: {}", getLogPrefix(), results);
        return results;
    }
   
    /**
     * Pull in mapped tags as resolved attributes.
     *
     * @param parent parent object containing mapped tags
     * @param results accumulator for results
     *
     * @throws CloneNotSupportedException if cloning the mapped tag fails
     */ 
    private void resolveMappedTags(@Nonnull final XMLObject parent, @Nonnull final Map<String,IdPAttribute> results)
            throws CloneNotSupportedException {
        final List<AttributesMapContainer> containerList = parent.getObjectMetadata().get(AttributesMapContainer.class);
        if (containerList != null && !containerList.isEmpty()) {
            final AttributesMapContainer container = containerList.get(0);
            final Multimap<String,IdPAttribute> tags = container.get();
            if (!tags.isEmpty()) {
                for (final IdPAttribute attribute : tags.values()) {
                    if (results.containsKey(attribute.getId())) {
                        final IdPAttribute existing = results.get(attribute.getId());
                        final List<IdPAttributeValue> union = new ArrayList<>(existing.getValues());
                        union.addAll(attribute.getValues());
                        existing.setValues(union);
                    } else {
                        // We have to clone because of the other branch of this conditional,
                        // we might need to mutate the values by combining them with a leter set.
                        results.put(attribute.getId(), attribute.clone());
                    }
                }
            }
        }
    }
    
}
